This commit is contained in:
Julian Pawlowski 2025-05-25 22:15:25 +00:00
parent bd33fc7367
commit f57fdfde6b
15 changed files with 1230 additions and 991 deletions

View file

@ -51,9 +51,7 @@ async def async_setup_entry(
coordinator = TibberPricesDataUpdateCoordinator( coordinator = TibberPricesDataUpdateCoordinator(
hass=hass, hass=hass,
entry=entry, config_entry=entry,
logger=LOGGER,
name=DOMAIN,
) )
entry.runtime_data = TibberPricesData( entry.runtime_data = TibberPricesData(
client=TibberPricesApiClient( client=TibberPricesApiClient(
@ -88,7 +86,7 @@ async def async_unload_entry(
# Unregister services if this was the last config entry # Unregister services if this was the last config entry
if not hass.config_entries.async_entries(DOMAIN): if not hass.config_entries.async_entries(DOMAIN):
for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml"]: for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml", "refresh_user_data"]:
if hass.services.has_service(DOMAIN, service): if hass.services.has_service(DOMAIN, service):
hass.services.async_remove(DOMAIN, service) hass.services.async_remove(DOMAIN, service)

View file

@ -10,7 +10,6 @@ from enum import Enum
from typing import Any from typing import Any
import aiohttp import aiohttp
import async_timeout
from homeassistant.const import __version__ as ha_version from homeassistant.const import __version__ as ha_version
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -19,9 +18,10 @@ from .const import VERSION
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429 HTTP_BAD_REQUEST = 400
HTTP_UNAUTHORIZED = 401 HTTP_UNAUTHORIZED = 401
HTTP_FORBIDDEN = 403 HTTP_FORBIDDEN = 403
HTTP_TOO_MANY_REQUESTS = 429
class QueryType(Enum): class QueryType(Enum):
@ -31,7 +31,7 @@ class QueryType(Enum):
DAILY_RATING = "daily" DAILY_RATING = "daily"
HOURLY_RATING = "hourly" HOURLY_RATING = "hourly"
MONTHLY_RATING = "monthly" MONTHLY_RATING = "monthly"
VIEWER = "viewer" USER = "user"
class TibberPricesApiClientError(Exception): class TibberPricesApiClientError(Exception):
@ -42,7 +42,8 @@ class TibberPricesApiClientError(Exception):
GRAPHQL_ERROR = "GraphQL error: {message}" GRAPHQL_ERROR = "GraphQL error: {message}"
EMPTY_DATA_ERROR = "Empty data received for {query_type}" EMPTY_DATA_ERROR = "Empty data received for {query_type}"
GENERIC_ERROR = "Something went wrong! {exception}" GENERIC_ERROR = "Something went wrong! {exception}"
RATE_LIMIT_ERROR = "Rate limit exceeded" RATE_LIMIT_ERROR = "Rate limit exceeded. Please wait {retry_after} seconds before retrying"
INVALID_QUERY_ERROR = "Invalid GraphQL query: {message}"
class TibberPricesApiClientCommunicationError(TibberPricesApiClientError): class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
@ -55,15 +56,33 @@ class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError): class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError):
"""Exception to indicate an authentication error.""" """Exception to indicate an authentication error."""
INVALID_CREDENTIALS = "Invalid credentials" INVALID_CREDENTIALS = "Invalid access token or expired credentials"
class TibberPricesApiClientPermissionError(TibberPricesApiClientError):
"""Exception to indicate insufficient permissions."""
INSUFFICIENT_PERMISSIONS = "Access forbidden - insufficient permissions for this operation"
def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None: def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None:
"""Verify that the response is valid.""" """Verify that the response is valid."""
if response.status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): if response.status == HTTP_UNAUTHORIZED:
_LOGGER.error("Tibber API authentication failed - check access token")
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS) raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
if response.status == HTTP_FORBIDDEN:
_LOGGER.error("Tibber API access forbidden - insufficient permissions")
raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS)
if response.status == HTTP_TOO_MANY_REQUESTS: if response.status == HTTP_TOO_MANY_REQUESTS:
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR) # Check for Retry-After header that Tibber might send
retry_after = response.headers.get("Retry-After", "unknown")
_LOGGER.warning("Tibber API rate limit exceeded - retry after %s seconds", retry_after)
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after))
if response.status == HTTP_BAD_REQUEST:
_LOGGER.error("Tibber API rejected request - likely invalid GraphQL query")
raise TibberPricesApiClientError(
TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message="Bad request - likely invalid GraphQL query")
)
response.raise_for_status() response.raise_for_status()
@ -72,21 +91,41 @@ async def _verify_graphql_response(response_json: dict, query_type: QueryType) -
if "errors" in response_json: if "errors" in response_json:
errors = response_json["errors"] errors = response_json["errors"]
if not errors: if not errors:
_LOGGER.error("Tibber API returned empty errors array")
raise TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR) raise TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
error = errors[0] # Take first error error = errors[0] # Take first error
if not isinstance(error, dict): if not isinstance(error, dict):
_LOGGER.error("Tibber API returned malformed error: %s", error)
raise TibberPricesApiClientError(TibberPricesApiClientError.MALFORMED_ERROR.format(error=error)) raise TibberPricesApiClientError(TibberPricesApiClientError.MALFORMED_ERROR.format(error=error))
message = error.get("message", "Unknown error") message = error.get("message", "Unknown error")
extensions = error.get("extensions", {}) extensions = error.get("extensions", {})
error_code = extensions.get("code")
if extensions.get("code") == "UNAUTHENTICATED": # Handle specific Tibber API error codes
if error_code == "UNAUTHENTICATED":
_LOGGER.error("Tibber API authentication error: %s", message)
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS) raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
if error_code == "FORBIDDEN":
_LOGGER.error("Tibber API permission error: %s", message)
raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS)
if error_code in ["RATE_LIMITED", "TOO_MANY_REQUESTS"]:
# Some GraphQL APIs return rate limit info in extensions
retry_after = extensions.get("retryAfter", "unknown")
_LOGGER.warning("Tibber API rate limited via GraphQL: %s (retry after %s)", message, retry_after)
raise TibberPricesApiClientError(
TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after)
)
if error_code in ["VALIDATION_ERROR", "GRAPHQL_VALIDATION_FAILED"]:
_LOGGER.error("Tibber API validation error: %s", message)
raise TibberPricesApiClientError(TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message=message))
_LOGGER.error("Tibber API GraphQL error (code: %s): %s", error_code or "unknown", message)
raise TibberPricesApiClientError(TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message)) raise TibberPricesApiClientError(TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message))
if "data" not in response_json or response_json["data"] is None: if "data" not in response_json or response_json["data"] is None:
_LOGGER.error("Tibber API response missing data object")
raise TibberPricesApiClientError( raise TibberPricesApiClientError(
TibberPricesApiClientError.GRAPHQL_ERROR.format(message="Response missing data object") TibberPricesApiClientError.GRAPHQL_ERROR.format(message="Response missing data object")
) )
@ -122,7 +161,7 @@ def _is_data_empty(data: dict, query_type: str) -> bool:
is_empty = False is_empty = False
try: try:
if query_type == "viewer": if query_type == "user":
has_user_id = ( has_user_id = (
"viewer" in data "viewer" in data
and isinstance(data["viewer"], dict) and isinstance(data["viewer"], dict)
@ -321,11 +360,16 @@ class TibberPricesApiClient:
"""Tibber API Client.""" """Tibber API Client."""
self._access_token = access_token self._access_token = access_token
self._session = session self._session = session
self._request_semaphore = asyncio.Semaphore(2) self._request_semaphore = asyncio.Semaphore(2) # Max 2 concurrent requests
self._last_request_time = dt_util.now() self._last_request_time = dt_util.now()
self._min_request_interval = timedelta(seconds=1) self._min_request_interval = timedelta(seconds=1) # Min 1 second between requests
self._max_retries = 5 self._max_retries = 5
self._retry_delay = 2 self._retry_delay = 2 # Base retry delay in seconds
# Timeout configuration - more granular control
self._connect_timeout = 10 # Connection timeout in seconds
self._request_timeout = 25 # Total request timeout in seconds
self._socket_connect_timeout = 5 # Socket connection timeout
async def async_get_viewer_details(self) -> Any: async def async_get_viewer_details(self) -> Any:
"""Test connection to the API.""" """Test connection to the API."""
@ -352,11 +396,11 @@ class TibberPricesApiClient:
} }
""" """
}, },
query_type=QueryType.VIEWER, query_type=QueryType.USER,
) )
async def async_get_price_info(self, home_id: str) -> dict: async def async_get_price_info(self) -> dict:
"""Get price info data in flat format for the specified home_id.""" """Get price info data in flat format for all homes."""
data = await self._api_wrapper( data = await self._api_wrapper(
data={ data={
"query": """ "query": """
@ -371,15 +415,21 @@ class TibberPricesApiClient:
query_type=QueryType.PRICE_INFO, query_type=QueryType.PRICE_INFO,
) )
homes = data.get("viewer", {}).get("homes", []) homes = data.get("viewer", {}).get("homes", [])
home = next((h for h in homes if h.get("id") == home_id), None)
if home and "currentSubscription" in home: homes_data = {}
data["priceInfo"] = _flatten_price_info(home["currentSubscription"]) for home in homes:
else: home_id = home.get("id")
data["priceInfo"] = {} if home_id:
if "currentSubscription" in home:
homes_data[home_id] = _flatten_price_info(home["currentSubscription"])
else:
homes_data[home_id] = {}
data["homes"] = homes_data
return data return data
async def async_get_daily_price_rating(self, home_id: str) -> dict: async def async_get_daily_price_rating(self) -> dict:
"""Get daily price rating data in flat format for the specified home_id.""" """Get daily price rating data in flat format for all homes."""
data = await self._api_wrapper( data = await self._api_wrapper(
data={ data={
"query": """ "query": """
@ -394,15 +444,21 @@ class TibberPricesApiClient:
query_type=QueryType.DAILY_RATING, query_type=QueryType.DAILY_RATING,
) )
homes = data.get("viewer", {}).get("homes", []) homes = data.get("viewer", {}).get("homes", [])
home = next((h for h in homes if h.get("id") == home_id), None)
if home and "currentSubscription" in home: homes_data = {}
data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) for home in homes:
else: home_id = home.get("id")
data["priceRating"] = {} if home_id:
if "currentSubscription" in home:
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
else:
homes_data[home_id] = {}
data["homes"] = homes_data
return data return data
async def async_get_hourly_price_rating(self, home_id: str) -> dict: async def async_get_hourly_price_rating(self) -> dict:
"""Get hourly price rating data in flat format for the specified home_id.""" """Get hourly price rating data in flat format for all homes."""
data = await self._api_wrapper( data = await self._api_wrapper(
data={ data={
"query": """ "query": """
@ -417,15 +473,21 @@ class TibberPricesApiClient:
query_type=QueryType.HOURLY_RATING, query_type=QueryType.HOURLY_RATING,
) )
homes = data.get("viewer", {}).get("homes", []) homes = data.get("viewer", {}).get("homes", [])
home = next((h for h in homes if h.get("id") == home_id), None)
if home and "currentSubscription" in home: homes_data = {}
data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) for home in homes:
else: home_id = home.get("id")
data["priceRating"] = {} if home_id:
if "currentSubscription" in home:
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
else:
homes_data[home_id] = {}
data["homes"] = homes_data
return data return data
async def async_get_monthly_price_rating(self, home_id: str) -> dict: async def async_get_monthly_price_rating(self) -> dict:
"""Get monthly price rating data in flat format for the specified home_id.""" """Get monthly price rating data in flat format for all homes."""
data = await self._api_wrapper( data = await self._api_wrapper(
data={ data={
"query": """ "query": """
@ -440,40 +502,52 @@ class TibberPricesApiClient:
query_type=QueryType.MONTHLY_RATING, query_type=QueryType.MONTHLY_RATING,
) )
homes = data.get("viewer", {}).get("homes", []) homes = data.get("viewer", {}).get("homes", [])
home = next((h for h in homes if h.get("id") == home_id), None)
if home and "currentSubscription" in home: homes_data = {}
data["priceRating"] = _flatten_price_rating(home["currentSubscription"]) for home in homes:
else: home_id = home.get("id")
data["priceRating"] = {} if home_id:
if "currentSubscription" in home:
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
else:
homes_data[home_id] = {}
data["homes"] = homes_data
return data return data
async def async_get_data(self, home_id: str) -> dict: async def async_get_data(self) -> dict:
"""Get all data from the API by combining multiple queries in flat format for the specified home_id.""" """Get all data from the API by combining multiple queries in flat format for all homes."""
price_info = await self.async_get_price_info(home_id) price_info = await self.async_get_price_info()
daily_rating = await self.async_get_daily_price_rating(home_id) daily_rating = await self.async_get_daily_price_rating()
hourly_rating = await self.async_get_hourly_price_rating(home_id) hourly_rating = await self.async_get_hourly_price_rating()
monthly_rating = await self.async_get_monthly_price_rating(home_id) monthly_rating = await self.async_get_monthly_price_rating()
price_rating = {
"thresholdPercentages": daily_rating["priceRating"].get("thresholdPercentages"),
"daily": daily_rating["priceRating"].get("daily", []),
"hourly": hourly_rating["priceRating"].get("hourly", []),
"monthly": monthly_rating["priceRating"].get("monthly", []),
"currency": (
daily_rating["priceRating"].get("currency")
or hourly_rating["priceRating"].get("currency")
or monthly_rating["priceRating"].get("currency")
),
}
return {
"priceInfo": price_info["priceInfo"],
"priceRating": price_rating,
}
async def async_set_title(self, value: str) -> Any: all_home_ids = set()
"""Get data from the API.""" all_home_ids.update(price_info.get("homes", {}).keys())
return await self._api_wrapper( all_home_ids.update(daily_rating.get("homes", {}).keys())
data={"title": value}, all_home_ids.update(hourly_rating.get("homes", {}).keys())
) all_home_ids.update(monthly_rating.get("homes", {}).keys())
homes_combined = {}
for home_id in all_home_ids:
daily_data = daily_rating.get("homes", {}).get(home_id, {})
hourly_data = hourly_rating.get("homes", {}).get(home_id, {})
monthly_data = monthly_rating.get("homes", {}).get(home_id, {})
price_rating = {
"thresholdPercentages": daily_data.get("thresholdPercentages"),
"daily": daily_data.get("daily", []),
"hourly": hourly_data.get("hourly", []),
"monthly": monthly_data.get("monthly", []),
"currency": (daily_data.get("currency") or hourly_data.get("currency") or monthly_data.get("currency")),
}
homes_combined[home_id] = {
"priceInfo": price_info.get("homes", {}).get(home_id, {}),
"priceRating": price_rating,
}
return {"homes": homes_combined}
async def _make_request( async def _make_request(
self, self,
@ -481,23 +555,117 @@ class TibberPricesApiClient:
data: dict, data: dict,
query_type: QueryType, query_type: QueryType,
) -> dict: ) -> dict:
"""Make an API request with proper error handling.""" """Make an API request with comprehensive error handling for network issues."""
_LOGGER.debug("Making API request with data: %s", data) _LOGGER.debug("Making API request with data: %s", data)
response = await self._session.request( try:
method="POST", # More granular timeout configuration for better network failure handling
url="https://api.tibber.com/v1-beta/gql", timeout = aiohttp.ClientTimeout(
headers=headers, total=self._request_timeout, # Total request timeout: 25s
json=data, connect=self._connect_timeout, # Connection timeout: 10s
) sock_connect=self._socket_connect_timeout, # Socket connection: 5s
)
_verify_response_or_raise(response) response = await self._session.request(
response_json = await response.json() method="POST",
_LOGGER.debug("Received API response: %s", response_json) url="https://api.tibber.com/v1-beta/gql",
headers=headers,
json=data,
timeout=timeout,
)
await _verify_graphql_response(response_json, query_type) _verify_response_or_raise(response)
response_json = await response.json()
_LOGGER.debug("Received API response: %s", response_json)
return response_json["data"] await _verify_graphql_response(response_json, query_type)
return response_json["data"]
except aiohttp.ClientResponseError as error:
_LOGGER.exception("HTTP error during API request")
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) from error
except aiohttp.ClientConnectorError as error:
_LOGGER.exception("Connection error - server unreachable or network down")
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) from error
except aiohttp.ServerDisconnectedError as error:
_LOGGER.exception("Server disconnected during request")
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) from error
except TimeoutError as error:
_LOGGER.exception(
"Request timeout after %d seconds - slow network or server overload", self._request_timeout
)
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(exception=str(error))
) from error
except socket.gaierror as error:
self._handle_dns_error(error)
except OSError as error:
self._handle_network_error(error)
def _handle_dns_error(self, error: socket.gaierror) -> None:
"""Handle DNS resolution errors with IPv4/IPv6 dual stack considerations."""
error_msg = str(error)
if "Name or service not known" in error_msg:
_LOGGER.exception("DNS resolution failed - domain name not found")
elif "Temporary failure in name resolution" in error_msg:
_LOGGER.exception("DNS resolution temporarily failed - network or DNS server issue")
elif "Address family for hostname not supported" in error_msg:
_LOGGER.exception("DNS resolution failed - IPv4/IPv6 address family not supported")
elif "No address associated with hostname" in error_msg:
_LOGGER.exception("DNS resolution failed - no IPv4/IPv6 addresses found")
else:
_LOGGER.exception("DNS resolution failed - check internet connection: %s", error_msg)
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) from error
def _handle_network_error(self, error: OSError) -> None:
"""Handle network-level errors with IPv4/IPv6 dual stack considerations."""
error_msg = str(error)
errno = getattr(error, "errno", None)
# Common IPv4/IPv6 dual stack network error codes
errno_network_unreachable = 101 # ENETUNREACH
errno_host_unreachable = 113 # EHOSTUNREACH
errno_connection_refused = 111 # ECONNREFUSED
errno_connection_timeout = 110 # ETIMEDOUT
if errno == errno_network_unreachable:
_LOGGER.exception("Network unreachable - check internet connection or IPv4/IPv6 routing")
elif errno == errno_host_unreachable:
_LOGGER.exception("Host unreachable - routing issue or IPv4/IPv6 connectivity problem")
elif errno == errno_connection_refused:
_LOGGER.exception("Connection refused - server not accepting connections")
elif errno == errno_connection_timeout:
_LOGGER.exception("Connection timed out - network latency or server overload")
elif "Address family not supported" in error_msg:
_LOGGER.exception("Address family not supported - IPv4/IPv6 configuration issue")
elif "Protocol not available" in error_msg:
_LOGGER.exception("Protocol not available - IPv4/IPv6 stack configuration issue")
elif "Network is down" in error_msg:
_LOGGER.exception("Network interface is down - check network adapter")
elif "Permission denied" in error_msg:
_LOGGER.exception("Network permission denied - firewall or security restriction")
else:
_LOGGER.exception("Network error - internet may be down: %s", error_msg)
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) from error
async def _handle_request( async def _handle_request(
self, self,
@ -517,19 +685,90 @@ class TibberPricesApiClient:
) )
await asyncio.sleep(sleep_time) await asyncio.sleep(sleep_time)
async with async_timeout.timeout(10): self._last_request_time = dt_util.now()
self._last_request_time = dt_util.now() return await self._make_request(
return await self._make_request( headers,
headers, data or {},
data or {}, query_type,
query_type, )
)
def _should_retry_error(self, error: Exception, retry: int) -> tuple[bool, int]:
"""Determine if an error should be retried and calculate delay."""
# Check if we've exceeded max retries first
if retry >= self._max_retries:
return False, 0
# Non-retryable errors - authentication and permission issues
if isinstance(error, (TibberPricesApiClientAuthenticationError, TibberPricesApiClientPermissionError)):
return False, 0
# Handle API-specific errors
if isinstance(error, TibberPricesApiClientError):
return self._handle_api_error_retry(error, retry)
# Network and timeout errors - retryable with exponential backoff
if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)):
delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds
return True, delay
# Unknown errors - not retryable
return False, 0
def _handle_api_error_retry(self, error: TibberPricesApiClientError, retry: int) -> tuple[bool, int]:
"""Handle retry logic for API-specific errors."""
error_msg = str(error)
# Non-retryable: Invalid queries
if "Invalid GraphQL query" in error_msg or "Bad request" in error_msg:
return False, 0
# Rate limits - special handling with extracted delay
if "Rate limit exceeded" in error_msg or "rate limited" in error_msg.lower():
delay = self._extract_retry_delay(error, retry)
return True, delay
# Empty data - retryable with capped exponential backoff
if "Empty data received" in error_msg:
delay = min(self._retry_delay * (2**retry), 60) # Cap at 60 seconds
return True, delay
# Other API errors - retryable with capped exponential backoff
delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds
return True, delay
def _extract_retry_delay(self, error: Exception, retry: int) -> int:
"""Extract retry delay from rate limit error or use exponential backoff."""
import re
error_msg = str(error)
# Try to extract Retry-After value from error message
retry_after_match = re.search(r"retry after (\d+) seconds", error_msg.lower())
if retry_after_match:
try:
retry_after = int(retry_after_match.group(1))
return min(retry_after + 1, 300) # Add buffer, max 5 minutes
except ValueError:
pass
# Try to extract generic seconds value
seconds_match = re.search(r"(\d+) seconds", error_msg)
if seconds_match:
try:
seconds = int(seconds_match.group(1))
return min(seconds + 1, 300) # Add buffer, max 5 minutes
except ValueError:
pass
# Fall back to exponential backoff with cap
base_delay = self._retry_delay * (2**retry)
return min(base_delay, 120) # Cap at 2 minutes for rate limits
async def _api_wrapper( async def _api_wrapper(
self, self,
data: dict | None = None, data: dict | None = None,
headers: dict | None = None, headers: dict | None = None,
query_type: QueryType = QueryType.VIEWER, query_type: QueryType = QueryType.USER,
) -> Any: ) -> Any:
"""Get information from the API with rate limiting and retry logic.""" """Get information from the API with rate limiting and retry logic."""
headers = headers or _prepare_headers(self._access_token) headers = headers or _prepare_headers(self._access_token)
@ -537,32 +776,34 @@ class TibberPricesApiClient:
for retry in range(self._max_retries + 1): for retry in range(self._max_retries + 1):
try: try:
return await self._handle_request( return await self._handle_request(headers, data or {}, query_type)
headers,
data or {},
query_type,
)
except TibberPricesApiClientAuthenticationError: except (
TibberPricesApiClientAuthenticationError,
TibberPricesApiClientPermissionError,
):
_LOGGER.exception("Non-retryable error occurred")
raise raise
except ( except (
TibberPricesApiClientError,
aiohttp.ClientError, aiohttp.ClientError,
socket.gaierror, socket.gaierror,
TimeoutError, TimeoutError,
TibberPricesApiClientError,
) as error: ) as error:
last_error = ( last_error = (
error error
if isinstance(error, TibberPricesApiClientError) if isinstance(error, TibberPricesApiClientError)
else TibberPricesApiClientError( else TibberPricesApiClientCommunicationError(
TibberPricesApiClientError.GENERIC_ERROR.format(exception=str(error)) TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
) )
) )
if retry < self._max_retries: should_retry, delay = self._should_retry_error(error, retry)
delay = self._retry_delay * (2**retry) if should_retry:
error_type = self._get_error_type(error)
_LOGGER.warning( _LOGGER.warning(
"Request failed, attempt %d/%d. Retrying in %d seconds: %s", "Tibber %s error, attempt %d/%d. Retrying in %d seconds: %s",
error_type,
retry + 1, retry + 1,
self._max_retries, self._max_retries,
delay, delay,
@ -571,6 +812,10 @@ class TibberPricesApiClient:
await asyncio.sleep(delay) await asyncio.sleep(delay)
continue continue
if "Invalid GraphQL query" in str(error):
_LOGGER.exception("Invalid query - not retrying")
raise
# Handle final error state # Handle final error state
if isinstance(last_error, TimeoutError): if isinstance(last_error, TimeoutError):
raise TibberPricesApiClientCommunicationError( raise TibberPricesApiClientCommunicationError(
@ -582,3 +827,11 @@ class TibberPricesApiClient:
) from last_error ) from last_error
raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR) raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
def _get_error_type(self, error: Exception) -> str:
"""Get a descriptive error type for logging."""
if "Rate limit" in str(error):
return "rate limit"
if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)):
return "network"
return "API"

