Code quality improvements (v0.2.0-pre-alpha) (#5)

* Bumped version to 0.2.0
* Implemented Github Actions for HACS, Hassfest, Linting
* Improved code quality (compliant with the linter now)
* Added link to the finish page to automatically login on the same device/browser
This commit is contained in:
Christiaan Goossens
2024-12-27 00:20:38 +01:00
committed by GitHub
parent a30d42ffce
commit b4a08b17ab
18 changed files with 1148 additions and 278 deletions

View File

@@ -1,3 +1,5 @@
"""OIDC Integration for Home Assistant."""
import logging
from typing import OrderedDict
@@ -10,29 +12,31 @@ from .endpoints.finish import OIDCFinishView
from .endpoints.callback import OIDCCallbackView
from .oidc_client import OIDCClient
from .provider import OpenIDAuthProvider
DOMAIN = "auth_oidc"
_LOGGER = logging.getLogger(__name__)
from .provider import OpenIDAuthProvider
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Required("client_id"): vol.Coerce(str),
vol.Optional("client_secret"): vol.Coerce(str),
vol.Required("discovery_url"): vol.Url(),
vol.Required("discovery_url"): vol.Coerce(str),
}
)
},
extra=vol.ALLOW_EXTRA,
)
async def async_setup(hass: HomeAssistant, config):
"""Add the OIDC Auth Provider to the providers in Home Assistant"""
providers = OrderedDict()
# Use private APIs until there is a real auth platform
# pylint: disable=protected-access
provider = OpenIDAuthProvider(
hass,
hass.auth._store,
@@ -42,13 +46,14 @@ async def async_setup(hass: HomeAssistant, config):
providers[(provider.type, provider.id)] = provider
providers.update(hass.auth._providers)
hass.auth._providers = providers
# pylint: enable=protected-access
_LOGGER.debug("Added OIDC provider for Home Assistant")
# Define some fields
discovery_url = config[DOMAIN]["discovery_url"]
client_id = config[DOMAIN]["client_id"]
scope = "openid profile email"
discovery_url: str = config[DOMAIN]["discovery_url"]
client_id: str = config[DOMAIN]["client_id"]
scope: str = "openid profile email"
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)

View File

@@ -1,12 +1,13 @@
from aiohttp import web
"""Callback route to return the user to after external OIDC interaction."""
from homeassistant.components.http import HomeAssistantView
import logging
from aiohttp import web
from ..oidc_client import OIDCClient
from ..provider import OpenIDAuthProvider
from ..helpers import get_url
PATH = "/auth/oidc/callback"
_LOGGER = logging.getLogger(__name__)
class OIDCCallbackView(HomeAssistantView):
"""OIDC Plugin Callback View."""
@@ -24,12 +25,9 @@ class OIDCCallbackView(HomeAssistantView):
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Callback view accessed")
params = request.rel_url.query
code = params.get("code")
state = params.get("state")
base_uri = str(request.url).split('/auth', 2)[0]
if not (code and state):
return web.Response(
@@ -37,13 +35,16 @@ class OIDCCallbackView(HomeAssistantView):
text="<h1>Error</h1><p>Missing code or state parameter</p>",
)
user_details = await self.oidc_client.complete_token_flow(base_uri, code, state)
redirect_uri = get_url("/auth/oidc/callback")
user_details = await self.oidc_client.async_complete_token_flow(
redirect_uri, code, state
)
if user_details is None:
return web.Response(
headers={"content-type": "text/html"},
text="<h1>Error</h1><p>Failed to get user details, see console.</p>",
)
code = await self.oidc_provider.save_user_info(user_details)
code = await self.oidc_provider.async_save_user_info(user_details)
return web.HTTPFound(base_uri + "/auth/oidc/finish?code=" + code)
return web.HTTPFound(get_url("/auth/oidc/finish?code=" + code))

View File

