Allow forcing HTTPS in URL generation (#92)
* Force HTTPS feature * Add docs
This commit is contained in:
committed by
GitHub
parent
054f0e4bca
commit
e22f960d69
@@ -24,6 +24,8 @@ from .config import (
|
||||
ROLES,
|
||||
NETWORK,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION,
|
||||
FEATURES_FORCE_HTTPS,
|
||||
)
|
||||
|
||||
# pylint: enable=useless-import-alias
|
||||
@@ -93,14 +95,23 @@ async def async_setup(hass: HomeAssistant, config):
|
||||
|
||||
# Register the views
|
||||
is_frontend_injection_enabled = (
|
||||
features_config.get("disable_frontend_changes", False) is False
|
||||
features_config.get(FEATURES_DISABLE_FRONTEND_INJECTION, False) is False
|
||||
)
|
||||
name = config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
|
||||
name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name)
|
||||
|
||||
hass.http.register_view(OIDCWelcomeView(name, is_frontend_injection_enabled))
|
||||
hass.http.register_view(OIDCRedirectView(oidc_client))
|
||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
|
||||
force_https = features_config.get(FEATURES_FORCE_HTTPS, False)
|
||||
|
||||
hass.http.register_view(
|
||||
OIDCWelcomeView(
|
||||
name,
|
||||
# Welcome view is not enabled if frontend injection is enabled
|
||||
not is_frontend_injection_enabled,
|
||||
force_https,
|
||||
)
|
||||
)
|
||||
hass.http.register_view(OIDCRedirectView(oidc_client, force_https))
|
||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https))
|
||||
hass.http.register_view(OIDCFinishView())
|
||||
|
||||
_LOGGER.info("Registered OIDC views")
|
||||
|
||||
@@ -14,7 +14,8 @@ FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
||||
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
|
||||
FEATURE_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
|
||||
FEATURES_FORCE_HTTPS = "force_https"
|
||||
CLAIMS = "claims"
|
||||
CLAIMS_DISPLAY_NAME = "display_name"
|
||||
CLAIMS_USERNAME = "username"
|
||||
@@ -72,8 +73,12 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
): vol.Coerce(bool),
|
||||
# Disable frontend injection of OIDC login button
|
||||
vol.Optional(
|
||||
FEATURE_DISABLE_FRONTEND_INJECTION, default=False
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION, default=False
|
||||
): vol.Coerce(bool),
|
||||
# Force HTTPS on all generated URLs (like redirect_uri)
|
||||
vol.Optional(FEATURES_FORCE_HTTPS, default=False): vol.Coerce(
|
||||
bool
|
||||
),
|
||||
}
|
||||
),
|
||||
# Determine which specific claims will be used from the id_token
|
||||
|
||||
@@ -17,10 +17,14 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
name = "auth:oidc:callback"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient, oidc_provider: OpenIDAuthProvider
|
||||
self,
|
||||
oidc_client: OIDCClient,
|
||||
oidc_provider: OpenIDAuthProvider,
|
||||
force_https: bool,
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
self.oidc_provider = oidc_provider
|
||||
self.force_https = force_https
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
@@ -38,7 +42,7 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
)
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
redirect_uri = get_url("/auth/oidc/callback")
|
||||
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||
user_details = await self.oidc_client.async_complete_token_flow(
|
||||
redirect_uri, code, state
|
||||
)
|
||||
@@ -63,4 +67,6 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
code = await self.oidc_provider.async_save_user_info(user_details)
|
||||
return web.HTTPFound(get_url("/auth/oidc/finish?code=" + code))
|
||||
return web.HTTPFound(
|
||||
get_url("/auth/oidc/finish?code=" + code, self.force_https)
|
||||
)
|
||||
|
||||
@@ -17,13 +17,14 @@ class OIDCRedirectView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:redirect"
|
||||
|
||||
def __init__(self, oidc_client: OIDCClient) -> None:
|
||||
def __init__(self, oidc_client: OIDCClient, force_https: bool) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
self.force_https = force_https
|
||||
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
redirect_uri = get_url("/auth/oidc/callback")
|
||||
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
||||
|
||||
if auth_url:
|
||||
|
||||
@@ -14,15 +14,16 @@ class OIDCWelcomeView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
def __init__(self, name: str, is_frontend_injection_enabled: bool) -> None:
|
||||
def __init__(self, name: str, is_enabled: bool, force_https: bool) -> None:
|
||||
self.name = name
|
||||
self.is_enabled = not is_frontend_injection_enabled
|
||||
self.is_enabled = is_enabled
|
||||
self.force_https = force_https
|
||||
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
if not self.is_enabled:
|
||||
return web.HTTPTemporaryRedirect(get_url("/"))
|
||||
return web.HTTPTemporaryRedirect(get_url("/", self.force_https))
|
||||
|
||||
view_html = await get_view("welcome", {"name": self.name})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
@@ -4,12 +4,14 @@ from homeassistant.components import http
|
||||
from .views.loader import AsyncTemplateRenderer
|
||||
|
||||
|
||||
def get_url(path: str) -> str:
|
||||
def get_url(path: str, force_https: bool) -> str:
|
||||
"""Returns the requested path appended to the current request base URL."""
|
||||
if (req := http.current_request.get()) is None:
|
||||
raise RuntimeError("No current request in context")
|
||||
|
||||
base_uri = str(req.url).split("/auth", 2)[0]
|
||||
if force_https:
|
||||
base_uri = base_uri.replace("http://", "https://")
|
||||
return f"{base_uri}{path}"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user