diff --git a/oidc_provider/lib/utils/common.py b/oidc_provider/lib/utils/common.py index 26b68af3..d98bf9dd 100644 --- a/oidc_provider/lib/utils/common.py +++ b/oidc_provider/lib/utils/common.py @@ -1,10 +1,9 @@ from hashlib import sha224 -import django -if django.VERSION >= (1, 11): +try: from django.urls import reverse -else: +except ImportError: from django.core.urlresolvers import reverse from django.http import HttpResponse diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index 73fd62ea..00ec9b66 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -2,12 +2,10 @@ import time import uuid -from Cryptodome.PublicKey.RSA import importKey from django.utils import dateformat, timezone -from jwkest.jwk import RSAKey as jwk_RSAKey -from jwkest.jwk import SYMKey -from jwkest.jws import JWS -from jwkest.jwt import JWT +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.backends import default_backend from oidc_provider.lib.utils.common import get_issuer from oidc_provider.models import ( @@ -71,18 +69,17 @@ def encode_id_token(payload, client): Represent the ID Token as a JSON Web Token (JWT). Return a hash. """ - keys = get_client_alg_keys(client) - _jws = JWS(payload, alg=client.jwt_alg) - return _jws.sign_compact(keys) - - -def decode_id_token(token, client): - """ - Represent the ID Token as a JSON Web Token (JWT). - Return a hash. - """ - keys = get_client_alg_keys(client) - return JWS().verify_compact(token, keys=keys) + key = client.client_secret + if client.jwt_alg == 'RS256': + rsakeys = RSAKey.objects.all() + if not rsakeys: + raise Exception('You must have an RSA Key.') + rsakey = rsakeys[0] + key = serialization.load_pem_private_key(rsakey.key.encode(), None, default_backend()).private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()).decode('utf8') + return jwt.encode(payload, key, algorithm=client.jwt_alg).decode() def client_id_from_id_token(id_token): @@ -90,8 +87,7 @@ def client_id_from_id_token(id_token): Extracts the client id from a JSON Web Token (JWT). Returns a string or None. """ - payload = JWT().unpack(id_token).payload() - return payload.get('aud', None) + return jwt.decode(id_token, verify=False).get('aud', None) def create_token(user, client, scope, id_token_dic=None): @@ -138,22 +134,3 @@ def create_code(user, client, scope, nonce, is_authentication, code.is_authentication = is_authentication return code - - -def get_client_alg_keys(client): - """ - Takes a client and returns the set of keys associated with it. - Returns a list of keys. - """ - if client.jwt_alg == 'RS256': - keys = [] - for rsakey in RSAKey.objects.all(): - keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid)) - if not keys: - raise Exception('You must add at least one RSA Key.') - elif client.jwt_alg == 'HS256': - keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)] - else: - raise Exception('Unsupported key algorithm.') - - return keys diff --git a/oidc_provider/management/commands/creatersakey.py b/oidc_provider/management/commands/creatersakey.py index d5d423f7..fdd7439c 100644 --- a/oidc_provider/management/commands/creatersakey.py +++ b/oidc_provider/management/commands/creatersakey.py @@ -1,4 +1,6 @@ -from Cryptodome.PublicKey import RSA +from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.backends import default_backend from django.core.management.base import BaseCommand from oidc_provider.models import RSAKey @@ -9,8 +11,12 @@ class Command(BaseCommand): def handle(self, *args, **options): try: - key = RSA.generate(1024) - rsakey = RSAKey(key=key.exportKey('PEM').decode('utf8')) + key = generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + rsakey = RSAKey( + key=key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()).decode('utf8')) rsakey.save() self.stdout.write(u'RSA key successfully created with kid: {0}'.format(rsakey.kid)) except Exception as e: diff --git a/oidc_provider/views.py b/oidc_provider/views.py index 72d941e2..7808375f 100644 --- a/oidc_provider/views.py +++ b/oidc_provider/views.py @@ -6,16 +6,17 @@ except ImportError: from urllib.parse import urlsplit, parse_qs, urlunsplit, urlencode -from Cryptodome.PublicKey import RSA +from base64 import urlsafe_b64encode from django.contrib.auth.views import ( redirect_to_login, logout, ) -import django -if django.VERSION >= (1, 11): +from struct import pack + +try: from django.urls import reverse -else: +except ImportError: from django.core.urlresolvers import reverse from django.contrib.auth import logout as django_user_logout @@ -26,7 +27,9 @@ from django.views.decorators.clickjacking import xframe_options_exempt from django.views.decorators.http import require_http_methods from django.views.generic import View -from jwkest import long_to_base64 + +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.hazmat.backends import default_backend from oidc_provider.lib.claims import StandardScopeClaims from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint @@ -284,18 +287,29 @@ def get(self, request, *args, **kwargs): class JwksView(View): + def _num_to_b64_string(self, value): + int_array = [] + while value: + value, r = divmod(value, 256) + int_array.insert(0, r) + char_array = pack('B'*len(int_array), *int_array) + return urlsafe_b64encode(char_array).rstrip(b'=').decode("ascii") + def get(self, request, *args, **kwargs): dic = dict(keys=[]) for rsakey in RSAKey.objects.all(): - public_key = RSA.importKey(rsakey.key).publickey() + public_numbers = load_pem_private_key( + rsakey.key.encode(), + None, + default_backend()).public_key().public_numbers() dic['keys'].append({ 'kty': 'RSA', 'alg': 'RS256', 'use': 'sig', 'kid': rsakey.kid, - 'n': long_to_base64(public_key.n), - 'e': long_to_base64(public_key.e), + 'n': self._num_to_b64_string(public_numbers.n), + 'e': self._num_to_b64_string(public_numbers.e), }) response = JsonResponse(dic) diff --git a/setup.py b/setup.py index 61ce8ccb..fd8b87fa 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,13 @@ test_suite='runtests.runtests', tests_require=[ 'pyjwkest>=1.3.0', + 'cryptography>==2.0', + 'pyjwt>==1.5.0', 'mock>=2.0.0', ], install_requires=[ - 'pyjwkest>=1.3.0', + 'cryptography>==2.0', + 'pyjwt>==1.5.0', ], ) diff --git a/tox.ini b/tox.ini index e3b52b12..8e5162ee 100644 --- a/tox.ini +++ b/tox.ini @@ -36,4 +36,4 @@ commands= basepython=python deps=flake8 commands = - flake8 --max-line-length=120 + flake8 --max-line-length=120 --exclude=.svn,CVS,.bzr,.hg,.git,__pycache__,.tox,oidc_provider/migrations