Code quality improvements (v0.2.0-pre-alpha) (#5)
* Bumped version to 0.2.0 * Implemented Github Actions for HACS, Hassfest, Linting * Improved code quality (compliant with the linter now) * Added link to the finish page to automatically login on the same device/browser
This commit is contained in:
committed by
GitHub
parent
a30d42ffce
commit
b4a08b17ab
@@ -1,8 +1,11 @@
|
||||
"""OIDC Authentication provider.
|
||||
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
@@ -11,30 +14,26 @@ from homeassistant.auth.providers import (
|
||||
Credentials,
|
||||
UserMeta,
|
||||
)
|
||||
from homeassistant.components import http
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import string
|
||||
from homeassistant.helpers.storage import Store
|
||||
from collections.abc import Mapping
|
||||
|
||||
from .stores.code_store import CodeStore
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
@AUTH_PROVIDERS.register("oidc")
|
||||
class OpenIDAuthProvider(AuthProvider):
|
||||
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
|
||||
"""Allow access to users based on login with an external
|
||||
OpenID Connect Identity Provider (IdP)."""
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "auth_oidc"
|
||||
@@ -42,13 +41,62 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
self._code_store: CodeStore | None = None
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
|
||||
# Init the code store first
|
||||
# Use the same technique as the HomeAssistant auth provider for storage
|
||||
# (/auth/providers/homeassistant.py#L392)
|
||||
async with self._init_lock:
|
||||
if self._code_store is not None:
|
||||
return
|
||||
|
||||
store = CodeStore(self.hass)
|
||||
await store.async_load()
|
||||
self._code_store = store
|
||||
self._user_meta = {}
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[str]:
|
||||
"""Retrieve user from the code, return username and save meta
|
||||
for later use with this provider instance."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
|
||||
user_data = await self._code_store.receive_userinfo_for_code(code)
|
||||
if user_data is None:
|
||||
return None
|
||||
|
||||
username = user_data["username"]
|
||||
self._user_meta[username] = user_data
|
||||
return username
|
||||
|
||||
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
|
||||
"""Save user info and return a code."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
|
||||
return await self._code_store.async_generate_code_for_userinfo(user_info)
|
||||
|
||||
# ====
|
||||
# Required functions for Home Assistant Auth Providers
|
||||
# ====
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return OpenIdLoginFlow(self)
|
||||
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Mapping[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
@@ -64,7 +112,7 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Currently, supports name, group and local_only.
|
||||
Currently, supports name, is_active, group and local_only.
|
||||
"""
|
||||
meta = self._user_meta.get(credentials.data["username"], {})
|
||||
groups = meta.get("groups", [])
|
||||
@@ -76,96 +124,70 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
group=group,
|
||||
local_only="true",
|
||||
)
|
||||
|
||||
async def save_user_info(self, user_info: dict) -> str:
|
||||
"""Save user info during login."""
|
||||
_LOGGER.info("User info to be saved: %s", user_info)
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
user_data = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat()
|
||||
}
|
||||
|
||||
await self._save_to_db(self._get_code_key(code), user_data)
|
||||
return code
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[dict]:
|
||||
"""Retrieve user info based on the code."""
|
||||
user_data = await self._get_from_db(self._get_code_key(code))
|
||||
await self._wipe_from_db(self._get_code_key(code))
|
||||
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
|
||||
username = user_data["user_info"]["preferred_username"]
|
||||
self._user_meta[username] = user_data["user_info"]
|
||||
return username
|
||||
return None
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return ''.join(random.choices(string.digits, k=6))
|
||||
|
||||
def _get_code_key(self, code: str) -> str:
|
||||
return f"provider_oidc_auth_user_{code}"
|
||||
|
||||
async def _save_to_db(self, key: str, value: dict) -> None:
|
||||
"""Save key-value data to the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
await store.async_save(value)
|
||||
|
||||
async def _get_from_db(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_load()
|
||||
|
||||
async def _wipe_from_db(self, key: str) -> None:
|
||||
"""Delete key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_remove()
|
||||
|
||||
|
||||
class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def _finalize_user(self, code: str) -> AuthFlowResult:
|
||||
username = await self._auth_provider.async_retrieve_username(code)
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
return await self.async_finish(
|
||||
{
|
||||
"username": username,
|
||||
}
|
||||
)
|
||||
|
||||
raise InvalidAuthError
|
||||
|
||||
def _show_login_form(
|
||||
self, errors: Optional[dict[str, str]] = None
|
||||
) -> AuthFlowResult:
|
||||
if errors is None:
|
||||
errors = {}
|
||||
|
||||
# Show the login form
|
||||
# Abuses the MFA form, as it works better for our usecase
|
||||
# UI suggestions are welcome (make a PR!)
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
|
||||
# Show the login form
|
||||
# Currently, this form looks bad because the frontend gives no options to make it look better
|
||||
# We will investigate options to make it look better in the future
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors={},
|
||||
)
|
||||
|
||||
# Try to use the user input first
|
||||
if user_input is not None:
|
||||
try:
|
||||
return await self._finalize_user(user_input["code"])
|
||||
except InvalidAuthError:
|
||||
return self._show_login_form({"base": "invalid_auth"})
|
||||
|
||||
# If not available, check the cookie
|
||||
req = http.current_request.get()
|
||||
code_cookie = req.cookies.get("auth_oidc_code")
|
||||
|
||||
if code_cookie:
|
||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
||||
try:
|
||||
return await self._finalize_user(code_cookie)
|
||||
except InvalidAuthError:
|
||||
pass
|
||||
|
||||
# If none are available, just show the form
|
||||
return self._show_login_form()
|
||||
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the result of the form."""
|
||||
|
||||
if user_input is None:
|
||||
return self.async_abort(reason="no_code_given")
|
||||
|
||||
# Log
|
||||
_LOGGER.info("User input %s", user_input)
|
||||
_LOGGER.info("Code %s was entered", user_input["code"])
|
||||
|
||||
username = await self._auth_provider.async_retrieve_username(user_input["code"])
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
|
||||
return await self.async_finish({
|
||||
"username": username,
|
||||
})
|
||||
|
||||
return self.async_abort(reason="invalid_code")
|
||||
# This is a dummy step function just to use the nicer MFA UI instead
|
||||
return await self.async_step_init(user_input)
|
||||
|
||||
Reference in New Issue
Block a user