Reimplement UI injection (#236)
This commit is contained in:
committed by
GitHub
parent
fdc93e2719
commit
fd3643685d
@@ -1,8 +1,17 @@
|
||||
"""Helper functions for the integration."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.components import http
|
||||
from aiohttp import web
|
||||
|
||||
from ..views.loader import AsyncTemplateRenderer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..provider import OpenIDAuthProvider
|
||||
|
||||
STATE_COOKIE_NAME = "auth_oidc_state"
|
||||
|
||||
|
||||
def get_url(path: str, force_https: bool) -> str:
|
||||
"""Returns the requested path appended to the current request base URL."""
|
||||
@@ -22,3 +31,39 @@ async def get_view(template: str, parameters: dict | None = None) -> str:
|
||||
|
||||
renderer = AsyncTemplateRenderer()
|
||||
return await renderer.render_template(f"{template}.html", **parameters)
|
||||
|
||||
|
||||
def get_state_id(request: web.Request) -> str | None:
|
||||
"""Return the current OIDC state cookie, if present."""
|
||||
return request.cookies.get(STATE_COOKIE_NAME)
|
||||
|
||||
|
||||
async def get_valid_state_id(
|
||||
request: web.Request, oidc_provider: "OpenIDAuthProvider"
|
||||
) -> str | None:
|
||||
"""Return state id only when cookie exists and state is still valid."""
|
||||
state_id = get_state_id(request)
|
||||
if not state_id:
|
||||
return None
|
||||
|
||||
if not await oidc_provider.async_is_state_valid(state_id):
|
||||
return None
|
||||
|
||||
return state_id
|
||||
|
||||
|
||||
def html_response(html: str, status: int = 200) -> web.Response:
|
||||
"""Return an HTML response with the standard content type."""
|
||||
return web.Response(text=html, content_type="text/html", status=status)
|
||||
|
||||
|
||||
async def template_response(
|
||||
template: str, parameters: dict | None = None
|
||||
) -> web.Response:
|
||||
"""Render a template and return it as an HTML response."""
|
||||
return html_response(await get_view(template, parameters))
|
||||
|
||||
|
||||
async def error_response(message: str, status: int = 400) -> web.Response:
|
||||
"""Render the shared error view."""
|
||||
return html_response(await get_view("error", {"error": message}), status=status)
|
||||
|
||||
@@ -289,9 +289,6 @@ class OIDCDiscoveryClient:
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
|
||||
# Flows stores the state, code_verifier and nonce of all current flows.
|
||||
flows = {}
|
||||
|
||||
# HTTP session to be used
|
||||
http_session: aiohttp.ClientSession = None
|
||||
|
||||
@@ -312,6 +309,9 @@ class OIDCClient:
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
# Stores code_verifier and nonce for active authorization flows.
|
||||
self.flows: dict[str, dict[str, str]] = {}
|
||||
|
||||
# Optional parameters
|
||||
self.client_secret = kwargs.get("client_secret")
|
||||
|
||||
@@ -544,7 +544,9 @@ class OIDCClient:
|
||||
_LOGGER.warning("JWT verification failed: %s", e)
|
||||
return None
|
||||
|
||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
||||
async def async_get_authorization_url(
|
||||
self, redirect_uri: str, state: str
|
||||
) -> Optional[str]:
|
||||
"""Generates the authorization URL for the OIDC flow."""
|
||||
try:
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
@@ -552,7 +554,6 @@ class OIDCClient:
|
||||
|
||||
# Generate random nonce & state
|
||||
nonce = self._generate_random_url_string()
|
||||
state = self._generate_random_url_string()
|
||||
|
||||
# Generate PKCE (RFC 7636) parameters
|
||||
code_verifier = self._generate_random_url_string(32)
|
||||
@@ -644,11 +645,10 @@ class OIDCClient:
|
||||
"""Completes the OIDC token flow to obtain a user's details."""
|
||||
|
||||
try:
|
||||
if state not in self.flows:
|
||||
flow = self.flows.pop(state, None)
|
||||
if flow is None:
|
||||
raise OIDCStateInvalid
|
||||
|
||||
flow = self.flows[state]
|
||||
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
token_endpoint = discovery_document["token_endpoint"]
|
||||
|
||||
|
||||
@@ -16,3 +16,26 @@ class UserDetails(dict):
|
||||
username: str
|
||||
# Home Assistant role to assign to this user
|
||||
role: Literal["system-admin", "system-users", "invalid"]
|
||||
|
||||
|
||||
class OIDCState(dict):
|
||||
"""OIDC State representation"""
|
||||
|
||||
# ID of this state
|
||||
id: str
|
||||
|
||||
# User friendly device code
|
||||
device_code: str | None
|
||||
|
||||
# The redirect_uri associated with this state,
|
||||
# to be able to redirect the user back after authentication
|
||||
redirect_uri: str
|
||||
|
||||
# User details, if available
|
||||
user_details: UserDetails | None
|
||||
|
||||
# Expiration time of this state, in ISO format
|
||||
expiration: str
|
||||
|
||||
# IP address
|
||||
ip_address: str | None
|
||||
|
||||
Reference in New Issue
Block a user