Reimplement UI injection (#236)
This commit is contained in:
committed by
GitHub
parent
fdc93e2719
commit
fd3643685d
@@ -1,81 +0,0 @@
|
||||
"""Code Store, stores the codes and their associated authenticated user temporarily."""
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from ..tools.types import UserDetails
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth_provider.auth_oidc.codes"
|
||||
|
||||
|
||||
class CodeStore:
|
||||
"""Holds the codes and associated data"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = Store[dict[str, UserDetails]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, dict[str, dict | str]] | None = None
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
if (data := await self._store.async_load()) is None:
|
||||
data = cast(dict[str, UserDetails], {})
|
||||
self._data = data
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
if self._data is not None:
|
||||
await self._store.async_save(self._data)
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return "".join(random.choices(string.digits, k=6))
|
||||
|
||||
async def async_generate_code_for_userinfo(self, user_info: UserDetails) -> str:
|
||||
"""Generates a one time code and adds it to the database for 5 minutes."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
|
||||
self._data[code] = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat(),
|
||||
}
|
||||
|
||||
await self._async_save()
|
||||
return code
|
||||
|
||||
async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]:
|
||||
"""Retrieve user info based on the code."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
user_data = self._data.get(code)
|
||||
|
||||
if user_data:
|
||||
# We should now wipe it from the database, as it's one time use code
|
||||
self._data.pop(code)
|
||||
await self._async_save()
|
||||
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.now(
|
||||
timezone.utc
|
||||
):
|
||||
return user_data["user_info"]
|
||||
|
||||
return None
|
||||
|
||||
def get_data(self):
|
||||
"""Get the internal data for testing purposes."""
|
||||
return self._data
|
||||
191
custom_components/auth_oidc/stores/state_store.py
Normal file
191
custom_components/auth_oidc/stores/state_store.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""State Store, store authentication states (redirect_uri)."""
|
||||
|
||||
import secrets
|
||||
import random
|
||||
import string
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from ..tools.types import OIDCState, UserDetails
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth_provider.auth_oidc.states"
|
||||
MAX_DEVICE_CODE_ATTEMPTS = 10
|
||||
|
||||
|
||||
class StateStore:
|
||||
"""Holds the authentication states and associated data"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = Store[dict[str, OIDCState]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, OIDCState] | None = None
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
if (data := await self._store.async_load()) is None:
|
||||
data = cast(dict[str, OIDCState], {})
|
||||
self._data = data
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
if self._data is not None:
|
||||
await self._store.async_save(self._data)
|
||||
|
||||
def _generate_id(self) -> str:
|
||||
"""Generate a random identifier."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return "".join(random.choices(string.digits, k=6))
|
||||
|
||||
def _is_expired(self, state: OIDCState) -> bool:
|
||||
"""Check if a state is expired."""
|
||||
return datetime.fromisoformat(state["expiration"]) < datetime.now(timezone.utc)
|
||||
|
||||
def _is_valid(self, state: OIDCState, ip: str | None) -> bool:
|
||||
"""Check if a state is valid"""
|
||||
return (
|
||||
not self._is_expired(state)
|
||||
and bool(state["redirect_uri"])
|
||||
and ip is not None
|
||||
and state["ip_address"] == ip
|
||||
)
|
||||
|
||||
async def async_create_state_from_url(self, redirect_uri: str, ip: str) -> str:
|
||||
"""Generates a the OIDC state adds it to the database for 5 minutes."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
state_id = self._generate_id()
|
||||
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
|
||||
self._data[state_id] = {
|
||||
"id": state_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"device_code": None,
|
||||
"device_code_attempts": 0,
|
||||
"user_details": None,
|
||||
"expiration": expiration.isoformat(),
|
||||
"ip_address": ip,
|
||||
}
|
||||
|
||||
await self._async_save()
|
||||
return state_id
|
||||
|
||||
async def async_generate_code_for_state(self, state_id: str) -> Optional[str]:
|
||||
"""Generates a one time code for the state to link device clients."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
try:
|
||||
code = self._generate_code()
|
||||
self._data[state_id]["device_code"] = code
|
||||
await self._async_save()
|
||||
return code
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def async_add_userinfo_to_state(
|
||||
self, state_id: str, user_info: UserDetails
|
||||
) -> bool:
|
||||
"""Add userinfo to existing state to complete login"""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
try:
|
||||
self._data[state_id]["user_details"] = user_info
|
||||
await self._async_save()
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
async def async_get_redirect_uri_for_state(
|
||||
self, state_id: str, ip: str
|
||||
) -> Optional[str]:
|
||||
"""Get the redirect_uri for a given state_id."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
state = self._data.get(state_id)
|
||||
if state and self._is_valid(state, ip):
|
||||
return state["redirect_uri"]
|
||||
|
||||
return None
|
||||
|
||||
async def async_is_state_ready(self, state_id: str, ip: str) -> bool:
|
||||
"""Check if the state has received the user info from the OIDC callback."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
state = self._data.get(state_id)
|
||||
return (
|
||||
state is not None
|
||||
and state["user_details"] is not None
|
||||
and self._is_valid(state, ip)
|
||||
)
|
||||
|
||||
async def async_link_state_to_code(
|
||||
self, state_id: str, code: str, ip: str | None
|
||||
) -> bool:
|
||||
"""Link a state to a device code, used for mobile sign-in."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
state_data = self._data.get(state_id)
|
||||
if (
|
||||
state_data
|
||||
and self._is_valid(state_data, ip)
|
||||
and state_data["user_details"] is not None
|
||||
):
|
||||
attempts = state_data.get("device_code_attempts", 0)
|
||||
if attempts >= MAX_DEVICE_CODE_ATTEMPTS:
|
||||
return False
|
||||
|
||||
# Find the state with the matching device code and link it
|
||||
for state in self._data.values():
|
||||
if state["device_code"] == code and not self._is_expired(state):
|
||||
# Set user details on the device state to allow it to complete login
|
||||
state["user_details"] = state_data["user_details"]
|
||||
|
||||
# Delete the 'donor' state as it's one time use
|
||||
self._data.pop(state_id)
|
||||
|
||||
# Save and return true
|
||||
await self._async_save()
|
||||
return True
|
||||
|
||||
state_data["device_code_attempts"] = attempts + 1
|
||||
await self._async_save()
|
||||
|
||||
return False
|
||||
|
||||
async def async_receive_userinfo_for_state(
|
||||
self, state_id: str, ip: str
|
||||
) -> Optional[OIDCState]:
|
||||
"""Retrieve user info based on the state_id."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
user_data = self._data.get(state_id)
|
||||
|
||||
if user_data:
|
||||
# We should now wipe it from the database, as it's one time use
|
||||
self._data.pop(state_id)
|
||||
await self._async_save()
|
||||
|
||||
if user_data and self._is_valid(user_data, ip):
|
||||
return user_data["user_details"]
|
||||
|
||||
return None
|
||||
|
||||
def get_data(self):
|
||||
"""Get the internal data for testing purposes."""
|
||||
return self._data
|
||||
Reference in New Issue
Block a user