Reimplement UI injection (#236)
This commit is contained in:
committed by
GitHub
parent
fdc93e2719
commit
fd3643685d
@@ -22,7 +22,6 @@ from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.components import http, person
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
|
||||
from .config.const import (
|
||||
FEATURES,
|
||||
@@ -30,13 +29,14 @@ from .config.const import (
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
DEFAULT_TITLE,
|
||||
)
|
||||
from .stores.code_store import CodeStore
|
||||
from .stores.state_store import StateStore
|
||||
from .tools.types import UserDetails
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_TYPE = "auth_oidc"
|
||||
HASS_PROVIDER_TYPE = "homeassistant"
|
||||
COOKIE_NAME = "auth_oidc_state"
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
@@ -68,7 +68,7 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
)
|
||||
|
||||
self._user_meta: dict[UserDetails] = {}
|
||||
self._code_store: CodeStore | None = None
|
||||
self._state_store: StateStore | None = None
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
features = config.get(
|
||||
@@ -89,29 +89,120 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
|
||||
# Init the code store first
|
||||
# Init the store first
|
||||
# Use the same technique as the HomeAssistant auth provider for storage
|
||||
# (/auth/providers/homeassistant.py#L392)
|
||||
async with self._init_lock:
|
||||
if self._code_store is not None:
|
||||
if self._state_store is not None:
|
||||
return
|
||||
|
||||
store = CodeStore(self.hass)
|
||||
store = StateStore(self.hass)
|
||||
await store.async_load()
|
||||
self._code_store = store
|
||||
self._state_store = store
|
||||
self._user_meta = {}
|
||||
|
||||
# Listen for user creation events
|
||||
self.hass.bus.async_listen(EVENT_USER_ADDED, self.async_user_created)
|
||||
|
||||
async def async_get_subject(self, code: str) -> Optional[str]:
|
||||
"""Retrieve user from the code, return subject and save meta
|
||||
for later use with this provider instance."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
def _resolve_ip(self, ip: str | None = None) -> str | None:
|
||||
"""Resolve client IP from explicit input or current request context."""
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
user_data = await self._code_store.receive_userinfo_for_code(code)
|
||||
req = http.current_request.get()
|
||||
if req and req.remote:
|
||||
return req.remote
|
||||
|
||||
return None
|
||||
|
||||
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
|
||||
"""Create a new OIDC state and return the state id."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_create_state_from_url(
|
||||
redirect_uri, self._resolve_ip(ip)
|
||||
)
|
||||
|
||||
async def async_generate_device_code(self, state_id: str) -> Optional[str]:
|
||||
"""Generate a device code for the state, used for device login."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_generate_code_for_state(state_id)
|
||||
|
||||
async def async_save_user_info(
|
||||
self, state_id: str, user_info: dict[str, dict | str]
|
||||
) -> bool:
|
||||
"""Save user info to the given state."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_add_userinfo_to_state(state_id, user_info)
|
||||
|
||||
async def async_get_redirect_uri_for_state(
|
||||
self, state_id: str, ip: str | None = None
|
||||
) -> Optional[str]:
|
||||
"""Get the redirect_uri for the given state."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_get_redirect_uri_for_state(
|
||||
state_id, self._resolve_ip(ip)
|
||||
)
|
||||
|
||||
async def async_is_state_valid(self, state_id: str, ip: str | None = None) -> bool:
|
||||
"""Check if a state exists, belongs to this IP, and is not expired."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return (
|
||||
await self._state_store.async_get_redirect_uri_for_state(
|
||||
state_id, self._resolve_ip(ip)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
async def async_is_state_ready(self, state_id: str, ip: str | None = None) -> bool:
|
||||
"""Check if the state has received the user info from the OIDC callback."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_is_state_ready(
|
||||
state_id, self._resolve_ip(ip)
|
||||
)
|
||||
|
||||
async def async_link_state_to_code(
|
||||
self, state_id: str, code: str, ip: str | None = None
|
||||
) -> bool:
|
||||
"""Link two states together by copying the user info from one to the other."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
return await self._state_store.async_link_state_to_code(
|
||||
state_id, code, self._resolve_ip(ip)
|
||||
)
|
||||
|
||||
async def async_get_subject(
|
||||
self, state_id: str, ip: str | None = None
|
||||
) -> Optional[str]:
|
||||
"""Retrieve user from the state_id, return subject and save meta
|
||||
for later use with this provider instance."""
|
||||
if self._state_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._state_store is not None
|
||||
|
||||
# This also deletes the state as we are using it for sign-in
|
||||
user_data = await self._state_store.async_receive_userinfo_for_state(
|
||||
state_id, self._resolve_ip(ip)
|
||||
)
|
||||
if user_data is None:
|
||||
return None
|
||||
|
||||
@@ -119,14 +210,6 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
self._user_meta[sub] = user_data
|
||||
return sub
|
||||
|
||||
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
|
||||
"""Save user info and return a code."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
|
||||
return await self._code_store.async_generate_code_for_userinfo(user_info)
|
||||
|
||||
async def _async_find_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""Find a user by username."""
|
||||
users = await self.store.async_get_users()
|
||||
@@ -145,6 +228,18 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
|
||||
return None
|
||||
|
||||
def get_cookie_header(self, state_id: str, secure: bool = False):
|
||||
"""Get the cookie header to set the state_id cookie."""
|
||||
secure_flag = "; Secure" if secure else ""
|
||||
return {
|
||||
# Set a cookie for the other pages to know the state_id
|
||||
# Keep cookie lifetime aligned with state lifetime in storage (5 minutes).
|
||||
"set-cookie": f"{COOKIE_NAME}="
|
||||
+ state_id
|
||||
+ "; Path=/auth/; SameSite=Strict; HttpOnly; Max-Age=300"
|
||||
+ secure_flag,
|
||||
}
|
||||
|
||||
# ====
|
||||
# Handler for user created and related functions (person creation)
|
||||
# ====
|
||||
@@ -271,7 +366,7 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def _finalize_user(self, code: str) -> AuthFlowResult:
|
||||
async def _finalize_user(self, state_id: str) -> AuthFlowResult:
|
||||
# Verify a dummy hash to make it last a bit longer
|
||||
# as security measure (limits the amount of attempts you have in 5 min)
|
||||
# Similar to what the HomeAssistant auth provider does
|
||||
@@ -280,7 +375,7 @@ class OpenIdLoginFlow(LoginFlow):
|
||||
|
||||
# Actually look up the auth provider after,
|
||||
# this doesn't take a lot of time (regardless of it's in there or not)
|
||||
sub = await self._auth_provider.async_get_subject(code)
|
||||
sub = await self._auth_provider.async_get_subject(state_id)
|
||||
if sub:
|
||||
return await self.async_finish(
|
||||
{
|
||||
@@ -290,54 +385,23 @@ class OpenIdLoginFlow(LoginFlow):
|
||||
|
||||
raise InvalidAuthError
|
||||
|
||||
def _show_login_form(
|
||||
self, errors: Optional[dict[str, str]] = None
|
||||
) -> AuthFlowResult:
|
||||
if errors is None:
|
||||
errors = {}
|
||||
|
||||
# Show the login form
|
||||
# Abuses the MFA form, as it works better for our usecase
|
||||
# UI suggestions are welcome (make a PR!)
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
|
||||
# Try to use the user input first
|
||||
if user_input is not None and "code" in user_input:
|
||||
try:
|
||||
return await self._finalize_user(user_input["code"])
|
||||
except InvalidAuthError:
|
||||
return self._show_login_form({"base": "invalid_auth"})
|
||||
|
||||
# If not available, check the cookie
|
||||
# Check if the cookie is present to login
|
||||
req = http.current_request.get()
|
||||
if req and req.cookies:
|
||||
code_cookie = req.cookies.get("auth_oidc_code")
|
||||
state_cookie = req.cookies.get(COOKIE_NAME)
|
||||
|
||||
if code_cookie:
|
||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
||||
if state_cookie:
|
||||
_LOGGER.debug("State cookie found on login: %s", state_cookie)
|
||||
try:
|
||||
return await self._finalize_user(code_cookie)
|
||||
return await self._finalize_user(state_cookie)
|
||||
except InvalidAuthError:
|
||||
pass
|
||||
|
||||
# If none are available, just show the form
|
||||
return self._show_login_form()
|
||||
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
# This is a dummy step function just to use the nicer MFA UI instead
|
||||
return await self.async_step_init(user_input)
|
||||
# If no cookie is found, abort.
|
||||
# User should either be redirected or start manually on the welcome
|
||||
return self.async_abort(reason="no_oidc_cookie_found")
|
||||
|
||||
Reference in New Issue
Block a user