View file

@ -204,10 +204,15 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow):
async def async_step_user(self, user_input: dict[str, Any] | None = None) -> SubentryFlowResult: async def async_step_user(self, user_input: dict[str, Any] | None = None) -> SubentryFlowResult:
"""User flow to add a new home.""" """User flow to add a new home."""
parent_entry = self._get_entry() parent_entry = self._get_entry()
if not parent_entry: if not parent_entry or not hasattr(parent_entry, "runtime_data") or not parent_entry.runtime_data:
return self.async_abort(reason="no_parent_entry") return self.async_abort(reason="no_parent_entry")
homes = parent_entry.data.get("homes", []) coordinator = parent_entry.runtime_data.coordinator
# Force refresh user data to get latest homes from Tibber API
await coordinator.refresh_user_data()
homes = coordinator.get_user_homes()
if not homes: if not homes:
return self.async_abort(reason="no_available_homes") return self.async_abort(reason="no_available_homes")
@ -233,11 +238,11 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow):
) )
# Get existing home IDs by checking all subentries for this parent # Get existing home IDs by checking all subentries for this parent
existing_home_ids = set() existing_home_ids = {
for entry in self.hass.config_entries.async_entries(DOMAIN): entry.data["home_id"]
# Check if this entry has home_id data (indicating it's a subentry) for entry in self.hass.config_entries.async_entries(DOMAIN)
if entry.data.get("home_id") and entry != parent_entry: if entry.data.get("home_id") and entry != parent_entry
existing_home_ids.add(entry.data["home_id"]) }
available_homes = [home for home in homes if home["id"] not in existing_home_ids] available_homes = [home for home in homes if home["id"] not in existing_home_ids]

