mirror of
https://github.com/jpawlowski/hass.tibber_prices.git
synced 2026-03-29 21:03:40 +00:00
update
This commit is contained in:
parent
bd33fc7367
commit
f57fdfde6b
15 changed files with 1230 additions and 991 deletions
|
|
@ -51,9 +51,7 @@ async def async_setup_entry(
|
||||||
|
|
||||||
coordinator = TibberPricesDataUpdateCoordinator(
|
coordinator = TibberPricesDataUpdateCoordinator(
|
||||||
hass=hass,
|
hass=hass,
|
||||||
entry=entry,
|
config_entry=entry,
|
||||||
logger=LOGGER,
|
|
||||||
name=DOMAIN,
|
|
||||||
)
|
)
|
||||||
entry.runtime_data = TibberPricesData(
|
entry.runtime_data = TibberPricesData(
|
||||||
client=TibberPricesApiClient(
|
client=TibberPricesApiClient(
|
||||||
|
|
@ -88,7 +86,7 @@ async def async_unload_entry(
|
||||||
|
|
||||||
# Unregister services if this was the last config entry
|
# Unregister services if this was the last config entry
|
||||||
if not hass.config_entries.async_entries(DOMAIN):
|
if not hass.config_entries.async_entries(DOMAIN):
|
||||||
for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml"]:
|
for service in ["get_price", "get_apexcharts_data", "get_apexcharts_yaml", "refresh_user_data"]:
|
||||||
if hass.services.has_service(DOMAIN, service):
|
if hass.services.has_service(DOMAIN, service):
|
||||||
hass.services.async_remove(DOMAIN, service)
|
hass.services.async_remove(DOMAIN, service)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import async_timeout
|
|
||||||
|
|
||||||
from homeassistant.const import __version__ as ha_version
|
from homeassistant.const import __version__ as ha_version
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
@ -19,9 +18,10 @@ from .const import VERSION
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
HTTP_TOO_MANY_REQUESTS = 429
|
HTTP_BAD_REQUEST = 400
|
||||||
HTTP_UNAUTHORIZED = 401
|
HTTP_UNAUTHORIZED = 401
|
||||||
HTTP_FORBIDDEN = 403
|
HTTP_FORBIDDEN = 403
|
||||||
|
HTTP_TOO_MANY_REQUESTS = 429
|
||||||
|
|
||||||
|
|
||||||
class QueryType(Enum):
|
class QueryType(Enum):
|
||||||
|
|
@ -31,7 +31,7 @@ class QueryType(Enum):
|
||||||
DAILY_RATING = "daily"
|
DAILY_RATING = "daily"
|
||||||
HOURLY_RATING = "hourly"
|
HOURLY_RATING = "hourly"
|
||||||
MONTHLY_RATING = "monthly"
|
MONTHLY_RATING = "monthly"
|
||||||
VIEWER = "viewer"
|
USER = "user"
|
||||||
|
|
||||||
|
|
||||||
class TibberPricesApiClientError(Exception):
|
class TibberPricesApiClientError(Exception):
|
||||||
|
|
@ -42,7 +42,8 @@ class TibberPricesApiClientError(Exception):
|
||||||
GRAPHQL_ERROR = "GraphQL error: {message}"
|
GRAPHQL_ERROR = "GraphQL error: {message}"
|
||||||
EMPTY_DATA_ERROR = "Empty data received for {query_type}"
|
EMPTY_DATA_ERROR = "Empty data received for {query_type}"
|
||||||
GENERIC_ERROR = "Something went wrong! {exception}"
|
GENERIC_ERROR = "Something went wrong! {exception}"
|
||||||
RATE_LIMIT_ERROR = "Rate limit exceeded"
|
RATE_LIMIT_ERROR = "Rate limit exceeded. Please wait {retry_after} seconds before retrying"
|
||||||
|
INVALID_QUERY_ERROR = "Invalid GraphQL query: {message}"
|
||||||
|
|
||||||
|
|
||||||
class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
|
class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
|
||||||
|
|
@ -55,15 +56,33 @@ class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
|
||||||
class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError):
|
class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError):
|
||||||
"""Exception to indicate an authentication error."""
|
"""Exception to indicate an authentication error."""
|
||||||
|
|
||||||
INVALID_CREDENTIALS = "Invalid credentials"
|
INVALID_CREDENTIALS = "Invalid access token or expired credentials"
|
||||||
|
|
||||||
|
|
||||||
|
class TibberPricesApiClientPermissionError(TibberPricesApiClientError):
|
||||||
|
"""Exception to indicate insufficient permissions."""
|
||||||
|
|
||||||
|
INSUFFICIENT_PERMISSIONS = "Access forbidden - insufficient permissions for this operation"
|
||||||
|
|
||||||
|
|
||||||
def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None:
|
def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None:
|
||||||
"""Verify that the response is valid."""
|
"""Verify that the response is valid."""
|
||||||
if response.status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN):
|
if response.status == HTTP_UNAUTHORIZED:
|
||||||
|
_LOGGER.error("Tibber API authentication failed - check access token")
|
||||||
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
|
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
|
||||||
|
if response.status == HTTP_FORBIDDEN:
|
||||||
|
_LOGGER.error("Tibber API access forbidden - insufficient permissions")
|
||||||
|
raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS)
|
||||||
if response.status == HTTP_TOO_MANY_REQUESTS:
|
if response.status == HTTP_TOO_MANY_REQUESTS:
|
||||||
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR)
|
# Check for Retry-After header that Tibber might send
|
||||||
|
retry_after = response.headers.get("Retry-After", "unknown")
|
||||||
|
_LOGGER.warning("Tibber API rate limit exceeded - retry after %s seconds", retry_after)
|
||||||
|
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after))
|
||||||
|
if response.status == HTTP_BAD_REQUEST:
|
||||||
|
_LOGGER.error("Tibber API rejected request - likely invalid GraphQL query")
|
||||||
|
raise TibberPricesApiClientError(
|
||||||
|
TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message="Bad request - likely invalid GraphQL query")
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -72,21 +91,41 @@ async def _verify_graphql_response(response_json: dict, query_type: QueryType) -
|
||||||
if "errors" in response_json:
|
if "errors" in response_json:
|
||||||
errors = response_json["errors"]
|
errors = response_json["errors"]
|
||||||
if not errors:
|
if not errors:
|
||||||
|
_LOGGER.error("Tibber API returned empty errors array")
|
||||||
raise TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
|
raise TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
|
||||||
|
|
||||||
error = errors[0] # Take first error
|
error = errors[0] # Take first error
|
||||||
if not isinstance(error, dict):
|
if not isinstance(error, dict):
|
||||||
|
_LOGGER.error("Tibber API returned malformed error: %s", error)
|
||||||
raise TibberPricesApiClientError(TibberPricesApiClientError.MALFORMED_ERROR.format(error=error))
|
raise TibberPricesApiClientError(TibberPricesApiClientError.MALFORMED_ERROR.format(error=error))
|
||||||
|
|
||||||
message = error.get("message", "Unknown error")
|
message = error.get("message", "Unknown error")
|
||||||
extensions = error.get("extensions", {})
|
extensions = error.get("extensions", {})
|
||||||
|
error_code = extensions.get("code")
|
||||||
|
|
||||||
if extensions.get("code") == "UNAUTHENTICATED":
|
# Handle specific Tibber API error codes
|
||||||
|
if error_code == "UNAUTHENTICATED":
|
||||||
|
_LOGGER.error("Tibber API authentication error: %s", message)
|
||||||
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
|
raise TibberPricesApiClientAuthenticationError(TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS)
|
||||||
|
if error_code == "FORBIDDEN":
|
||||||
|
_LOGGER.error("Tibber API permission error: %s", message)
|
||||||
|
raise TibberPricesApiClientPermissionError(TibberPricesApiClientPermissionError.INSUFFICIENT_PERMISSIONS)
|
||||||
|
if error_code in ["RATE_LIMITED", "TOO_MANY_REQUESTS"]:
|
||||||
|
# Some GraphQL APIs return rate limit info in extensions
|
||||||
|
retry_after = extensions.get("retryAfter", "unknown")
|
||||||
|
_LOGGER.warning("Tibber API rate limited via GraphQL: %s (retry after %s)", message, retry_after)
|
||||||
|
raise TibberPricesApiClientError(
|
||||||
|
TibberPricesApiClientError.RATE_LIMIT_ERROR.format(retry_after=retry_after)
|
||||||
|
)
|
||||||
|
if error_code in ["VALIDATION_ERROR", "GRAPHQL_VALIDATION_FAILED"]:
|
||||||
|
_LOGGER.error("Tibber API validation error: %s", message)
|
||||||
|
raise TibberPricesApiClientError(TibberPricesApiClientError.INVALID_QUERY_ERROR.format(message=message))
|
||||||
|
|
||||||
|
_LOGGER.error("Tibber API GraphQL error (code: %s): %s", error_code or "unknown", message)
|
||||||
raise TibberPricesApiClientError(TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message))
|
raise TibberPricesApiClientError(TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message))
|
||||||
|
|
||||||
if "data" not in response_json or response_json["data"] is None:
|
if "data" not in response_json or response_json["data"] is None:
|
||||||
|
_LOGGER.error("Tibber API response missing data object")
|
||||||
raise TibberPricesApiClientError(
|
raise TibberPricesApiClientError(
|
||||||
TibberPricesApiClientError.GRAPHQL_ERROR.format(message="Response missing data object")
|
TibberPricesApiClientError.GRAPHQL_ERROR.format(message="Response missing data object")
|
||||||
)
|
)
|
||||||
|
|
@ -122,7 +161,7 @@ def _is_data_empty(data: dict, query_type: str) -> bool:
|
||||||
|
|
||||||
is_empty = False
|
is_empty = False
|
||||||
try:
|
try:
|
||||||
if query_type == "viewer":
|
if query_type == "user":
|
||||||
has_user_id = (
|
has_user_id = (
|
||||||
"viewer" in data
|
"viewer" in data
|
||||||
and isinstance(data["viewer"], dict)
|
and isinstance(data["viewer"], dict)
|
||||||
|
|
@ -321,11 +360,16 @@ class TibberPricesApiClient:
|
||||||
"""Tibber API Client."""
|
"""Tibber API Client."""
|
||||||
self._access_token = access_token
|
self._access_token = access_token
|
||||||
self._session = session
|
self._session = session
|
||||||
self._request_semaphore = asyncio.Semaphore(2)
|
self._request_semaphore = asyncio.Semaphore(2) # Max 2 concurrent requests
|
||||||
self._last_request_time = dt_util.now()
|
self._last_request_time = dt_util.now()
|
||||||
self._min_request_interval = timedelta(seconds=1)
|
self._min_request_interval = timedelta(seconds=1) # Min 1 second between requests
|
||||||
self._max_retries = 5
|
self._max_retries = 5
|
||||||
self._retry_delay = 2
|
self._retry_delay = 2 # Base retry delay in seconds
|
||||||
|
|
||||||
|
# Timeout configuration - more granular control
|
||||||
|
self._connect_timeout = 10 # Connection timeout in seconds
|
||||||
|
self._request_timeout = 25 # Total request timeout in seconds
|
||||||
|
self._socket_connect_timeout = 5 # Socket connection timeout
|
||||||
|
|
||||||
async def async_get_viewer_details(self) -> Any:
|
async def async_get_viewer_details(self) -> Any:
|
||||||
"""Test connection to the API."""
|
"""Test connection to the API."""
|
||||||
|
|
@ -352,11 +396,11 @@ class TibberPricesApiClient:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
},
|
},
|
||||||
query_type=QueryType.VIEWER,
|
query_type=QueryType.USER,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_price_info(self, home_id: str) -> dict:
|
async def async_get_price_info(self) -> dict:
|
||||||
"""Get price info data in flat format for the specified home_id."""
|
"""Get price info data in flat format for all homes."""
|
||||||
data = await self._api_wrapper(
|
data = await self._api_wrapper(
|
||||||
data={
|
data={
|
||||||
"query": """
|
"query": """
|
||||||
|
|
@ -371,15 +415,21 @@ class TibberPricesApiClient:
|
||||||
query_type=QueryType.PRICE_INFO,
|
query_type=QueryType.PRICE_INFO,
|
||||||
)
|
)
|
||||||
homes = data.get("viewer", {}).get("homes", [])
|
homes = data.get("viewer", {}).get("homes", [])
|
||||||
home = next((h for h in homes if h.get("id") == home_id), None)
|
|
||||||
if home and "currentSubscription" in home:
|
homes_data = {}
|
||||||
data["priceInfo"] = _flatten_price_info(home["currentSubscription"])
|
for home in homes:
|
||||||
else:
|
home_id = home.get("id")
|
||||||
data["priceInfo"] = {}
|
if home_id:
|
||||||
|
if "currentSubscription" in home:
|
||||||
|
homes_data[home_id] = _flatten_price_info(home["currentSubscription"])
|
||||||
|
else:
|
||||||
|
homes_data[home_id] = {}
|
||||||
|
|
||||||
|
data["homes"] = homes_data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def async_get_daily_price_rating(self, home_id: str) -> dict:
|
async def async_get_daily_price_rating(self) -> dict:
|
||||||
"""Get daily price rating data in flat format for the specified home_id."""
|
"""Get daily price rating data in flat format for all homes."""
|
||||||
data = await self._api_wrapper(
|
data = await self._api_wrapper(
|
||||||
data={
|
data={
|
||||||
"query": """
|
"query": """
|
||||||
|
|
@ -394,15 +444,21 @@ class TibberPricesApiClient:
|
||||||
query_type=QueryType.DAILY_RATING,
|
query_type=QueryType.DAILY_RATING,
|
||||||
)
|
)
|
||||||
homes = data.get("viewer", {}).get("homes", [])
|
homes = data.get("viewer", {}).get("homes", [])
|
||||||
home = next((h for h in homes if h.get("id") == home_id), None)
|
|
||||||
if home and "currentSubscription" in home:
|
homes_data = {}
|
||||||
data["priceRating"] = _flatten_price_rating(home["currentSubscription"])
|
for home in homes:
|
||||||
else:
|
home_id = home.get("id")
|
||||||
data["priceRating"] = {}
|
if home_id:
|
||||||
|
if "currentSubscription" in home:
|
||||||
|
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
|
||||||
|
else:
|
||||||
|
homes_data[home_id] = {}
|
||||||
|
|
||||||
|
data["homes"] = homes_data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def async_get_hourly_price_rating(self, home_id: str) -> dict:
|
async def async_get_hourly_price_rating(self) -> dict:
|
||||||
"""Get hourly price rating data in flat format for the specified home_id."""
|
"""Get hourly price rating data in flat format for all homes."""
|
||||||
data = await self._api_wrapper(
|
data = await self._api_wrapper(
|
||||||
data={
|
data={
|
||||||
"query": """
|
"query": """
|
||||||
|
|
@ -417,15 +473,21 @@ class TibberPricesApiClient:
|
||||||
query_type=QueryType.HOURLY_RATING,
|
query_type=QueryType.HOURLY_RATING,
|
||||||
)
|
)
|
||||||
homes = data.get("viewer", {}).get("homes", [])
|
homes = data.get("viewer", {}).get("homes", [])
|
||||||
home = next((h for h in homes if h.get("id") == home_id), None)
|
|
||||||
if home and "currentSubscription" in home:
|
homes_data = {}
|
||||||
data["priceRating"] = _flatten_price_rating(home["currentSubscription"])
|
for home in homes:
|
||||||
else:
|
home_id = home.get("id")
|
||||||
data["priceRating"] = {}
|
if home_id:
|
||||||
|
if "currentSubscription" in home:
|
||||||
|
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
|
||||||
|
else:
|
||||||
|
homes_data[home_id] = {}
|
||||||
|
|
||||||
|
data["homes"] = homes_data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def async_get_monthly_price_rating(self, home_id: str) -> dict:
|
async def async_get_monthly_price_rating(self) -> dict:
|
||||||
"""Get monthly price rating data in flat format for the specified home_id."""
|
"""Get monthly price rating data in flat format for all homes."""
|
||||||
data = await self._api_wrapper(
|
data = await self._api_wrapper(
|
||||||
data={
|
data={
|
||||||
"query": """
|
"query": """
|
||||||
|
|
@ -440,40 +502,52 @@ class TibberPricesApiClient:
|
||||||
query_type=QueryType.MONTHLY_RATING,
|
query_type=QueryType.MONTHLY_RATING,
|
||||||
)
|
)
|
||||||
homes = data.get("viewer", {}).get("homes", [])
|
homes = data.get("viewer", {}).get("homes", [])
|
||||||
home = next((h for h in homes if h.get("id") == home_id), None)
|
|
||||||
if home and "currentSubscription" in home:
|
homes_data = {}
|
||||||
data["priceRating"] = _flatten_price_rating(home["currentSubscription"])
|
for home in homes:
|
||||||
else:
|
home_id = home.get("id")
|
||||||
data["priceRating"] = {}
|
if home_id:
|
||||||
|
if "currentSubscription" in home:
|
||||||
|
homes_data[home_id] = _flatten_price_rating(home["currentSubscription"])
|
||||||
|
else:
|
||||||
|
homes_data[home_id] = {}
|
||||||
|
|
||||||
|
data["homes"] = homes_data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def async_get_data(self, home_id: str) -> dict:
|
async def async_get_data(self) -> dict:
|
||||||
"""Get all data from the API by combining multiple queries in flat format for the specified home_id."""
|
"""Get all data from the API by combining multiple queries in flat format for all homes."""
|
||||||
price_info = await self.async_get_price_info(home_id)
|
price_info = await self.async_get_price_info()
|
||||||
daily_rating = await self.async_get_daily_price_rating(home_id)
|
daily_rating = await self.async_get_daily_price_rating()
|
||||||
hourly_rating = await self.async_get_hourly_price_rating(home_id)
|
hourly_rating = await self.async_get_hourly_price_rating()
|
||||||
monthly_rating = await self.async_get_monthly_price_rating(home_id)
|
monthly_rating = await self.async_get_monthly_price_rating()
|
||||||
price_rating = {
|
|
||||||
"thresholdPercentages": daily_rating["priceRating"].get("thresholdPercentages"),
|
|
||||||
"daily": daily_rating["priceRating"].get("daily", []),
|
|
||||||
"hourly": hourly_rating["priceRating"].get("hourly", []),
|
|
||||||
"monthly": monthly_rating["priceRating"].get("monthly", []),
|
|
||||||
"currency": (
|
|
||||||
daily_rating["priceRating"].get("currency")
|
|
||||||
or hourly_rating["priceRating"].get("currency")
|
|
||||||
or monthly_rating["priceRating"].get("currency")
|
|
||||||
),
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"priceInfo": price_info["priceInfo"],
|
|
||||||
"priceRating": price_rating,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def async_set_title(self, value: str) -> Any:
|
all_home_ids = set()
|
||||||
"""Get data from the API."""
|
all_home_ids.update(price_info.get("homes", {}).keys())
|
||||||
return await self._api_wrapper(
|
all_home_ids.update(daily_rating.get("homes", {}).keys())
|
||||||
data={"title": value},
|
all_home_ids.update(hourly_rating.get("homes", {}).keys())
|
||||||
)
|
all_home_ids.update(monthly_rating.get("homes", {}).keys())
|
||||||
|
|
||||||
|
homes_combined = {}
|
||||||
|
for home_id in all_home_ids:
|
||||||
|
daily_data = daily_rating.get("homes", {}).get(home_id, {})
|
||||||
|
hourly_data = hourly_rating.get("homes", {}).get(home_id, {})
|
||||||
|
monthly_data = monthly_rating.get("homes", {}).get(home_id, {})
|
||||||
|
|
||||||
|
price_rating = {
|
||||||
|
"thresholdPercentages": daily_data.get("thresholdPercentages"),
|
||||||
|
"daily": daily_data.get("daily", []),
|
||||||
|
"hourly": hourly_data.get("hourly", []),
|
||||||
|
"monthly": monthly_data.get("monthly", []),
|
||||||
|
"currency": (daily_data.get("currency") or hourly_data.get("currency") or monthly_data.get("currency")),
|
||||||
|
}
|
||||||
|
|
||||||
|
homes_combined[home_id] = {
|
||||||
|
"priceInfo": price_info.get("homes", {}).get(home_id, {}),
|
||||||
|
"priceRating": price_rating,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"homes": homes_combined}
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
|
|
@ -481,23 +555,117 @@ class TibberPricesApiClient:
|
||||||
data: dict,
|
data: dict,
|
||||||
query_type: QueryType,
|
query_type: QueryType,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Make an API request with proper error handling."""
|
"""Make an API request with comprehensive error handling for network issues."""
|
||||||
_LOGGER.debug("Making API request with data: %s", data)
|
_LOGGER.debug("Making API request with data: %s", data)
|
||||||
|
|
||||||
response = await self._session.request(
|
try:
|
||||||
method="POST",
|
# More granular timeout configuration for better network failure handling
|
||||||
url="https://api.tibber.com/v1-beta/gql",
|
timeout = aiohttp.ClientTimeout(
|
||||||
headers=headers,
|
total=self._request_timeout, # Total request timeout: 25s
|
||||||
json=data,
|
connect=self._connect_timeout, # Connection timeout: 10s
|
||||||
)
|
sock_connect=self._socket_connect_timeout, # Socket connection: 5s
|
||||||
|
)
|
||||||
|
|
||||||
_verify_response_or_raise(response)
|
response = await self._session.request(
|
||||||
response_json = await response.json()
|
method="POST",
|
||||||
_LOGGER.debug("Received API response: %s", response_json)
|
url="https://api.tibber.com/v1-beta/gql",
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
await _verify_graphql_response(response_json, query_type)
|
_verify_response_or_raise(response)
|
||||||
|
response_json = await response.json()
|
||||||
|
_LOGGER.debug("Received API response: %s", response_json)
|
||||||
|
|
||||||
return response_json["data"]
|
await _verify_graphql_response(response_json, query_type)
|
||||||
|
|
||||||
|
return response_json["data"]
|
||||||
|
|
||||||
|
except aiohttp.ClientResponseError as error:
|
||||||
|
_LOGGER.exception("HTTP error during API request")
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
|
except aiohttp.ClientConnectorError as error:
|
||||||
|
_LOGGER.exception("Connection error - server unreachable or network down")
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
|
except aiohttp.ServerDisconnectedError as error:
|
||||||
|
_LOGGER.exception("Server disconnected during request")
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
|
except TimeoutError as error:
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Request timeout after %d seconds - slow network or server overload", self._request_timeout
|
||||||
|
)
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
|
except socket.gaierror as error:
|
||||||
|
self._handle_dns_error(error)
|
||||||
|
|
||||||
|
except OSError as error:
|
||||||
|
self._handle_network_error(error)
|
||||||
|
|
||||||
|
def _handle_dns_error(self, error: socket.gaierror) -> None:
|
||||||
|
"""Handle DNS resolution errors with IPv4/IPv6 dual stack considerations."""
|
||||||
|
error_msg = str(error)
|
||||||
|
|
||||||
|
if "Name or service not known" in error_msg:
|
||||||
|
_LOGGER.exception("DNS resolution failed - domain name not found")
|
||||||
|
elif "Temporary failure in name resolution" in error_msg:
|
||||||
|
_LOGGER.exception("DNS resolution temporarily failed - network or DNS server issue")
|
||||||
|
elif "Address family for hostname not supported" in error_msg:
|
||||||
|
_LOGGER.exception("DNS resolution failed - IPv4/IPv6 address family not supported")
|
||||||
|
elif "No address associated with hostname" in error_msg:
|
||||||
|
_LOGGER.exception("DNS resolution failed - no IPv4/IPv6 addresses found")
|
||||||
|
else:
|
||||||
|
_LOGGER.exception("DNS resolution failed - check internet connection: %s", error_msg)
|
||||||
|
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
|
def _handle_network_error(self, error: OSError) -> None:
|
||||||
|
"""Handle network-level errors with IPv4/IPv6 dual stack considerations."""
|
||||||
|
error_msg = str(error)
|
||||||
|
errno = getattr(error, "errno", None)
|
||||||
|
|
||||||
|
# Common IPv4/IPv6 dual stack network error codes
|
||||||
|
errno_network_unreachable = 101 # ENETUNREACH
|
||||||
|
errno_host_unreachable = 113 # EHOSTUNREACH
|
||||||
|
errno_connection_refused = 111 # ECONNREFUSED
|
||||||
|
errno_connection_timeout = 110 # ETIMEDOUT
|
||||||
|
|
||||||
|
if errno == errno_network_unreachable:
|
||||||
|
_LOGGER.exception("Network unreachable - check internet connection or IPv4/IPv6 routing")
|
||||||
|
elif errno == errno_host_unreachable:
|
||||||
|
_LOGGER.exception("Host unreachable - routing issue or IPv4/IPv6 connectivity problem")
|
||||||
|
elif errno == errno_connection_refused:
|
||||||
|
_LOGGER.exception("Connection refused - server not accepting connections")
|
||||||
|
elif errno == errno_connection_timeout:
|
||||||
|
_LOGGER.exception("Connection timed out - network latency or server overload")
|
||||||
|
elif "Address family not supported" in error_msg:
|
||||||
|
_LOGGER.exception("Address family not supported - IPv4/IPv6 configuration issue")
|
||||||
|
elif "Protocol not available" in error_msg:
|
||||||
|
_LOGGER.exception("Protocol not available - IPv4/IPv6 stack configuration issue")
|
||||||
|
elif "Network is down" in error_msg:
|
||||||
|
_LOGGER.exception("Network interface is down - check network adapter")
|
||||||
|
elif "Permission denied" in error_msg:
|
||||||
|
_LOGGER.exception("Network permission denied - firewall or security restriction")
|
||||||
|
else:
|
||||||
|
_LOGGER.exception("Network error - internet may be down: %s", error_msg)
|
||||||
|
|
||||||
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
|
) from error
|
||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
self,
|
self,
|
||||||
|
|
@ -517,19 +685,90 @@ class TibberPricesApiClient:
|
||||||
)
|
)
|
||||||
await asyncio.sleep(sleep_time)
|
await asyncio.sleep(sleep_time)
|
||||||
|
|
||||||
async with async_timeout.timeout(10):
|
self._last_request_time = dt_util.now()
|
||||||
self._last_request_time = dt_util.now()
|
return await self._make_request(
|
||||||
return await self._make_request(
|
headers,
|
||||||
headers,
|
data or {},
|
||||||
data or {},
|
query_type,
|
||||||
query_type,
|
)
|
||||||
)
|
|
||||||
|
def _should_retry_error(self, error: Exception, retry: int) -> tuple[bool, int]:
|
||||||
|
"""Determine if an error should be retried and calculate delay."""
|
||||||
|
# Check if we've exceeded max retries first
|
||||||
|
if retry >= self._max_retries:
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# Non-retryable errors - authentication and permission issues
|
||||||
|
if isinstance(error, (TibberPricesApiClientAuthenticationError, TibberPricesApiClientPermissionError)):
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# Handle API-specific errors
|
||||||
|
if isinstance(error, TibberPricesApiClientError):
|
||||||
|
return self._handle_api_error_retry(error, retry)
|
||||||
|
|
||||||
|
# Network and timeout errors - retryable with exponential backoff
|
||||||
|
if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)):
|
||||||
|
delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds
|
||||||
|
return True, delay
|
||||||
|
|
||||||
|
# Unknown errors - not retryable
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
def _handle_api_error_retry(self, error: TibberPricesApiClientError, retry: int) -> tuple[bool, int]:
|
||||||
|
"""Handle retry logic for API-specific errors."""
|
||||||
|
error_msg = str(error)
|
||||||
|
|
||||||
|
# Non-retryable: Invalid queries
|
||||||
|
if "Invalid GraphQL query" in error_msg or "Bad request" in error_msg:
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# Rate limits - special handling with extracted delay
|
||||||
|
if "Rate limit exceeded" in error_msg or "rate limited" in error_msg.lower():
|
||||||
|
delay = self._extract_retry_delay(error, retry)
|
||||||
|
return True, delay
|
||||||
|
|
||||||
|
# Empty data - retryable with capped exponential backoff
|
||||||
|
if "Empty data received" in error_msg:
|
||||||
|
delay = min(self._retry_delay * (2**retry), 60) # Cap at 60 seconds
|
||||||
|
return True, delay
|
||||||
|
|
||||||
|
# Other API errors - retryable with capped exponential backoff
|
||||||
|
delay = min(self._retry_delay * (2**retry), 30) # Cap at 30 seconds
|
||||||
|
return True, delay
|
||||||
|
|
||||||
|
def _extract_retry_delay(self, error: Exception, retry: int) -> int:
|
||||||
|
"""Extract retry delay from rate limit error or use exponential backoff."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
error_msg = str(error)
|
||||||
|
|
||||||
|
# Try to extract Retry-After value from error message
|
||||||
|
retry_after_match = re.search(r"retry after (\d+) seconds", error_msg.lower())
|
||||||
|
if retry_after_match:
|
||||||
|
try:
|
||||||
|
retry_after = int(retry_after_match.group(1))
|
||||||
|
return min(retry_after + 1, 300) # Add buffer, max 5 minutes
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to extract generic seconds value
|
||||||
|
seconds_match = re.search(r"(\d+) seconds", error_msg)
|
||||||
|
if seconds_match:
|
||||||
|
try:
|
||||||
|
seconds = int(seconds_match.group(1))
|
||||||
|
return min(seconds + 1, 300) # Add buffer, max 5 minutes
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to exponential backoff with cap
|
||||||
|
base_delay = self._retry_delay * (2**retry)
|
||||||
|
return min(base_delay, 120) # Cap at 2 minutes for rate limits
|
||||||
|
|
||||||
async def _api_wrapper(
|
async def _api_wrapper(
|
||||||
self,
|
self,
|
||||||
data: dict | None = None,
|
data: dict | None = None,
|
||||||
headers: dict | None = None,
|
headers: dict | None = None,
|
||||||
query_type: QueryType = QueryType.VIEWER,
|
query_type: QueryType = QueryType.USER,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get information from the API with rate limiting and retry logic."""
|
"""Get information from the API with rate limiting and retry logic."""
|
||||||
headers = headers or _prepare_headers(self._access_token)
|
headers = headers or _prepare_headers(self._access_token)
|
||||||
|
|
@ -537,32 +776,34 @@ class TibberPricesApiClient:
|
||||||
|
|
||||||
for retry in range(self._max_retries + 1):
|
for retry in range(self._max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return await self._handle_request(
|
return await self._handle_request(headers, data or {}, query_type)
|
||||||
headers,
|
|
||||||
data or {},
|
|
||||||
query_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
except TibberPricesApiClientAuthenticationError:
|
except (
|
||||||
|
TibberPricesApiClientAuthenticationError,
|
||||||
|
TibberPricesApiClientPermissionError,
|
||||||
|
):
|
||||||
|
_LOGGER.exception("Non-retryable error occurred")
|
||||||
raise
|
raise
|
||||||
except (
|
except (
|
||||||
|
TibberPricesApiClientError,
|
||||||
aiohttp.ClientError,
|
aiohttp.ClientError,
|
||||||
socket.gaierror,
|
socket.gaierror,
|
||||||
TimeoutError,
|
TimeoutError,
|
||||||
TibberPricesApiClientError,
|
|
||||||
) as error:
|
) as error:
|
||||||
last_error = (
|
last_error = (
|
||||||
error
|
error
|
||||||
if isinstance(error, TibberPricesApiClientError)
|
if isinstance(error, TibberPricesApiClientError)
|
||||||
else TibberPricesApiClientError(
|
else TibberPricesApiClientCommunicationError(
|
||||||
TibberPricesApiClientError.GENERIC_ERROR.format(exception=str(error))
|
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(exception=str(error))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if retry < self._max_retries:
|
should_retry, delay = self._should_retry_error(error, retry)
|
||||||
delay = self._retry_delay * (2**retry)
|
if should_retry:
|
||||||
|
error_type = self._get_error_type(error)
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Request failed, attempt %d/%d. Retrying in %d seconds: %s",
|
"Tibber %s error, attempt %d/%d. Retrying in %d seconds: %s",
|
||||||
|
error_type,
|
||||||
retry + 1,
|
retry + 1,
|
||||||
self._max_retries,
|
self._max_retries,
|
||||||
delay,
|
delay,
|
||||||
|
|
@ -571,6 +812,10 @@ class TibberPricesApiClient:
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if "Invalid GraphQL query" in str(error):
|
||||||
|
_LOGGER.exception("Invalid query - not retrying")
|
||||||
|
raise
|
||||||
|
|
||||||
# Handle final error state
|
# Handle final error state
|
||||||
if isinstance(last_error, TimeoutError):
|
if isinstance(last_error, TimeoutError):
|
||||||
raise TibberPricesApiClientCommunicationError(
|
raise TibberPricesApiClientCommunicationError(
|
||||||
|
|
@ -582,3 +827,11 @@ class TibberPricesApiClient:
|
||||||
) from last_error
|
) from last_error
|
||||||
|
|
||||||
raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
|
raise last_error or TibberPricesApiClientError(TibberPricesApiClientError.UNKNOWN_ERROR)
|
||||||
|
|
||||||
|
def _get_error_type(self, error: Exception) -> str:
|
||||||
|
"""Get a descriptive error type for logging."""
|
||||||
|
if "Rate limit" in str(error):
|
||||||
|
return "rate limit"
|
||||||
|
if isinstance(error, (aiohttp.ClientError, socket.gaierror, TimeoutError)):
|
||||||
|
return "network"
|
||||||
|
return "API"
|
||||||
|
|
|
||||||
|
|
@ -204,10 +204,15 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
async def async_step_user(self, user_input: dict[str, Any] | None = None) -> SubentryFlowResult:
|
async def async_step_user(self, user_input: dict[str, Any] | None = None) -> SubentryFlowResult:
|
||||||
"""User flow to add a new home."""
|
"""User flow to add a new home."""
|
||||||
parent_entry = self._get_entry()
|
parent_entry = self._get_entry()
|
||||||
if not parent_entry:
|
if not parent_entry or not hasattr(parent_entry, "runtime_data") or not parent_entry.runtime_data:
|
||||||
return self.async_abort(reason="no_parent_entry")
|
return self.async_abort(reason="no_parent_entry")
|
||||||
|
|
||||||
homes = parent_entry.data.get("homes", [])
|
coordinator = parent_entry.runtime_data.coordinator
|
||||||
|
|
||||||
|
# Force refresh user data to get latest homes from Tibber API
|
||||||
|
await coordinator.refresh_user_data()
|
||||||
|
|
||||||
|
homes = coordinator.get_user_homes()
|
||||||
if not homes:
|
if not homes:
|
||||||
return self.async_abort(reason="no_available_homes")
|
return self.async_abort(reason="no_available_homes")
|
||||||
|
|
||||||
|
|
@ -233,11 +238,11 @@ class TibberPricesSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get existing home IDs by checking all subentries for this parent
|
# Get existing home IDs by checking all subentries for this parent
|
||||||
existing_home_ids = set()
|
existing_home_ids = {
|
||||||
for entry in self.hass.config_entries.async_entries(DOMAIN):
|
entry.data["home_id"]
|
||||||
# Check if this entry has home_id data (indicating it's a subentry)
|
for entry in self.hass.config_entries.async_entries(DOMAIN)
|
||||||
if entry.data.get("home_id") and entry != parent_entry:
|
if entry.data.get("home_id") and entry != parent_entry
|
||||||
existing_home_ids.add(entry.data["home_id"])
|
}
|
||||||
|
|
||||||
available_homes = [home for home in homes if home["id"] not in existing_home_ids]
|
available_homes = [home for home in homes if home["id"] not in existing_home_ids]
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -27,16 +27,47 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]):
|
||||||
"COTTAGE": "Cottage",
|
"COTTAGE": "Cottage",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get home info from Tibber API if available
|
# Get user profile information from coordinator
|
||||||
|
user_profile = self.coordinator.get_user_profile()
|
||||||
|
|
||||||
|
# Check if this is a main entry or subentry
|
||||||
|
is_subentry = bool(self.coordinator.config_entry.data.get("home_id"))
|
||||||
|
|
||||||
|
# Initialize variables
|
||||||
home_name = "Tibber Home"
|
home_name = "Tibber Home"
|
||||||
home_id = self.coordinator.config_entry.unique_id
|
home_id = self.coordinator.config_entry.unique_id
|
||||||
home_type = None
|
home_type = None
|
||||||
city = None
|
|
||||||
app_nickname = None
|
if is_subentry:
|
||||||
address1 = None
|
# For subentries, show specific home information
|
||||||
if coordinator.data:
|
home_data = self.coordinator.config_entry.data.get("home_data", {})
|
||||||
|
home_id = self.coordinator.config_entry.data.get("home_id")
|
||||||
|
|
||||||
|
# Get home details
|
||||||
|
address = home_data.get("address", {})
|
||||||
|
address1 = address.get("address1", "")
|
||||||
|
city = address.get("city", "")
|
||||||
|
app_nickname = home_data.get("appNickname", "")
|
||||||
|
home_type = home_data.get("type", "")
|
||||||
|
|
||||||
|
# Compose home name
|
||||||
|
home_name = app_nickname or address1 or f"Tibber Home {home_id}"
|
||||||
|
if city:
|
||||||
|
home_name = f"{home_name}, {city}"
|
||||||
|
|
||||||
|
# Add user information if available
|
||||||
|
if user_profile and user_profile.get("name"):
|
||||||
|
home_name = f"{home_name} ({user_profile['name']})"
|
||||||
|
elif user_profile:
|
||||||
|
# For main entry, show user profile information
|
||||||
|
user_name = user_profile.get("name", "Tibber User")
|
||||||
|
user_email = user_profile.get("email", "")
|
||||||
|
home_name = f"Tibber - {user_name}"
|
||||||
|
if user_email:
|
||||||
|
home_name = f"{home_name} ({user_email})"
|
||||||
|
elif coordinator.data:
|
||||||
|
# Fallback to original logic if user data not available yet
|
||||||
try:
|
try:
|
||||||
home_id = self.unique_id
|
|
||||||
address1 = str(coordinator.data.get("address", {}).get("address1", ""))
|
address1 = str(coordinator.data.get("address", {}).get("address1", ""))
|
||||||
city = str(coordinator.data.get("address", {}).get("city", ""))
|
city = str(coordinator.data.get("address", {}).get("city", ""))
|
||||||
app_nickname = str(coordinator.data.get("appNickname", ""))
|
app_nickname = str(coordinator.data.get("appNickname", ""))
|
||||||
|
|
@ -47,8 +78,6 @@ class TibberPricesEntity(CoordinatorEntity[TibberPricesDataUpdateCoordinator]):
|
||||||
home_name = f"{home_name}, {city}"
|
home_name = f"{home_name}, {city}"
|
||||||
except (KeyError, IndexError, TypeError):
|
except (KeyError, IndexError, TypeError):
|
||||||
home_name = "Tibber Home"
|
home_name = "Tibber Home"
|
||||||
else:
|
|
||||||
home_name = "Tibber Home"
|
|
||||||
|
|
||||||
self._attr_device_info = DeviceInfo(
|
self._attr_device_info = DeviceInfo(
|
||||||
entry_type=DeviceEntryType.SERVICE,
|
entry_type=DeviceEntryType.SERVICE,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"issue_tracker": "https://github.com/jpawlowski/hass.tibber_prices/issues",
|
"issue_tracker": "https://github.com/jpawlowski/hass.tibber_prices/issues",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
|
"homeassistant": "2024.1.0",
|
||||||
"requirements": [
|
"requirements": [
|
||||||
"aiofiles>=23.2.1"
|
"aiofiles>=23.2.1"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,11 @@ from homeassistant.exceptions import ServiceValidationError
|
||||||
from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry
|
from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from .api import (
|
||||||
|
TibberPricesApiClientAuthenticationError,
|
||||||
|
TibberPricesApiClientCommunicationError,
|
||||||
|
TibberPricesApiClientError,
|
||||||
|
)
|
||||||
from .const import (
|
from .const import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
PRICE_LEVEL_CHEAP,
|
PRICE_LEVEL_CHEAP,
|
||||||
|
|
@ -29,6 +34,7 @@ from .const import (
|
||||||
PRICE_SERVICE_NAME = "get_price"
|
PRICE_SERVICE_NAME = "get_price"
|
||||||
APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data"
|
APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data"
|
||||||
APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml"
|
APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml"
|
||||||
|
REFRESH_USER_DATA_SERVICE_NAME = "refresh_user_data"
|
||||||
ATTR_DAY: Final = "day"
|
ATTR_DAY: Final = "day"
|
||||||
ATTR_ENTRY_ID: Final = "entry_id"
|
ATTR_ENTRY_ID: Final = "entry_id"
|
||||||
ATTR_TIME: Final = "time"
|
ATTR_TIME: Final = "time"
|
||||||
|
|
@ -68,6 +74,12 @@ APEXCHARTS_SERVICE_SCHEMA: Final = vol.Schema(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REFRESH_USER_DATA_SERVICE_SCHEMA: Final = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(ATTR_ENTRY_ID): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# region Top-level functions (ordered by call hierarchy)
|
# region Top-level functions (ordered by call hierarchy)
|
||||||
|
|
||||||
# --- Entry point: Service handler ---
|
# --- Entry point: Service handler ---
|
||||||
|
|
@ -165,41 +177,59 @@ async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]:
|
||||||
level_type = call.data.get("level_type", "rating_level")
|
level_type = call.data.get("level_type", "rating_level")
|
||||||
level_key = call.data.get("level_key")
|
level_key = call.data.get("level_key")
|
||||||
hass = call.hass
|
hass = call.hass
|
||||||
|
|
||||||
|
# Get entry ID and verify it exists
|
||||||
entry_id = await _get_entry_id_from_entity_id(hass, entity_id)
|
entry_id = await _get_entry_id_from_entity_id(hass, entity_id)
|
||||||
if not entry_id:
|
if not entry_id:
|
||||||
raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id")
|
raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id")
|
||||||
|
|
||||||
entry, coordinator, data = _get_entry_and_data(hass, entry_id)
|
entry, coordinator, data = _get_entry_and_data(hass, entry_id)
|
||||||
points = []
|
|
||||||
|
# Get entries based on level_type
|
||||||
|
entries = _get_apexcharts_entries(coordinator, day, level_type)
|
||||||
|
if not entries:
|
||||||
|
return {"points": []}
|
||||||
|
|
||||||
|
# Ensure level_key is a string
|
||||||
|
if level_key is None:
|
||||||
|
raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_level_key")
|
||||||
|
|
||||||
|
# Generate points for the chart
|
||||||
|
points = _generate_apexcharts_points(entries, str(level_key))
|
||||||
|
return {"points": points}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_apexcharts_entries(coordinator: Any, day: str, level_type: str) -> list[dict]:
|
||||||
|
"""Get the appropriate entries for ApexCharts based on level_type and day."""
|
||||||
if level_type == "rating_level":
|
if level_type == "rating_level":
|
||||||
entries = coordinator.data.get("priceRating", {}).get("hourly", [])
|
entries = coordinator.data.get("priceRating", {}).get("hourly", [])
|
||||||
price_info = coordinator.data.get("priceInfo", {})
|
price_info = coordinator.data.get("priceInfo", {})
|
||||||
if day == "today":
|
day_info = price_info.get(day, [])
|
||||||
prefixes = _get_day_prefixes(price_info.get("today", []))
|
prefixes = _get_day_prefixes(day_info)
|
||||||
if not prefixes:
|
|
||||||
return {"points": []}
|
if not prefixes:
|
||||||
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])]
|
return []
|
||||||
elif day == "tomorrow":
|
|
||||||
prefixes = _get_day_prefixes(price_info.get("tomorrow", []))
|
return [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])]
|
||||||
if not prefixes:
|
|
||||||
return {"points": []}
|
# For non-rating level types, return the price info for the specified day
|
||||||
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])]
|
return coordinator.data.get("priceInfo", {}).get(day, [])
|
||||||
elif day == "yesterday":
|
|
||||||
prefixes = _get_day_prefixes(price_info.get("yesterday", []))
|
|
||||||
if not prefixes:
|
def _generate_apexcharts_points(entries: list[dict], level_key: str) -> list:
|
||||||
return {"points": []}
|
"""Generate data points for ApexCharts based on the entries and level key."""
|
||||||
entries = [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])]
|
points = []
|
||||||
else:
|
|
||||||
entries = coordinator.data.get("priceInfo", {}).get(day, [])
|
|
||||||
if not entries:
|
|
||||||
return {"points": []}
|
|
||||||
for i in range(len(entries) - 1):
|
for i in range(len(entries) - 1):
|
||||||
p = entries[i]
|
p = entries[i]
|
||||||
if p.get("level") != level_key:
|
if p.get("level") != level_key:
|
||||||
continue
|
continue
|
||||||
points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)])
|
points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)])
|
||||||
|
|
||||||
|
# Add a final point with null value if there are any points
|
||||||
if points:
|
if points:
|
||||||
points.append([points[-1][0], None])
|
points.append([points[-1][0], None])
|
||||||
return {"points": points}
|
|
||||||
|
return points
|
||||||
|
|
||||||
|
|
||||||
async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]:
|
async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]:
|
||||||
|
|
@ -265,6 +295,57 @@ async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]:
|
||||||
|
"""Refresh user data for a specific config entry and return updated information."""
|
||||||
|
entry_id = call.data.get(ATTR_ENTRY_ID)
|
||||||
|
hass = call.hass
|
||||||
|
|
||||||
|
if not entry_id:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "Entry ID is required",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the entry and coordinator
|
||||||
|
try:
|
||||||
|
entry, coordinator, data = _get_entry_and_data(hass, entry_id)
|
||||||
|
except ServiceValidationError as ex:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid entry ID: {ex}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Force refresh user data using the public method
|
||||||
|
try:
|
||||||
|
updated = await coordinator.refresh_user_data()
|
||||||
|
except (
|
||||||
|
TibberPricesApiClientAuthenticationError,
|
||||||
|
TibberPricesApiClientCommunicationError,
|
||||||
|
TibberPricesApiClientError,
|
||||||
|
) as ex:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"API error refreshing user data: {ex!s}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if updated:
|
||||||
|
user_profile = coordinator.get_user_profile()
|
||||||
|
homes = coordinator.get_user_homes()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "User data refreshed successfully",
|
||||||
|
"user_profile": user_profile,
|
||||||
|
"homes_count": len(homes),
|
||||||
|
"homes": homes,
|
||||||
|
"last_updated": user_profile.get("last_updated"),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "User data was already up to date",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- Direct helpers (called by service handler or each other) ---
|
# --- Direct helpers (called by service handler or each other) ---
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,8 +551,6 @@ def _annotate_intervals_with_times(
|
||||||
elif day == "tomorrow":
|
elif day == "tomorrow":
|
||||||
prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False)
|
prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False)
|
||||||
interval["previous_end_time"] = prev_end
|
interval["previous_end_time"] = prev_end
|
||||||
elif day == "yesterday":
|
|
||||||
interval["previous_end_time"] = None
|
|
||||||
else:
|
else:
|
||||||
interval["previous_end_time"] = None
|
interval["previous_end_time"] = None
|
||||||
|
|
||||||
|
|
@ -712,6 +791,13 @@ def async_setup_services(hass: HomeAssistant) -> None:
|
||||||
schema=APEXCHARTS_SERVICE_SCHEMA,
|
schema=APEXCHARTS_SERVICE_SCHEMA,
|
||||||
supports_response=SupportsResponse.ONLY,
|
supports_response=SupportsResponse.ONLY,
|
||||||
)
|
)
|
||||||
|
hass.services.async_register(
|
||||||
|
DOMAIN,
|
||||||
|
REFRESH_USER_DATA_SERVICE_NAME,
|
||||||
|
_refresh_user_data,
|
||||||
|
schema=REFRESH_USER_DATA_SERVICE_SCHEMA,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
||||||
|
|
@ -111,3 +111,16 @@ get_apexcharts_yaml:
|
||||||
- yesterday
|
- yesterday
|
||||||
- today
|
- today
|
||||||
- tomorrow
|
- tomorrow
|
||||||
|
refresh_user_data:
|
||||||
|
name: Refresh User Data
|
||||||
|
description: >-
|
||||||
|
Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues.
|
||||||
|
fields:
|
||||||
|
entry_id:
|
||||||
|
name: Entry ID
|
||||||
|
description: The config entry ID for the Tibber integration.
|
||||||
|
required: true
|
||||||
|
example: "1234567890abcdef"
|
||||||
|
selector:
|
||||||
|
config_entry:
|
||||||
|
integration: tibber_prices
|
||||||
|
|
|
||||||
|
|
@ -133,5 +133,27 @@
|
||||||
"name": "Daten für morgen verfügbar"
|
"name": "Daten für morgen verfügbar"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"issues": {
|
||||||
|
"new_homes_available": {
|
||||||
|
"title": "Neue Tibber-Häuser erkannt",
|
||||||
|
"description": "Wir haben {count} neue(s) Zuhause in deinem Tibber-Konto erkannt: {homes}. Du kannst diese über die Tibber-Integration in Home Assistant hinzufügen."
|
||||||
|
},
|
||||||
|
"homes_removed": {
|
||||||
|
"title": "Tibber-Häuser entfernt",
|
||||||
|
"description": "Wir haben erkannt, dass {count} Zuhause aus deinem Tibber-Konto entfernt wurde(n): {homes}. Bitte überprüfe deine Tibber-Integrationskonfiguration."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"refresh_user_data": {
|
||||||
|
"name": "Benutzerdaten aktualisieren",
|
||||||
|
"description": "Erzwingt eine Aktualisierung der Benutzerdaten (Häuser, Profilinformationen) aus der Tibber API. Dies kann nützlich sein, nachdem Änderungen an deinem Tibber-Konto vorgenommen wurden oder bei der Fehlerbehebung von Verbindungsproblemen.",
|
||||||
|
"fields": {
|
||||||
|
"entry_id": {
|
||||||
|
"name": "Eintrag-ID",
|
||||||
|
"description": "Die Konfigurationseintrag-ID für die Tibber-Integration."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,11 @@
|
||||||
},
|
},
|
||||||
"config_subentries": {
|
"config_subentries": {
|
||||||
"home": {
|
"home": {
|
||||||
"title": "Home",
|
"title": "Add Home",
|
||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
"title": "Add Tibber Home",
|
"title": "Add Tibber Home",
|
||||||
"description": "Select a home to add to your Tibber integration.",
|
"description": "Select a home to add to your Tibber integration.\n\n**Note:** After adding this home, you can add additional homes from the integration's context menu by selecting \"Add Home\".",
|
||||||
"data": {
|
"data": {
|
||||||
"home_id": "Home"
|
"home_id": "Home"
|
||||||
}
|
}
|
||||||
|
|
@ -51,7 +51,7 @@
|
||||||
"no_access_token": "No access token available",
|
"no_access_token": "No access token available",
|
||||||
"home_not_found": "Selected home not found",
|
"home_not_found": "Selected home not found",
|
||||||
"api_error": "Failed to fetch homes from Tibber API",
|
"api_error": "Failed to fetch homes from Tibber API",
|
||||||
"no_available_homes": "No additional homes available to add"
|
"no_available_homes": "No additional homes available to add. All homes from your Tibber account have already been added."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -148,5 +148,27 @@
|
||||||
"name": "Tomorrow's Data Available"
|
"name": "Tomorrow's Data Available"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"issues": {
|
||||||
|
"new_homes_available": {
|
||||||
|
"title": "New Tibber homes detected",
|
||||||
|
"description": "We detected {count} new home(s) on your Tibber account: {homes}. You can add them to Home Assistant through the Tibber integration configuration."
|
||||||
|
},
|
||||||
|
"homes_removed": {
|
||||||
|
"title": "Tibber homes removed",
|
||||||
|
"description": "We detected that {count} home(s) have been removed from your Tibber account: {homes}. Please review your Tibber integration configuration."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"refresh_user_data": {
|
||||||
|
"name": "Refresh User Data",
|
||||||
|
"description": "Forces a refresh of the user data (homes, profile information) from the Tibber API. This can be useful after making changes to your Tibber account or when troubleshooting connectivity issues.",
|
||||||
|
"fields": {
|
||||||
|
"entry_id": {
|
||||||
|
"name": "Entry ID",
|
||||||
|
"description": "The config entry ID for the Tibber integration."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "Tibber Price Information & Ratings",
|
"name": "Tibber Price Information & Ratings",
|
||||||
"homeassistant": "2025.4.2",
|
"homeassistant": "2025.5.0",
|
||||||
"hacs": "2.0.1",
|
"hacs": "2.0.1",
|
||||||
"render_readme": true
|
"render_readme": true
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Tests package
|
||||||
84
tests/test_coordinator_basic.py
Normal file
84
tests/test_coordinator_basic.py
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""Test basic coordinator functionality with the enhanced coordinator."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
class TestBasicCoordinator:
|
||||||
|
"""Test basic coordinator functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hass(self):
|
||||||
|
"""Create a mock Home Assistant instance."""
|
||||||
|
hass = Mock()
|
||||||
|
hass.data = {}
|
||||||
|
return hass
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry(self):
|
||||||
|
"""Create a mock config entry."""
|
||||||
|
config_entry = Mock()
|
||||||
|
config_entry.unique_id = "test_home_123"
|
||||||
|
config_entry.entry_id = "test_entry"
|
||||||
|
config_entry.data = {"access_token": "test_token"}
|
||||||
|
return config_entry
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock session."""
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def coordinator(self, mock_hass, mock_config_entry, mock_session):
|
||||||
|
"""Create a coordinator instance."""
|
||||||
|
with patch(
|
||||||
|
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
|
||||||
|
return_value=mock_session,
|
||||||
|
):
|
||||||
|
with patch("custom_components.tibber_prices.coordinator.Store") as mock_store_class:
|
||||||
|
mock_store = Mock()
|
||||||
|
mock_store.async_load = AsyncMock(return_value=None)
|
||||||
|
mock_store.async_save = AsyncMock()
|
||||||
|
mock_store_class.return_value = mock_store
|
||||||
|
|
||||||
|
return TibberPricesDataUpdateCoordinator(mock_hass, mock_config_entry)
|
||||||
|
|
||||||
|
def test_coordinator_creation(self, coordinator):
|
||||||
|
"""Test that coordinator can be created."""
|
||||||
|
assert coordinator is not None
|
||||||
|
assert hasattr(coordinator, "get_current_interval_data")
|
||||||
|
assert hasattr(coordinator, "get_all_intervals")
|
||||||
|
assert hasattr(coordinator, "get_user_profile")
|
||||||
|
|
||||||
|
def test_is_main_entry(self, coordinator):
|
||||||
|
"""Test main entry detection."""
|
||||||
|
# First coordinator should be main entry
|
||||||
|
assert coordinator.is_main_entry() is True
|
||||||
|
|
||||||
|
def test_get_user_profile_no_data(self, coordinator):
|
||||||
|
"""Test getting user profile when no data is cached."""
|
||||||
|
profile = coordinator.get_user_profile()
|
||||||
|
assert profile == {"last_updated": None, "cached_user_data": False}
|
||||||
|
|
||||||
|
def test_get_user_homes_no_data(self, coordinator):
|
||||||
|
"""Test getting user homes when no data is cached."""
|
||||||
|
homes = coordinator.get_user_homes()
|
||||||
|
assert homes == []
|
||||||
|
|
||||||
|
def test_get_current_interval_data_no_data(self, coordinator):
|
||||||
|
"""Test getting current interval data when no data is available."""
|
||||||
|
current_data = coordinator.get_current_interval_data()
|
||||||
|
assert current_data is None
|
||||||
|
|
||||||
|
def test_get_all_intervals_no_data(self, coordinator):
|
||||||
|
"""Test getting all intervals when no data is available."""
|
||||||
|
intervals = coordinator.get_all_intervals()
|
||||||
|
assert intervals == []
|
||||||
|
|
||||||
|
def test_get_interval_granularity(self, coordinator):
|
||||||
|
"""Test getting interval granularity."""
|
||||||
|
granularity = coordinator.get_interval_granularity()
|
||||||
|
assert granularity is None
|
||||||
255
tests/test_coordinator_enhanced.py
Normal file
255
tests/test_coordinator_enhanced.py
Normal file
|
|
@ -0,0 +1,255 @@
|
||||||
|
"""Test enhanced coordinator functionality."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from custom_components.tibber_prices.const import DOMAIN
|
||||||
|
from custom_components.tibber_prices.coordinator import TibberPricesDataUpdateCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnhancedCoordinator:
|
||||||
|
"""Test enhanced coordinator functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry(self) -> Mock:
|
||||||
|
"""Create a mock config entry."""
|
||||||
|
config_entry = Mock()
|
||||||
|
config_entry.unique_id = "test_home_id_123"
|
||||||
|
config_entry.entry_id = "test_entry_id"
|
||||||
|
config_entry.data = {"access_token": "test_token"}
|
||||||
|
return config_entry
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hass(self) -> Mock:
|
||||||
|
"""Create a mock Home Assistant instance."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
hass = Mock()
|
||||||
|
hass.data = {}
|
||||||
|
# Mock the event loop for time tracking
|
||||||
|
hass.loop = asyncio.get_event_loop()
|
||||||
|
return hass
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_store(self) -> Mock:
|
||||||
|
"""Create a mock store."""
|
||||||
|
store = Mock()
|
||||||
|
store.async_load = AsyncMock(return_value=None)
|
||||||
|
store.async_save = AsyncMock()
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api(self) -> Mock:
|
||||||
|
"""Create a mock API client."""
|
||||||
|
api = Mock()
|
||||||
|
api.async_get_viewer_details = AsyncMock(return_value={"homes": []})
|
||||||
|
api.async_get_price_info = AsyncMock(return_value={"homes": {}})
|
||||||
|
api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {}})
|
||||||
|
api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {}})
|
||||||
|
api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {}})
|
||||||
|
return api
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def coordinator(
|
||||||
|
self, mock_hass: Mock, mock_config_entry: Mock, mock_store: Mock, mock_api: Mock
|
||||||
|
) -> TibberPricesDataUpdateCoordinator:
|
||||||
|
"""Create a coordinator for testing."""
|
||||||
|
mock_session = Mock()
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
|
||||||
|
return_value=mock_session,
|
||||||
|
),
|
||||||
|
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
|
||||||
|
):
|
||||||
|
coordinator = TibberPricesDataUpdateCoordinator(
|
||||||
|
hass=mock_hass,
|
||||||
|
config_entry=mock_config_entry,
|
||||||
|
)
|
||||||
|
# Replace the API instance with our mock
|
||||||
|
coordinator.api = mock_api
|
||||||
|
return coordinator
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_main_subentry_pattern(self, mock_hass: Mock, mock_store: Mock) -> None:
|
||||||
|
"""Test main/subentry coordinator pattern."""
|
||||||
|
# Create main coordinator first
|
||||||
|
main_config_entry = Mock()
|
||||||
|
main_config_entry.unique_id = "main_home_id"
|
||||||
|
main_config_entry.entry_id = "main_entry_id"
|
||||||
|
main_config_entry.data = {"access_token": "test_token"}
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
|
||||||
|
return_value=mock_session,
|
||||||
|
),
|
||||||
|
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
|
||||||
|
):
|
||||||
|
main_coordinator = TibberPricesDataUpdateCoordinator(
|
||||||
|
hass=mock_hass,
|
||||||
|
config_entry=main_config_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify main coordinator is marked as main entry
|
||||||
|
assert main_coordinator.is_main_entry()
|
||||||
|
|
||||||
|
# Create subentry coordinator
|
||||||
|
sub_config_entry = Mock()
|
||||||
|
sub_config_entry.unique_id = "sub_home_id"
|
||||||
|
sub_config_entry.entry_id = "sub_entry_id"
|
||||||
|
sub_config_entry.data = {"access_token": "test_token", "home_id": "sub_home_id"}
|
||||||
|
|
||||||
|
# Set up domain data to simulate main coordinator being already registered
|
||||||
|
mock_hass.data[DOMAIN] = {"main_entry_id": main_coordinator}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.tibber_prices.coordinator.aiohttp_client.async_get_clientsession",
|
||||||
|
return_value=mock_session,
|
||||||
|
),
|
||||||
|
patch("custom_components.tibber_prices.coordinator.Store", return_value=mock_store),
|
||||||
|
):
|
||||||
|
sub_coordinator = TibberPricesDataUpdateCoordinator(
|
||||||
|
hass=mock_hass,
|
||||||
|
config_entry=sub_config_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subentry coordinator is not marked as main entry
|
||||||
|
assert not sub_coordinator.is_main_entry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_data_functionality(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
|
||||||
|
"""Test user data related functionality."""
|
||||||
|
# Mock user data API
|
||||||
|
mock_user_data = {
|
||||||
|
"homes": [
|
||||||
|
{"id": "home1", "appNickname": "Home 1"},
|
||||||
|
{"id": "home2", "appNickname": "Home 2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
coordinator.api.async_get_viewer_details = AsyncMock(return_value=mock_user_data)
|
||||||
|
|
||||||
|
# Test refresh user data
|
||||||
|
result = await coordinator.refresh_user_data()
|
||||||
|
assert result
|
||||||
|
|
||||||
|
# Test get user profile
|
||||||
|
profile = coordinator.get_user_profile()
|
||||||
|
assert isinstance(profile, dict)
|
||||||
|
assert "last_updated" in profile
|
||||||
|
assert "cached_user_data" in profile
|
||||||
|
|
||||||
|
# Test get user homes
|
||||||
|
homes = coordinator.get_user_homes()
|
||||||
|
assert isinstance(homes, list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_data_update_with_multi_home_response(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
|
||||||
|
"""Test coordinator handling multi-home API response."""
|
||||||
|
# Mock API responses
|
||||||
|
mock_price_response = {
|
||||||
|
"homes": {
|
||||||
|
"test_home_id_123": {
|
||||||
|
"priceInfo": {
|
||||||
|
"today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.25}],
|
||||||
|
"tomorrow": [],
|
||||||
|
"yesterday": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"other_home_id": {
|
||||||
|
"priceInfo": {
|
||||||
|
"today": [{"startsAt": "2025-05-25T00:00:00Z", "total": 0.30}],
|
||||||
|
"tomorrow": [],
|
||||||
|
"yesterday": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_hourly_rating = {"homes": {"test_home_id_123": {"hourly": []}}}
|
||||||
|
mock_daily_rating = {"homes": {"test_home_id_123": {"daily": []}}}
|
||||||
|
mock_monthly_rating = {"homes": {"test_home_id_123": {"monthly": []}}}
|
||||||
|
|
||||||
|
# Mock all API methods
|
||||||
|
coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response)
|
||||||
|
coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value=mock_hourly_rating)
|
||||||
|
coordinator.api.async_get_daily_price_rating = AsyncMock(return_value=mock_daily_rating)
|
||||||
|
coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value=mock_monthly_rating)
|
||||||
|
|
||||||
|
# Update the coordinator to fetch data
|
||||||
|
await coordinator.async_refresh()
|
||||||
|
|
||||||
|
# Verify coordinator has data
|
||||||
|
assert coordinator.data is not None
|
||||||
|
assert "priceInfo" in coordinator.data
|
||||||
|
assert "priceRating" in coordinator.data
|
||||||
|
|
||||||
|
# Test public API methods work
|
||||||
|
intervals = coordinator.get_all_intervals()
|
||||||
|
assert isinstance(intervals, list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling_with_cache_fallback(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
|
||||||
|
"""Test error handling with fallback to cached data."""
|
||||||
|
from custom_components.tibber_prices.api import TibberPricesApiClientCommunicationError
|
||||||
|
|
||||||
|
# Set up cached data using the store mechanism
|
||||||
|
test_cached_data = {
|
||||||
|
"timestamp": "2025-05-25T00:00:00Z",
|
||||||
|
"homes": {
|
||||||
|
"test_home_id_123": {
|
||||||
|
"price_info": {"today": [], "tomorrow": [], "yesterday": []},
|
||||||
|
"hourly_rating": {},
|
||||||
|
"daily_rating": {},
|
||||||
|
"monthly_rating": {},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock store to return cached data
|
||||||
|
coordinator._store.async_load = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"price_data": test_cached_data,
|
||||||
|
"user_data": None,
|
||||||
|
"last_price_update": "2025-05-25T00:00:00Z",
|
||||||
|
"last_user_update": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the cache
|
||||||
|
await coordinator._load_cache()
|
||||||
|
|
||||||
|
# Mock API to raise communication error
|
||||||
|
coordinator.api.async_get_price_info = AsyncMock(
|
||||||
|
side_effect=TibberPricesApiClientCommunicationError("Network error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise exception but use cached data
|
||||||
|
await coordinator.async_refresh()
|
||||||
|
|
||||||
|
# Verify coordinator has fallback data
|
||||||
|
assert coordinator.data is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_persistence(self, coordinator: TibberPricesDataUpdateCoordinator) -> None:
|
||||||
|
"""Test that data is properly cached and persisted."""
|
||||||
|
# Mock API responses
|
||||||
|
mock_price_response = {
|
||||||
|
"homes": {"test_home_id_123": {"priceInfo": {"today": [], "tomorrow": [], "yesterday": []}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
coordinator.api.async_get_price_info = AsyncMock(return_value=mock_price_response)
|
||||||
|
coordinator.api.async_get_hourly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
|
||||||
|
coordinator.api.async_get_daily_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
|
||||||
|
coordinator.api.async_get_monthly_price_rating = AsyncMock(return_value={"homes": {"test_home_id_123": {}}})
|
||||||
|
|
||||||
|
# Update the coordinator
|
||||||
|
await coordinator.async_refresh()
|
||||||
|
|
||||||
|
# Verify data was cached (store should have been called)
|
||||||
|
coordinator._store.async_save.assert_called()
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
import unittest
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
|
|
||||||
class TestReauthentication(unittest.TestCase):
|
|
||||||
@patch("your_module.connection") # Replace 'your_module' with the actual module name
|
|
||||||
def test_reauthentication_flow(self, mock_connection):
|
|
||||||
mock_connection.reauthenticate = Mock(return_value=True)
|
|
||||||
result = mock_connection.reauthenticate()
|
|
||||||
self.assertTrue(result)
|
|
||||||
|
|
||||||
@patch("your_module.connection") # Replace 'your_module' with the actual module name
|
|
||||||
def test_connection_timeout(self, mock_connection):
|
|
||||||
mock_connection.connect = Mock(side_effect=TimeoutError)
|
|
||||||
with self.assertRaises(TimeoutError):
|
|
||||||
mock_connection.connect()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Reference in a new issue