Files
hass-oidc-auth/tests/test_state_store.py
2026-04-13 22:51:31 +02:00

261 lines
9.3 KiB
Python

"""Tests for the state store."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.core import HomeAssistant
from auth_oidc.stores.state_store import MAX_DEVICE_CODE_ATTEMPTS, StateStore
TEST_IP = "127.0.0.1"
@pytest.mark.asyncio
async def test_state_store_generate_and_receive_state(hass: HomeAssistant):
"""Test creating a state, storing user info, and receiving it once."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
assert state_store.get_data() == {}
redirect_uri = "https://example.com/callback"
state_id = await state_store.async_create_state_from_url(redirect_uri, TEST_IP)
assert state_id in state_store.get_data()
assert (
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
== redirect_uri
)
user_info = {
"sub": "user1",
"display_name": "Test User",
"username": "testuser",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state(state_id, user_info) is True
)
assert state_id in state_store.get_data()
assert await state_store.async_is_state_ready(state_id, TEST_IP) is True
assert state_id in state_store.get_data()
result = await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
assert result == user_info
assert state_id not in state_store.get_data()
@pytest.mark.asyncio
async def test_state_store_generate_code_and_link_state(hass: HomeAssistant):
"""Test generating a device code and linking another state to it."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
code = await state_store.async_generate_code_for_state(target_state)
assert code is not None
assert len(code) == 6
assert code.isdigit()
user_info = {
"sub": "user2",
"display_name": "Device User",
"username": "deviceuser",
"role": "system-admin",
}
assert (
await state_store.async_add_userinfo_to_state(donor_state, user_info)
is True
)
assert donor_state in state_store.get_data()
assert (
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
is True
)
assert donor_state not in state_store.get_data()
assert await state_store.async_is_state_ready(target_state, TEST_IP) is True
assert target_state in state_store.get_data()
assert (
await state_store.async_receive_userinfo_for_state(target_state, TEST_IP)
== user_info
)
assert target_state not in state_store.get_data()
@pytest.mark.asyncio
async def test_state_store_link_state_returns_false_for_wrong_code(hass: HomeAssistant):
"""Test linking fails when the device code does not match any state."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
await state_store.async_generate_code_for_state(target_state)
user_info = {
"sub": "user3",
"display_name": "Wrong Code User",
"username": "wrongcode",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state(donor_state, user_info)
is True
)
assert (
await state_store.async_link_state_to_code(donor_state, "000000", TEST_IP)
is False
)
assert donor_state in state_store.get_data()
assert await state_store.async_is_state_ready(target_state, TEST_IP) is False
@pytest.mark.asyncio
async def test_state_store_throttles_device_code_link_attempts(hass: HomeAssistant):
"""Test that repeated wrong device codes are throttled per state."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
code = await state_store.async_generate_code_for_state(target_state)
assert code is not None
user_info = {
"sub": "user-throttle",
"display_name": "Throttle User",
"username": "throttle",
"role": "system-users",
}
assert await state_store.async_add_userinfo_to_state(donor_state, user_info)
for _ in range(MAX_DEVICE_CODE_ATTEMPTS):
assert (
await state_store.async_link_state_to_code(
donor_state, "000000", TEST_IP
)
is False
)
assert (
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
is False
)
@pytest.mark.asyncio
async def test_state_store_expired_state(hass: HomeAssistant):
"""Test that expired states are treated as invalid."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
state_id = await state_store.async_create_state_from_url(
"https://example.com/expired", TEST_IP
)
state_store.get_data()[state_id]["expiration"] = (
datetime.now(timezone.utc) - timedelta(minutes=10)
).isoformat()
assert (
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
is None
)
assert await state_store.async_is_state_ready(state_id, TEST_IP) is False
assert (
await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
is None
)
@pytest.mark.asyncio
async def test_state_store_data_not_loaded(hass: HomeAssistant):
"""Test that using the store before loading raises RuntimeError."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
with pytest.raises(RuntimeError):
await state_store.async_create_state_from_url(
"https://example.com", TEST_IP
)
with pytest.raises(RuntimeError):
await state_store.async_generate_code_for_state("state")
with pytest.raises(RuntimeError):
await state_store.async_add_userinfo_to_state(
"state",
{
"sub": "user4",
"display_name": "Not Loaded",
"username": "notloaded",
"role": "system-users",
},
)
with pytest.raises(RuntimeError):
await state_store.async_get_redirect_uri_for_state("state", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_is_state_ready("state", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_link_state_to_code("state", "123456", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_receive_userinfo_for_state("state", TEST_IP)
@pytest.mark.asyncio
async def test_state_store_missing_keys(hass: HomeAssistant):
"""Test that missing keys raise correct responses."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
# async_generate_code_for_state returns None if state_id is not found
store_mock.async_load.return_value = {}
await state_store.async_load()
assert await state_store.async_generate_code_for_state("nonexistent") is None
# async_add_userinfo_to_state returns False if state_id is not found
user_info = {
"sub": "user5",
"display_name": "Missing Keys",
"username": "missingkeys",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state("nonexistent", user_info)
is False
)