Implement initial flow (#2)

This commit is contained in:
Christiaan Goossens
2024-12-24 21:38:57 +01:00
committed by GitHub
parent 1c8c7ed14a
commit 8ba494c49c
15 changed files with 883 additions and 1805 deletions

View File

@@ -4,6 +4,13 @@ from typing import OrderedDict
import voluptuous as vol
from homeassistant.core import HomeAssistant
from .endpoints.welcome import OIDCWelcomeView
from .endpoints.redirect import OIDCRedirectView
from .endpoints.finish import OIDCFinishView
from .endpoints.callback import OIDCCallbackView
from .oidc_client import OIDCClient
DOMAIN = "auth_oidc"
_LOGGER = logging.getLogger(__name__)
@@ -13,7 +20,9 @@ CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Required("client_id"): vol.Coerce(str),
vol.Optional("client_secret"): vol.Coerce(str),
vol.Required("discovery_url"): vol.Url(),
}
)
},
@@ -34,5 +43,18 @@ async def async_setup(hass: HomeAssistant, config):
providers.update(hass.auth._providers)
hass.auth._providers = providers
_LOGGER.debug("Added OIDC provider")
_LOGGER.debug("Added OIDC provider for Home Assistant")
# Define some fields
discovery_url = config[DOMAIN]["discovery_url"]
client_id = config[DOMAIN]["client_id"]
scope = "openid profile email"
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)
hass.http.register_view(OIDCWelcomeView())
hass.http.register_view(OIDCRedirectView(oidc_client))
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
hass.http.register_view(OIDCFinishView())
return True

View File

@@ -1,41 +0,0 @@
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback
import logging
DATA_VIEW_REGISTERED = "oauth2_view_reg"
AUTH_CALLBACK_PATH = "/auth/oidc/callback"
_LOGGER = logging.getLogger(__name__)
@callback
def async_register_view(hass: HomeAssistant) -> None:
"""Make sure callback view is registered."""
if not hass.data.get(DATA_VIEW_REGISTERED, False):
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
hass.data[DATA_VIEW_REGISTERED] = True
class OAuth2AuthorizeCallbackView(HomeAssistantView):
"""OAuth2 Authorization Callback View."""
requires_auth = False
url = AUTH_CALLBACK_PATH
name = "auth:oidc:callback"
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug(request.query)
hass = request.app["hass"]
flow_mgr = hass.auth.login_flow
await flow_mgr.async_configure(
flow_id=request.query["flow_id"], user_input=request.query["test"]
)
return web.Response(
headers={"content-type": "text/html"},
text="<script>if (window.opener) { window.opener.postMessage({type: 'externalCallback'}); } window.close();</script>",
)

View File

@@ -0,0 +1,49 @@
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
from ..oidc_client import OIDCClient
from ..provider import OpenIDAuthProvider
PATH = "/auth/oidc/callback"
_LOGGER = logging.getLogger(__name__)
class OIDCCallbackView(HomeAssistantView):
"""OIDC Plugin Callback View."""
requires_auth = False
url = PATH
name = "auth:oidc:callback"
def __init__(
self, oidc_client: OIDCClient, oidc_provider: OpenIDAuthProvider
) -> None:
self.oidc_client = oidc_client
self.oidc_provider = oidc_provider
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Callback view accessed")
params = request.rel_url.query
code = params.get("code")
state = params.get("state")
base_uri = str(request.url).split('/auth', 2)[0]
if not (code and state):
return web.Response(
headers={"content-type": "text/html"},
text="<h1>Error</h1><p>Missing code or state parameter</p>",
)
user_details = await self.oidc_client.complete_token_flow(base_uri, code, state)
if user_details is None:
return web.Response(
headers={"content-type": "text/html"},
text="<h1>Error</h1><p>Failed to get user details, see console.</p>",
)
code = await self.oidc_provider.save_user_info(user_details)
return web.HTTPFound(base_uri + "/auth/oidc/finish?code=" + code)

View File

@@ -0,0 +1,24 @@
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
PATH = "/auth/oidc/finish"
_LOGGER = logging.getLogger(__name__)
class OIDCFinishView(HomeAssistantView):
"""OIDC Plugin Finish View."""
requires_auth = False
url = PATH
name = "auth:oidc:finish"
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
code = request.query.get("code", "FAIL")
return web.Response(
headers={"content-type": "text/html"},
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p><p>Please return to the Home Assistant login screen (or your mobile app) and fill in this code into the single login field. It should be visible if you select 'Login with OpenID Connect (SSO)'.</p>",
)

View File