File diff suppressed because it is too large Load diff

View file

@ -27,16 +27,47 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]):
"COTTAGE": "Cottage", "COTTAGE": "Cottage",
} }
# Get home info from Tibber API if available # Get user profile information from coordinator
user_profile = self.coordinator.get_user_profile()
# Check if this is a main entry or subentry
is_subentry = bool(self.coordinator.config_entry.data.get("home_id"))
# Initialize variables
home_name = "Tibber Home" home_name = "Tibber Home"
home_id = self.coordinator.config_entry.unique_id home_id = self.coordinator.config_entry.unique_id
home_type = None home_type = None
city = None
app_nickname = None if is_subentry:
address1 = None # For subentries, show specific home information
if coordinator.data: home_data = self.coordinator.config_entry.data.get("home_data", {})
home_id = self.coordinator.config_entry.data.get("home_id")
# Get home details
address = home_data.get("address", {})
address1 = address.get("address1", "")
city = address.get("city", "")
app_nickname = home_data.get("appNickname", "")
home_type = home_data.get("type", "")
# Compose home name
home_name = app_nickname or address1 or f"Tibber Home {home_id}"
if city:
home_name = f"{home_name}, {city}"
# Add user information if available
if user_profile and user_profile.get("name"):
home_name = f"{home_name} ({user_profile['name']})"
elif user_profile:
# For main entry, show user profile information
user_name = user_profile.get("name", "Tibber User")
user_email = user_profile.get("email", "")
home_name = f"Tibber - {user_name}"
if user_email:
home_name = f"{home_name} ({user_email})"
elif coordinator.data:
# Fallback to original logic if user data not available yet
try: try:
home_id = self.unique_id
address1 = str(coordinator.data.get("address", {}).get("address1", "")) address1 = str(coordinator.data.get("address", {}).get("address1", ""))
city = str(coordinator.data.get("address", {}).get("city", "")) city = str(coordinator.data.get("address", {}).get("city", ""))
app_nickname = str(coordinator.data.get("appNickname", "")) app_nickname = str(coordinator.data.get("appNickname", ""))
@ -47,8 +78,6 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]):
home_name = f"{home_name}, {city}" home_name = f"{home_name}, {city}"
except (KeyError, IndexError, TypeError): except (KeyError, IndexError, TypeError):
home_name = "Tibber Home" home_name = "Tibber Home"
else:
home_name = "Tibber Home"
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,

