From 404d2451dff963e7fea6f7fed9b6a6c6e4e70b7a Mon Sep 17 00:00:00 2001 From: Christiaan Goossens <9487666+christiaangoossens@users.noreply.github.com> Date: Sun, 5 Oct 2025 21:03:02 +0200 Subject: [PATCH] Add unit tests (#133) * Add initial test & add pipeline * Add very basic YAML config tests * Add coverage reporting * Add some webserver & template loading tests * Add test cases for the helpers * Implement initial OIDC server tests * Test codestore & discovery checker * Test basics of the config flow * Add test for the HA auth provider * Cleaned up tests & test injection --- .github/workflows/test.yaml | 24 + .gitignore | 4 +- .../auth_oidc/endpoints/callback.py | 4 +- .../auth_oidc/endpoints/finish.py | 2 +- .../auth_oidc/endpoints/redirect.py | 10 +- .../auth_oidc/endpoints/welcome.py | 2 +- custom_components/auth_oidc/provider.py | 23 +- .../auth_oidc/stores/code_store.py | 19 +- .../auth_oidc/tools/oidc_client.py | 12 +- .../auth_oidc/tools/validation.py | 4 +- custom_components/auth_oidc/views/loader.py | 2 +- pyproject.toml | 10 + scripts/coverage-report | 3 + scripts/test | 2 + tests/__init__.py | 0 tests/conftest.py | 8 + tests/mocks/__init__.py | 0 tests/mocks/auth_page.html | 14 + tests/mocks/oidc_server.py | 197 ++++++ tests/mocks/scenarios/empty.json | 5 + .../invalid_code_challenge_types.json | 10 + .../mocks/scenarios/invalid_grant_types.json | 10 + .../invalid_id_token_signing_alg.json | 8 + .../scenarios/invalid_response_modes.json | 9 + .../scenarios/invalid_response_types.json | 9 + tests/mocks/scenarios/invalid_url.json | 8 + tests/mocks/scenarios/missing_jwks.json | 7 + tests/mocks/scenarios/missing_token.json | 6 + tests/mocks/scenarios/only_issuer.json | 5 + tests/mocks/scenarios/username.json | 3 + .../scenarios/wrong_id_token_signing_alg.json | 9 + .../fake_templates/folder.html/empty.txt | 0 tests/resources/fake_templates/index.html | 1 + tests/test_code_store.py | 90 +++ tests/test_hass_auth_provider.py | 229 +++++++ tests/test_hass_oidc_client.py | 287 ++++++++ tests/test_hass_ui_config_flow.py | 364 ++++++++++ tests/test_hass_webserver.py | 151 ++++ tests/test_hass_yaml_init.py | 93 +++ tests/test_helpers.py | 85 +++ tests/test_view_template.py | 49 ++ uv.lock | 644 ++++++++++++++++-- 42 files changed, 2331 insertions(+), 91 deletions(-) create mode 100644 .github/workflows/test.yaml create mode 100755 scripts/coverage-report create mode 100755 scripts/test create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/mocks/__init__.py create mode 100644 tests/mocks/auth_page.html create mode 100644 tests/mocks/oidc_server.py create mode 100644 tests/mocks/scenarios/empty.json create mode 100644 tests/mocks/scenarios/invalid_code_challenge_types.json create mode 100644 tests/mocks/scenarios/invalid_grant_types.json create mode 100644 tests/mocks/scenarios/invalid_id_token_signing_alg.json create mode 100644 tests/mocks/scenarios/invalid_response_modes.json create mode 100644 tests/mocks/scenarios/invalid_response_types.json create mode 100644 tests/mocks/scenarios/invalid_url.json create mode 100644 tests/mocks/scenarios/missing_jwks.json create mode 100644 tests/mocks/scenarios/missing_token.json create mode 100644 tests/mocks/scenarios/only_issuer.json create mode 100644 tests/mocks/scenarios/username.json create mode 100644 tests/mocks/scenarios/wrong_id_token_signing_alg.json create mode 100644 tests/resources/fake_templates/folder.html/empty.txt create mode 100644 tests/resources/fake_templates/index.html create mode 100644 tests/test_code_store.py create mode 100644 tests/test_hass_auth_provider.py create mode 100644 tests/test_hass_oidc_client.py create mode 100644 tests/test_hass_ui_config_flow.py create mode 100644 tests/test_hass_webserver.py create mode 100644 tests/test_hass_yaml_init.py create mode 100644 tests/test_helpers.py create mode 100644 tests/test_view_template.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..c63598a --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,24 @@ +--- +name: Tests (pytest) + +on: + push: + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: "Set up Python" + uses: actions/setup-python@v6 + with: + python-version-file: ".python-version" + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Sync dependencies + run: scripts/sync + - name: Test + run: scripts/test diff --git a/.gitignore b/.gitignore index 8d4e076..120c37b 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,6 @@ dmypy.json # End of https://www.gitignore.io/api/python /config/ -.venv \ No newline at end of file +.venv + +.pytest_logs.log \ No newline at end of file diff --git a/custom_components/auth_oidc/endpoints/callback.py b/custom_components/auth_oidc/endpoints/callback.py index 0b92052..178ee1e 100644 --- a/custom_components/auth_oidc/endpoints/callback.py +++ b/custom_components/auth_oidc/endpoints/callback.py @@ -67,6 +67,4 @@ class OIDCCallbackView(HomeAssistantView): return web.Response(text=view_html, content_type="text/html") code = await self.oidc_provider.async_save_user_info(user_details) - return web.HTTPFound( - get_url("/auth/oidc/finish?code=" + code, self.force_https) - ) + raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https)) diff --git a/custom_components/auth_oidc/endpoints/finish.py b/custom_components/auth_oidc/endpoints/finish.py index 1b896aa..6b60371 100644 --- a/custom_components/auth_oidc/endpoints/finish.py +++ b/custom_components/auth_oidc/endpoints/finish.py @@ -40,7 +40,7 @@ class OIDCFinishView(HomeAssistantView): return web.Response(text="No code received", status=500) # Return redirect to the main page for sign in with a cookie - return web.HTTPFound( + raise web.HTTPFound( location="/?storeToken=true", headers={ # Set a cookie to enable autologin on only the specific path used diff --git a/custom_components/auth_oidc/endpoints/redirect.py b/custom_components/auth_oidc/endpoints/redirect.py index 0421edf..e203b38 100644 --- a/custom_components/auth_oidc/endpoints/redirect.py +++ b/custom_components/auth_oidc/endpoints/redirect.py @@ -25,10 +25,14 @@ class OIDCRedirectView(HomeAssistantView): """Receive response.""" redirect_uri = get_url("/auth/oidc/callback", self.force_https) - auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri) - if auth_url: - return web.HTTPFound(auth_url) + try: + auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri) + + if auth_url: + raise web.HTTPFound(auth_url) + except RuntimeError: + pass view_html = await get_view( "error", diff --git a/custom_components/auth_oidc/endpoints/welcome.py b/custom_components/auth_oidc/endpoints/welcome.py index 3b3eba7..7088bbb 100644 --- a/custom_components/auth_oidc/endpoints/welcome.py +++ b/custom_components/auth_oidc/endpoints/welcome.py @@ -23,7 +23,7 @@ class OIDCWelcomeView(HomeAssistantView): """Receive response.""" if not self.is_enabled: - return web.HTTPTemporaryRedirect(get_url("/", self.force_https)) + raise web.HTTPTemporaryRedirect(get_url("/", self.force_https)) view_html = await get_view("welcome", {"name": self.name}) return web.Response(text=view_html, content_type="text/html") diff --git a/custom_components/auth_oidc/provider.py b/custom_components/auth_oidc/provider.py index cd50cfa..02b864f 100644 --- a/custom_components/auth_oidc/provider.py +++ b/custom_components/auth_oidc/provider.py @@ -177,9 +177,9 @@ class OpenIDAuthProvider(AuthProvider): # If person creation is enabled, add a person for this user if self.create_persons: user_meta = await self.async_user_meta_for_credentials(credential) - await self.async_create_person(user, user_meta.name) + await self._async_create_person(user, user_meta.name) - async def async_create_person(self, user: User, name: str) -> None: + async def _async_create_person(self, user: User, name: str) -> None: """Create a person for the user.""" _LOGGER.info("Automatically creating person for new user %s", user.id) @@ -194,7 +194,7 @@ class OpenIDAuthProvider(AuthProvider): # pylint: disable=broad-exception-caught except Exception: _LOGGER.warning( - "Requested automatic person creation, but person creation failed." + "Requested automatic person creation, but person creation failed" ) # pylint: enable=broad-exception-caught @@ -315,7 +315,7 @@ class OpenIdLoginFlow(LoginFlow): """Handle the step of the form.""" # Try to use the user input first - if user_input is not None: + if user_input is not None and "code" in user_input: try: return await self._finalize_user(user_input["code"]) except InvalidAuthError: @@ -323,14 +323,15 @@ class OpenIdLoginFlow(LoginFlow): # If not available, check the cookie req = http.current_request.get() - code_cookie = req.cookies.get("auth_oidc_code") + if req and req.cookies: + code_cookie = req.cookies.get("auth_oidc_code") - if code_cookie: - _LOGGER.debug("Code cookie found on login: %s", code_cookie) - try: - return await self._finalize_user(code_cookie) - except InvalidAuthError: - pass + if code_cookie: + _LOGGER.debug("Code cookie found on login: %s", code_cookie) + try: + return await self._finalize_user(code_cookie) + except InvalidAuthError: + pass # If none are available, just show the form return self._show_login_form() diff --git a/custom_components/auth_oidc/stores/code_store.py b/custom_components/auth_oidc/stores/code_store.py index 6c2447e..d3aedab 100644 --- a/custom_components/auth_oidc/stores/code_store.py +++ b/custom_components/auth_oidc/stores/code_store.py @@ -3,7 +3,7 @@ import random import string -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import cast, Optional from homeassistant.helpers.storage import Store from homeassistant.core import HomeAssistant @@ -31,7 +31,7 @@ class CodeStore: data = cast(dict[str, UserDetails], {}) self._data = data - async def async_save(self) -> None: + async def _async_save(self) -> None: """Save data.""" if self._data is not None: await self._store.async_save(self._data) @@ -46,7 +46,7 @@ class CodeStore: raise RuntimeError("Data not loaded") code = self._generate_code() - expiration = datetime.utcnow() + timedelta(minutes=5) + expiration = datetime.now(timezone.utc) + timedelta(minutes=5) self._data[code] = { "user_info": user_info, @@ -54,7 +54,7 @@ class CodeStore: "expiration": expiration.isoformat(), } - await self.async_save() + await self._async_save() return code async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]: @@ -67,12 +67,15 @@ class CodeStore: if user_data: # We should now wipe it from the database, as it's one time use code self._data.pop(code) - await self.async_save() + await self._async_save() - if ( - user_data - and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow() + if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.now( + timezone.utc ): return user_data["user_info"] return None + + def get_data(self): + """Get the internal data for testing purposes.""" + return self._data diff --git a/custom_components/auth_oidc/tools/oidc_client.py b/custom_components/auth_oidc/tools/oidc_client.py index a23dccc..59f88cc 100644 --- a/custom_components/auth_oidc/tools/oidc_client.py +++ b/custom_components/auth_oidc/tools/oidc_client.py @@ -39,12 +39,8 @@ class OIDCDiscoveryInvalid(OIDCClientException): type: Optional[str] details: Optional[dict] - def __init__(self, *args, **kwargs): - if args: - self.message = args[0] - else: - self.message = "OIDC Discovery document is invalid" - + def __init__(self, **kwargs): + self.message = "OIDC Discovery document is invalid" self.type = kwargs.pop("type", None) self.details = kwargs.pop("details", None) super().__init__(self.message) @@ -196,7 +192,7 @@ class OIDCDiscoveryClient: ) raise OIDCDiscoveryInvalid( type="does_not_support_response_mode", - modes=document["response_modes_supported"], + details={"modes": document["response_modes_supported"]}, ) # If grant_types_supported is set, should support 'authorization_code' @@ -281,7 +277,7 @@ class OIDCDiscoveryClient: await self._validate_discovery_document(document) return document - async def fetch_jwks(self, jwks_uri: str | None): + async def fetch_jwks(self, jwks_uri: str | None = None): """Fetches JWKS.""" if jwks_uri is None: discovery_document = await self._fetch_discovery_document() diff --git a/custom_components/auth_oidc/tools/validation.py b/custom_components/auth_oidc/tools/validation.py index 9a98566..3e5920e 100644 --- a/custom_components/auth_oidc/tools/validation.py +++ b/custom_components/auth_oidc/tools/validation.py @@ -10,7 +10,7 @@ def validate_url(url: str) -> bool: try: parsed = urlparse(url.strip()) return bool(parsed.scheme in ("http", "https") and parsed.netloc) - except (ValueError, TypeError): + except (ValueError, TypeError, AttributeError): return False @@ -23,7 +23,7 @@ def validate_discovery_url(url: str) -> bool: and parsed.netloc and parsed.path.endswith("/.well-known/openid-configuration") ) - except (ValueError, TypeError): + except (ValueError, TypeError, AttributeError): return False diff --git a/custom_components/auth_oidc/views/loader.py b/custom_components/auth_oidc/views/loader.py index eba1ff4..b63f885 100644 --- a/custom_components/auth_oidc/views/loader.py +++ b/custom_components/auth_oidc/views/loader.py @@ -40,7 +40,7 @@ class AsyncTemplateRenderer: ) as f: content = await f.read() templates[filename] = content - except (OSError, IOError) as e: + except (OSError, IOError) as e: # pragma: no cover _LOGGER.warning("Error reading template file %s: %s", filename, e) async def render_template(self, template_name: str, **kwargs: Any) -> str: diff --git a/pyproject.toml b/pyproject.toml index 7910a83..0185df4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "aiofiles~=24.1", "jinja2~=3.1", "bcrypt~=4.2", + "joserfc>=1.3.4", ] readme = "README.md" requires-python = "~=3.13.7" @@ -19,6 +20,10 @@ requires-python = "~=3.13.7" dev = [ "homeassistant~=2025.8", "pylint~=3.3", + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pytest-cov>=7.0.0", + "pytest-homeassistant-custom-component>=0.13.286", "ruff~=0.12", ] @@ -34,3 +39,8 @@ allow-direct-references = true [tool.hatch.build.targets.wheel] packages = ["custom_components/auth_oidc"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +addopts = "--cov=custom_components --cov-fail-under=0" +log_level = "DEBUG" diff --git a/scripts/coverage-report b/scripts/coverage-report new file mode 100755 index 0000000..4b3abf5 --- /dev/null +++ b/scripts/coverage-report @@ -0,0 +1,3 @@ +#! /bin/bash +uv run pytest --cov-report html tests/ +uv run python -m http.server 8000 -d htmlcov \ No newline at end of file diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000..c6c38a1 --- /dev/null +++ b/scripts/test @@ -0,0 +1,2 @@ +#! /bin/bash +uv run pytest --cov-report term:skip-covered tests/ \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..83add63 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +"""Fixtures for testing.""" + +import pytest + + +@pytest.fixture(autouse=True) +def auto_enable_custom_integrations(enable_custom_integrations): + yield diff --git a/tests/mocks/__init__.py b/tests/mocks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mocks/auth_page.html b/tests/mocks/auth_page.html new file mode 100644 index 0000000..758e37d --- /dev/null +++ b/tests/mocks/auth_page.html @@ -0,0 +1,14 @@ + + + + + + + Test + + + + Test page + + + \ No newline at end of file diff --git a/tests/mocks/oidc_server.py b/tests/mocks/oidc_server.py new file mode 100644 index 0000000..6558ddb --- /dev/null +++ b/tests/mocks/oidc_server.py @@ -0,0 +1,197 @@ +"""A simple mock OIDC server for testing purposes.""" + +from contextlib import contextmanager +import time +import logging +import hashlib +import random +import json +import os +from unittest.mock import AsyncMock, patch +from urllib.parse import urlparse, parse_qs +from joserfc import jwt +from joserfc.jwk import RSAKey, KeySet + +_LOGGER = logging.getLogger(__name__) + +BASE_URL = "https://oidc.example.com" +SUBJECT = "testuser" + + +class MockOIDCServer: + """A simple mock OIDC server for testing purposes.""" + + _code_storage = {} + _scenario = {} + + def __init__(self, scenario: str | None = None): + """Initialize the mock OIDC server.""" + # Create a JWK private key + self._jwk = RSAKey.generate_key( + 2048, {"alg": "RS256", "use": "sig"}, private=True, auto_kid=True + ) + + if scenario: + # Load scenario JSON file from disk + scenario_path = os.path.join( + os.path.dirname(__file__), "scenarios", f"{scenario}.json" + ) + with open(scenario_path, "r", encoding="utf-8") as f: + self._scenario = json.load(f) + + # Log it + _LOGGER.debug("Loaded scenario: %s", self._scenario) + + def get_random_code(self): + """Return a random authorization code.""" + return "".join(str(random.randint(0, 9)) for _ in range(6)) + + @staticmethod + def get_discovery_url(): + """Return the discovery URL for the given base URL.""" + return f"{BASE_URL}/.well-known/openid-configuration" + + @staticmethod + def get_authorize_url(): + """Return the authorization URL for the given base URL.""" + return f"{BASE_URL}/authorize" + + def process_request(self, url: str, method: str, body: dict) -> tuple[dict, int]: + """Process a request to the mock OIDC server.""" + _LOGGER.debug("Received %s request to %s in OIDC mock server", method, url) + + if url == self.get_discovery_url() and method == "GET": + response = self._get_discovery_document() + elif url.startswith(self.get_authorize_url()) and method == "GET": + response = self._get_authorize_response(url) + elif url == f"{BASE_URL}/token" and method == "POST": + response = self._get_token_response(body) + elif url == f"{BASE_URL}/jwks" and method == "GET": + response = self._get_jwks_response() + else: + response = {"error": "Unknown endpoint"}, 404 + + _LOGGER.debug("Responding with: %s", response) + return response + + def _get_discovery_document(self) -> tuple[dict, int]: + """Return a mock discovery document.""" + + if "discovery" in self._scenario: + return self._scenario["discovery"], 200 + + return { + "issuer": BASE_URL, + "authorization_endpoint": self.get_authorize_url(), + "token_endpoint": f"{BASE_URL}/token", + "userinfo_endpoint": f"{BASE_URL}/userinfo", + "jwks_uri": f"{BASE_URL}/jwks", + "id_token_signing_alg_values_supported": ["RS256"], + }, 200 + + def _get_authorize_response(self, url: str) -> tuple[dict, int]: + """Return a mock authorization response.""" + # Parse the url + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + code = self.get_random_code() + self._code_storage[code] = query_params + + return {"code": code, "state": "xyz"}, 200 + + def _get_token_response(self, body: dict) -> tuple[dict, int]: + """Return a mock token response.""" + + if body.get("code") in self._code_storage: + # TODO: Verify PKCE? + return { + "access_token": "exampleAccessToken", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": self._create_id_token(body.get("code")), + }, 200 + else: + return {"error": "invalid_request"}, 400 + + def _create_id_token(self, code: str) -> str: + """Create a mock ID token.""" + # Get the query params + if code not in self._code_storage: + raise ValueError("Invalid code") + query_params = self._code_storage[code] + _LOGGER.debug("Creating ID token with query params: %s", query_params) + + # Get username + if "username" in self._scenario: + username = self._scenario["username"] + else: + username = "testuser" + + # Create a simple signed JWT with our JWK + header = {"alg": self._jwk.alg, "kid": self._jwk.kid} + claims = { + "iss": BASE_URL, + "sub": SUBJECT, + "aud": query_params.get("client_id", [""])[0], + "nonce": query_params.get("nonce", [""])[0], + "name": "Test Name", + "preferred_username": username, + } + + now = int(time.time()) + claims["nbf"] = now + claims["iat"] = now + claims["exp"] = now + 3600 # 1 hour expiry + + return jwt.encode(header, claims, self._jwk) + + def _get_jwks_response(self) -> tuple[dict, int]: + """Return a mock JWKS response.""" + private_key = self._jwk + public_key_dict = private_key.as_dict(private=False) + public_key = RSAKey.import_key( + public_key_dict, {"use": "sig", "alg": "RS256", "kid": private_key.kid} + ) + + key_set = KeySet([public_key]) + + return key_set.as_dict(), 200 + + @staticmethod + def get_final_subject(): + """Return the subject that's returned to HA.""" + return hashlib.sha256(f"{BASE_URL}.{SUBJECT}".encode("utf-8")).hexdigest() + + +@contextmanager +def mock_oidc_responses(scenario: str | None = None): + """Mock OIDC responses for testing.""" + + mock_oidc_server = MockOIDCServer(scenario) + + def make_mock_response(json_data, status): + mock_response = AsyncMock() + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + mock_response.json = AsyncMock(return_value=json_data) + mock_response.status = status + return mock_response + + def default_handler(method, url, *args, **kwargs): + _LOGGER.debug("Mocked %s request to %s", method, url) + body = kwargs.get("data") or kwargs.get("json") or None + response = mock_oidc_server.process_request(url, method, body) + return make_mock_response(response[0], response[1]) + + def get_side_effect(url, *args, **kwargs): + return default_handler("GET", url, *args, **kwargs) + + def post_side_effect(url, *args, **kwargs): + return default_handler("POST", url, *args, **kwargs) + + with ( + patch("aiohttp.ClientSession.get", side_effect=get_side_effect) as get_patch, + patch("aiohttp.ClientSession.post", side_effect=post_side_effect) as post_patch, + ): + yield (get_patch, post_patch, default_handler) diff --git a/tests/mocks/scenarios/empty.json b/tests/mocks/scenarios/empty.json new file mode 100644 index 0000000..4bc4cb7 --- /dev/null +++ b/tests/mocks/scenarios/empty.json @@ -0,0 +1,5 @@ +{ + "discovery": { + + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_code_challenge_types.json b/tests/mocks/scenarios/invalid_code_challenge_types.json new file mode 100644 index 0000000..8d516e4 --- /dev/null +++ b/tests/mocks/scenarios/invalid_code_challenge_types.json @@ -0,0 +1,10 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks", + "response_types_supported": ["code"], + "code_challenge_methods_supported": ["plain"] + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_grant_types.json b/tests/mocks/scenarios/invalid_grant_types.json new file mode 100644 index 0000000..565172a --- /dev/null +++ b/tests/mocks/scenarios/invalid_grant_types.json @@ -0,0 +1,10 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks", + "response_types_supported": ["code"], + "grant_types_supported": ["refresh_token"] + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_id_token_signing_alg.json b/tests/mocks/scenarios/invalid_id_token_signing_alg.json new file mode 100644 index 0000000..fe8b5d4 --- /dev/null +++ b/tests/mocks/scenarios/invalid_id_token_signing_alg.json @@ -0,0 +1,8 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks" + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_response_modes.json b/tests/mocks/scenarios/invalid_response_modes.json new file mode 100644 index 0000000..8cd1b46 --- /dev/null +++ b/tests/mocks/scenarios/invalid_response_modes.json @@ -0,0 +1,9 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks", + "response_modes_supported": ["post"] + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_response_types.json b/tests/mocks/scenarios/invalid_response_types.json new file mode 100644 index 0000000..7b8cc39 --- /dev/null +++ b/tests/mocks/scenarios/invalid_response_types.json @@ -0,0 +1,9 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks", + "response_types_supported": ["token"] + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/invalid_url.json b/tests/mocks/scenarios/invalid_url.json new file mode 100644 index 0000000..0c7df6d --- /dev/null +++ b/tests/mocks/scenarios/invalid_url.json @@ -0,0 +1,8 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "/jwks" + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/missing_jwks.json b/tests/mocks/scenarios/missing_jwks.json new file mode 100644 index 0000000..06c5bfa --- /dev/null +++ b/tests/mocks/scenarios/missing_jwks.json @@ -0,0 +1,7 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token" + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/missing_token.json b/tests/mocks/scenarios/missing_token.json new file mode 100644 index 0000000..a3c3ad2 --- /dev/null +++ b/tests/mocks/scenarios/missing_token.json @@ -0,0 +1,6 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize" + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/only_issuer.json b/tests/mocks/scenarios/only_issuer.json new file mode 100644 index 0000000..a0a0390 --- /dev/null +++ b/tests/mocks/scenarios/only_issuer.json @@ -0,0 +1,5 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local" + } +} \ No newline at end of file diff --git a/tests/mocks/scenarios/username.json b/tests/mocks/scenarios/username.json new file mode 100644 index 0000000..de8f9e7 --- /dev/null +++ b/tests/mocks/scenarios/username.json @@ -0,0 +1,3 @@ +{ + "username": "foobar" +} \ No newline at end of file diff --git a/tests/mocks/scenarios/wrong_id_token_signing_alg.json b/tests/mocks/scenarios/wrong_id_token_signing_alg.json new file mode 100644 index 0000000..1da8b18 --- /dev/null +++ b/tests/mocks/scenarios/wrong_id_token_signing_alg.json @@ -0,0 +1,9 @@ +{ + "discovery": { + "issuer": "https://mock-oidc-server.local", + "authorization_endpoint": "https://mock-oidc-server.local/authorize", + "token_endpoint": "https://mock-oidc-server.local/token", + "jwks_uri": "https://mock-oidc-server.local/jwks", + "id_token_signing_alg_values_supported": ["HS256"] + } +} \ No newline at end of file diff --git a/tests/resources/fake_templates/folder.html/empty.txt b/tests/resources/fake_templates/folder.html/empty.txt new file mode 100644 index 0000000..e69de29 diff --git a/tests/resources/fake_templates/index.html b/tests/resources/fake_templates/index.html new file mode 100644 index 0000000..b4b912c --- /dev/null +++ b/tests/resources/fake_templates/index.html @@ -0,0 +1 @@ +

Example template

\ No newline at end of file diff --git a/tests/test_code_store.py b/tests/test_code_store.py new file mode 100644 index 0000000..055cdbc --- /dev/null +++ b/tests/test_code_store.py @@ -0,0 +1,90 @@ +"""Tests for the code store""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch +from homeassistant.core import HomeAssistant + +import pytest + +from auth_oidc.stores.code_store import CodeStore + + +@pytest.mark.asyncio +async def test_code_store_generate_and_receive_code(hass: HomeAssistant): + """Test generating and receiving a code.""" + store_mock = AsyncMock() + with patch("homeassistant.helpers.storage.Store", return_value=store_mock): + code_store = CodeStore(hass) + + # Simulate loading with empty data + store_mock.async_load.return_value = {} + await code_store.async_load() + assert code_store.get_data() == {} + + user_info = {"sub": "user1", "name": "Test User"} + code = await code_store.async_generate_code_for_userinfo(user_info) + assert code in code_store.get_data() + + # Should return user_info and remove the code + with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock: + dt_mock.utcnow.return_value = datetime.now(timezone.utc) + dt_mock.fromisoformat.side_effect = datetime.fromisoformat + result = await code_store.receive_userinfo_for_code(code) + assert result == user_info + assert code not in code_store.get_data() + + +@pytest.mark.asyncio +async def test_code_store_expired_code(hass): + """Test that expired codes return None.""" + store_mock = AsyncMock() + with patch("homeassistant.helpers.storage.Store", return_value=store_mock): + code_store = CodeStore(hass) + store_mock.async_load.return_value = {} + await code_store.async_load() + assert code_store.get_data() == {} + + user_info = {"sub": "user2", "name": "Expired User"} + code = await code_store.async_generate_code_for_userinfo(user_info) + + # Patch expiration to be in the past + code_store.get_data()[code]["expiration"] = ( + datetime.now(timezone.utc) - timedelta(minutes=10) + ).isoformat() + + with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock: + dt_mock.utcnow.return_value = datetime.now(timezone.utc) + dt_mock.fromisoformat.side_effect = datetime.fromisoformat + result = await code_store.receive_userinfo_for_code(code) + assert result is None + assert code not in code_store.get_data() + + +@pytest.mark.asyncio +async def test_code_store_data_not_loaded(hass): + """Test that using the store before loading raises RuntimeError.""" + store_mock = AsyncMock() + with patch("homeassistant.helpers.storage.Store", return_value=store_mock): + code_store = CodeStore(hass) + + # Data is not loaded yet, should result in RuntimeError + + with pytest.raises(RuntimeError): + await code_store.async_generate_code_for_userinfo({"sub": "user3"}) + with pytest.raises(RuntimeError): + await code_store.receive_userinfo_for_code("123456") + + +@pytest.mark.asyncio +async def test_code_store_generate_code_length(hass): + """Test that generated codes are 6 digits.""" + store_mock = AsyncMock() + with patch("homeassistant.helpers.storage.Store", return_value=store_mock): + code_store = CodeStore(hass) + store_mock.async_load.return_value = {} + await code_store.async_load() + assert code_store.get_data() == {} + user_info = {"sub": "user4"} + code = await code_store.async_generate_code_for_userinfo(user_info) + assert len(code) == 6 + assert code.isdigit() diff --git a/tests/test_hass_auth_provider.py b/tests/test_hass_auth_provider.py new file mode 100644 index 0000000..a5342f8 --- /dev/null +++ b/tests/test_hass_auth_provider.py @@ -0,0 +1,229 @@ +"""Tests for the Auth Provider registration in HA""" + +from urllib.parse import urlparse, parse_qs +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType +from homeassistant.setup import async_setup_component +from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.components.person import DOMAIN as PERSON_DOMAIN + +from custom_components.auth_oidc import DOMAIN +from custom_components.auth_oidc.config.const import ( + DISCOVERY_URL, + CLIENT_ID, + FEATURES, + FEATURES_AUTOMATIC_PERSON_CREATION, + FEATURES_AUTOMATIC_USER_LINKING, +) +from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses + + +async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool: + """Set up the auth_oidc component.""" + result = await async_setup_component(hass, DOMAIN, {DOMAIN: config}) + + if expect_success: + assert result + assert DOMAIN in hass.data + + +@pytest.mark.asyncio +async def test_setup_success_auth_provider_registration(hass: HomeAssistant): + """Test successful setup""" + await setup( + hass, + { + CLIENT_ID: "dummy", + DISCOVERY_URL: "https://example.com/.well-known/openid-configuration", + }, + True, + ) + + # Ensure the auth provider is registered + auth_providers = hass.auth.get_auth_providers(DOMAIN) + assert len(auth_providers) == 1 + + +async def login_user(hass: HomeAssistant, code: str): + """Helper to login a user.""" + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + flow = await provider.async_login_flow({}) + + result = await flow.async_step_init({"code": code}) + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["data"] is not None + + data = result["data"] + sub = data["sub"] + assert sub == MockOIDCServer.get_final_subject() + + # Get credentials + credentials = await provider.async_get_or_create_credentials(data) + assert credentials is not None + assert credentials.data["sub"] == sub + + user = await hass.auth.async_get_or_create_user(credentials) + assert user.is_active + return user + + +async def get_login_code(hass: HomeAssistant, hass_client): + """Helper to get a login code.""" + client = await hass_client() + resp = await client.get("/auth/oidc/redirect", allow_redirects=False) + assert resp.status == 302 + location = resp.headers["Location"] + parsed_url = urlparse(location) + query_params = parse_qs(parsed_url.query) + state = query_params["state"][0] + + session = async_get_clientsession(hass) + resp = session.get(location, allow_redirects=False) + assert resp.status == 200 + + json_parsed = await resp.json() + assert "code" in json_parsed and json_parsed["code"] + + code = json_parsed["code"] + client = await hass_client() + resp = await client.get( + f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False + ) + + assert resp.status == 302 + location = resp.headers["Location"] + assert "/auth/oidc/finish?code=" in location + + # Get the code from the finish URL + code = location.split("code=")[1] + return code + + +@pytest.mark.asyncio +async def test_full_login(hass: HomeAssistant, hass_client): + """Test a full login flow.""" + await setup( + hass, + { + CLIENT_ID: "dummy", + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + FEATURES: { + FEATURES_AUTOMATIC_PERSON_CREATION: False, + FEATURES_AUTOMATIC_USER_LINKING: False, + }, + }, + True, + ) + + with mock_oidc_responses(): + # Actually start the login and get a code + code = await get_login_code(hass, hass_client) + + # Use the code to login directly with the registered auth provider + # Inspired by tests for the built-in providers + user = await login_user(hass, code) + assert user.name == "Test Name" + + # Login again to see if we trigger the re-use path + code2 = await get_login_code(hass, hass_client) + user2 = await login_user(hass, code2) + assert user2.id == user.id + + +@pytest.mark.asyncio +async def test_login_with_linking(hass: HomeAssistant, hass_client): + """Test a linking login.""" + await setup( + hass, + { + CLIENT_ID: "dummy", + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + FEATURES: { + FEATURES_AUTOMATIC_PERSON_CREATION: False, + FEATURES_AUTOMATIC_USER_LINKING: True, + }, + }, + True, + ) + + with mock_oidc_responses("username"): + # Create a user first with username 'foobar' + user = await hass.auth.async_create_user("Foo Bar") + assert user.is_active + + hass_provider = hass.auth.get_auth_providers("homeassistant")[0] + credential = await hass_provider.async_get_or_create_credentials( + {"username": "foobar"} + ) + await hass.auth.async_link_user(user, credential) + + # Actually start the login and get a code + code = await get_login_code(hass, hass_client) + + # Use the code to login directly with the registered auth provider + user2 = await login_user(hass, code) + assert user2.id == user.id # Assert that the user was linked + + +@pytest.mark.asyncio +async def test_login_with_person_create(hass: HomeAssistant, hass_client): + """Test a person create.""" + await setup( + hass, + { + CLIENT_ID: "dummy", + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + FEATURES: { + FEATURES_AUTOMATIC_PERSON_CREATION: True, + FEATURES_AUTOMATIC_USER_LINKING: False, + }, + }, + True, + ) + + await async_setup_component(hass, PERSON_DOMAIN, {}) + + with mock_oidc_responses(): + code = await get_login_code(hass, hass_client) + user = await login_user(hass, code) + assert user.is_active + + # Find the person associated to this user using the PersonRegistry API + person_store = hass.data[PERSON_DOMAIN][1] + persons = person_store.async_items() + assert len(persons) == 1 + + person = persons[0] + assert person["user_id"] == user.id + + +@pytest.mark.asyncio +async def test_login_shows_form(hass: HomeAssistant): + """Test a login""" + await setup( + hass, + { + CLIENT_ID: "dummy", + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + FEATURES: { + FEATURES_AUTOMATIC_PERSON_CREATION: False, + FEATURES_AUTOMATIC_USER_LINKING: False, + }, + }, + True, + ) + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + flow = await provider.async_login_flow({}) + + result = await flow.async_step_init({}) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "mfa" + + # Attempt an invalid code + result = await flow.async_step_init({"code": "invalid"}) + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {"base": "invalid_auth"} diff --git a/tests/test_hass_oidc_client.py b/tests/test_hass_oidc_client.py new file mode 100644 index 0000000..e5e4176 --- /dev/null +++ b/tests/test_hass_oidc_client.py @@ -0,0 +1,287 @@ +"""Tests for the OIDC client""" + +from urllib.parse import urlparse, parse_qs +import pytest +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +from auth_oidc import DOMAIN +from auth_oidc.tools.oidc_client import OIDCDiscoveryClient, OIDCDiscoveryInvalid +from auth_oidc.config.const import ( + DISCOVERY_URL, + CLIENT_ID, +) + +from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses + +EXAMPLE_CLIENT_ID = "dummyclient" + + +async def setup(hass: HomeAssistant): + """Set up the integration within Home Assistant""" + mock_config = { + DOMAIN: { + CLIENT_ID: EXAMPLE_CLIENT_ID, + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + } + } + + result = await async_setup_component(hass, DOMAIN, mock_config) + assert result + + +@pytest.mark.asyncio +async def test_full_oidc_flow(hass: HomeAssistant, hass_client): + """Test that one full OIDC flow works if OIDC is mocked.""" + + await setup(hass) + + with mock_oidc_responses(): + # Start by going to /auth/oidc/redirect + client = await hass_client() + resp = await client.get("/auth/oidc/redirect", allow_redirects=False) + assert resp.status == 302 + assert resp.headers["Location"].startswith(MockOIDCServer.get_authorize_url()) + + # Parse the location header and test the query params for correctness + location = resp.headers["Location"] + parsed_url = urlparse(location) + query_params = parse_qs(parsed_url.query) + + assert "response_type" in query_params and query_params.get( + "response_type" + ) == ["code"] + assert "client_id" in query_params and query_params.get("client_id") == [ + EXAMPLE_CLIENT_ID + ] + assert "scope" in query_params and query_params.get("scope") == [ + "openid profile groups" + ] + assert "state" in query_params and query_params["state"] + state = query_params["state"][0] + assert len(state) >= 16 # Ensure state is sufficiently long + assert ( + "redirect_uri" in query_params + and query_params["redirect_uri"] + and query_params["redirect_uri"][0].endswith("/auth/oidc/callback") + ) # TODO: Also test that the URL itself is correct + assert "nonce" in query_params and query_params["nonce"] + assert "code_challenge_method" in query_params and query_params.get( + "code_challenge_method" + ) == ["S256"] + assert "code_challenge" in query_params and query_params["code_challenge"] + + session = async_get_clientsession(hass) + resp = session.get(location, allow_redirects=False) + assert resp.status == 200 + + json_parsed = await resp.json() + assert "code" in json_parsed and json_parsed["code"] + + # Now go back to the callback with a sample code + code = json_parsed["code"] + client = await hass_client() + resp = await client.get( + f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False + ) + + # TODO: Test if logged text contains our login + # TODO: Test if the code actually works + assert resp.status == 302 + assert "/auth/oidc/finish?code=" in resp.headers["Location"] + + +async def discovery_test_through_redirect( + hass_client, caplog, scenario: str, match_log_line: str +): + """Test that discovery document retrieval fails gracefully through redirect endpoint.""" + with mock_oidc_responses(scenario): + # Start by going to /auth/oidc/redirect + client = await hass_client() + resp = await client.get("/auth/oidc/redirect", allow_redirects=False) + + # Find matching log line + assert match_log_line in caplog.text + + # Assert that we get a 200 response with an error message + assert resp.status == 200 + text = await resp.text() + assert "Integration is misconfigured, discovery could not be obtained." in text + + +async def direct_discovery_test( + hass: HomeAssistant, + scenario: str, + match_type: str, + match_log_line: str | None = None, +): + """Test that discovery document retrieval fails with nice error directly.""" + with mock_oidc_responses(scenario): + session = async_get_clientsession(hass) + client = OIDCDiscoveryClient( + MockOIDCServer.get_discovery_url(), + session, + { + "id_token_signing_alg": "RS256", + }, + ) + + with pytest.raises(OIDCDiscoveryInvalid) as exc_info: + await client.fetch_discovery_document() + + assert exc_info.value.type == match_type + assert exc_info.value.get_detail_string().startswith("type: " + match_type) + + if match_log_line: + assert match_log_line in exc_info.value.get_detail_string() + + +@pytest.mark.asyncio +async def test_discovery_failures(hass: HomeAssistant, hass_client, caplog): + """Test that discovery document retrieval fails gracefully.""" + + await setup(hass) + + # Empty scenario + await discovery_test_through_redirect( + hass_client, caplog, "empty", "is missing required endpoint: issuer" + ) + await direct_discovery_test(hass, "empty", "missing_endpoint", "endpoint: issuer") + + # Missing authorization_endpoint + await discovery_test_through_redirect( + hass_client, + caplog, + "only_issuer", + "is missing required endpoint: authorization_endpoint", + ) + await direct_discovery_test( + hass, "only_issuer", "missing_endpoint", "endpoint: authorization_endpoint" + ) + + # Missing token_endpoint + await discovery_test_through_redirect( + hass_client, + caplog, + "missing_token", + "is missing required endpoint: token_endpoint", + ) + await direct_discovery_test( + hass, "missing_token", "missing_endpoint", "endpoint: token_endpoint" + ) + + # Missing jwks_uri + await discovery_test_through_redirect( + hass_client, + caplog, + "missing_jwks", + "is missing required endpoint: jwks_uri", + ) + await direct_discovery_test( + hass, "missing_jwks", "missing_endpoint", "endpoint: jwks_uri" + ) + + # Invalid response_modes_supported + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_response_modes", + "does not support required 'query' response mode, only supports: ['post']", + ) + await direct_discovery_test( + hass, "invalid_response_modes", "does_not_support_response_mode", "post" + ) + + # Invalid grant_types supported + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_grant_types", + "does not support required 'authorization_code' grant type, only supports: ['refresh_token']", + ) + await direct_discovery_test( + hass, "invalid_grant_types", "does_not_support_grant_type", "refresh_token" + ) + + # Invalid response types + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_response_types", + "does not support required 'code' response type, only supports: ['token']", + ) + await direct_discovery_test( + hass, "invalid_response_types", "does_not_support_response_type", "token" + ) + + # Invalid code_challenge types + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_code_challenge_types", + "does not support required 'S256' code challenge method, only supports: ['plain']", + ) + await direct_discovery_test( + hass, + "invalid_code_challenge_types", + "does_not_support_required_code_challenge_method", + "plain", + ) + + # Invalid id_token_signing alg + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_id_token_signing_alg", + "does not have 'id_token_signing_alg_values_supported' field", + ) + await direct_discovery_test( + hass, "invalid_id_token_signing_alg", "missing_id_token_signing_alg_values" + ) + + # Not matching id_token_signing alg + await discovery_test_through_redirect( + hass_client, + caplog, + "wrong_id_token_signing_alg", + "does not support requested id_token_signing_alg 'RS256', only supports: ['HS256']", + ) + await direct_discovery_test( + hass, + "wrong_id_token_signing_alg", + "does_not_support_id_token_signing_alg", + "requested: RS256, supported: ['HS256']", + ) + + # Invalid URL + await discovery_test_through_redirect( + hass_client, + caplog, + "invalid_url", + "has invalid URL in endpoint: jwks_uri (/jwks)", + ) + await direct_discovery_test( + hass, + "invalid_url", + "invalid_endpoint", + "endpoint: jwks_uri, url: /jwks", + ) + + +@pytest.mark.asyncio +async def test_direct_jwks_fetch(hass: HomeAssistant): + """Test direct fetch of JWKS.""" + with mock_oidc_responses(): + session = async_get_clientsession(hass) + client = OIDCDiscoveryClient( + MockOIDCServer.get_discovery_url(), + session, + { + "id_token_signing_alg": "RS256", + }, + ) + + await client.fetch_discovery_document() + jwks = await client.fetch_jwks() + assert "keys" in jwks diff --git a/tests/test_hass_ui_config_flow.py b/tests/test_hass_ui_config_flow.py new file mode 100644 index 0000000..e49c998 --- /dev/null +++ b/tests/test_hass_ui_config_flow.py @@ -0,0 +1,364 @@ +"""Tests for the UI config flow""" + +import pytest + +from homeassistant import config_entries +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + + +from custom_components.auth_oidc import DOMAIN, async_setup_entry +from custom_components.auth_oidc.config.const import ( + OIDC_PROVIDERS, + CLIENT_ID, + CLIENT_SECRET, + DISCOVERY_URL, + DISPLAY_NAME, + FEATURES, + FEATURES_AUTOMATIC_USER_LINKING, + FEATURES_AUTOMATIC_PERSON_CREATION, + FEATURES_INCLUDE_GROUPS_SCOPE, + CLAIMS, + CLAIMS_DISPLAY_NAME, + CLAIMS_GROUPS, + CLAIMS_USERNAME, + ROLES, + ROLE_ADMINS, + ROLE_USERS, +) + +from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses + +DEMO_CLIENT_ID = "testing_example_client_id" +DEMO_CLIENT_SECRET = "faz" +DEMO_ADMIN_ROLE = "boo" +DEMO_USER_ROLE = "far" + + +@pytest.mark.asyncio +async def test_full_config_flow_success(hass: HomeAssistant): + """Test a successful full config flow.""" + + with mock_oidc_responses(): + # 1. Start the user step + # This simulates clicking "Add Integration" in the UI. + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + # Assert that it's a form and expects user input for the 'user' step + # 'user' is always the first step if it is user triggered + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" + assert result["data_schema"] is not None + schema = result["data_schema"] + # Extract the schema dict from voluptuous Schema + schema_dict = schema.schema + # Assert 'provider' is a key in the schema + assert "provider" in schema_dict + # Assert 'authentik' is one of the allowed values for 'provider' + provider_field = schema_dict["provider"] + # If provider_field is a voluptuous In validator, get its container + allowed_providers = getattr(provider_field, "container", None) + assert "authentik" in OIDC_PROVIDERS + assert allowed_providers is not None and "authentik" in allowed_providers + + assert result["errors"] == {} + + # 2. Submit user input for the 'user' step + # This simulates the user filling out host/port + user_input_step_user = {"provider": "authentik"} + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input_step_user + ) + + # Assert that it proceeds to the 'auth' step + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "discovery_url" + assert result["data_schema"] is not None + assert result["errors"] == {} + + # Fill in the discovery URL + user_input_step_discovery = { + "discovery_url": MockOIDCServer.get_discovery_url() + } + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input_step_discovery + ) + + # Assert that it proceeds to the 'credentials' step + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "validate_connection" + + # Assert that it validates correctly with our mock + assert result["errors"] == {} + + # Send in continue + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"action": "continue"} + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "client_config" + assert result["data_schema"] is not None + assert result["errors"] == {} + + # Fill in the client config + user_input_step_client_config = { + "client_id": DEMO_CLIENT_ID, + "client_secret": DEMO_CLIENT_SECRET, + } + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input_step_client_config + ) + + # Assert that we are at groups_config + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "groups_config" + assert result["data_schema"] is not None + assert result["errors"] == {} + + # Fill in the groups config + user_input_step_groups_config = { + "admin_group": DEMO_ADMIN_ROLE, + "user_group": DEMO_USER_ROLE, + } + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input_step_groups_config + ) + + # Assert that were are at user_linking config + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user_linking" + assert result["data_schema"] is not None + assert result["errors"] == {} + + # Fill in the user linking config + user_input_step_user_linking = {"enable_user_linking": False} + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input_step_user_linking + ) + + # Finally, assert that the flow is complete and a config entry is created + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == OIDC_PROVIDERS["authentik"]["name"] + + expected_data = { + "provider": "authentik", + CLIENT_ID: DEMO_CLIENT_ID, + CLIENT_SECRET: DEMO_CLIENT_SECRET, + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"], + FEATURES: { + FEATURES_AUTOMATIC_USER_LINKING: False, + FEATURES_AUTOMATIC_PERSON_CREATION: True, + FEATURES_INCLUDE_GROUPS_SCOPE: True, + }, + CLAIMS: { + CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"][ + "display_name" + ], + CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"], + CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"], + }, + ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE}, + } + + assert result["data"] == expected_data + + # Verify that the config entry was loaded into Home Assistant + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + assert entries[0].data == expected_data + + # You can also assert that `async_setup_entry` was called for this entry + # (assuming it's mocked or you let it run if it's simple) + # The PHCC `hass` fixture automatically mocks `async_setup_entry` + # and `async_unload_entry` for you, making it easy to test that they're called. + assert await async_setup_entry(hass, entries[0]) is True + + +@pytest.mark.asyncio +async def test_options_flow_success(hass: HomeAssistant): + """Test a successful options flow.""" + + # First, set up an initial config entry as in the full config flow + initial_data = { + "provider": "authentik", + CLIENT_ID: DEMO_CLIENT_ID, + CLIENT_SECRET: DEMO_CLIENT_SECRET, + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"], + FEATURES: { + FEATURES_AUTOMATIC_USER_LINKING: False, + FEATURES_AUTOMATIC_PERSON_CREATION: True, + FEATURES_INCLUDE_GROUPS_SCOPE: True, + }, + CLAIMS: { + CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"], + CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"], + CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"], + }, + ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE}, + } + + entry = config_entries.ConfigEntry( + version=1, + minor_version=0, + domain=DOMAIN, + title=OIDC_PROVIDERS["authentik"]["name"], + data=initial_data, + source=config_entries.SOURCE_USER, + entry_id="1", + unique_id="test_unique_id", + options={}, + pref_disable_new_entities=False, + pref_disable_polling=False, + discovery_keys=None, + subentries_data=None, + ) + + await hass.config_entries.async_add(entry) + + # Start the reconfigure flow + result = await hass.config_entries.options.async_init(entry.entry_id) + + # Should start the options flow + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "init" + assert result["data_schema"] is not None + + # Assert that the schema is as expected + # Schema contains enable_user_linking, enable_groups, admin_group & user_groups and no other keys + schema = result["data_schema"] + schema_dict = schema.schema + # Assert that the schema contains the expected keys + expected_keys = { + "admin_group", + "enable_user_linking", + "enable_groups", + "user_group", + } + assert set(schema_dict.keys()) == expected_keys + + # Change the client_id and client_secret + new_enable_linking = True + new_enable_groups = True + new_admin_group = "bazzbbb" + new_user_group = "foobar" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + { + "enable_user_linking": new_enable_linking, + "enable_groups": new_enable_groups, + "admin_group": new_admin_group, + "user_group": new_user_group, + }, + ) + + # Should finish and update the entry options + assert result["type"] == FlowResultType.CREATE_ENTRY + + # Optionally, check that the entry options are updated + updated_entry = hass.config_entries.async_get_entry(entry.entry_id) + assert updated_entry is not None + + # Verify that the config entry was loaded into Home Assistant + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + + assert ( + entries[0].data[FEATURES][FEATURES_AUTOMATIC_USER_LINKING] == new_enable_linking + ) + assert entries[0].data[FEATURES][FEATURES_INCLUDE_GROUPS_SCOPE] == new_enable_groups + assert entries[0].data[ROLES][ROLE_ADMINS] == new_admin_group + assert entries[0].data[ROLES][ROLE_USERS] == new_user_group + + +@pytest.mark.asyncio +async def test_reconfigure_flow_success(hass: HomeAssistant): + """Test a successful reconfigure flow.""" + + # First, set up an initial config entry as in the full config flow + initial_data = { + "provider": "authentik", + CLIENT_ID: DEMO_CLIENT_ID, + CLIENT_SECRET: DEMO_CLIENT_SECRET, + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"], + FEATURES: { + FEATURES_AUTOMATIC_USER_LINKING: False, + FEATURES_AUTOMATIC_PERSON_CREATION: True, + FEATURES_INCLUDE_GROUPS_SCOPE: True, + }, + CLAIMS: { + CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"], + CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"], + CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"], + }, + ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE}, + } + + entry = config_entries.ConfigEntry( + version=1, + minor_version=0, + domain=DOMAIN, + title=OIDC_PROVIDERS["authentik"]["name"], + data=initial_data, + source=config_entries.SOURCE_USER, + entry_id="1", + unique_id="test_unique_id", + options={}, + pref_disable_new_entities=False, + pref_disable_polling=False, + discovery_keys=None, + subentries_data=None, + ) + + await hass.config_entries.async_add(entry) + + # Start async_step_reconfigure to reconfigure the entry + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={ + "source": config_entries.SOURCE_RECONFIGURE, + "entry_id": entry.entry_id, + }, + ) + + # Should start the reconfigure flow + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "reconfigure" + assert result["data_schema"] is not None + + # Assert that the schema is client_id & client_secret + schema = result["data_schema"] + schema_dict = schema.schema + # Assert that the schema contains the expected keys + expected_keys = { + "client_id", + "client_secret", + } + assert set(schema_dict.keys()) == expected_keys + + # Change the client_id and client_secret + new_client_id = "newclientid" + new_client_secret = "newclientsecret" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "client_id": new_client_id, + "client_secret": new_client_secret, + }, + ) + + # Should finish and update the entry data + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reconfigure_successful" + + # Verify that the config entry was loaded into Home Assistant + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + assert entries[0].data[CLIENT_ID] == new_client_id + assert entries[0].data[CLIENT_SECRET] == new_client_secret diff --git a/tests/test_hass_webserver.py b/tests/test_hass_webserver.py new file mode 100644 index 0000000..67320d6 --- /dev/null +++ b/tests/test_hass_webserver.py @@ -0,0 +1,151 @@ +"""Tests for the registered webpages""" + +import os +from auth_oidc.config.const import ( + DISCOVERY_URL, + CLIENT_ID, + FEATURES, + FEATURES_DISABLE_FRONTEND_INJECTION, +) +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component +from homeassistant.components.http import StaticPathConfig, DOMAIN as HTTP_DOMAIN + +from custom_components.auth_oidc import DOMAIN + + +async def setup(hass: HomeAssistant, enable_frontend_changes: bool = None): + mock_config = { + DOMAIN: { + CLIENT_ID: "dummy", + DISCOVERY_URL: "https://example.com/.well-known/openid-configuration", + FEATURES: { + FEATURES_DISABLE_FRONTEND_INJECTION: not enable_frontend_changes + }, + } + } + + if enable_frontend_changes is None: + del mock_config[DOMAIN][FEATURES][FEATURES_DISABLE_FRONTEND_INJECTION] + + result = await async_setup_component(hass, DOMAIN, mock_config) + assert result + + +@pytest.mark.asyncio +async def test_welcome_page_registration(hass: HomeAssistant, hass_client): + """Test that welcome page is present if frontend changes are disabled.""" + + await setup(hass, enable_frontend_changes=False) + + client = await hass_client() + resp = await client.get("/auth/oidc/welcome", allow_redirects=False) + assert resp.status == 200 + + +@pytest.mark.asyncio +async def test_welcome_page_registration_with_changes(hass: HomeAssistant, hass_client): + """Test that welcome page is redirect if frontend changes are enabled.""" + + await setup(hass, enable_frontend_changes=True) + + client = await hass_client() + resp = await client.get("/auth/oidc/welcome", allow_redirects=False) + assert resp.status == 307 + + +@pytest.mark.asyncio +async def test_redirect_page_registration(hass: HomeAssistant, hass_client): + """Test that redirect page shows OIDC misconfiguration error if OIDC server is not reachable.""" + + await setup(hass) + + client = await hass_client() + resp = await client.get("/auth/oidc/redirect", allow_redirects=False) + assert resp.status == 200 + text = await resp.text() + assert "Integration is misconfigured" in text + + resp2 = await client.post("/auth/oidc/redirect", allow_redirects=False) + assert resp2.status == 200 + + +@pytest.mark.asyncio +async def test_callback_registration(hass: HomeAssistant, hass_client): + """Test that callback page is reachable.""" + + await setup(hass) + + client = await hass_client() + resp = await client.get("/auth/oidc/callback", allow_redirects=False) + assert resp.status == 200 + + +@pytest.mark.asyncio +async def test_finish_registration(hass: HomeAssistant, hass_client): + """Test that finish page is reachable.""" + + await setup(hass) + + client = await hass_client() + resp = await client.get("/auth/oidc/finish", allow_redirects=False) + assert resp.status == 200 + text = await resp.text() + + # Should miss the code parameter if called without it + assert "Missing code" in text + + resp2 = await client.get("/auth/oidc/finish?code=123456", allow_redirects=False) + assert resp2.status == 200 + text2 = await resp2.text() + assert "Missing code" not in text2 + assert "123456" in text2 + + +@pytest.mark.asyncio +async def test_finish_post(hass: HomeAssistant, hass_client): + """Test that finish page works with POST.""" + + await setup(hass) + client = await hass_client() + resp = await client.post("/auth/oidc/finish", data={}, allow_redirects=False) + assert resp.status == 500 + + resp2 = await client.post( + "/auth/oidc/finish", data={"code": "456888"}, allow_redirects=False + ) + assert resp2.status == 302 + assert resp2.headers["Location"] == "/?storeToken=true" + assert resp2.cookies["auth_oidc_code"].value == "456888" + + +# Test the frontend injection +@pytest.mark.asyncio +async def test_frontend_injection(hass: HomeAssistant, hass_client): + """Test that frontend injection works.""" + + # Because there is no frontend in the test setup, + # we'll have to fake /auth/authorize for the changes to register + await async_setup_component(hass, HTTP_DOMAIN, {}) + + mock_html_path = os.path.join(os.path.dirname(__file__), "mocks", "auth_page.html") + await hass.http.async_register_static_paths( + [ + StaticPathConfig( + "/auth/authorize", + mock_html_path, + cache_headers=False, + ) + ] + ) + + await setup(hass, enable_frontend_changes=True) + + client = await hass_client() + resp = await client.get("/auth/authorize", allow_redirects=False) + assert resp.status == 200 + text = await resp.text() + + assert "