add data retrieving

This commit is contained in:
Julian Pawlowski 2025-04-18 21:14:36 +00:00
parent 5f8abf3a63
commit f092ad2839
4 changed files with 645 additions and 95 deletions

View file

@ -2,13 +2,42 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging
import socket import socket
from datetime import UTC, datetime, timedelta
from enum import Enum, auto
from typing import Any from typing import Any
import aiohttp import aiohttp
import async_timeout import async_timeout
from homeassistant.const import __version__ as ha_version from homeassistant.const import __version__ as ha_version
from .const import VERSION
_LOGGER = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_UNAUTHORIZED = 401
HTTP_FORBIDDEN = 403
class TransformMode(Enum):
"""Data transformation mode."""
TRANSFORM = auto() # Transform price info data
SKIP = auto() # Return raw data without transformation
class QueryType(Enum):
"""Types of queries that can be made to the API."""
PRICE_INFO = "price_info"
DAILY_RATING = "daily"
HOURLY_RATING = "hourly"
MONTHLY_RATING = "monthly"
TEST = "test"
class TibberPricesApiClientError(Exception): class TibberPricesApiClientError(Exception):
"""Exception to indicate a general API error.""" """Exception to indicate a general API error."""
@ -16,50 +45,37 @@ class TibberPricesApiClientError(Exception):
UNKNOWN_ERROR = "Unknown GraphQL error" UNKNOWN_ERROR = "Unknown GraphQL error"
MALFORMED_ERROR = "Malformed GraphQL error: {error}" MALFORMED_ERROR = "Malformed GraphQL error: {error}"
GRAPHQL_ERROR = "GraphQL error: {message}" GRAPHQL_ERROR = "GraphQL error: {message}"
EMPTY_DATA_ERROR = "Empty data received for {query_type}"
GENERIC_ERROR = "Something went wrong! {exception}" GENERIC_ERROR = "Something went wrong! {exception}"
RATE_LIMIT_ERROR = "Rate limit exceeded"
class TibberPricesApiClientCommunicationError( class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
TibberPricesApiClientError,
):
"""Exception to indicate a communication error.""" """Exception to indicate a communication error."""
TIMEOUT_ERROR = "Timeout error fetching information - {exception}" TIMEOUT_ERROR = "Timeout error fetching information - {exception}"
CONNECTION_ERROR = "Error fetching information - {exception}" CONNECTION_ERROR = "Error fetching information - {exception}"
class TibberPricesApiClientAuthenticationError( class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError):
TibberPricesApiClientError,
):
"""Exception to indicate an authentication error.""" """Exception to indicate an authentication error."""
INVALID_CREDENTIALS = "Invalid credentials"
def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None: def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None:
"""Verify that the response is valid.""" """Verify that the response is valid."""
if response.status in (401, 403): if response.status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN):
msg = "Invalid credentials"
raise TibberPricesApiClientAuthenticationError( raise TibberPricesApiClientAuthenticationError(
msg, TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS
) )
if response.status == HTTP_TOO_MANY_REQUESTS:
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR)
response.raise_for_status() response.raise_for_status()
async def _verify_graphql_response(response_json: dict) -> None: async def _verify_graphql_response(response_json: dict) -> None:
""" """Verify the GraphQL response for errors and data completeness."""
Verify the GraphQL response for errors.
GraphQL errors follow this structure:
{
"errors": [{
"message": "error message",
"locations": [...],
"path": [...],
"extensions": {
"code": "ERROR_CODE"
}
}]
}
"""
if "errors" in response_json: if "errors" in response_json:
errors = response_json["errors"] errors = response_json["errors"]
if not errors: if not errors:
@ -74,11 +90,11 @@ async def _verify_graphql_response(response_json: dict) -> None:
message = error.get("message", "Unknown error") message = error.get("message", "Unknown error")
extensions = error.get("extensions", {}) extensions = error.get("extensions", {})
# Check for authentication errors first
if extensions.get("code") == "UNAUTHENTICATED": if extensions.get("code") == "UNAUTHENTICATED":
raise TibberPricesApiClientAuthenticationError(message) raise TibberPricesApiClientAuthenticationError(
TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS
)
# Handle all other GraphQL errors
raise TibberPricesApiClientError( raise TibberPricesApiClientError(
TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message) TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message)
) )
@ -91,6 +107,134 @@ async def _verify_graphql_response(response_json: dict) -> None:
) )
def _is_data_empty(data: dict, query_type: str) -> bool:
"""Check if the response data is empty or incomplete."""
_LOGGER.debug("Checking if data is empty for query_type %s", query_type)
try:
subscription = data["viewer"]["homes"][0]["currentSubscription"]
if query_type == "price_info":
price_info = subscription["priceInfo"]
# Check either range or yesterday, since we transform range into yesterday
has_range = "range" in price_info and price_info["range"]["edges"]
has_yesterday = "yesterday" in price_info and price_info["yesterday"]
has_historical = has_range or has_yesterday
is_empty = not has_historical or not price_info["today"]
_LOGGER.debug(
"Price info check - historical data: %s, today: %s, is_empty: %s",
bool(has_historical),
bool(price_info["today"]),
is_empty,
)
return is_empty
if query_type in ["daily", "hourly", "monthly"]:
rating = subscription["priceRating"]
if not rating["thresholdPercentages"]:
_LOGGER.debug("Missing threshold percentages for %s rating", query_type)
return True
entries = rating[query_type]["entries"]
is_empty = not entries or len(entries) == 0
_LOGGER.debug(
"%s rating check - entries count: %d, is_empty: %s",
query_type,
len(entries) if entries else 0,
is_empty,
)
return is_empty
_LOGGER.debug("Unknown query type %s, treating as non-empty", query_type)
return False
except (KeyError, IndexError, TypeError) as error:
_LOGGER.debug("Error checking data emptiness: %s", error)
return True
else:
return False
def _prepare_headers(access_token: str) -> dict[str, str]:
"""Prepare headers for API request."""
return {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
"User-Agent": f"HomeAssistant/{ha_version} tibber_prices/{VERSION}",
}
def _transform_data(data: dict, query_type: QueryType) -> dict:
"""Transform API response data based on query type."""
if not data or "viewer" not in data:
_LOGGER.debug("No data to transform or missing viewer key")
return data
_LOGGER.debug("Starting data transformation for query type %s", query_type)
if query_type == QueryType.PRICE_INFO:
return _transform_price_info(data)
if query_type in (
QueryType.DAILY_RATING,
QueryType.HOURLY_RATING,
QueryType.MONTHLY_RATING,
):
return data
if query_type == QueryType.TEST:
return data
_LOGGER.warning("Unknown query type %s, returning raw data", query_type)
return data
def _transform_price_info(data: dict) -> dict:
"""Transform the price info data structure."""
if not data or "viewer" not in data:
_LOGGER.debug("No data to transform or missing viewer key")
return data
_LOGGER.debug("Starting price info transformation")
price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"]
# Get yesterday's date in UTC first, then convert to local for comparison
today_utc = datetime.now(tz=UTC)
today_local = today_utc.astimezone().date()
yesterday_local = today_local - timedelta(days=1)
_LOGGER.debug("Processing data for yesterday's date: %s", yesterday_local)
# Transform edges data
if "range" in price_info and "edges" in price_info["range"]:
edges = price_info["range"]["edges"]
yesterday_prices = []
for edge in edges:
if "node" not in edge:
_LOGGER.debug("Skipping edge without node: %s", edge)
continue
price_data = edge["node"]
# First parse startsAt time, then handle timezone conversion
starts_at = datetime.fromisoformat(price_data["startsAt"])
if starts_at.tzinfo is None:
_LOGGER.debug(
"Found naive timestamp, treating as local time: %s", starts_at
)
starts_at = starts_at.astimezone()
else:
starts_at = starts_at.astimezone()
price_date = starts_at.date()
# Only include prices from yesterday
if price_date == yesterday_local:
yesterday_prices.append(price_data)
_LOGGER.debug("Found %d price entries for yesterday", len(yesterday_prices))
# Replace the entire range object with yesterday prices
price_info["yesterday"] = yesterday_prices
del price_info["range"]
return data
class TibberPricesApiClient: class TibberPricesApiClient:
"""Tibber API Client.""" """Tibber API Client."""
@ -102,6 +246,11 @@ class TibberPricesApiClient:
"""Tibber API Client.""" """Tibber API Client."""
self._access_token = access_token self._access_token = access_token
self._session = session self._session = session
self._request_semaphore = asyncio.Semaphore(2)
self._last_request_time = datetime.now(tz=UTC)
self._min_request_interval = timedelta(seconds=1)
self._max_retries = 3
self._retry_delay = 2
async def async_test_connection(self) -> Any: async def async_test_connection(self) -> Any:
"""Test connection to the API.""" """Test connection to the API."""
@ -114,27 +263,111 @@ class TibberPricesApiClient:
} }
} }
""" """
} },
query_type=QueryType.TEST,
) )
async def async_get_data(self) -> Any: async def async_get_price_info(self) -> Any:
"""Get data from the API.""" """Get price info data including today, tomorrow and last 48 hours."""
return await self._api_wrapper( return await self._api_wrapper(
data={ data={
"query": """ "query": """
query { {viewer{homes{currentSubscription{priceInfo{
viewer { range(resolution:HOURLY,last:48){edges{node{
homes { startsAt total energy tax level
timeZone }}}
currentSubscription { today{startsAt total energy tax level}
status tomorrow{startsAt total energy tax level}
} }}}}}"""
},
query_type=QueryType.PRICE_INFO,
)
async def async_get_daily_price_rating(self) -> Any:
"""Get daily price rating data."""
return await self._api_wrapper(
data={
"query": """
{viewer{homes{currentSubscription{priceRating{
thresholdPercentages{low high}
daily{entries{time total energy tax difference level}}
}}}}}"""
},
query_type=QueryType.DAILY_RATING,
)
async def async_get_hourly_price_rating(self) -> Any:
"""Get hourly price rating data."""
return await self._api_wrapper(
data={
"query": """
{viewer{homes{currentSubscription{priceRating{
thresholdPercentages{low high}
hourly{entries{time total energy tax difference level}}
}}}}}"""
},
query_type=QueryType.HOURLY_RATING,
)
async def async_get_monthly_price_rating(self) -> Any:
"""Get monthly price rating data."""
return await self._api_wrapper(
data={
"query": """
{viewer{homes{currentSubscription{priceRating{
thresholdPercentages{low high}
monthly{
currency
entries{time total energy tax difference level}
}
}}}}}"""
},
query_type=QueryType.MONTHLY_RATING,
)
async def async_get_data(self) -> Any:
"""Get all data from the API by combining multiple queries."""
# Get all data concurrently
price_info = await self.async_get_price_info()
daily_rating = await self.async_get_daily_price_rating()
hourly_rating = await self.async_get_hourly_price_rating()
monthly_rating = await self.async_get_monthly_price_rating()
# Extract the base paths to make the code more readable
def get_base_path(response: dict) -> dict:
"""Get the base subscription path from the response."""
return response["viewer"]["homes"][0]["currentSubscription"]
def get_rating_data(response: dict) -> dict:
"""Get the price rating data from the response."""
return get_base_path(response)["priceRating"]
price_info_data = get_base_path(price_info)["priceInfo"]
# Combine the results
return {
"data": {
"viewer": {
"homes": [
{
"currentSubscription": {
"priceInfo": price_info_data,
"priceRating": {
"thresholdPercentages": get_rating_data(
daily_rating
)["thresholdPercentages"],
"daily": get_rating_data(daily_rating)["daily"],
"hourly": get_rating_data(hourly_rating)["hourly"],
"monthly": get_rating_data(monthly_rating)[
"monthly"
],
},
} }
} }
} ]
""", }
}, }
) }
async def async_set_title(self, value: str) -> Any: async def async_set_title(self, value: str) -> Any:
"""Get data from the API.""" """Get data from the API."""
@ -142,55 +375,132 @@ class TibberPricesApiClient:
data={"title": value}, data={"title": value},
) )
async def _make_request(
self,
headers: dict[str, str],
data: dict,
query_type: QueryType,
) -> dict:
"""Make an API request with proper error handling."""
_LOGGER.debug("Making API request with data: %s", data)
response = await self._session.request(
method="POST",
url="https://api.tibber.com/v1-beta/gql",
headers=headers,
json=data,
)
_verify_response_or_raise(response)
response_json = await response.json()
_LOGGER.debug("Received API response: %s", response_json)
await _verify_graphql_response(response_json)
return _transform_data(response_json["data"], query_type)
async def _handle_request(
self,
headers: dict[str, str],
data: dict,
query_type: QueryType,
) -> Any:
"""Handle a single API request with rate limiting."""
async with self._request_semaphore:
now = datetime.now(tz=UTC)
time_since_last_request = now - self._last_request_time
if time_since_last_request < self._min_request_interval:
sleep_time = (
self._min_request_interval - time_since_last_request
).total_seconds()
_LOGGER.debug(
"Rate limiting: waiting %s seconds before next request",
sleep_time,
)
await asyncio.sleep(sleep_time)
async with async_timeout.timeout(10):
self._last_request_time = datetime.now(tz=UTC)
response_data = await self._make_request(
headers,
data or {},
query_type,
)
if query_type != QueryType.TEST and _is_data_empty(
response_data, query_type.value
):
_LOGGER.debug("Empty data detected for query_type: %s", query_type)
raise TibberPricesApiClientError(
TibberPricesApiClientError.EMPTY_DATA_ERROR.format(
query_type=query_type.value
)
)
return response_data
async def _api_wrapper( async def _api_wrapper(
self, self,
data: dict | None = None, data: dict | None = None,
headers: dict | None = None, headers: dict | None = None,
query_type: QueryType = QueryType.TEST,
) -> Any: ) -> Any:
""" """Get information from the API with rate limiting and retry logic."""
Get information from the API. headers = headers or _prepare_headers(self._access_token)
last_error: Exception | None = None
Returns the contents of the 'data' object from the GraphQL response. for retry in range(self._max_retries + 1):
Raises an error if the response doesn't contain a 'data' object. try:
""" return await self._handle_request(
try: headers,
async with async_timeout.timeout(10): data or {},
headers = headers or {} query_type,
if headers.get("Authorization") is None: )
headers["Authorization"] = f"Bearer {self._access_token}"
if headers.get("Accept") is None: except TibberPricesApiClientAuthenticationError:
headers["Accept"] = "application/json" raise
if headers.get("User-Agent") is None: except (
headers["User-Agent"] = ( aiohttp.ClientError,
f"HomeAssistant/{ha_version} (tibber_prices; +https://github.com/jpawlowski/hass.tibber_prices/)" socket.gaierror,
TimeoutError,
TibberPricesApiClientError,
) as error:
last_error = (
error
if isinstance(error, TibberPricesApiClientError)
else TibberPricesApiClientError(
TibberPricesApiClientError.GENERIC_ERROR.format(
exception=str(error)
)
) )
response = await self._session.request(
method="POST",
url="https://api.tibber.com/v1-beta/gql",
headers=headers,
json=data,
) )
_verify_response_or_raise(response)
response_json = await response.json()
await _verify_graphql_response(response_json)
return response_json["data"]
except TimeoutError as exception: if retry < self._max_retries:
delay = self._retry_delay * (2**retry)
_LOGGER.warning(
"Request failed, attempt %d/%d. Retrying in %d seconds: %s",
retry + 1,
self._max_retries,
delay,
str(error),
)
await asyncio.sleep(delay)
continue
# Handle final error state
if isinstance(last_error, TimeoutError):
raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(
exception=last_error
)
) from last_error
if isinstance(last_error, (aiohttp.ClientError, socket.gaierror)):
raise TibberPricesApiClientCommunicationError( raise TibberPricesApiClientCommunicationError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format( TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(
exception=exception exception=last_error
) )
) from exception ) from last_error
except (aiohttp.ClientError, socket.gaierror) as exception:
raise TibberPricesApiClientCommunicationError( raise last_error or TibberPricesApiClientError(
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format( TibberPricesApiClientError.UNKNOWN_ERROR
exception=exception )
)
) from exception
except TibberPricesApiClientAuthenticationError:
# Re-raise authentication errors directly
raise
except Exception as exception: # pylint: disable=broad-except
raise TibberPricesApiClientError(
TibberPricesApiClientError.GENERIC_ERROR.format(exception=exception)
) from exception

