Code quality improvements (v0.2.0-pre-alpha) (#5)
* Bumped version to 0.2.0 * Implemented Github Actions for HACS, Hassfest, Linting * Improved code quality (compliant with the linter now) * Added link to the finish page to automatically login on the same device/browser
This commit is contained in:
committed by
GitHub
parent
a30d42ffce
commit
b4a08b17ab
@@ -1,25 +1,50 @@
|
||||
import aiohttp
|
||||
"""OIDC Client class"""
|
||||
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from jose import jwt
|
||||
|
||||
from jose import jwk, jwt
|
||||
from typing import Optional
|
||||
import aiohttp
|
||||
from jose import jwt, jwk
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCClientException(Exception):
|
||||
"Raised when the OIDC Client encounters an error"
|
||||
|
||||
|
||||
class OIDCDiscoveryInvalid(OIDCClientException):
|
||||
"Raised when the discovery document is not found, invalid or otherwise malformed."
|
||||
|
||||
|
||||
class OIDCTokenResponseInvalid(OIDCClientException):
|
||||
"Raised when the token request returns invalid."
|
||||
|
||||
|
||||
class OIDCJWKSInvalid(OIDCClientException):
|
||||
"Raised when the JWKS is invalid or cannot be obtained."
|
||||
|
||||
|
||||
class OIDCStateInvalid(OIDCClientException):
|
||||
"Raised when the state for your request cannot be matched against a stored state."
|
||||
|
||||
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
|
||||
# Flows stores the state, code_verifier and nonce of all current flows.
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url, client_id, scope):
|
||||
def __init__(self, discovery_url: str, client_id: str, scope: str):
|
||||
self.discovery_url = discovery_url
|
||||
self.discovery_document = None
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
async def _fetch_discovery_document(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
@@ -27,47 +52,60 @@ class OIDCClient:
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document not found at %s", self.discovery_url
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def get_authorization_url(self, base_uri):
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
_LOGGER.warning("Error: %s - %s", e.status, e.message)
|
||||
raise OIDCDiscoveryInvalid from e
|
||||
|
||||
if not self.discovery_document:
|
||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
||||
"""Generates the authorization URL for the OIDC flow."""
|
||||
try:
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
auth_endpoint = self.discovery_document["authorization_endpoint"]
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(
|
||||
hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
)
|
||||
.rstrip(b"=")
|
||||
.decode("utf-8")
|
||||
)
|
||||
nonce = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
state = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {"code_verifier": code_verifier, "nonce": nonce}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": self.scope,
|
||||
"state": state,
|
||||
"nonce": nonce,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error generating authorization URL: %s", e)
|
||||
return None
|
||||
|
||||
auth_endpoint = self.discovery_document['authorization_endpoint']
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
||||
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {
|
||||
'code_verifier': code_verifier,
|
||||
'nonce': nonce
|
||||
}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'scope': self.scope,
|
||||
'state': state,
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -75,12 +113,9 @@ class OIDCClient:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
response_json = await response.json()
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
||||
return None
|
||||
_LOGGER.warning("Error exchanging token: %s - %s", e.status, e.message)
|
||||
raise OIDCTokenResponseInvalid from e
|
||||
|
||||
return None
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
@@ -89,23 +124,18 @@ class OIDCClient:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
||||
return None
|
||||
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _parse_id_token(self, id_token: str):
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
async def _parse_id_token(self, id_token):
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
jwks_uri = self.discovery_document['jwks_uri']
|
||||
|
||||
jwks_uri = self.discovery_document["jwks_uri"]
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
if not jwks_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
@@ -113,12 +143,11 @@ class OIDCClient:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get('kid')
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
@@ -139,66 +168,60 @@ class OIDCClient:
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document['issuer'],
|
||||
issuer=self.discovery_document["issuer"],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def complete_token_flow(self, base_uri, code, state):
|
||||
if state not in self.flows:
|
||||
|
||||
async def async_complete_token_flow(
|
||||
self, redirect_uri: str, code: str, state: str
|
||||
) -> dict[str, str | dict]:
|
||||
"""Completes the OIDC token flow to obtain a user's details."""
|
||||
|
||||
try:
|
||||
if state not in self.flows:
|
||||
raise OIDCStateInvalid
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow["code_verifier"]
|
||||
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
token_endpoint = self.discovery_document["token_endpoint"]
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.client_id,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
token_response = await self._make_token_request(
|
||||
token_endpoint, query_params
|
||||
)
|
||||
id_token = token_response.get("id_token")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get("nonce") != flow["nonce"]:
|
||||
_LOGGER.warning("Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"username": id_token.get("preferred_username"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error completing token flow: %s", e)
|
||||
return None
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow['code_verifier']
|
||||
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
token_endpoint = self.discovery_document['token_endpoint']
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'code': code,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'code_verifier': code_verifier,
|
||||
}
|
||||
|
||||
_LOGGER.debug(f"Token request params: {query_params}")
|
||||
|
||||
token_response = await self._make_token_request(token_endpoint, query_params)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
access_token = token_response.get('access_token')
|
||||
id_token = token_response.get('id_token')
|
||||
_LOGGER.debug(f"Access Token: {access_token}")
|
||||
_LOGGER.debug(f"ID Token: {id_token}")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get('nonce') != flow['nonce']:
|
||||
_LOGGER.warning(f"Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"email": id_token.get("email"),
|
||||
"preferred_username": id_token.get("preferred_username"),
|
||||
"nickname": id_token.get("nickname"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
Reference in New Issue
Block a user