View file

@ -9,6 +9,7 @@
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"issue_tracker": "https://github.com/jpawlowski/hass.tibber_prices/issues", "issue_tracker": "https://github.com/jpawlowski/hass.tibber_prices/issues",
"version": "0.1.0", "version": "0.1.0",
"homeassistant": "2024.1.0",
"requirements": [ "requirements": [
"aiofiles>=23.2.1" "aiofiles>=23.2.1"
] ]

View file

@ -13,6 +13,11 @@ from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .api import (
TibberPricesApiClientAuthenticationError,
TibberPricesApiClientCommunicationError,
TibberPricesApiClientError,
)
from .const import ( from .const import (
DOMAIN, DOMAIN,
PRICE_LEVEL_CHEAP, PRICE_LEVEL_CHEAP,
@ -29,6 +34,7 @@ from .const import (
PRICE_SERVICE_NAME = "get_price" PRICE_SERVICE_NAME = "get_price"
APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data" APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data"
APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml" APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml"
REFRESH_USER_DATA_SERVICE_NAME = "refresh_user_data"
ATTR_DAY: Final = "day" ATTR_DAY: Final = "day"
ATTR_ENTRY_ID: Final = "entry_id" ATTR_ENTRY_ID: Final = "entry_id"
ATTR_TIME: Final = "time" ATTR_TIME: Final = "time"
@ -68,6 +74,12 @@ APEXCHARTS_SERVICE_SCHEMA: Final = vol.Schema(
} }
) )
REFRESH_USER_DATA_SERVICE_SCHEMA: Final = vol.Schema(
{
vol.Required(ATTR_ENTRY_ID): str,
}
)
# region Top-level functions (ordered by call hierarchy) # region Top-level functions (ordered by call hierarchy)
# --- Entry point: Service handler --- # --- Entry point: Service handler ---
@ -165,41 +177,59 @@ async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]:
level_type = call.data.get("level_type", "rating_level") level_type = call.data.get("level_type", "rating_level")
level_key = call.data.get("level_key") level_key = call.data.get("level_key")
hass = call.hass hass = call.hass
# Get entry ID and verify it exists
entry_id = await _get_entry_id_from_entity_id(hass, entity_id) entry_id = await _get_entry_id_from_entity_id(hass, entity_id)
if not entry_id: if not entry_id:
raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id") raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id")
entry, coordinator, data = _get_entry_and_data(hass, entry_id) entry, coordinator, data = _get_entry_and_data(hass, entry_id)
points = []
# Get entries based on level_type
entries = _get_apexcharts_entries(coordinator, day, level_type)
if not entries:
return {"points": []}
# Ensure level_key is a string
if level_key is None:
raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_level_key")
# Generate points for the chart
points = _generate_apexcharts_points(entries, str(level_key))
return {"points": points}
def _get_apexcharts_entries(coordinator: Any, day: str, level_type: str) -> list[dict]:
"""Get the appropriate entries for ApexCharts based on level_type and day."""
if level_type == "rating_level": if level_type == "rating_level":
entries = coordinator.data.get("priceRating", {}).get("hourly", []) entries = coordinator.data.get("priceRating", {}).get("hourly", [])
price_info = coordinator.data.get("priceInfo", {}) price_info = coordinator.data.get("priceInfo", {})
if day == "today": day_info = price_info.get(day, [])
prefixes = _get_day_prefixes(price_info.get("today", [])) prefixes = _get_day_prefixes(day_info)
if not prefixes:
return {"points": []} if not prefixes:
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] return []
elif day == "tomorrow":
prefixes = _get_day_prefixes(price_info.get("tomorrow", [])) return [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])]
if not prefixes:
return {"points": []} # For non-rating level types, return the price info for the specified day
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] return coordinator.data.get("priceInfo", {}).get(day, [])
elif day == "yesterday":
prefixes = _get_day_prefixes(price_info.get("yesterday", []))
if not prefixes: def _generate_apexcharts_points(entries: list[dict], level_key: str) -> list:
return {"points": []} """Generate data points for ApexCharts based on the entries and level key."""
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] points = []
else:
entries = coordinator.data.get("priceInfo", {}).get(day, [])
if not entries:
return {"points": []}
for i in range(len(entries) - 1): for i in range(len(entries) - 1):
p = entries[i] p = entries[i]
if p.get("level") != level_key: if p.get("level") != level_key:
continue continue
points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)]) points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)])
# Add a final point with null value if there are any points
if points: if points:
points.append([points[-1][0], None]) points.append([points[-1][0], None])
return {"points": points}
return points
async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]: async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]:
@ -265,6 +295,57 @@ async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]:
} }
async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]:
"""Refresh user data for a specific config entry and return updated information."""
entry_id = call.data.get(ATTR_ENTRY_ID)
hass = call.hass
if not entry_id:
return {
"success": False,
"message": "Entry ID is required",
}
# Get the entry and coordinator
try:
entry, coordinator, data = _get_entry_and_data(hass, entry_id)
except ServiceValidationError as ex:
return {
"success": False,
"message": f"Invalid entry ID: {ex}",
}
# Force refresh user data using the public method
try:
updated = await coordinator.refresh_user_data()
except (
TibberPricesApiClientAuthenticationError,
TibberPricesApiClientCommunicationError,
TibberPricesApiClientError,
) as ex:
return {
"success": False,
"message": f"API error refreshing user data: {ex!s}",
}
else:
if updated:
user_profile = coordinator.get_user_profile()
homes = coordinator.get_user_homes()
return {
"success": True,
"message": "User data refreshed successfully",
"user_profile": user_profile,
"homes_count": len(homes),
"homes": homes,
"last_updated": user_profile.get("last_updated"),
}
return {
"success": False,
"message": "User data was already up to date",
}
# --- Direct helpers (called by service handler or each other) --- # --- Direct helpers (called by service handler or each other) ---
@ -470,8 +551,6 @@ def _annotate_intervals_with_times(
elif day == "tomorrow": elif day == "tomorrow":
prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False) prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False)
interval["previous_end_time"] = prev_end interval["previous_end_time"] = prev_end
elif day == "yesterday":
interval["previous_end_time"] = None
else: else:
interval["previous_end_time"] = None interval["previous_end_time"] = None
@ -712,6 +791,13 @@ def async_setup_services(hass: HomeAssistant) -> None:
schema=APEXCHARTS_SERVICE_SCHEMA, schema=APEXCHARTS_SERVICE_SCHEMA,
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,
) )
hass.services.async_register(
DOMAIN,
REFRESH_USER_DATA_SERVICE_NAME,
_refresh_user_data,
schema=REFRESH_USER_DATA_SERVICE_SCHEMA,
supports_response=SupportsResponse.ONLY,
)
# endregion # endregion

