Fixes Home Assistant error about re-creating HTTP sessions (#22)

* Bump to 0.5.1

* Prevent HA errors about HTTP session left open
This commit is contained in:
Christiaan Goossens
2025-01-12 12:43:41 +01:00
committed by GitHub
parent bfad0418ad
commit 63f5f175ee
3 changed files with 29 additions and 13 deletions

View File

@@ -19,5 +19,5 @@
"jinja2>=3.1.4", "jinja2>=3.1.4",
"bcrypt>=4.2.0" "bcrypt>=4.2.0"
], ],
"version": "0.4.1" "version": "0.5.1"
} }

View File

@@ -58,6 +58,9 @@ class OIDCClient:
# Flows stores the state, code_verifier and nonce of all current flows. # Flows stores the state, code_verifier and nonce of all current flows.
flows = {} flows = {}
# HTTP session to be used
http_session: aiohttp.ClientSession = None
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@@ -94,11 +97,13 @@ class OIDCClient:
self.tls_verify = network.get(NETWORK_TLS_VERIFY, True) self.tls_verify = network.get(NETWORK_TLS_VERIFY, True)
self.tls_ca_path = network.get(NETWORK_TLS_CA_PATH) self.tls_ca_path = network.get(NETWORK_TLS_CA_PATH)
_LOGGER.debug( def __del__(self):
"OIDC provider network options (verify certificates: %r, custom CA file: %s)", """Cleanup the HTTP session."""
self.tls_verify,
self.tls_ca_path, # HA never seems to run this, but it's good practice to close the session
) if self.http_session:
_LOGGER.debug("Closing HTTP session")
self.http_session.close()
def _base64url_encode(self, value: str) -> str: def _base64url_encode(self, value: str) -> str:
"""Uses base64url encoding on a given string""" """Uses base64url encoding on a given string"""
@@ -108,8 +113,18 @@ class OIDCClient:
"""Generates a random URL safe string (base64_url encoded)""" """Generates a random URL safe string (base64_url encoded)"""
return self._base64url_encode(os.urandom(length)) return self._base64url_encode(os.urandom(length))
async def _create_session(self): async def _get_http_session(self) -> aiohttp.ClientSession:
"""Create a new client session with custom networking/TLS options""" """Create or get the existing client session with custom networking/TLS options"""
if self.http_session is not None:
return self.http_session
_LOGGER.debug(
"Creating HTTP session provider with options: "
+ "verify certificates: %r, custom CA file: %s",
self.tls_verify,
self.tls_ca_path,
)
tcp_connector_args = {"verify_ssl": self.tls_verify} tcp_connector_args = {"verify_ssl": self.tls_verify}
if self.tls_ca_path: if self.tls_ca_path:
@@ -119,14 +134,15 @@ class OIDCClient:
) )
tcp_connector_args["ssl"] = ssl_context tcp_connector_args["ssl"] = ssl_context
return aiohttp.ClientSession( self.http_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(**tcp_connector_args) connector=aiohttp.TCPConnector(**tcp_connector_args)
) )
return self.http_session
async def _fetch_discovery_document(self): async def _fetch_discovery_document(self):
"""Fetches discovery document from the given URL.""" """Fetches discovery document from the given URL."""
try: try:
session = await self._create_session() session = await self._get_http_session()
async with session.get(self.discovery_url) as response: async with session.get(self.discovery_url) as response:
response.raise_for_status() response.raise_for_status()
@@ -143,7 +159,7 @@ class OIDCClient:
async def _get_jwks(self, jwks_uri): async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL.""" """Fetches JWKS from the given URL."""
try: try:
session = await self._create_session() session = await self._get_http_session()
async with session.get(jwks_uri) as response: async with session.get(jwks_uri) as response:
response.raise_for_status() response.raise_for_status()
@@ -155,7 +171,7 @@ class OIDCClient:
async def _make_token_request(self, token_endpoint, query_params): async def _make_token_request(self, token_endpoint, query_params):
"""Performs the token POST call""" """Performs the token POST call"""
try: try:
session = await self._create_session() session = await self._get_http_session()
async with session.post(token_endpoint, data=query_params) as response: async with session.post(token_endpoint, data=query_params) as response:
response.raise_for_status() response.raise_for_status()

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "hass-oidc-auth" name = "hass-oidc-auth"
version = "0.4.1" version = "0.5.1"
description = "OIDC component for Home Assistant" description = "OIDC component for Home Assistant"
authors = [ authors = [
{ name = "Christiaan Goossens", email = "contact@christiaangoossens.nl" } { name = "Christiaan Goossens", email = "contact@christiaangoossens.nl" }