* Add initial test & add pipeline * Add very basic YAML config tests * Add coverage reporting * Add some webserver & template loading tests * Add test cases for the helpers * Implement initial OIDC server tests * Test codestore & discovery checker * Test basics of the config flow * Add test for the HA auth provider * Cleaned up tests & test injection
198 lines
6.7 KiB
Python
198 lines
6.7 KiB
Python
"""A simple mock OIDC server for testing purposes."""
|
|
|
|
from contextlib import contextmanager
|
|
import time
|
|
import logging
|
|
import hashlib
|
|
import random
|
|
import json
|
|
import os
|
|
from unittest.mock import AsyncMock, patch
|
|
from urllib.parse import urlparse, parse_qs
|
|
from joserfc import jwt
|
|
from joserfc.jwk import RSAKey, KeySet
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
BASE_URL = "https://oidc.example.com"
|
|
SUBJECT = "testuser"
|
|
|
|
|
|
class MockOIDCServer:
|
|
"""A simple mock OIDC server for testing purposes."""
|
|
|
|
_code_storage = {}
|
|
_scenario = {}
|
|
|
|
def __init__(self, scenario: str | None = None):
|
|
"""Initialize the mock OIDC server."""
|
|
# Create a JWK private key
|
|
self._jwk = RSAKey.generate_key(
|
|
2048, {"alg": "RS256", "use": "sig"}, private=True, auto_kid=True
|
|
)
|
|
|
|
if scenario:
|
|
# Load scenario JSON file from disk
|
|
scenario_path = os.path.join(
|
|
os.path.dirname(__file__), "scenarios", f"{scenario}.json"
|
|
)
|
|
with open(scenario_path, "r", encoding="utf-8") as f:
|
|
self._scenario = json.load(f)
|
|
|
|
# Log it
|
|
_LOGGER.debug("Loaded scenario: %s", self._scenario)
|
|
|
|
def get_random_code(self):
|
|
"""Return a random authorization code."""
|
|
return "".join(str(random.randint(0, 9)) for _ in range(6))
|
|
|
|
@staticmethod
|
|
def get_discovery_url():
|
|
"""Return the discovery URL for the given base URL."""
|
|
return f"{BASE_URL}/.well-known/openid-configuration"
|
|
|
|
@staticmethod
|
|
def get_authorize_url():
|
|
"""Return the authorization URL for the given base URL."""
|
|
return f"{BASE_URL}/authorize"
|
|
|
|
def process_request(self, url: str, method: str, body: dict) -> tuple[dict, int]:
|
|
"""Process a request to the mock OIDC server."""
|
|
_LOGGER.debug("Received %s request to %s in OIDC mock server", method, url)
|
|
|
|
if url == self.get_discovery_url() and method == "GET":
|
|
response = self._get_discovery_document()
|
|
elif url.startswith(self.get_authorize_url()) and method == "GET":
|
|
response = self._get_authorize_response(url)
|
|
elif url == f"{BASE_URL}/token" and method == "POST":
|
|
response = self._get_token_response(body)
|
|
elif url == f"{BASE_URL}/jwks" and method == "GET":
|
|
response = self._get_jwks_response()
|
|
else:
|
|
response = {"error": "Unknown endpoint"}, 404
|
|
|
|
_LOGGER.debug("Responding with: %s", response)
|
|
return response
|
|
|
|
def _get_discovery_document(self) -> tuple[dict, int]:
|
|
"""Return a mock discovery document."""
|
|
|
|
if "discovery" in self._scenario:
|
|
return self._scenario["discovery"], 200
|
|
|
|
return {
|
|
"issuer": BASE_URL,
|
|
"authorization_endpoint": self.get_authorize_url(),
|
|
"token_endpoint": f"{BASE_URL}/token",
|
|
"userinfo_endpoint": f"{BASE_URL}/userinfo",
|
|
"jwks_uri": f"{BASE_URL}/jwks",
|
|
"id_token_signing_alg_values_supported": ["RS256"],
|
|
}, 200
|
|
|
|
def _get_authorize_response(self, url: str) -> tuple[dict, int]:
|
|
"""Return a mock authorization response."""
|
|
# Parse the url
|
|
parsed_url = urlparse(url)
|
|
query_params = parse_qs(parsed_url.query)
|
|
|
|
code = self.get_random_code()
|
|
self._code_storage[code] = query_params
|
|
|
|
return {"code": code, "state": "xyz"}, 200
|
|
|
|
def _get_token_response(self, body: dict) -> tuple[dict, int]:
|
|
"""Return a mock token response."""
|
|
|
|
if body.get("code") in self._code_storage:
|
|
# TODO: Verify PKCE?
|
|
return {
|
|
"access_token": "exampleAccessToken",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
"id_token": self._create_id_token(body.get("code")),
|
|
}, 200
|
|
else:
|
|
return {"error": "invalid_request"}, 400
|
|
|
|
def _create_id_token(self, code: str) -> str:
|
|
"""Create a mock ID token."""
|
|
# Get the query params
|
|
if code not in self._code_storage:
|
|
raise ValueError("Invalid code")
|
|
query_params = self._code_storage[code]
|
|
_LOGGER.debug("Creating ID token with query params: %s", query_params)
|
|
|
|
# Get username
|
|
if "username" in self._scenario:
|
|
username = self._scenario["username"]
|
|
else:
|
|
username = "testuser"
|
|
|
|
# Create a simple signed JWT with our JWK
|
|
header = {"alg": self._jwk.alg, "kid": self._jwk.kid}
|
|
claims = {
|
|
"iss": BASE_URL,
|
|
"sub": SUBJECT,
|
|
"aud": query_params.get("client_id", [""])[0],
|
|
"nonce": query_params.get("nonce", [""])[0],
|
|
"name": "Test Name",
|
|
"preferred_username": username,
|
|
}
|
|
|
|
now = int(time.time())
|
|
claims["nbf"] = now
|
|
claims["iat"] = now
|
|
claims["exp"] = now + 3600 # 1 hour expiry
|
|
|
|
return jwt.encode(header, claims, self._jwk)
|
|
|
|
def _get_jwks_response(self) -> tuple[dict, int]:
|
|
"""Return a mock JWKS response."""
|
|
private_key = self._jwk
|
|
public_key_dict = private_key.as_dict(private=False)
|
|
public_key = RSAKey.import_key(
|
|
public_key_dict, {"use": "sig", "alg": "RS256", "kid": private_key.kid}
|
|
)
|
|
|
|
key_set = KeySet([public_key])
|
|
|
|
return key_set.as_dict(), 200
|
|
|
|
@staticmethod
|
|
def get_final_subject():
|
|
"""Return the subject that's returned to HA."""
|
|
return hashlib.sha256(f"{BASE_URL}.{SUBJECT}".encode("utf-8")).hexdigest()
|
|
|
|
|
|
@contextmanager
|
|
def mock_oidc_responses(scenario: str | None = None):
|
|
"""Mock OIDC responses for testing."""
|
|
|
|
mock_oidc_server = MockOIDCServer(scenario)
|
|
|
|
def make_mock_response(json_data, status):
|
|
mock_response = AsyncMock()
|
|
mock_response.__aenter__.return_value = mock_response
|
|
mock_response.__aexit__.return_value = None
|
|
mock_response.json = AsyncMock(return_value=json_data)
|
|
mock_response.status = status
|
|
return mock_response
|
|
|
|
def default_handler(method, url, *args, **kwargs):
|
|
_LOGGER.debug("Mocked %s request to %s", method, url)
|
|
body = kwargs.get("data") or kwargs.get("json") or None
|
|
response = mock_oidc_server.process_request(url, method, body)
|
|
return make_mock_response(response[0], response[1])
|
|
|
|
def get_side_effect(url, *args, **kwargs):
|
|
return default_handler("GET", url, *args, **kwargs)
|
|
|
|
def post_side_effect(url, *args, **kwargs):
|
|
return default_handler("POST", url, *args, **kwargs)
|
|
|
|
with (
|
|
patch("aiohttp.ClientSession.get", side_effect=get_side_effect) as get_patch,
|
|
patch("aiohttp.ClientSession.post", side_effect=post_side_effect) as post_patch,
|
|
):
|
|
yield (get_patch, post_patch, default_handler)
|