View file

@ -111,3 +111,16 @@ get_apexcharts_yaml:
- yesterday - yesterday
- today - today
- tomorrow - tomorrow
refresh_user_data:
name: Refresh User Data
description: >-
Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues.
fields:
entry_id:
name: Entry ID
description: The config entry ID for the Tibber integration.
required: true
example: "1234567890abcdef"
selector:
config_entry:
integration: tibber_prices

View file

@ -133,5 +133,27 @@
"name": "Daten für morgen verfügbar" "name": "Daten für morgen verfügbar"
} }
} }
},
"issues": {
"new_homes_available": {
"title": "Neue Tibber-Häuser erkannt",
"description": "Wir haben {count} neue(s) Zuhause in deinem Tibber-Konto erkannt: {homes}. Du kannst diese über die Tibber-Integration in Home Assistant hinzufügen."
},
"homes_removed": {
"title": "Tibber-Häuser entfernt",
"description": "Wir haben erkannt, dass {count} Zuhause aus deinem Tibber-Konto entfernt wurde(n): {homes}. Bitte überprüfe deine Tibber-Integrationskonfiguration."
}
},
"services": {
"refresh_user_data": {
"name": "Benutzerdaten aktualisieren",
"description": "Erzwingt eine Aktualisierung der Benutzerdaten (Häuser, Profilinformationen) aus der Tibber API. Dies kann nützlich sein, nachdem Änderungen an deinem Tibber-Konto vorgenommen wurden oder bei der Fehlerbehebung von Verbindungsproblemen.",
"fields": {
"entry_id": {
"name": "Eintrag-ID",
"description": "Die Konfigurationseintrag-ID für die Tibber-Integration."
}
}
}
} }
} }

