192 lines
6.3 KiB
Python
192 lines
6.3 KiB
Python
"""State Store, store authentication states (redirect_uri)."""
|
|
|
|
import secrets
|
|
import random
|
|
import string
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import cast, Optional
|
|
from homeassistant.helpers.storage import Store
|
|
from homeassistant.core import HomeAssistant
|
|
|
|
from ..tools.types import OIDCState, UserDetails
|
|
|
|
STORAGE_VERSION = 1
|
|
STORAGE_KEY = "auth_provider.auth_oidc.states"
|
|
MAX_DEVICE_CODE_ATTEMPTS = 10
|
|
|
|
|
|
class StateStore:
|
|
"""Holds the authentication states and associated data"""
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Initialize the user data store."""
|
|
self.hass = hass
|
|
self._store = Store[dict[str, OIDCState]](
|
|
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
|
)
|
|
self._data: dict[str, OIDCState] | None = None
|
|
|
|
async def async_load(self) -> None:
|
|
"""Load stored data."""
|
|
if (data := await self._store.async_load()) is None:
|
|
data = cast(dict[str, OIDCState], {})
|
|
self._data = data
|
|
|
|
async def _async_save(self) -> None:
|
|
"""Save data."""
|
|
if self._data is not None:
|
|
await self._store.async_save(self._data)
|
|
|
|
def _generate_id(self) -> str:
|
|
"""Generate a random identifier."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
def _generate_code(self) -> str:
|
|
"""Generate a random six-digit code."""
|
|
return "".join(random.choices(string.digits, k=6))
|
|
|
|
def _is_expired(self, state: OIDCState) -> bool:
|
|
"""Check if a state is expired."""
|
|
return datetime.fromisoformat(state["expiration"]) < datetime.now(timezone.utc)
|
|
|
|
def _is_valid(self, state: OIDCState, ip: str | None) -> bool:
|
|
"""Check if a state is valid"""
|
|
return (
|
|
not self._is_expired(state)
|
|
and bool(state["redirect_uri"])
|
|
and ip is not None
|
|
and state["ip_address"] == ip
|
|
)
|
|
|
|
async def async_create_state_from_url(self, redirect_uri: str, ip: str) -> str:
|
|
"""Generates a the OIDC state adds it to the database for 5 minutes."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
state_id = self._generate_id()
|
|
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
|
|
|
|
self._data[state_id] = {
|
|
"id": state_id,
|
|
"redirect_uri": redirect_uri,
|
|
"device_code": None,
|
|
"device_code_attempts": 0,
|
|
"user_details": None,
|
|
"expiration": expiration.isoformat(),
|
|
"ip_address": ip,
|
|
}
|
|
|
|
await self._async_save()
|
|
return state_id
|
|
|
|
async def async_generate_code_for_state(self, state_id: str) -> Optional[str]:
|
|
"""Generates a one time code for the state to link device clients."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
try:
|
|
code = self._generate_code()
|
|
self._data[state_id]["device_code"] = code
|
|
await self._async_save()
|
|
return code
|
|
except KeyError:
|
|
return None
|
|
|
|
async def async_add_userinfo_to_state(
|
|
self, state_id: str, user_info: UserDetails
|
|
) -> bool:
|
|
"""Add userinfo to existing state to complete login"""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
try:
|
|
self._data[state_id]["user_details"] = user_info
|
|
await self._async_save()
|
|
return True
|
|
except KeyError:
|
|
return False
|
|
|
|
async def async_get_redirect_uri_for_state(
|
|
self, state_id: str, ip: str
|
|
) -> Optional[str]:
|
|
"""Get the redirect_uri for a given state_id."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
state = self._data.get(state_id)
|
|
if state and self._is_valid(state, ip):
|
|
return state["redirect_uri"]
|
|
|
|
return None
|
|
|
|
async def async_is_state_ready(self, state_id: str, ip: str) -> bool:
|
|
"""Check if the state has received the user info from the OIDC callback."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
state = self._data.get(state_id)
|
|
return (
|
|
state is not None
|
|
and state["user_details"] is not None
|
|
and self._is_valid(state, ip)
|
|
)
|
|
|
|
async def async_link_state_to_code(
|
|
self, state_id: str, code: str, ip: str | None
|
|
) -> bool:
|
|
"""Link a state to a device code, used for mobile sign-in."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
state_data = self._data.get(state_id)
|
|
if (
|
|
state_data
|
|
and self._is_valid(state_data, ip)
|
|
and state_data["user_details"] is not None
|
|
):
|
|
attempts = state_data.get("device_code_attempts", 0)
|
|
if attempts >= MAX_DEVICE_CODE_ATTEMPTS:
|
|
return False
|
|
|
|
# Find the state with the matching device code and link it
|
|
for state in self._data.values():
|
|
if state["device_code"] == code and not self._is_expired(state):
|
|
# Set user details on the device state to allow it to complete login
|
|
state["user_details"] = state_data["user_details"]
|
|
|
|
# Delete the 'donor' state as it's one time use
|
|
self._data.pop(state_id)
|
|
|
|
# Save and return true
|
|
await self._async_save()
|
|
return True
|
|
|
|
state_data["device_code_attempts"] = attempts + 1
|
|
await self._async_save()
|
|
|
|
return False
|
|
|
|
async def async_receive_userinfo_for_state(
|
|
self, state_id: str, ip: str
|
|
) -> Optional[OIDCState]:
|
|
"""Retrieve user info based on the state_id."""
|
|
if self._data is None:
|
|
raise RuntimeError("Data not loaded")
|
|
|
|
user_data = self._data.get(state_id)
|
|
|
|
if user_data:
|
|
# We should now wipe it from the database, as it's one time use
|
|
self._data.pop(state_id)
|
|
await self._async_save()
|
|
|
|
if user_data and self._is_valid(user_data, ip):
|
|
return user_data["user_details"]
|
|
|
|
return None
|
|
|
|
def get_data(self):
|
|
"""Get the internal data for testing purposes."""
|
|
return self._data
|