Add unit tests (#133)

* Add initial test & add pipeline

* Add very basic YAML config tests

* Add coverage reporting

* Add some webserver & template loading tests

* Add test cases for the helpers

* Implement initial OIDC server tests

* Test codestore & discovery checker

* Test basics of the config flow

* Add test for the HA auth provider

* Cleaned up tests & test injection
This commit is contained in:
Christiaan Goossens
2025-10-05 21:03:02 +02:00
committed by GitHub
parent 5714e844a7
commit 404d2451df
42 changed files with 2331 additions and 91 deletions

197
tests/mocks/oidc_server.py Normal file
View File

@@ -0,0 +1,197 @@
"""A simple mock OIDC server for testing purposes."""
from contextlib import contextmanager
import time
import logging
import hashlib
import random
import json
import os
from unittest.mock import AsyncMock, patch
from urllib.parse import urlparse, parse_qs
from joserfc import jwt
from joserfc.jwk import RSAKey, KeySet
_LOGGER = logging.getLogger(__name__)
BASE_URL = "https://oidc.example.com"
SUBJECT = "testuser"
class MockOIDCServer:
"""A simple mock OIDC server for testing purposes."""
_code_storage = {}
_scenario = {}
def __init__(self, scenario: str | None = None):
"""Initialize the mock OIDC server."""
# Create a JWK private key
self._jwk = RSAKey.generate_key(
2048, {"alg": "RS256", "use": "sig"}, private=True, auto_kid=True
)
if scenario:
# Load scenario JSON file from disk
scenario_path = os.path.join(
os.path.dirname(__file__), "scenarios", f"{scenario}.json"
)
with open(scenario_path, "r", encoding="utf-8") as f:
self._scenario = json.load(f)
# Log it
_LOGGER.debug("Loaded scenario: %s", self._scenario)
def get_random_code(self):
"""Return a random authorization code."""
return "".join(str(random.randint(0, 9)) for _ in range(6))
@staticmethod
def get_discovery_url():
"""Return the discovery URL for the given base URL."""
return f"{BASE_URL}/.well-known/openid-configuration"
@staticmethod
def get_authorize_url():
"""Return the authorization URL for the given base URL."""
return f"{BASE_URL}/authorize"
def process_request(self, url: str, method: str, body: dict) -> tuple[dict, int]:
"""Process a request to the mock OIDC server."""
_LOGGER.debug("Received %s request to %s in OIDC mock server", method, url)
if url == self.get_discovery_url() and method == "GET":
response = self._get_discovery_document()
elif url.startswith(self.get_authorize_url()) and method == "GET":
response = self._get_authorize_response(url)
elif url == f"{BASE_URL}/token" and method == "POST":
response = self._get_token_response(body)
elif url == f"{BASE_URL}/jwks" and method == "GET":
response = self._get_jwks_response()
else:
response = {"error": "Unknown endpoint"}, 404
_LOGGER.debug("Responding with: %s", response)
return response
def _get_discovery_document(self) -> tuple[dict, int]:
"""Return a mock discovery document."""
if "discovery" in self._scenario:
return self._scenario["discovery"], 200
return {
"issuer": BASE_URL,
"authorization_endpoint": self.get_authorize_url(),
"token_endpoint": f"{BASE_URL}/token",
"userinfo_endpoint": f"{BASE_URL}/userinfo",
"jwks_uri": f"{BASE_URL}/jwks",
"id_token_signing_alg_values_supported": ["RS256"],
}, 200
def _get_authorize_response(self, url: str) -> tuple[dict, int]:
"""Return a mock authorization response."""
# Parse the url
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
code = self.get_random_code()
self._code_storage[code] = query_params
return {"code": code, "state": "xyz"}, 200
def _get_token_response(self, body: dict) -> tuple[dict, int]:
"""Return a mock token response."""
if body.get("code") in self._code_storage:
# TODO: Verify PKCE?
return {
"access_token": "exampleAccessToken",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": self._create_id_token(body.get("code")),
}, 200
else:
return {"error": "invalid_request"}, 400
def _create_id_token(self, code: str) -> str:
"""Create a mock ID token."""
# Get the query params
if code not in self._code_storage:
raise ValueError("Invalid code")
query_params = self._code_storage[code]
_LOGGER.debug("Creating ID token with query params: %s", query_params)
# Get username
if "username" in self._scenario:
username = self._scenario["username"]
else:
username = "testuser"
# Create a simple signed JWT with our JWK
header = {"alg": self._jwk.alg, "kid": self._jwk.kid}
claims = {
"iss": BASE_URL,
"sub": SUBJECT,
"aud": query_params.get("client_id", [""])[0],
"nonce": query_params.get("nonce", [""])[0],
"name": "Test Name",
"preferred_username": username,
}
now = int(time.time())
claims["nbf"] = now
claims["iat"] = now
claims["exp"] = now + 3600 # 1 hour expiry
return jwt.encode(header, claims, self._jwk)
def _get_jwks_response(self) -> tuple[dict, int]:
"""Return a mock JWKS response."""
private_key = self._jwk
public_key_dict = private_key.as_dict(private=False)
public_key = RSAKey.import_key(
public_key_dict, {"use": "sig", "alg": "RS256", "kid": private_key.kid}
)
key_set = KeySet([public_key])
return key_set.as_dict(), 200
@staticmethod
def get_final_subject():
"""Return the subject that's returned to HA."""
return hashlib.sha256(f"{BASE_URL}.{SUBJECT}".encode("utf-8")).hexdigest()
@contextmanager
def mock_oidc_responses(scenario: str | None = None):
"""Mock OIDC responses for testing."""
mock_oidc_server = MockOIDCServer(scenario)
def make_mock_response(json_data, status):
mock_response = AsyncMock()
mock_response.__aenter__.return_value = mock_response
mock_response.__aexit__.return_value = None
mock_response.json = AsyncMock(return_value=json_data)
mock_response.status = status
return mock_response
def default_handler(method, url, *args, **kwargs):
_LOGGER.debug("Mocked %s request to %s", method, url)
body = kwargs.get("data") or kwargs.get("json") or None
response = mock_oidc_server.process_request(url, method, body)
return make_mock_response(response[0], response[1])
def get_side_effect(url, *args, **kwargs):
return default_handler("GET", url, *args, **kwargs)
def post_side_effect(url, *args, **kwargs):
return default_handler("POST", url, *args, **kwargs)
with (
patch("aiohttp.ClientSession.get", side_effect=get_side_effect) as get_patch,
patch("aiohttp.ClientSession.post", side_effect=post_side_effect) as post_patch,
):
yield (get_patch, post_patch, default_handler)