View file

@ -31,7 +31,9 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
_errors = {} _errors = {}
if user_input is not None: if user_input is not None:
try: try:
await self._test_credentials(access_token=user_input[CONF_ACCESS_TOKEN]) name = await self._test_credentials(
access_token=user_input[CONF_ACCESS_TOKEN]
)
except TibberPricesApiClientAuthenticationError as exception: except TibberPricesApiClientAuthenticationError as exception:
LOGGER.warning(exception) LOGGER.warning(exception)
_errors["base"] = "auth" _errors["base"] = "auth"
@ -42,15 +44,10 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
LOGGER.exception(exception) LOGGER.exception(exception)
_errors["base"] = "unknown" _errors["base"] = "unknown"
else: else:
await self.async_set_unique_id( await self.async_set_unique_id(unique_id=slugify(name))
## Do NOT use this in production code
## The unique_id should never be something that can change
## https://developers.home-assistant.io/docs/config_entries_config_flow_handler#unique-ids
unique_id=slugify(user_input[CONF_ACCESS_TOKEN])
)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
return self.async_create_entry( return self.async_create_entry(
title=user_input[CONF_ACCESS_TOKEN], title=name,
data=user_input, data=user_input,
) )
@ -73,10 +70,11 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
errors=_errors, errors=_errors,
) )
async def _test_credentials(self, access_token: str) -> None: async def _test_credentials(self, access_token: str) -> str:
"""Validate credentials.""" """Validate credentials and return the user's name."""
client = TibberPricesApiClient( client = TibberPricesApiClient(
access_token=access_token, access_token=access_token,
session=async_create_clientsession(self.hass), session=async_create_clientsession(self.hass),
) )
await client.async_test_connection() result = await client.async_test_connection()
return result["viewer"]["name"]