View file

@ -33,11 +33,11 @@
}, },
"config_subentries": { "config_subentries": {
"home": { "home": {
"title": "Home", "title": "Add Home",
"step": { "step": {
"user": { "user": {
"title": "Add Tibber Home", "title": "Add Tibber Home",
"description": "Select a home to add to your Tibber integration.", "description": "Select a home to add to your Tibber integration.\n\n**Note:** After adding this home, you can add additional homes from the integration's context menu by selecting \"Add Home\".",
"data": { "data": {
"home_id": "Home" "home_id": "Home"
} }
@ -51,7 +51,7 @@
"no_access_token": "No access token available", "no_access_token": "No access token available",
"home_not_found": "Selected home not found", "home_not_found": "Selected home not found",
"api_error": "Failed to fetch homes from Tibber API", "api_error": "Failed to fetch homes from Tibber API",
"no_available_homes": "No additional homes available to add" "no_available_homes": "No additional homes available to add. All homes from your Tibber account have already been added."
} }
} }
}, },
@ -148,5 +148,27 @@
"name": "Tomorrow's Data Available" "name": "Tomorrow's Data Available"
} }
} }
},
"issues": {
"new_homes_available": {
"title": "New Tibber homes detected",
"description": "We detected {count} new home(s) on your Tibber account: {homes}. You can add them to Home Assistant through the Tibber integration configuration."
},
"homes_removed": {
"title": "Tibber homes removed",
"description": "We detected that {count} home(s) have been removed from your Tibber account: {homes}. Please review your Tibber integration configuration."
}
},
"services": {
"refresh_user_data": {
"name": "Refresh User Data",
"description": "Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues.",
"fields": {
"entry_id": {
"name": "Entry ID",
"description": "The config entry ID for the Tibber integration."
}
}
}
} }
} }

