feat: enable verification of certs via network.tls_verify and private CA chains with network.tls_ca_path (#16)

Signed-off-by: Christopher Klein <ckl@dreitier.com>
This commit is contained in:
Schakko
2025-01-06 10:09:30 +01:00
committed by GitHub
parent 00da053f50
commit bfad0418ad
4 changed files with 94 additions and 14 deletions

View File

@@ -5,9 +5,12 @@ import logging
import os
import base64
import hashlib
import ssl
from typing import Optional
from functools import partial
import aiohttp
from jose import jwt, jwk
from homeassistant.core import HomeAssistant
from .types import UserDetails
from .config import (
@@ -17,6 +20,8 @@ from .config import (
CLAIMS_GROUPS,
ROLE_ADMINS,
ROLE_USERS,
NETWORK_TLS_VERIFY,
NETWORK_TLS_CA_PATH,
)
_LOGGER = logging.getLogger(__name__)
@@ -53,7 +58,15 @@ class OIDCClient:
# Flows stores the state, code_verifier and nonce of all current flows.
flows = {}
def __init__(self, discovery_url: str, client_id: str, scope: str, **kwargs: str):
def __init__(
self,
hass: HomeAssistant,
discovery_url: str,
client_id: str,
scope: str,
**kwargs: str,
):
self.hass = hass
self.discovery_url = discovery_url
self.discovery_document = None
self.client_id = client_id
@@ -70,13 +83,22 @@ class OIDCClient:
features = kwargs.get("features")
claims = kwargs.get("claims")
roles = kwargs.get("roles")
network = kwargs.get("network")
self.disable_pkce: bool = features.get(FEATURES_DISABLE_PKCE)
self.disable_pkce = features.get(FEATURES_DISABLE_PKCE, False)
self.display_name_claim = claims.get(CLAIMS_DISPLAY_NAME, "name")
self.username_claim = claims.get(CLAIMS_USERNAME, "preferred_username")
self.groups_claim = claims.get(CLAIMS_GROUPS, "groups")
self.user_role = roles.get(ROLE_USERS, None)
self.admin_role = roles.get(ROLE_ADMINS, "admins")
self.tls_verify = network.get(NETWORK_TLS_VERIFY, True)
self.tls_ca_path = network.get(NETWORK_TLS_CA_PATH)
_LOGGER.debug(
"OIDC provider network options (verify certificates: %r, custom CA file: %s)",
self.tls_verify,
self.tls_ca_path,
)
def _base64url_encode(self, value: str) -> str:
"""Uses base64url encoding on a given string"""
@@ -86,13 +108,29 @@ class OIDCClient:
"""Generates a random URL safe string (base64_url encoded)"""
return self._base64url_encode(os.urandom(length))
async def _create_session(self):
"""Create a new client session with custom networking/TLS options"""
tcp_connector_args = {"verify_ssl": self.tls_verify}
if self.tls_ca_path:
# Move to hass' executor to prevent blocking code inside non-blocking method
ssl_context = await self.hass.loop.run_in_executor(
None, partial(ssl.create_default_context, cafile=self.tls_ca_path)
)
tcp_connector_args["ssl"] = ssl_context
return aiohttp.ClientSession(
connector=aiohttp.TCPConnector(**tcp_connector_args)
)
async def _fetch_discovery_document(self):
"""Fetches discovery document from the given URL."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.discovery_url) as response:
response.raise_for_status()
return await response.json()
session = await self._create_session()
async with session.get(self.discovery_url) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
_LOGGER.warning(
@@ -105,10 +143,11 @@ class OIDCClient:
async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(jwks_uri) as response:
response.raise_for_status()
return await response.json()
session = await self._create_session()
async with session.get(jwks_uri) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
raise OIDCJWKSInvalid from e
@@ -116,10 +155,11 @@ class OIDCClient:
async def _make_token_request(self, token_endpoint, query_params):
"""Performs the token POST call"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(token_endpoint, data=query_params) as response:
response.raise_for_status()
return await response.json()
session = await self._create_session()
async with session.post(token_endpoint, data=query_params) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
if e.status == 400:
_LOGGER.warning(