Add unit tests (#133)

* Add initial test & add pipeline

* Add very basic YAML config tests

* Add coverage reporting

* Add some webserver & template loading tests

* Add test cases for the helpers

* Implement initial OIDC server tests

* Test codestore & discovery checker

* Test basics of the config flow

* Add test for the HA auth provider

* Cleaned up tests & test injection
This commit is contained in:
Christiaan Goossens
2025-10-05 21:03:02 +02:00
committed by GitHub
parent 5714e844a7
commit 404d2451df
42 changed files with 2331 additions and 91 deletions

View File

@@ -67,6 +67,4 @@ class OIDCCallbackView(HomeAssistantView):
return web.Response(text=view_html, content_type="text/html")
code = await self.oidc_provider.async_save_user_info(user_details)
return web.HTTPFound(
get_url("/auth/oidc/finish?code=" + code, self.force_https)
)
raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https))

View File

@@ -40,7 +40,7 @@ class OIDCFinishView(HomeAssistantView):
return web.Response(text="No code received", status=500)
# Return redirect to the main page for sign in with a cookie
return web.HTTPFound(
raise web.HTTPFound(
location="/?storeToken=true",
headers={
# Set a cookie to enable autologin on only the specific path used

View File

@@ -25,10 +25,14 @@ class OIDCRedirectView(HomeAssistantView):
"""Receive response."""
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
if auth_url:
return web.HTTPFound(auth_url)
try:
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
if auth_url:
raise web.HTTPFound(auth_url)
except RuntimeError:
pass
view_html = await get_view(
"error",

View File

@@ -23,7 +23,7 @@ class OIDCWelcomeView(HomeAssistantView):
"""Receive response."""
if not self.is_enabled:
return web.HTTPTemporaryRedirect(get_url("/", self.force_https))
raise web.HTTPTemporaryRedirect(get_url("/", self.force_https))
view_html = await get_view("welcome", {"name": self.name})
return web.Response(text=view_html, content_type="text/html")

View File

@@ -177,9 +177,9 @@ class OpenIDAuthProvider(AuthProvider):
# If person creation is enabled, add a person for this user
if self.create_persons:
user_meta = await self.async_user_meta_for_credentials(credential)
await self.async_create_person(user, user_meta.name)
await self._async_create_person(user, user_meta.name)
async def async_create_person(self, user: User, name: str) -> None:
async def _async_create_person(self, user: User, name: str) -> None:
"""Create a person for the user."""
_LOGGER.info("Automatically creating person for new user %s", user.id)
@@ -194,7 +194,7 @@ class OpenIDAuthProvider(AuthProvider):
# pylint: disable=broad-exception-caught
except Exception:
_LOGGER.warning(
"Requested automatic person creation, but person creation failed."
"Requested automatic person creation, but person creation failed"
)
# pylint: enable=broad-exception-caught
@@ -315,7 +315,7 @@ class OpenIdLoginFlow(LoginFlow):
"""Handle the step of the form."""
# Try to use the user input first
if user_input is not None:
if user_input is not None and "code" in user_input:
try:
return await self._finalize_user(user_input["code"])
except InvalidAuthError:
@@ -323,14 +323,15 @@ class OpenIdLoginFlow(LoginFlow):
# If not available, check the cookie
req = http.current_request.get()
code_cookie = req.cookies.get("auth_oidc_code")
if req and req.cookies:
code_cookie = req.cookies.get("auth_oidc_code")
if code_cookie:
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
try:
return await self._finalize_user(code_cookie)
except InvalidAuthError:
pass
if code_cookie:
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
try:
return await self._finalize_user(code_cookie)
except InvalidAuthError:
pass
# If none are available, just show the form
return self._show_login_form()

View File

@@ -3,7 +3,7 @@
import random
import string
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import cast, Optional
from homeassistant.helpers.storage import Store
from homeassistant.core import HomeAssistant
@@ -31,7 +31,7 @@ class CodeStore:
data = cast(dict[str, UserDetails], {})
self._data = data
async def async_save(self) -> None:
async def _async_save(self) -> None:
"""Save data."""
if self._data is not None:
await self._store.async_save(self._data)
@@ -46,7 +46,7 @@ class CodeStore:
raise RuntimeError("Data not loaded")
code = self._generate_code()
expiration = datetime.utcnow() + timedelta(minutes=5)
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
self._data[code] = {
"user_info": user_info,
@@ -54,7 +54,7 @@ class CodeStore:
"expiration": expiration.isoformat(),
}
await self.async_save()
await self._async_save()
return code
async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]:
@@ -67,12 +67,15 @@ class CodeStore:
if user_data:
# We should now wipe it from the database, as it's one time use code
self._data.pop(code)
await self.async_save()
await self._async_save()
if (
user_data
and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow()
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.now(
timezone.utc
):
return user_data["user_info"]
return None
def get_data(self):
"""Get the internal data for testing purposes."""
return self._data

View File

@@ -39,12 +39,8 @@ class OIDCDiscoveryInvalid(OIDCClientException):
type: Optional[str]
details: Optional[dict]
def __init__(self, *args, **kwargs):
if args:
self.message = args[0]
else:
self.message = "OIDC Discovery document is invalid"
def __init__(self, **kwargs):
self.message = "OIDC Discovery document is invalid"
self.type = kwargs.pop("type", None)
self.details = kwargs.pop("details", None)
super().__init__(self.message)
@@ -196,7 +192,7 @@ class OIDCDiscoveryClient:
)
raise OIDCDiscoveryInvalid(
type="does_not_support_response_mode",
modes=document["response_modes_supported"],
details={"modes": document["response_modes_supported"]},
)
# If grant_types_supported is set, should support 'authorization_code'
@@ -281,7 +277,7 @@ class OIDCDiscoveryClient:
await self._validate_discovery_document(document)
return document
async def fetch_jwks(self, jwks_uri: str | None):
async def fetch_jwks(self, jwks_uri: str | None = None):
"""Fetches JWKS."""
if jwks_uri is None:
discovery_document = await self._fetch_discovery_document()

View File

@@ -10,7 +10,7 @@ def validate_url(url: str) -> bool:
try:
parsed = urlparse(url.strip())
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
except (ValueError, TypeError):
except (ValueError, TypeError, AttributeError):
return False
@@ -23,7 +23,7 @@ def validate_discovery_url(url: str) -> bool:
and parsed.netloc
and parsed.path.endswith("/.well-known/openid-configuration")
)
except (ValueError, TypeError):
except (ValueError, TypeError, AttributeError):
return False

View File

@@ -40,7 +40,7 @@ class AsyncTemplateRenderer:
) as f:
content = await f.read()
templates[filename] = content
except (OSError, IOError) as e:
except (OSError, IOError) as e: # pragma: no cover
_LOGGER.warning("Error reading template file %s: %s", filename, e)
async def render_template(self, template_name: str, **kwargs: Any) -> str: