"""Tibber API Client.""" from __future__ import annotations import asyncio import base64 import logging import re import socket from datetime import timedelta from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo import aiohttp from homeassistant.util import dt as dt_utils from .exceptions import ( TibberPricesApiClientAuthenticationError, TibberPricesApiClientCommunicationError, TibberPricesApiClientError, TibberPricesApiClientPermissionError, ) from .helpers import ( flatten_price_info, prepare_headers, verify_graphql_response, verify_response_or_raise, ) from .queries import TibberPricesQueryType if TYPE_CHECKING: from custom_components.tibber_prices.coordinator.time_service import TibberPricesTimeService _LOGGER = logging.getLogger(__name__) class TibberPricesApiClient: """Tibber API Client.""" def __init__( self, access_token: str, session: aiohttp.ClientSession, version: str, ) -> None: """Tibber API Client.""" self._access_token = access_token self._session = session self._version = version self._request_semaphore = asyncio.Semaphore(2) # Max 2 concurrent requests self.time: TibberPricesTimeService | None = None # Set externally by coordinator (optional during config flow) self._last_request_time = None # Set on first request self._min_request_interval = timedelta(seconds=1) # Min 1 second between requests self._max_retries = 5 self._retry_delay = 2 # Base retry delay in seconds # Timeout configuration - more granular control self._connect_timeout = 10 # Connection timeout in seconds self._request_timeout = 25 # Total request timeout in seconds self._socket_connect_timeout = 5 # Socket connection timeout async def async_get_viewer_details(self) -> Any: """Get comprehensive viewer and home details from Tibber API.""" return await self._api_wrapper( data={ "query": """ { viewer { userId name login accountType homes { id type appNickname appAvatar size timeZone mainFuseSize numberOfResidents primaryHeatingSource hasVentilationSystem address { address1 address2 address3 postalCode city country latitude longitude } owner { id firstName lastName isCompany name contactInfo { email mobile } language } meteringPointData { consumptionEan gridCompany gridAreaCode priceAreaCode productionEan energyTaxType vatType estimatedAnnualConsumption } currentSubscription { id status validFrom validTo priceInfo { current { currency } } } features { realTimeConsumptionEnabled } } } } """ }, query_type=TibberPricesQueryType.USER, ) async def async_get_price_info(self, home_ids: set[str], user_data: dict[str, Any]) -> dict: """ Get price info for specific homes using GraphQL aliases. Uses timezone-aware cursor calculation per home based on the home's actual timezone from Tibber API (not HA system timezone). This ensures correct "day before yesterday midnight" calculation for homes in different timezones. Args: home_ids: Set of home IDs to fetch price data for. user_data: User data dict containing home metadata (including timezone). REQUIRED - must be fetched before calling this method. Returns: Dict with "homes" key containing home_id -> price_data mapping. Raises: TibberPricesApiClientError: If TimeService not initialized or user_data missing. """ if not self.time: msg = "TimeService not initialized - required for price info processing" raise TibberPricesApiClientError(msg) if not user_data: msg = "User data required for timezone-aware price fetching - fetch user data first" raise TibberPricesApiClientError(msg) if not home_ids: return {"homes": {}} # Build home_id -> timezone mapping from user_data home_timezones = self._extract_home_timezones(user_data) # Build query with aliases for each home # Each home gets its own cursor based on its timezone home_queries = [] for idx, home_id in enumerate(sorted(home_ids)): alias = f"home{idx}" # Get timezone for this home (fallback to HA system timezone) home_tz = home_timezones.get(home_id) # Calculate cursor: day before yesterday midnight in home's timezone cursor = self._calculate_cursor_for_home(home_tz) home_query = f""" {alias}: home(id: "{home_id}") {{ id currentSubscription {{ priceInfoRange(resolution:QUARTER_HOURLY, first:192, after: "{cursor}") {{ pageInfo{{ count }} edges{{node{{ startsAt total level }}}} }} priceInfo(resolution:QUARTER_HOURLY) {{ today{{startsAt total level}} tomorrow{{startsAt total level}} }} }} }} """ home_queries.append(home_query) query = "{viewer{" + "".join(home_queries) + "}}" _LOGGER.debug("Fetching price info for %d specific home(s)", len(home_ids)) data = await self._api_wrapper( data={"query": query}, query_type=TibberPricesQueryType.PRICE_INFO, ) # Parse aliased response viewer = data.get("viewer", {}) homes_data = {} for idx, home_id in enumerate(sorted(home_ids)): alias = f"home{idx}" home = viewer.get(alias) if not home: _LOGGER.debug("Home %s not found in API response", home_id) homes_data[home_id] = {} continue if "currentSubscription" in home and home["currentSubscription"] is not None: homes_data[home_id] = flatten_price_info(home["currentSubscription"]) else: _LOGGER.debug( "Home %s has no active subscription - price data will be unavailable", home_id, ) homes_data[home_id] = {} data["homes"] = homes_data return data def _extract_home_timezones(self, user_data: dict[str, Any]) -> dict[str, str]: """ Extract home_id -> timezone mapping from user_data. Args: user_data: User data dict from async_get_viewer_details() (required). Returns: Dict mapping home_id to timezone string (e.g., "Europe/Oslo"). """ home_timezones = {} viewer = user_data.get("viewer", {}) homes = viewer.get("homes", []) for home in homes: home_id = home.get("id") timezone = home.get("timeZone") if home_id and timezone: home_timezones[home_id] = timezone _LOGGER.debug("Extracted timezone %s for home %s", timezone, home_id) elif home_id: _LOGGER.warning("Home %s has no timezone in user data, will use fallback", home_id) return home_timezones def _calculate_cursor_for_home(self, home_timezone: str | None) -> str: """ Calculate cursor (day before yesterday midnight) for a home's timezone. Args: home_timezone: Timezone string (e.g., "Europe/Oslo", "America/New_York"). If None, falls back to HA system timezone. Returns: Base64-encoded ISO timestamp string for use as GraphQL cursor. """ if not self.time: msg = "TimeService not initialized" raise TibberPricesApiClientError(msg) # Get current time now = self.time.now() # Convert to home's timezone or fallback to HA system timezone if home_timezone: try: tz = ZoneInfo(home_timezone) now_in_home_tz = now.astimezone(tz) except (KeyError, ValueError, OSError) as error: _LOGGER.warning( "Invalid timezone %s (%s), falling back to HA system timezone", home_timezone, error, ) now_in_home_tz = dt_utils.as_local(now) else: # Fallback to HA system timezone now_in_home_tz = dt_utils.as_local(now) # Calculate day before yesterday midnight in home's timezone day_before_yesterday_midnight = (now_in_home_tz - timedelta(days=2)).replace( hour=0, minute=0, second=0, microsecond=0 ) # Convert to ISO format and base64 encode iso_string = day_before_yesterday_midnight.isoformat() return base64.b64encode(iso_string.encode()).decode() async def _make_request( self, headers: dict[str, str], data: dict, query_type: TibberPricesQueryType, ) -> dict[str, Any]: """Make an API request with comprehensive error handling for network issues.""" _LOGGER.debug("Making API request with data: %s", data) try: # More granular timeout configuration for better network failure handling timeout = aiohttp.ClientTimeout( total=self._request_timeout, # Total request timeout: 25s connect=self._connect_timeout, # Connection timeout: 10s sock_connect=self._socket_connect_timeout, # Socket connection: 5s ) response = await self._session.request( method="POST", url="https://api.tibber.com/v1-beta/gql", headers=headers, json=data, timeout=timeout, ) verify_response_or_raise(response) response_json = await response.json() _LOGGER.debug("Received API response: %s", response_json) await verify_graphql_response(response_json, query_type) return response_json["data"] except aiohttp.ClientResponseError as error: _LOGGER.exception("HTTP error during API request") raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) from error except aiohttp.ClientConnectorError as error: _LOGGER.exception("Connection error - server unreachable or network down") raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) from error except aiohttp.ServerDisconnectedError as error: _LOGGER.exception("Server disconnected during request") raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) from error except TimeoutError as error: _LOGGER.exception( "Request timeout after %d seconds - slow network or server overload", self._request_timeout, ) raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(exception=str(error)) ) from error except socket.gaierror as error: self._handle_dns_error(error) raise # Ensure type checker knows this path always raises except OSError as error: self._handle_network_error(error) raise # Ensure type checker knows this path always raises def _handle_dns_error(self, error: socket.gaierror) -> None: """Handle DNS resolution errors with IPv4/IPv6 dual stack considerations.""" error_msg = str(error) if "Name or service not known" in error_msg: _LOGGER.exception("DNS resolution failed - domain name not found") elif "Temporary failure in name resolution" in error_msg: _LOGGER.exception("DNS resolution temporarily failed - network or DNS server issue") elif "Address family for hostname not supported" in error_msg: _LOGGER.exception("DNS resolution failed - IPv4/IPv6 address family not supported") elif "No address associated with hostname" in error_msg: _LOGGER.exception("DNS resolution failed - no IPv4/IPv6 addresses found") else: _LOGGER.exception("DNS resolution failed - check internet connection: %s", error_msg) raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) from error def _handle_network_error(self, error: OSError) -> None: """Handle network-level errors with IPv4/IPv6 dual stack considerations.""" error_msg = str(error) errno = getattr(error, "errno", None) # Common IPv4/IPv6 dual stack network error codes errno_network_unreachable = 101 # ENETUNREACH errno_host_unreachable = 113 # EHOSTUNREACH errno_connection_refused = 111 # ECONNREFUSED errno_connection_timeout = 110 # ETIMEDOUT if errno == errno_network_unreachable: _LOGGER.exception("Network unreachable - check internet connection or IPv4/IPv6 routing") elif errno == errno_host_unreachable: _LOGGER.exception("Host unreachable - routing issue or IPv4/IPv6 connectivity problem") elif errno == errno_connection_refused: _LOGGER.exception("Connection refused - server not accepting connections") elif errno == errno_connection_timeout: _LOGGER.exception("Connection timed out - network latency or server overload") elif "Address family not supported" in error_msg: _LOGGER.exception("Address family not supported - IPv4/IPv6 configuration issue") elif "Protocol not available" in error_msg: _LOGGER.exception("Protocol not available - IPv4/IPv6 stack configuration issue") elif "Network is down" in error_msg: _LOGGER.exception("Network interface is down - check network adapter") elif "Permission denied" in error_msg: _LOGGER.exception("Network permission denied - firewall or security restriction") else: _LOGGER.exception("Network error - internet may be down: %s", error_msg) raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) from error async def _handle_request( self, headers: dict[str, str], data: dict, query_type: TibberPricesQueryType, ) -> Any: """Handle a single API request with rate limiting.""" async with self._request_semaphore: # Rate limiting: ensure minimum interval between requests if self.time and self._last_request_time: now = self.time.now() time_since_last_request = now - self._last_request_time if time_since_last_request < self._min_request_interval: sleep_time = (self._min_request_interval - time_since_last_request).total_seconds() _LOGGER.debug( "Rate limiting: waiting %s seconds before next request", sleep_time, ) await asyncio.sleep(sleep_time) if self.time: self._last_request_time = self.time.now() return await self._make_request( headers, data or {}, query_type, ) def _should_retry_error(self, error: Exception, retry: int) -> tuple[bool, int]: """Determine if an error should be retried and calculate delay.""" # Check if we've exceeded max retries first if retry >= self._max_retries: return False, 0 # Non-retryable errors - authentication and permission issues if isinstance( error, ( TibberPricesApiClientAuthenticationError, TibberPricesApiClientPermissionError, ), ): return False, 0 # Handle API-specific errors if isinstance(error, TibberPricesApiClientError): return self._handle_api_error_retry(error, retry) # Network and timeout errors - retryable with exponential backoff if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)): delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds return True, delay # Unknown errors - not retryable return False, 0 def _handle_api_error_retry(self, error: TibberPricesApiClientError, retry: int) -> tuple[bool, int]: """Handle retry logic for API-specific errors.""" error_msg = str(error) # Non-retryable: Invalid queries if "Invalid GraphQL query" in error_msg or "Bad request" in error_msg: return False, 0 # Rate limits - special handling with extracted delay if "Rate limit exceeded" in error_msg or "rate limited" in error_msg.lower(): delay = self._extract_retry_delay(error, retry) return True, delay # Empty data - retryable with capped exponential backoff if "Empty data received" in error_msg: delay = min(self._retry_delay * (2**retry), 60) # Cap at 60 seconds return True, delay # Other API errors - retryable with capped exponential backoff delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds return True, delay def _extract_retry_delay(self, error: Exception, retry: int) -> int: """Extract retry delay from rate limit error or use exponential backoff.""" error_msg = str(error) # Try to extract Retry-After value from error message retry_after_match = re.search(r"retry after (\d+) seconds", error_msg.lower()) if retry_after_match: try: retry_after = int(retry_after_match.group(1)) return min(retry_after + 1, 300) # Add buffer, max 5 minutes except ValueError: pass # Try to extract generic seconds value seconds_match = re.search(r"(\d+) seconds", error_msg) if seconds_match: try: seconds = int(seconds_match.group(1)) return min(seconds + 1, 300) # Add buffer, max 5 minutes except ValueError: pass # Fall back to exponential backoff with cap base_delay = self._retry_delay * (2**retry) return min(base_delay, 120) # Cap at 2 minutes for rate limits async def _api_wrapper( self, data: dict | None = None, headers: dict | None = None, query_type: TibberPricesQueryType = TibberPricesQueryType.USER, ) -> Any: """Get information from the API with rate limiting and retry logic.""" headers = headers or prepare_headers(self._access_token, self._version) last_error: Exception | None = None for retry in range(self._max_retries + 1): try: return await self._handle_request(headers, data or {}, query_type) except ( TibberPricesApiClientAuthenticationError, TibberPricesApiClientPermissionError, ): _LOGGER.exception("Non-retryable error occurred") raise except ( TibberPricesApiClientError, aiohttp.ClientError, socket.gaierror, TimeoutError, ) as error: last_error = ( error if isinstance(error, TibberPricesApiClientError) else TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error)) ) ) should_retry, delay = self._should_retry_error(error, retry) if should_retry: error_type = self._get_error_type(error) _LOGGER.warning( "Tibber %s error, attempt %d/%d. Retrying in %d seconds: %s", error_type, retry + 1, self._max_retries, delay, str(error), ) await asyncio.sleep(delay) continue if "Invalid GraphQL query" in str(error): _LOGGER.exception("Invalid query - not retrying") raise # Handle final error state if isinstance(last_error, TimeoutError): raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(exception=last_error) ) from last_error if isinstance(last_error, (aiohttp.ClientError, socket.gaierror)): raise TibberPricesApiClientCommunicationError( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=last_error) ) from last_error raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR) def _get_error_type(self, error: Exception) -> str: """Get a descriptive error type for logging.""" if "Rate limit" in str(error): return "rate limit" if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)): return "network" return "API"