Implement initial flow (#2)
This commit is contained in:
committed by
GitHub
parent
1c8c7ed14a
commit
8ba494c49c
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
Copyright 2022 Christiaan Goossens
|
||||
Copyright 2024 Christiaan Goossens
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
|
||||
66
README.md
66
README.md
@@ -1,28 +1,70 @@
|
||||
# OIDC Auth for Home Assistant
|
||||
|
||||
Status: in progress, but very slowly.
|
||||
> [!CAUTION]
|
||||
> This is a pre-alpha release. I give no guarantees about code quality, error handling or security at this stage. Please treat this repo as a proof of concept for now and only use it on development HA installs.
|
||||
|
||||
Current roadblocks:
|
||||
Provides an OIDC implementation for Home Assistant.
|
||||
|
||||
- [ ] Find a way to do a redirect within the login step in Home Assistant, we should not use window.open
|
||||
- [ ] Find out how to make this redirect work on all platforms (including mobile)
|
||||
### Background
|
||||
If you would like to read the background/open letter that lead to this component, please see https://community.home-assistant.io/t/open-letter-for-improving-home-assistants-authentication-system-oidc-sso/494223. It is currently one of the most upvoted feature requests for Home Assistant.
|
||||
|
||||
If this is solved, implementing OIDC itself is doable.
|
||||
|
||||
If you have any tips or would like to contribute, send me a message.
|
||||
|
||||
## Installation
|
||||
## How to use
|
||||
### Installation
|
||||
|
||||
Add this repository to [HACS](https://hacs.xyz/).
|
||||
|
||||
Update your configuration.yaml file with
|
||||
Update your `configuration.yaml` file with
|
||||
|
||||
```yaml
|
||||
auth_oidc:
|
||||
client_id: ""
|
||||
discovery_url: ""
|
||||
```
|
||||
|
||||
Register your client with your OIDC Provider (e.g. Authentik/Authelia) as a public client and get the client_id. Then, use the obtained client_id and discovery URLs to fill the fields in `configuration.yaml`.
|
||||
|
||||
For example:
|
||||
```yaml
|
||||
auth_oidc:
|
||||
client_id: "someValueForTheClientId"
|
||||
discovery_url: "https://example.com/application/o/application/.well-known/openid-configuration"
|
||||
```
|
||||
|
||||
Afterwards, restart Home Assistant.
|
||||
|
||||
### Login
|
||||
You should now be able to see a second option on your login screen ("OpenID Connect (SSO)"). It provides you with a single input field.
|
||||
|
||||
Sadly, the user experience is pretty poor right now. Go to `/auth/oidc/welcome` (for example `https://hass.io/auth/oidc/welcome`, replace the URL with your Home Assistant URL) and follow the prompts provided to login, then copy the code into the input field from before. You should now login automatically with your username from SSO.
|
||||
|
||||
> [!TIP]
|
||||
> You can use a different device to login instead. Open the `/auth/oidc/welcome` link on device A and then type the obtained code into the normal HA login on device B (can also be the mobile app) to login.
|
||||
|
||||
## Development
|
||||
This package uses poetry: https://github.com/python-poetry/poetry. Use `poetry install` to install.
|
||||
You can force the venv within the project with `poetry config virtualenvs.in-project true`.
|
||||
This project uses the Rye package manager for development. You can find installation instructions here: https://rye.astral.sh/guide/installation/.
|
||||
Start by installing the dependencies using `rye sync` and then point your editor towards the environment created in the `.venv` directory.
|
||||
|
||||
### Help wanted
|
||||
If you have any tips or would like to contribute, send me a message. You are also welcome to contribute a PR to fix any of the TODOs.
|
||||
|
||||
Currently, this is a pre-alpha, so I welcome issues but I cannot guarantee I can fix them (at least within a reasonable time). Please turn on watch for this repository to remain updated. When the component is in a beta stage, issues will likely get fixed more frequently.
|
||||
|
||||
### TODOs
|
||||
|
||||
- [X] Basic flow
|
||||
- [ ] Improve welcome screen UI, should render a simple centered Tailwind UI instructing users that you should login externally to obtain a code.
|
||||
- [ ] Improve finish screen UI, showing the code clearly with a copy button and instructions to paste it into Home Assistant.
|
||||
- [ ] Implement error handling on top of this proof of concept (discovery, JWKS, OIDC)
|
||||
- [ ] Make id_token claim used for the group (admin/user) configurable
|
||||
- [ ] Make id_token claim used for the username configurable
|
||||
- [ ] Make id_token claim used for the name configurable
|
||||
- [ ] Add instructions on how to deploy this with Authentik & Authelia
|
||||
- [ ] Configure Github Actions to automatically lint and build the package
|
||||
- [ ] Configure Dependabot for automatic updates
|
||||
|
||||
Currently impossible TODOs (waiting for assistance from HA devs, not possible without forking HA frontend & apps right now):
|
||||
|
||||
- [ ] Update the HA frontend code to allow a redirection to be requested from an auth provider instead of manually opening welcome page
|
||||
- [ ] Implement this redirection logic to open a new tab on desktop
|
||||
- [ ] Implement this redirection logic to open a Android Custom Tab (Android) / SFSafariViewController (iOS), instead of opening the link in the HA webview
|
||||
- [ ] Implement a final redirect back to the main page with the code as a query param instead of showing the finalize page
|
||||
|
||||
@@ -4,6 +4,13 @@ from typing import OrderedDict
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .endpoints.welcome import OIDCWelcomeView
|
||||
from .endpoints.redirect import OIDCRedirectView
|
||||
from .endpoints.finish import OIDCFinishView
|
||||
from .endpoints.callback import OIDCCallbackView
|
||||
|
||||
from .oidc_client import OIDCClient
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,7 +20,9 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
|
||||
vol.Required("client_id"): vol.Coerce(str),
|
||||
vol.Optional("client_secret"): vol.Coerce(str),
|
||||
vol.Required("discovery_url"): vol.Url(),
|
||||
}
|
||||
)
|
||||
},
|
||||
@@ -34,5 +43,18 @@ async def async_setup(hass: HomeAssistant, config):
|
||||
providers.update(hass.auth._providers)
|
||||
hass.auth._providers = providers
|
||||
|
||||
_LOGGER.debug("Added OIDC provider")
|
||||
_LOGGER.debug("Added OIDC provider for Home Assistant")
|
||||
|
||||
# Define some fields
|
||||
discovery_url = config[DOMAIN]["discovery_url"]
|
||||
client_id = config[DOMAIN]["client_id"]
|
||||
scope = "openid profile email"
|
||||
|
||||
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)
|
||||
|
||||
hass.http.register_view(OIDCWelcomeView())
|
||||
hass.http.register_view(OIDCRedirectView(oidc_client))
|
||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
|
||||
hass.http.register_view(OIDCFinishView())
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
import logging
|
||||
|
||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||
AUTH_CALLBACK_PATH = "/auth/oidc/callback"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@callback
|
||||
def async_register_view(hass: HomeAssistant) -> None:
|
||||
"""Make sure callback view is registered."""
|
||||
if not hass.data.get(DATA_VIEW_REGISTERED, False):
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
||||
hass.data[DATA_VIEW_REGISTERED] = True
|
||||
|
||||
|
||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
"""OAuth2 Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = AUTH_CALLBACK_PATH
|
||||
name = "auth:oidc:callback"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug(request.query)
|
||||
|
||||
hass = request.app["hass"]
|
||||
flow_mgr = hass.auth.login_flow
|
||||
|
||||
await flow_mgr.async_configure(
|
||||
flow_id=request.query["flow_id"], user_input=request.query["test"]
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<script>if (window.opener) { window.opener.postMessage({type: 'externalCallback'}); } window.close();</script>",
|
||||
)
|
||||
49
custom_components/auth_oidc/endpoints/callback.py
Normal file
49
custom_components/auth_oidc/endpoints/callback.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..provider import OpenIDAuthProvider
|
||||
|
||||
PATH = "/auth/oidc/callback"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCCallbackView(HomeAssistantView):
|
||||
"""OIDC Plugin Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:callback"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient, oidc_provider: OpenIDAuthProvider
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
self.oidc_provider = oidc_provider
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Callback view accessed")
|
||||
|
||||
params = request.rel_url.query
|
||||
code = params.get("code")
|
||||
state = params.get("state")
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
|
||||
if not (code and state):
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Error</h1><p>Missing code or state parameter</p>",
|
||||
)
|
||||
|
||||
user_details = await self.oidc_client.complete_token_flow(base_uri, code, state)
|
||||
if user_details is None:
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Error</h1><p>Failed to get user details, see console.</p>",
|
||||
)
|
||||
|
||||
code = await self.oidc_provider.save_user_info(user_details)
|
||||
|
||||
return web.HTTPFound(base_uri + "/auth/oidc/finish?code=" + code)
|
||||
24
custom_components/auth_oidc/endpoints/finish.py
Normal file
24
custom_components/auth_oidc/endpoints/finish.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
PATH = "/auth/oidc/finish"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCFinishView(HomeAssistantView):
|
||||
"""OIDC Plugin Finish View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:finish"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
code = request.query.get("code", "FAIL")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p><p>Please return to the Home Assistant login screen (or your mobile app) and fill in this code into the single login field. It should be visible if you select 'Login with OpenID Connect (SSO)'.</p>",
|
||||
)
|
||||
46
custom_components/auth_oidc/endpoints/redirect.py
Normal file
46
custom_components/auth_oidc/endpoints/redirect.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
from ..oidc_client import OIDCClient
|
||||
|
||||
PATH = "/auth/oidc/redirect"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCRedirectView(HomeAssistantView):
|
||||
"""OIDC Plugin Redirect View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:redirect"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Redirect view accessed")
|
||||
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
_LOGGER.debug("Base URI: %s", base_uri)
|
||||
|
||||
auth_url = await self.oidc_client.get_authorization_url(base_uri)
|
||||
_LOGGER.debug("Auth URL: %s", auth_url)
|
||||
|
||||
if auth_url:
|
||||
return web.HTTPFound(auth_url)
|
||||
else:
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Plugin is misconfigured, discovery could not be obtained</h1>",
|
||||
)
|
||||
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""POST"""
|
||||
|
||||
_LOGGER.debug("Redirect POST view accessed")
|
||||
return await self.get(request)
|
||||
24
custom_components/auth_oidc/endpoints/welcome.py
Normal file
24
custom_components/auth_oidc/endpoints/welcome.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
PATH = "/auth/oidc/welcome"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCWelcomeView(HomeAssistantView):
|
||||
"""OIDC Plugin Welcome View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Welcome view accessed")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>OIDC Login (beta)</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
|
||||
)
|
||||
204
custom_components/auth_oidc/oidc_client.py
Normal file
204
custom_components/auth_oidc/oidc_client.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import aiohttp
|
||||
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from jose import jwt
|
||||
|
||||
from jose import jwk, jwt
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCClient:
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url, client_id, scope):
|
||||
self.discovery_url = discovery_url
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
||||
else:
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def get_authorization_url(self, base_uri):
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
auth_endpoint = self.discovery_document['authorization_endpoint']
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
||||
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {
|
||||
'code_verifier': code_verifier,
|
||||
'nonce': nonce
|
||||
}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'scope': self.scope,
|
||||
'state': state,
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
response_json = await response.json()
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def _parse_id_token(self, id_token):
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
jwks_uri = self.discovery_document['jwks_uri']
|
||||
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
if not jwks_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
if not unverified_header:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get('kid')
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
if key["kid"] == kid:
|
||||
rsa_key = key
|
||||
break
|
||||
|
||||
if not rsa_key:
|
||||
print(f"Could not find matching key with kid:{kid}")
|
||||
return None
|
||||
|
||||
# Construct the JWK
|
||||
jwk_obj = jwk.construct(rsa_key)
|
||||
|
||||
# Verify the token
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document['issuer'],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def complete_token_flow(self, base_uri, code, state):
|
||||
if state not in self.flows:
|
||||
return None
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow['code_verifier']
|
||||
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
token_endpoint = self.discovery_document['token_endpoint']
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'code': code,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'code_verifier': code_verifier,
|
||||
}
|
||||
|
||||
_LOGGER.debug(f"Token request params: {query_params}")
|
||||
|
||||
token_response = await self._make_token_request(token_endpoint, query_params)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
access_token = token_response.get('access_token')
|
||||
id_token = token_response.get('id_token')
|
||||
_LOGGER.debug(f"Access Token: {access_token}")
|
||||
_LOGGER.debug(f"ID Token: {id_token}")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get('nonce') != flow['nonce']:
|
||||
_LOGGER.warning(f"Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"email": id_token.get("email"),
|
||||
"preferred_username": id_token.get("preferred_username"),
|
||||
"nickname": id_token.get("nickname"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
@@ -2,18 +2,22 @@
|
||||
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
|
||||
"""
|
||||
import logging
|
||||
from secrets import token_hex
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Dict, Optional
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
LoginFlow,
|
||||
AuthFlowResult,
|
||||
Credentials,
|
||||
UserMeta,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
||||
from .callback import async_register_view, AUTH_CALLBACK_PATH
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import string
|
||||
from homeassistant.helpers.storage import Store
|
||||
from collections.abc import Mapping
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +28,12 @@ class InvalidAuthError(HomeAssistantError):
|
||||
class OpenIDAuthProvider(AuthProvider):
|
||||
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect"
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@@ -33,43 +42,130 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
|
||||
async_register_view(self.hass)
|
||||
return OpenIdLoginFlow(self)
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Mapping[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
for credential in await self.async_credentials():
|
||||
if credential.data["username"] == username:
|
||||
return credential
|
||||
|
||||
# Create new credentials.
|
||||
return self.async_create_credentials({"username": username})
|
||||
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials
|
||||
) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Currently, supports name, group and local_only.
|
||||
"""
|
||||
meta = self._user_meta.get(credentials.data["username"], {})
|
||||
groups = meta.get("groups", [])
|
||||
|
||||
group = "system-admin" if "admins" in groups else "system-users"
|
||||
return UserMeta(
|
||||
name=meta.get("name"),
|
||||
is_active=True,
|
||||
group=group,
|
||||
local_only="true",
|
||||
)
|
||||
|
||||
async def save_user_info(self, user_info: dict) -> str:
|
||||
"""Save user info during login."""
|
||||
_LOGGER.info("User info to be saved: %s", user_info)
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
user_data = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat()
|
||||
}
|
||||
|
||||
await self._save_to_db(self._get_code_key(code), user_data)
|
||||
return code
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[dict]:
|
||||
"""Retrieve user info based on the code."""
|
||||
user_data = await self._get_from_db(self._get_code_key(code))
|
||||
await self._wipe_from_db(self._get_code_key(code))
|
||||
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
|
||||
username = user_data["user_info"]["preferred_username"]
|
||||
self._user_meta[username] = user_data["user_info"]
|
||||
return username
|
||||
return None
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return ''.join(random.choices(string.digits, k=6))
|
||||
|
||||
def _get_code_key(self, code: str) -> str:
|
||||
return f"provider_oidc_auth_user_{code}"
|
||||
|
||||
async def _save_to_db(self, key: str, value: dict) -> None:
|
||||
"""Save key-value data to the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
await store.async_save(value)
|
||||
|
||||
async def _get_from_db(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_load()
|
||||
|
||||
async def _wipe_from_db(self, key: str) -> None:
|
||||
"""Delete key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_remove()
|
||||
|
||||
|
||||
class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
external_data: Any
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
return await self.async_step_authenticate()
|
||||
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
return f"{get_url(self.hass, allow_external=True, require_current_request=True)}{AUTH_CALLBACK_PATH}?test=value&flow_id={self.flow_id}"
|
||||
# Show the login form
|
||||
# Currently, this form looks bad because the frontend gives no options to make it look better
|
||||
# We will investigate options to make it look better in the future
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors={},
|
||||
)
|
||||
|
||||
|
||||
async def async_step_authenticate(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Authenticate user using external step."""
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the result of the form."""
|
||||
|
||||
if user_input:
|
||||
self.external_data = str(user_input)
|
||||
return self.async_external_step_done(next_step_id="authorize")
|
||||
if user_input is None:
|
||||
return self.async_abort(reason="no_code_given")
|
||||
|
||||
return self.async_external_step(step_id="authenticate", url=self.redirect_uri())
|
||||
# Log
|
||||
_LOGGER.info("User input %s", user_input)
|
||||
_LOGGER.info("Code %s was entered", user_input["code"])
|
||||
|
||||
async def async_step_authorize(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Authorize user received from external step."""
|
||||
_LOGGER.debug(self.external_data)
|
||||
return self.async_abort(reason="invalid_auth")
|
||||
username = await self._auth_provider.async_retrieve_username(user_input["code"])
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
|
||||
return await self.async_finish({
|
||||
"username": username,
|
||||
})
|
||||
|
||||
return self.async_abort(reason="invalid_code")
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "OpenID Connect",
|
||||
"render_readme": true,
|
||||
"homeassistant": "2022.11"
|
||||
}
|
||||
"homeassistant": "2024.12"
|
||||
}
|
||||
1706
poetry.lock
generated
1706
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,30 @@
|
||||
[tool.poetry]
|
||||
[project]
|
||||
name = "hass-oidc"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Christiaan Goossens <contact@christiaangoossens.nl>"]
|
||||
description = "OIDC component for Home Assistant"
|
||||
authors = [
|
||||
{ name = "Christiaan Goossens", email = "contact@christiaangoossens.nl" }
|
||||
]
|
||||
license = "MIT"
|
||||
dependencies = [
|
||||
"python-jose>=3.3.0",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "3.10.*"
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
homeassistant = "^2022.11.4"
|
||||
pylint = "^2.15.6"
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = [
|
||||
"homeassistant~=2024.12",
|
||||
"pylint~=3.3",
|
||||
]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
autopep8 = "^2.0.0"
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["custom_components/auth_oidc"]
|
||||
|
||||
282
requirements-dev.lock
Normal file
282
requirements-dev.lock
Normal file
@@ -0,0 +1,282 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
# universal: false
|
||||
|
||||
-e file:.
|
||||
acme==3.0.1
|
||||
# via hass-nabucasa
|
||||
aiodns==3.2.0
|
||||
# via homeassistant
|
||||
aiohappyeyeballs==2.4.4
|
||||
# via aiohttp
|
||||
aiohasupervisor==0.2.1
|
||||
# via homeassistant
|
||||
aiohttp==3.11.11
|
||||
# via aiohasupervisor
|
||||
# via aiohttp-cors
|
||||
# via aiohttp-fast-zlib
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
# via snitun
|
||||
aiohttp-cors==0.7.0
|
||||
# via homeassistant
|
||||
aiohttp-fast-zlib==0.2.0
|
||||
# via homeassistant
|
||||
aiooui==0.1.7
|
||||
# via bluetooth-adapters
|
||||
aiosignal==1.3.2
|
||||
# via aiohttp
|
||||
aiozoneinfo==0.2.1
|
||||
# via homeassistant
|
||||
anyio==4.7.0
|
||||
# via httpx
|
||||
astral==2.2
|
||||
# via homeassistant
|
||||
astroid==3.3.8
|
||||
# via pylint
|
||||
async-interrupt==1.2.0
|
||||
# via habluetooth
|
||||
# via homeassistant
|
||||
async-timeout==5.0.1
|
||||
# via snitun
|
||||
atomicwrites-homeassistant==1.4.1
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
attrs==24.2.0
|
||||
# via aiohttp
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
# via snitun
|
||||
audioop-lts==0.2.1
|
||||
# via homeassistant
|
||||
# via standard-aifc
|
||||
awesomeversion==24.6.0
|
||||
# via homeassistant
|
||||
bcrypt==4.2.0
|
||||
# via homeassistant
|
||||
bleak==0.22.3
|
||||
# via bleak-retry-connector
|
||||
# via bluetooth-adapters
|
||||
# via habluetooth
|
||||
bleak-retry-connector==3.6.0
|
||||
# via habluetooth
|
||||
bluetooth-adapters==0.20.2
|
||||
# via bleak-retry-connector
|
||||
# via bluetooth-auto-recovery
|
||||
# via habluetooth
|
||||
bluetooth-auto-recovery==1.4.2
|
||||
# via habluetooth
|
||||
bluetooth-data-tools==1.20.0
|
||||
# via habluetooth
|
||||
boto3==1.35.87
|
||||
# via pycognito
|
||||
botocore==1.35.87
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
btsocket==0.3.0
|
||||
# via bluetooth-auto-recovery
|
||||
certifi==2024.12.14
|
||||
# via homeassistant
|
||||
# via httpcore
|
||||
# via httpx
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via cryptography
|
||||
# via pycares
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
ciso8601==2.3.1
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
cryptography==43.0.1
|
||||
# via acme
|
||||
# via bluetooth-data-tools
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
# via josepy
|
||||
# via pyjwt
|
||||
# via pyopenssl
|
||||
# via securetar
|
||||
# via snitun
|
||||
dbus-fast==2.24.4
|
||||
# via bleak
|
||||
# via bleak-retry-connector
|
||||
# via bluetooth-adapters
|
||||
dill==0.3.9
|
||||
# via pylint
|
||||
ecdsa==0.19.0
|
||||
# via python-jose
|
||||
envs==1.4
|
||||
# via pycognito
|
||||
fnv-hash-fast==1.0.2
|
||||
# via homeassistant
|
||||
fnvhash==0.1.0
|
||||
# via fnv-hash-fast
|
||||
frozenlist==1.5.0
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
habluetooth==3.6.0
|
||||
# via home-assistant-bluetooth
|
||||
hass-nabucasa==0.86.0
|
||||
# via homeassistant
|
||||
home-assistant-bluetooth==1.13.0
|
||||
# via homeassistant
|
||||
homeassistant==2024.12.5
|
||||
httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.27.2
|
||||
# via homeassistant
|
||||
idna==3.10
|
||||
# via anyio
|
||||
# via httpx
|
||||
# via requests
|
||||
# via yarl
|
||||
ifaddr==0.2.0
|
||||
# via homeassistant
|
||||
isort==5.13.2
|
||||
# via pylint
|
||||
jinja2==3.1.4
|
||||
# via homeassistant
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
josepy==1.14.0
|
||||
# via acme
|
||||
lru-dict==1.3.0
|
||||
# via homeassistant
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mashumaro==3.15
|
||||
# via aiohasupervisor
|
||||
# via webrtc-models
|
||||
mccabe==0.7.0
|
||||
# via pylint
|
||||
multidict==6.1.0
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
orjson==3.10.12
|
||||
# via aiohasupervisor
|
||||
# via homeassistant
|
||||
# via webrtc-models
|
||||
packaging==24.2
|
||||
# via homeassistant
|
||||
pillow==11.0.0
|
||||
# via homeassistant
|
||||
platformdirs==4.3.6
|
||||
# via pylint
|
||||
propcache==0.2.1
|
||||
# via aiohttp
|
||||
# via homeassistant
|
||||
# via yarl
|
||||
psutil==6.1.1
|
||||
# via psutil-home-assistant
|
||||
psutil-home-assistant==0.0.1
|
||||
# via homeassistant
|
||||
pyasn1==0.6.1
|
||||
# via python-jose
|
||||
# via rsa
|
||||
pycares==4.5.0
|
||||
# via aiodns
|
||||
pycognito==2024.5.1
|
||||
# via hass-nabucasa
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pyjwt==2.10.1
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
# via pycognito
|
||||
pylint==3.3.3
|
||||
pyopenssl==24.2.1
|
||||
# via acme
|
||||
# via homeassistant
|
||||
# via josepy
|
||||
pyrfc3339==2.0.1
|
||||
# via acme
|
||||
pyric==0.1.6.3
|
||||
# via bluetooth-auto-recovery
|
||||
python-dateutil==2.9.0.post0
|
||||
# via botocore
|
||||
python-jose==3.3.0
|
||||
# via hass-oidc
|
||||
python-slugify==8.0.4
|
||||
# via homeassistant
|
||||
pytz==2024.2
|
||||
# via acme
|
||||
# via astral
|
||||
pyyaml==6.0.2
|
||||
# via homeassistant
|
||||
requests==2.32.3
|
||||
# via acme
|
||||
# via homeassistant
|
||||
# via pycognito
|
||||
rsa==4.9
|
||||
# via python-jose
|
||||
s3transfer==0.10.4
|
||||
# via boto3
|
||||
securetar==2024.11.0
|
||||
# via homeassistant
|
||||
setuptools==75.6.0
|
||||
# via acme
|
||||
six==1.17.0
|
||||
# via ecdsa
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
# via httpx
|
||||
snitun==0.39.1
|
||||
# via hass-nabucasa
|
||||
sqlalchemy==2.0.36
|
||||
# via homeassistant
|
||||
standard-aifc==3.13.0
|
||||
# via homeassistant
|
||||
standard-chunk==3.13.0
|
||||
# via standard-aifc
|
||||
standard-telnetlib==3.13.0
|
||||
# via homeassistant
|
||||
text-unidecode==1.3
|
||||
# via python-slugify
|
||||
tomlkit==0.13.2
|
||||
# via pylint
|
||||
typing-extensions==4.12.2
|
||||
# via homeassistant
|
||||
# via mashumaro
|
||||
# via sqlalchemy
|
||||
tzdata==2024.2
|
||||
# via aiozoneinfo
|
||||
uart-devices==0.1.0
|
||||
# via bluetooth-adapters
|
||||
ulid-transform==1.0.2
|
||||
# via homeassistant
|
||||
urllib3==1.26.20
|
||||
# via botocore
|
||||
# via homeassistant
|
||||
# via requests
|
||||
usb-devices==0.4.5
|
||||
# via bluetooth-adapters
|
||||
# via bluetooth-auto-recovery
|
||||
uv==0.5.4
|
||||
# via homeassistant
|
||||
voluptuous==0.15.2
|
||||
# via homeassistant
|
||||
# via voluptuous-openapi
|
||||
# via voluptuous-serialize
|
||||
voluptuous-openapi==0.0.5
|
||||
# via homeassistant
|
||||
voluptuous-serialize==2.6.0
|
||||
# via homeassistant
|
||||
webrtc-models==0.3.0
|
||||
# via hass-nabucasa
|
||||
# via homeassistant
|
||||
yarl==1.18.3
|
||||
# via aiohasupervisor
|
||||
# via aiohttp
|
||||
# via homeassistant
|
||||
23
requirements.lock
Normal file
23
requirements.lock
Normal file
@@ -0,0 +1,23 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
# universal: false
|
||||
|
||||
-e file:.
|
||||
ecdsa==0.19.0
|
||||
# via python-jose
|
||||
pyasn1==0.6.1
|
||||
# via python-jose
|
||||
# via rsa
|
||||
python-jose==3.3.0
|
||||
# via hass-oidc
|
||||
rsa==4.9
|
||||
# via python-jose
|
||||
six==1.17.0
|
||||
# via ecdsa
|
||||
Reference in New Issue
Block a user