From f57fdfde6b5865db41704b34e92a9fdfaa046699 Mon Sep 17 00:00:00 2001 From: Julian Pawlowski Date: Sun, 25 May 2025 22:15:25 +0000 Subject: [PATCH] update --- custom_components/tibber_prices/__init__.py | 6 +- custom_components/tibber_prices/api.py | 451 +++++-- .../tibber_prices/config_flow.py | 19 +- .../tibber_prices/coordinator.py | 1142 +++++------------ custom_components/tibber_prices/entity.py | 45 +- custom_components/tibber_prices/manifest.json | 1 + custom_components/tibber_prices/services.py | 132 +- custom_components/tibber_prices/services.yaml | 13 + .../tibber_prices/translations/de.json | 22 + .../tibber_prices/translations/en.json | 28 +- hacs.json | 2 +- tests/__init__.py | 1 + tests/test_coordinator_basic.py | 84 ++ tests/test_coordinator_enhanced.py | 255 ++++ tests/test_hello.py | 20 - 15 files changed, 1230 insertions(+), 991 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_coordinator_basic.py create mode 100644 tests/test_coordinator_enhanced.py delete mode 100644 tests/test_hello.py diff --git a/custom_components/tibber_prices/__init__.py b/custom_components/tibber_prices/__init__.py index c4a7ba6..c1aa0b8 100644 --- a/custom_components/tibber_prices/__init__.py +++ b/custom_components/tibber_prices/__init__.py @@ -51,9 +51,7 @@ async def async_setup_entry( coordinator = TibberPricesDataUpdateCoordinator( hass=hass, - entry=entry, - logger=LOGGER, - name=DOMAIN, + config_entry=entry, ) entry.runtime_data = TibberPricesData( client=TibberPricesApiClient( @@ -88,7 +86,7 @@ async def async_unload_entry( # Unregister services if this was the last config entry if not hass.config_entries.async_entries(DOMAIN): - for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml"]: + for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml", "refresh_user_data"]: if hass.services.has_service(DOMAIN, service): hass.services.async_remove(DOMAIN, service) diff --git a/custom_components/tibber_prices/api.py b/custom_components/tibber_prices/api.py index fd06c86..0c3b3db 100644 --- a/custom_components/tibber_prices/api.py +++ b/custom_components/tibber_prices/api.py @@ -10,7 +10,6 @@ from enum import Enum from typing import Any import aiohttp -import async_timeout from homeassistant.const import __version__ as ha_version from homeassistant.util import dt as dt_util @@ -19,9 +18,10 @@ from .const import VERSION _LOGGER = logging.getLogger(__name__) -HTTP_TOO_MANY_REQUESTS = 429 +HTTP_BAD_REQUEST = 400 HTTP_UNAUTHORIZED = 401 HTTP_FORBIDDEN = 403 +HTTP_TOO_MANY_REQUESTS = 429 class QueryType(Enum): @@ -31,7 +31,7 @@ class QueryType(Enum): DAILY_RATING = "daily" HOURLY_RATING = "hourly" MONTHLY_RATING = "monthly" - VIEWER = "viewer" + USER = "user" class TibberPricesApiClientError(Exception): @@ -42,7 +42,8 @@ class TibberPricesApiClientError(Exception): GRAPHQL_ERROR = "GraphQL error: {message}" EMPTY_DATA_ERROR = "Empty data received for {query_type}" GENERIC_ERROR = "Something went wrong! {exception}" - RATE_LIMIT_ERROR = "Rate limit exceeded" + RATE_LIMIT_ERROR = "Rate limit exceeded. Please wait {retry_after} seconds before retrying" + INVALID_QUERY_ERROR = "Invalid GraphQL query: {message}" class TibberPricesApiClientCommunicationError(TibberPricesApiClientError): @@ -55,15 +56,33 @@ class TibberPricesApiClientCommunicationError(TibberPricesApiClientError): class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError): """Exception to indicate an authentication error.""" - INVALID_CREDENTIALS = "Invalid credentials" + INVALID_CREDENTIALS = "Invalid access token or expired credentials" + + +class TibberPricesApiClientPermissionError(TibberPricesApiClientError): + """Exception to indicate insufficient permissions.""" + + INSUFFICIENT_PERMISSIONS = "Access forbidden - insufficient permissions for this operation" def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None: """Verify that the response is valid.""" - if response.status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): + if response.status == HTTP_UNAUTHORIZED: + _LOGGER.error("Tibber API authentication failed - check access token") raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS) + if response.status == HTTP_FORBIDDEN: + _LOGGER.error("Tibber API access forbidden - insufficient permissions") + raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS) if response.status == HTTP_TOO_MANY_REQUESTS: - raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR) + # Check for Retry-After header that Tibber might send + retry_after = response.headers.get("Retry-After", "unknown") + _LOGGER.warning("Tibber API rate limit exceeded - retry after %s seconds", retry_after) + raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after)) + if response.status == HTTP_BAD_REQUEST: + _LOGGER.error("Tibber API rejected request - likely invalid GraphQL query") + raise TibberPricesApiClientError( + TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message="Bad request - likely invalid GraphQL query") + ) response.raise_for_status() @@ -72,21 +91,41 @@ async def _verify_graphql_response(response_json: dict, query_type: QueryType) - if "errors" in response_json: errors = response_json["errors"] if not errors: + _LOGGER.error("Tibber API returned empty errors array") raise TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR) error = errors[0] # Take first error if not isinstance(error, dict): + _LOGGER.error("Tibber API returned malformed error: %s", error) raise TibberPricesApiClientError(TibberPricesApiClientError.MALFORMED_ERROR.format(error=error)) message = error.get("message", "Unknown error") extensions = error.get("extensions", {}) + error_code = extensions.get("code") - if extensions.get("code") == "UNAUTHENTICATED": + # Handle specific Tibber API error codes + if error_code == "UNAUTHENTICATED": + _LOGGER.error("Tibber API authentication error: %s", message) raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS) + if error_code == "FORBIDDEN": + _LOGGER.error("Tibber API permission error: %s", message) + raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS) + if error_code in ["RATE_LIMITED", "TOO_MANY_REQUESTS"]: + # Some GraphQL APIs return rate limit info in extensions + retry_after = extensions.get("retryAfter", "unknown") + _LOGGER.warning("Tibber API rate limited via GraphQL: %s (retry after %s)", message, retry_after) + raise TibberPricesApiClientError( + TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after) + ) + if error_code in ["VALIDATION_ERROR", "GRAPHQL_VALIDATION_FAILED"]: + _LOGGER.error("Tibber API validation error: %s", message) + raise TibberPricesApiClientError(TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message=message)) + _LOGGER.error("Tibber API GraphQL error (code: %s): %s", error_code or "unknown", message) raise TibberPricesApiClientError(TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message)) if "data" not in response_json or response_json["data"] is None: + _LOGGER.error("Tibber API response missing data object") raise TibberPricesApiClientError( TibberPricesApiClientError.GRAPHQL_ERROR.format(message="Response missing data object") ) @@ -122,7 +161,7 @@ def _is_data_empty(data: dict, query_type: str) -> bool: is_empty = False try: - if query_type == "viewer": + if query_type == "user": has_user_id = ( "viewer" in data and isinstance(data["viewer"], dict) @@ -321,11 +360,16 @@ class TibberPricesApiClient: """Tibber API Client.""" self._access_token = access_token self._session = session - self._request_semaphore = asyncio.Semaphore(2) + self._request_semaphore = asyncio.Semaphore(2) # Max 2 concurrent requests self._last_request_time = dt_util.now() - self._min_request_interval = timedelta(seconds=1) + self._min_request_interval = timedelta(seconds=1) # Min 1 second between requests self._max_retries = 5 - self._retry_delay = 2 + self._retry_delay = 2 # Base retry delay in seconds + + # Timeout configuration - more granular control + self._connect_timeout = 10 # Connection timeout in seconds + self._request_timeout = 25 # Total request timeout in seconds + self._socket_connect_timeout = 5 # Socket connection timeout async def async_get_viewer_details(self) -> Any: """Test connection to the API.""" @@ -352,11 +396,11 @@ class TibberPricesApiClient: } """ }, - query_type=QueryType.VIEWER, + query_type=QueryType.USER, ) - async def async_get_price_info(self, home_id: str) -> dict: - """Get price info data in flat format for the specified home_id.""" + async def async_get_price_info(self) -> dict: + """Get price info data in flat format for all homes.""" data = await self._api_wrapper( data={ "query": """ @@ -371,15 +415,21 @@ class TibberPricesApiClient: query_type=QueryType.PRICE_INFO, ) homes = data.get("viewer", {}).get("homes", []) - home = next((h for h in homes if h.get("id") == home_id), None) - if home and "currentSubscription" in home: - data["priceInfo"] = _flatten_price_info(home["currentSubscription"]) - else: - data["priceInfo"] = {} + + homes_data = {} + for home in homes: + home_id = home.get("id") + if home_id: + if "currentSubscription" in home: + homes_data[home_id] = _flatten_price_info(home["currentSubscription"]) + else: + homes_data[home_id] = {} + + data["homes"] = homes_data return data - async def async_get_daily_price_rating(self, home_id: str) -> dict: - """Get daily price rating data in flat format for the specified home_id.""" + async def async_get_daily_price_rating(self) -> dict: + """Get daily price rating data in flat format for all homes.""" data = await self._api_wrapper( data={ "query": """ @@ -394,15 +444,21 @@ class TibberPricesApiClient: query_type=QueryType.DAILY_RATING, ) homes = data.get("viewer", {}).get("homes", []) - home = next((h for h in homes if h.get("id") == home_id), None) - if home and "currentSubscription" in home: - data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) - else: - data["priceRating"] = {} + + homes_data = {} + for home in homes: + home_id = home.get("id") + if home_id: + if "currentSubscription" in home: + homes_data[home_id] = _flatten_price_rating(home["currentSubscription"]) + else: + homes_data[home_id] = {} + + data["homes"] = homes_data return data - async def async_get_hourly_price_rating(self, home_id: str) -> dict: - """Get hourly price rating data in flat format for the specified home_id.""" + async def async_get_hourly_price_rating(self) -> dict: + """Get hourly price rating data in flat format for all homes.""" data = await self._api_wrapper( data={ "query": """ @@ -417,15 +473,21 @@ class TibberPricesApiClient: query_type=QueryType.HOURLY_RATING, ) homes = data.get("viewer", {}).get("homes", []) - home = next((h for h in homes if h.get("id") == home_id), None) - if home and "currentSubscription" in home: - data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) - else: - data["priceRating"] = {} + + homes_data = {} + for home in homes: + home_id = home.get("id") + if home_id: + if "currentSubscription" in home: + homes_data[home_id] = _flatten_price_rating(home["currentSubscription"]) + else: + homes_data[home_id] = {} + + data["homes"] = homes_data return data - async def async_get_monthly_price_rating(self, home_id: str) -> dict: - """Get monthly price rating data in flat format for the specified home_id.""" + async def async_get_monthly_price_rating(self) -> dict: + """Get monthly price rating data in flat format for all homes.""" data = await self._api_wrapper( data={ "query": """ @@ -440,40 +502,52 @@ class TibberPricesApiClient: query_type=QueryType.MONTHLY_RATING, ) homes = data.get("viewer", {}).get("homes", []) - home = next((h for h in homes if h.get("id") == home_id), None) - if home and "currentSubscription" in home: - data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) - else: - data["priceRating"] = {} + + homes_data = {} + for home in homes: + home_id = home.get("id") + if home_id: + if "currentSubscription" in home: + homes_data[home_id] = _flatten_price_rating(home["currentSubscription"]) + else: + homes_data[home_id] = {} + + data["homes"] = homes_data return data - async def async_get_data(self, home_id: str) -> dict: - """Get all data from the API by combining multiple queries in flat format for the specified home_id.""" - price_info = await self.async_get_price_info(home_id) - daily_rating = await self.async_get_daily_price_rating(home_id) - hourly_rating = await self.async_get_hourly_price_rating(home_id) - monthly_rating = await self.async_get_monthly_price_rating(home_id) - price_rating = { - "thresholdPercentages": daily_rating["priceRating"].get("thresholdPercentages"), - "daily": daily_rating["priceRating"].get("daily", []), - "hourly": hourly_rating["priceRating"].get("hourly", []), - "monthly": monthly_rating["priceRating"].get("monthly", []), - "currency": ( - daily_rating["priceRating"].get("currency") - or hourly_rating["priceRating"].get("currency") - or monthly_rating["priceRating"].get("currency") - ), - } - return { - "priceInfo": price_info["priceInfo"], - "priceRating": price_rating, - } + async def async_get_data(self) -> dict: + """Get all data from the API by combining multiple queries in flat format for all homes.""" + price_info = await self.async_get_price_info() + daily_rating = await self.async_get_daily_price_rating() + hourly_rating = await self.async_get_hourly_price_rating() + monthly_rating = await self.async_get_monthly_price_rating() - async def async_set_title(self, value: str) -> Any: - """Get data from the API.""" - return await self._api_wrapper( - data={"title": value}, - ) + all_home_ids = set() + all_home_ids.update(price_info.get("homes", {}).keys()) + all_home_ids.update(daily_rating.get("homes", {}).keys()) + all_home_ids.update(hourly_rating.get("homes", {}).keys()) + all_home_ids.update(monthly_rating.get("homes", {}).keys()) + + homes_combined = {} + for home_id in all_home_ids: + daily_data = daily_rating.get("homes", {}).get(home_id, {}) + hourly_data = hourly_rating.get("homes", {}).get(home_id, {}) + monthly_data = monthly_rating.get("homes", {}).get(home_id, {}) + + price_rating = { + "thresholdPercentages": daily_data.get("thresholdPercentages"), + "daily": daily_data.get("daily", []), + "hourly": hourly_data.get("hourly", []), + "monthly": monthly_data.get("monthly", []), + "currency": (daily_data.get("currency") or hourly_data.get("currency") or monthly_data.get("currency")), + } + + homes_combined[home_id] = { + "priceInfo": price_info.get("homes", {}).get(home_id, {}), + "priceRating": price_rating, + } + + return {"homes": homes_combined} async def _make_request( self, @@ -481,23 +555,117 @@ class TibberPricesApiClient: data: dict, query_type: QueryType, ) -> dict: - """Make an API request with proper error handling.""" + """Make an API request with comprehensive error handling for network issues.""" _LOGGER.debug("Making API request with data: %s", data) - response = await self._session.request( - method="POST", - url="https://api.tibber.com/v1-beta/gql", - headers=headers, - json=data, - ) + try: + # More granular timeout configuration for better network failure handling + timeout = aiohttp.ClientTimeout( + total=self._request_timeout, # Total request timeout: 25s + connect=self._connect_timeout, # Connection timeout: 10s + sock_connect=self._socket_connect_timeout, # Socket connection: 5s + ) - _verify_response_or_raise(response) - response_json = await response.json() - _LOGGER.debug("Received API response: %s", response_json) + response = await self._session.request( + method="POST", + url="https://api.tibber.com/v1-beta/gql", + headers=headers, + json=data, + timeout=timeout, + ) - await _verify_graphql_response(response_json, query_type) + _verify_response_or_raise(response) + response_json = await response.json() + _LOGGER.debug("Received API response: %s", response_json) - return response_json["data"] + await _verify_graphql_response(response_json, query_type) + + return response_json["data"] + + except aiohttp.ClientResponseError as error: + _LOGGER.exception("HTTP error during API request") + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) + ) from error + + except aiohttp.ClientConnectorError as error: + _LOGGER.exception("Connection error - server unreachable or network down") + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) + ) from error + + except aiohttp.ServerDisconnectedError as error: + _LOGGER.exception("Server disconnected during request") + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) + ) from error + + except TimeoutError as error: + _LOGGER.exception( + "Request timeout after %d seconds - slow network or server overload", self._request_timeout + ) + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(exception=str(error)) + ) from error + + except socket.gaierror as error: + self._handle_dns_error(error) + + except OSError as error: + self._handle_network_error(error) + + def _handle_dns_error(self, error: socket.gaierror) -> None: + """Handle DNS resolution errors with IPv4/IPv6 dual stack considerations.""" + error_msg = str(error) + + if "Name or service not known" in error_msg: + _LOGGER.exception("DNS resolution failed - domain name not found") + elif "Temporary failure in name resolution" in error_msg: + _LOGGER.exception("DNS resolution temporarily failed - network or DNS server issue") + elif "Address family for hostname not supported" in error_msg: + _LOGGER.exception("DNS resolution failed - IPv4/IPv6 address family not supported") + elif "No address associated with hostname" in error_msg: + _LOGGER.exception("DNS resolution failed - no IPv4/IPv6 addresses found") + else: + _LOGGER.exception("DNS resolution failed - check internet connection: %s", error_msg) + + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) + ) from error + + def _handle_network_error(self, error: OSError) -> None: + """Handle network-level errors with IPv4/IPv6 dual stack considerations.""" + error_msg = str(error) + errno = getattr(error, "errno", None) + + # Common IPv4/IPv6 dual stack network error codes + errno_network_unreachable = 101 # ENETUNREACH + errno_host_unreachable = 113 # EHOSTUNREACH + errno_connection_refused = 111 # ECONNREFUSED + errno_connection_timeout = 110 # ETIMEDOUT + + if errno == errno_network_unreachable: + _LOGGER.exception("Network unreachable - check internet connection or IPv4/IPv6 routing") + elif errno == errno_host_unreachable: + _LOGGER.exception("Host unreachable - routing issue or IPv4/IPv6 connectivity problem") + elif errno == errno_connection_refused: + _LOGGER.exception("Connection refused - server not accepting connections") + elif errno == errno_connection_timeout: + _LOGGER.exception("Connection timed out - network latency or server overload") + elif "Address family not supported" in error_msg: + _LOGGER.exception("Address family not supported - IPv4/IPv6 configuration issue") + elif "Protocol not available" in error_msg: + _LOGGER.exception("Protocol not available - IPv4/IPv6 stack configuration issue") + elif "Network is down" in error_msg: + _LOGGER.exception("Network interface is down - check network adapter") + elif "Permission denied" in error_msg: + _LOGGER.exception("Network permission denied - firewall or security restriction") + else: + _LOGGER.exception("Network error - internet may be down: %s", error_msg) + + raise TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) + ) from error async def _handle_request( self, @@ -517,19 +685,90 @@ class TibberPricesApiClient: ) await asyncio.sleep(sleep_time) - async with async_timeout.timeout(10): - self._last_request_time = dt_util.now() - return await self._make_request( - headers, - data or {}, - query_type, - ) + self._last_request_time = dt_util.now() + return await self._make_request( + headers, + data or {}, + query_type, + ) + + def _should_retry_error(self, error: Exception, retry: int) -> tuple[bool, int]: + """Determine if an error should be retried and calculate delay.""" + # Check if we've exceeded max retries first + if retry >= self._max_retries: + return False, 0 + + # Non-retryable errors - authentication and permission issues + if isinstance(error, (TibberPricesApiClientAuthenticationError, TibberPricesApiClientPermissionError)): + return False, 0 + + # Handle API-specific errors + if isinstance(error, TibberPricesApiClientError): + return self._handle_api_error_retry(error, retry) + + # Network and timeout errors - retryable with exponential backoff + if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)): + delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds + return True, delay + + # Unknown errors - not retryable + return False, 0 + + def _handle_api_error_retry(self, error: TibberPricesApiClientError, retry: int) -> tuple[bool, int]: + """Handle retry logic for API-specific errors.""" + error_msg = str(error) + + # Non-retryable: Invalid queries + if "Invalid GraphQL query" in error_msg or "Bad request" in error_msg: + return False, 0 + + # Rate limits - special handling with extracted delay + if "Rate limit exceeded" in error_msg or "rate limited" in error_msg.lower(): + delay = self._extract_retry_delay(error, retry) + return True, delay + + # Empty data - retryable with capped exponential backoff + if "Empty data received" in error_msg: + delay = min(self._retry_delay * (2**retry), 60) # Cap at 60 seconds + return True, delay + + # Other API errors - retryable with capped exponential backoff + delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds + return True, delay + + def _extract_retry_delay(self, error: Exception, retry: int) -> int: + """Extract retry delay from rate limit error or use exponential backoff.""" + import re + + error_msg = str(error) + + # Try to extract Retry-After value from error message + retry_after_match = re.search(r"retry after (\d+) seconds", error_msg.lower()) + if retry_after_match: + try: + retry_after = int(retry_after_match.group(1)) + return min(retry_after + 1, 300) # Add buffer, max 5 minutes + except ValueError: + pass + + # Try to extract generic seconds value + seconds_match = re.search(r"(\d+) seconds", error_msg) + if seconds_match: + try: + seconds = int(seconds_match.group(1)) + return min(seconds + 1, 300) # Add buffer, max 5 minutes + except ValueError: + pass + + # Fall back to exponential backoff with cap + base_delay = self._retry_delay * (2**retry) + return min(base_delay, 120) # Cap at 2 minutes for rate limits async def _api_wrapper( self, data: dict | None = None, headers: dict | None = None, - query_type: QueryType = QueryType.VIEWER, + query_type: QueryType = QueryType.USER, ) -> Any: """Get information from the API with rate limiting and retry logic.""" headers = headers or _prepare_headers(self._access_token) @@ -537,32 +776,34 @@ class TibberPricesApiClient: for retry in range(self._max_retries + 1): try: - return await self._handle_request( - headers, - data or {}, - query_type, - ) + return await self._handle_request(headers, data or {}, query_type) - except TibberPricesApiClientAuthenticationError: + except ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientPermissionError, + ): + _LOGGER.exception("Non-retryable error occurred") raise except ( + TibberPricesApiClientError, aiohttp.ClientError, socket.gaierror, TimeoutError, - TibberPricesApiClientError, ) as error: last_error = ( error if isinstance(error, TibberPricesApiClientError) - else TibberPricesApiClientError( - TibberPricesApiClientError.GENERIC_ERROR.format(exception=str(error)) + else TibberPricesApiClientCommunicationError( + TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) ) - if retry < self._max_retries: - delay = self._retry_delay * (2**retry) + should_retry, delay = self._should_retry_error(error, retry) + if should_retry: + error_type = self._get_error_type(error) _LOGGER.warning( - "Request failed, attempt %d/%d. Retrying in %d seconds: %s", + "Tibber %s error, attempt %d/%d. Retrying in %d seconds: %s", + error_type, retry + 1, self._max_retries, delay, @@ -571,6 +812,10 @@ class TibberPricesApiClient: await asyncio.sleep(delay) continue + if "Invalid GraphQL query" in str(error): + _LOGGER.exception("Invalid query - not retrying") + raise + # Handle final error state if isinstance(last_error, TimeoutError): raise TibberPricesApiClientCommunicationError( @@ -582,3 +827,11 @@ class TibberPricesApiClient: ) from last_error raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR) + + def _get_error_type(self, error: Exception) -> str: + """Get a descriptive error type for logging.""" + if "Rate limit" in str(error): + return "rate limit" + if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)): + return "network" + return "API" diff --git a/custom_components/tibber_prices/config_flow.py b/custom_components/tibber_prices/config_flow.py index 4f552c6..fabd89d 100644 --- a/custom_components/tibber_prices/config_flow.py +++ b/custom_components/tibber_prices/config_flow.py @@ -204,10 +204,15 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow): async def async_step_user(self, user_input: dict[str, Any] | None = None) -> SubentryFlowResult: """User flow to add a new home.""" parent_entry = self._get_entry() - if not parent_entry: + if not parent_entry or not hasattr(parent_entry, "runtime_data") or not parent_entry.runtime_data: return self.async_abort(reason="no_parent_entry") - homes = parent_entry.data.get("homes", []) + coordinator = parent_entry.runtime_data.coordinator + + # Force refresh user data to get latest homes from Tibber API + await coordinator.refresh_user_data() + + homes = coordinator.get_user_homes() if not homes: return self.async_abort(reason="no_available_homes") @@ -233,11 +238,11 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow): ) # Get existing home IDs by checking all subentries for this parent - existing_home_ids = set() - for entry in self.hass.config_entries.async_entries(DOMAIN): - # Check if this entry has home_id data (indicating it's a subentry) - if entry.data.get("home_id") and entry != parent_entry: - existing_home_ids.add(entry.data["home_id"]) + existing_home_ids = { + entry.data["home_id"] + for entry in self.hass.config_entries.async_entries(DOMAIN) + if entry.data.get("home_id") and entry != parent_entry + } available_homes = [home for home in homes if home["id"] not in existing_home_ids] diff --git a/custom_components/tibber_prices/coordinator.py b/custom_components/tibber_prices/coordinator.py index 5eee345..057177d 100644 --- a/custom_components/tibber_prices/coordinator.py +++ b/custom_components/tibber_prices/coordinator.py @@ -1,916 +1,406 @@ -"""Coordinator for fetching Tibber price data.""" +"""Enhanced coordinator for fetching Tibber price data with comprehensive caching.""" from __future__ import annotations -import asyncio import logging -from datetime import date, datetime, timedelta -from typing import TYPE_CHECKING, Any, Final, cast +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any -import homeassistant.util.dt as dt_util +from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import ConfigEntryAuthFailed -from homeassistant.helpers.event import async_track_time_change +from homeassistant.helpers import aiohttp_client from homeassistant.helpers.storage import Store from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed +from homeassistant.util import dt as dt_util if TYPE_CHECKING: from collections.abc import Callable - from .data import TibberPricesConfigEntry + from homeassistant.config_entries import ConfigEntry from .api import ( + TibberPricesApiClient, TibberPricesApiClientAuthenticationError, TibberPricesApiClientCommunicationError, TibberPricesApiClientError, ) -from .const import DOMAIN, LOGGER +from .const import DOMAIN _LOGGER = logging.getLogger(__name__) -PRICE_UPDATE_RANDOM_MIN_HOUR: Final = 13 # Don't check before 13:00 -PRICE_UPDATE_RANDOM_MAX_HOUR: Final = 15 # Don't check after 15:00 -RANDOM_DELAY_MAX_MINUTES: Final = 120 # Maximum random delay in minutes -NO_DATA_ERROR_MSG: Final = "No data available" -STORAGE_VERSION: Final = 1 -UPDATE_INTERVAL: Final = timedelta(days=1) # Both price and rating data update daily -UPDATE_FAILED_MSG: Final = "Update failed" -AUTH_FAILED_MSG: Final = "Authentication failed" -MIN_RETRY_INTERVAL: Final = timedelta(minutes=10) -END_OF_DAY_HOUR: Final = 24 # End of day hour for logic clarity +# Storage version for storing data +STORAGE_VERSION = 1 + +# Update interval - fetch data every 15 minutes +UPDATE_INTERVAL = timedelta(minutes=15) -@callback -def _raise_no_data() -> None: - """Raise error when no data is available.""" - raise TibberPricesApiClientError(NO_DATA_ERROR_MSG) - - -@callback -def _get_latest_timestamp_from_prices( - price_data: dict | None, -) -> datetime | None: - """Get the latest timestamp from price data.""" - if not price_data: - return None - - try: - latest_timestamp = None - - # Check today's prices - if today_prices := price_data.get("today"): - for price in today_prices: - if starts_at := price.get("startsAt"): - timestamp = dt_util.parse_datetime(starts_at) - if timestamp and (not latest_timestamp or timestamp > latest_timestamp): - latest_timestamp = timestamp - - # Check tomorrow's prices - if tomorrow_prices := price_data.get("tomorrow"): - for price in tomorrow_prices: - if starts_at := price.get("startsAt"): - timestamp = dt_util.parse_datetime(starts_at) - if timestamp and (not latest_timestamp or timestamp > latest_timestamp): - latest_timestamp = timestamp - - except (KeyError, IndexError, TypeError): - return None - else: - return latest_timestamp - - -@callback -def _get_latest_timestamp_from_rating( - rating_data: dict | None, -) -> datetime | None: - """Get the latest timestamp from rating data.""" - if not rating_data or "priceRating" not in rating_data: - return None - - try: - price_rating = rating_data["priceRating"] - latest_timestamp = None - - # Check all rating types (hourly, daily, monthly) - for rating_type in ["hourly", "daily", "monthly"]: - if rating_entries := price_rating.get(rating_type, []): - for entry in rating_entries: - if time := entry.get("time"): - timestamp = dt_util.parse_datetime(time) - if timestamp and (not latest_timestamp or timestamp > latest_timestamp): - latest_timestamp = timestamp - except (KeyError, IndexError, TypeError): - return None - else: - return latest_timestamp - - -class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]): - """Coordinator for fetching Tibber price data.""" +class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Any]]): + """Enhanced coordinator with main/subentry pattern and comprehensive caching.""" def __init__( self, hass: HomeAssistant, - entry: TibberPricesConfigEntry, - *args: Any, - **kwargs: Any, + config_entry: ConfigEntry, ) -> None: - """Initialize coordinator with cache.""" - super().__init__(hass, *args, **kwargs) - self.config_entry = entry - storage_key = f"{DOMAIN}.{entry.entry_id}" - self._store = Store(hass, STORAGE_VERSION, storage_key) - self._cached_price_data: dict | None = None - self._cached_rating_data_hourly: dict | None = None - self._cached_rating_data_daily: dict | None = None - self._cached_rating_data_monthly: dict | None = None - self._last_price_update: datetime | None = None - self._last_rating_update_hourly: datetime | None = None - self._last_rating_update_daily: datetime | None = None - self._last_rating_update_monthly: datetime | None = None - self._remove_update_listeners: list[Any] = [] - self._force_update = False - self._rotation_lock = asyncio.Lock() - self._last_attempted_price_update: datetime | None = None - self._random_update_minute: int | None = None - self._random_update_date: date | None = None - self._remove_update_listeners.append( - async_track_time_change( - hass, - self._async_refresh_quarter_hour, - minute=[0, 15, 30, 45], - second=0, - ) + """Initialize the coordinator.""" + super().__init__( + hass, + _LOGGER, + name=DOMAIN, + update_interval=UPDATE_INTERVAL, ) - async def async_shutdown(self) -> None: - """Clean up coordinator on shutdown.""" - await super().async_shutdown() - for listener in self._remove_update_listeners: - listener() + self.config_entry = config_entry + self.api = TibberPricesApiClient( + access_token=config_entry.data[CONF_ACCESS_TOKEN], + session=aiohttp_client.async_get_clientsession(hass), + ) - async def async_request_refresh(self) -> None: - """Request an immediate refresh of the data.""" - self._force_update = True - await self.async_refresh() + # Storage for persistence + storage_key = f"{DOMAIN}.{config_entry.entry_id}" + self._store = Store(hass, STORAGE_VERSION, storage_key) - async def _async_refresh_quarter_hour(self, now: datetime | None = None) -> None: - """Refresh at every quarter hour, and rotate at midnight before update.""" - if now and now.hour == 0 and now.minute == 0: - if self._is_today_data_stale(): - LOGGER.warning("Detected stale 'today' data (not from today) at midnight. Forcing full refresh.") - await self._fetch_all_data() - else: - await self._perform_midnight_rotation() - await self.async_refresh() + # User data cache (updated daily) + self._cached_user_data: dict[str, Any] | None = None + self._last_user_update: datetime | None = None + self._user_update_interval = timedelta(days=1) + + # Price data cache + self._cached_price_data: dict[str, Any] | None = None + self._last_price_update: datetime | None = None + + # Track if this is the main entry (first one created) + self._is_main_entry = not self._has_existing_main_coordinator() + + def _has_existing_main_coordinator(self) -> bool: + """Check if there's already a main coordinator in hass.data.""" + domain_data = self.hass.data.get(DOMAIN, {}) + return any( + isinstance(coordinator, TibberPricesDataUpdateCoordinator) and coordinator.is_main_entry() + for coordinator in domain_data.values() + ) + + def is_main_entry(self) -> bool: + """Return True if this is the main entry that fetches data for all homes.""" + return self._is_main_entry + + async def _async_update_data(self) -> dict[str, Any]: + """Fetch data from Tibber API.""" + # Load cache if not already loaded + if self._cached_price_data is None and self._cached_user_data is None: + await self._load_cache() + + current_time = dt_util.utcnow() - async def _async_update_data(self) -> dict: - """Fetch new state data for the coordinator. Handles expired credentials by raising ConfigEntryAuthFailed.""" - if self._cached_price_data is None: - await self._handle_initialization() try: - current_time = dt_util.now() - if self._force_update: - LOGGER.debug( - "Force updating data", - extra={ - "reason": "force_update", - "last_success": self.last_update_success, - "last_price_update": self._last_price_update, - "last_rating_updates": { - "hourly": self._last_rating_update_hourly, - "daily": self._last_rating_update_daily, - "monthly": self._last_rating_update_monthly, - }, - }, - ) - self._force_update = False - return await self._fetch_all_data() - return await self._handle_conditional_update(current_time) + if self.is_main_entry(): + # Main entry fetches data for all homes + return await self._handle_main_entry_update(current_time) + # Subentries get data from main coordinator + return await self._handle_subentry_update() + + except TibberPricesApiClientAuthenticationError as err: + msg = "Invalid access token" + raise ConfigEntryAuthFailed(msg) from err except ( - TibberPricesApiClientAuthenticationError, - TimeoutError, TibberPricesApiClientCommunicationError, TibberPricesApiClientError, - ) as exception: - return await self._handle_update_exception(exception) - - async def _handle_initialization(self) -> None: - """Handle initialization and related errors for cached price data.""" - try: - await self._async_initialize() - except TimeoutError as exception: - msg = "Timeout during initialization" - LOGGER.error( - "%s: %s", - msg, - exception, - extra={"error_type": "timeout_init"}, - ) - raise UpdateFailed(msg) from exception - except TibberPricesApiClientAuthenticationError as exception: - msg = "Authentication failed: credentials expired or invalid" - LOGGER.error( - "Authentication failed (likely expired credentials) during initialization", - extra={"error": str(exception), "error_type": "auth_failed_init"}, - ) - raise ConfigEntryAuthFailed(msg) from exception - except Exception as exception: - msg = "Unexpected error during initialization" - LOGGER.exception( - "%s", - msg, - extra={"error": str(exception), "error_type": "unexpected_init"}, - ) - raise UpdateFailed(msg) from exception - - async def _handle_update_exception(self, exception: Exception) -> dict: - """Handle exceptions during update and return fallback or raise.""" - if isinstance(exception, TibberPricesApiClientAuthenticationError): - msg = "Authentication failed: credentials expired or invalid" - LOGGER.error( - "Authentication failed (likely expired credentials)", - extra={"error": str(exception), "error_type": "auth_failed"}, - ) - raise ConfigEntryAuthFailed(msg) from exception - if isinstance(exception, TimeoutError): - msg = "Timeout during data update" - LOGGER.warning( - "%s: %s", - msg, - exception, - extra={"error_type": "timeout_runtime"}, - ) + ) as err: + # Use cached data as fallback if available if self._cached_price_data is not None: - LOGGER.info("Using cached data as fallback after timeout") - return self._merge_all_cached_data() - raise UpdateFailed(msg) from exception - if isinstance(exception, TibberPricesApiClientCommunicationError): - LOGGER.error( - "API communication error", - extra={ - "error": str(exception), - "error_type": "communication_error", - }, - ) - elif isinstance(exception, TibberPricesApiClientError): - LOGGER.error( - "API client error", - extra={"error": str(exception), "error_type": "client_error"}, - ) - else: - LOGGER.exception( - "Unexpected error", - extra={"error": str(exception), "error_type": "unexpected"}, - ) + _LOGGER.warning("API error, using cached data: %s", err) + return self._merge_cached_data() + msg = f"Error communicating with API: {err}" + raise UpdateFailed(msg) from err + + async def _handle_main_entry_update(self, current_time: datetime) -> dict[str, Any]: + """Handle update for main entry - fetch data for all homes.""" + # Update user data if needed (daily check) + await self._update_user_data_if_needed(current_time) + + # Check if we need to update price data + if self._should_update_price_data(current_time): + raw_data = await self._fetch_all_homes_data() + # Cache the data + self._cached_price_data = raw_data + self._last_price_update = current_time + await self._store_cache() + # Transform for main entry: provide aggregated view + return self._transform_data_for_main_entry(raw_data) + + # Use cached data if self._cached_price_data is not None: - LOGGER.info("Using cached data as fallback") - return self._merge_all_cached_data() - raise UpdateFailed(UPDATE_FAILED_MSG) from exception + return self._transform_data_for_main_entry(self._cached_price_data) - async def _handle_conditional_update(self, current_time: datetime) -> dict: - """Handle conditional update based on update conditions.""" - update_conditions = self._check_update_conditions(current_time) - if any(update_conditions.values()): - LOGGER.debug( - "Updating data based on conditions", - extra=update_conditions, - ) - return await self._fetch_all_data() - if self._cached_price_data is not None: - LOGGER.debug("Using cached data") - return self._merge_all_cached_data() - LOGGER.debug("No cached data available, fetching new data") - return await self._fetch_all_data() - - async def _fetch_all_data(self) -> dict: - """Fetch all data from the API without checking update conditions.""" - current_time = dt_util.now() - new_data = { - "price_data": None, - "rating_data": {"hourly": None, "daily": None, "monthly": None}, - } - try: - price_data = await self._fetch_price_data() - new_data["price_data"] = self._extract_data(price_data, "priceInfo", ("yesterday", "today", "tomorrow")) - for rating_type in ["hourly", "daily", "monthly"]: - try: - rating_data = await self._get_rating_data_for_type(rating_type) - new_data["rating_data"][rating_type] = rating_data - except TibberPricesApiClientError as ex: - LOGGER.error("Failed to fetch %s rating data: %s", rating_type, ex) - except TibberPricesApiClientError as ex: - LOGGER.error("Failed to fetch price data: %s", ex) - if self._cached_price_data is not None: - LOGGER.info("Using cached data as fallback after price data fetch failure") - return self._merge_all_cached_data() - raise - if new_data["price_data"] is None: - LOGGER.error("No price data available after fetch") - if self._cached_price_data is not None: - LOGGER.info("Using cached data as fallback due to missing price data") - return self._merge_all_cached_data() - _raise_no_data() - self._cached_price_data = cast("dict", new_data["price_data"]) + # No cached data, fetch new + raw_data = await self._fetch_all_homes_data() + self._cached_price_data = raw_data self._last_price_update = current_time - for rating_type, rating_data in new_data["rating_data"].items(): - if rating_data is not None: - self._update_rating_cache(rating_type, rating_data, current_time) await self._store_cache() - LOGGER.debug("Updated and stored all cache data at %s", current_time) - return self._merge_all_cached_data() + return self._transform_data_for_main_entry(raw_data) - async def _fetch_price_data(self) -> dict: - """Fetch fresh price data from API. Assumes errors are handled in api.py.""" - client = self.config_entry.runtime_data.client - home_id = self.config_entry.unique_id - if not home_id: - LOGGER.error("No home_id (unique_id) set in config entry!") - return {} - data = await client.async_get_price_info(home_id) - if not data: - return {} - price_info = data.get("priceInfo", {}) - if not price_info: - return {} - return price_info + async def _handle_subentry_update(self) -> dict[str, Any]: + """Handle update for subentry - get data from main coordinator.""" + main_data = await self._get_data_from_main_coordinator() + return self._transform_data_for_subentry(main_data) - async def _get_rating_data_for_type(self, rating_type: str) -> dict: - """Get fresh rating data for a specific type in flat format. Assumes errors are handled in api.py.""" - client = self.config_entry.runtime_data.client - home_id = self.config_entry.unique_id - if not home_id: - LOGGER.error("No home_id (unique_id) set in config entry!") - return {} - method_map = { - "hourly": client.async_get_hourly_price_rating, - "daily": client.async_get_daily_price_rating, - "monthly": client.async_get_monthly_price_rating, + async def _fetch_all_homes_data(self) -> dict[str, Any]: + """Fetch data for all homes (main coordinator only).""" + _LOGGER.debug("Fetching data for all homes") + + # Get price data for all homes + price_data = await self.api.async_get_price_info() + + # Get rating data for all homes + hourly_rating = await self.api.async_get_hourly_price_rating() + daily_rating = await self.api.async_get_daily_price_rating() + monthly_rating = await self.api.async_get_monthly_price_rating() + + all_homes_data = {} + homes_list = price_data.get("homes", {}) + + for home_id, home_price_data in homes_list.items(): + home_data = { + "price_info": home_price_data, + "hourly_rating": hourly_rating.get("homes", {}).get(home_id, {}), + "daily_rating": daily_rating.get("homes", {}).get(home_id, {}), + "monthly_rating": monthly_rating.get("homes", {}).get(home_id, {}), + } + all_homes_data[home_id] = home_data + + return { + "timestamp": dt_util.utcnow(), + "homes": all_homes_data, } - fetch_method = method_map.get(rating_type) - if not fetch_method: - msg = f"Unknown rating type: {rating_type}" - raise ValueError(msg) - data = await fetch_method(home_id) - if not data: - return {} - try: - price_rating = data.get("priceRating", data) - threshold = price_rating.get("thresholdPercentages") - entries = price_rating.get(rating_type, []) - currency = price_rating.get("currency") - except KeyError as ex: - LOGGER.error("Failed to extract rating data (flat format): %s", ex) - raise TibberPricesApiClientError( - TibberPricesApiClientError.EMPTY_DATA_ERROR.format(query_type=rating_type) - ) from ex - return {"priceRating": {rating_type: entries, "thresholdPercentages": threshold, "currency": currency}} - async def _async_initialize(self) -> None: - """Load stored data in flat format and check for stale 'today' data.""" - stored = await self._store.async_load() - if stored is None: - LOGGER.warning("No cache file found or cache is empty on startup.") - if stored: - self._cached_price_data = stored.get("price_data") - self._cached_rating_data_hourly = stored.get("rating_data_hourly") - self._cached_rating_data_daily = stored.get("rating_data_daily") - self._cached_rating_data_monthly = stored.get("rating_data_monthly") - self._last_price_update = self._recover_timestamp(self._cached_price_data, stored.get("last_price_update")) - self._last_rating_update_hourly = self._recover_timestamp( - self._cached_rating_data_hourly, - stored.get("last_rating_update_hourly"), - "hourly", - ) - self._last_rating_update_daily = self._recover_timestamp( - self._cached_rating_data_daily, - stored.get("last_rating_update_daily"), - "daily", - ) - self._last_rating_update_monthly = self._recover_timestamp( - self._cached_rating_data_monthly, - stored.get("last_rating_update_monthly"), - "monthly", - ) - LOGGER.debug( - "Loaded stored cache data - Price update: %s, Rating hourly: %s, daily: %s, monthly: %s", - self._last_price_update, - self._last_rating_update_hourly, - self._last_rating_update_daily, - self._last_rating_update_monthly, - ) - if self._cached_price_data is None: - LOGGER.warning("Cached price data missing after cache load!") - if self._last_price_update is None: - LOGGER.warning("Price update timestamp missing after cache load!") - # Stale data detection on startup - if self._is_today_data_stale(): - LOGGER.warning("Detected stale 'today' data on startup (not from today). Forcing full refresh.") - await self._fetch_all_data() - else: - LOGGER.info("No cache loaded; will fetch fresh data on first update.") + async def _get_data_from_main_coordinator(self) -> dict[str, Any]: + """Get data from the main coordinator (subentries only).""" + # Find the main coordinator + main_coordinator = self._find_main_coordinator() + if not main_coordinator: + msg = "Main coordinator not found" + raise UpdateFailed(msg) + + # Wait for main coordinator to have data + if main_coordinator.data is None: + main_coordinator.async_set_updated_data({}) + + # Return the main coordinator's data + return main_coordinator.data or {} + + def _find_main_coordinator(self) -> TibberPricesDataUpdateCoordinator | None: + """Find the main coordinator that fetches data for all homes.""" + domain_data = self.hass.data.get(DOMAIN, {}) + for coordinator in domain_data.values(): + if ( + isinstance(coordinator, TibberPricesDataUpdateCoordinator) + and coordinator.is_main_entry() + and coordinator != self + ): + return coordinator + return None + + async def _load_cache(self) -> None: + """Load cached data from storage.""" + try: + stored = await self._store.async_load() + if stored: + self._cached_price_data = stored.get("price_data") + self._cached_user_data = stored.get("user_data") + + # Restore timestamps + if last_price_update := stored.get("last_price_update"): + self._last_price_update = dt_util.parse_datetime(last_price_update) + if last_user_update := stored.get("last_user_update"): + self._last_user_update = dt_util.parse_datetime(last_user_update) + + _LOGGER.debug("Cache loaded successfully") + else: + _LOGGER.debug("No cache found, will fetch fresh data") + except OSError as ex: + _LOGGER.warning("Failed to load cache: %s", ex) async def _store_cache(self) -> None: - """Store cache data in flat format.""" + """Store cache data.""" data = { "price_data": self._cached_price_data, - "rating_data_hourly": self._cached_rating_data_hourly, - "rating_data_daily": self._cached_rating_data_daily, - "rating_data_monthly": self._cached_rating_data_monthly, + "user_data": self._cached_user_data, "last_price_update": (self._last_price_update.isoformat() if self._last_price_update else None), - "last_rating_update_hourly": ( - self._last_rating_update_hourly.isoformat() if self._last_rating_update_hourly else None - ), - "last_rating_update_daily": ( - self._last_rating_update_daily.isoformat() if self._last_rating_update_daily else None - ), - "last_rating_update_monthly": ( - self._last_rating_update_monthly.isoformat() if self._last_rating_update_monthly else None - ), + "last_user_update": (self._last_user_update.isoformat() if self._last_user_update else None), } - LOGGER.debug( - "Storing cache data with timestamps: %s", - {k: v for k, v in data.items() if k.startswith("last_")}, - ) - if data["price_data"] is None: - LOGGER.warning("Attempting to store cache with missing price_data!") - if data["last_price_update"] is None: - LOGGER.warning("Attempting to store cache with missing last_price_update!") + try: await self._store.async_save(data) - LOGGER.debug("Cache successfully written to disk.") - except OSError as ex: - LOGGER.error("Failed to write cache to disk: %s", ex) + _LOGGER.debug("Cache stored successfully") + except OSError: + _LOGGER.exception("Failed to store cache") - async def _perform_midnight_rotation(self) -> None: - """Perform the data rotation at midnight within the hourly update process.""" - LOGGER.info("Performing midnight data rotation as part of hourly update cycle") - if not self._cached_price_data: - LOGGER.debug("No cached price data available for midnight rotation") - return - async with self._rotation_lock: + async def _update_user_data_if_needed(self, current_time: datetime) -> None: + """Update user data if needed (daily check).""" + if self._last_user_update is None or current_time - self._last_user_update >= self._user_update_interval: try: - today_count = len(self._cached_price_data.get("today", [])) - tomorrow_count = len(self._cached_price_data.get("tomorrow", [])) - yesterday_count = len(self._cached_price_data.get("yesterday", [])) - LOGGER.debug( - "Before rotation - Yesterday: %d, Today: %d, Tomorrow: %d items", - yesterday_count, - today_count, - tomorrow_count, - ) - if today_data := self._cached_price_data.get("today"): - self._cached_price_data["yesterday"] = today_data - else: - LOGGER.warning("No today's data available to move to yesterday") - if tomorrow_data := self._cached_price_data.get("tomorrow"): - self._cached_price_data["today"] = tomorrow_data - self._cached_price_data["tomorrow"] = [] - else: - LOGGER.warning("No tomorrow's data available to move to today") - await self._store_cache() - LOGGER.info( - "Completed midnight rotation - Yesterday: %d, Today: %d, Tomorrow: %d items", - len(self._cached_price_data.get("yesterday", [])), - len(self._cached_price_data.get("today", [])), - len(self._cached_price_data.get("tomorrow", [])), - ) - self._force_update = True - except (KeyError, TypeError, ValueError) as ex: - LOGGER.error("Error during midnight data rotation in hourly update: %s", ex) - - @callback - def _check_update_conditions(self, current_time: datetime) -> dict[str, bool]: - """Check all update conditions and return results as a dictionary.""" - return { - "update_price": self._should_update_price_data(current_time), - "update_hourly": self._should_update_rating_type( - current_time, - self._cached_rating_data_hourly, - self._last_rating_update_hourly, - "hourly", - ), - "update_daily": self._should_update_rating_type( - current_time, - self._cached_rating_data_daily, - self._last_rating_update_daily, - "daily", - ), - "update_monthly": self._should_update_rating_type( - current_time, - self._cached_rating_data_monthly, - self._last_rating_update_monthly, - "monthly", - ), - } - - def _log_update_decision(self, ctx: dict) -> None: - """Log update decision context for debugging.""" - LOGGER.debug("[tibber_prices] Update decision: %s", ctx) - - def _get_tomorrow_data_status(self) -> tuple[int, bool]: - """Return (interval_count, tomorrow_data_complete) for tomorrow's prices (flat structure).""" - tomorrow_prices = [] - if self._cached_price_data: - raw_tomorrow = self._cached_price_data.get("tomorrow", []) - if raw_tomorrow is None: - LOGGER.warning( - "Tomorrow price data is None, treating as empty list. Full price_data: %s", - self._cached_price_data, - ) - tomorrow_prices = [] - elif not isinstance(raw_tomorrow, list): - LOGGER.warning( - "Tomorrow price data is not a list: %r. Full price_data: %s", - raw_tomorrow, - self._cached_price_data, - ) - tomorrow_prices = list(raw_tomorrow) if hasattr(raw_tomorrow, "__iter__") else [] - else: - tomorrow_prices = raw_tomorrow - else: - LOGGER.warning("No cached price_data available: %s", self._cached_price_data) - interval_count = len(tomorrow_prices) - min_tomorrow_intervals_hourly = 24 - min_tomorrow_intervals_15min = 96 - tomorrow_data_complete = interval_count in {min_tomorrow_intervals_hourly, min_tomorrow_intervals_15min} - if interval_count == 0: - LOGGER.debug( - "Tomorrow price data is empty at late hour. Raw tomorrow data: %s | Full price_data: %s", - tomorrow_prices, - self._cached_price_data, - ) - return interval_count, tomorrow_data_complete + _LOGGER.debug("Updating user data") + user_data = await self.api.async_get_viewer_details() + self._cached_user_data = user_data + self._last_user_update = current_time + _LOGGER.debug("User data updated successfully") + except (TibberPricesApiClientError, TibberPricesApiClientCommunicationError) as ex: + _LOGGER.warning("Failed to update user data: %s", ex) @callback def _should_update_price_data(self, current_time: datetime) -> bool: - """Decide if price data should be updated. Logs all decision points for debugging.""" - should_update, log_ctx = self._decide_price_update(current_time) - self._log_update_decision(log_ctx) - return should_update - - @callback - def _should_update_rating_type( - self, - current_time: datetime, - cached_data: dict | None, - last_update: datetime | None, - rating_type: str, - ) -> bool: - def extra_check_monthly(now: datetime, latest: datetime) -> bool: - current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - return latest < current_month_start - - if rating_type == "monthly": - return self._should_update_data( - current_time, - cached_data, - last_update, - lambda d: self._get_latest_rating_timestamp(d, rating_type), - config={ - "interval": timedelta(days=1), - "extra_check": extra_check_monthly, - }, - ) - return self._should_update_data( - current_time, - cached_data, - last_update, - lambda d: self._get_latest_rating_timestamp(d, rating_type), - config={ - "update_window": (PRICE_UPDATE_RANDOM_MIN_HOUR, PRICE_UPDATE_RANDOM_MAX_HOUR), - "interval": UPDATE_INTERVAL, - }, - ) - - @callback - def _should_update_data( - self, - current_time: datetime, - cached_data: dict | None, - last_update: datetime | None, - timestamp_func: Callable[[dict | None], datetime | None], - config: dict | None = None, - ) -> bool: - """Generalized update check for any data type.""" - config = config or {} - update_window = config.get("update_window") - interval = config.get("interval", UPDATE_INTERVAL) - extra_check = config.get("extra_check") - if cached_data is None: + """Check if price data should be updated.""" + if self._cached_price_data is None: return True - latest_timestamp = timestamp_func(cached_data) - if not latest_timestamp: + if self._last_price_update is None: return True - # Always use last_update if present and valid - if last_update and (current_time - last_update) < interval: - return False - if not last_update: - last_update = latest_timestamp - if update_window: - current_hour = current_time.hour - if update_window[0] <= current_hour <= update_window[1]: - tomorrow = (current_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) - if latest_timestamp < tomorrow: - return True - if last_update and current_time - last_update >= interval: - return True - return extra_check(current_time, latest_timestamp) if extra_check else False + # Update every 15 minutes + return (current_time - self._last_price_update) >= UPDATE_INTERVAL @callback - def _extract_data(self, data: dict, container_key: str, keys: tuple[str, ...]) -> dict: - """Extract and harmonize data for caching in flat format.""" - # For price data, just flatten to {key: list} for each key - try: - container = data[container_key] - if not isinstance(container, dict): - LOGGER.error( - "Extracted %s is not a dict: %r. Full data: %s", - container_key, - container, - data, - ) - container = {} - extracted = {key: list(container.get(key, [])) for key in keys} - except (KeyError, IndexError, TypeError): - # For flat price data, just copy keys from data - extracted = {key: list(data.get(key, [])) for key in keys} - return extracted - - @callback - def _update_rating_cache(self, rating_type: str, rating_data: dict, current_time: datetime) -> None: - """Update the rating cache for a specific rating type.""" - if rating_type == "hourly": - self._cached_rating_data_hourly = cast("dict", rating_data) - self._last_rating_update_hourly = current_time - elif rating_type == "daily": - self._cached_rating_data_daily = cast("dict", rating_data) - self._last_rating_update_daily = current_time - else: - self._cached_rating_data_monthly = cast("dict", rating_data) - self._last_rating_update_monthly = current_time - LOGGER.debug("Updated %s rating data cache at %s", rating_type, current_time) - - @callback - def _merge_all_cached_data(self) -> dict: - """Merge all cached data into Home Assistant-style structure: priceInfo, priceRating, currency.""" + def _merge_cached_data(self) -> dict[str, Any]: + """Merge cached data into the expected format for main entry.""" if not self._cached_price_data: return {} - merged = { - "priceInfo": dict(self._cached_price_data), # 'today', 'tomorrow', 'yesterday' under 'priceInfo' - } + return self._transform_data_for_main_entry(self._cached_price_data) + + def _transform_data_for_main_entry(self, raw_data: dict[str, Any]) -> dict[str, Any]: + """Transform raw data for main entry (aggregated view of all homes).""" + # For main entry, we can show data from the first home as default + # or provide an aggregated view + homes_data = raw_data.get("homes", {}) + if not homes_data: + return { + "timestamp": raw_data.get("timestamp"), + "homes": {}, + "priceInfo": {}, + "priceRating": {}, + } + + # Use the first home's data as the main entry's data + first_home_data = next(iter(homes_data.values())) + price_info = first_home_data.get("price_info", {}) + + # Combine rating data price_rating = { - "hourly": [], - "daily": [], - "monthly": [], - "thresholdPercentages": None, - "currency": None, + "hourly": first_home_data.get("hourly_rating", {}), + "daily": first_home_data.get("daily_rating", {}), + "monthly": first_home_data.get("monthly_rating", {}), } - for rating_type, cached in zip( - ["hourly", "daily", "monthly"], - [self._cached_rating_data_hourly, self._cached_rating_data_daily, self._cached_rating_data_monthly], - strict=True, - ): - if cached and "priceRating" in cached: - entries = cached["priceRating"].get(rating_type, []) - price_rating[rating_type] = entries - if not price_rating["thresholdPercentages"]: - price_rating["thresholdPercentages"] = cached["priceRating"].get("thresholdPercentages") - if not price_rating["currency"]: - price_rating["currency"] = cached["priceRating"].get("currency") - merged["priceRating"] = price_rating - merged["currency"] = price_rating["currency"] - return merged - @callback - def _recover_timestamp( - self, - data: dict | None, - stored_timestamp: str | None, - rating_type: str | None = None, - ) -> datetime | None: - """Recover timestamp from stored value or data.""" - # Always prefer the stored timestamp if present and valid - if stored_timestamp: - ts = dt_util.parse_datetime(stored_timestamp) - if ts: - return ts - # Fallback to data-derived timestamp - if not data: - return None - if rating_type: - timestamp = self._get_latest_rating_timestamp(data, rating_type) - else: - timestamp = self._get_latest_price_timestamp(data) - if timestamp: - LOGGER.debug( - "Recovered %s timestamp from data: %s", - rating_type or "price", - timestamp, - ) - return timestamp + return { + "timestamp": raw_data.get("timestamp"), + "homes": homes_data, + "priceInfo": price_info, + "priceRating": price_rating, + } - @callback - def _get_latest_timestamp( - self, - data: dict | None, - container_key: str, - entry_key: str | None = None, - time_field: str = "startsAt", - ) -> datetime | None: - """Get the latest timestamp from a container in data, optionally for a subkey and time field.""" - if not data or container_key not in data: - return None - try: - container = data[container_key] - if entry_key: - container = container.get(entry_key, []) - latest = None - for entry in container: - time_str = entry.get(time_field) - if time_str: - timestamp = dt_util.parse_datetime(time_str) - if timestamp and (not latest or timestamp > latest): - latest = timestamp - except (KeyError, IndexError, TypeError): - return None - return latest + def _transform_data_for_subentry(self, main_data: dict[str, Any]) -> dict[str, Any]: + """Transform main coordinator data for subentry (home-specific view).""" + home_id = self.config_entry.data.get("home_id") + if not home_id: + return main_data - @callback - def _get_latest_price_timestamp(self, price_data: dict | None) -> datetime | None: - """Get the latest timestamp from price data (today and tomorrow).""" - if not price_data: - return None - today = self._get_latest_timestamp(price_data, "today", None, "startsAt") - tomorrow = self._get_latest_timestamp(price_data, "tomorrow", None, "startsAt") - if today and tomorrow: - return max(today, tomorrow) - return today or tomorrow + homes_data = main_data.get("homes", {}) + home_data = homes_data.get(home_id, {}) - @callback - def _get_latest_rating_timestamp(self, rating_data: dict | None, rating_type: str | None = None) -> datetime | None: - """Get the latest timestamp from rating data, optionally for a specific type.""" - if not rating_type: - latest = None - for rtype in ("hourly", "daily", "monthly"): - ts = self._get_latest_timestamp(rating_data, "priceRating", rtype, "time") - if ts and (not latest or ts > latest): - latest = ts - return latest - return self._get_latest_timestamp(rating_data, "priceRating", rating_type, "time") + if not home_data: + return { + "timestamp": main_data.get("timestamp"), + "priceInfo": {}, + "priceRating": {}, + } - @callback - def _get_latest_timestamp_from_rating_type(self, rating_data: dict | None, rating_type: str) -> datetime | None: - """Get the latest timestamp from a specific rating type.""" - if not rating_data or "priceRating" not in rating_data: + price_info = home_data.get("price_info", {}) + + # Combine rating data for this specific home + price_rating = { + "hourly": home_data.get("hourly_rating", {}), + "daily": home_data.get("daily_rating", {}), + "monthly": home_data.get("monthly_rating", {}), + } + + return { + "timestamp": main_data.get("timestamp"), + "priceInfo": price_info, + "priceRating": price_rating, + } + + # --- Methods expected by sensors and services --- + + def get_home_data(self, home_id: str) -> dict[str, Any] | None: + """Get data for a specific home.""" + if not self.data: return None - try: - price_rating = rating_data["priceRating"] - result = None + homes_data = self.data.get("homes", {}) + return homes_data.get(home_id) - if rating_entries := price_rating.get(rating_type, []): - for entry in rating_entries: - if time := entry.get("time"): - timestamp = dt_util.parse_datetime(time) - if timestamp and (not result or timestamp > result): - result = timestamp - except (KeyError, IndexError, TypeError): + def get_current_interval_data(self) -> dict[str, Any] | None: + """Get the price data for the current interval.""" + if not self.data: return None - return result - def get_all_intervals(self) -> list[dict]: - """Return a combined, sorted list of all price intervals for yesterday, today, and tomorrow.""" - price_info = self.data.get("priceInfo", {}) if self.data else {} - all_prices = price_info.get("yesterday", []) + price_info.get("today", []) + price_info.get("tomorrow", []) - return sorted( - all_prices, - key=lambda p: dt_util.parse_datetime(p.get("startsAt") or "") or dt_util.now(), - ) - - def get_interval_granularity(self) -> int | None: - """Return the interval granularity in minutes (e.g., 15 or 60) for today's data.""" - price_info = self.data.get("priceInfo", {}) if self.data else {} - today_prices = price_info.get("today", []) - from .sensor import detect_interval_granularity - - return detect_interval_granularity(today_prices) if today_prices else None - - def get_current_interval_data(self) -> dict | None: - """Return the price data for the current interval.""" - price_info = self.data.get("priceInfo", {}) if self.data else {} + price_info = self.data.get("priceInfo", {}) if not price_info: return None - now = dt_util.now() - interval_length = self.get_interval_granularity() + from .sensor import find_price_data_for_interval - return find_price_data_for_interval(price_info, now, interval_length) + now = dt_util.now() + return find_price_data_for_interval(price_info, now) - def get_combined_price_info(self) -> dict: - """Return a dict with all intervals under a single key 'all'.""" - return {"all": self.get_all_intervals()} + def get_all_intervals(self) -> list[dict[str, Any]]: + """Get all price intervals (today + tomorrow).""" + if not self.data: + return [] - def is_tomorrow_data_available(self) -> bool | None: - """Return True if tomorrow's data is fully available, False if not, None if unknown.""" - tomorrow_prices = self.data.get("priceInfo", {}).get("tomorrow", []) if self.data else [] - interval_count = len(tomorrow_prices) - min_tomorrow_intervals_hourly = 24 - min_tomorrow_intervals_15min = 96 - tomorrow_interval_counts = {min_tomorrow_intervals_hourly, min_tomorrow_intervals_15min} - return interval_count in tomorrow_interval_counts + price_info = self.data.get("priceInfo", {}) + today_prices = price_info.get("today", []) + tomorrow_prices = price_info.get("tomorrow", []) + return today_prices + tomorrow_prices - def _transform_api_response(self, data: dict[str, Any]) -> dict: - """Transform API response to coordinator data format.""" - return cast("dict", data) + def get_interval_granularity(self) -> int | None: + """Get the granularity of price intervals in minutes.""" + all_intervals = self.get_all_intervals() + if not all_intervals: + return None - def _should_update_random_window(self, current_time: datetime, log_ctx: dict) -> tuple[bool, dict]: - """Determine if a random update should occur in the random window (13:00-15:00).""" - today = current_time.date() - if self._random_update_date != today or self._random_update_minute is None: - self._random_update_date = today - import secrets + from .sensor import detect_interval_granularity - self._random_update_minute = secrets.randbelow(RANDOM_DELAY_MAX_MINUTES) - log_ctx["window"] = "random" - log_ctx["random_update_minute"] = self._random_update_minute - log_ctx["current_minute"] = current_time.minute - if current_time.minute == self._random_update_minute: - if self._last_attempted_price_update: - since_last = current_time - self._last_attempted_price_update - log_ctx["since_last_attempt"] = str(since_last) - if since_last >= MIN_RETRY_INTERVAL: - self._last_attempted_price_update = current_time - log_ctx["reason"] = "random window, random minute, min retry met" - log_ctx["decision"] = True - return True, log_ctx - log_ctx["reason"] = "random window, random minute, min retry not met" - log_ctx["decision"] = False - return False, log_ctx - self._last_attempted_price_update = current_time - log_ctx["reason"] = "random window, first attempt" - log_ctx["decision"] = True - return True, log_ctx - log_ctx["reason"] = "random window, not random minute" - log_ctx["decision"] = False - return False, log_ctx + return detect_interval_granularity(all_intervals) - def _decide_price_update(self, current_time: datetime) -> tuple[bool, dict]: - current_hour = current_time.hour - log_ctx = { - "current_time": str(current_time), - "current_hour": current_hour, - "has_cached_price_data": bool(self._cached_price_data), - "last_price_update": str(self._last_price_update) if self._last_price_update else None, - } - should_update = False - if current_hour < PRICE_UPDATE_RANDOM_MIN_HOUR: - should_update = not self._cached_price_data - log_ctx["window"] = "early" - log_ctx["reason"] = "no cache" if should_update else "cache present" - log_ctx["decision"] = should_update - return should_update, log_ctx - interval_count, tomorrow_data_complete = self._get_tomorrow_data_status() - log_ctx["interval_count"] = interval_count - log_ctx["tomorrow_data_complete"] = tomorrow_data_complete - in_random_window = PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour < PRICE_UPDATE_RANDOM_MAX_HOUR - in_late_window = PRICE_UPDATE_RANDOM_MAX_HOUR <= current_hour < END_OF_DAY_HOUR - if ( - tomorrow_data_complete - and self._last_price_update - and (current_time - self._last_price_update) < UPDATE_INTERVAL + async def refresh_user_data(self) -> bool: + """Force refresh of user data and return True if data was updated.""" + try: + current_time = dt_util.utcnow() + await self._update_user_data_if_needed(current_time) + await self._store_cache() + except ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientCommunicationError, + TibberPricesApiClientError, ): - should_update = False - log_ctx["window"] = "any" - log_ctx["reason"] = "tomorrow_data_complete and last_price_update < 24h" - log_ctx["decision"] = should_update - return should_update, log_ctx - if in_random_window and not tomorrow_data_complete: - return self._should_update_random_window(current_time, log_ctx) - if in_late_window and not tomorrow_data_complete: - should_update = True - log_ctx["window"] = "late" - log_ctx["reason"] = "late window, tomorrow data missing (force update)" - log_ctx["decision"] = should_update - return should_update, log_ctx - should_update = False - log_ctx["window"] = "late-or-random" - log_ctx["reason"] = "no update needed" - log_ctx["decision"] = should_update - return should_update, log_ctx + return False + else: + return True - def _is_today_data_stale(self) -> bool: - """Return True if the first 'today' interval is not from today (stale cache).""" - if not self._cached_price_data: - return True - today_prices = self._cached_price_data.get("today", []) - if not today_prices: - return True # No data, treat as stale - first = today_prices[0] - starts_at = first.get("startsAt") - if not starts_at: - return True - dt = dt_util.parse_datetime(starts_at) - if not dt: - return True - return dt.date() != dt_util.now().date() + def get_user_profile(self) -> dict[str, Any]: + """Get user profile information.""" + return { + "last_updated": self._last_user_update, + "cached_user_data": self._cached_user_data is not None, + } + + def get_user_homes(self) -> list[dict[str, Any]]: + """Get list of user homes.""" + if not self._cached_user_data: + return [] + return self._cached_user_data.get("homes", []) + + @callback + def async_add_listener(self, update_callback: Callable[[], None]) -> Callable[[], None]: + """Add a listener for updates.""" + return super().async_add_listener(update_callback) diff --git a/custom_components/tibber_prices/entity.py b/custom_components/tibber_prices/entity.py index dd009eb..ab00ae4 100644 --- a/custom_components/tibber_prices/entity.py +++ b/custom_components/tibber_prices/entity.py @@ -27,16 +27,47 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]): "COTTAGE": "Cottage", } - # Get home info from Tibber API if available + # Get user profile information from coordinator + user_profile = self.coordinator.get_user_profile() + + # Check if this is a main entry or subentry + is_subentry = bool(self.coordinator.config_entry.data.get("home_id")) + + # Initialize variables home_name = "Tibber Home" home_id = self.coordinator.config_entry.unique_id home_type = None - city = None - app_nickname = None - address1 = None - if coordinator.data: + + if is_subentry: + # For subentries, show specific home information + home_data = self.coordinator.config_entry.data.get("home_data", {}) + home_id = self.coordinator.config_entry.data.get("home_id") + + # Get home details + address = home_data.get("address", {}) + address1 = address.get("address1", "") + city = address.get("city", "") + app_nickname = home_data.get("appNickname", "") + home_type = home_data.get("type", "") + + # Compose home name + home_name = app_nickname or address1 or f"Tibber Home {home_id}" + if city: + home_name = f"{home_name}, {city}" + + # Add user information if available + if user_profile and user_profile.get("name"): + home_name = f"{home_name} ({user_profile['name']})" + elif user_profile: + # For main entry, show user profile information + user_name = user_profile.get("name", "Tibber User") + user_email = user_profile.get("email", "") + home_name = f"Tibber - {user_name}" + if user_email: + home_name = f"{home_name} ({user_email})" + elif coordinator.data: + # Fallback to original logic if user data not available yet try: - home_id = self.unique_id address1 = str(coordinator.data.get("address", {}).get("address1", "")) city = str(coordinator.data.get("address", {}).get("city", "")) app_nickname = str(coordinator.data.get("appNickname", "")) @@ -47,8 +78,6 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]): home_name = f"{home_name}, {city}" except (KeyError, IndexError, TypeError): home_name = "Tibber Home" - else: - home_name = "Tibber Home" self._attr_device_info = DeviceInfo( entry_type=DeviceEntryType.SERVICE, diff --git a/custom_components/tibber_prices/manifest.json b/custom_components/tibber_prices/manifest.json index 86f0c92..e8a694d 100644 --- a/custom_components/tibber_prices/manifest.json +++ b/custom_components/tibber_prices/manifest.json @@ -9,6 +9,7 @@ "iot_class": "cloud_polling", "issue_tracker": "https://github.com/jpawlowski/hass.tibber_prices/issues", "version": "0.1.0", + "homeassistant": "2024.1.0", "requirements": [ "aiofiles>=23.2.1" ] diff --git a/custom_components/tibber_prices/services.py b/custom_components/tibber_prices/services.py index e1c6a66..f99457c 100644 --- a/custom_components/tibber_prices/services.py +++ b/custom_components/tibber_prices/services.py @@ -13,6 +13,11 @@ from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry from homeassistant.util import dt as dt_util +from .api import ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientCommunicationError, + TibberPricesApiClientError, +) from .const import ( DOMAIN, PRICE_LEVEL_CHEAP, @@ -29,6 +34,7 @@ from .const import ( PRICE_SERVICE_NAME = "get_price" APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data" APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml" +REFRESH_USER_DATA_SERVICE_NAME = "refresh_user_data" ATTR_DAY: Final = "day" ATTR_ENTRY_ID: Final = "entry_id" ATTR_TIME: Final = "time" @@ -68,6 +74,12 @@ APEXCHARTS_SERVICE_SCHEMA: Final = vol.Schema( } ) +REFRESH_USER_DATA_SERVICE_SCHEMA: Final = vol.Schema( + { + vol.Required(ATTR_ENTRY_ID): str, + } +) + # region Top-level functions (ordered by call hierarchy) # --- Entry point: Service handler --- @@ -165,41 +177,59 @@ async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]: level_type = call.data.get("level_type", "rating_level") level_key = call.data.get("level_key") hass = call.hass + + # Get entry ID and verify it exists entry_id = await _get_entry_id_from_entity_id(hass, entity_id) if not entry_id: raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id") + entry, coordinator, data = _get_entry_and_data(hass, entry_id) - points = [] + + # Get entries based on level_type + entries = _get_apexcharts_entries(coordinator, day, level_type) + if not entries: + return {"points": []} + + # Ensure level_key is a string + if level_key is None: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_level_key") + + # Generate points for the chart + points = _generate_apexcharts_points(entries, str(level_key)) + return {"points": points} + + +def _get_apexcharts_entries(coordinator: Any, day: str, level_type: str) -> list[dict]: + """Get the appropriate entries for ApexCharts based on level_type and day.""" if level_type == "rating_level": entries = coordinator.data.get("priceRating", {}).get("hourly", []) price_info = coordinator.data.get("priceInfo", {}) - if day == "today": - prefixes = _get_day_prefixes(price_info.get("today", [])) - if not prefixes: - return {"points": []} - entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] - elif day == "tomorrow": - prefixes = _get_day_prefixes(price_info.get("tomorrow", [])) - if not prefixes: - return {"points": []} - entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] - elif day == "yesterday": - prefixes = _get_day_prefixes(price_info.get("yesterday", [])) - if not prefixes: - return {"points": []} - entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] - else: - entries = coordinator.data.get("priceInfo", {}).get(day, []) - if not entries: - return {"points": []} + day_info = price_info.get(day, []) + prefixes = _get_day_prefixes(day_info) + + if not prefixes: + return [] + + return [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] + + # For non-rating level types, return the price info for the specified day + return coordinator.data.get("priceInfo", {}).get(day, []) + + +def _generate_apexcharts_points(entries: list[dict], level_key: str) -> list: + """Generate data points for ApexCharts based on the entries and level key.""" + points = [] for i in range(len(entries) - 1): p = entries[i] if p.get("level") != level_key: continue points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)]) + + # Add a final point with null value if there are any points if points: points.append([points[-1][0], None]) - return {"points": points} + + return points async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]: @@ -265,6 +295,57 @@ async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]: } +async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]: + """Refresh user data for a specific config entry and return updated information.""" + entry_id = call.data.get(ATTR_ENTRY_ID) + hass = call.hass + + if not entry_id: + return { + "success": False, + "message": "Entry ID is required", + } + + # Get the entry and coordinator + try: + entry, coordinator, data = _get_entry_and_data(hass, entry_id) + except ServiceValidationError as ex: + return { + "success": False, + "message": f"Invalid entry ID: {ex}", + } + + # Force refresh user data using the public method + try: + updated = await coordinator.refresh_user_data() + except ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientCommunicationError, + TibberPricesApiClientError, + ) as ex: + return { + "success": False, + "message": f"API error refreshing user data: {ex!s}", + } + else: + if updated: + user_profile = coordinator.get_user_profile() + homes = coordinator.get_user_homes() + + return { + "success": True, + "message": "User data refreshed successfully", + "user_profile": user_profile, + "homes_count": len(homes), + "homes": homes, + "last_updated": user_profile.get("last_updated"), + } + return { + "success": False, + "message": "User data was already up to date", + } + + # --- Direct helpers (called by service handler or each other) --- @@ -470,8 +551,6 @@ def _annotate_intervals_with_times( elif day == "tomorrow": prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False) interval["previous_end_time"] = prev_end - elif day == "yesterday": - interval["previous_end_time"] = None else: interval["previous_end_time"] = None @@ -712,6 +791,13 @@ def async_setup_services(hass: HomeAssistant) -> None: schema=APEXCHARTS_SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) + hass.services.async_register( + DOMAIN, + REFRESH_USER_DATA_SERVICE_NAME, + _refresh_user_data, + schema=REFRESH_USER_DATA_SERVICE_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) # endregion diff --git a/custom_components/tibber_prices/services.yaml b/custom_components/tibber_prices/services.yaml index 2f136a5..1c338fb 100644 --- a/custom_components/tibber_prices/services.yaml +++ b/custom_components/tibber_prices/services.yaml @@ -111,3 +111,16 @@ get_apexcharts_yaml: - yesterday - today - tomorrow +refresh_user_data: + name: Refresh User Data + description: >- + Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues. + fields: + entry_id: + name: Entry ID + description: The config entry ID for the Tibber integration. + required: true + example: "1234567890abcdef" + selector: + config_entry: + integration: tibber_prices diff --git a/custom_components/tibber_prices/translations/de.json b/custom_components/tibber_prices/translations/de.json index b038441..60e5743 100644 --- a/custom_components/tibber_prices/translations/de.json +++ b/custom_components/tibber_prices/translations/de.json @@ -133,5 +133,27 @@ "name": "Daten für morgen verfügbar" } } + }, + "issues": { + "new_homes_available": { + "title": "Neue Tibber-Häuser erkannt", + "description": "Wir haben {count} neue(s) Zuhause in deinem Tibber-Konto erkannt: {homes}. Du kannst diese über die Tibber-Integration in Home Assistant hinzufügen." + }, + "homes_removed": { + "title": "Tibber-Häuser entfernt", + "description": "Wir haben erkannt, dass {count} Zuhause aus deinem Tibber-Konto entfernt wurde(n): {homes}. Bitte überprüfe deine Tibber-Integrationskonfiguration." + } + }, + "services": { + "refresh_user_data": { + "name": "Benutzerdaten aktualisieren", + "description": "Erzwingt eine Aktualisierung der Benutzerdaten (Häuser, Profilinformationen) aus der Tibber API. Dies kann nützlich sein, nachdem Änderungen an deinem Tibber-Konto vorgenommen wurden oder bei der Fehlerbehebung von Verbindungsproblemen.", + "fields": { + "entry_id": { + "name": "Eintrag-ID", + "description": "Die Konfigurationseintrag-ID für die Tibber-Integration." + } + } + } } } diff --git a/custom_components/tibber_prices/translations/en.json b/custom_components/tibber_prices/translations/en.json index 3ae316a..3ea30a8 100644 --- a/custom_components/tibber_prices/translations/en.json +++ b/custom_components/tibber_prices/translations/en.json @@ -33,11 +33,11 @@ }, "config_subentries": { "home": { - "title": "Home", + "title": "Add Home", "step": { "user": { "title": "Add Tibber Home", - "description": "Select a home to add to your Tibber integration.", + "description": "Select a home to add to your Tibber integration.\n\n**Note:** After adding this home, you can add additional homes from the integration's context menu by selecting \"Add Home\".", "data": { "home_id": "Home" } @@ -51,7 +51,7 @@ "no_access_token": "No access token available", "home_not_found": "Selected home not found", "api_error": "Failed to fetch homes from Tibber API", - "no_available_homes": "No additional homes available to add" + "no_available_homes": "No additional homes available to add. All homes from your Tibber account have already been added." } } }, @@ -148,5 +148,27 @@ "name": "Tomorrow's Data Available" } } + }, + "issues": { + "new_homes_available": { + "title": "New Tibber homes detected", + "description": "We detected {count} new home(s) on your Tibber account: {homes}. You can add them to Home Assistant through the Tibber integration configuration." + }, + "homes_removed": { + "title": "Tibber homes removed", + "description": "We detected that {count} home(s) have been removed from your Tibber account: {homes}. Please review your Tibber integration configuration." + } + }, + "services": { + "refresh_user_data": { + "name": "Refresh User Data", + "description": "Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues.", + "fields": { + "entry_id": { + "name": "Entry ID", + "description": "The config entry ID for the Tibber integration." + } + } + } } } diff --git a/hacs.json b/hacs.json index 2975773..33830aa 100644 --- a/hacs.json +++ b/hacs.json @@ -1,6 +1,6 @@ { "name": "Tibber Price Information & Ratings", - "homeassistant": "2025.4.2", + "homeassistant": "2025.5.0", "hacs": "2.0.1", "render_readme": true } diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/tests/test_coordinator_basic.py b/tests/test_coordinator_basic.py new file mode 100644 index 0000000..6d2f90e --- /dev/null +++ b/tests/test_coordinator_basic.py @@ -0,0 +1,84 @@ +"""Test basic coordinator functionality with the enhanced coordinator.""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator + + +class TestBasicCoordinator: + """Test basic coordinator functionality.""" + + @pytest.fixture + def mock_hass(self): + """Create a mock Home Assistant instance.""" + hass = Mock() + hass.data = {} + return hass + + @pytest.fixture + def mock_config_entry(self): + """Create a mock config entry.""" + config_entry = Mock() + config_entry.unique_id = "test_home_123" + config_entry.entry_id = "test_entry" + config_entry.data = {"access_token": "test_token"} + return config_entry + + @pytest.fixture + def mock_session(self): + """Create a mock session.""" + return Mock() + + @pytest.fixture + def coordinator(self, mock_hass, mock_config_entry, mock_session): + """Create a coordinator instance.""" + with patch( + "custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession", + return_value=mock_session, + ): + with patch("custom_components.tibber_prices.coordinator.Store") as mock_store_class: + mock_store = Mock() + mock_store.async_load = AsyncMock(return_value=None) + mock_store.async_save = AsyncMock() + mock_store_class.return_value = mock_store + + return TibberPricesDataUpdateCoordinator(mock_hass, mock_config_entry) + + def test_coordinator_creation(self, coordinator): + """Test that coordinator can be created.""" + assert coordinator is not None + assert hasattr(coordinator, "get_current_interval_data") + assert hasattr(coordinator, "get_all_intervals") + assert hasattr(coordinator, "get_user_profile") + + def test_is_main_entry(self, coordinator): + """Test main entry detection.""" + # First coordinator should be main entry + assert coordinator.is_main_entry() is True + + def test_get_user_profile_no_data(self, coordinator): + """Test getting user profile when no data is cached.""" + profile = coordinator.get_user_profile() + assert profile == {"last_updated": None, "cached_user_data": False} + + def test_get_user_homes_no_data(self, coordinator): + """Test getting user homes when no data is cached.""" + homes = coordinator.get_user_homes() + assert homes == [] + + def test_get_current_interval_data_no_data(self, coordinator): + """Test getting current interval data when no data is available.""" + current_data = coordinator.get_current_interval_data() + assert current_data is None + + def test_get_all_intervals_no_data(self, coordinator): + """Test getting all intervals when no data is available.""" + intervals = coordinator.get_all_intervals() + assert intervals == [] + + def test_get_interval_granularity(self, coordinator): + """Test getting interval granularity.""" + granularity = coordinator.get_interval_granularity() + assert granularity is None diff --git a/tests/test_coordinator_enhanced.py b/tests/test_coordinator_enhanced.py new file mode 100644 index 0000000..eabe6f0 --- /dev/null +++ b/tests/test_coordinator_enhanced.py @@ -0,0 +1,255 @@ +"""Test enhanced coordinator functionality.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from custom_components.tibber_prices.const import DOMAIN +from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator + + +class TestEnhancedCoordinator: + """Test enhanced coordinator functionality.""" + + @pytest.fixture + def mock_config_entry(self) -> Mock: + """Create a mock config entry.""" + config_entry = Mock() + config_entry.unique_id = "test_home_id_123" + config_entry.entry_id = "test_entry_id" + config_entry.data = {"access_token": "test_token"} + return config_entry + + @pytest.fixture + def mock_hass(self) -> Mock: + """Create a mock Home Assistant instance.""" + import asyncio + + hass = Mock() + hass.data = {} + # Mock the event loop for time tracking + hass.loop = asyncio.get_event_loop() + return hass + + @pytest.fixture + def mock_store(self) -> Mock: + """Create a mock store.""" + store = Mock() + store.async_load = AsyncMock(return_value=None) + store.async_save = AsyncMock() + return store + + @pytest.fixture + def mock_api(self) -> Mock: + """Create a mock API client.""" + api = Mock() + api.async_get_viewer_details = AsyncMock(return_value={"homes": []}) + api.async_get_price_info = AsyncMock(return_value={"homes": {}}) + api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {}}) + api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {}}) + api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {}}) + return api + + @pytest.fixture + def coordinator( + self, mock_hass: Mock, mock_config_entry: Mock, mock_store: Mock, mock_api: Mock + ) -> TibberPricesDataUpdateCoordinator: + """Create a coordinator for testing.""" + mock_session = Mock() + with ( + patch( + "custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession", + return_value=mock_session, + ), + patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store), + ): + coordinator = TibberPricesDataUpdateCoordinator( + hass=mock_hass, + config_entry=mock_config_entry, + ) + # Replace the API instance with our mock + coordinator.api = mock_api + return coordinator + + @pytest.mark.asyncio + async def test_main_subentry_pattern(self, mock_hass: Mock, mock_store: Mock) -> None: + """Test main/subentry coordinator pattern.""" + # Create main coordinator first + main_config_entry = Mock() + main_config_entry.unique_id = "main_home_id" + main_config_entry.entry_id = "main_entry_id" + main_config_entry.data = {"access_token": "test_token"} + + mock_session = Mock() + with ( + patch( + "custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession", + return_value=mock_session, + ), + patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store), + ): + main_coordinator = TibberPricesDataUpdateCoordinator( + hass=mock_hass, + config_entry=main_config_entry, + ) + + # Verify main coordinator is marked as main entry + assert main_coordinator.is_main_entry() + + # Create subentry coordinator + sub_config_entry = Mock() + sub_config_entry.unique_id = "sub_home_id" + sub_config_entry.entry_id = "sub_entry_id" + sub_config_entry.data = {"access_token": "test_token", "home_id": "sub_home_id"} + + # Set up domain data to simulate main coordinator being already registered + mock_hass.data[DOMAIN] = {"main_entry_id": main_coordinator} + + with ( + patch( + "custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession", + return_value=mock_session, + ), + patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store), + ): + sub_coordinator = TibberPricesDataUpdateCoordinator( + hass=mock_hass, + config_entry=sub_config_entry, + ) + + # Verify subentry coordinator is not marked as main entry + assert not sub_coordinator.is_main_entry() + + @pytest.mark.asyncio + async def test_user_data_functionality(self, coordinator: TibberPricesDataUpdateCoordinator) -> None: + """Test user data related functionality.""" + # Mock user data API + mock_user_data = { + "homes": [ + {"id": "home1", "appNickname": "Home 1"}, + {"id": "home2", "appNickname": "Home 2"}, + ] + } + coordinator.api.async_get_viewer_details = AsyncMock(return_value=mock_user_data) + + # Test refresh user data + result = await coordinator.refresh_user_data() + assert result + + # Test get user profile + profile = coordinator.get_user_profile() + assert isinstance(profile, dict) + assert "last_updated" in profile + assert "cached_user_data" in profile + + # Test get user homes + homes = coordinator.get_user_homes() + assert isinstance(homes, list) + + @pytest.mark.asyncio + async def test_data_update_with_multi_home_response(self, coordinator: TibberPricesDataUpdateCoordinator) -> None: + """Test coordinator handling multi-home API response.""" + # Mock API responses + mock_price_response = { + "homes": { + "test_home_id_123": { + "priceInfo": { + "today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.25}], + "tomorrow": [], + "yesterday": [], + } + }, + "other_home_id": { + "priceInfo": { + "today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.30}], + "tomorrow": [], + "yesterday": [], + } + }, + } + } + + mock_hourly_rating = {"homes": {"test_home_id_123": {"hourly": []}}} + mock_daily_rating = {"homes": {"test_home_id_123": {"daily": []}}} + mock_monthly_rating = {"homes": {"test_home_id_123": {"monthly": []}}} + + # Mock all API methods + coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response) + coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value=mock_hourly_rating) + coordinator.api.async_get_daily_price_rating = AsyncMock(return_value=mock_daily_rating) + coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value=mock_monthly_rating) + + # Update the coordinator to fetch data + await coordinator.async_refresh() + + # Verify coordinator has data + assert coordinator.data is not None + assert "priceInfo" in coordinator.data + assert "priceRating" in coordinator.data + + # Test public API methods work + intervals = coordinator.get_all_intervals() + assert isinstance(intervals, list) + + @pytest.mark.asyncio + async def test_error_handling_with_cache_fallback(self, coordinator: TibberPricesDataUpdateCoordinator) -> None: + """Test error handling with fallback to cached data.""" + from custom_components.tibber_prices.api import TibberPricesApiClientCommunicationError + + # Set up cached data using the store mechanism + test_cached_data = { + "timestamp": "2025-05-25T00:00:00Z", + "homes": { + "test_home_id_123": { + "price_info": {"today": [], "tomorrow": [], "yesterday": []}, + "hourly_rating": {}, + "daily_rating": {}, + "monthly_rating": {}, + } + }, + } + + # Mock store to return cached data + coordinator._store.async_load = AsyncMock( + return_value={ + "price_data": test_cached_data, + "user_data": None, + "last_price_update": "2025-05-25T00:00:00Z", + "last_user_update": None, + } + ) + + # Load the cache + await coordinator._load_cache() + + # Mock API to raise communication error + coordinator.api.async_get_price_info = AsyncMock( + side_effect=TibberPricesApiClientCommunicationError("Network error") + ) + + # Should not raise exception but use cached data + await coordinator.async_refresh() + + # Verify coordinator has fallback data + assert coordinator.data is not None + + @pytest.mark.asyncio + async def test_cache_persistence(self, coordinator: TibberPricesDataUpdateCoordinator) -> None: + """Test that data is properly cached and persisted.""" + # Mock API responses + mock_price_response = { + "homes": {"test_home_id_123": {"priceInfo": {"today": [], "tomorrow": [], "yesterday": []}}} + } + + coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response) + coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}}) + coordinator.api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}}) + coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}}) + + # Update the coordinator + await coordinator.async_refresh() + + # Verify data was cached (store should have been called) + coordinator._store.async_save.assert_called() diff --git a/tests/test_hello.py b/tests/test_hello.py deleted file mode 100644 index b591bc7..0000000 --- a/tests/test_hello.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -from unittest.mock import Mock, patch - - -class TestReauthentication(unittest.TestCase): - @patch("your_module.connection") # Replace 'your_module' with the actual module name - def test_reauthentication_flow(self, mock_connection): - mock_connection.reauthenticate = Mock(return_value=True) - result = mock_connection.reauthenticate() - self.assertTrue(result) - - @patch("your_module.connection") # Replace 'your_module' with the actual module name - def test_connection_timeout(self, mock_connection): - mock_connection.connect = Mock(side_effect=TimeoutError) - with self.assertRaises(TimeoutError): - mock_connection.connect() - - -if __name__ == "__main__": - unittest.main()