This commit is contained in:
Julian Pawlowski 2025-05-24 13:59:46 +00:00
parent 130b51f5b6
commit 86c8073a51
2 changed files with 70 additions and 106 deletions

View file

@ -6,7 +6,7 @@ import asyncio
import logging
import socket
from datetime import timedelta
from enum import Enum, auto
from enum import Enum
from typing import Any
import aiohttp
@ -24,13 +24,6 @@ HTTP_UNAUTHORIZED = 401
HTTP_FORBIDDEN = 403
class TransformMode(Enum):
"""Data transformation mode."""
TRANSFORM = auto() # Transform price info data
SKIP = auto() # Return raw data without transformation
class QueryType(Enum):
"""Types of queries that can be made to the API."""
@ -248,44 +241,16 @@ def _prepare_headers(access_token: str) -> dict[str, str]:
}
def _transform_data(data: dict, query_type: QueryType) -> dict:
"""Transform API response data based on query type."""
if not data or "viewer" not in data:
_LOGGER.debug("No data to transform or missing viewer key")
return data
_LOGGER.debug("Starting data transformation for query type %s", query_type)
if query_type == QueryType.PRICE_INFO:
return _transform_price_info(data)
if query_type in (
QueryType.DAILY_RATING,
QueryType.HOURLY_RATING,
QueryType.MONTHLY_RATING,
):
return data
if query_type == QueryType.VIEWER:
return data
_LOGGER.warning("Unknown query type %s, returning raw data", query_type)
return data
def _transform_price_info(data: dict) -> dict:
"""Transform the price info data structure."""
if not data or "viewer" not in data:
_LOGGER.debug("No data to transform or missing viewer key")
return data
_LOGGER.debug("Starting price info transformation")
price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"]
def _flatten_price_info(subscription: dict) -> dict:
"""Transform and flatten priceInfo from full API data structure."""
price_info = subscription.get("priceInfo", {})
# Get today and yesterday dates using Home Assistant's dt_util
today_local = dt_util.now().date()
yesterday_local = today_local - timedelta(days=1)
_LOGGER.debug("Processing data for yesterday's date: %s", yesterday_local)
# Transform edges data
# Transform edges data (extract yesterday's prices)
if "range" in price_info and "edges" in price_info["range"]:
edges = price_info["range"]["edges"]
yesterday_prices = []
@ -315,12 +280,6 @@ def _transform_price_info(data: dict) -> dict:
price_info["yesterday"] = yesterday_prices
del price_info["range"]
return data
def _flatten_price_info(subscription: dict) -> dict:
"""Extract and flatten priceInfo from subscription."""
price_info = subscription.get("priceInfo", {})
return {
"yesterday": price_info.get("yesterday", []),
"today": price_info.get("today", []),
@ -538,7 +497,7 @@ class TibberPricesApiClient:
await _verify_graphql_response(response_json, query_type)
return _transform_data(response_json["data"], query_type)
return response_json["data"]
async def _handle_request(
self,

View file

@ -167,6 +167,36 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]):
async def _async_update_data(self) -> dict:
"""Fetch new state data for the coordinator. Handles expired credentials by raising ConfigEntryAuthFailed."""
if self._cached_price_data is None:
await self._handle_initialization()
try:
current_time = dt_util.now()
if self._force_update:
LOGGER.debug(
"Force updating data",
extra={
"reason": "force_update",
"last_success": self.last_update_success,
"last_price_update": self._last_price_update,
"last_rating_updates": {
"hourly": self._last_rating_update_hourly,
"daily": self._last_rating_update_daily,
"monthly": self._last_rating_update_monthly,
},
},
)
self._force_update = False
return await self._fetch_all_data()
return await self._handle_conditional_update(current_time)
except (
TibberPricesApiClientAuthenticationError,
TimeoutError,
TibberPricesApiClientCommunicationError,
TibberPricesApiClientError,
) as exception:
return await self._handle_update_exception(exception)
async def _handle_initialization(self) -> None:
"""Handle initialization and related errors for cached price data."""
try:
await self._async_initialize()
except TimeoutError as exception:
@ -193,35 +223,17 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]):
extra={"error": str(exception), "error_type": "unexpected_init"},
)
raise UpdateFailed(msg) from exception
try:
current_time = dt_util.now()
result = None
if self._force_update:
LOGGER.debug(
"Force updating data",
extra={
"reason": "force_update",
"last_success": self.last_update_success,
"last_price_update": self._last_price_update,
"last_rating_updates": {
"hourly": self._last_rating_update_hourly,
"daily": self._last_rating_update_daily,
"monthly": self._last_rating_update_monthly,
},
},
)
self._force_update = False
result = await self._fetch_all_data()
else:
result = await self._handle_conditional_update(current_time)
except TibberPricesApiClientAuthenticationError as exception:
async def _handle_update_exception(self, exception: Exception) -> dict:
"""Handle exceptions during update and return fallback or raise."""
if isinstance(exception, TibberPricesApiClientAuthenticationError):
msg = "Authentication failed: credentials expired or invalid"
LOGGER.error(
"Authentication failed (likely expired credentials)",
extra={"error": str(exception), "error_type": "auth_failed"},
)
raise ConfigEntryAuthFailed(msg) from exception
except TimeoutError as exception:
if isinstance(exception, TimeoutError):
msg = "Timeout during data update"
LOGGER.warning(
"%s: %s",
@ -233,11 +245,6 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]):
LOGGER.info("Using cached data as fallback after timeout")
return self._merge_all_cached_data()
raise UpdateFailed(msg) from exception
except (
TibberPricesApiClientCommunicationError,
TibberPricesApiClientError,
Exception,
) as exception:
if isinstance(exception, TibberPricesApiClientCommunicationError):
LOGGER.error(
"API communication error",
@ -260,8 +267,6 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict]):
LOGGER.info("Using cached data as fallback")
return self._merge_all_cached_data()
raise UpdateFailed(UPDATE_FAILED_MSG) from exception
else:
return result
async def _handle_conditional_update(self, current_time: datetime) -> dict:
"""Handle conditional update based on update conditions."""