@@ -1,10 +1,12 @@
from aiohttp import web
"""Finish route to allow the user to view their code."""
from homeassistant.components.http import HomeAssistantView
import logging
from aiohttp import web
from ..helpers import get_url
PATH = "/auth/oidc/finish"
_LOGGER = logging.getLogger(__name__)
class OIDCFinishView(HomeAssistantView):
"""OIDC Plugin Finish View."""
@@ -17,8 +19,20 @@ class OIDCFinishView(HomeAssistantView):
"""Receive response."""
code = request.query.get("code", "FAIL")
link = get_url("/")
return web.Response(
headers={"content-type": "text/html"},
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p><p>Please return to the Home Assistant login screen (or your mobile app) and fill in this code into the single login field. It should be visible if you select 'Login with OpenID Connect (SSO)'.</p>",
)
headers={
"content-type": "text/html",
"set-cookie": "auth_oidc_code="
+ code
+ "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=300",
},
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p>"
+ "<p>Please return to the Home Assistant login "
+ "screen (or your mobile app) and fill in this code into the single login field. "
+ "It should be visible if you "
+ "select 'Login with OpenID Connect (SSO)'.</p><p><a href='"
+ link
+ "'>Click here to login automatically (on desktop).</a></p>",
)

View File

@@ -1,12 +1,14 @@
"""Redirect route to redirect the user to the external OIDC server,
can either be linked to directly or accessed through the welcome page."""
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
from ..oidc_client import OIDCClient
from ..helpers import get_url
PATH = "/auth/oidc/redirect"
_LOGGER = logging.getLogger(__name__)
class OIDCRedirectView(HomeAssistantView):
"""OIDC Plugin Redirect View."""
@@ -15,32 +17,23 @@ class OIDCRedirectView(HomeAssistantView):
url = PATH
name = "auth:oidc:redirect"
def __init__(
self, oidc_client: OIDCClient
) -> None:
def __init__(self, oidc_client: OIDCClient) -> None:
self.oidc_client = oidc_client
async def get(self, request: web.Request) -> web.Response:
async def get(self, _: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Redirect view accessed")
base_uri = str(request.url).split('/auth', 2)[0]
_LOGGER.debug("Base URI: %s", base_uri)
auth_url = await self.oidc_client.get_authorization_url(base_uri)
_LOGGER.debug("Auth URL: %s", auth_url)
redirect_uri = get_url("/auth/oidc/callback")
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
if auth_url:
return web.HTTPFound(auth_url)
else:
return web.Response(
return web.Response(
headers={"content-type": "text/html"},
text="<h1>Plugin is misconfigured, discovery could not be obtained</h1>",
)
async def post(self, request: web.Request) -> web.Response:
"""POST"""
_LOGGER.debug("Redirect POST view accessed")
return await self.get(request)
return await self.get(request)

View File

@@ -1,10 +1,10 @@
"""Welcome route to show the user the OIDC login button and give instructions."""
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
PATH = "/auth/oidc/welcome"
_LOGGER = logging.getLogger(__name__)
class OIDCWelcomeView(HomeAssistantView):
"""OIDC Plugin Welcome View."""
@@ -13,12 +13,10 @@ class OIDCWelcomeView(HomeAssistantView):
url = PATH
name = "auth:oidc:welcome"
async def get(self, request: web.Request) -> web.Response:
async def get(self, _: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Welcome view accessed")
return web.Response(
headers={"content-type": "text/html"},
text="<h1>OIDC Login (beta)</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
)
text="<h1>OIDC Login</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
)

View File

@@ -0,0 +1,12 @@
"""Helper functions for the integration."""
from homeassistant.components import http
def get_url(path: str) -> str:
"""Returns the requested path appended to the current request base URL."""
if (req := http.current_request.get()) is None:
raise RuntimeError("No current request in context")
base_uri = str(req.url).split("/auth", 2)[0]
return f"{base_uri}{path}"

View File

@@ -1,14 +1,20 @@
{
"domain": "auth_oidc",
"name": "OIDC Authentication",
"documentation": "",
"requirements": [],
"ssdp": [],
"zeroconf": [],
"homekit": {},
"dependencies": [
"auth"
"codeowners": [
"@christiaangoossens"
],
"codeowners": ["@christiaangoossens"],
"version": "0.1"
}
"config_flow": false,
"dependencies": [
"auth",
"http"
],
"documentation": "https://github.com/christiaangoossens/hass-oidc-auth",
"integration_type": "service",
"iot_class": "calculated",
"issue_tracker": "https://github.com/christiaangoossens/hass-oidc-auth/issues",
"requirements": [
"python-jose>=3.3.0"
],
"version": "0.2.0"
}

View File

@@ -1,25 +1,50 @@
import aiohttp
"""OIDC Client class"""
import urllib.parse
import logging
import os
import base64
import hashlib
from jose import jwt
from jose import jwk, jwt
from typing import Optional
import aiohttp
from jose import jwt, jwk
_LOGGER = logging.getLogger(__name__)
class OIDCClientException(Exception):
"Raised when the OIDC Client encounters an error"
class OIDCDiscoveryInvalid(OIDCClientException):
"Raised when the discovery document is not found, invalid or otherwise malformed."
class OIDCTokenResponseInvalid(OIDCClientException):
"Raised when the token request returns invalid."
class OIDCJWKSInvalid(OIDCClientException):
"Raised when the JWKS is invalid or cannot be obtained."
class OIDCStateInvalid(OIDCClientException):
"Raised when the state for your request cannot be matched against a stored state."
class OIDCClient:
"""OIDC Client implementation for Python, including PKCE."""
# Flows stores the state, code_verifier and nonce of all current flows.
flows = {}
def __init__(self, discovery_url, client_id, scope):
def __init__(self, discovery_url: str, client_id: str, scope: str):
self.discovery_url = discovery_url
self.discovery_document = None
self.client_id = client_id
self.scope = scope
async def fetch_discovery_document(self):
async def _fetch_discovery_document(self):
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.discovery_url) as response:
@@ -27,47 +52,60 @@ class OIDCClient:
return await response.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
_LOGGER.warning(
"Error: Discovery document not found at %s", self.discovery_url
)
else:
_LOGGER.warning(f"Error: {e.status} - {e.message}")
return None
async def get_authorization_url(self, base_uri):
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
_LOGGER.warning("Error: %s - %s", e.status, e.message)
raise OIDCDiscoveryInvalid from e
if not self.discovery_document:
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
"""Generates the authorization URL for the OIDC flow."""
try:
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
auth_endpoint = self.discovery_document["authorization_endpoint"]
# Generate the necessary PKCE parameters, nonce & state
code_verifier = (
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode("utf-8")
)
code_challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(code_verifier.encode("utf-8")).digest()
)
.rstrip(b"=")
.decode("utf-8")
)
nonce = (
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
)
state = (
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
)
# Save all of them for later verification
self.flows[state] = {"code_verifier": code_verifier, "nonce": nonce}
# Construct the params
query_params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": redirect_uri,
"scope": self.scope,
"state": state,
"nonce": nonce,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
return url
except OIDCClientException as e:
_LOGGER.warning("Error generating authorization URL: %s", e)
return None
auth_endpoint = self.discovery_document['authorization_endpoint']
# Generate the necessary PKCE parameters, nonce & state
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
# Save all of them for later verification
self.flows[state] = {
'code_verifier': code_verifier,
'nonce': nonce
}
# Construct the params
query_params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': base_uri + '/auth/oidc/callback',
'scope': self.scope,
'state': state,
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': 'S256',
}
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
return url
async def _make_token_request(self, token_endpoint, query_params):
try:
async with aiohttp.ClientSession() as session:
@@ -75,12 +113,9 @@ class OIDCClient:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
response_json = await response.json()
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
return None
_LOGGER.warning("Error exchanging token: %s - %s", e.status, e.message)
raise OIDCTokenResponseInvalid from e
return None
async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
@@ -89,23 +124,18 @@ class OIDCClient:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
return None
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
raise OIDCJWKSInvalid from e
async def _parse_id_token(self, id_token: str):
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
async def _parse_id_token(self, id_token):
# Parse the id token to obtain the relevant details
# Use python-jose
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
jwks_uri = self.discovery_document['jwks_uri']
jwks_uri = self.discovery_document["jwks_uri"]
jwks_data = await self._get_jwks(jwks_uri)
if not jwks_data:
return None
try:
unverified_header = jwt.get_unverified_header(id_token)
@@ -113,12 +143,11 @@ class OIDCClient:
print("Could not parse JWT Header")
return None
kid = unverified_header.get('kid')
kid = unverified_header.get("kid")
if not kid:
print("JWT does not have kid (Key ID)")
return None
# Get the correct key
rsa_key = None
for key in jwks_data["keys"]:
@@ -139,66 +168,60 @@ class OIDCClient:
jwk_obj,
algorithms=["RS256"], # Adjust if your algorithm is different
audience=self.client_id,
issuer=self.discovery_document['issuer'],
issuer=self.discovery_document["issuer"],
)
return decoded_token
except jwt.JWTError as e:
print(f"JWT Verification failed: {e}")
return None
except Exception as e:
print(f"Unexpected error: {e}")
return None
async def complete_token_flow(self, base_uri, code, state):
if state not in self.flows:
async def async_complete_token_flow(
self, redirect_uri: str, code: str, state: str
) -> dict[str, str | dict]:
"""Completes the OIDC token flow to obtain a user's details."""
try:
if state not in self.flows:
raise OIDCStateInvalid
flow = self.flows[state]
code_verifier = flow["code_verifier"]
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
token_endpoint = self.discovery_document["token_endpoint"]
# Construct the params
query_params = {
"grant_type": "authorization_code",
"client_id": self.client_id,
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
token_response = await self._make_token_request(
token_endpoint, query_params
)
id_token = token_response.get("id_token")
# Parse the id token to obtain the relevant details
id_token = await self._parse_id_token(id_token)
# Verify nonce
if id_token.get("nonce") != flow["nonce"]:
_LOGGER.warning("Nonce mismatch!")
return None
return {
"name": id_token.get("name"),
"username": id_token.get("preferred_username"),
"groups": id_token.get("groups"),
}
except OIDCClientException as e:
_LOGGER.warning("Error completing token flow: %s", e)
return None
flow = self.flows[state]
code_verifier = flow['code_verifier']
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
token_endpoint = self.discovery_document['token_endpoint']
# Construct the params
query_params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'code': code,
'redirect_uri': base_uri + '/auth/oidc/callback',
'code_verifier': code_verifier,
}
_LOGGER.debug(f"Token request params: {query_params}")
token_response = await self._make_token_request(token_endpoint, query_params)
if not token_response:
return None
access_token = token_response.get('access_token')
id_token = token_response.get('id_token')
_LOGGER.debug(f"Access Token: {access_token}")
_LOGGER.debug(f"ID Token: {id_token}")
# Parse the id token to obtain the relevant details
id_token = await self._parse_id_token(id_token)
# Verify nonce
if id_token.get('nonce') != flow['nonce']:
_LOGGER.warning(f"Nonce mismatch!")
return None
return {
"name": id_token.get("name"),
"email": id_token.get("email"),
"preferred_username": id_token.get("preferred_username"),
"nickname": id_token.get("nickname"),
"groups": id_token.get("groups"),
}

