From 862dfcb158b47314fda58acd2b09b5003da234cd Mon Sep 17 00:00:00 2001 From: Julian Pawlowski Date: Wed, 21 May 2025 15:14:55 +0000 Subject: [PATCH] fix --- .../tibber_prices/config_flow.py | 130 ++++++++++++------ .../tibber_prices/coordinator.py | 85 +++++++++++- requirements.txt | 7 +- tests/test_hello.py | 20 +++ 4 files changed, 188 insertions(+), 54 deletions(-) create mode 100644 tests/test_hello.py diff --git a/custom_components/tibber_prices/config_flow.py b/custom_components/tibber_prices/config_flow.py index f9590a0..a382cda 100644 --- a/custom_components/tibber_prices/config_flow.py +++ b/custom_components/tibber_prices/config_flow.py @@ -45,6 +45,11 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Get the options flow for this handler.""" return TibberPricesOptionsFlowHandler(config_entry) + @staticmethod + def async_get_reauth_flow(entry: config_entries.ConfigEntry) -> config_entries.ConfigFlow: + """Return the reauth flow handler for this integration.""" + return TibberPricesReauthFlowHandler(entry) + def is_matching(self, other_flow: dict) -> bool: """Return True if match_dict matches this flow.""" return bool(other_flow.get("domain") == DOMAIN) @@ -153,6 +158,86 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): return result["viewer"] +class TibberPricesReauthFlowHandler(config_entries.ConfigFlow): + """Handle a reauthentication flow for tibber_prices.""" + + def __init__(self, entry: config_entries.ConfigEntry) -> None: + """Initialize the reauth flow handler.""" + self._entry = entry + self._errors: dict[str, str] = {} + self._pending_user_input: dict | None = None + + async def async_step_user(self, user_input: dict | None = None) -> config_entries.ConfigFlowResult: + """Prompt for a new access token, then go to finish for home selection.""" + if user_input is not None: + try: + viewer = await TibberPricesApiClient( + access_token=user_input[CONF_ACCESS_TOKEN], + session=async_create_clientsession(self.hass), + ).async_get_viewer_details() + except TibberPricesApiClientAuthenticationError as exception: + LOGGER.warning(exception) + self._errors["base"] = "auth" + except TibberPricesApiClientCommunicationError as exception: + LOGGER.error(exception) + self._errors["base"] = "connection" + except TibberPricesApiClientError as exception: + LOGGER.exception(exception) + self._errors["base"] = "unknown" + else: + self._pending_user_input = { + "access_token": user_input[CONF_ACCESS_TOKEN], + "viewer": viewer.get("viewer", viewer), + } + return await self.async_step_finish() + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required(CONF_ACCESS_TOKEN): selector.TextSelector( + selector.TextSelectorConfig(type=selector.TextSelectorType.TEXT) + ), + } + ), + errors=self._errors, + ) + + async def async_step_finish(self, user_input: dict | None = None) -> config_entries.ConfigFlowResult: + """Show home selection, then update config entry.""" + if self._pending_user_input is not None and user_input is None: + viewer = self._pending_user_input["viewer"] + homes = viewer.get("homes", []) + home_choices = {} + for home in homes: + label = home.get("appNickname") or home.get("address", {}).get("address1") or home["id"] + if home.get("address", {}).get("city"): + label += f", {home['address']['city']}" + home_choices[home["id"]] = label + schema = vol.Schema({vol.Required("home_id"): vol.In(home_choices)}) + return self.async_show_form( + step_id="finish", + data_schema=schema, + description_placeholders={}, + errors={}, + last_step=True, + ) + if self._pending_user_input is not None and user_input is not None: + home_id = user_input["home_id"] + # Update the config entry with new token and home_id + self.hass.config_entries.async_update_entry( + self._entry, + data={ + **self._entry.data, + CONF_ACCESS_TOKEN: self._pending_user_input["access_token"], + "home_id": home_id, + }, + ) + self._pending_user_input = None + return self.async_abort(reason="reauth_successful") + return self.async_abort(reason="setup_complete") + + class TibberPricesOptionsFlowHandler(config_entries.OptionsFlow): """Tibber Prices config flow options handler.""" @@ -166,14 +251,6 @@ class TibberPricesOptionsFlowHandler(config_entries.OptionsFlow): # Build options schema options = { - vol.Required( - CONF_ACCESS_TOKEN, - default=self.config_entry.data.get(CONF_ACCESS_TOKEN, ""), - ): selector.TextSelector( - selector.TextSelectorConfig( - type=selector.TextSelectorType.TEXT, - ), - ), vol.Optional( CONF_EXTENDED_DESCRIPTIONS, default=self.config_entry.options.get( @@ -216,47 +293,10 @@ class TibberPricesOptionsFlowHandler(config_entries.OptionsFlow): } if user_input is not None: - # Validate new access token if changed - new_token = user_input.get(CONF_ACCESS_TOKEN, self.config_entry.data.get(CONF_ACCESS_TOKEN, "")) or "" - current_home_id = self.config_entry.data.get("home_id", "") - errors = {} - if new_token != self.config_entry.data.get(CONF_ACCESS_TOKEN, ""): - try: - client = TibberPricesApiClient( - access_token=new_token, - session=async_create_clientsession(self.hass), - ) - result = await client.async_get_viewer_details() - homes = result["viewer"].get("homes", []) - if not any(home["id"] == current_home_id for home in homes): - errors[CONF_ACCESS_TOKEN] = "different_home" - except TibberPricesApiClientAuthenticationError as exception: - LOGGER.warning(exception) - errors[CONF_ACCESS_TOKEN] = "auth" - except TibberPricesApiClientCommunicationError as exception: - LOGGER.error(exception) - errors[CONF_ACCESS_TOKEN] = "connection" - except TibberPricesApiClientError as exception: - LOGGER.exception(exception) - errors[CONF_ACCESS_TOKEN] = "unknown" - if errors: - # Show form again with errors - description_placeholders = { - "access_token": new_token, - "home_id": current_home_id, - } - return self.async_show_form( - step_id="init", - data_schema=vol.Schema(options), - errors=errors, - description_placeholders=description_placeholders, - ) - # Only update options and access token if valid return self.async_create_entry(title="", data=user_input) # Prepare read-only info for description placeholders description_placeholders = { - "access_token": self.config_entry.data.get(CONF_ACCESS_TOKEN, ""), "unique_id": self.config_entry.unique_id or "", } diff --git a/custom_components/tibber_prices/coordinator.py b/custom_components/tibber_prices/coordinator.py index a02e598..1f3f4d3 100644 --- a/custom_components/tibber_prices/coordinator.py +++ b/custom_components/tibber_prices/coordinator.py @@ -165,9 +165,34 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]): await self.async_refresh() async def _async_update_data(self) -> dict: - """Fetch new state data for the coordinator.""" + """Fetch new state data for the coordinator. Handles expired credentials by raising ConfigEntryAuthFailed.""" if self._cached_price_data is None: - await self._async_initialize() + try: + await self._async_initialize() + except TimeoutError as exception: + msg = "Timeout during initialization" + LOGGER.error( + "%s: %s", + msg, + exception, + extra={"error_type": "timeout_init"}, + ) + raise UpdateFailed(msg) from exception + except TibberPricesApiClientAuthenticationError as exception: + msg = "Authentication failed: credentials expired or invalid" + LOGGER.error( + "Authentication failed (likely expired credentials) during initialization", + extra={"error": str(exception), "error_type": "auth_failed_init"}, + ) + raise ConfigEntryAuthFailed(msg) from exception + except Exception as exception: + msg = "Unexpected error during initialization" + LOGGER.exception( + "%s", + msg, + extra={"error": str(exception), "error_type": "unexpected_init"}, + ) + raise UpdateFailed(msg) from exception try: current_time = dt_util.now() result = None @@ -190,11 +215,24 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]): else: result = await self._handle_conditional_update(current_time) except TibberPricesApiClientAuthenticationError as exception: + msg = "Authentication failed: credentials expired or invalid" LOGGER.error( - "Authentication failed", + "Authentication failed (likely expired credentials)", extra={"error": str(exception), "error_type": "auth_failed"}, ) - raise ConfigEntryAuthFailed(AUTH_FAILED_MSG) from exception + raise ConfigEntryAuthFailed(msg) from exception + except TimeoutError as exception: + msg = "Timeout during data update" + LOGGER.warning( + "%s: %s", + msg, + exception, + extra={"error_type": "timeout_runtime"}, + ) + if self._cached_price_data is not None: + LOGGER.info("Using cached data as fallback after timeout") + return self._merge_all_cached_data() + raise UpdateFailed(msg) from exception except ( TibberPricesApiClientCommunicationError, TibberPricesApiClientError, @@ -278,9 +316,44 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]): return self._merge_all_cached_data() async def _fetch_price_data(self) -> dict: - """Fetch fresh price data from API.""" + """Fetch fresh price data from API and check for GraphQL errors.""" client = self.config_entry.runtime_data.client - return await client.async_get_price_info() + data = await client.async_get_price_info() + # Check for GraphQL errors at the top level + if isinstance(data, dict) and "errors" in data and data["errors"]: + errors = data["errors"] + # Look for authentication-related errors (extensions.code == 'UNAUTHENTICATED') + for err in errors: + code = err.get("extensions", {}).get("code") + msg = str(err.get("message", "")) + if code == "UNAUTHENTICATED": + LOGGER.error( + "GraphQL authentication error (UNAUTHENTICATED): %s", + msg, + extra={"error": msg, "error_type": "graphql_auth_failed", "code": code}, + ) + raise TibberPricesApiClientAuthenticationError(msg) + # Fallback: also check for other auth-related keywords in message/type + err_type = str(err.get("type", "")) + if any( + s in msg.lower() or s in err_type.lower() + for s in ("auth", "token", "credential", "unauth", "expired") + ): + LOGGER.error( + "GraphQL authentication error: %s", + msg, + extra={"error": msg, "error_type": "graphql_auth_failed_fallback"}, + ) + raise TibberPricesApiClientAuthenticationError(msg) + # If errors exist but not auth-related, log and raise generic error + msg = f"GraphQL error(s): {errors}" + LOGGER.error( + "GraphQL error(s) in response: %s", + errors, + extra={"error_type": "graphql_error"}, + ) + raise TibberPricesApiClientError(msg) + return data async def _get_rating_data_for_type(self, rating_type: str) -> dict: """Get fresh rating data for a specific type in flat format.""" diff --git a/requirements.txt b/requirements.txt index d097945..8b117f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -colorlog==6.9.0 -homeassistant==2025.4.2 +colorlog>=6.9.0,<7.0.0 +homeassistant>=2025.5.0,<2025.6.0 +pytest-homeassistant-custom-components>=0.13.0,<0.14.0 pip>=21.3.1 -ruff==0.11.6 \ No newline at end of file +ruff>=0.11.6,<0.12.0 diff --git a/tests/test_hello.py b/tests/test_hello.py new file mode 100644 index 0000000..b591bc7 --- /dev/null +++ b/tests/test_hello.py @@ -0,0 +1,20 @@ +import unittest +from unittest.mock import Mock, patch + + +class TestReauthentication(unittest.TestCase): + @patch("your_module.connection") # Replace 'your_module' with the actual module name + def test_reauthentication_flow(self, mock_connection): + mock_connection.reauthenticate = Mock(return_value=True) + result = mock_connection.reauthenticate() + self.assertTrue(result) + + @patch("your_module.connection") # Replace 'your_module' with the actual module name + def test_connection_timeout(self, mock_connection): + mock_connection.connect = Mock(side_effect=TimeoutError) + with self.assertRaises(TimeoutError): + mock_connection.connect() + + +if __name__ == "__main__": + unittest.main()