View file

@ -1,6 +1,6 @@
{ {
"name": "Tibber Price Information & Ratings", "name": "Tibber Price Information & Ratings",
"homeassistant": "2025.4.2", "homeassistant": "2025.5.0",
"hacs": "2.0.1", "hacs": "2.0.1",
"render_readme": true "render_readme": true
} }

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
# Tests package

View file

@ -0,0 +1,84 @@
"""Test basic coordinator functionality with the enhanced coordinator."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator
class TestBasicCoordinator:
"""Test basic coordinator functionality."""
@pytest.fixture
def mock_hass(self):
"""Create a mock Home Assistant instance."""
hass = Mock()
hass.data = {}
return hass
@pytest.fixture
def mock_config_entry(self):
"""Create a mock config entry."""
config_entry = Mock()
config_entry.unique_id = "test_home_123"
config_entry.entry_id = "test_entry"
config_entry.data = {"access_token": "test_token"}
return config_entry
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock()
@pytest.fixture
def coordinator(self, mock_hass, mock_config_entry, mock_session):
"""Create a coordinator instance."""
with patch(
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
return_value=mock_session,
):
with patch("custom_components.tibber_prices.coordinator.Store") as mock_store_class:
mock_store = Mock()
mock_store.async_load = AsyncMock(return_value=None)
mock_store.async_save = AsyncMock()
mock_store_class.return_value = mock_store
return TibberPricesDataUpdateCoordinator(mock_hass, mock_config_entry)
def test_coordinator_creation(self, coordinator):
"""Test that coordinator can be created."""
assert coordinator is not None
assert hasattr(coordinator, "get_current_interval_data")
assert hasattr(coordinator, "get_all_intervals")
assert hasattr(coordinator, "get_user_profile")
def test_is_main_entry(self, coordinator):
"""Test main entry detection."""
# First coordinator should be main entry
assert coordinator.is_main_entry() is True
def test_get_user_profile_no_data(self, coordinator):
"""Test getting user profile when no data is cached."""
profile = coordinator.get_user_profile()
assert profile == {"last_updated": None, "cached_user_data": False}
def test_get_user_homes_no_data(self, coordinator):
"""Test getting user homes when no data is cached."""
homes = coordinator.get_user_homes()
assert homes == []
def test_get_current_interval_data_no_data(self, coordinator):
"""Test getting current interval data when no data is available."""
current_data = coordinator.get_current_interval_data()
assert current_data is None
def test_get_all_intervals_no_data(self, coordinator):
"""Test getting all intervals when no data is available."""
intervals = coordinator.get_all_intervals()
assert intervals == []
def test_get_interval_granularity(self, coordinator):
"""Test getting interval granularity."""
granularity = coordinator.get_interval_granularity()
assert granularity is None

View file

@ -0,0 +1,255 @@
"""Test enhanced coordinator functionality."""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock, patch
import pytest
from custom_components.tibber_prices.const import DOMAIN
from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator
class TestEnhancedCoordinator:
"""Test enhanced coordinator functionality."""
@pytest.fixture
def mock_config_entry(self) -> Mock:
"""Create a mock config entry."""
config_entry = Mock()
config_entry.unique_id = "test_home_id_123"
config_entry.entry_id = "test_entry_id"
config_entry.data = {"access_token": "test_token"}
return config_entry
@pytest.fixture
def mock_hass(self) -> Mock:
"""Create a mock Home Assistant instance."""
import asyncio
hass = Mock()
hass.data = {}
# Mock the event loop for time tracking
hass.loop = asyncio.get_event_loop()
return hass
@pytest.fixture
def mock_store(self) -> Mock:
"""Create a mock store."""
store = Mock()
store.async_load = AsyncMock(return_value=None)
store.async_save = AsyncMock()
return store
@pytest.fixture
def mock_api(self) -> Mock:
"""Create a mock API client."""
api = Mock()
api.async_get_viewer_details = AsyncMock(return_value={"homes": []})
api.async_get_price_info = AsyncMock(return_value={"homes": {}})
api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {}})
api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {}})
api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {}})
return api
@pytest.fixture
def coordinator(
self, mock_hass: Mock, mock_config_entry: Mock, mock_store: Mock, mock_api: Mock
) -> TibberPricesDataUpdateCoordinator:
"""Create a coordinator for testing."""
mock_session = Mock()
with (
patch(
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
return_value=mock_session,
),
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
):
coordinator = TibberPricesDataUpdateCoordinator(
hass=mock_hass,
config_entry=mock_config_entry,
)
# Replace the API instance with our mock
coordinator.api = mock_api
return coordinator
@pytest.mark.asyncio
async def test_main_subentry_pattern(self, mock_hass: Mock, mock_store: Mock) -> None:
"""Test main/subentry coordinator pattern."""
# Create main coordinator first
main_config_entry = Mock()
main_config_entry.unique_id = "main_home_id"
main_config_entry.entry_id = "main_entry_id"
main_config_entry.data = {"access_token": "test_token"}
mock_session = Mock()
with (
patch(
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
return_value=mock_session,
),
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
):
main_coordinator = TibberPricesDataUpdateCoordinator(
hass=mock_hass,
config_entry=main_config_entry,
)
# Verify main coordinator is marked as main entry
assert main_coordinator.is_main_entry()
# Create subentry coordinator
sub_config_entry = Mock()
sub_config_entry.unique_id = "sub_home_id"
sub_config_entry.entry_id = "sub_entry_id"
sub_config_entry.data = {"access_token": "test_token", "home_id": "sub_home_id"}
# Set up domain data to simulate main coordinator being already registered
mock_hass.data[DOMAIN] = {"main_entry_id": main_coordinator}
with (
patch(
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
return_value=mock_session,
),
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
):
sub_coordinator = TibberPricesDataUpdateCoordinator(
hass=mock_hass,
config_entry=sub_config_entry,
)
# Verify subentry coordinator is not marked as main entry
assert not sub_coordinator.is_main_entry()
@pytest.mark.asyncio
async def test_user_data_functionality(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
"""Test user data related functionality."""
# Mock user data API
mock_user_data = {
"homes": [
{"id": "home1", "appNickname": "Home 1"},
{"id": "home2", "appNickname": "Home 2"},
]
}
coordinator.api.async_get_viewer_details = AsyncMock(return_value=mock_user_data)
# Test refresh user data
result = await coordinator.refresh_user_data()
assert result
# Test get user profile
profile = coordinator.get_user_profile()
assert isinstance(profile, dict)
assert "last_updated" in profile
assert "cached_user_data" in profile
# Test get user homes
homes = coordinator.get_user_homes()
assert isinstance(homes, list)
@pytest.mark.asyncio
async def test_data_update_with_multi_home_response(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
"""Test coordinator handling multi-home API response."""
# Mock API responses
mock_price_response = {
"homes": {
"test_home_id_123": {
"priceInfo": {
"today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.25}],
"tomorrow": [],
"yesterday": [],
}
},
"other_home_id": {
"priceInfo": {
"today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.30}],
"tomorrow": [],
"yesterday": [],
}
},
}
}
mock_hourly_rating = {"homes": {"test_home_id_123": {"hourly": []}}}
mock_daily_rating = {"homes": {"test_home_id_123": {"daily": []}}}
mock_monthly_rating = {"homes": {"test_home_id_123": {"monthly": []}}}
# Mock all API methods
coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response)
coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value=mock_hourly_rating)
coordinator.api.async_get_daily_price_rating = AsyncMock(return_value=mock_daily_rating)
coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value=mock_monthly_rating)
# Update the coordinator to fetch data
await coordinator.async_refresh()
# Verify coordinator has data
assert coordinator.data is not None
assert "priceInfo" in coordinator.data
assert "priceRating" in coordinator.data
# Test public API methods work
intervals = coordinator.get_all_intervals()
assert isinstance(intervals, list)
@pytest.mark.asyncio
async def test_error_handling_with_cache_fallback(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
"""Test error handling with fallback to cached data."""
from custom_components.tibber_prices.api import TibberPricesApiClientCommunicationError
# Set up cached data using the store mechanism
test_cached_data = {
"timestamp": "2025-05-25T00:00:00Z",
"homes": {
"test_home_id_123": {
"price_info": {"today": [], "tomorrow": [], "yesterday": []},
"hourly_rating": {},
"daily_rating": {},
"monthly_rating": {},
}
},
}
# Mock store to return cached data
coordinator._store.async_load = AsyncMock(
return_value={
"price_data": test_cached_data,
"user_data": None,
"last_price_update": "2025-05-25T00:00:00Z",
"last_user_update": None,
}
)
# Load the cache
await coordinator._load_cache()
# Mock API to raise communication error
coordinator.api.async_get_price_info = AsyncMock(
side_effect=TibberPricesApiClientCommunicationError("Network error")
)
# Should not raise exception but use cached data
await coordinator.async_refresh()
# Verify coordinator has fallback data
assert coordinator.data is not None
@pytest.mark.asyncio
async def test_cache_persistence(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
"""Test that data is properly cached and persisted."""
# Mock API responses
mock_price_response = {
"homes": {"test_home_id_123": {"priceInfo": {"today": [], "tomorrow": [], "yesterday": []}}}
}
coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response)
coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
coordinator.api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
# Update the coordinator
await coordinator.async_refresh()
# Verify data was cached (store should have been called)
coordinator._store.async_save.assert_called()

View file

@ -1,20 +0,0 @@
import unittest
from unittest.mock import Mock, patch
class TestReauthentication(unittest.TestCase):
@patch("your_module.connection") # Replace 'your_module' with the actual module name
def test_reauthentication_flow(self, mock_connection):
mock_connection.reauthenticate = Mock(return_value=True)
result = mock_connection.reauthenticate()
self.assertTrue(result)
@patch("your_module.connection") # Replace 'your_module' with the actual module name
def test_connection_timeout(self, mock_connection):
mock_connection.connect = Mock(side_effect=TimeoutError)
with self.assertRaises(TimeoutError):
mock_connection.connect()
if __name__ == "__main__":
unittest.main()