Code quality improvements (v0.2.0-pre-alpha) (#5)

* Bumped version to 0.2.0
* Implemented Github Actions for HACS, Hassfest, Linting
* Improved code quality (compliant with the linter now)
* Added link to the finish page to automatically login on the same device/browser
This commit is contained in:
Christiaan Goossens
2024-12-27 00:20:38 +01:00
committed by GitHub
parent a30d42ffce
commit b4a08b17ab
18 changed files with 1148 additions and 278 deletions

View File

@@ -1,25 +1,50 @@
import aiohttp
"""OIDC Client class"""
import urllib.parse
import logging
import os
import base64
import hashlib
from jose import jwt
from jose import jwk, jwt
from typing import Optional
import aiohttp
from jose import jwt, jwk
_LOGGER = logging.getLogger(__name__)
class OIDCClientException(Exception):
"Raised when the OIDC Client encounters an error"
class OIDCDiscoveryInvalid(OIDCClientException):
"Raised when the discovery document is not found, invalid or otherwise malformed."
class OIDCTokenResponseInvalid(OIDCClientException):
"Raised when the token request returns invalid."
class OIDCJWKSInvalid(OIDCClientException):
"Raised when the JWKS is invalid or cannot be obtained."
class OIDCStateInvalid(OIDCClientException):
"Raised when the state for your request cannot be matched against a stored state."
class OIDCClient:
"""OIDC Client implementation for Python, including PKCE."""
# Flows stores the state, code_verifier and nonce of all current flows.
flows = {}
def __init__(self, discovery_url, client_id, scope):
def __init__(self, discovery_url: str, client_id: str, scope: str):
self.discovery_url = discovery_url
self.discovery_document = None
self.client_id = client_id
self.scope = scope
async def fetch_discovery_document(self):
async def _fetch_discovery_document(self):
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.discovery_url) as response:
@@ -27,47 +52,60 @@ class OIDCClient:
return await response.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
_LOGGER.warning(
"Error: Discovery document not found at %s", self.discovery_url
)
else:
_LOGGER.warning(f"Error: {e.status} - {e.message}")
return None
async def get_authorization_url(self, base_uri):
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
_LOGGER.warning("Error: %s - %s", e.status, e.message)
raise OIDCDiscoveryInvalid from e
if not self.discovery_document:
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
"""Generates the authorization URL for the OIDC flow."""
try:
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
auth_endpoint = self.discovery_document["authorization_endpoint"]
# Generate the necessary PKCE parameters, nonce & state
code_verifier = (
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode("utf-8")
)
code_challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(code_verifier.encode("utf-8")).digest()
)
.rstrip(b"=")
.decode("utf-8")
)
nonce = (
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
)
state = (
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
)
# Save all of them for later verification
self.flows[state] = {"code_verifier": code_verifier, "nonce": nonce}
# Construct the params
query_params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": redirect_uri,
"scope": self.scope,
"state": state,
"nonce": nonce,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
return url
except OIDCClientException as e:
_LOGGER.warning("Error generating authorization URL: %s", e)
return None
auth_endpoint = self.discovery_document['authorization_endpoint']
# Generate the necessary PKCE parameters, nonce & state
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
# Save all of them for later verification
self.flows[state] = {
'code_verifier': code_verifier,
'nonce': nonce
}
# Construct the params
query_params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': base_uri + '/auth/oidc/callback',
'scope': self.scope,
'state': state,
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': 'S256',
}
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
return url
async def _make_token_request(self, token_endpoint, query_params):
try:
async with aiohttp.ClientSession() as session:
@@ -75,12 +113,9 @@ class OIDCClient:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
response_json = await response.json()
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
return None
_LOGGER.warning("Error exchanging token: %s - %s", e.status, e.message)
raise OIDCTokenResponseInvalid from e
return None
async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
@@ -89,23 +124,18 @@ class OIDCClient:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
return None
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
raise OIDCJWKSInvalid from e
async def _parse_id_token(self, id_token: str):
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
async def _parse_id_token(self, id_token):
# Parse the id token to obtain the relevant details
# Use python-jose
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
jwks_uri = self.discovery_document['jwks_uri']
jwks_uri = self.discovery_document["jwks_uri"]
jwks_data = await self._get_jwks(jwks_uri)
if not jwks_data:
return None
try:
unverified_header = jwt.get_unverified_header(id_token)
@@ -113,12 +143,11 @@ class OIDCClient:
print("Could not parse JWT Header")
return None
kid = unverified_header.get('kid')
kid = unverified_header.get("kid")
if not kid:
print("JWT does not have kid (Key ID)")
return None
# Get the correct key
rsa_key = None
for key in jwks_data["keys"]:
@@ -139,66 +168,60 @@ class OIDCClient:
jwk_obj,
algorithms=["RS256"], # Adjust if your algorithm is different
audience=self.client_id,
issuer=self.discovery_document['issuer'],
issuer=self.discovery_document["issuer"],
)
return decoded_token
except jwt.JWTError as e:
print(f"JWT Verification failed: {e}")
return None
except Exception as e:
print(f"Unexpected error: {e}")
return None
async def complete_token_flow(self, base_uri, code, state):
if state not in self.flows:
async def async_complete_token_flow(
self, redirect_uri: str, code: str, state: str
) -> dict[str, str | dict]:
"""Completes the OIDC token flow to obtain a user's details."""
try:
if state not in self.flows:
raise OIDCStateInvalid
flow = self.flows[state]
code_verifier = flow["code_verifier"]
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
token_endpoint = self.discovery_document["token_endpoint"]
# Construct the params
query_params = {
"grant_type": "authorization_code",
"client_id": self.client_id,
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
token_response = await self._make_token_request(
token_endpoint, query_params
)
id_token = token_response.get("id_token")
# Parse the id token to obtain the relevant details
id_token = await self._parse_id_token(id_token)
# Verify nonce
if id_token.get("nonce") != flow["nonce"]:
_LOGGER.warning("Nonce mismatch!")
return None
return {
"name": id_token.get("name"),
"username": id_token.get("preferred_username"),
"groups": id_token.get("groups"),
}
except OIDCClientException as e:
_LOGGER.warning("Error completing token flow: %s", e)
return None
flow = self.flows[state]
code_verifier = flow['code_verifier']
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
token_endpoint = self.discovery_document['token_endpoint']
# Construct the params
query_params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'code': code,
'redirect_uri': base_uri + '/auth/oidc/callback',
'code_verifier': code_verifier,
}
_LOGGER.debug(f"Token request params: {query_params}")
token_response = await self._make_token_request(token_endpoint, query_params)
if not token_response:
return None
access_token = token_response.get('access_token')
id_token = token_response.get('id_token')
_LOGGER.debug(f"Access Token: {access_token}")
_LOGGER.debug(f"ID Token: {id_token}")
# Parse the id token to obtain the relevant details
id_token = await self._parse_id_token(id_token)
# Verify nonce
if id_token.get('nonce') != flow['nonce']:
_LOGGER.warning(f"Nonce mismatch!")
return None
return {
"name": id_token.get("name"),
"email": id_token.get("email"),
"preferred_username": id_token.get("preferred_username"),
"nickname": id_token.get("nickname"),
"groups": id_token.get("groups"),
}