Add unit tests (#133)
* Add initial test & add pipeline * Add very basic YAML config tests * Add coverage reporting * Add some webserver & template loading tests * Add test cases for the helpers * Implement initial OIDC server tests * Test codestore & discovery checker * Test basics of the config flow * Add test for the HA auth provider * Cleaned up tests & test injection
This commit is contained in:
committed by
GitHub
parent
5714e844a7
commit
404d2451df
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
8
tests/conftest.py
Normal file
8
tests/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Fixtures for testing."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def auto_enable_custom_integrations(enable_custom_integrations):
|
||||
yield
|
||||
0
tests/mocks/__init__.py
Normal file
0
tests/mocks/__init__.py
Normal file
14
tests/mocks/auth_page.html
Normal file
14
tests/mocks/auth_page.html
Normal file
@@ -0,0 +1,14 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Test</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
Test page
|
||||
</body>
|
||||
|
||||
</html>
|
||||
197
tests/mocks/oidc_server.py
Normal file
197
tests/mocks/oidc_server.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""A simple mock OIDC server for testing purposes."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
import logging
|
||||
import hashlib
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from joserfc import jwt
|
||||
from joserfc.jwk import RSAKey, KeySet
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://oidc.example.com"
|
||||
SUBJECT = "testuser"
|
||||
|
||||
|
||||
class MockOIDCServer:
|
||||
"""A simple mock OIDC server for testing purposes."""
|
||||
|
||||
_code_storage = {}
|
||||
_scenario = {}
|
||||
|
||||
def __init__(self, scenario: str | None = None):
|
||||
"""Initialize the mock OIDC server."""
|
||||
# Create a JWK private key
|
||||
self._jwk = RSAKey.generate_key(
|
||||
2048, {"alg": "RS256", "use": "sig"}, private=True, auto_kid=True
|
||||
)
|
||||
|
||||
if scenario:
|
||||
# Load scenario JSON file from disk
|
||||
scenario_path = os.path.join(
|
||||
os.path.dirname(__file__), "scenarios", f"{scenario}.json"
|
||||
)
|
||||
with open(scenario_path, "r", encoding="utf-8") as f:
|
||||
self._scenario = json.load(f)
|
||||
|
||||
# Log it
|
||||
_LOGGER.debug("Loaded scenario: %s", self._scenario)
|
||||
|
||||
def get_random_code(self):
|
||||
"""Return a random authorization code."""
|
||||
return "".join(str(random.randint(0, 9)) for _ in range(6))
|
||||
|
||||
@staticmethod
|
||||
def get_discovery_url():
|
||||
"""Return the discovery URL for the given base URL."""
|
||||
return f"{BASE_URL}/.well-known/openid-configuration"
|
||||
|
||||
@staticmethod
|
||||
def get_authorize_url():
|
||||
"""Return the authorization URL for the given base URL."""
|
||||
return f"{BASE_URL}/authorize"
|
||||
|
||||
def process_request(self, url: str, method: str, body: dict) -> tuple[dict, int]:
|
||||
"""Process a request to the mock OIDC server."""
|
||||
_LOGGER.debug("Received %s request to %s in OIDC mock server", method, url)
|
||||
|
||||
if url == self.get_discovery_url() and method == "GET":
|
||||
response = self._get_discovery_document()
|
||||
elif url.startswith(self.get_authorize_url()) and method == "GET":
|
||||
response = self._get_authorize_response(url)
|
||||
elif url == f"{BASE_URL}/token" and method == "POST":
|
||||
response = self._get_token_response(body)
|
||||
elif url == f"{BASE_URL}/jwks" and method == "GET":
|
||||
response = self._get_jwks_response()
|
||||
else:
|
||||
response = {"error": "Unknown endpoint"}, 404
|
||||
|
||||
_LOGGER.debug("Responding with: %s", response)
|
||||
return response
|
||||
|
||||
def _get_discovery_document(self) -> tuple[dict, int]:
|
||||
"""Return a mock discovery document."""
|
||||
|
||||
if "discovery" in self._scenario:
|
||||
return self._scenario["discovery"], 200
|
||||
|
||||
return {
|
||||
"issuer": BASE_URL,
|
||||
"authorization_endpoint": self.get_authorize_url(),
|
||||
"token_endpoint": f"{BASE_URL}/token",
|
||||
"userinfo_endpoint": f"{BASE_URL}/userinfo",
|
||||
"jwks_uri": f"{BASE_URL}/jwks",
|
||||
"id_token_signing_alg_values_supported": ["RS256"],
|
||||
}, 200
|
||||
|
||||
def _get_authorize_response(self, url: str) -> tuple[dict, int]:
|
||||
"""Return a mock authorization response."""
|
||||
# Parse the url
|
||||
parsed_url = urlparse(url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
code = self.get_random_code()
|
||||
self._code_storage[code] = query_params
|
||||
|
||||
return {"code": code, "state": "xyz"}, 200
|
||||
|
||||
def _get_token_response(self, body: dict) -> tuple[dict, int]:
|
||||
"""Return a mock token response."""
|
||||
|
||||
if body.get("code") in self._code_storage:
|
||||
# TODO: Verify PKCE?
|
||||
return {
|
||||
"access_token": "exampleAccessToken",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"id_token": self._create_id_token(body.get("code")),
|
||||
}, 200
|
||||
else:
|
||||
return {"error": "invalid_request"}, 400
|
||||
|
||||
def _create_id_token(self, code: str) -> str:
|
||||
"""Create a mock ID token."""
|
||||
# Get the query params
|
||||
if code not in self._code_storage:
|
||||
raise ValueError("Invalid code")
|
||||
query_params = self._code_storage[code]
|
||||
_LOGGER.debug("Creating ID token with query params: %s", query_params)
|
||||
|
||||
# Get username
|
||||
if "username" in self._scenario:
|
||||
username = self._scenario["username"]
|
||||
else:
|
||||
username = "testuser"
|
||||
|
||||
# Create a simple signed JWT with our JWK
|
||||
header = {"alg": self._jwk.alg, "kid": self._jwk.kid}
|
||||
claims = {
|
||||
"iss": BASE_URL,
|
||||
"sub": SUBJECT,
|
||||
"aud": query_params.get("client_id", [""])[0],
|
||||
"nonce": query_params.get("nonce", [""])[0],
|
||||
"name": "Test Name",
|
||||
"preferred_username": username,
|
||||
}
|
||||
|
||||
now = int(time.time())
|
||||
claims["nbf"] = now
|
||||
claims["iat"] = now
|
||||
claims["exp"] = now + 3600 # 1 hour expiry
|
||||
|
||||
return jwt.encode(header, claims, self._jwk)
|
||||
|
||||
def _get_jwks_response(self) -> tuple[dict, int]:
|
||||
"""Return a mock JWKS response."""
|
||||
private_key = self._jwk
|
||||
public_key_dict = private_key.as_dict(private=False)
|
||||
public_key = RSAKey.import_key(
|
||||
public_key_dict, {"use": "sig", "alg": "RS256", "kid": private_key.kid}
|
||||
)
|
||||
|
||||
key_set = KeySet([public_key])
|
||||
|
||||
return key_set.as_dict(), 200
|
||||
|
||||
@staticmethod
|
||||
def get_final_subject():
|
||||
"""Return the subject that's returned to HA."""
|
||||
return hashlib.sha256(f"{BASE_URL}.{SUBJECT}".encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_oidc_responses(scenario: str | None = None):
|
||||
"""Mock OIDC responses for testing."""
|
||||
|
||||
mock_oidc_server = MockOIDCServer(scenario)
|
||||
|
||||
def make_mock_response(json_data, status):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aenter__.return_value = mock_response
|
||||
mock_response.__aexit__.return_value = None
|
||||
mock_response.json = AsyncMock(return_value=json_data)
|
||||
mock_response.status = status
|
||||
return mock_response
|
||||
|
||||
def default_handler(method, url, *args, **kwargs):
|
||||
_LOGGER.debug("Mocked %s request to %s", method, url)
|
||||
body = kwargs.get("data") or kwargs.get("json") or None
|
||||
response = mock_oidc_server.process_request(url, method, body)
|
||||
return make_mock_response(response[0], response[1])
|
||||
|
||||
def get_side_effect(url, *args, **kwargs):
|
||||
return default_handler("GET", url, *args, **kwargs)
|
||||
|
||||
def post_side_effect(url, *args, **kwargs):
|
||||
return default_handler("POST", url, *args, **kwargs)
|
||||
|
||||
with (
|
||||
patch("aiohttp.ClientSession.get", side_effect=get_side_effect) as get_patch,
|
||||
patch("aiohttp.ClientSession.post", side_effect=post_side_effect) as post_patch,
|
||||
):
|
||||
yield (get_patch, post_patch, default_handler)
|
||||
5
tests/mocks/scenarios/empty.json
Normal file
5
tests/mocks/scenarios/empty.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"discovery": {
|
||||
|
||||
}
|
||||
}
|
||||
10
tests/mocks/scenarios/invalid_code_challenge_types.json
Normal file
10
tests/mocks/scenarios/invalid_code_challenge_types.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||
"response_types_supported": ["code"],
|
||||
"code_challenge_methods_supported": ["plain"]
|
||||
}
|
||||
}
|
||||
10
tests/mocks/scenarios/invalid_grant_types.json
Normal file
10
tests/mocks/scenarios/invalid_grant_types.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["refresh_token"]
|
||||
}
|
||||
}
|
||||
8
tests/mocks/scenarios/invalid_id_token_signing_alg.json
Normal file
8
tests/mocks/scenarios/invalid_id_token_signing_alg.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks"
|
||||
}
|
||||
}
|
||||
9
tests/mocks/scenarios/invalid_response_modes.json
Normal file
9
tests/mocks/scenarios/invalid_response_modes.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||
"response_modes_supported": ["post"]
|
||||
}
|
||||
}
|
||||
9
tests/mocks/scenarios/invalid_response_types.json
Normal file
9
tests/mocks/scenarios/invalid_response_types.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||
"response_types_supported": ["token"]
|
||||
}
|
||||
}
|
||||
8
tests/mocks/scenarios/invalid_url.json
Normal file
8
tests/mocks/scenarios/invalid_url.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "/jwks"
|
||||
}
|
||||
}
|
||||
7
tests/mocks/scenarios/missing_jwks.json
Normal file
7
tests/mocks/scenarios/missing_jwks.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token"
|
||||
}
|
||||
}
|
||||
6
tests/mocks/scenarios/missing_token.json
Normal file
6
tests/mocks/scenarios/missing_token.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize"
|
||||
}
|
||||
}
|
||||
5
tests/mocks/scenarios/only_issuer.json
Normal file
5
tests/mocks/scenarios/only_issuer.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local"
|
||||
}
|
||||
}
|
||||
3
tests/mocks/scenarios/username.json
Normal file
3
tests/mocks/scenarios/username.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"username": "foobar"
|
||||
}
|
||||
9
tests/mocks/scenarios/wrong_id_token_signing_alg.json
Normal file
9
tests/mocks/scenarios/wrong_id_token_signing_alg.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"discovery": {
|
||||
"issuer": "https://mock-oidc-server.local",
|
||||
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||
"id_token_signing_alg_values_supported": ["HS256"]
|
||||
}
|
||||
}
|
||||
1
tests/resources/fake_templates/index.html
Normal file
1
tests/resources/fake_templates/index.html
Normal file
@@ -0,0 +1 @@
|
||||
<p>Example template</p>
|
||||
90
tests/test_code_store.py
Normal file
90
tests/test_code_store.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for the code store"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
import pytest
|
||||
|
||||
from auth_oidc.stores.code_store import CodeStore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_store_generate_and_receive_code(hass: HomeAssistant):
|
||||
"""Test generating and receiving a code."""
|
||||
store_mock = AsyncMock()
|
||||
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||
code_store = CodeStore(hass)
|
||||
|
||||
# Simulate loading with empty data
|
||||
store_mock.async_load.return_value = {}
|
||||
await code_store.async_load()
|
||||
assert code_store.get_data() == {}
|
||||
|
||||
user_info = {"sub": "user1", "name": "Test User"}
|
||||
code = await code_store.async_generate_code_for_userinfo(user_info)
|
||||
assert code in code_store.get_data()
|
||||
|
||||
# Should return user_info and remove the code
|
||||
with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock:
|
||||
dt_mock.utcnow.return_value = datetime.now(timezone.utc)
|
||||
dt_mock.fromisoformat.side_effect = datetime.fromisoformat
|
||||
result = await code_store.receive_userinfo_for_code(code)
|
||||
assert result == user_info
|
||||
assert code not in code_store.get_data()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_store_expired_code(hass):
|
||||
"""Test that expired codes return None."""
|
||||
store_mock = AsyncMock()
|
||||
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||
code_store = CodeStore(hass)
|
||||
store_mock.async_load.return_value = {}
|
||||
await code_store.async_load()
|
||||
assert code_store.get_data() == {}
|
||||
|
||||
user_info = {"sub": "user2", "name": "Expired User"}
|
||||
code = await code_store.async_generate_code_for_userinfo(user_info)
|
||||
|
||||
# Patch expiration to be in the past
|
||||
code_store.get_data()[code]["expiration"] = (
|
||||
datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
).isoformat()
|
||||
|
||||
with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock:
|
||||
dt_mock.utcnow.return_value = datetime.now(timezone.utc)
|
||||
dt_mock.fromisoformat.side_effect = datetime.fromisoformat
|
||||
result = await code_store.receive_userinfo_for_code(code)
|
||||
assert result is None
|
||||
assert code not in code_store.get_data()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_store_data_not_loaded(hass):
|
||||
"""Test that using the store before loading raises RuntimeError."""
|
||||
store_mock = AsyncMock()
|
||||
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||
code_store = CodeStore(hass)
|
||||
|
||||
# Data is not loaded yet, should result in RuntimeError
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await code_store.async_generate_code_for_userinfo({"sub": "user3"})
|
||||
with pytest.raises(RuntimeError):
|
||||
await code_store.receive_userinfo_for_code("123456")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_store_generate_code_length(hass):
|
||||
"""Test that generated codes are 6 digits."""
|
||||
store_mock = AsyncMock()
|
||||
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||
code_store = CodeStore(hass)
|
||||
store_mock.async_load.return_value = {}
|
||||
await code_store.async_load()
|
||||
assert code_store.get_data() == {}
|
||||
user_info = {"sub": "user4"}
|
||||
code = await code_store.async_generate_code_for_userinfo(user_info)
|
||||
assert len(code) == 6
|
||||
assert code.isdigit()
|
||||
229
tests/test_hass_auth_provider.py
Normal file
229
tests/test_hass_auth_provider.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tests for the Auth Provider registration in HA"""
|
||||
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.components.person import DOMAIN as PERSON_DOMAIN
|
||||
|
||||
from custom_components.auth_oidc import DOMAIN
|
||||
from custom_components.auth_oidc.config.const import (
|
||||
DISCOVERY_URL,
|
||||
CLIENT_ID,
|
||||
FEATURES,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
FEATURES_AUTOMATIC_USER_LINKING,
|
||||
)
|
||||
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||
|
||||
|
||||
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
||||
"""Set up the auth_oidc component."""
|
||||
result = await async_setup_component(hass, DOMAIN, {DOMAIN: config})
|
||||
|
||||
if expect_success:
|
||||
assert result
|
||||
assert DOMAIN in hass.data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_auth_provider_registration(hass: HomeAssistant):
|
||||
"""Test successful setup"""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
# Ensure the auth provider is registered
|
||||
auth_providers = hass.auth.get_auth_providers(DOMAIN)
|
||||
assert len(auth_providers) == 1
|
||||
|
||||
|
||||
async def login_user(hass: HomeAssistant, code: str):
|
||||
"""Helper to login a user."""
|
||||
|
||||
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||
flow = await provider.async_login_flow({})
|
||||
|
||||
result = await flow.async_step_init({"code": code})
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] is not None
|
||||
|
||||
data = result["data"]
|
||||
sub = data["sub"]
|
||||
assert sub == MockOIDCServer.get_final_subject()
|
||||
|
||||
# Get credentials
|
||||
credentials = await provider.async_get_or_create_credentials(data)
|
||||
assert credentials is not None
|
||||
assert credentials.data["sub"] == sub
|
||||
|
||||
user = await hass.auth.async_get_or_create_user(credentials)
|
||||
assert user.is_active
|
||||
return user
|
||||
|
||||
|
||||
async def get_login_code(hass: HomeAssistant, hass_client):
|
||||
"""Helper to get a login code."""
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||
assert resp.status == 302
|
||||
location = resp.headers["Location"]
|
||||
parsed_url = urlparse(location)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
state = query_params["state"][0]
|
||||
|
||||
session = async_get_clientsession(hass)
|
||||
resp = session.get(location, allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
|
||||
json_parsed = await resp.json()
|
||||
assert "code" in json_parsed and json_parsed["code"]
|
||||
|
||||
code = json_parsed["code"]
|
||||
client = await hass_client()
|
||||
resp = await client.get(
|
||||
f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False
|
||||
)
|
||||
|
||||
assert resp.status == 302
|
||||
location = resp.headers["Location"]
|
||||
assert "/auth/oidc/finish?code=" in location
|
||||
|
||||
# Get the code from the finish URL
|
||||
code = location.split("code=")[1]
|
||||
return code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_login(hass: HomeAssistant, hass_client):
|
||||
"""Test a full login flow."""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
},
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
with mock_oidc_responses():
|
||||
# Actually start the login and get a code
|
||||
code = await get_login_code(hass, hass_client)
|
||||
|
||||
# Use the code to login directly with the registered auth provider
|
||||
# Inspired by tests for the built-in providers
|
||||
user = await login_user(hass, code)
|
||||
assert user.name == "Test Name"
|
||||
|
||||
# Login again to see if we trigger the re-use path
|
||||
code2 = await get_login_code(hass, hass_client)
|
||||
user2 = await login_user(hass, code2)
|
||||
assert user2.id == user.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_with_linking(hass: HomeAssistant, hass_client):
|
||||
"""Test a linking login."""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||
FEATURES_AUTOMATIC_USER_LINKING: True,
|
||||
},
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
with mock_oidc_responses("username"):
|
||||
# Create a user first with username 'foobar'
|
||||
user = await hass.auth.async_create_user("Foo Bar")
|
||||
assert user.is_active
|
||||
|
||||
hass_provider = hass.auth.get_auth_providers("homeassistant")[0]
|
||||
credential = await hass_provider.async_get_or_create_credentials(
|
||||
{"username": "foobar"}
|
||||
)
|
||||
await hass.auth.async_link_user(user, credential)
|
||||
|
||||
# Actually start the login and get a code
|
||||
code = await get_login_code(hass, hass_client)
|
||||
|
||||
# Use the code to login directly with the registered auth provider
|
||||
user2 = await login_user(hass, code)
|
||||
assert user2.id == user.id # Assert that the user was linked
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_with_person_create(hass: HomeAssistant, hass_client):
|
||||
"""Test a person create."""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
},
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
await async_setup_component(hass, PERSON_DOMAIN, {})
|
||||
|
||||
with mock_oidc_responses():
|
||||
code = await get_login_code(hass, hass_client)
|
||||
user = await login_user(hass, code)
|
||||
assert user.is_active
|
||||
|
||||
# Find the person associated to this user using the PersonRegistry API
|
||||
person_store = hass.data[PERSON_DOMAIN][1]
|
||||
persons = person_store.async_items()
|
||||
assert len(persons) == 1
|
||||
|
||||
person = persons[0]
|
||||
assert person["user_id"] == user.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_shows_form(hass: HomeAssistant):
|
||||
"""Test a login"""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
},
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||
flow = await provider.async_login_flow({})
|
||||
|
||||
result = await flow.async_step_init({})
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "mfa"
|
||||
|
||||
# Attempt an invalid code
|
||||
result = await flow.async_step_init({"code": "invalid"})
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
287
tests/test_hass_oidc_client.py
Normal file
287
tests/test_hass_oidc_client.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Tests for the OIDC client"""
|
||||
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import pytest
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from auth_oidc import DOMAIN
|
||||
from auth_oidc.tools.oidc_client import OIDCDiscoveryClient, OIDCDiscoveryInvalid
|
||||
from auth_oidc.config.const import (
|
||||
DISCOVERY_URL,
|
||||
CLIENT_ID,
|
||||
)
|
||||
|
||||
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||
|
||||
EXAMPLE_CLIENT_ID = "dummyclient"
|
||||
|
||||
|
||||
async def setup(hass: HomeAssistant):
|
||||
"""Set up the integration within Home Assistant"""
|
||||
mock_config = {
|
||||
DOMAIN: {
|
||||
CLIENT_ID: EXAMPLE_CLIENT_ID,
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
}
|
||||
}
|
||||
|
||||
result = await async_setup_component(hass, DOMAIN, mock_config)
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
|
||||
"""Test that one full OIDC flow works if OIDC is mocked."""
|
||||
|
||||
await setup(hass)
|
||||
|
||||
with mock_oidc_responses():
|
||||
# Start by going to /auth/oidc/redirect
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||
assert resp.status == 302
|
||||
assert resp.headers["Location"].startswith(MockOIDCServer.get_authorize_url())
|
||||
|
||||
# Parse the location header and test the query params for correctness
|
||||
location = resp.headers["Location"]
|
||||
parsed_url = urlparse(location)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
assert "response_type" in query_params and query_params.get(
|
||||
"response_type"
|
||||
) == ["code"]
|
||||
assert "client_id" in query_params and query_params.get("client_id") == [
|
||||
EXAMPLE_CLIENT_ID
|
||||
]
|
||||
assert "scope" in query_params and query_params.get("scope") == [
|
||||
"openid profile groups"
|
||||
]
|
||||
assert "state" in query_params and query_params["state"]
|
||||
state = query_params["state"][0]
|
||||
assert len(state) >= 16 # Ensure state is sufficiently long
|
||||
assert (
|
||||
"redirect_uri" in query_params
|
||||
and query_params["redirect_uri"]
|
||||
and query_params["redirect_uri"][0].endswith("/auth/oidc/callback")
|
||||
) # TODO: Also test that the URL itself is correct
|
||||
assert "nonce" in query_params and query_params["nonce"]
|
||||
assert "code_challenge_method" in query_params and query_params.get(
|
||||
"code_challenge_method"
|
||||
) == ["S256"]
|
||||
assert "code_challenge" in query_params and query_params["code_challenge"]
|
||||
|
||||
session = async_get_clientsession(hass)
|
||||
resp = session.get(location, allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
|
||||
json_parsed = await resp.json()
|
||||
assert "code" in json_parsed and json_parsed["code"]
|
||||
|
||||
# Now go back to the callback with a sample code
|
||||
code = json_parsed["code"]
|
||||
client = await hass_client()
|
||||
resp = await client.get(
|
||||
f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False
|
||||
)
|
||||
|
||||
# TODO: Test if logged text contains our login
|
||||
# TODO: Test if the code actually works
|
||||
assert resp.status == 302
|
||||
assert "/auth/oidc/finish?code=" in resp.headers["Location"]
|
||||
|
||||
|
||||
async def discovery_test_through_redirect(
|
||||
hass_client, caplog, scenario: str, match_log_line: str
|
||||
):
|
||||
"""Test that discovery document retrieval fails gracefully through redirect endpoint."""
|
||||
with mock_oidc_responses(scenario):
|
||||
# Start by going to /auth/oidc/redirect
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||
|
||||
# Find matching log line
|
||||
assert match_log_line in caplog.text
|
||||
|
||||
# Assert that we get a 200 response with an error message
|
||||
assert resp.status == 200
|
||||
text = await resp.text()
|
||||
assert "Integration is misconfigured, discovery could not be obtained." in text
|
||||
|
||||
|
||||
async def direct_discovery_test(
|
||||
hass: HomeAssistant,
|
||||
scenario: str,
|
||||
match_type: str,
|
||||
match_log_line: str | None = None,
|
||||
):
|
||||
"""Test that discovery document retrieval fails with nice error directly."""
|
||||
with mock_oidc_responses(scenario):
|
||||
session = async_get_clientsession(hass)
|
||||
client = OIDCDiscoveryClient(
|
||||
MockOIDCServer.get_discovery_url(),
|
||||
session,
|
||||
{
|
||||
"id_token_signing_alg": "RS256",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(OIDCDiscoveryInvalid) as exc_info:
|
||||
await client.fetch_discovery_document()
|
||||
|
||||
assert exc_info.value.type == match_type
|
||||
assert exc_info.value.get_detail_string().startswith("type: " + match_type)
|
||||
|
||||
if match_log_line:
|
||||
assert match_log_line in exc_info.value.get_detail_string()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discovery_failures(hass: HomeAssistant, hass_client, caplog):
|
||||
"""Test that discovery document retrieval fails gracefully."""
|
||||
|
||||
await setup(hass)
|
||||
|
||||
# Empty scenario
|
||||
await discovery_test_through_redirect(
|
||||
hass_client, caplog, "empty", "is missing required endpoint: issuer"
|
||||
)
|
||||
await direct_discovery_test(hass, "empty", "missing_endpoint", "endpoint: issuer")
|
||||
|
||||
# Missing authorization_endpoint
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"only_issuer",
|
||||
"is missing required endpoint: authorization_endpoint",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "only_issuer", "missing_endpoint", "endpoint: authorization_endpoint"
|
||||
)
|
||||
|
||||
# Missing token_endpoint
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"missing_token",
|
||||
"is missing required endpoint: token_endpoint",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "missing_token", "missing_endpoint", "endpoint: token_endpoint"
|
||||
)
|
||||
|
||||
# Missing jwks_uri
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"missing_jwks",
|
||||
"is missing required endpoint: jwks_uri",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "missing_jwks", "missing_endpoint", "endpoint: jwks_uri"
|
||||
)
|
||||
|
||||
# Invalid response_modes_supported
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_response_modes",
|
||||
"does not support required 'query' response mode, only supports: ['post']",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "invalid_response_modes", "does_not_support_response_mode", "post"
|
||||
)
|
||||
|
||||
# Invalid grant_types supported
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_grant_types",
|
||||
"does not support required 'authorization_code' grant type, only supports: ['refresh_token']",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "invalid_grant_types", "does_not_support_grant_type", "refresh_token"
|
||||
)
|
||||
|
||||
# Invalid response types
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_response_types",
|
||||
"does not support required 'code' response type, only supports: ['token']",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "invalid_response_types", "does_not_support_response_type", "token"
|
||||
)
|
||||
|
||||
# Invalid code_challenge types
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_code_challenge_types",
|
||||
"does not support required 'S256' code challenge method, only supports: ['plain']",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass,
|
||||
"invalid_code_challenge_types",
|
||||
"does_not_support_required_code_challenge_method",
|
||||
"plain",
|
||||
)
|
||||
|
||||
# Invalid id_token_signing alg
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_id_token_signing_alg",
|
||||
"does not have 'id_token_signing_alg_values_supported' field",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass, "invalid_id_token_signing_alg", "missing_id_token_signing_alg_values"
|
||||
)
|
||||
|
||||
# Not matching id_token_signing alg
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"wrong_id_token_signing_alg",
|
||||
"does not support requested id_token_signing_alg 'RS256', only supports: ['HS256']",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass,
|
||||
"wrong_id_token_signing_alg",
|
||||
"does_not_support_id_token_signing_alg",
|
||||
"requested: RS256, supported: ['HS256']",
|
||||
)
|
||||
|
||||
# Invalid URL
|
||||
await discovery_test_through_redirect(
|
||||
hass_client,
|
||||
caplog,
|
||||
"invalid_url",
|
||||
"has invalid URL in endpoint: jwks_uri (/jwks)",
|
||||
)
|
||||
await direct_discovery_test(
|
||||
hass,
|
||||
"invalid_url",
|
||||
"invalid_endpoint",
|
||||
"endpoint: jwks_uri, url: /jwks",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_jwks_fetch(hass: HomeAssistant):
|
||||
"""Test direct fetch of JWKS."""
|
||||
with mock_oidc_responses():
|
||||
session = async_get_clientsession(hass)
|
||||
client = OIDCDiscoveryClient(
|
||||
MockOIDCServer.get_discovery_url(),
|
||||
session,
|
||||
{
|
||||
"id_token_signing_alg": "RS256",
|
||||
},
|
||||
)
|
||||
|
||||
await client.fetch_discovery_document()
|
||||
jwks = await client.fetch_jwks()
|
||||
assert "keys" in jwks
|
||||
364
tests/test_hass_ui_config_flow.py
Normal file
364
tests/test_hass_ui_config_flow.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Tests for the UI config flow"""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
|
||||
from custom_components.auth_oidc import DOMAIN, async_setup_entry
|
||||
from custom_components.auth_oidc.config.const import (
|
||||
OIDC_PROVIDERS,
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
DISCOVERY_URL,
|
||||
DISPLAY_NAME,
|
||||
FEATURES,
|
||||
FEATURES_AUTOMATIC_USER_LINKING,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||
CLAIMS,
|
||||
CLAIMS_DISPLAY_NAME,
|
||||
CLAIMS_GROUPS,
|
||||
CLAIMS_USERNAME,
|
||||
ROLES,
|
||||
ROLE_ADMINS,
|
||||
ROLE_USERS,
|
||||
)
|
||||
|
||||
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||
|
||||
DEMO_CLIENT_ID = "testing_example_client_id"
|
||||
DEMO_CLIENT_SECRET = "faz"
|
||||
DEMO_ADMIN_ROLE = "boo"
|
||||
DEMO_USER_ROLE = "far"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_config_flow_success(hass: HomeAssistant):
|
||||
"""Test a successful full config flow."""
|
||||
|
||||
with mock_oidc_responses():
|
||||
# 1. Start the user step
|
||||
# This simulates clicking "Add Integration" in the UI.
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
# Assert that it's a form and expects user input for the 'user' step
|
||||
# 'user' is always the first step if it is user triggered
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "user"
|
||||
assert result["data_schema"] is not None
|
||||
schema = result["data_schema"]
|
||||
# Extract the schema dict from voluptuous Schema
|
||||
schema_dict = schema.schema
|
||||
# Assert 'provider' is a key in the schema
|
||||
assert "provider" in schema_dict
|
||||
# Assert 'authentik' is one of the allowed values for 'provider'
|
||||
provider_field = schema_dict["provider"]
|
||||
# If provider_field is a voluptuous In validator, get its container
|
||||
allowed_providers = getattr(provider_field, "container", None)
|
||||
assert "authentik" in OIDC_PROVIDERS
|
||||
assert allowed_providers is not None and "authentik" in allowed_providers
|
||||
|
||||
assert result["errors"] == {}
|
||||
|
||||
# 2. Submit user input for the 'user' step
|
||||
# This simulates the user filling out host/port
|
||||
user_input_step_user = {"provider": "authentik"}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input_step_user
|
||||
)
|
||||
|
||||
# Assert that it proceeds to the 'auth' step
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "discovery_url"
|
||||
assert result["data_schema"] is not None
|
||||
assert result["errors"] == {}
|
||||
|
||||
# Fill in the discovery URL
|
||||
user_input_step_discovery = {
|
||||
"discovery_url": MockOIDCServer.get_discovery_url()
|
||||
}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input_step_discovery
|
||||
)
|
||||
|
||||
# Assert that it proceeds to the 'credentials' step
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "validate_connection"
|
||||
|
||||
# Assert that it validates correctly with our mock
|
||||
assert result["errors"] == {}
|
||||
|
||||
# Send in continue
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"action": "continue"}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "client_config"
|
||||
assert result["data_schema"] is not None
|
||||
assert result["errors"] == {}
|
||||
|
||||
# Fill in the client config
|
||||
user_input_step_client_config = {
|
||||
"client_id": DEMO_CLIENT_ID,
|
||||
"client_secret": DEMO_CLIENT_SECRET,
|
||||
}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input_step_client_config
|
||||
)
|
||||
|
||||
# Assert that we are at groups_config
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "groups_config"
|
||||
assert result["data_schema"] is not None
|
||||
assert result["errors"] == {}
|
||||
|
||||
# Fill in the groups config
|
||||
user_input_step_groups_config = {
|
||||
"admin_group": DEMO_ADMIN_ROLE,
|
||||
"user_group": DEMO_USER_ROLE,
|
||||
}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input_step_groups_config
|
||||
)
|
||||
|
||||
# Assert that were are at user_linking config
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "user_linking"
|
||||
assert result["data_schema"] is not None
|
||||
assert result["errors"] == {}
|
||||
|
||||
# Fill in the user linking config
|
||||
user_input_step_user_linking = {"enable_user_linking": False}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input_step_user_linking
|
||||
)
|
||||
|
||||
# Finally, assert that the flow is complete and a config entry is created
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == OIDC_PROVIDERS["authentik"]["name"]
|
||||
|
||||
expected_data = {
|
||||
"provider": "authentik",
|
||||
CLIENT_ID: DEMO_CLIENT_ID,
|
||||
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||
},
|
||||
CLAIMS: {
|
||||
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"][
|
||||
"display_name"
|
||||
],
|
||||
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||
},
|
||||
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||
}
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
# Verify that the config entry was loaded into Home Assistant
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].data == expected_data
|
||||
|
||||
# You can also assert that `async_setup_entry` was called for this entry
|
||||
# (assuming it's mocked or you let it run if it's simple)
|
||||
# The PHCC `hass` fixture automatically mocks `async_setup_entry`
|
||||
# and `async_unload_entry` for you, making it easy to test that they're called.
|
||||
assert await async_setup_entry(hass, entries[0]) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_options_flow_success(hass: HomeAssistant):
|
||||
"""Test a successful options flow."""
|
||||
|
||||
# First, set up an initial config entry as in the full config flow
|
||||
initial_data = {
|
||||
"provider": "authentik",
|
||||
CLIENT_ID: DEMO_CLIENT_ID,
|
||||
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||
},
|
||||
CLAIMS: {
|
||||
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||
},
|
||||
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||
}
|
||||
|
||||
entry = config_entries.ConfigEntry(
|
||||
version=1,
|
||||
minor_version=0,
|
||||
domain=DOMAIN,
|
||||
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||
data=initial_data,
|
||||
source=config_entries.SOURCE_USER,
|
||||
entry_id="1",
|
||||
unique_id="test_unique_id",
|
||||
options={},
|
||||
pref_disable_new_entities=False,
|
||||
pref_disable_polling=False,
|
||||
discovery_keys=None,
|
||||
subentries_data=None,
|
||||
)
|
||||
|
||||
await hass.config_entries.async_add(entry)
|
||||
|
||||
# Start the reconfigure flow
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
# Should start the options flow
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
assert result["data_schema"] is not None
|
||||
|
||||
# Assert that the schema is as expected
|
||||
# Schema contains enable_user_linking, enable_groups, admin_group & user_groups and no other keys
|
||||
schema = result["data_schema"]
|
||||
schema_dict = schema.schema
|
||||
# Assert that the schema contains the expected keys
|
||||
expected_keys = {
|
||||
"admin_group",
|
||||
"enable_user_linking",
|
||||
"enable_groups",
|
||||
"user_group",
|
||||
}
|
||||
assert set(schema_dict.keys()) == expected_keys
|
||||
|
||||
# Change the client_id and client_secret
|
||||
new_enable_linking = True
|
||||
new_enable_groups = True
|
||||
new_admin_group = "bazzbbb"
|
||||
new_user_group = "foobar"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"enable_user_linking": new_enable_linking,
|
||||
"enable_groups": new_enable_groups,
|
||||
"admin_group": new_admin_group,
|
||||
"user_group": new_user_group,
|
||||
},
|
||||
)
|
||||
|
||||
# Should finish and update the entry options
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
|
||||
# Optionally, check that the entry options are updated
|
||||
updated_entry = hass.config_entries.async_get_entry(entry.entry_id)
|
||||
assert updated_entry is not None
|
||||
|
||||
# Verify that the config entry was loaded into Home Assistant
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
|
||||
assert (
|
||||
entries[0].data[FEATURES][FEATURES_AUTOMATIC_USER_LINKING] == new_enable_linking
|
||||
)
|
||||
assert entries[0].data[FEATURES][FEATURES_INCLUDE_GROUPS_SCOPE] == new_enable_groups
|
||||
assert entries[0].data[ROLES][ROLE_ADMINS] == new_admin_group
|
||||
assert entries[0].data[ROLES][ROLE_USERS] == new_user_group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconfigure_flow_success(hass: HomeAssistant):
|
||||
"""Test a successful reconfigure flow."""
|
||||
|
||||
# First, set up an initial config entry as in the full config flow
|
||||
initial_data = {
|
||||
"provider": "authentik",
|
||||
CLIENT_ID: DEMO_CLIENT_ID,
|
||||
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||
FEATURES: {
|
||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||
},
|
||||
CLAIMS: {
|
||||
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||
},
|
||||
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||
}
|
||||
|
||||
entry = config_entries.ConfigEntry(
|
||||
version=1,
|
||||
minor_version=0,
|
||||
domain=DOMAIN,
|
||||
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||
data=initial_data,
|
||||
source=config_entries.SOURCE_USER,
|
||||
entry_id="1",
|
||||
unique_id="test_unique_id",
|
||||
options={},
|
||||
pref_disable_new_entities=False,
|
||||
pref_disable_polling=False,
|
||||
discovery_keys=None,
|
||||
subentries_data=None,
|
||||
)
|
||||
|
||||
await hass.config_entries.async_add(entry)
|
||||
|
||||
# Start async_step_reconfigure to reconfigure the entry
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={
|
||||
"source": config_entries.SOURCE_RECONFIGURE,
|
||||
"entry_id": entry.entry_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Should start the reconfigure flow
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "reconfigure"
|
||||
assert result["data_schema"] is not None
|
||||
|
||||
# Assert that the schema is client_id & client_secret
|
||||
schema = result["data_schema"]
|
||||
schema_dict = schema.schema
|
||||
# Assert that the schema contains the expected keys
|
||||
expected_keys = {
|
||||
"client_id",
|
||||
"client_secret",
|
||||
}
|
||||
assert set(schema_dict.keys()) == expected_keys
|
||||
|
||||
# Change the client_id and client_secret
|
||||
new_client_id = "newclientid"
|
||||
new_client_secret = "newclientsecret"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"client_id": new_client_id,
|
||||
"client_secret": new_client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
# Should finish and update the entry data
|
||||
assert result["type"] == FlowResultType.ABORT
|
||||
assert result["reason"] == "reconfigure_successful"
|
||||
|
||||
# Verify that the config entry was loaded into Home Assistant
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].data[CLIENT_ID] == new_client_id
|
||||
assert entries[0].data[CLIENT_SECRET] == new_client_secret
|
||||
151
tests/test_hass_webserver.py
Normal file
151
tests/test_hass_webserver.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for the registered webpages"""
|
||||
|
||||
import os
|
||||
from auth_oidc.config.const import (
|
||||
DISCOVERY_URL,
|
||||
CLIENT_ID,
|
||||
FEATURES,
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.http import StaticPathConfig, DOMAIN as HTTP_DOMAIN
|
||||
|
||||
from custom_components.auth_oidc import DOMAIN
|
||||
|
||||
|
||||
async def setup(hass: HomeAssistant, enable_frontend_changes: bool = None):
|
||||
mock_config = {
|
||||
DOMAIN: {
|
||||
CLIENT_ID: "dummy",
|
||||
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||
FEATURES: {
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION: not enable_frontend_changes
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if enable_frontend_changes is None:
|
||||
del mock_config[DOMAIN][FEATURES][FEATURES_DISABLE_FRONTEND_INJECTION]
|
||||
|
||||
result = await async_setup_component(hass, DOMAIN, mock_config)
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_welcome_page_registration(hass: HomeAssistant, hass_client):
|
||||
"""Test that welcome page is present if frontend changes are disabled."""
|
||||
|
||||
await setup(hass, enable_frontend_changes=False)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/welcome", allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_welcome_page_registration_with_changes(hass: HomeAssistant, hass_client):
|
||||
"""Test that welcome page is redirect if frontend changes are enabled."""
|
||||
|
||||
await setup(hass, enable_frontend_changes=True)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/welcome", allow_redirects=False)
|
||||
assert resp.status == 307
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_page_registration(hass: HomeAssistant, hass_client):
|
||||
"""Test that redirect page shows OIDC misconfiguration error if OIDC server is not reachable."""
|
||||
|
||||
await setup(hass)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
text = await resp.text()
|
||||
assert "Integration is misconfigured" in text
|
||||
|
||||
resp2 = await client.post("/auth/oidc/redirect", allow_redirects=False)
|
||||
assert resp2.status == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_registration(hass: HomeAssistant, hass_client):
|
||||
"""Test that callback page is reachable."""
|
||||
|
||||
await setup(hass)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/callback", allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_registration(hass: HomeAssistant, hass_client):
|
||||
"""Test that finish page is reachable."""
|
||||
|
||||
await setup(hass)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/oidc/finish", allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
text = await resp.text()
|
||||
|
||||
# Should miss the code parameter if called without it
|
||||
assert "Missing code" in text
|
||||
|
||||
resp2 = await client.get("/auth/oidc/finish?code=123456", allow_redirects=False)
|
||||
assert resp2.status == 200
|
||||
text2 = await resp2.text()
|
||||
assert "Missing code" not in text2
|
||||
assert "123456" in text2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_post(hass: HomeAssistant, hass_client):
|
||||
"""Test that finish page works with POST."""
|
||||
|
||||
await setup(hass)
|
||||
client = await hass_client()
|
||||
resp = await client.post("/auth/oidc/finish", data={}, allow_redirects=False)
|
||||
assert resp.status == 500
|
||||
|
||||
resp2 = await client.post(
|
||||
"/auth/oidc/finish", data={"code": "456888"}, allow_redirects=False
|
||||
)
|
||||
assert resp2.status == 302
|
||||
assert resp2.headers["Location"] == "/?storeToken=true"
|
||||
assert resp2.cookies["auth_oidc_code"].value == "456888"
|
||||
|
||||
|
||||
# Test the frontend injection
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_injection(hass: HomeAssistant, hass_client):
|
||||
"""Test that frontend injection works."""
|
||||
|
||||
# Because there is no frontend in the test setup,
|
||||
# we'll have to fake /auth/authorize for the changes to register
|
||||
await async_setup_component(hass, HTTP_DOMAIN, {})
|
||||
|
||||
mock_html_path = os.path.join(os.path.dirname(__file__), "mocks", "auth_page.html")
|
||||
await hass.http.async_register_static_paths(
|
||||
[
|
||||
StaticPathConfig(
|
||||
"/auth/authorize",
|
||||
mock_html_path,
|
||||
cache_headers=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
await setup(hass, enable_frontend_changes=True)
|
||||
|
||||
client = await hass_client()
|
||||
resp = await client.get("/auth/authorize", allow_redirects=False)
|
||||
assert resp.status == 200
|
||||
text = await resp.text()
|
||||
|
||||
assert "<script src='/auth/oidc/static/injection.js" in text
|
||||
93
tests/test_hass_yaml_init.py
Normal file
93
tests/test_hass_yaml_init.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for the YAML config setup of OIDC"""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from custom_components.auth_oidc import DOMAIN
|
||||
from custom_components.auth_oidc.config.const import ADDITIONAL_SCOPES
|
||||
|
||||
|
||||
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
||||
"""Set up the auth_oidc component."""
|
||||
result = await async_setup_component(hass, DOMAIN, {DOMAIN: config})
|
||||
|
||||
if expect_success:
|
||||
assert result
|
||||
assert DOMAIN in hass.data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_yaml(hass: HomeAssistant):
|
||||
"""Test successful setup of a YAML configuration."""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
"client_id": "dummy",
|
||||
"discovery_url": "https://example.com/.well-known/openid-configuration",
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_yaml_with_optional(hass: HomeAssistant):
|
||||
"""Test successful setup of a YAML configuration with optional parameters."""
|
||||
await setup(
|
||||
hass,
|
||||
{
|
||||
"client_id": "dummy",
|
||||
"discovery_url": "https://example.com/.well-known/openid-configuration",
|
||||
ADDITIONAL_SCOPES: "email phone",
|
||||
},
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_failure_empty_yaml(hass: HomeAssistant, caplog):
|
||||
"""Test failure setup of an empty YAML configuration."""
|
||||
await setup(hass, {}, False)
|
||||
|
||||
assert "required key 'client_id' not provided" in caplog.text
|
||||
assert "required key 'discovery_url' not provided" in caplog.text
|
||||
assert (
|
||||
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_failure_partial_empty_yaml_discovery(hass: HomeAssistant, caplog):
|
||||
"""Test failure setup of an partial YAML configuration."""
|
||||
await setup(
|
||||
hass,
|
||||
{"discovery_url": "https://example.com/.well-known/openid-configuration"},
|
||||
False,
|
||||
)
|
||||
|
||||
assert "required key 'client_id' not provided" in caplog.text
|
||||
assert "required key 'discovery_url' not provided" not in caplog.text
|
||||
assert (
|
||||
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_failure_partial_empty_yaml_client(hass: HomeAssistant, caplog):
|
||||
"""Test failure setup of an partial YAML configuration."""
|
||||
|
||||
await setup(
|
||||
hass,
|
||||
{"client_id": "test"},
|
||||
False,
|
||||
)
|
||||
|
||||
assert "required key 'client_id' not provided" not in caplog.text
|
||||
assert "required key 'discovery_url' not provided" in caplog.text
|
||||
assert (
|
||||
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||
in caplog.text
|
||||
)
|
||||
85
tests/test_helpers.py
Normal file
85
tests/test_helpers.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for the helpers and validation tools"""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
|
||||
from custom_components.auth_oidc.tools.helpers import get_url, get_view
|
||||
from custom_components.auth_oidc.tools.validation import (
|
||||
validate_client_id,
|
||||
sanitize_client_secret,
|
||||
validate_discovery_url,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_url():
|
||||
"""Test the get_url helper."""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_url("https://example.com", "/test")
|
||||
assert str(excinfo.value) == "No current request in context"
|
||||
|
||||
# Mock homeassistant.components.http.current_request.get() to test the force HTTP flag
|
||||
with patch("homeassistant.components.http.current_request") as mock_current_request:
|
||||
fake_request = make_mocked_request("GET", "http://example.com")
|
||||
mock_current_request.get.return_value = fake_request
|
||||
result = get_url("/test", True)
|
||||
assert result == "https://example.com/test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_view():
|
||||
"""Test the get_view helper."""
|
||||
|
||||
data = await get_view("welcome")
|
||||
assert data.startswith("<!DOCTYPE html>")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_url():
|
||||
"""Test the validate_url helper."""
|
||||
|
||||
assert not validate_url("ftp://example.com")
|
||||
assert validate_url("http://example.com")
|
||||
assert validate_url("https://example.com")
|
||||
assert not validate_url("example.com")
|
||||
assert not validate_url(42)
|
||||
assert not validate_url([])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_discovery_url():
|
||||
"""Test the validate_discovery_url helper."""
|
||||
|
||||
assert not validate_discovery_url("ftp://example.com")
|
||||
assert not validate_discovery_url("http://example.com")
|
||||
assert not validate_discovery_url("https://example.com")
|
||||
assert not validate_discovery_url("example.com")
|
||||
assert not validate_discovery_url(
|
||||
"https://example.com/.well-known/openid_configuration"
|
||||
)
|
||||
assert validate_discovery_url(
|
||||
"https://example.com/.well-known/openid-configuration"
|
||||
)
|
||||
assert not validate_discovery_url(2)
|
||||
assert not validate_discovery_url([])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_secret():
|
||||
"""Test the sanitize_client_secret helper."""
|
||||
|
||||
assert sanitize_client_secret("test ") == "test"
|
||||
assert sanitize_client_secret("test2") == "test2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_id():
|
||||
"""Test the validate_client_id helper."""
|
||||
|
||||
assert not validate_client_id(" ")
|
||||
assert validate_client_id("test4")
|
||||
assert validate_client_id("test4 ")
|
||||
49
tests/test_view_template.py
Normal file
49
tests/test_view_template.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for the view templates"""
|
||||
|
||||
import pytest
|
||||
from os import path
|
||||
|
||||
from custom_components.auth_oidc.views.loader import AsyncTemplateRenderer
|
||||
|
||||
FAKE_TEMPLATE_PATH = path.join(
|
||||
path.dirname(path.abspath(__file__)), "resources", "fake_templates"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_template_render():
|
||||
"""Test that view template can render an real existing template."""
|
||||
|
||||
renderer = AsyncTemplateRenderer()
|
||||
rendered = await renderer.render_template("welcome.html")
|
||||
assert "<!DOCTYPE html>" in rendered
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_template_render():
|
||||
"""Test that view template can render an fake existing template."""
|
||||
|
||||
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||
await renderer.fetch_templates()
|
||||
rendered = await renderer.render_template("index.html")
|
||||
assert "<p>Example template</p>" in rendered
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dir_render_error():
|
||||
"""Test that view template sends correct error if you try to render directory."""
|
||||
|
||||
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||
await renderer.fetch_templates()
|
||||
with pytest.raises(ValueError):
|
||||
await renderer.render_template("folder.html")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_random_render_error():
|
||||
"""Test that view template sends correct error if you try to render non-existing."""
|
||||
|
||||
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||
await renderer.fetch_templates()
|
||||
with pytest.raises(ValueError):
|
||||
await renderer.render_template("non_existing.html")
|
||||
Reference in New Issue
Block a user