diff --git a/custom_components/auth_oidc/__init__.py b/custom_components/auth_oidc/__init__.py index af14a9e..b8638b3 100644 --- a/custom_components/auth_oidc/__init__.py +++ b/custom_components/auth_oidc/__init__.py @@ -27,7 +27,6 @@ from .config import ( ROLES, NETWORK, FEATURES_INCLUDE_GROUPS_SCOPE, - FEATURES_DISABLE_FRONTEND_INJECTION, FEATURES_FORCE_HTTPS, REQUIRED_SCOPES, ) @@ -40,6 +39,7 @@ from .endpoints import ( OIDCFinishView, OIDCCallbackView, OIDCInjectedAuthPage, + OIDCDeviceSSE, ) from .tools.oidc_client import OIDCClient from .provider import OpenIDAuthProvider @@ -96,6 +96,10 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam provider = OpenIDAuthProvider(hass, hass.auth._store, my_config) providers[(provider.type, provider.id)] = provider + + # Get current provider count + has_other_auth_providers = len(hass.auth._providers) > 0 + providers.update(hass.auth._providers) hass.auth._providers = providers # pylint: enable=protected-access @@ -137,33 +141,22 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam ) # Register the views - is_frontend_injection_enabled = ( - features_config.get(FEATURES_DISABLE_FRONTEND_INJECTION, False) is False - ) name = display_name name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name) force_https = features_config.get(FEATURES_FORCE_HTTPS, False) hass.http.register_view( - OIDCWelcomeView( - name, - # Welcome view is not enabled if frontend injection is enabled - not is_frontend_injection_enabled, - force_https, - ) + OIDCWelcomeView(provider, name, force_https, has_other_auth_providers) ) - hass.http.register_view(OIDCRedirectView(oidc_client, force_https)) + hass.http.register_view(OIDCDeviceSSE(provider)) + hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https)) hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https)) - hass.http.register_view(OIDCFinishView()) + hass.http.register_view(OIDCFinishView(provider)) _LOGGER.info("Registered OIDC views") - # Inject OIDC code into the frontend for /auth/authorize if the user has the - # frontend injection feature enabled - if is_frontend_injection_enabled: - await OIDCInjectedAuthPage.inject(hass, name) - else: - _LOGGER.info("OIDC frontend changes are disabled, skipping injection") + # Inject OIDC code into the frontend for /auth/authorize for automatic redirect + await OIDCInjectedAuthPage.inject(hass) return True diff --git a/custom_components/auth_oidc/config/const.py b/custom_components/auth_oidc/config/const.py index 8538262..24b0297 100644 --- a/custom_components/auth_oidc/config/const.py +++ b/custom_components/auth_oidc/config/const.py @@ -28,7 +28,6 @@ FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking" FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation" FEATURES_DISABLE_PKCE = "disable_rfc7636" FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope" -FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes" FEATURES_FORCE_HTTPS = "force_https" CLAIMS = "claims" CLAIMS_DISPLAY_NAME = "display_name" diff --git a/custom_components/auth_oidc/config/schema.py b/custom_components/auth_oidc/config/schema.py index 2167503..e7b9b9c 100644 --- a/custom_components/auth_oidc/config/schema.py +++ b/custom_components/auth_oidc/config/schema.py @@ -14,7 +14,6 @@ from .const import ( FEATURES_AUTOMATIC_PERSON_CREATION, FEATURES_DISABLE_PKCE, FEATURES_INCLUDE_GROUPS_SCOPE, - FEATURES_DISABLE_FRONTEND_INJECTION, FEATURES_FORCE_HTTPS, CLAIMS, CLAIMS_DISPLAY_NAME, @@ -72,10 +71,6 @@ CONFIG_SCHEMA = vol.Schema( vol.Optional( FEATURES_INCLUDE_GROUPS_SCOPE, default=True ): vol.Coerce(bool), - # Disable frontend injection of OIDC login button - vol.Optional( - FEATURES_DISABLE_FRONTEND_INJECTION, default=False - ): vol.Coerce(bool), # Force HTTPS on all generated URLs (like redirect_uri) vol.Optional(FEATURES_FORCE_HTTPS, default=False): vol.Coerce( bool diff --git a/custom_components/auth_oidc/config/ui_flow.py b/custom_components/auth_oidc/config/ui_flow.py index 847b26c..6d21009 100644 --- a/custom_components/auth_oidc/config/ui_flow.py +++ b/custom_components/auth_oidc/config/ui_flow.py @@ -621,21 +621,18 @@ class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors["client_id"] = "invalid_client_id" return errors, None - # Determine confidentiality by presence of client secret - client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip() - # If secret is empty, keep the existing one (if any) - if not client_secret: - client_secret = entry.data.get("client_secret") - # Build updated data data_updates = {"client_id": client_id} - if client_secret: - data_updates["client_secret"] = client_secret - elif "client_secret" in entry.data and not client_secret: - # Remove client secret if switching from confidential to public - data_updates = {**entry.data, **data_updates} - data_updates.pop("client_secret", None) + # The optional secret field is submitted explicitly when the form is used. + # An empty value means the user wants to keep the existing secret. + if CONF_CLIENT_SECRET in user_input: + client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip() + + if client_secret: + data_updates["client_secret"] = client_secret + elif "client_secret" in entry.data: + data_updates["client_secret"] = entry.data["client_secret"] return errors, data_updates diff --git a/custom_components/auth_oidc/endpoints/__init__.py b/custom_components/auth_oidc/endpoints/__init__.py index 3808b68..613fba0 100644 --- a/custom_components/auth_oidc/endpoints/__init__.py +++ b/custom_components/auth_oidc/endpoints/__init__.py @@ -5,3 +5,4 @@ from .finish import OIDCFinishView as OIDCFinishView from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage from .redirect import OIDCRedirectView as OIDCRedirectView from .welcome import OIDCWelcomeView as OIDCWelcomeView +from .device_sse import OIDCDeviceSSE as OIDCDeviceSSE diff --git a/custom_components/auth_oidc/endpoints/callback.py b/custom_components/auth_oidc/endpoints/callback.py index 178ee1e..19aea11 100644 --- a/custom_components/auth_oidc/endpoints/callback.py +++ b/custom_components/auth_oidc/endpoints/callback.py @@ -4,7 +4,7 @@ from homeassistant.components.http import HomeAssistantView from aiohttp import web from ..tools.oidc_client import OIDCClient from ..provider import OpenIDAuthProvider -from ..tools.helpers import get_url, get_view +from ..tools.helpers import error_response, get_url, get_valid_state_id PATH = "/auth/oidc/callback" @@ -29,42 +29,49 @@ class OIDCCallbackView(HomeAssistantView): async def get(self, request: web.Request) -> web.Response: """Receive response.""" + # Get cookie to get the state_id + state_id = await get_valid_state_id(request, self.oidc_provider) + if not state_id: + return await error_response("Missing state cookie, please restart login.") + + # Get the OIDC query parameters params = request.rel_url.query code = params.get("code") state = params.get("state") if not (code and state): - view_html = await get_view( - "error", - { - "error": "Missing code or state parameter.", - }, - ) - return web.Response(text=view_html, content_type="text/html") + return await error_response("Missing code or state parameter.") + # Check if the states match + if state != state_id: + return await error_response( + "State parameter does not match, possible CSRF attack." + ) + + # Complete the OIDC flow to get user details redirect_uri = get_url("/auth/oidc/callback", self.force_https) user_details = await self.oidc_client.async_complete_token_flow( redirect_uri, code, state ) if user_details is None: - view_html = await get_view( - "error", - { - "error": "Failed to get user details, " - + "see Home Assistant logs for more information.", - }, + return await error_response( + "Failed to get user details, see Home Assistant logs for more information.", + status=500, ) - return web.Response(text=view_html, content_type="text/html") if user_details.get("role") == "invalid": - view_html = await get_view( - "error", - { - "error": "User is not in the correct group to access Home Assistant, " - + "contact your administrator!", - }, + return await error_response( + "User is not in the correct group to access Home Assistant, " + + "contact your administrator!", + status=403, ) - return web.Response(text=view_html, content_type="text/html") - code = await self.oidc_provider.async_save_user_info(user_details) - raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https)) + # Finalize on the state + success = await self.oidc_provider.async_save_user_info(state_id, user_details) + if not success: + return await error_response( + "Failed to save user information, session probably expired. Please sign in again.", + status=500, + ) + + raise web.HTTPFound(get_url("/auth/oidc/finish", self.force_https)) diff --git a/custom_components/auth_oidc/endpoints/device_sse.py b/custom_components/auth_oidc/endpoints/device_sse.py new file mode 100644 index 0000000..bc7aad6 --- /dev/null +++ b/custom_components/auth_oidc/endpoints/device_sse.py @@ -0,0 +1,70 @@ +"""SSE handler for OIDC device authentication.""" + +import asyncio +from aiohttp import web +from homeassistant.components.http import HomeAssistantView +from ..provider import OpenIDAuthProvider +from ..tools.helpers import get_valid_state_id + +PATH = "/auth/oidc/device-sse" + + +class OIDCDeviceSSE(HomeAssistantView): + """OIDC Plugin SSE Handler.""" + + requires_auth = False + url = PATH + name = "auth:oidc:device-sse" + + def __init__(self, oidc_provider: OpenIDAuthProvider) -> None: + self.oidc_provider = oidc_provider + + async def get(self, req: web.Request) -> web.Response: + """Check for mobile sign-in completion with short server-side polling.""" + state_id = await get_valid_state_id(req, self.oidc_provider) + if not state_id: + raise web.HTTPBadRequest(text="Missing session cookie") + + timeout_seconds = 300 + started_at = asyncio.get_running_loop().time() + + response = web.StreamResponse( + status=200, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await response.prepare(req) + + try: + while True: + if ( + await self.oidc_provider.async_get_redirect_uri_for_state(state_id) + is None + ): + await response.write(b"event: expired\ndata: false\n\n") + break + + ready = await self.oidc_provider.async_is_state_ready(state_id) + if ready: + await response.write(b"event: ready\ndata: true\n\n") + break + + if asyncio.get_running_loop().time() - started_at >= timeout_seconds: + await response.write(b"event: timeout\ndata: false\n\n") + break + + await response.write(b"event: waiting\ndata: false\n\n") + await asyncio.sleep(0.5) + except (ConnectionResetError, RuntimeError): + # Client disconnected while listening for state changes. + pass + finally: + try: + await response.write_eof() + except ConnectionResetError: + pass + + return response diff --git a/custom_components/auth_oidc/endpoints/finish.py b/custom_components/auth_oidc/endpoints/finish.py index 6b60371..8421215 100644 --- a/custom_components/auth_oidc/endpoints/finish.py +++ b/custom_components/auth_oidc/endpoints/finish.py @@ -2,7 +2,12 @@ from homeassistant.components.http import HomeAssistantView from aiohttp import web -from ..tools.helpers import get_view +from ..provider import OpenIDAuthProvider +from ..tools.helpers import ( + error_response, + get_valid_state_id, + template_response, +) PATH = "/auth/oidc/finish" @@ -14,41 +19,62 @@ class OIDCFinishView(HomeAssistantView): url = PATH name = "auth:oidc:finish" + def __init__( + self, + oidc_provider: OpenIDAuthProvider, + ) -> None: + self.oidc_provider = oidc_provider + async def get(self, request: web.Request) -> web.Response: - """Show the finish screen to allow the user to view their code.""" + """Show the finish screen to pick between login & device code.""" + # Get cookie to get the state_id + state_id = await get_valid_state_id(request, self.oidc_provider) + if not state_id: + return await error_response("Missing state cookie, please restart login.") - code = request.query.get("code") - - if not code: - view_html = await get_view( - "error", - {"error": "Missing code to show the finish screen."}, - ) - return web.Response(text=view_html, content_type="text/html") - - view_html = await get_view("finish", {"code": code}) - return web.Response(text=view_html, content_type="text/html") + return await template_response("finish", {}) async def post(self, request: web.Request) -> web.Response: """Receive response.""" - # Get code from the message body - data = await request.post() - code = data.get("code") + # Get cookie to get the state_id + state_id = await get_valid_state_id(request, self.oidc_provider) + if not state_id: + return await error_response("Missing state cookie, please restart login.") - if not code: - return web.Response(text="No code received", status=500) - - # Return redirect to the main page for sign in with a cookie - raise web.HTTPFound( - location="/?storeToken=true", - headers={ - # Set a cookie to enable autologin on only the specific path used - # for the POST request, with all strict parameters set - # This cookie should not be read by any Javascript or any other paths. - # It can be really short lifetime as we redirect immediately (5 seconds) - "set-cookie": "auth_oidc_code=" - + code - + "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=5", - }, + # Get redirect_uri from the state + redirect_uri = await self.oidc_provider.async_get_redirect_uri_for_state( + state_id ) + + if not redirect_uri: + return await error_response("Invalid state, please restart login.") + + # Get the message body + data = await request.post() + device_code = data.get("device_code") + + # We are trying sign-in on this browser + if not device_code: + # Add to the URL correctly (also handle case where it's just the root) + separator = "?" + if "?" in redirect_uri: + separator = "&" + + # Redirect to this new URL for login + new_url = ( + redirect_uri + separator + "storeToken=true&skip_oidc_redirect=true" + ) + raise web.HTTPFound(location=new_url) + + # Check if we can link this device + linked = await self.oidc_provider.async_link_state_to_code( + state_id, device_code + ) + + if not linked: + return await error_response( + "Failed to link state to device code, please restart login." + ) + + return await template_response("device_success", {}) diff --git a/custom_components/auth_oidc/endpoints/injected_auth_page.py b/custom_components/auth_oidc/endpoints/injected_auth_page.py index 74d0945..ba5c5c4 100644 --- a/custom_components/auth_oidc/endpoints/injected_auth_page.py +++ b/custom_components/auth_oidc/endpoints/injected_auth_page.py @@ -1,6 +1,5 @@ """Injected authorization page, replacing the original""" -import json import logging from functools import partial from homeassistant.components.http import HomeAssistantView, StaticPathConfig @@ -19,7 +18,7 @@ async def read_file(path: str) -> str: return await f.read() -async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None: +async def frontend_injection(hass: HomeAssistant) -> None: """Inject new frontend code into /auth/authorize.""" router = hass.http.app.router frontend_path = None @@ -62,11 +61,8 @@ async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None: frontend_code = await read_file(frontend_path) # Inject JS and register that route - injection_js = "" - sso_name_js = f"" - frontend_code = frontend_code.replace( - "", f"{injection_js}{sso_name_js}" - ) + injection_js = "" + frontend_code = frontend_code.replace("", f"{injection_js}") await hass.http.async_register_static_paths( [ @@ -100,10 +96,10 @@ class OIDCInjectedAuthPage(HomeAssistantView): self.html = html @staticmethod - async def inject(hass: HomeAssistant, sso_name: str) -> None: + async def inject(hass: HomeAssistant) -> None: """Inject the OIDC auth page into the frontend.""" try: - await frontend_injection(hass, sso_name) + await frontend_injection(hass) except Exception as e: # pylint: disable=broad-except _LOGGER.error("Failed to inject OIDC auth page: %s", e) diff --git a/custom_components/auth_oidc/endpoints/redirect.py b/custom_components/auth_oidc/endpoints/redirect.py index 5a0c1df..8ed5e4c 100644 --- a/custom_components/auth_oidc/endpoints/redirect.py +++ b/custom_components/auth_oidc/endpoints/redirect.py @@ -1,11 +1,13 @@ """Redirect route to redirect the user to the external OIDC server, can either be linked to directly or accessed through the welcome page.""" +from urllib.parse import quote from aiohttp import web from homeassistant.components.http import HomeAssistantView +from ..provider import OpenIDAuthProvider from ..tools.oidc_client import OIDCClient -from ..tools.helpers import get_url, get_view +from ..tools.helpers import error_response, get_url, get_valid_state_id, get_view PATH = "/auth/oidc/redirect" @@ -17,28 +19,44 @@ class OIDCRedirectView(HomeAssistantView): url = PATH name = "auth:oidc:redirect" - def __init__(self, oidc_client: OIDCClient, force_https: bool) -> None: + def __init__( + self, + oidc_client: OIDCClient, + oidc_provider: OpenIDAuthProvider, + force_https: bool, + ) -> None: self.oidc_client = oidc_client + self.oidc_provider = oidc_provider self.force_https = force_https - async def get(self, _: web.Request) -> web.Response: + async def get(self, req: web.Request) -> web.Response: """Receive response.""" + # Get cookie to get the state_id + state_id = await get_valid_state_id(req, self.oidc_provider) + + if not state_id: + # Direct access to the redirect endpoint, go to welcome page instead + welcome_url = get_url("/auth/oidc/welcome", self.force_https) + raise web.HTTPFound(welcome_url) + try: redirect_uri = get_url("/auth/oidc/callback", self.force_https) - auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri) + auth_url = await self.oidc_client.async_get_authorization_url( + redirect_uri, state_id + ) if auth_url: - raise web.HTTPFound(auth_url) + view_html = await get_view("redirect", {"url": quote(auth_url)}) + return web.Response(text=view_html, content_type="text/html") except RuntimeError: pass - view_html = await get_view( - "error", - {"error": "Integration is misconfigured, discovery could not be obtained."}, + return await error_response( + "Integration is misconfigured, discovery could not be obtained.", + status=500, ) - return web.Response(text=view_html, content_type="text/html") - async def post(self, request: web.Request) -> web.Response: + async def post(self, req: web.Request) -> web.Response: """POST""" - return await self.get(request) + return await self.get(req) diff --git a/custom_components/auth_oidc/endpoints/welcome.py b/custom_components/auth_oidc/endpoints/welcome.py index 7088bbb..973fcb5 100644 --- a/custom_components/auth_oidc/endpoints/welcome.py +++ b/custom_components/auth_oidc/endpoints/welcome.py @@ -1,8 +1,12 @@ """Welcome route to show the user the OIDC login button and give instructions.""" +import base64 +import binascii +from urllib.parse import urlparse, parse_qs, unquote from aiohttp import web from homeassistant.components.http import HomeAssistantView -from ..tools.helpers import get_url, get_view +from ..tools.helpers import error_response, get_url, template_response +from ..provider import OpenIDAuthProvider PATH = "/auth/oidc/welcome" @@ -14,16 +18,90 @@ class OIDCWelcomeView(HomeAssistantView): url = PATH name = "auth:oidc:welcome" - def __init__(self, name: str, is_enabled: bool, force_https: bool) -> None: + def __init__( + self, + oidc_provider: OpenIDAuthProvider, + name: str, + force_https: bool, + has_other_auth_providers: bool, + ) -> None: + self.oidc_provider = oidc_provider self.name = name - self.is_enabled = is_enabled self.force_https = force_https + self.has_other_auth_providers = has_other_auth_providers - async def get(self, _: web.Request) -> web.Response: + def determine_if_mobile(self, redirect_uri: str) -> bool: + """Determine if the client is a mobile client based on the redirect_uri.""" + oauth2_url = urlparse(redirect_uri) + client_id = parse_qs(oauth2_url.query).get("client_id") + + # If the client_id starts with https://home-assistant.io/ we assume it's a mobile client + return bool(client_id and client_id[0].startswith("https://home-assistant.io/")) + + async def get(self, req: web.Request) -> web.Response: """Receive response.""" - if not self.is_enabled: - raise web.HTTPTemporaryRedirect(get_url("/", self.force_https)) + # Get the query parameter with the redirect_uri + redirect_uri = req.query.get("redirect_uri") - view_html = await get_view("welcome", {"name": self.name}) - return web.Response(text=view_html, content_type="text/html") + # If set, determine if this is a mobile client based on the redirect_uri, + # otherwise assume it's not mobile + if redirect_uri: + try: + # decodeURIComponent(btoa(...)) -> unquote first, then base64 decode + redirect_uri = base64.b64decode( + unquote(redirect_uri), validate=True + ).decode("utf-8") + is_mobile = self.determine_if_mobile(redirect_uri) + except (binascii.Error, UnicodeDecodeError, ValueError): + return await error_response( + "Invalid redirect_uri, please restart login." + ) + else: + # Backwards compatibility with older versions that directly go to /auth/oidc/welcome + # If not set, redirect back to the main page and assume that this is a web client + redirect_uri = get_url("/", self.force_https) + is_mobile = False + + # Create OIDC state with the redirect_uri so we can use it later in the flow + state_id = await self.oidc_provider.async_create_state(redirect_uri) + cookie_header = self.oidc_provider.get_cookie_header( + state_id, secure=self.force_https or req.url.scheme == "https" + ) + + # If this is the only provider and we are on desktop, + # automatically go through the OIDC login + if not is_mobile and not self.has_other_auth_providers: + raise web.HTTPFound( + location=get_url("/auth/oidc/redirect", self.force_https), + headers=cookie_header, + ) + + # Otherwise display the screen with either mobile sign in or the buttons + # First generate code if mobile + code = None + if is_mobile: + # Create a code to login + code = await self.oidc_provider.async_generate_device_code(state_id) + if not code: + return await error_response( + "Failed to generate device code, please restart login.", + status=500, + ) + + # And add the other link if we have other auth providers + other_link = None + if self.has_other_auth_providers: + other_link = get_url("/?skip_oidc_redirect=true", self.force_https) + + # And display + response = await template_response( + "welcome", + { + "name": self.name, + "other_link": other_link, + "code": code, + }, + ) + response.headers.update(cookie_header) + return response diff --git a/custom_components/auth_oidc/provider.py b/custom_components/auth_oidc/provider.py index 02b864f..590a97f 100644 --- a/custom_components/auth_oidc/provider.py +++ b/custom_components/auth_oidc/provider.py @@ -22,7 +22,6 @@ from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.core import HomeAssistant, callback from homeassistant.components import http, person from homeassistant.exceptions import HomeAssistantError -import voluptuous as vol from .config.const import ( FEATURES, @@ -30,13 +29,14 @@ from .config.const import ( FEATURES_AUTOMATIC_PERSON_CREATION, DEFAULT_TITLE, ) -from .stores.code_store import CodeStore +from .stores.state_store import StateStore from .tools.types import UserDetails _LOGGER = logging.getLogger(__name__) PROVIDER_TYPE = "auth_oidc" HASS_PROVIDER_TYPE = "homeassistant" +COOKIE_NAME = "auth_oidc_state" class InvalidAuthError(HomeAssistantError): @@ -68,7 +68,7 @@ class OpenIDAuthProvider(AuthProvider): ) self._user_meta: dict[UserDetails] = {} - self._code_store: CodeStore | None = None + self._state_store: StateStore | None = None self._init_lock = asyncio.Lock() features = config.get( @@ -89,29 +89,120 @@ class OpenIDAuthProvider(AuthProvider): async def async_initialize(self) -> None: """Initialize the auth provider.""" - # Init the code store first + # Init the store first # Use the same technique as the HomeAssistant auth provider for storage # (/auth/providers/homeassistant.py#L392) async with self._init_lock: - if self._code_store is not None: + if self._state_store is not None: return - store = CodeStore(self.hass) + store = StateStore(self.hass) await store.async_load() - self._code_store = store + self._state_store = store self._user_meta = {} # Listen for user creation events self.hass.bus.async_listen(EVENT_USER_ADDED, self.async_user_created) - async def async_get_subject(self, code: str) -> Optional[str]: - """Retrieve user from the code, return subject and save meta - for later use with this provider instance.""" - if self._code_store is None: - await self.async_initialize() - assert self._code_store is not None + def _resolve_ip(self, ip: str | None = None) -> str | None: + """Resolve client IP from explicit input or current request context.""" + if ip: + return ip - user_data = await self._code_store.receive_userinfo_for_code(code) + req = http.current_request.get() + if req and req.remote: + return req.remote + + return None + + async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str: + """Create a new OIDC state and return the state id.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_create_state_from_url( + redirect_uri, self._resolve_ip(ip) + ) + + async def async_generate_device_code(self, state_id: str) -> Optional[str]: + """Generate a device code for the state, used for device login.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_generate_code_for_state(state_id) + + async def async_save_user_info( + self, state_id: str, user_info: dict[str, dict | str] + ) -> bool: + """Save user info to the given state.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_add_userinfo_to_state(state_id, user_info) + + async def async_get_redirect_uri_for_state( + self, state_id: str, ip: str | None = None + ) -> Optional[str]: + """Get the redirect_uri for the given state.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_get_redirect_uri_for_state( + state_id, self._resolve_ip(ip) + ) + + async def async_is_state_valid(self, state_id: str, ip: str | None = None) -> bool: + """Check if a state exists, belongs to this IP, and is not expired.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return ( + await self._state_store.async_get_redirect_uri_for_state( + state_id, self._resolve_ip(ip) + ) + is not None + ) + + async def async_is_state_ready(self, state_id: str, ip: str | None = None) -> bool: + """Check if the state has received the user info from the OIDC callback.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_is_state_ready( + state_id, self._resolve_ip(ip) + ) + + async def async_link_state_to_code( + self, state_id: str, code: str, ip: str | None = None + ) -> bool: + """Link two states together by copying the user info from one to the other.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + return await self._state_store.async_link_state_to_code( + state_id, code, self._resolve_ip(ip) + ) + + async def async_get_subject( + self, state_id: str, ip: str | None = None + ) -> Optional[str]: + """Retrieve user from the state_id, return subject and save meta + for later use with this provider instance.""" + if self._state_store is None: + await self.async_initialize() + assert self._state_store is not None + + # This also deletes the state as we are using it for sign-in + user_data = await self._state_store.async_receive_userinfo_for_state( + state_id, self._resolve_ip(ip) + ) if user_data is None: return None @@ -119,14 +210,6 @@ class OpenIDAuthProvider(AuthProvider): self._user_meta[sub] = user_data return sub - async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str: - """Save user info and return a code.""" - if self._code_store is None: - await self.async_initialize() - assert self._code_store is not None - - return await self._code_store.async_generate_code_for_userinfo(user_info) - async def _async_find_user_by_username(self, username: str) -> Optional[User]: """Find a user by username.""" users = await self.store.async_get_users() @@ -145,6 +228,18 @@ class OpenIDAuthProvider(AuthProvider): return None + def get_cookie_header(self, state_id: str, secure: bool = False): + """Get the cookie header to set the state_id cookie.""" + secure_flag = "; Secure" if secure else "" + return { + # Set a cookie for the other pages to know the state_id + # Keep cookie lifetime aligned with state lifetime in storage (5 minutes). + "set-cookie": f"{COOKIE_NAME}=" + + state_id + + "; Path=/auth/; SameSite=Strict; HttpOnly; Max-Age=300" + + secure_flag, + } + # ==== # Handler for user created and related functions (person creation) # ==== @@ -271,7 +366,7 @@ class OpenIDAuthProvider(AuthProvider): class OpenIdLoginFlow(LoginFlow): """Handler for the login flow.""" - async def _finalize_user(self, code: str) -> AuthFlowResult: + async def _finalize_user(self, state_id: str) -> AuthFlowResult: # Verify a dummy hash to make it last a bit longer # as security measure (limits the amount of attempts you have in 5 min) # Similar to what the HomeAssistant auth provider does @@ -280,7 +375,7 @@ class OpenIdLoginFlow(LoginFlow): # Actually look up the auth provider after, # this doesn't take a lot of time (regardless of it's in there or not) - sub = await self._auth_provider.async_get_subject(code) + sub = await self._auth_provider.async_get_subject(state_id) if sub: return await self.async_finish( { @@ -290,54 +385,23 @@ class OpenIdLoginFlow(LoginFlow): raise InvalidAuthError - def _show_login_form( - self, errors: Optional[dict[str, str]] = None - ) -> AuthFlowResult: - if errors is None: - errors = {} - - # Show the login form - # Abuses the MFA form, as it works better for our usecase - # UI suggestions are welcome (make a PR!) - return self.async_show_form( - step_id="mfa", - data_schema=vol.Schema( - { - vol.Required("code"): str, - } - ), - errors=errors, - ) - async def async_step_init( self, user_input: dict[str, str] | None = None ) -> AuthFlowResult: """Handle the step of the form.""" - # Try to use the user input first - if user_input is not None and "code" in user_input: - try: - return await self._finalize_user(user_input["code"]) - except InvalidAuthError: - return self._show_login_form({"base": "invalid_auth"}) - - # If not available, check the cookie + # Check if the cookie is present to login req = http.current_request.get() if req and req.cookies: - code_cookie = req.cookies.get("auth_oidc_code") + state_cookie = req.cookies.get(COOKIE_NAME) - if code_cookie: - _LOGGER.debug("Code cookie found on login: %s", code_cookie) + if state_cookie: + _LOGGER.debug("State cookie found on login: %s", state_cookie) try: - return await self._finalize_user(code_cookie) + return await self._finalize_user(state_cookie) except InvalidAuthError: pass - # If none are available, just show the form - return self._show_login_form() - - async def async_step_mfa( - self, user_input: dict[str, str] | None = None - ) -> AuthFlowResult: - # This is a dummy step function just to use the nicer MFA UI instead - return await self.async_step_init(user_input) + # If no cookie is found, abort. + # User should either be redirected or start manually on the welcome + return self.async_abort(reason="no_oidc_cookie_found") diff --git a/custom_components/auth_oidc/static/injection.js b/custom_components/auth_oidc/static/injection.js index efd0375..d2e299b 100644 --- a/custom_components/auth_oidc/static/injection.js +++ b/custom_components/auth_oidc/static/injection.js @@ -1,215 +1,82 @@ -function safeSetTextContent(element, value) { - if (!element) return - var textNode = Array.from(element.childNodes).find(node => node.nodeType === Node.TEXT_NODE && node.textContent.trim().length > 0) - if (!textNode || textNode.textContent === value) return - textNode.textContent = value +/** + * OIDC Frontend Redirect injection script + * This script is injected because the 'hass-oidc-auth' custom component is active. + */ + +function attempt_oidc_redirect() { + // Get URL parameters + const urlParams = new URLSearchParams(window.location.search); + + // Check if we have skip_oidc_redirect directly here + if (urlParams.get('skip_oidc_redirect') === 'true') { + // No console log because this is intended behavior + return; + } + + const originalUrl = urlParams.get('redirect_uri'); + if (!originalUrl) { + console.warn('[OIDC] No OAuth2 redirect_uri parameter found in the URL. Frontend redirect cancelled.'); + return; + } + + try { + // Parse the redirect URI + const redirectUrl = new URL(originalUrl); + + // Check if redirect URI has a query parameter to stop OIDC injection + if (redirectUrl.searchParams.get('skip_oidc_redirect') === 'true') { + // No console log because this is intended behavior + return; + } + } catch (error) { + console.error('[OIDC] Invalid redirect_uri parameter:', error); + } + + window.stop(); // Stop loading the current page before redirecting + + // Redirect to the OIDC auth URL + const base64encodeUrl = btoa(window.location.href); + const oidcAuthUrl = '/auth/oidc/welcome?redirect_uri=' + encodeURIComponent(base64encodeUrl); + window.location.href = oidcAuthUrl; } -let firstFocus = true -let showCodeOverride = null +function click_alternative_provider_instead() { + setTimeout(() => { + // Find ha-auth-flow + const authFlowElement = document.querySelector('ha-auth-flow'); -function isMobile() { - const clientId = new URL(location.href).searchParams.get("client_id") - return clientId && clientId.startsWith("https://home-assistant.io/iOS") || clientId.startsWith("https://home-assistant.io/android") + if (!authFlowElement) { + console.warn("[OIDC] ha-auth-flow element not found. Not automatically selecting HA provider."); + return; + } + + // Check if the text "Login aborted" is present on the page + if (!authFlowElement.innerText.includes('Login aborted')) { + console.warn("[OIDC] 'Login aborted' text not found. Not automatically selecting HA provider."); + return; + } + + // Find the ha-pick-auth-provider element + const authProviderElement = document.querySelector('ha-pick-auth-provider'); + + if (!authProviderElement) { + console.warn("[OIDC] ha-pick-auth-provider not found. Not automatically selecting HA provider."); + return; + } + + // Click the first ha-list-item element inside the ha-pick-auth-provider + const firstListItem = authProviderElement.shadowRoot?.querySelector('ha-list-item'); + if (!firstListItem) { + console.warn("[OIDC] No ha-list-item found inside ha-pick-auth-provider. Not automatically selecting HA provider."); + return; + } + + firstListItem.click(); + }, 500); } -function showCode() { - if (showCodeOverride !== null) return showCodeOverride - return isMobile() -} - -let ssoButton = null -let codeButton = null -let codeMessage = null -let codeToggle = null -let codeToggleText = null - -function update() { - const sso_name = window.sso_name || "Single Sign-On" - const loginHeader = document.querySelector(".card-content > ha-auth-flow > form > h1") - const authForm = document.querySelector("ha-auth-form") - const codeField = document.querySelector(".mdc-text-field__input[name=code]") - const haButtons = document.querySelectorAll("ha-button:not(.sso)") - const errorAlert = document.querySelector("ha-auth-form ha-alert[alert-type=error]") - const loginOptionList = document.querySelector("ha-pick-auth-provider")?.shadowRoot?.querySelector("ha-list") - const forgotPasswordLink = document.querySelector(".forgot-password") - - // Iterate over haButtons to find one with text "Login with code" - let loginButton = null - haButtons.forEach(button => { - if (button.textContent.trim() === "Log in") { - loginButton = button - } - }) - - // ==== - // Code input - if (codeField) { - if (codeField.placeholder !== "One-time code") { - codeField.placeholder = "One-time code" - codeField.autofocus = false - codeField.autocomplete = "off" - - if (firstFocus) { - firstFocus = false - - if (document.activeElement === codeField) { - setTimeout(() => { - codeField.blur() - let check = setInterval(() => { - const helperText = document.querySelector("#helper-text") - const invalidTextField = document.querySelector(".mdc-text-field--invalid") - const validationMsg = document.querySelector(".mdc-text-field-helper-text--validation-msg") - if (helperText && invalidTextField && validationMsg) { - clearInterval(check) - safeSetTextContent(helperText, "") - invalidTextField.classList.remove("mdc-text-field--invalid") - validationMsg.classList.remove("mdc-text-field-helper-text--validation-msg") - } - }, 1) - }, 0) - } - } - } - - if (errorAlert && errorAlert.textContent.trim().length === 0) { - errorAlert.setAttribute("title", "Invalid Code") - } - - authForm.style.display = showCode() ? "" : "none" - } - - if (authForm && !codeMessage) { - codeMessage = document.createElement("p") - codeMessage.innerHTML = `Please login on a different device to continue.
You can also use your mobile webbrowser.` - authForm.parentElement.insertBefore(codeMessage, authForm) - } - - if (codeMessage) { - codeMessage.style.display = showCode() ? "" : "none" - } - - if (showCode() && loginButton !== null && !codeButton) { - codeButton = document.createElement("ha-button") - codeButton.id = "code_button" - codeButton.classList.add("code") - codeButton.innerText = "Log in with code" - codeButton.setAttribute("raised", "") - codeButton.style.marginRight = "1em" - - // Copy the onclick handler the loginButton - codeButton.addEventListener("click", () => { - loginButton.click() - }) - loginButton.parentElement.prepend(codeButton) - } else if (!showCode() && loginButton !== null &&codeButton) { - codeButton.remove() - codeButton = null - } - - // ==== - // Toggle button - if (loginOptionList && !codeToggle && !isMobile()) { - codeToggle = document.createElement("ha-list-item") - codeToggle.setAttribute("hasmeta", "") - codeToggleText = document.createTextNode("") - codeToggle.appendChild(codeToggleText) - const codeToggleIcon = document.createElement("ha-icon-next") - codeToggleIcon.setAttribute("slot", "meta") - codeToggle.appendChild(codeToggleIcon) - - let ranHandler = false; - codeToggle.addEventListener("click", () => { - ranHandler = true; - showCodeOverride = !showCode() - update() - }) - - loginOptionList.addEventListener("click", (ev) => { - if (!ranHandler) { - showCodeOverride = false; - codeMessage = null; - } - ranHandler = false; - }) - - loginOptionList.appendChild(codeToggle) - } - - if (codeToggle) { - codeToggle.style.display = codeField ? "" : "none" - } - - if (codeToggleText) { - codeToggleText.textContent = showCode() ? "Single-Sign On" : "One-time device code" - } - - // ==== - // SSO Page - const shouldShowSSOButton = !showCode() && !!codeField - const isOurScreen = showCode() || shouldShowSSOButton - - if (loginButton !== null && !ssoButton) { - ssoButton = document.createElement("ha-button") - ssoButton.id = "sso_button" - ssoButton.classList.add("sso") - ssoButton.innerText = "Log in with " + sso_name - ssoButton.setAttribute("raised", "") - ssoButton.style.marginRight = "1em" - ssoButton.addEventListener("click", () => { - location.href = "/auth/oidc/redirect" - ssoButton.innerHTML = "Redirecting, please wait..." - ssoButton.disabled = true - }) - loginButton.parentElement.prepend(ssoButton) - } - - if (ssoButton) { - ssoButton.style.display = (!showCode() && codeField) ? "" : "none" - } - - // ==== - // Header hidden on our screens - if (loginHeader) { - if (isOurScreen) { - // Hide the header on our screens - loginHeader.style.display = "none" - if (loginButton !== null) { - loginButton.style.display = "none" - } - forgotPasswordLink.style.display = "none" - } else { - // Show the header on the login screen - loginHeader.style.display = "" - if (loginButton !== null) { - loginButton.style.display = "" - } - forgotPasswordLink.style.display = "" - } - } -} - -// Hide the content until ready -let ready = false -document.querySelector(".content").style.display = "none" - -const observer = new MutationObserver((mutationsList, observer) => { - update() - - if (!ready) { - ready = Boolean(ssoButton && codeMessage && codeToggle && codeToggleText) - if (ready) document.querySelector(".content").style.display = "" - } -}) - -observer.observe(document.body, { childList: true, subtree: true }) - -setTimeout(() => { - if (!ready) { - console.warn("hass-oidc-auth: Document was not ready after 300ms seconds, force displaying. This may indicate a problem with the UI injection.") - } - - // Force display the content - document.querySelector(".content").style.display = ""; - update(); -}, 300) \ No newline at end of file +// Run OIDC injection upon load +(() => { + attempt_oidc_redirect(); + click_alternative_provider_instead(); +})(); \ No newline at end of file diff --git a/custom_components/auth_oidc/static/style.css b/custom_components/auth_oidc/static/style.css index 5035935..226d772 100644 --- a/custom_components/auth_oidc/static/style.css +++ b/custom_components/auth_oidc/static/style.css @@ -1,2 +1,2 @@ -/*! tailwindcss v4.1.14 | MIT License | https://tailwindcss.com */ -@layer properties{@supports (((-webkit-hyphens:none)) and (not (margin-trim:inline))) or ((-moz-orient:inline) and (not (color:rgb(from red r g b)))){*,:before,:after,::backdrop{--tw-border-style:solid;--tw-font-weight:initial;--tw-shadow:0 0 #0000;--tw-shadow-color:initial;--tw-shadow-alpha:100%;--tw-inset-shadow:0 0 #0000;--tw-inset-shadow-color:initial;--tw-inset-shadow-alpha:100%;--tw-ring-color:initial;--tw-ring-shadow:0 0 #0000;--tw-inset-ring-color:initial;--tw-inset-ring-shadow:0 0 #0000;--tw-ring-inset:initial;--tw-ring-offset-width:0px;--tw-ring-offset-color:#fff;--tw-ring-offset-shadow:0 0 #0000}}}@layer theme{:root,:host{--font-sans:ui-sans-serif,system-ui,sans-serif,"Apple Color Emoji","Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";--font-mono:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",monospace;--color-blue-100:oklch(93.2% .032 255.585);--color-blue-400:oklch(70.7% .165 254.624);--color-blue-500:oklch(62.3% .214 259.815);--color-blue-600:oklch(54.6% .245 262.881);--color-blue-700:oklch(48.8% .243 264.376);--color-gray-200:oklch(92.8% .006 264.531);--color-gray-800:oklch(27.8% .033 256.848);--color-white:#fff;--spacing:.25rem;--container-md:28rem;--text-sm:.875rem;--text-sm--line-height:calc(1.25/.875);--text-xl:1.25rem;--text-xl--line-height:calc(1.75/1.25);--text-2xl:1.5rem;--text-2xl--line-height:calc(2/1.5);--font-weight-semibold:600;--font-weight-bold:700;--radius-lg:.5rem;--animate-spin:spin 1s linear infinite;--default-font-family:var(--font-sans);--default-mono-font-family:var(--font-mono)}}@layer base{*,:after,:before,::backdrop{box-sizing:border-box;border:0 solid;margin:0;padding:0}::file-selector-button{box-sizing:border-box;border:0 solid;margin:0;padding:0}html,:host{-webkit-text-size-adjust:100%;tab-size:4;line-height:1.5;font-family:var(--default-font-family,ui-sans-serif,system-ui,sans-serif,"Apple Color Emoji","Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji");font-feature-settings:var(--default-font-feature-settings,normal);font-variation-settings:var(--default-font-variation-settings,normal);-webkit-tap-highlight-color:transparent}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;-webkit-text-decoration:inherit;-webkit-text-decoration:inherit;-webkit-text-decoration:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,samp,pre{font-family:var(--default-mono-font-family,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",monospace);font-feature-settings:var(--default-mono-font-feature-settings,normal);font-variation-settings:var(--default-mono-font-variation-settings,normal);font-size:1em}small{font-size:80%}sub,sup{vertical-align:baseline;font-size:75%;line-height:0;position:relative}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}:-moz-focusring{outline:auto}progress{vertical-align:baseline}summary{display:list-item}ol,ul,menu{list-style:none}img,svg,video,canvas,audio,iframe,embed,object{vertical-align:middle;display:block}img,video{max-width:100%;height:auto}button,input,select,optgroup,textarea{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}::file-selector-button{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}:where(select:is([multiple],[size])) optgroup{font-weight:bolder}:where(select:is([multiple],[size])) optgroup option{padding-inline-start:20px}::file-selector-button{margin-inline-end:4px}::placeholder{opacity:1}@supports (not ((-webkit-appearance:-apple-pay-button))) or (contain-intrinsic-size:1px){::placeholder{color:currentColor}@supports (color:color-mix(in lab, red, red)){::placeholder{color:color-mix(in oklab,currentcolor 50%,transparent)}}}textarea{resize:vertical}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-date-and-time-value{min-height:1lh;text-align:inherit}::-webkit-datetime-edit{display:inline-flex}::-webkit-datetime-edit-fields-wrapper{padding:0}::-webkit-datetime-edit{padding-block:0}::-webkit-datetime-edit-year-field{padding-block:0}::-webkit-datetime-edit-month-field{padding-block:0}::-webkit-datetime-edit-day-field{padding-block:0}::-webkit-datetime-edit-hour-field{padding-block:0}::-webkit-datetime-edit-minute-field{padding-block:0}::-webkit-datetime-edit-second-field{padding-block:0}::-webkit-datetime-edit-millisecond-field{padding-block:0}::-webkit-datetime-edit-meridiem-field{padding-block:0}::-webkit-calendar-picker-indicator{line-height:1}:-moz-ui-invalid{box-shadow:none}button,input:where([type=button],[type=reset],[type=submit]){appearance:button}::file-selector-button{appearance:button}::-webkit-inner-spin-button{height:auto}::-webkit-outer-spin-button{height:auto}[hidden]:where(:not([hidden=until-found])){display:none!important}}@layer components;@layer utilities{.invisible{visibility:hidden}.visible{visibility:visible}.sr-only{clip-path:inset(50%);white-space:nowrap;border-width:0;width:1px;height:1px;margin:-1px;padding:0;position:absolute;overflow:hidden}.absolute{position:absolute}.fixed{position:fixed}.relative{position:relative}.static{position:static}.container{width:100%}@media (min-width:40rem){.container{max-width:40rem}}@media (min-width:48rem){.container{max-width:48rem}}@media (min-width:64rem){.container{max-width:64rem}}@media (min-width:80rem){.container{max-width:80rem}}@media (min-width:96rem){.container{max-width:96rem}}.my-6{margin-block:calc(var(--spacing)*6)}.my-12{margin-block:calc(var(--spacing)*12)}.mt-6{margin-top:calc(var(--spacing)*6)}.mb-4{margin-bottom:calc(var(--spacing)*4)}.mb-6{margin-bottom:calc(var(--spacing)*6)}.mb-8{margin-bottom:calc(var(--spacing)*8)}.block{display:block}.contents{display:contents}.flex{display:flex}.hidden{display:none}.inline{display:inline}.table{display:table}.h-10{height:calc(var(--spacing)*10)}.h-full{height:100%}.max-h-full{max-height:100%}.min-h-full{min-height:100%}.w-10{width:calc(var(--spacing)*10)}.w-full{width:100%}.max-w-md{max-width:var(--container-md)}.animate-spin{animation:var(--animate-spin)}.items-center{align-items:center}.justify-center{justify-content:center}.rounded{border-radius:.25rem}.rounded-lg{border-radius:var(--radius-lg)}.border{border-style:var(--tw-border-style);border-width:1px}.border-blue-400{border-color:var(--color-blue-400)}.bg-blue-100{background-color:var(--color-blue-100)}.bg-blue-500{background-color:var(--color-blue-500)}.bg-gray-200{background-color:var(--color-gray-200)}.bg-white{background-color:var(--color-white)}.fill-blue-600{fill:var(--color-blue-600)}.p-6{padding:calc(var(--spacing)*6)}.px-4{padding-inline:calc(var(--spacing)*4)}.py-2{padding-block:calc(var(--spacing)*2)}.py-3{padding-block:calc(var(--spacing)*3)}.text-center{text-align:center}.text-2xl{font-size:var(--text-2xl);line-height:var(--tw-leading,var(--text-2xl--line-height))}.text-sm{font-size:var(--text-sm);line-height:var(--tw-leading,var(--text-sm--line-height))}.text-xl{font-size:var(--text-xl);line-height:var(--tw-leading,var(--text-xl--line-height))}.font-bold{--tw-font-weight:var(--font-weight-bold);font-weight:var(--font-weight-bold)}.font-semibold{--tw-font-weight:var(--font-weight-semibold);font-weight:var(--font-weight-semibold)}.text-blue-600{color:var(--color-blue-600)}.text-blue-700{color:var(--color-blue-700)}.text-gray-200{color:var(--color-gray-200)}.text-gray-800{color:var(--color-gray-800)}.text-white{color:var(--color-white)}.shadow{--tw-shadow:0 1px 3px 0 var(--tw-shadow-color,#0000001a),0 1px 2px -1px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.shadow-lg{--tw-shadow:0 10px 15px -3px var(--tw-shadow-color,#0000001a),0 4px 6px -4px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.shadow-md{--tw-shadow:0 4px 6px -1px var(--tw-shadow-color,#0000001a),0 2px 4px -2px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}@media (hover:hover){.hover\:bg-blue-700:hover{background-color:var(--color-blue-700)}.hover\:text-blue-700:hover{color:var(--color-blue-700)}.hover\:underline:hover{text-decoration-line:underline}}.focus\:ring-2:focus{--tw-ring-shadow:var(--tw-ring-inset,)0 0 0 calc(2px + var(--tw-ring-offset-width))var(--tw-ring-color,currentcolor);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.focus\:ring-blue-400:focus{--tw-ring-color:var(--color-blue-400)}.focus\:outline-none:focus{--tw-outline-style:none;outline-style:none}}@property --tw-border-style{syntax:"*";inherits:false;initial-value:solid}@property --tw-font-weight{syntax:"*";inherits:false}@property --tw-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-shadow-color{syntax:"*";inherits:false}@property --tw-shadow-alpha{syntax:"";inherits:false;initial-value:100%}@property --tw-inset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-shadow-color{syntax:"*";inherits:false}@property --tw-inset-shadow-alpha{syntax:"";inherits:false;initial-value:100%}@property --tw-ring-color{syntax:"*";inherits:false}@property --tw-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-ring-color{syntax:"*";inherits:false}@property --tw-inset-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-ring-inset{syntax:"*";inherits:false}@property --tw-ring-offset-width{syntax:"";inherits:false;initial-value:0}@property --tw-ring-offset-color{syntax:"*";inherits:false;initial-value:#fff}@property --tw-ring-offset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@keyframes spin{to{transform:rotate(360deg)}} \ No newline at end of file +/*! tailwindcss v4.2.1 | MIT License | https://tailwindcss.com */ +@layer properties{@supports (((-webkit-hyphens:none)) and (not (margin-trim:inline))) or ((-moz-orient:inline) and (not (color:rgb(from red r g b)))){*,:before,:after,::backdrop{--tw-border-style:solid;--tw-font-weight:initial;--tw-tracking:initial;--tw-shadow:0 0 #0000;--tw-shadow-color:initial;--tw-shadow-alpha:100%;--tw-inset-shadow:0 0 #0000;--tw-inset-shadow-color:initial;--tw-inset-shadow-alpha:100%;--tw-ring-color:initial;--tw-ring-shadow:0 0 #0000;--tw-inset-ring-color:initial;--tw-inset-ring-shadow:0 0 #0000;--tw-ring-inset:initial;--tw-ring-offset-width:0px;--tw-ring-offset-color:#fff;--tw-ring-offset-shadow:0 0 #0000}}}@layer theme{:root,:host{--font-sans:ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";--font-mono:ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;--color-blue-100:oklch(93.2% .032 255.585);--color-blue-400:oklch(70.7% .165 254.624);--color-blue-500:oklch(62.3% .214 259.815);--color-blue-600:oklch(54.6% .245 262.881);--color-blue-700:oklch(48.8% .243 264.376);--color-gray-50:oklch(98.5% .002 247.839);--color-gray-100:oklch(96.7% .003 264.542);--color-gray-200:oklch(92.8% .006 264.531);--color-gray-300:oklch(87.2% .01 258.338);--color-gray-600:oklch(44.6% .03 256.802);--color-gray-700:oklch(37.3% .034 259.733);--color-gray-800:oklch(27.8% .033 256.848);--color-white:#fff;--spacing:.25rem;--container-md:28rem;--text-sm:.875rem;--text-sm--line-height:calc(1.25 / .875);--text-base:1rem;--text-base--line-height:calc(1.5 / 1);--text-lg:1.125rem;--text-lg--line-height:calc(1.75 / 1.125);--text-2xl:1.5rem;--text-2xl--line-height:calc(2 / 1.5);--text-3xl:1.875rem;--text-3xl--line-height:calc(2.25 / 1.875);--font-weight-semibold:600;--font-weight-bold:700;--tracking-wide:.025em;--radius-md:.375rem;--radius-lg:.5rem;--animate-spin:spin 1s linear infinite;--default-font-family:var(--font-sans);--default-mono-font-family:var(--font-mono)}}@layer base{*,:after,:before,::backdrop{box-sizing:border-box;border:0 solid;margin:0;padding:0}::file-selector-button{box-sizing:border-box;border:0 solid;margin:0;padding:0}html,:host{-webkit-text-size-adjust:100%;tab-size:4;line-height:1.5;font-family:var(--default-font-family,ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji");font-feature-settings:var(--default-font-feature-settings,normal);font-variation-settings:var(--default-font-variation-settings,normal);-webkit-tap-highlight-color:transparent}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;-webkit-text-decoration:inherit;-webkit-text-decoration:inherit;-webkit-text-decoration:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,samp,pre{font-family:var(--default-mono-font-family,ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace);font-feature-settings:var(--default-mono-font-feature-settings,normal);font-variation-settings:var(--default-mono-font-variation-settings,normal);font-size:1em}small{font-size:80%}sub,sup{vertical-align:baseline;font-size:75%;line-height:0;position:relative}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}:-moz-focusring{outline:auto}progress{vertical-align:baseline}summary{display:list-item}ol,ul,menu{list-style:none}img,svg,video,canvas,audio,iframe,embed,object{vertical-align:middle;display:block}img,video{max-width:100%;height:auto}button,input,select,optgroup,textarea{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}::file-selector-button{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}:where(select:is([multiple],[size])) optgroup{font-weight:bolder}:where(select:is([multiple],[size])) optgroup option{padding-inline-start:20px}::file-selector-button{margin-inline-end:4px}::placeholder{opacity:1}@supports (not ((-webkit-appearance:-apple-pay-button))) or (contain-intrinsic-size:1px){::placeholder{color:currentColor}@supports (color:color-mix(in lab, red, red)){::placeholder{color:color-mix(in oklab, currentcolor 50%, transparent)}}}textarea{resize:vertical}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-date-and-time-value{min-height:1lh;text-align:inherit}::-webkit-datetime-edit{display:inline-flex}::-webkit-datetime-edit-fields-wrapper{padding:0}::-webkit-datetime-edit{padding-block:0}::-webkit-datetime-edit-year-field{padding-block:0}::-webkit-datetime-edit-month-field{padding-block:0}::-webkit-datetime-edit-day-field{padding-block:0}::-webkit-datetime-edit-hour-field{padding-block:0}::-webkit-datetime-edit-minute-field{padding-block:0}::-webkit-datetime-edit-second-field{padding-block:0}::-webkit-datetime-edit-millisecond-field{padding-block:0}::-webkit-datetime-edit-meridiem-field{padding-block:0}::-webkit-calendar-picker-indicator{line-height:1}:-moz-ui-invalid{box-shadow:none}button,input:where([type=button],[type=reset],[type=submit]){appearance:button}::file-selector-button{appearance:button}::-webkit-inner-spin-button{height:auto}::-webkit-outer-spin-button{height:auto}[hidden]:where(:not([hidden=until-found])){display:none!important}}@layer components;@layer utilities{.invisible{visibility:hidden}.visible{visibility:visible}.sr-only{clip-path:inset(50%);white-space:nowrap;border-width:0;width:1px;height:1px;margin:-1px;padding:0;position:absolute;overflow:hidden}.absolute{position:absolute}.fixed{position:fixed}.relative{position:relative}.static{position:static}.start{inset-inline-start:var(--spacing)}.end{inset-inline-end:var(--spacing)}.container{width:100%}@media (min-width:40rem){.container{max-width:40rem}}@media (min-width:48rem){.container{max-width:48rem}}@media (min-width:64rem){.container{max-width:64rem}}@media (min-width:80rem){.container{max-width:80rem}}@media (min-width:96rem){.container{max-width:96rem}}.my-6{margin-block:calc(var(--spacing) * 6)}.mt-4{margin-top:calc(var(--spacing) * 4)}.mb-2{margin-bottom:calc(var(--spacing) * 2)}.mb-4{margin-bottom:calc(var(--spacing) * 4)}.mb-8{margin-bottom:calc(var(--spacing) * 8)}.block{display:block}.contents{display:contents}.flex{display:flex}.hidden{display:none}.inline{display:inline}.inline-block{display:inline-block}.table{display:table}.h-10{height:calc(var(--spacing) * 10)}.h-full{height:100%}.max-h-full{max-height:100%}.min-h-full{min-height:100%}.w-10{width:calc(var(--spacing) * 10)}.w-full{width:100%}.max-w-md{max-width:var(--container-md)}.animate-spin{animation:var(--animate-spin)}.items-center{align-items:center}.justify-between{justify-content:space-between}.justify-center{justify-content:center}.rounded{border-radius:.25rem}.rounded-lg{border-radius:var(--radius-lg)}.rounded-md{border-radius:var(--radius-md)}.border{border-style:var(--tw-border-style);border-width:1px}.border-t{border-top-style:var(--tw-border-style);border-top-width:1px}.border-blue-400{border-color:var(--color-blue-400)}.border-blue-500{border-color:var(--color-blue-500)}.border-gray-200{border-color:var(--color-gray-200)}.border-gray-300{border-color:var(--color-gray-300)}.bg-blue-100{background-color:var(--color-blue-100)}.bg-blue-500{background-color:var(--color-blue-500)}.bg-gray-50{background-color:var(--color-gray-50)}.bg-gray-100{background-color:var(--color-gray-100)}.bg-gray-200{background-color:var(--color-gray-200)}.bg-white{background-color:var(--color-white)}.fill-blue-600{fill:var(--color-blue-600)}.p-6{padding:calc(var(--spacing) * 6)}.px-4{padding-inline:calc(var(--spacing) * 4)}.px-5{padding-inline:calc(var(--spacing) * 5)}.px-6{padding-inline:calc(var(--spacing) * 6)}.py-2{padding-block:calc(var(--spacing) * 2)}.py-3{padding-block:calc(var(--spacing) * 3)}.py-4{padding-block:calc(var(--spacing) * 4)}.pt-4{padding-top:calc(var(--spacing) * 4)}.text-center{text-align:center}.text-left{text-align:left}.text-2xl{font-size:var(--text-2xl);line-height:var(--tw-leading,var(--text-2xl--line-height))}.text-3xl{font-size:var(--text-3xl);line-height:var(--tw-leading,var(--text-3xl--line-height))}.text-base{font-size:var(--text-base);line-height:var(--tw-leading,var(--text-base--line-height))}.text-lg{font-size:var(--text-lg);line-height:var(--tw-leading,var(--text-lg--line-height))}.text-sm{font-size:var(--text-sm);line-height:var(--tw-leading,var(--text-sm--line-height))}.font-bold{--tw-font-weight:var(--font-weight-bold);font-weight:var(--font-weight-bold)}.font-semibold{--tw-font-weight:var(--font-weight-semibold);font-weight:var(--font-weight-semibold)}.tracking-\[0\.15em\]{--tw-tracking:.15em;letter-spacing:.15em}.tracking-wide{--tw-tracking:var(--tracking-wide);letter-spacing:var(--tracking-wide)}.text-blue-600{color:var(--color-blue-600)}.text-blue-700{color:var(--color-blue-700)}.text-gray-200{color:var(--color-gray-200)}.text-gray-600{color:var(--color-gray-600)}.text-gray-700{color:var(--color-gray-700)}.text-gray-800{color:var(--color-gray-800)}.text-white{color:var(--color-white)}.shadow{--tw-shadow:0 1px 3px 0 var(--tw-shadow-color,#0000001a), 0 1px 2px -1px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow)}.shadow-lg{--tw-shadow:0 10px 15px -3px var(--tw-shadow-color,#0000001a), 0 4px 6px -4px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow)}.shadow-md{--tw-shadow:0 4px 6px -1px var(--tw-shadow-color,#0000001a), 0 2px 4px -2px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow)}@media (hover:hover){.hover\:cursor-pointer:hover{cursor:pointer}.hover\:bg-blue-700:hover{background-color:var(--color-blue-700)}.hover\:bg-gray-100:hover{background-color:var(--color-gray-100)}.hover\:text-blue-700:hover{color:var(--color-blue-700)}.hover\:underline:hover{text-decoration-line:underline}}.focus\:ring-2:focus{--tw-ring-shadow:var(--tw-ring-inset,) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color,currentcolor);box-shadow:var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow)}.focus\:ring-blue-400:focus{--tw-ring-color:var(--color-blue-400)}.focus\:outline-none:focus{--tw-outline-style:none;outline-style:none}}@property --tw-border-style{syntax:"*";inherits:false;initial-value:solid}@property --tw-font-weight{syntax:"*";inherits:false}@property --tw-tracking{syntax:"*";inherits:false}@property --tw-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-shadow-color{syntax:"*";inherits:false}@property --tw-shadow-alpha{syntax:"";inherits:false;initial-value:100%}@property --tw-inset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-shadow-color{syntax:"*";inherits:false}@property --tw-inset-shadow-alpha{syntax:"";inherits:false;initial-value:100%}@property --tw-ring-color{syntax:"*";inherits:false}@property --tw-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-ring-color{syntax:"*";inherits:false}@property --tw-inset-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-ring-inset{syntax:"*";inherits:false}@property --tw-ring-offset-width{syntax:"";inherits:false;initial-value:0}@property --tw-ring-offset-color{syntax:"*";inherits:false;initial-value:#fff}@property --tw-ring-offset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@keyframes spin{to{transform:rotate(360deg)}} \ No newline at end of file diff --git a/custom_components/auth_oidc/stores/code_store.py b/custom_components/auth_oidc/stores/code_store.py deleted file mode 100644 index d3aedab..0000000 --- a/custom_components/auth_oidc/stores/code_store.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Code Store, stores the codes and their associated authenticated user temporarily.""" - -import random -import string - -from datetime import datetime, timedelta, timezone -from typing import cast, Optional -from homeassistant.helpers.storage import Store -from homeassistant.core import HomeAssistant - -from ..tools.types import UserDetails - -STORAGE_VERSION = 1 -STORAGE_KEY = "auth_provider.auth_oidc.codes" - - -class CodeStore: - """Holds the codes and associated data""" - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the user data store.""" - self.hass = hass - self._store = Store[dict[str, UserDetails]]( - hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True - ) - self._data: dict[str, dict[str, dict | str]] | None = None - - async def async_load(self) -> None: - """Load stored data.""" - if (data := await self._store.async_load()) is None: - data = cast(dict[str, UserDetails], {}) - self._data = data - - async def _async_save(self) -> None: - """Save data.""" - if self._data is not None: - await self._store.async_save(self._data) - - def _generate_code(self) -> str: - """Generate a random six-digit code.""" - return "".join(random.choices(string.digits, k=6)) - - async def async_generate_code_for_userinfo(self, user_info: UserDetails) -> str: - """Generates a one time code and adds it to the database for 5 minutes.""" - if self._data is None: - raise RuntimeError("Data not loaded") - - code = self._generate_code() - expiration = datetime.now(timezone.utc) + timedelta(minutes=5) - - self._data[code] = { - "user_info": user_info, - "code": code, - "expiration": expiration.isoformat(), - } - - await self._async_save() - return code - - async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]: - """Retrieve user info based on the code.""" - if self._data is None: - raise RuntimeError("Data not loaded") - - user_data = self._data.get(code) - - if user_data: - # We should now wipe it from the database, as it's one time use code - self._data.pop(code) - await self._async_save() - - if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.now( - timezone.utc - ): - return user_data["user_info"] - - return None - - def get_data(self): - """Get the internal data for testing purposes.""" - return self._data diff --git a/custom_components/auth_oidc/stores/state_store.py b/custom_components/auth_oidc/stores/state_store.py new file mode 100644 index 0000000..00a8455 --- /dev/null +++ b/custom_components/auth_oidc/stores/state_store.py @@ -0,0 +1,191 @@ +"""State Store, store authentication states (redirect_uri).""" + +import secrets +import random +import string + +from datetime import datetime, timedelta, timezone +from typing import cast, Optional +from homeassistant.helpers.storage import Store +from homeassistant.core import HomeAssistant + +from ..tools.types import OIDCState, UserDetails + +STORAGE_VERSION = 1 +STORAGE_KEY = "auth_provider.auth_oidc.states" +MAX_DEVICE_CODE_ATTEMPTS = 10 + + +class StateStore: + """Holds the authentication states and associated data""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the user data store.""" + self.hass = hass + self._store = Store[dict[str, OIDCState]]( + hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True + ) + self._data: dict[str, OIDCState] | None = None + + async def async_load(self) -> None: + """Load stored data.""" + if (data := await self._store.async_load()) is None: + data = cast(dict[str, OIDCState], {}) + self._data = data + + async def _async_save(self) -> None: + """Save data.""" + if self._data is not None: + await self._store.async_save(self._data) + + def _generate_id(self) -> str: + """Generate a random identifier.""" + return secrets.token_urlsafe(32) + + def _generate_code(self) -> str: + """Generate a random six-digit code.""" + return "".join(random.choices(string.digits, k=6)) + + def _is_expired(self, state: OIDCState) -> bool: + """Check if a state is expired.""" + return datetime.fromisoformat(state["expiration"]) < datetime.now(timezone.utc) + + def _is_valid(self, state: OIDCState, ip: str | None) -> bool: + """Check if a state is valid""" + return ( + not self._is_expired(state) + and bool(state["redirect_uri"]) + and ip is not None + and state["ip_address"] == ip + ) + + async def async_create_state_from_url(self, redirect_uri: str, ip: str) -> str: + """Generates a the OIDC state adds it to the database for 5 minutes.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + state_id = self._generate_id() + expiration = datetime.now(timezone.utc) + timedelta(minutes=5) + + self._data[state_id] = { + "id": state_id, + "redirect_uri": redirect_uri, + "device_code": None, + "device_code_attempts": 0, + "user_details": None, + "expiration": expiration.isoformat(), + "ip_address": ip, + } + + await self._async_save() + return state_id + + async def async_generate_code_for_state(self, state_id: str) -> Optional[str]: + """Generates a one time code for the state to link device clients.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + try: + code = self._generate_code() + self._data[state_id]["device_code"] = code + await self._async_save() + return code + except KeyError: + return None + + async def async_add_userinfo_to_state( + self, state_id: str, user_info: UserDetails + ) -> bool: + """Add userinfo to existing state to complete login""" + if self._data is None: + raise RuntimeError("Data not loaded") + + try: + self._data[state_id]["user_details"] = user_info + await self._async_save() + return True + except KeyError: + return False + + async def async_get_redirect_uri_for_state( + self, state_id: str, ip: str + ) -> Optional[str]: + """Get the redirect_uri for a given state_id.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + state = self._data.get(state_id) + if state and self._is_valid(state, ip): + return state["redirect_uri"] + + return None + + async def async_is_state_ready(self, state_id: str, ip: str) -> bool: + """Check if the state has received the user info from the OIDC callback.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + state = self._data.get(state_id) + return ( + state is not None + and state["user_details"] is not None + and self._is_valid(state, ip) + ) + + async def async_link_state_to_code( + self, state_id: str, code: str, ip: str | None + ) -> bool: + """Link a state to a device code, used for mobile sign-in.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + state_data = self._data.get(state_id) + if ( + state_data + and self._is_valid(state_data, ip) + and state_data["user_details"] is not None + ): + attempts = state_data.get("device_code_attempts", 0) + if attempts >= MAX_DEVICE_CODE_ATTEMPTS: + return False + + # Find the state with the matching device code and link it + for state in self._data.values(): + if state["device_code"] == code and not self._is_expired(state): + # Set user details on the device state to allow it to complete login + state["user_details"] = state_data["user_details"] + + # Delete the 'donor' state as it's one time use + self._data.pop(state_id) + + # Save and return true + await self._async_save() + return True + + state_data["device_code_attempts"] = attempts + 1 + await self._async_save() + + return False + + async def async_receive_userinfo_for_state( + self, state_id: str, ip: str + ) -> Optional[OIDCState]: + """Retrieve user info based on the state_id.""" + if self._data is None: + raise RuntimeError("Data not loaded") + + user_data = self._data.get(state_id) + + if user_data: + # We should now wipe it from the database, as it's one time use + self._data.pop(state_id) + await self._async_save() + + if user_data and self._is_valid(user_data, ip): + return user_data["user_details"] + + return None + + def get_data(self): + """Get the internal data for testing purposes.""" + return self._data diff --git a/custom_components/auth_oidc/tools/helpers.py b/custom_components/auth_oidc/tools/helpers.py index ec984dd..8adc4e3 100644 --- a/custom_components/auth_oidc/tools/helpers.py +++ b/custom_components/auth_oidc/tools/helpers.py @@ -1,8 +1,17 @@ """Helper functions for the integration.""" +from typing import TYPE_CHECKING + from homeassistant.components import http +from aiohttp import web + from ..views.loader import AsyncTemplateRenderer +if TYPE_CHECKING: + from ..provider import OpenIDAuthProvider + +STATE_COOKIE_NAME = "auth_oidc_state" + def get_url(path: str, force_https: bool) -> str: """Returns the requested path appended to the current request base URL.""" @@ -22,3 +31,39 @@ async def get_view(template: str, parameters: dict | None = None) -> str: renderer = AsyncTemplateRenderer() return await renderer.render_template(f"{template}.html", **parameters) + + +def get_state_id(request: web.Request) -> str | None: + """Return the current OIDC state cookie, if present.""" + return request.cookies.get(STATE_COOKIE_NAME) + + +async def get_valid_state_id( + request: web.Request, oidc_provider: "OpenIDAuthProvider" +) -> str | None: + """Return state id only when cookie exists and state is still valid.""" + state_id = get_state_id(request) + if not state_id: + return None + + if not await oidc_provider.async_is_state_valid(state_id): + return None + + return state_id + + +def html_response(html: str, status: int = 200) -> web.Response: + """Return an HTML response with the standard content type.""" + return web.Response(text=html, content_type="text/html", status=status) + + +async def template_response( + template: str, parameters: dict | None = None +) -> web.Response: + """Render a template and return it as an HTML response.""" + return html_response(await get_view(template, parameters)) + + +async def error_response(message: str, status: int = 400) -> web.Response: + """Render the shared error view.""" + return html_response(await get_view("error", {"error": message}), status=status) diff --git a/custom_components/auth_oidc/tools/oidc_client.py b/custom_components/auth_oidc/tools/oidc_client.py index 1d3413a..c0cd501 100644 --- a/custom_components/auth_oidc/tools/oidc_client.py +++ b/custom_components/auth_oidc/tools/oidc_client.py @@ -289,9 +289,6 @@ class OIDCDiscoveryClient: class OIDCClient: """OIDC Client implementation for Python, including PKCE.""" - # Flows stores the state, code_verifier and nonce of all current flows. - flows = {} - # HTTP session to be used http_session: aiohttp.ClientSession = None @@ -312,6 +309,9 @@ class OIDCClient: self.client_id = client_id self.scope = scope + # Stores code_verifier and nonce for active authorization flows. + self.flows: dict[str, dict[str, str]] = {} + # Optional parameters self.client_secret = kwargs.get("client_secret") @@ -544,7 +544,9 @@ class OIDCClient: _LOGGER.warning("JWT verification failed: %s", e) return None - async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]: + async def async_get_authorization_url( + self, redirect_uri: str, state: str + ) -> Optional[str]: """Generates the authorization URL for the OIDC flow.""" try: discovery_document = await self._fetch_discovery_document() @@ -552,7 +554,6 @@ class OIDCClient: # Generate random nonce & state nonce = self._generate_random_url_string() - state = self._generate_random_url_string() # Generate PKCE (RFC 7636) parameters code_verifier = self._generate_random_url_string(32) @@ -644,11 +645,10 @@ class OIDCClient: """Completes the OIDC token flow to obtain a user's details.""" try: - if state not in self.flows: + flow = self.flows.pop(state, None) + if flow is None: raise OIDCStateInvalid - flow = self.flows[state] - discovery_document = await self._fetch_discovery_document() token_endpoint = discovery_document["token_endpoint"] diff --git a/custom_components/auth_oidc/tools/types.py b/custom_components/auth_oidc/tools/types.py index ef55c63..e302aef 100644 --- a/custom_components/auth_oidc/tools/types.py +++ b/custom_components/auth_oidc/tools/types.py @@ -16,3 +16,26 @@ class UserDetails(dict): username: str # Home Assistant role to assign to this user role: Literal["system-admin", "system-users", "invalid"] + + +class OIDCState(dict): + """OIDC State representation""" + + # ID of this state + id: str + + # User friendly device code + device_code: str | None + + # The redirect_uri associated with this state, + # to be able to redirect the user back after authentication + redirect_uri: str + + # User details, if available + user_details: UserDetails | None + + # Expiration time of this state, in ISO format + expiration: str + + # IP address + ip_address: str | None diff --git a/custom_components/auth_oidc/views/templates/device_success.html b/custom_components/auth_oidc/views/templates/device_success.html new file mode 100644 index 0000000..af3af5a --- /dev/null +++ b/custom_components/auth_oidc/views/templates/device_success.html @@ -0,0 +1,14 @@ +{% extends "base.html" %} +{% block title %}Done!{% endblock %} +{% block head %} +{{ super() }} +{% endblock %} +{% block content %} +
+

You have successfully logged in on your mobile device. It should continue the login soon.

You have been logged out on this device.

+
+ Restart +
+
+{% endblock %} \ No newline at end of file diff --git a/custom_components/auth_oidc/views/templates/finish.html b/custom_components/auth_oidc/views/templates/finish.html index 6c91500..37100de 100644 --- a/custom_components/auth_oidc/views/templates/finish.html +++ b/custom_components/auth_oidc/views/templates/finish.html @@ -4,28 +4,63 @@ {{ super() }} {% endblock %} {% block content %} -
-
-

I want to login to this browser

+
+

Logged in!

+ +
+

Continue on this device

+

Tap Continue to login to Home Assistant on this device.

- -
-
- -
-

I am on a mobile device

-

Your one-time code is: {{ code }}

-

You have 5 minutes to use this code on any device.
The code can only - be used once.

-

Please type the code into your app manually. If you don't see a code input, select - 'Login with - OpenID Connect (SSO)' first.

+
+
+ Use a code from another device +
+
+

On your other device, open the Home Assistant app. You will see a + 6-digit code.

+

Input that code here and click Approve to login on the other device. +

+
+
+ +
+ +
+
{% endblock %} \ No newline at end of file diff --git a/custom_components/auth_oidc/views/templates/redirect.html b/custom_components/auth_oidc/views/templates/redirect.html new file mode 100644 index 0000000..1df06bc --- /dev/null +++ b/custom_components/auth_oidc/views/templates/redirect.html @@ -0,0 +1,28 @@ +{% extends "base.html" %} +{% block title %}OIDC Redirect{% endblock %} +{% block head %} +{{ super() }} +{% endblock %} +{% block content %} +
+
+ + Redirecting... +
+
+ +{% endblock %} \ No newline at end of file diff --git a/custom_components/auth_oidc/views/templates/welcome.html b/custom_components/auth_oidc/views/templates/welcome.html index 2d67054..b427a48 100644 --- a/custom_components/auth_oidc/views/templates/welcome.html +++ b/custom_components/auth_oidc/views/templates/welcome.html @@ -12,41 +12,53 @@ dashboard

-

Home Assistant

-

You have been invited to login to Home Assistant.
Start the login process below.

- -
- - - + + {% else %} + + {% endif %} + + {% if other_link %} +

+ Use alternative sign-in method +

+ {% endif %}