View File

@@ -1,8 +1,11 @@
"""OIDC Authentication provider.
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
"""
import logging
from typing import Dict, Optional
import asyncio
from homeassistant.auth.providers import (
AUTH_PROVIDERS,
AuthProvider,
@@ -11,30 +14,26 @@ from homeassistant.auth.providers import (
Credentials,
UserMeta,
)
from homeassistant.components import http
from homeassistant.exceptions import HomeAssistantError
import voluptuous as vol
from datetime import datetime, timedelta
import random
import string
from homeassistant.helpers.storage import Store
from collections.abc import Mapping
from .stores.code_store import CodeStore
_LOGGER = logging.getLogger(__name__)
class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication."""
@AUTH_PROVIDERS.register("oidc")
class OpenIDAuthProvider(AuthProvider):
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
"""Allow access to users based on login with an external
OpenID Connect Identity Provider (IdP)."""
DEFAULT_TITLE = "OpenID Connect (SSO)"
def __init__(self, *args, **kwargs):
"""Initialize the OpenIDAuthProvider."""
super().__init__(*args, **kwargs)
self._user_meta = {}
@property
def type(self) -> str:
return "auth_oidc"
@@ -42,13 +41,62 @@ class OpenIDAuthProvider(AuthProvider):
@property
def support_mfa(self) -> bool:
return False
def __init__(self, *args, **kwargs):
"""Initialize the OpenIDAuthProvider."""
super().__init__(*args, **kwargs)
self._user_meta = {}
self._code_store: CodeStore | None = None
self._init_lock = asyncio.Lock()
async def async_initialize(self) -> None:
"""Initialize the auth provider."""
# Init the code store first
# Use the same technique as the HomeAssistant auth provider for storage
# (/auth/providers/homeassistant.py#L392)
async with self._init_lock:
if self._code_store is not None:
return
store = CodeStore(self.hass)
await store.async_load()
self._code_store = store
self._user_meta = {}
async def async_retrieve_username(self, code: str) -> Optional[str]:
"""Retrieve user from the code, return username and save meta
for later use with this provider instance."""
if self._code_store is None:
await self.async_initialize()
assert self._code_store is not None
user_data = await self._code_store.receive_userinfo_for_code(code)
if user_data is None:
return None
username = user_data["username"]
self._user_meta[username] = user_data
return username
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
"""Save user info and return a code."""
if self._code_store is None:
await self.async_initialize()
assert self._code_store is not None
return await self._code_store.async_generate_code_for_userinfo(user_info)
# ====
# Required functions for Home Assistant Auth Providers
# ====
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
return OpenIdLoginFlow(self)
async def async_get_or_create_credentials(
self, flow_result: Mapping[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result["username"]
@@ -64,7 +112,7 @@ class OpenIDAuthProvider(AuthProvider):
) -> UserMeta:
"""Return extra user metadata for credentials.
Currently, supports name, group and local_only.
Currently, supports name, is_active, group and local_only.
"""
meta = self._user_meta.get(credentials.data["username"], {})
groups = meta.get("groups", [])
@@ -76,96 +124,70 @@ class OpenIDAuthProvider(AuthProvider):
group=group,
local_only="true",
)
async def save_user_info(self, user_info: dict) -> str:
"""Save user info during login."""
_LOGGER.info("User info to be saved: %s", user_info)
code = self._generate_code()
expiration = datetime.utcnow() + timedelta(minutes=5)
user_data = {
"user_info": user_info,
"code": code,
"expiration": expiration.isoformat()
}
await self._save_to_db(self._get_code_key(code), user_data)
return code
async def async_retrieve_username(self, code: str) -> Optional[dict]:
"""Retrieve user info based on the code."""
user_data = await self._get_from_db(self._get_code_key(code))
await self._wipe_from_db(self._get_code_key(code))
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
username = user_data["user_info"]["preferred_username"]
self._user_meta[username] = user_data["user_info"]
return username
return None
def _generate_code(self) -> str:
"""Generate a random six-digit code."""
return ''.join(random.choices(string.digits, k=6))
def _get_code_key(self, code: str) -> str:
return f"provider_oidc_auth_user_{code}"
async def _save_to_db(self, key: str, value: dict) -> None:
"""Save key-value data to the Home Assistant storage."""
store = Store(self.hass, 1, key)
await store.async_save(value)
async def _get_from_db(self, key: str) -> Optional[dict]:
"""Retrieve key-value data from the Home Assistant storage."""
store = Store(self.hass, 1, key)
return await store.async_load()
async def _wipe_from_db(self, key: str) -> None:
"""Delete key-value data from the Home Assistant storage."""
store = Store(self.hass, 1, key)
return await store.async_remove()
class OpenIdLoginFlow(LoginFlow):
"""Handler for the login flow."""
async def _finalize_user(self, code: str) -> AuthFlowResult:
username = await self._auth_provider.async_retrieve_username(code)
if username:
_LOGGER.info("Logged in user: %s", username)
return await self.async_finish(
{
"username": username,
}
)
raise InvalidAuthError
def _show_login_form(
self, errors: Optional[dict[str, str]] = None
) -> AuthFlowResult:
if errors is None:
errors = {}
# Show the login form
# Abuses the MFA form, as it works better for our usecase
# UI suggestions are welcome (make a PR!)
return self.async_show_form(
step_id="mfa",
data_schema=vol.Schema(
{
vol.Required("code"): str,
}
),
errors=errors,
)
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> AuthFlowResult:
"""Handle the step of the form."""
# Show the login form
# Currently, this form looks bad because the frontend gives no options to make it look better
# We will investigate options to make it look better in the future
return self.async_show_form(
step_id="mfa",
data_schema=vol.Schema(
{
vol.Required("code"): str,
}
),
errors={},
)
# Try to use the user input first
if user_input is not None:
try:
return await self._finalize_user(user_input["code"])
except InvalidAuthError:
return self._show_login_form({"base": "invalid_auth"})
# If not available, check the cookie
req = http.current_request.get()
code_cookie = req.cookies.get("auth_oidc_code")
if code_cookie:
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
try:
return await self._finalize_user(code_cookie)
except InvalidAuthError:
pass
# If none are available, just show the form
return self._show_login_form()
async def async_step_mfa(
self, user_input: dict[str, str] | None = None
) -> AuthFlowResult:
"""Handle the result of the form."""
if user_input is None:
return self.async_abort(reason="no_code_given")
# Log
_LOGGER.info("User input %s", user_input)
_LOGGER.info("Code %s was entered", user_input["code"])
username = await self._auth_provider.async_retrieve_username(user_input["code"])
if username:
_LOGGER.info("Logged in user: %s", username)
return await self.async_finish({
"username": username,
})
return self.async_abort(reason="invalid_code")
# This is a dummy step function just to use the nicer MFA UI instead
return await self.async_step_init(user_input)

View File

@@ -0,0 +1,80 @@
"""Code Store, stores the codes and their associated authenticated user temporarily."""
import random
import string
from datetime import datetime, timedelta
from typing import cast, Optional
from homeassistant.helpers.storage import Store
from homeassistant.core import HomeAssistant
STORAGE_VERSION = 1
STORAGE_KEY = "auth_provider.auth_oidc.codes"
class CodeStore:
"""Holds the codes and associated data"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the user data store."""
self.hass = hass
self._store = Store[dict[str, dict[str, dict | str]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
)
self._data: dict[str, dict[str, dict | str]] | None = None
async def async_load(self) -> None:
"""Load stored data."""
if (data := await self._store.async_load()) is None:
data = cast(dict[str, dict[str, dict | str]], {})
self._data = data
async def async_save(self) -> None:
"""Save data."""
if self._data is not None:
await self._store.async_save(self._data)
def _generate_code(self) -> str:
"""Generate a random six-digit code."""
return "".join(random.choices(string.digits, k=6))
async def async_generate_code_for_userinfo(
self, user_info: dict[str, dict | str]
) -> str:
"""Generates a one time code and adds it to the database for 5 minutes."""
if self._data is None:
raise RuntimeError("Data not loaded")
code = self._generate_code()
expiration = datetime.utcnow() + timedelta(minutes=5)
self._data[code] = {
"user_info": user_info,
"code": code,
"expiration": expiration.isoformat(),
}
await self.async_save()
return code
async def receive_userinfo_for_code(
self, code: str
) -> Optional[dict[str, dict | str]]:
"""Retrieve user info based on the code."""
if self._data is None:
raise RuntimeError("Data not loaded")
user_data = self._data.get(code)
if user_data:
# We should now wipe it from the database, as it's one time use code
self._data.pop(code)
await self.async_save()
if (
user_data
and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow()
):
return user_data["user_info"]
return None