Respect force https on the initial redirect URI (#303)

* Also force HTTPS on the redirect URI

* Format & test
This commit is contained in:
Christiaan Goossens
2026-05-01 15:09:34 +02:00
committed by GitHub
parent 9d9025164a
commit 843c415f88
3 changed files with 54 additions and 4 deletions

View File

@@ -73,7 +73,7 @@ async def frontend_injection(
frontend_code = await read_file(frontend_path) frontend_code = await read_file(frontend_path)
# Inject JS and register that route # Inject JS and register that route
injection_js = "<script src='/auth/oidc/static/injection.js?v=6'></script>" injection_js = "<script src='/auth/oidc/static/injection.js?v=7'></script>"
frontend_code = frontend_code.replace("</body>", f"{injection_js}</body>") frontend_code = frontend_code.replace("</body>", f"{injection_js}</body>")
await hass.http.async_register_static_paths( await hass.http.async_register_static_paths(
@@ -156,8 +156,12 @@ class OIDCInjectedAuthPage(HomeAssistantView):
def _get_welcome_redirect_location(self, req: web.Request) -> str: def _get_welcome_redirect_location(self, req: web.Request) -> str:
"""Build the welcome URL for the injected auth page redirect.""" """Build the welcome URL for the injected auth page redirect."""
url = str(req.url)
if self.force_https:
url = url.replace("http://", "https://")
encoded_current_url = quote( encoded_current_url = quote(
base64.b64encode(str(req.url).encode("utf-8")).decode("ascii") base64.b64encode(url.encode("utf-8")).decode("ascii")
) )
return get_url( return get_url(
f"{WELCOME_PATH}?redirect_uri={encoded_current_url}", f"{WELCOME_PATH}?redirect_uri={encoded_current_url}",

View File

@@ -372,8 +372,7 @@ class OIDCClient:
tcp_connector_args["ssl"] = ssl_context tcp_connector_args["ssl"] = ssl_context
self.http_session = aiohttp.ClientSession( self.http_session = aiohttp.ClientSession(
trust_env=True, trust_env=True, connector=aiohttp.TCPConnector(**tcp_connector_args)
connector=aiohttp.TCPConnector(**tcp_connector_args)
) )
return self.http_session return self.http_session

View File

@@ -5,6 +5,8 @@ import os
from collections import OrderedDict from collections import OrderedDict
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import web
from auth_oidc.config.const import ( from auth_oidc.config.const import (
DISCOVERY_URL, DISCOVERY_URL,
CLIENT_ID, CLIENT_ID,
@@ -927,3 +929,48 @@ async def test_injected_auth_page_ignores_trusted_networks_when_not_first(
) )
assert_redirects_to_welcome(resp) assert_redirects_to_welcome(resp)
@pytest.mark.asyncio
async def test_injected_auth_page_converts_http_to_https_in_redirect(
hass: HomeAssistant,
):
"""_get_welcome_redirect_location should convert HTTP to HTTPS when force_https is True."""
await setup(hass)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
injected_page = OIDCInjectedAuthPage(
html="<html></html>",
provider=provider,
force_https=True,
has_trusted_networks_provider_first=False,
)
# Create a mock request with HTTP URL
mock_req = MagicMock(spec=web.Request)
mock_req.url = "http://example.com/auth/authorize?redirect_uri=test"
with patch(
"custom_components.auth_oidc.endpoints.injected_auth_page.get_url"
) as mock_get_url:
mock_get_url.return_value = "https://example.com/auth/oidc/welcome?redirect_uri=..."
# pylint: disable=protected-access
injected_page._get_welcome_redirect_location(mock_req)
# pylint: enable=protected-access
# Verify that the URL was converted from HTTP to HTTPS before being passed to get_url
call_args = mock_get_url.call_args
assert call_args is not None
welcome_path_with_redirect = call_args[0][0] # First positional argument to get_url
# Extract the redirect_uri parameter and decode it
parsed = urlparse(welcome_path_with_redirect)
query_params = parse_qs(parsed.query)
encoded_redirect_uri = query_params.get("redirect_uri", [None])[0]
# Decode the base64-encoded redirect_uri
if encoded_redirect_uri:
decoded_redirect_uri = base64.b64decode(unquote(encoded_redirect_uri)).decode("utf-8")
# Verify it contains https:// instead of http://
assert decoded_redirect_uri.startswith("https://example.com")