@@ -0,0 +1,46 @@
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
from ..oidc_client import OIDCClient
PATH = "/auth/oidc/redirect"
_LOGGER = logging.getLogger(__name__)
class OIDCRedirectView(HomeAssistantView):
"""OIDC Plugin Redirect View."""
requires_auth = False
url = PATH
name = "auth:oidc:redirect"
def __init__(
self, oidc_client: OIDCClient
) -> None:
self.oidc_client = oidc_client
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Redirect view accessed")
base_uri = str(request.url).split('/auth', 2)[0]
_LOGGER.debug("Base URI: %s", base_uri)
auth_url = await self.oidc_client.get_authorization_url(base_uri)
_LOGGER.debug("Auth URL: %s", auth_url)
if auth_url:
return web.HTTPFound(auth_url)
else:
return web.Response(
headers={"content-type": "text/html"},
text="<h1>Plugin is misconfigured, discovery could not be obtained</h1>",
)
async def post(self, request: web.Request) -> web.Response:
"""POST"""
_LOGGER.debug("Redirect POST view accessed")
return await self.get(request)

View File

@@ -0,0 +1,24 @@
from aiohttp import web
from homeassistant.components.http import HomeAssistantView
import logging
PATH = "/auth/oidc/welcome"
_LOGGER = logging.getLogger(__name__)
class OIDCWelcomeView(HomeAssistantView):
"""OIDC Plugin Welcome View."""
requires_auth = False
url = PATH
name = "auth:oidc:welcome"
async def get(self, request: web.Request) -> web.Response:
"""Receive response."""
_LOGGER.debug("Welcome view accessed")
return web.Response(
headers={"content-type": "text/html"},
text="<h1>OIDC Login (beta)</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
)

View File

@@ -0,0 +1,204 @@
import aiohttp
import urllib.parse
import logging
import os
import base64
import hashlib
from jose import jwt
from jose import jwk, jwt
_LOGGER = logging.getLogger(__name__)
class OIDCClient:
flows = {}
def __init__(self, discovery_url, client_id, scope):
self.discovery_url = discovery_url
self.client_id = client_id
self.scope = scope
async def fetch_discovery_document(self):
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.discovery_url) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
else:
_LOGGER.warning(f"Error: {e.status} - {e.message}")
return None
async def get_authorization_url(self, base_uri):
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
auth_endpoint = self.discovery_document['authorization_endpoint']
# Generate the necessary PKCE parameters, nonce & state
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
# Save all of them for later verification
self.flows[state] = {
'code_verifier': code_verifier,
'nonce': nonce
}
# Construct the params
query_params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': base_uri + '/auth/oidc/callback',
'scope': self.scope,
'state': state,
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': 'S256',
}
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
return url
async def _make_token_request(self, token_endpoint, query_params):
try:
async with aiohttp.ClientSession() as session:
async with session.post(token_endpoint, data=query_params) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
response_json = await response.json()
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
return None
return None
async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(jwks_uri) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientResponseError as e:
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
return None
async def _parse_id_token(self, id_token):
# Parse the id token to obtain the relevant details
# Use python-jose
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
jwks_uri = self.discovery_document['jwks_uri']
jwks_data = await self._get_jwks(jwks_uri)
if not jwks_data:
return None
try:
unverified_header = jwt.get_unverified_header(id_token)
if not unverified_header:
print("Could not parse JWT Header")
return None
kid = unverified_header.get('kid')
if not kid:
print("JWT does not have kid (Key ID)")
return None
# Get the correct key
rsa_key = None
for key in jwks_data["keys"]:
if key["kid"] == kid:
rsa_key = key
break
if not rsa_key:
print(f"Could not find matching key with kid:{kid}")
return None
# Construct the JWK
jwk_obj = jwk.construct(rsa_key)
# Verify the token
decoded_token = jwt.decode(
id_token,
jwk_obj,
algorithms=["RS256"], # Adjust if your algorithm is different
audience=self.client_id,
issuer=self.discovery_document['issuer'],
)
return decoded_token
except jwt.JWTError as e:
print(f"JWT Verification failed: {e}")
return None
except Exception as e:
print(f"Unexpected error: {e}")
return None
async def complete_token_flow(self, base_uri, code, state):
if state not in self.flows:
return None
flow = self.flows[state]
code_verifier = flow['code_verifier']
if not hasattr(self, 'discovery_document'):
self.discovery_document = await self.fetch_discovery_document()
if not self.discovery_document:
return None
token_endpoint = self.discovery_document['token_endpoint']
# Construct the params
query_params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'code': code,
'redirect_uri': base_uri + '/auth/oidc/callback',
'code_verifier': code_verifier,
}
_LOGGER.debug(f"Token request params: {query_params}")
token_response = await self._make_token_request(token_endpoint, query_params)
if not token_response:
return None
access_token = token_response.get('access_token')
id_token = token_response.get('id_token')
_LOGGER.debug(f"Access Token: {access_token}")
_LOGGER.debug(f"ID Token: {id_token}")
# Parse the id token to obtain the relevant details
id_token = await self._parse_id_token(id_token)
# Verify nonce
if id_token.get('nonce') != flow['nonce']:
_LOGGER.warning(f"Nonce mismatch!")
return None
return {
"name": id_token.get("name"),
"email": id_token.get("email"),
"preferred_username": id_token.get("preferred_username"),
"nickname": id_token.get("nickname"),
"groups": id_token.get("groups"),
}

View File

