Implement initial flow (#2)
This commit is contained in:
committed by
GitHub
parent
1c8c7ed14a
commit
8ba494c49c
204
custom_components/auth_oidc/oidc_client.py
Normal file
204
custom_components/auth_oidc/oidc_client.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import aiohttp
|
||||
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from jose import jwt
|
||||
|
||||
from jose import jwk, jwt
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCClient:
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url, client_id, scope):
|
||||
self.discovery_url = discovery_url
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
||||
else:
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def get_authorization_url(self, base_uri):
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
auth_endpoint = self.discovery_document['authorization_endpoint']
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
||||
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {
|
||||
'code_verifier': code_verifier,
|
||||
'nonce': nonce
|
||||
}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'scope': self.scope,
|
||||
'state': state,
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
response_json = await response.json()
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def _parse_id_token(self, id_token):
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
jwks_uri = self.discovery_document['jwks_uri']
|
||||
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
if not jwks_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
if not unverified_header:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get('kid')
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
if key["kid"] == kid:
|
||||
rsa_key = key
|
||||
break
|
||||
|
||||
if not rsa_key:
|
||||
print(f"Could not find matching key with kid:{kid}")
|
||||
return None
|
||||
|
||||
# Construct the JWK
|
||||
jwk_obj = jwk.construct(rsa_key)
|
||||
|
||||
# Verify the token
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document['issuer'],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def complete_token_flow(self, base_uri, code, state):
|
||||
if state not in self.flows:
|
||||
return None
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow['code_verifier']
|
||||
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
token_endpoint = self.discovery_document['token_endpoint']
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'code': code,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'code_verifier': code_verifier,
|
||||
}
|
||||
|
||||
_LOGGER.debug(f"Token request params: {query_params}")
|
||||
|
||||
token_response = await self._make_token_request(token_endpoint, query_params)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
access_token = token_response.get('access_token')
|
||||
id_token = token_response.get('id_token')
|
||||
_LOGGER.debug(f"Access Token: {access_token}")
|
||||
_LOGGER.debug(f"ID Token: {id_token}")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get('nonce') != flow['nonce']:
|
||||
_LOGGER.warning(f"Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"email": id_token.get("email"),
|
||||
"preferred_username": id_token.get("preferred_username"),
|
||||
"nickname": id_token.get("nickname"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
Reference in New Issue
Block a user