diff --git a/custom_components/auth_oidc/__init__.py b/custom_components/auth_oidc/__init__.py
index af14a9e..b8638b3 100644
--- a/custom_components/auth_oidc/__init__.py
+++ b/custom_components/auth_oidc/__init__.py
@@ -27,7 +27,6 @@ from .config import (
ROLES,
NETWORK,
FEATURES_INCLUDE_GROUPS_SCOPE,
- FEATURES_DISABLE_FRONTEND_INJECTION,
FEATURES_FORCE_HTTPS,
REQUIRED_SCOPES,
)
@@ -40,6 +39,7 @@ from .endpoints import (
OIDCFinishView,
OIDCCallbackView,
OIDCInjectedAuthPage,
+ OIDCDeviceSSE,
)
from .tools.oidc_client import OIDCClient
from .provider import OpenIDAuthProvider
@@ -96,6 +96,10 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
providers[(provider.type, provider.id)] = provider
+
+ # Get current provider count
+ has_other_auth_providers = len(hass.auth._providers) > 0
+
providers.update(hass.auth._providers)
hass.auth._providers = providers
# pylint: enable=protected-access
@@ -137,33 +141,22 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
)
# Register the views
- is_frontend_injection_enabled = (
- features_config.get(FEATURES_DISABLE_FRONTEND_INJECTION, False) is False
- )
name = display_name
name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name)
force_https = features_config.get(FEATURES_FORCE_HTTPS, False)
hass.http.register_view(
- OIDCWelcomeView(
- name,
- # Welcome view is not enabled if frontend injection is enabled
- not is_frontend_injection_enabled,
- force_https,
- )
+ OIDCWelcomeView(provider, name, force_https, has_other_auth_providers)
)
- hass.http.register_view(OIDCRedirectView(oidc_client, force_https))
+ hass.http.register_view(OIDCDeviceSSE(provider))
+ hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https))
hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https))
- hass.http.register_view(OIDCFinishView())
+ hass.http.register_view(OIDCFinishView(provider))
_LOGGER.info("Registered OIDC views")
- # Inject OIDC code into the frontend for /auth/authorize if the user has the
- # frontend injection feature enabled
- if is_frontend_injection_enabled:
- await OIDCInjectedAuthPage.inject(hass, name)
- else:
- _LOGGER.info("OIDC frontend changes are disabled, skipping injection")
+ # Inject OIDC code into the frontend for /auth/authorize for automatic redirect
+ await OIDCInjectedAuthPage.inject(hass)
return True
diff --git a/custom_components/auth_oidc/config/const.py b/custom_components/auth_oidc/config/const.py
index 8538262..24b0297 100644
--- a/custom_components/auth_oidc/config/const.py
+++ b/custom_components/auth_oidc/config/const.py
@@ -28,7 +28,6 @@ FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
FEATURES_DISABLE_PKCE = "disable_rfc7636"
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
-FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
FEATURES_FORCE_HTTPS = "force_https"
CLAIMS = "claims"
CLAIMS_DISPLAY_NAME = "display_name"
diff --git a/custom_components/auth_oidc/config/schema.py b/custom_components/auth_oidc/config/schema.py
index 2167503..e7b9b9c 100644
--- a/custom_components/auth_oidc/config/schema.py
+++ b/custom_components/auth_oidc/config/schema.py
@@ -14,7 +14,6 @@ from .const import (
FEATURES_AUTOMATIC_PERSON_CREATION,
FEATURES_DISABLE_PKCE,
FEATURES_INCLUDE_GROUPS_SCOPE,
- FEATURES_DISABLE_FRONTEND_INJECTION,
FEATURES_FORCE_HTTPS,
CLAIMS,
CLAIMS_DISPLAY_NAME,
@@ -72,10 +71,6 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(
FEATURES_INCLUDE_GROUPS_SCOPE, default=True
): vol.Coerce(bool),
- # Disable frontend injection of OIDC login button
- vol.Optional(
- FEATURES_DISABLE_FRONTEND_INJECTION, default=False
- ): vol.Coerce(bool),
# Force HTTPS on all generated URLs (like redirect_uri)
vol.Optional(FEATURES_FORCE_HTTPS, default=False): vol.Coerce(
bool
diff --git a/custom_components/auth_oidc/config/ui_flow.py b/custom_components/auth_oidc/config/ui_flow.py
index 847b26c..6d21009 100644
--- a/custom_components/auth_oidc/config/ui_flow.py
+++ b/custom_components/auth_oidc/config/ui_flow.py
@@ -621,21 +621,18 @@ class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors["client_id"] = "invalid_client_id"
return errors, None
- # Determine confidentiality by presence of client secret
- client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
- # If secret is empty, keep the existing one (if any)
- if not client_secret:
- client_secret = entry.data.get("client_secret")
-
# Build updated data
data_updates = {"client_id": client_id}
- if client_secret:
- data_updates["client_secret"] = client_secret
- elif "client_secret" in entry.data and not client_secret:
- # Remove client secret if switching from confidential to public
- data_updates = {**entry.data, **data_updates}
- data_updates.pop("client_secret", None)
+ # The optional secret field is submitted explicitly when the form is used.
+ # An empty value means the user wants to keep the existing secret.
+ if CONF_CLIENT_SECRET in user_input:
+ client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
+
+ if client_secret:
+ data_updates["client_secret"] = client_secret
+ elif "client_secret" in entry.data:
+ data_updates["client_secret"] = entry.data["client_secret"]
return errors, data_updates
diff --git a/custom_components/auth_oidc/endpoints/__init__.py b/custom_components/auth_oidc/endpoints/__init__.py
index 3808b68..613fba0 100644
--- a/custom_components/auth_oidc/endpoints/__init__.py
+++ b/custom_components/auth_oidc/endpoints/__init__.py
@@ -5,3 +5,4 @@ from .finish import OIDCFinishView as OIDCFinishView
from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage
from .redirect import OIDCRedirectView as OIDCRedirectView
from .welcome import OIDCWelcomeView as OIDCWelcomeView
+from .device_sse import OIDCDeviceSSE as OIDCDeviceSSE
diff --git a/custom_components/auth_oidc/endpoints/callback.py b/custom_components/auth_oidc/endpoints/callback.py
index 178ee1e..19aea11 100644
--- a/custom_components/auth_oidc/endpoints/callback.py
+++ b/custom_components/auth_oidc/endpoints/callback.py
@@ -4,7 +4,7 @@ from homeassistant.components.http import HomeAssistantView
from aiohttp import web
from ..tools.oidc_client import OIDCClient
from ..provider import OpenIDAuthProvider
-from ..tools.helpers import get_url, get_view
+from ..tools.helpers import error_response, get_url, get_valid_state_id
PATH = "/auth/oidc/callback"
@@ -29,42 +29,49 @@ class OIDCCallbackView(HomeAssistantView):
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
+ # Get cookie to get the state_id
+ state_id = await get_valid_state_id(request, self.oidc_provider)
+ if not state_id:
+ return await error_response("Missing state cookie, please restart login.")
+
+ # Get the OIDC query parameters
params = request.rel_url.query
code = params.get("code")
state = params.get("state")
if not (code and state):
- view_html = await get_view(
- "error",
- {
- "error": "Missing code or state parameter.",
- },
- )
- return web.Response(text=view_html, content_type="text/html")
+ return await error_response("Missing code or state parameter.")
+ # Check if the states match
+ if state != state_id:
+ return await error_response(
+ "State parameter does not match, possible CSRF attack."
+ )
+
+ # Complete the OIDC flow to get user details
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
user_details = await self.oidc_client.async_complete_token_flow(
redirect_uri, code, state
)
if user_details is None:
- view_html = await get_view(
- "error",
- {
- "error": "Failed to get user details, "
- + "see Home Assistant logs for more information.",
- },
+ return await error_response(
+ "Failed to get user details, see Home Assistant logs for more information.",
+ status=500,
)
- return web.Response(text=view_html, content_type="text/html")
if user_details.get("role") == "invalid":
- view_html = await get_view(
- "error",
- {
- "error": "User is not in the correct group to access Home Assistant, "
- + "contact your administrator!",
- },
+ return await error_response(
+ "User is not in the correct group to access Home Assistant, "
+ + "contact your administrator!",
+ status=403,
)
- return web.Response(text=view_html, content_type="text/html")
- code = await self.oidc_provider.async_save_user_info(user_details)
- raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https))
+ # Finalize on the state
+ success = await self.oidc_provider.async_save_user_info(state_id, user_details)
+ if not success:
+ return await error_response(
+ "Failed to save user information, session probably expired. Please sign in again.",
+ status=500,
+ )
+
+ raise web.HTTPFound(get_url("/auth/oidc/finish", self.force_https))
diff --git a/custom_components/auth_oidc/endpoints/device_sse.py b/custom_components/auth_oidc/endpoints/device_sse.py
new file mode 100644
index 0000000..bc7aad6
--- /dev/null
+++ b/custom_components/auth_oidc/endpoints/device_sse.py
@@ -0,0 +1,70 @@
+"""SSE handler for OIDC device authentication."""
+
+import asyncio
+from aiohttp import web
+from homeassistant.components.http import HomeAssistantView
+from ..provider import OpenIDAuthProvider
+from ..tools.helpers import get_valid_state_id
+
+PATH = "/auth/oidc/device-sse"
+
+
+class OIDCDeviceSSE(HomeAssistantView):
+ """OIDC Plugin SSE Handler."""
+
+ requires_auth = False
+ url = PATH
+ name = "auth:oidc:device-sse"
+
+ def __init__(self, oidc_provider: OpenIDAuthProvider) -> None:
+ self.oidc_provider = oidc_provider
+
+ async def get(self, req: web.Request) -> web.Response:
+ """Check for mobile sign-in completion with short server-side polling."""
+ state_id = await get_valid_state_id(req, self.oidc_provider)
+ if not state_id:
+ raise web.HTTPBadRequest(text="Missing session cookie")
+
+ timeout_seconds = 300
+ started_at = asyncio.get_running_loop().time()
+
+ response = web.StreamResponse(
+ status=200,
+ headers={
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+ await response.prepare(req)
+
+ try:
+ while True:
+ if (
+ await self.oidc_provider.async_get_redirect_uri_for_state(state_id)
+ is None
+ ):
+ await response.write(b"event: expired\ndata: false\n\n")
+ break
+
+ ready = await self.oidc_provider.async_is_state_ready(state_id)
+ if ready:
+ await response.write(b"event: ready\ndata: true\n\n")
+ break
+
+ if asyncio.get_running_loop().time() - started_at >= timeout_seconds:
+ await response.write(b"event: timeout\ndata: false\n\n")
+ break
+
+ await response.write(b"event: waiting\ndata: false\n\n")
+ await asyncio.sleep(0.5)
+ except (ConnectionResetError, RuntimeError):
+ # Client disconnected while listening for state changes.
+ pass
+ finally:
+ try:
+ await response.write_eof()
+ except ConnectionResetError:
+ pass
+
+ return response
diff --git a/custom_components/auth_oidc/endpoints/finish.py b/custom_components/auth_oidc/endpoints/finish.py
index 6b60371..8421215 100644
--- a/custom_components/auth_oidc/endpoints/finish.py
+++ b/custom_components/auth_oidc/endpoints/finish.py
@@ -2,7 +2,12 @@
from homeassistant.components.http import HomeAssistantView
from aiohttp import web
-from ..tools.helpers import get_view
+from ..provider import OpenIDAuthProvider
+from ..tools.helpers import (
+ error_response,
+ get_valid_state_id,
+ template_response,
+)
PATH = "/auth/oidc/finish"
@@ -14,41 +19,62 @@ class OIDCFinishView(HomeAssistantView):
url = PATH
name = "auth:oidc:finish"
+ def __init__(
+ self,
+ oidc_provider: OpenIDAuthProvider,
+ ) -> None:
+ self.oidc_provider = oidc_provider
+
async def get(self, request: web.Request) -> web.Response:
- """Show the finish screen to allow the user to view their code."""
+ """Show the finish screen to pick between login & device code."""
+ # Get cookie to get the state_id
+ state_id = await get_valid_state_id(request, self.oidc_provider)
+ if not state_id:
+ return await error_response("Missing state cookie, please restart login.")
- code = request.query.get("code")
-
- if not code:
- view_html = await get_view(
- "error",
- {"error": "Missing code to show the finish screen."},
- )
- return web.Response(text=view_html, content_type="text/html")
-
- view_html = await get_view("finish", {"code": code})
- return web.Response(text=view_html, content_type="text/html")
+ return await template_response("finish", {})
async def post(self, request: web.Request) -> web.Response:
"""Receive response."""
- # Get code from the message body
- data = await request.post()
- code = data.get("code")
+ # Get cookie to get the state_id
+ state_id = await get_valid_state_id(request, self.oidc_provider)
+ if not state_id:
+ return await error_response("Missing state cookie, please restart login.")
- if not code:
- return web.Response(text="No code received", status=500)
-
- # Return redirect to the main page for sign in with a cookie
- raise web.HTTPFound(
- location="/?storeToken=true",
- headers={
- # Set a cookie to enable autologin on only the specific path used
- # for the POST request, with all strict parameters set
- # This cookie should not be read by any Javascript or any other paths.
- # It can be really short lifetime as we redirect immediately (5 seconds)
- "set-cookie": "auth_oidc_code="
- + code
- + "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=5",
- },
+ # Get redirect_uri from the state
+ redirect_uri = await self.oidc_provider.async_get_redirect_uri_for_state(
+ state_id
)
+
+ if not redirect_uri:
+ return await error_response("Invalid state, please restart login.")
+
+ # Get the message body
+ data = await request.post()
+ device_code = data.get("device_code")
+
+ # We are trying sign-in on this browser
+ if not device_code:
+ # Add to the URL correctly (also handle case where it's just the root)
+ separator = "?"
+ if "?" in redirect_uri:
+ separator = "&"
+
+ # Redirect to this new URL for login
+ new_url = (
+ redirect_uri + separator + "storeToken=true&skip_oidc_redirect=true"
+ )
+ raise web.HTTPFound(location=new_url)
+
+ # Check if we can link this device
+ linked = await self.oidc_provider.async_link_state_to_code(
+ state_id, device_code
+ )
+
+ if not linked:
+ return await error_response(
+ "Failed to link state to device code, please restart login."
+ )
+
+ return await template_response("device_success", {})
diff --git a/custom_components/auth_oidc/endpoints/injected_auth_page.py b/custom_components/auth_oidc/endpoints/injected_auth_page.py
index 74d0945..ba5c5c4 100644
--- a/custom_components/auth_oidc/endpoints/injected_auth_page.py
+++ b/custom_components/auth_oidc/endpoints/injected_auth_page.py
@@ -1,6 +1,5 @@
"""Injected authorization page, replacing the original"""
-import json
import logging
from functools import partial
from homeassistant.components.http import HomeAssistantView, StaticPathConfig
@@ -19,7 +18,7 @@ async def read_file(path: str) -> str:
return await f.read()
-async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None:
+async def frontend_injection(hass: HomeAssistant) -> None:
"""Inject new frontend code into /auth/authorize."""
router = hass.http.app.router
frontend_path = None
@@ -62,11 +61,8 @@ async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None:
frontend_code = await read_file(frontend_path)
# Inject JS and register that route
- injection_js = ""
- sso_name_js = f""
- frontend_code = frontend_code.replace(
- "