View file

@ -4,5 +4,7 @@ from logging import Logger, getLogger
LOGGER: Logger = getLogger(__package__) LOGGER: Logger = getLogger(__package__)
NAME = "Tibber Price Information & Ratings"
VERSION = "0.1.0" # Must match version in manifest.json
DOMAIN = "tibber_prices" DOMAIN = "tibber_prices"
ATTRIBUTION = "Data provided by https://tibber.com/" ATTRIBUTION = "Data provided by https://tibber.com/"

View file

@ -2,31 +2,271 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any import asyncio
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, cast
from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.storage import Store
from homeassistant.helpers.update_coordinator import (
DataUpdateCoordinator,
UpdateFailed,
)
from homeassistant.util import dt as dt_util
from .api import ( from .api import (
TibberPricesApiClientAuthenticationError, TibberPricesApiClientAuthenticationError,
TibberPricesApiClientError, TibberPricesApiClientError,
) )
from .const import DOMAIN, LOGGER
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.core import HomeAssistant
from .data import TibberPricesConfigEntry from .data import TibberPricesConfigEntry
PRICE_UPDATE_RANDOM_MIN_HOUR = 13 # Don't check before 13:00
PRICE_UPDATE_RANDOM_MAX_HOUR = 15 # Don't check after 15:00
PRICE_UPDATE_INTERVAL = timedelta(days=1)
RATING_UPDATE_INTERVAL = timedelta(hours=1)
NO_DATA_ERROR_MSG = "No data available"
STORAGE_VERSION = 1
STORAGE_KEY = f"{DOMAIN}.coordinator"
def _raise_no_data() -> None:
"""Raise error when no data is available."""
raise TibberPricesApiClientError(NO_DATA_ERROR_MSG)
# https://developers.home-assistant.io/docs/integration_fetching_data#coordinated-single-api-poll-for-data-for-all-entities # https://developers.home-assistant.io/docs/integration_fetching_data#coordinated-single-api-poll-for-data-for-all-entities
class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator): class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator):
"""Class to manage fetching data from the API.""" """Class to manage fetching data from the API."""
config_entry: TibberPricesConfigEntry config_entry: TibberPricesConfigEntry
def __init__(
self,
hass: HomeAssistant,
*args: Any,
**kwargs: Any,
) -> None:
"""Initialize coordinator with cache."""
super().__init__(hass, *args, **kwargs)
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
self._cached_price_data: dict | None = None
self._cached_rating_data: dict | None = None
self._last_price_update: datetime | None = None
self._last_rating_update: datetime | None = None
self._scheduled_price_update: asyncio.Task | None = None
async def _async_initialize(self) -> None:
"""Load stored data."""
stored = await self._store.async_load()
if stored:
self._cached_price_data = stored.get("price_data")
self._cached_rating_data = stored.get("rating_data")
if last_price := stored.get("last_price_update"):
self._last_price_update = dt_util.parse_datetime(last_price)
if last_rating := stored.get("last_rating_update"):
self._last_rating_update = dt_util.parse_datetime(last_rating)
LOGGER.debug(
"Loaded stored cache data - Price from: %s, Rating from: %s",
self._last_price_update,
self._last_rating_update,
)
async def _async_update_data(self) -> Any: async def _async_update_data(self) -> Any:
"""Update data via library.""" """Update data via library."""
if self._cached_price_data is None:
# First run after startup, load stored data
await self._async_initialize()
try: try:
return await self.config_entry.runtime_data.client.async_get_data() data = await self._update_all_data()
except TibberPricesApiClientAuthenticationError as exception: except TibberPricesApiClientAuthenticationError as exception:
raise ConfigEntryAuthFailed(exception) from exception raise ConfigEntryAuthFailed(exception) from exception
except TibberPricesApiClientError as exception: except TibberPricesApiClientError as exception:
raise UpdateFailed(exception) from exception raise UpdateFailed(exception) from exception
else:
return data
async def _update_all_data(self) -> dict[str, Any]:
"""Update all data and manage cache."""
current_time = dt_util.now()
processed_data: dict[str, Any] | None = None
is_initial_setup = self._cached_price_data is None
# Handle price data update if needed
if self._should_update_price_data(current_time):
# Check if we're within the allowed time window for price updates
# or if this is initial setup
current_hour = current_time.hour
if is_initial_setup or self._is_price_update_window(current_hour):
# Add random delay only for regular updates, not initial setup
if not is_initial_setup:
delay = random.randint(0, 120) # noqa: S311
LOGGER.debug(
"Adding random delay of %d minutes before price update",
delay,
)
await asyncio.sleep(delay * 60)
# Get fresh price data
data = await self._fetch_price_data()
self._cached_price_data = self._extract_price_data(data)
self._last_price_update = current_time
await self._store_cache()
LOGGER.debug("Updated price data cache at %s", current_time)
processed_data = data
# Handle rating data update if needed
if self._should_update_rating_data(current_time):
rating_data = await self._get_rating_data()
self._cached_rating_data = self._extract_rating_data(rating_data)
self._last_rating_update = current_time
await self._store_cache()
LOGGER.debug("Updated rating data cache at %s", current_time)
processed_data = rating_data
# If we have cached data but no updates were needed
if (
processed_data is None
and self._cached_price_data
and self._cached_rating_data
):
LOGGER.debug(
"Using cached data - Price from: %s, Rating from: %s",
self._last_price_update,
self._last_rating_update,
)
processed_data = self._merge_cached_data()
if processed_data is None:
_raise_no_data()
return cast("dict[str, Any]", processed_data)
async def _store_cache(self) -> None:
"""Store cache data."""
last_price = (
self._last_price_update.isoformat() if self._last_price_update else None
)
last_rating = (
self._last_rating_update.isoformat() if self._last_rating_update else None
)
data = {
"price_data": self._cached_price_data,
"rating_data": self._cached_rating_data,
"last_price_update": last_price,
"last_rating_update": last_rating,
}
await self._store.async_save(data)
def _should_update_price_data(self, current_time: datetime) -> bool:
"""Check if price data should be updated."""
return (
self._cached_price_data is None
or self._last_price_update is None
or current_time - self._last_price_update >= PRICE_UPDATE_INTERVAL
)
def _should_update_rating_data(self, current_time: datetime) -> bool:
"""Check if rating data should be updated."""
return (
self._cached_rating_data is None
or self._last_rating_update is None
or current_time - self._last_rating_update >= RATING_UPDATE_INTERVAL
)
def _is_price_update_window(self, current_hour: int) -> bool:
"""Check if current hour is within price update window."""
return (
PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR
)
async def _fetch_price_data(self) -> dict:
"""Fetch fresh price data from API."""
client = self.config_entry.runtime_data.client
return await client.async_get_price_info()
def _extract_price_data(self, data: dict) -> dict:
"""Extract price data for caching."""
price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"]
return {
"data": {
"viewer": {
"homes": [{"currentSubscription": {"priceInfo": price_info}}]
}
}
}
def _extract_rating_data(self, data: dict) -> dict:
"""Extract rating data for caching."""
return {
"data": {
"viewer": {
"homes": [
{
"currentSubscription": {
"priceRating": data["data"]["viewer"]["homes"][0][
"currentSubscription"
]["priceRating"]
}
}
]
}
}
}
def _merge_cached_data(self) -> dict:
"""Merge cached price and rating data."""
if not self._cached_price_data or not self._cached_rating_data:
return {}
subscription = {
"priceInfo": self._cached_price_data["data"]["viewer"]["homes"][0][
"currentSubscription"
]["priceInfo"],
"priceRating": self._cached_rating_data["data"]["viewer"]["homes"][0][
"currentSubscription"
]["priceRating"],
}
return {"data": {"viewer": {"homes": [{"currentSubscription": subscription}]}}}
async def _get_rating_data(self) -> dict:
"""Get fresh rating data from API."""
client = self.config_entry.runtime_data.client
daily = await client.async_get_daily_price_rating()
hourly = await client.async_get_hourly_price_rating()
monthly = await client.async_get_monthly_price_rating()
rating_base = daily["viewer"]["homes"][0]["currentSubscription"]["priceRating"]
return {
"data": {
"viewer": {
"homes": [
{
"currentSubscription": {
"priceRating": {
"thresholdPercentages": rating_base[
"thresholdPercentages"
],
"daily": rating_base["daily"],
"hourly": hourly["viewer"]["homes"][0][
"currentSubscription"
]["priceRating"]["hourly"],
"monthly": monthly["viewer"]["homes"][0][
"currentSubscription"
]["priceRating"]["monthly"],
}
}
}
]
}
}
}