Reimplement UI injection (#236)
This commit is contained in:
committed by
GitHub
parent
fdc93e2719
commit
fd3643685d
@@ -5,3 +5,4 @@ from .finish import OIDCFinishView as OIDCFinishView
|
||||
from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage
|
||||
from .redirect import OIDCRedirectView as OIDCRedirectView
|
||||
from .welcome import OIDCWelcomeView as OIDCWelcomeView
|
||||
from .device_sse import OIDCDeviceSSE as OIDCDeviceSSE
|
||||
|
||||
@@ -4,7 +4,7 @@ from homeassistant.components.http import HomeAssistantView
|
||||
from aiohttp import web
|
||||
from ..tools.oidc_client import OIDCClient
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..tools.helpers import get_url, get_view
|
||||
from ..tools.helpers import error_response, get_url, get_valid_state_id
|
||||
|
||||
PATH = "/auth/oidc/callback"
|
||||
|
||||
@@ -29,42 +29,49 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
# Get cookie to get the state_id
|
||||
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||
if not state_id:
|
||||
return await error_response("Missing state cookie, please restart login.")
|
||||
|
||||
# Get the OIDC query parameters
|
||||
params = request.rel_url.query
|
||||
code = params.get("code")
|
||||
state = params.get("state")
|
||||
|
||||
if not (code and state):
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
{
|
||||
"error": "Missing code or state parameter.",
|
||||
},
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
return await error_response("Missing code or state parameter.")
|
||||
|
||||
# Check if the states match
|
||||
if state != state_id:
|
||||
return await error_response(
|
||||
"State parameter does not match, possible CSRF attack."
|
||||
)
|
||||
|
||||
# Complete the OIDC flow to get user details
|
||||
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||
user_details = await self.oidc_client.async_complete_token_flow(
|
||||
redirect_uri, code, state
|
||||
)
|
||||
if user_details is None:
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
{
|
||||
"error": "Failed to get user details, "
|
||||
+ "see Home Assistant logs for more information.",
|
||||
},
|
||||
return await error_response(
|
||||
"Failed to get user details, see Home Assistant logs for more information.",
|
||||
status=500,
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
if user_details.get("role") == "invalid":
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
{
|
||||
"error": "User is not in the correct group to access Home Assistant, "
|
||||
+ "contact your administrator!",
|
||||
},
|
||||
return await error_response(
|
||||
"User is not in the correct group to access Home Assistant, "
|
||||
+ "contact your administrator!",
|
||||
status=403,
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
code = await self.oidc_provider.async_save_user_info(user_details)
|
||||
raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https))
|
||||
# Finalize on the state
|
||||
success = await self.oidc_provider.async_save_user_info(state_id, user_details)
|
||||
if not success:
|
||||
return await error_response(
|
||||
"Failed to save user information, session probably expired. Please sign in again.",
|
||||
status=500,
|
||||
)
|
||||
|
||||
raise web.HTTPFound(get_url("/auth/oidc/finish", self.force_https))
|
||||
|
||||
70
custom_components/auth_oidc/endpoints/device_sse.py
Normal file
70
custom_components/auth_oidc/endpoints/device_sse.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""SSE handler for OIDC device authentication."""
|
||||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..tools.helpers import get_valid_state_id
|
||||
|
||||
PATH = "/auth/oidc/device-sse"
|
||||
|
||||
|
||||
class OIDCDeviceSSE(HomeAssistantView):
|
||||
"""OIDC Plugin SSE Handler."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:device-sse"
|
||||
|
||||
def __init__(self, oidc_provider: OpenIDAuthProvider) -> None:
|
||||
self.oidc_provider = oidc_provider
|
||||
|
||||
async def get(self, req: web.Request) -> web.Response:
|
||||
"""Check for mobile sign-in completion with short server-side polling."""
|
||||
state_id = await get_valid_state_id(req, self.oidc_provider)
|
||||
if not state_id:
|
||||
raise web.HTTPBadRequest(text="Missing session cookie")
|
||||
|
||||
timeout_seconds = 300
|
||||
started_at = asyncio.get_running_loop().time()
|
||||
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
await response.prepare(req)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if (
|
||||
await self.oidc_provider.async_get_redirect_uri_for_state(state_id)
|
||||
is None
|
||||
):
|
||||
await response.write(b"event: expired\ndata: false\n\n")
|
||||
break
|
||||
|
||||
ready = await self.oidc_provider.async_is_state_ready(state_id)
|
||||
if ready:
|
||||
await response.write(b"event: ready\ndata: true\n\n")
|
||||
break
|
||||
|
||||
if asyncio.get_running_loop().time() - started_at >= timeout_seconds:
|
||||
await response.write(b"event: timeout\ndata: false\n\n")
|
||||
break
|
||||
|
||||
await response.write(b"event: waiting\ndata: false\n\n")
|
||||
await asyncio.sleep(0.5)
|
||||
except (ConnectionResetError, RuntimeError):
|
||||
# Client disconnected while listening for state changes.
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
await response.write_eof()
|
||||
except ConnectionResetError:
|
||||
pass
|
||||
|
||||
return response
|
||||
@@ -2,7 +2,12 @@
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from aiohttp import web
|
||||
from ..tools.helpers import get_view
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..tools.helpers import (
|
||||
error_response,
|
||||
get_valid_state_id,
|
||||
template_response,
|
||||
)
|
||||
|
||||
PATH = "/auth/oidc/finish"
|
||||
|
||||
@@ -14,41 +19,62 @@ class OIDCFinishView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:finish"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oidc_provider: OpenIDAuthProvider,
|
||||
) -> None:
|
||||
self.oidc_provider = oidc_provider
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Show the finish screen to allow the user to view their code."""
|
||||
"""Show the finish screen to pick between login & device code."""
|
||||
# Get cookie to get the state_id
|
||||
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||
if not state_id:
|
||||
return await error_response("Missing state cookie, please restart login.")
|
||||
|
||||
code = request.query.get("code")
|
||||
|
||||
if not code:
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
{"error": "Missing code to show the finish screen."},
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
view_html = await get_view("finish", {"code": code})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
return await template_response("finish", {})
|
||||
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
# Get code from the message body
|
||||
data = await request.post()
|
||||
code = data.get("code")
|
||||
# Get cookie to get the state_id
|
||||
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||
if not state_id:
|
||||
return await error_response("Missing state cookie, please restart login.")
|
||||
|
||||
if not code:
|
||||
return web.Response(text="No code received", status=500)
|
||||
|
||||
# Return redirect to the main page for sign in with a cookie
|
||||
raise web.HTTPFound(
|
||||
location="/?storeToken=true",
|
||||
headers={
|
||||
# Set a cookie to enable autologin on only the specific path used
|
||||
# for the POST request, with all strict parameters set
|
||||
# This cookie should not be read by any Javascript or any other paths.
|
||||
# It can be really short lifetime as we redirect immediately (5 seconds)
|
||||
"set-cookie": "auth_oidc_code="
|
||||
+ code
|
||||
+ "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=5",
|
||||
},
|
||||
# Get redirect_uri from the state
|
||||
redirect_uri = await self.oidc_provider.async_get_redirect_uri_for_state(
|
||||
state_id
|
||||
)
|
||||
|
||||
if not redirect_uri:
|
||||
return await error_response("Invalid state, please restart login.")
|
||||
|
||||
# Get the message body
|
||||
data = await request.post()
|
||||
device_code = data.get("device_code")
|
||||
|
||||
# We are trying sign-in on this browser
|
||||
if not device_code:
|
||||
# Add to the URL correctly (also handle case where it's just the root)
|
||||
separator = "?"
|
||||
if "?" in redirect_uri:
|
||||
separator = "&"
|
||||
|
||||
# Redirect to this new URL for login
|
||||
new_url = (
|
||||
redirect_uri + separator + "storeToken=true&skip_oidc_redirect=true"
|
||||
)
|
||||
raise web.HTTPFound(location=new_url)
|
||||
|
||||
# Check if we can link this device
|
||||
linked = await self.oidc_provider.async_link_state_to_code(
|
||||
state_id, device_code
|
||||
)
|
||||
|
||||
if not linked:
|
||||
return await error_response(
|
||||
"Failed to link state to device code, please restart login."
|
||||
)
|
||||
|
||||
return await template_response("device_success", {})
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Injected authorization page, replacing the original"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from homeassistant.components.http import HomeAssistantView, StaticPathConfig
|
||||
@@ -19,7 +18,7 @@ async def read_file(path: str) -> str:
|
||||
return await f.read()
|
||||
|
||||
|
||||
async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None:
|
||||
async def frontend_injection(hass: HomeAssistant) -> None:
|
||||
"""Inject new frontend code into /auth/authorize."""
|
||||
router = hass.http.app.router
|
||||
frontend_path = None
|
||||
@@ -62,11 +61,8 @@ async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None:
|
||||
frontend_code = await read_file(frontend_path)
|
||||
|
||||
# Inject JS and register that route
|
||||
injection_js = "<script src='/auth/oidc/static/injection.js?v=3'></script>"
|
||||
sso_name_js = f"<script>window.sso_name = {json.dumps(sso_name)};</script>"
|
||||
frontend_code = frontend_code.replace(
|
||||
"</body>", f"{injection_js}{sso_name_js}</body>"
|
||||
)
|
||||
injection_js = "<script src='/auth/oidc/static/injection.js?v=4'></script>"
|
||||
frontend_code = frontend_code.replace("</body>", f"{injection_js}</body>")
|
||||
|
||||
await hass.http.async_register_static_paths(
|
||||
[
|
||||
@@ -100,10 +96,10 @@ class OIDCInjectedAuthPage(HomeAssistantView):
|
||||
self.html = html
|
||||
|
||||
@staticmethod
|
||||
async def inject(hass: HomeAssistant, sso_name: str) -> None:
|
||||
async def inject(hass: HomeAssistant) -> None:
|
||||
"""Inject the OIDC auth page into the frontend."""
|
||||
try:
|
||||
await frontend_injection(hass, sso_name)
|
||||
await frontend_injection(hass)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
_LOGGER.error("Failed to inject OIDC auth page: %s", e)
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Redirect route to redirect the user to the external OIDC server,
|
||||
can either be linked to directly or accessed through the welcome page."""
|
||||
|
||||
from urllib.parse import quote
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..tools.oidc_client import OIDCClient
|
||||
from ..tools.helpers import get_url, get_view
|
||||
from ..tools.helpers import error_response, get_url, get_valid_state_id, get_view
|
||||
|
||||
PATH = "/auth/oidc/redirect"
|
||||
|
||||
@@ -17,28 +19,44 @@ class OIDCRedirectView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:redirect"
|
||||
|
||||
def __init__(self, oidc_client: OIDCClient, force_https: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
oidc_client: OIDCClient,
|
||||
oidc_provider: OpenIDAuthProvider,
|
||||
force_https: bool,
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
self.oidc_provider = oidc_provider
|
||||
self.force_https = force_https
|
||||
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
async def get(self, req: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
# Get cookie to get the state_id
|
||||
state_id = await get_valid_state_id(req, self.oidc_provider)
|
||||
|
||||
if not state_id:
|
||||
# Direct access to the redirect endpoint, go to welcome page instead
|
||||
welcome_url = get_url("/auth/oidc/welcome", self.force_https)
|
||||
raise web.HTTPFound(welcome_url)
|
||||
|
||||
try:
|
||||
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(
|
||||
redirect_uri, state_id
|
||||
)
|
||||
|
||||
if auth_url:
|
||||
raise web.HTTPFound(auth_url)
|
||||
view_html = await get_view("redirect", {"url": quote(auth_url)})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
{"error": "Integration is misconfigured, discovery could not be obtained."},
|
||||
return await error_response(
|
||||
"Integration is misconfigured, discovery could not be obtained.",
|
||||
status=500,
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
async def post(self, req: web.Request) -> web.Response:
|
||||
"""POST"""
|
||||
return await self.get(request)
|
||||
return await self.get(req)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""Welcome route to show the user the OIDC login button and give instructions."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from urllib.parse import urlparse, parse_qs, unquote
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from ..tools.helpers import get_url, get_view
|
||||
from ..tools.helpers import error_response, get_url, template_response
|
||||
from ..provider import OpenIDAuthProvider
|
||||
|
||||
PATH = "/auth/oidc/welcome"
|
||||
|
||||
@@ -14,16 +18,90 @@ class OIDCWelcomeView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
def __init__(self, name: str, is_enabled: bool, force_https: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
oidc_provider: OpenIDAuthProvider,
|
||||
name: str,
|
||||
force_https: bool,
|
||||
has_other_auth_providers: bool,
|
||||
) -> None:
|
||||
self.oidc_provider = oidc_provider
|
||||
self.name = name
|
||||
self.is_enabled = is_enabled
|
||||
self.force_https = force_https
|
||||
self.has_other_auth_providers = has_other_auth_providers
|
||||
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
def determine_if_mobile(self, redirect_uri: str) -> bool:
|
||||
"""Determine if the client is a mobile client based on the redirect_uri."""
|
||||
oauth2_url = urlparse(redirect_uri)
|
||||
client_id = parse_qs(oauth2_url.query).get("client_id")
|
||||
|
||||
# If the client_id starts with https://home-assistant.io/ we assume it's a mobile client
|
||||
return bool(client_id and client_id[0].startswith("https://home-assistant.io/"))
|
||||
|
||||
async def get(self, req: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
if not self.is_enabled:
|
||||
raise web.HTTPTemporaryRedirect(get_url("/", self.force_https))
|
||||
# Get the query parameter with the redirect_uri
|
||||
redirect_uri = req.query.get("redirect_uri")
|
||||
|
||||
view_html = await get_view("welcome", {"name": self.name})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
# If set, determine if this is a mobile client based on the redirect_uri,
|
||||
# otherwise assume it's not mobile
|
||||
if redirect_uri:
|
||||
try:
|
||||
# decodeURIComponent(btoa(...)) -> unquote first, then base64 decode
|
||||
redirect_uri = base64.b64decode(
|
||||
unquote(redirect_uri), validate=True
|
||||
).decode("utf-8")
|
||||
is_mobile = self.determine_if_mobile(redirect_uri)
|
||||
except (binascii.Error, UnicodeDecodeError, ValueError):
|
||||
return await error_response(
|
||||
"Invalid redirect_uri, please restart login."
|
||||
)
|
||||
else:
|
||||
# Backwards compatibility with older versions that directly go to /auth/oidc/welcome
|
||||
# If not set, redirect back to the main page and assume that this is a web client
|
||||
redirect_uri = get_url("/", self.force_https)
|
||||
is_mobile = False
|
||||
|
||||
# Create OIDC state with the redirect_uri so we can use it later in the flow
|
||||
state_id = await self.oidc_provider.async_create_state(redirect_uri)
|
||||
cookie_header = self.oidc_provider.get_cookie_header(
|
||||
state_id, secure=self.force_https or req.url.scheme == "https"
|
||||
)
|
||||
|
||||
# If this is the only provider and we are on desktop,
|
||||
# automatically go through the OIDC login
|
||||
if not is_mobile and not self.has_other_auth_providers:
|
||||
raise web.HTTPFound(
|
||||
location=get_url("/auth/oidc/redirect", self.force_https),
|
||||
headers=cookie_header,
|
||||
)
|
||||
|
||||
# Otherwise display the screen with either mobile sign in or the buttons
|
||||
# First generate code if mobile
|
||||
code = None
|
||||
if is_mobile:
|
||||
# Create a code to login
|
||||
code = await self.oidc_provider.async_generate_device_code(state_id)
|
||||
if not code:
|
||||
return await error_response(
|
||||
"Failed to generate device code, please restart login.",
|
||||
status=500,
|
||||
)
|
||||
|
||||
# And add the other link if we have other auth providers
|
||||
other_link = None
|
||||
if self.has_other_auth_providers:
|
||||
other_link = get_url("/?skip_oidc_redirect=true", self.force_https)
|
||||
|
||||
# And display
|
||||
response = await template_response(
|
||||
"welcome",
|
||||
{
|
||||
"name": self.name,
|
||||
"other_link": other_link,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
response.headers.update(cookie_header)
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user