@@ -2,18 +2,22 @@
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
"""
import logging
from secrets import token_hex
from typing import Any, Dict, Optional, cast
from typing import Dict, Optional
from homeassistant.auth.providers import (
AUTH_PROVIDERS,
AuthProvider,
LoginFlow,
AuthFlowResult,
Credentials,
UserMeta,
)
from homeassistant.exceptions import HomeAssistantError
import voluptuous as vol
from homeassistant.helpers.network import get_url
from .callback import async_register_view, AUTH_CALLBACK_PATH
from datetime import datetime, timedelta
import random
import string
from homeassistant.helpers.storage import Store
from collections.abc import Mapping
_LOGGER = logging.getLogger(__name__)
@@ -24,7 +28,12 @@ class InvalidAuthError(HomeAssistantError):
class OpenIDAuthProvider(AuthProvider):
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
DEFAULT_TITLE = "OpenID Connect"
DEFAULT_TITLE = "OpenID Connect (SSO)"
def __init__(self, *args, **kwargs):
"""Initialize the OpenIDAuthProvider."""
super().__init__(*args, **kwargs)
self._user_meta = {}
@property
def type(self) -> str:
@@ -33,43 +42,130 @@ class OpenIDAuthProvider(AuthProvider):
@property
def support_mfa(self) -> bool:
return False
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login."""
async_register_view(self.hass)
return OpenIdLoginFlow(self)
async def async_get_or_create_credentials(
self, flow_result: Mapping[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result["username"]
for credential in await self.async_credentials():
if credential.data["username"] == username:
return credential
# Create new credentials.
return self.async_create_credentials({"username": username})
async def async_user_meta_for_credentials(
self, credentials: Credentials
) -> UserMeta:
"""Return extra user metadata for credentials.
Currently, supports name, group and local_only.
"""
meta = self._user_meta.get(credentials.data["username"], {})
groups = meta.get("groups", [])
group = "system-admin" if "admins" in groups else "system-users"
return UserMeta(
name=meta.get("name"),
is_active=True,
group=group,
local_only="true",
)
async def save_user_info(self, user_info: dict) -> str:
"""Save user info during login."""
_LOGGER.info("User info to be saved: %s", user_info)
code = self._generate_code()
expiration = datetime.utcnow() + timedelta(minutes=5)
user_data = {
"user_info": user_info,
"code": code,
"expiration": expiration.isoformat()
}
await self._save_to_db(self._get_code_key(code), user_data)
return code
async def async_retrieve_username(self, code: str) -> Optional[dict]:
"""Retrieve user info based on the code."""
user_data = await self._get_from_db(self._get_code_key(code))
await self._wipe_from_db(self._get_code_key(code))
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
username = user_data["user_info"]["preferred_username"]
self._user_meta[username] = user_data["user_info"]
return username
return None
def _generate_code(self) -> str:
"""Generate a random six-digit code."""
return ''.join(random.choices(string.digits, k=6))
def _get_code_key(self, code: str) -> str:
return f"provider_oidc_auth_user_{code}"
async def _save_to_db(self, key: str, value: dict) -> None:
"""Save key-value data to the Home Assistant storage."""
store = Store(self.hass, 1, key)
await store.async_save(value)
async def _get_from_db(self, key: str) -> Optional[dict]:
"""Retrieve key-value data from the Home Assistant storage."""
store = Store(self.hass, 1, key)
return await store.async_load()
async def _wipe_from_db(self, key: str) -> None:
"""Delete key-value data from the Home Assistant storage."""
store = Store(self.hass, 1, key)
return await store.async_remove()
class OpenIdLoginFlow(LoginFlow):
"""Handler for the login flow."""
external_data: Any
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> AuthFlowResult:
"""Handle the step of the form."""
return await self.async_step_authenticate()
def redirect_uri(self) -> str:
"""Return the redirect uri."""
return f"{get_url(self.hass, allow_external=True, require_current_request=True)}{AUTH_CALLBACK_PATH}?test=value&flow_id={self.flow_id}"
# Show the login form
# Currently, this form looks bad because the frontend gives no options to make it look better
# We will investigate options to make it look better in the future
return self.async_show_form(
step_id="mfa",
data_schema=vol.Schema(
{
vol.Required("code"): str,
}
),
errors={},
)
async def async_step_authenticate(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Authenticate user using external step."""
async def async_step_mfa(
self, user_input: dict[str, str] | None = None
) -> AuthFlowResult:
"""Handle the result of the form."""
if user_input:
self.external_data = str(user_input)
return self.async_external_step_done(next_step_id="authorize")
if user_input is None:
return self.async_abort(reason="no_code_given")
return self.async_external_step(step_id="authenticate", url=self.redirect_uri())
# Log
_LOGGER.info("User input %s", user_input)
_LOGGER.info("Code %s was entered", user_input["code"])
async def async_step_authorize(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Authorize user received from external step."""
_LOGGER.debug(self.external_data)
return self.async_abort(reason="invalid_auth")
username = await self._auth_provider.async_retrieve_username(user_input["code"])
if username:
_LOGGER.info("Logged in user: %s", username)
return await self.async_finish({
"username": username,
})
return self.async_abort(reason="invalid_code")