204 lines
6.8 KiB
Python
204 lines
6.8 KiB
Python
import aiohttp
|
|
|
|
import urllib.parse
|
|
import logging
|
|
import os
|
|
import base64
|
|
import hashlib
|
|
from jose import jwt
|
|
|
|
from jose import jwk, jwt
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
class OIDCClient:
|
|
flows = {}
|
|
|
|
def __init__(self, discovery_url, client_id, scope):
|
|
self.discovery_url = discovery_url
|
|
self.client_id = client_id
|
|
self.scope = scope
|
|
|
|
async def fetch_discovery_document(self):
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(self.discovery_url) as response:
|
|
response.raise_for_status()
|
|
return await response.json()
|
|
except aiohttp.ClientResponseError as e:
|
|
if e.status == 404:
|
|
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
|
else:
|
|
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
|
return None
|
|
|
|
async def get_authorization_url(self, base_uri):
|
|
if not hasattr(self, 'discovery_document'):
|
|
self.discovery_document = await self.fetch_discovery_document()
|
|
|
|
if not self.discovery_document:
|
|
return None
|
|
|
|
auth_endpoint = self.discovery_document['authorization_endpoint']
|
|
|
|
# Generate the necessary PKCE parameters, nonce & state
|
|
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
|
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
|
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
|
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
|
|
|
# Save all of them for later verification
|
|
self.flows[state] = {
|
|
'code_verifier': code_verifier,
|
|
'nonce': nonce
|
|
}
|
|
|
|
# Construct the params
|
|
query_params = {
|
|
'response_type': 'code',
|
|
'client_id': self.client_id,
|
|
'redirect_uri': base_uri + '/auth/oidc/callback',
|
|
'scope': self.scope,
|
|
'state': state,
|
|
'nonce': nonce,
|
|
'code_challenge': code_challenge,
|
|
'code_challenge_method': 'S256',
|
|
}
|
|
|
|
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
|
return url
|
|
|
|
async def _make_token_request(self, token_endpoint, query_params):
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(token_endpoint, data=query_params) as response:
|
|
response.raise_for_status()
|
|
return await response.json()
|
|
except aiohttp.ClientResponseError as e:
|
|
response_json = await response.json()
|
|
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
|
return None
|
|
|
|
return None
|
|
|
|
async def _get_jwks(self, jwks_uri):
|
|
"""Fetches JWKS from the given URL."""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(jwks_uri) as response:
|
|
response.raise_for_status()
|
|
return await response.json()
|
|
except aiohttp.ClientResponseError as e:
|
|
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
|
return None
|
|
|
|
async def _parse_id_token(self, id_token):
|
|
# Parse the id token to obtain the relevant details
|
|
# Use python-jose
|
|
if not hasattr(self, 'discovery_document'):
|
|
self.discovery_document = await self.fetch_discovery_document()
|
|
|
|
if not self.discovery_document:
|
|
return None
|
|
|
|
jwks_uri = self.discovery_document['jwks_uri']
|
|
|
|
jwks_data = await self._get_jwks(jwks_uri)
|
|
if not jwks_data:
|
|
return None
|
|
|
|
try:
|
|
unverified_header = jwt.get_unverified_header(id_token)
|
|
if not unverified_header:
|
|
print("Could not parse JWT Header")
|
|
return None
|
|
|
|
kid = unverified_header.get('kid')
|
|
if not kid:
|
|
print("JWT does not have kid (Key ID)")
|
|
return None
|
|
|
|
|
|
# Get the correct key
|
|
rsa_key = None
|
|
for key in jwks_data["keys"]:
|
|
if key["kid"] == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
print(f"Could not find matching key with kid:{kid}")
|
|
return None
|
|
|
|
# Construct the JWK
|
|
jwk_obj = jwk.construct(rsa_key)
|
|
|
|
# Verify the token
|
|
decoded_token = jwt.decode(
|
|
id_token,
|
|
jwk_obj,
|
|
algorithms=["RS256"], # Adjust if your algorithm is different
|
|
audience=self.client_id,
|
|
issuer=self.discovery_document['issuer'],
|
|
)
|
|
return decoded_token
|
|
|
|
except jwt.JWTError as e:
|
|
print(f"JWT Verification failed: {e}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"Unexpected error: {e}")
|
|
|
|
return None
|
|
|
|
async def complete_token_flow(self, base_uri, code, state):
|
|
if state not in self.flows:
|
|
return None
|
|
|
|
flow = self.flows[state]
|
|
code_verifier = flow['code_verifier']
|
|
|
|
if not hasattr(self, 'discovery_document'):
|
|
self.discovery_document = await self.fetch_discovery_document()
|
|
|
|
if not self.discovery_document:
|
|
return None
|
|
|
|
token_endpoint = self.discovery_document['token_endpoint']
|
|
|
|
# Construct the params
|
|
query_params = {
|
|
'grant_type': 'authorization_code',
|
|
'client_id': self.client_id,
|
|
'code': code,
|
|
'redirect_uri': base_uri + '/auth/oidc/callback',
|
|
'code_verifier': code_verifier,
|
|
}
|
|
|
|
_LOGGER.debug(f"Token request params: {query_params}")
|
|
|
|
token_response = await self._make_token_request(token_endpoint, query_params)
|
|
|
|
if not token_response:
|
|
return None
|
|
|
|
access_token = token_response.get('access_token')
|
|
id_token = token_response.get('id_token')
|
|
_LOGGER.debug(f"Access Token: {access_token}")
|
|
_LOGGER.debug(f"ID Token: {id_token}")
|
|
|
|
# Parse the id token to obtain the relevant details
|
|
id_token = await self._parse_id_token(id_token)
|
|
|
|
# Verify nonce
|
|
if id_token.get('nonce') != flow['nonce']:
|
|
_LOGGER.warning(f"Nonce mismatch!")
|
|
return None
|
|
|
|
return {
|
|
"name": id_token.get("name"),
|
|
"email": id_token.get("email"),
|
|
"preferred_username": id_token.get("preferred_username"),
|
|
"nickname": id_token.get("nickname"),
|
|
"groups": id_token.get("groups"),
|
|
} |