mirror of
https://github.com/jpawlowski/hass.tibber_prices.git
synced 2026-03-29 21:03:40 +00:00
add data retrieving
This commit is contained in:
parent
5f8abf3a63
commit
f092ad2839
4 changed files with 645 additions and 95 deletions
|
|
@ -2,13 +2,42 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import async_timeout
|
||||
from homeassistant.const import __version__ as ha_version
|
||||
|
||||
from .const import VERSION
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_UNAUTHORIZED = 401
|
||||
HTTP_FORBIDDEN = 403
|
||||
|
||||
|
||||
class TransformMode(Enum):
|
||||
"""Data transformation mode."""
|
||||
|
||||
TRANSFORM = auto() # Transform price info data
|
||||
SKIP = auto() # Return raw data without transformation
|
||||
|
||||
|
||||
class QueryType(Enum):
|
||||
"""Types of queries that can be made to the API."""
|
||||
|
||||
PRICE_INFO = "price_info"
|
||||
DAILY_RATING = "daily"
|
||||
HOURLY_RATING = "hourly"
|
||||
MONTHLY_RATING = "monthly"
|
||||
TEST = "test"
|
||||
|
||||
|
||||
class TibberPricesApiClientError(Exception):
|
||||
"""Exception to indicate a general API error."""
|
||||
|
|
@ -16,50 +45,37 @@ class TibberPricesApiClientError(Exception):
|
|||
UNKNOWN_ERROR = "Unknown GraphQL error"
|
||||
MALFORMED_ERROR = "Malformed GraphQL error: {error}"
|
||||
GRAPHQL_ERROR = "GraphQL error: {message}"
|
||||
EMPTY_DATA_ERROR = "Empty data received for {query_type}"
|
||||
GENERIC_ERROR = "Something went wrong! {exception}"
|
||||
RATE_LIMIT_ERROR = "Rate limit exceeded"
|
||||
|
||||
|
||||
class TibberPricesApiClientCommunicationError(
|
||||
TibberPricesApiClientError,
|
||||
):
|
||||
class TibberPricesApiClientCommunicationError(TibberPricesApiClientError):
|
||||
"""Exception to indicate a communication error."""
|
||||
|
||||
TIMEOUT_ERROR = "Timeout error fetching information - {exception}"
|
||||
CONNECTION_ERROR = "Error fetching information - {exception}"
|
||||
|
||||
|
||||
class TibberPricesApiClientAuthenticationError(
|
||||
TibberPricesApiClientError,
|
||||
):
|
||||
class TibberPricesApiClientAuthenticationError(TibberPricesApiClientError):
|
||||
"""Exception to indicate an authentication error."""
|
||||
|
||||
INVALID_CREDENTIALS = "Invalid credentials"
|
||||
|
||||
|
||||
def _verify_response_or_raise(response: aiohttp.ClientResponse) -> None:
|
||||
"""Verify that the response is valid."""
|
||||
if response.status in (401, 403):
|
||||
msg = "Invalid credentials"
|
||||
if response.status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN):
|
||||
raise TibberPricesApiClientAuthenticationError(
|
||||
msg,
|
||||
TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS
|
||||
)
|
||||
if response.status == HTTP_TOO_MANY_REQUESTS:
|
||||
raise TibberPricesApiClientError(TibberPricesApiClientError.RATE_LIMIT_ERROR)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def _verify_graphql_response(response_json: dict) -> None:
|
||||
"""
|
||||
Verify the GraphQL response for errors.
|
||||
|
||||
GraphQL errors follow this structure:
|
||||
{
|
||||
"errors": [{
|
||||
"message": "error message",
|
||||
"locations": [...],
|
||||
"path": [...],
|
||||
"extensions": {
|
||||
"code": "ERROR_CODE"
|
||||
}
|
||||
}]
|
||||
}
|
||||
"""
|
||||
"""Verify the GraphQL response for errors and data completeness."""
|
||||
if "errors" in response_json:
|
||||
errors = response_json["errors"]
|
||||
if not errors:
|
||||
|
|
@ -74,11 +90,11 @@ async def _verify_graphql_response(response_json: dict) -> None:
|
|||
message = error.get("message", "Unknown error")
|
||||
extensions = error.get("extensions", {})
|
||||
|
||||
# Check for authentication errors first
|
||||
if extensions.get("code") == "UNAUTHENTICATED":
|
||||
raise TibberPricesApiClientAuthenticationError(message)
|
||||
raise TibberPricesApiClientAuthenticationError(
|
||||
TibberPricesApiClientAuthenticationError.INVALID_CREDENTIALS
|
||||
)
|
||||
|
||||
# Handle all other GraphQL errors
|
||||
raise TibberPricesApiClientError(
|
||||
TibberPricesApiClientError.GRAPHQL_ERROR.format(message=message)
|
||||
)
|
||||
|
|
@ -91,6 +107,134 @@ async def _verify_graphql_response(response_json: dict) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _is_data_empty(data: dict, query_type: str) -> bool:
|
||||
"""Check if the response data is empty or incomplete."""
|
||||
_LOGGER.debug("Checking if data is empty for query_type %s", query_type)
|
||||
|
||||
try:
|
||||
subscription = data["viewer"]["homes"][0]["currentSubscription"]
|
||||
|
||||
if query_type == "price_info":
|
||||
price_info = subscription["priceInfo"]
|
||||
# Check either range or yesterday, since we transform range into yesterday
|
||||
has_range = "range" in price_info and price_info["range"]["edges"]
|
||||
has_yesterday = "yesterday" in price_info and price_info["yesterday"]
|
||||
has_historical = has_range or has_yesterday
|
||||
is_empty = not has_historical or not price_info["today"]
|
||||
_LOGGER.debug(
|
||||
"Price info check - historical data: %s, today: %s, is_empty: %s",
|
||||
bool(has_historical),
|
||||
bool(price_info["today"]),
|
||||
is_empty,
|
||||
)
|
||||
return is_empty
|
||||
|
||||
if query_type in ["daily", "hourly", "monthly"]:
|
||||
rating = subscription["priceRating"]
|
||||
if not rating["thresholdPercentages"]:
|
||||
_LOGGER.debug("Missing threshold percentages for %s rating", query_type)
|
||||
return True
|
||||
entries = rating[query_type]["entries"]
|
||||
is_empty = not entries or len(entries) == 0
|
||||
_LOGGER.debug(
|
||||
"%s rating check - entries count: %d, is_empty: %s",
|
||||
query_type,
|
||||
len(entries) if entries else 0,
|
||||
is_empty,
|
||||
)
|
||||
return is_empty
|
||||
|
||||
_LOGGER.debug("Unknown query type %s, treating as non-empty", query_type)
|
||||
return False
|
||||
except (KeyError, IndexError, TypeError) as error:
|
||||
_LOGGER.debug("Error checking data emptiness: %s", error)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _prepare_headers(access_token: str) -> dict[str, str]:
|
||||
"""Prepare headers for API request."""
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": f"HomeAssistant/{ha_version} tibber_prices/{VERSION}",
|
||||
}
|
||||
|
||||
|
||||
def _transform_data(data: dict, query_type: QueryType) -> dict:
|
||||
"""Transform API response data based on query type."""
|
||||
if not data or "viewer" not in data:
|
||||
_LOGGER.debug("No data to transform or missing viewer key")
|
||||
return data
|
||||
|
||||
_LOGGER.debug("Starting data transformation for query type %s", query_type)
|
||||
|
||||
if query_type == QueryType.PRICE_INFO:
|
||||
return _transform_price_info(data)
|
||||
if query_type in (
|
||||
QueryType.DAILY_RATING,
|
||||
QueryType.HOURLY_RATING,
|
||||
QueryType.MONTHLY_RATING,
|
||||
):
|
||||
return data
|
||||
if query_type == QueryType.TEST:
|
||||
return data
|
||||
|
||||
_LOGGER.warning("Unknown query type %s, returning raw data", query_type)
|
||||
return data
|
||||
|
||||
|
||||
def _transform_price_info(data: dict) -> dict:
|
||||
"""Transform the price info data structure."""
|
||||
if not data or "viewer" not in data:
|
||||
_LOGGER.debug("No data to transform or missing viewer key")
|
||||
return data
|
||||
|
||||
_LOGGER.debug("Starting price info transformation")
|
||||
price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"]
|
||||
|
||||
# Get yesterday's date in UTC first, then convert to local for comparison
|
||||
today_utc = datetime.now(tz=UTC)
|
||||
today_local = today_utc.astimezone().date()
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
_LOGGER.debug("Processing data for yesterday's date: %s", yesterday_local)
|
||||
|
||||
# Transform edges data
|
||||
if "range" in price_info and "edges" in price_info["range"]:
|
||||
edges = price_info["range"]["edges"]
|
||||
yesterday_prices = []
|
||||
|
||||
for edge in edges:
|
||||
if "node" not in edge:
|
||||
_LOGGER.debug("Skipping edge without node: %s", edge)
|
||||
continue
|
||||
|
||||
price_data = edge["node"]
|
||||
# First parse startsAt time, then handle timezone conversion
|
||||
starts_at = datetime.fromisoformat(price_data["startsAt"])
|
||||
if starts_at.tzinfo is None:
|
||||
_LOGGER.debug(
|
||||
"Found naive timestamp, treating as local time: %s", starts_at
|
||||
)
|
||||
starts_at = starts_at.astimezone()
|
||||
else:
|
||||
starts_at = starts_at.astimezone()
|
||||
|
||||
price_date = starts_at.date()
|
||||
|
||||
# Only include prices from yesterday
|
||||
if price_date == yesterday_local:
|
||||
yesterday_prices.append(price_data)
|
||||
|
||||
_LOGGER.debug("Found %d price entries for yesterday", len(yesterday_prices))
|
||||
# Replace the entire range object with yesterday prices
|
||||
price_info["yesterday"] = yesterday_prices
|
||||
del price_info["range"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class TibberPricesApiClient:
|
||||
"""Tibber API Client."""
|
||||
|
||||
|
|
@ -102,6 +246,11 @@ class TibberPricesApiClient:
|
|||
"""Tibber API Client."""
|
||||
self._access_token = access_token
|
||||
self._session = session
|
||||
self._request_semaphore = asyncio.Semaphore(2)
|
||||
self._last_request_time = datetime.now(tz=UTC)
|
||||
self._min_request_interval = timedelta(seconds=1)
|
||||
self._max_retries = 3
|
||||
self._retry_delay = 2
|
||||
|
||||
async def async_test_connection(self) -> Any:
|
||||
"""Test connection to the API."""
|
||||
|
|
@ -114,27 +263,111 @@ class TibberPricesApiClient:
|
|||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
},
|
||||
query_type=QueryType.TEST,
|
||||
)
|
||||
|
||||
async def async_get_data(self) -> Any:
|
||||
"""Get data from the API."""
|
||||
async def async_get_price_info(self) -> Any:
|
||||
"""Get price info data including today, tomorrow and last 48 hours."""
|
||||
return await self._api_wrapper(
|
||||
data={
|
||||
"query": """
|
||||
query {
|
||||
viewer {
|
||||
homes {
|
||||
timeZone
|
||||
currentSubscription {
|
||||
status
|
||||
}
|
||||
{viewer{homes{currentSubscription{priceInfo{
|
||||
range(resolution:HOURLY,last:48){edges{node{
|
||||
startsAt total energy tax level
|
||||
}}}
|
||||
today{startsAt total energy tax level}
|
||||
tomorrow{startsAt total energy tax level}
|
||||
}}}}}"""
|
||||
},
|
||||
query_type=QueryType.PRICE_INFO,
|
||||
)
|
||||
|
||||
async def async_get_daily_price_rating(self) -> Any:
|
||||
"""Get daily price rating data."""
|
||||
return await self._api_wrapper(
|
||||
data={
|
||||
"query": """
|
||||
{viewer{homes{currentSubscription{priceRating{
|
||||
thresholdPercentages{low high}
|
||||
daily{entries{time total energy tax difference level}}
|
||||
}}}}}"""
|
||||
},
|
||||
query_type=QueryType.DAILY_RATING,
|
||||
)
|
||||
|
||||
async def async_get_hourly_price_rating(self) -> Any:
|
||||
"""Get hourly price rating data."""
|
||||
return await self._api_wrapper(
|
||||
data={
|
||||
"query": """
|
||||
{viewer{homes{currentSubscription{priceRating{
|
||||
thresholdPercentages{low high}
|
||||
hourly{entries{time total energy tax difference level}}
|
||||
}}}}}"""
|
||||
},
|
||||
query_type=QueryType.HOURLY_RATING,
|
||||
)
|
||||
|
||||
async def async_get_monthly_price_rating(self) -> Any:
|
||||
"""Get monthly price rating data."""
|
||||
return await self._api_wrapper(
|
||||
data={
|
||||
"query": """
|
||||
{viewer{homes{currentSubscription{priceRating{
|
||||
thresholdPercentages{low high}
|
||||
monthly{
|
||||
currency
|
||||
entries{time total energy tax difference level}
|
||||
}
|
||||
}}}}}"""
|
||||
},
|
||||
query_type=QueryType.MONTHLY_RATING,
|
||||
)
|
||||
|
||||
async def async_get_data(self) -> Any:
|
||||
"""Get all data from the API by combining multiple queries."""
|
||||
# Get all data concurrently
|
||||
price_info = await self.async_get_price_info()
|
||||
daily_rating = await self.async_get_daily_price_rating()
|
||||
hourly_rating = await self.async_get_hourly_price_rating()
|
||||
monthly_rating = await self.async_get_monthly_price_rating()
|
||||
|
||||
# Extract the base paths to make the code more readable
|
||||
def get_base_path(response: dict) -> dict:
|
||||
"""Get the base subscription path from the response."""
|
||||
return response["viewer"]["homes"][0]["currentSubscription"]
|
||||
|
||||
def get_rating_data(response: dict) -> dict:
|
||||
"""Get the price rating data from the response."""
|
||||
return get_base_path(response)["priceRating"]
|
||||
|
||||
price_info_data = get_base_path(price_info)["priceInfo"]
|
||||
|
||||
# Combine the results
|
||||
return {
|
||||
"data": {
|
||||
"viewer": {
|
||||
"homes": [
|
||||
{
|
||||
"currentSubscription": {
|
||||
"priceInfo": price_info_data,
|
||||
"priceRating": {
|
||||
"thresholdPercentages": get_rating_data(
|
||||
daily_rating
|
||||
)["thresholdPercentages"],
|
||||
"daily": get_rating_data(daily_rating)["daily"],
|
||||
"hourly": get_rating_data(hourly_rating)["hourly"],
|
||||
"monthly": get_rating_data(monthly_rating)[
|
||||
"monthly"
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def async_set_title(self, value: str) -> Any:
|
||||
"""Get data from the API."""
|
||||
|
|
@ -142,55 +375,132 @@ class TibberPricesApiClient:
|
|||
data={"title": value},
|
||||
)
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
headers: dict[str, str],
|
||||
data: dict,
|
||||
query_type: QueryType,
|
||||
) -> dict:
|
||||
"""Make an API request with proper error handling."""
|
||||
_LOGGER.debug("Making API request with data: %s", data)
|
||||
|
||||
response = await self._session.request(
|
||||
method="POST",
|
||||
url="https://api.tibber.com/v1-beta/gql",
|
||||
headers=headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
_verify_response_or_raise(response)
|
||||
response_json = await response.json()
|
||||
_LOGGER.debug("Received API response: %s", response_json)
|
||||
|
||||
await _verify_graphql_response(response_json)
|
||||
|
||||
return _transform_data(response_json["data"], query_type)
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
headers: dict[str, str],
|
||||
data: dict,
|
||||
query_type: QueryType,
|
||||
) -> Any:
|
||||
"""Handle a single API request with rate limiting."""
|
||||
async with self._request_semaphore:
|
||||
now = datetime.now(tz=UTC)
|
||||
time_since_last_request = now - self._last_request_time
|
||||
if time_since_last_request < self._min_request_interval:
|
||||
sleep_time = (
|
||||
self._min_request_interval - time_since_last_request
|
||||
).total_seconds()
|
||||
_LOGGER.debug(
|
||||
"Rate limiting: waiting %s seconds before next request",
|
||||
sleep_time,
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
async with async_timeout.timeout(10):
|
||||
self._last_request_time = datetime.now(tz=UTC)
|
||||
response_data = await self._make_request(
|
||||
headers,
|
||||
data or {},
|
||||
query_type,
|
||||
)
|
||||
|
||||
if query_type != QueryType.TEST and _is_data_empty(
|
||||
response_data, query_type.value
|
||||
):
|
||||
_LOGGER.debug("Empty data detected for query_type: %s", query_type)
|
||||
raise TibberPricesApiClientError(
|
||||
TibberPricesApiClientError.EMPTY_DATA_ERROR.format(
|
||||
query_type=query_type.value
|
||||
)
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
async def _api_wrapper(
|
||||
self,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
query_type: QueryType = QueryType.TEST,
|
||||
) -> Any:
|
||||
"""
|
||||
Get information from the API.
|
||||
"""Get information from the API with rate limiting and retry logic."""
|
||||
headers = headers or _prepare_headers(self._access_token)
|
||||
last_error: Exception | None = None
|
||||
|
||||
Returns the contents of the 'data' object from the GraphQL response.
|
||||
Raises an error if the response doesn't contain a 'data' object.
|
||||
"""
|
||||
try:
|
||||
async with async_timeout.timeout(10):
|
||||
headers = headers or {}
|
||||
if headers.get("Authorization") is None:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
if headers.get("Accept") is None:
|
||||
headers["Accept"] = "application/json"
|
||||
if headers.get("User-Agent") is None:
|
||||
headers["User-Agent"] = (
|
||||
f"HomeAssistant/{ha_version} (tibber_prices; +https://github.com/jpawlowski/hass.tibber_prices/)"
|
||||
for retry in range(self._max_retries + 1):
|
||||
try:
|
||||
return await self._handle_request(
|
||||
headers,
|
||||
data or {},
|
||||
query_type,
|
||||
)
|
||||
|
||||
except TibberPricesApiClientAuthenticationError:
|
||||
raise
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
socket.gaierror,
|
||||
TimeoutError,
|
||||
TibberPricesApiClientError,
|
||||
) as error:
|
||||
last_error = (
|
||||
error
|
||||
if isinstance(error, TibberPricesApiClientError)
|
||||
else TibberPricesApiClientError(
|
||||
TibberPricesApiClientError.GENERIC_ERROR.format(
|
||||
exception=str(error)
|
||||
)
|
||||
)
|
||||
response = await self._session.request(
|
||||
method="POST",
|
||||
url="https://api.tibber.com/v1-beta/gql",
|
||||
headers=headers,
|
||||
json=data,
|
||||
)
|
||||
_verify_response_or_raise(response)
|
||||
response_json = await response.json()
|
||||
await _verify_graphql_response(response_json)
|
||||
return response_json["data"]
|
||||
|
||||
except TimeoutError as exception:
|
||||
if retry < self._max_retries:
|
||||
delay = self._retry_delay * (2**retry)
|
||||
_LOGGER.warning(
|
||||
"Request failed, attempt %d/%d. Retrying in %d seconds: %s",
|
||||
retry + 1,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(error),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Handle final error state
|
||||
if isinstance(last_error, TimeoutError):
|
||||
raise TibberPricesApiClientCommunicationError(
|
||||
TibberPricesApiClientCommunicationError.TIMEOUT_ERROR.format(
|
||||
exception=last_error
|
||||
)
|
||||
) from last_error
|
||||
if isinstance(last_error, (aiohttp.ClientError, socket.gaierror)):
|
||||
raise TibberPricesApiClientCommunicationError(
|
||||
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(
|
||||
exception=exception
|
||||
exception=last_error
|
||||
)
|
||||
) from exception
|
||||
except (aiohttp.ClientError, socket.gaierror) as exception:
|
||||
raise TibberPricesApiClientCommunicationError(
|
||||
TibberPricesApiClientCommunicationError.CONNECTION_ERROR.format(
|
||||
exception=exception
|
||||
)
|
||||
) from exception
|
||||
except TibberPricesApiClientAuthenticationError:
|
||||
# Re-raise authentication errors directly
|
||||
raise
|
||||
except Exception as exception: # pylint: disable=broad-except
|
||||
raise TibberPricesApiClientError(
|
||||
TibberPricesApiClientError.GENERIC_ERROR.format(exception=exception)
|
||||
) from exception
|
||||
) from last_error
|
||||
|
||||
raise last_error or TibberPricesApiClientError(
|
||||
TibberPricesApiClientError.UNKNOWN_ERROR
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
_errors = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
await self._test_credentials(access_token=user_input[CONF_ACCESS_TOKEN])
|
||||
name = await self._test_credentials(
|
||||
access_token=user_input[CONF_ACCESS_TOKEN]
|
||||
)
|
||||
except TibberPricesApiClientAuthenticationError as exception:
|
||||
LOGGER.warning(exception)
|
||||
_errors["base"] = "auth"
|
||||
|
|
@ -42,15 +44,10 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
LOGGER.exception(exception)
|
||||
_errors["base"] = "unknown"
|
||||
else:
|
||||
await self.async_set_unique_id(
|
||||
## Do NOT use this in production code
|
||||
## The unique_id should never be something that can change
|
||||
## https://developers.home-assistant.io/docs/config_entries_config_flow_handler#unique-ids
|
||||
unique_id=slugify(user_input[CONF_ACCESS_TOKEN])
|
||||
)
|
||||
await self.async_set_unique_id(unique_id=slugify(name))
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(
|
||||
title=user_input[CONF_ACCESS_TOKEN],
|
||||
title=name,
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
|
|
@ -73,10 +70,11 @@ class TibberPricesFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
errors=_errors,
|
||||
)
|
||||
|
||||
async def _test_credentials(self, access_token: str) -> None:
|
||||
"""Validate credentials."""
|
||||
async def _test_credentials(self, access_token: str) -> str:
|
||||
"""Validate credentials and return the user's name."""
|
||||
client = TibberPricesApiClient(
|
||||
access_token=access_token,
|
||||
session=async_create_clientsession(self.hass),
|
||||
)
|
||||
await client.async_test_connection()
|
||||
result = await client.async_test_connection()
|
||||
return result["viewer"]["name"]
|
||||
|
|
|
|||
|
|
@ -4,5 +4,7 @@ from logging import Logger, getLogger
|
|||
|
||||
LOGGER: Logger = getLogger(__package__)
|
||||
|
||||
NAME = "Tibber Price Information & Ratings"
|
||||
VERSION = "0.1.0" # Must match version in manifest.json
|
||||
DOMAIN = "tibber_prices"
|
||||
ATTRIBUTION = "Data provided by https://tibber.com/"
|
||||
|
|
|
|||
|
|
@ -2,31 +2,271 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.helpers.update_coordinator import (
|
||||
DataUpdateCoordinator,
|
||||
UpdateFailed,
|
||||
)
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .api import (
|
||||
TibberPricesApiClientAuthenticationError,
|
||||
TibberPricesApiClientError,
|
||||
)
|
||||
from .const import DOMAIN, LOGGER
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .data import TibberPricesConfigEntry
|
||||
|
||||
|
||||
PRICE_UPDATE_RANDOM_MIN_HOUR = 13 # Don't check before 13:00
|
||||
PRICE_UPDATE_RANDOM_MAX_HOUR = 15 # Don't check after 15:00
|
||||
PRICE_UPDATE_INTERVAL = timedelta(days=1)
|
||||
RATING_UPDATE_INTERVAL = timedelta(hours=1)
|
||||
NO_DATA_ERROR_MSG = "No data available"
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = f"{DOMAIN}.coordinator"
|
||||
|
||||
|
||||
def _raise_no_data() -> None:
|
||||
"""Raise error when no data is available."""
|
||||
raise TibberPricesApiClientError(NO_DATA_ERROR_MSG)
|
||||
|
||||
|
||||
# https://developers.home-assistant.io/docs/integration_fetching_data#coordinated-single-api-poll-for-data-for-all-entities
|
||||
class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
"""Class to manage fetching data from the API."""
|
||||
|
||||
config_entry: TibberPricesConfigEntry
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize coordinator with cache."""
|
||||
super().__init__(hass, *args, **kwargs)
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._cached_price_data: dict | None = None
|
||||
self._cached_rating_data: dict | None = None
|
||||
self._last_price_update: datetime | None = None
|
||||
self._last_rating_update: datetime | None = None
|
||||
self._scheduled_price_update: asyncio.Task | None = None
|
||||
|
||||
async def _async_initialize(self) -> None:
|
||||
"""Load stored data."""
|
||||
stored = await self._store.async_load()
|
||||
if stored:
|
||||
self._cached_price_data = stored.get("price_data")
|
||||
self._cached_rating_data = stored.get("rating_data")
|
||||
if last_price := stored.get("last_price_update"):
|
||||
self._last_price_update = dt_util.parse_datetime(last_price)
|
||||
if last_rating := stored.get("last_rating_update"):
|
||||
self._last_rating_update = dt_util.parse_datetime(last_rating)
|
||||
LOGGER.debug(
|
||||
"Loaded stored cache data - Price from: %s, Rating from: %s",
|
||||
self._last_price_update,
|
||||
self._last_rating_update,
|
||||
)
|
||||
|
||||
async def _async_update_data(self) -> Any:
|
||||
"""Update data via library."""
|
||||
if self._cached_price_data is None:
|
||||
# First run after startup, load stored data
|
||||
await self._async_initialize()
|
||||
|
||||
try:
|
||||
return await self.config_entry.runtime_data.client.async_get_data()
|
||||
data = await self._update_all_data()
|
||||
except TibberPricesApiClientAuthenticationError as exception:
|
||||
raise ConfigEntryAuthFailed(exception) from exception
|
||||
except TibberPricesApiClientError as exception:
|
||||
raise UpdateFailed(exception) from exception
|
||||
else:
|
||||
return data
|
||||
|
||||
async def _update_all_data(self) -> dict[str, Any]:
|
||||
"""Update all data and manage cache."""
|
||||
current_time = dt_util.now()
|
||||
processed_data: dict[str, Any] | None = None
|
||||
is_initial_setup = self._cached_price_data is None
|
||||
|
||||
# Handle price data update if needed
|
||||
if self._should_update_price_data(current_time):
|
||||
# Check if we're within the allowed time window for price updates
|
||||
# or if this is initial setup
|
||||
current_hour = current_time.hour
|
||||
if is_initial_setup or self._is_price_update_window(current_hour):
|
||||
# Add random delay only for regular updates, not initial setup
|
||||
if not is_initial_setup:
|
||||
delay = random.randint(0, 120) # noqa: S311
|
||||
LOGGER.debug(
|
||||
"Adding random delay of %d minutes before price update",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay * 60)
|
||||
|
||||
# Get fresh price data
|
||||
data = await self._fetch_price_data()
|
||||
self._cached_price_data = self._extract_price_data(data)
|
||||
self._last_price_update = current_time
|
||||
await self._store_cache()
|
||||
LOGGER.debug("Updated price data cache at %s", current_time)
|
||||
processed_data = data
|
||||
|
||||
# Handle rating data update if needed
|
||||
if self._should_update_rating_data(current_time):
|
||||
rating_data = await self._get_rating_data()
|
||||
self._cached_rating_data = self._extract_rating_data(rating_data)
|
||||
self._last_rating_update = current_time
|
||||
await self._store_cache()
|
||||
LOGGER.debug("Updated rating data cache at %s", current_time)
|
||||
processed_data = rating_data
|
||||
|
||||
# If we have cached data but no updates were needed
|
||||
if (
|
||||
processed_data is None
|
||||
and self._cached_price_data
|
||||
and self._cached_rating_data
|
||||
):
|
||||
LOGGER.debug(
|
||||
"Using cached data - Price from: %s, Rating from: %s",
|
||||
self._last_price_update,
|
||||
self._last_rating_update,
|
||||
)
|
||||
processed_data = self._merge_cached_data()
|
||||
|
||||
if processed_data is None:
|
||||
_raise_no_data()
|
||||
|
||||
return cast("dict[str, Any]", processed_data)
|
||||
|
||||
async def _store_cache(self) -> None:
|
||||
"""Store cache data."""
|
||||
last_price = (
|
||||
self._last_price_update.isoformat() if self._last_price_update else None
|
||||
)
|
||||
last_rating = (
|
||||
self._last_rating_update.isoformat() if self._last_rating_update else None
|
||||
)
|
||||
data = {
|
||||
"price_data": self._cached_price_data,
|
||||
"rating_data": self._cached_rating_data,
|
||||
"last_price_update": last_price,
|
||||
"last_rating_update": last_rating,
|
||||
}
|
||||
await self._store.async_save(data)
|
||||
|
||||
def _should_update_price_data(self, current_time: datetime) -> bool:
|
||||
"""Check if price data should be updated."""
|
||||
return (
|
||||
self._cached_price_data is None
|
||||
or self._last_price_update is None
|
||||
or current_time - self._last_price_update >= PRICE_UPDATE_INTERVAL
|
||||
)
|
||||
|
||||
def _should_update_rating_data(self, current_time: datetime) -> bool:
|
||||
"""Check if rating data should be updated."""
|
||||
return (
|
||||
self._cached_rating_data is None
|
||||
or self._last_rating_update is None
|
||||
or current_time - self._last_rating_update >= RATING_UPDATE_INTERVAL
|
||||
)
|
||||
|
||||
def _is_price_update_window(self, current_hour: int) -> bool:
|
||||
"""Check if current hour is within price update window."""
|
||||
return (
|
||||
PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR
|
||||
)
|
||||
|
||||
async def _fetch_price_data(self) -> dict:
|
||||
"""Fetch fresh price data from API."""
|
||||
client = self.config_entry.runtime_data.client
|
||||
return await client.async_get_price_info()
|
||||
|
||||
def _extract_price_data(self, data: dict) -> dict:
|
||||
"""Extract price data for caching."""
|
||||
price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"]
|
||||
return {
|
||||
"data": {
|
||||
"viewer": {
|
||||
"homes": [{"currentSubscription": {"priceInfo": price_info}}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_rating_data(self, data: dict) -> dict:
|
||||
"""Extract rating data for caching."""
|
||||
return {
|
||||
"data": {
|
||||
"viewer": {
|
||||
"homes": [
|
||||
{
|
||||
"currentSubscription": {
|
||||
"priceRating": data["data"]["viewer"]["homes"][0][
|
||||
"currentSubscription"
|
||||
]["priceRating"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _merge_cached_data(self) -> dict:
|
||||
"""Merge cached price and rating data."""
|
||||
if not self._cached_price_data or not self._cached_rating_data:
|
||||
return {}
|
||||
|
||||
subscription = {
|
||||
"priceInfo": self._cached_price_data["data"]["viewer"]["homes"][0][
|
||||
"currentSubscription"
|
||||
]["priceInfo"],
|
||||
"priceRating": self._cached_rating_data["data"]["viewer"]["homes"][0][
|
||||
"currentSubscription"
|
||||
]["priceRating"],
|
||||
}
|
||||
|
||||
return {"data": {"viewer": {"homes": [{"currentSubscription": subscription}]}}}
|
||||
|
||||
async def _get_rating_data(self) -> dict:
|
||||
"""Get fresh rating data from API."""
|
||||
client = self.config_entry.runtime_data.client
|
||||
daily = await client.async_get_daily_price_rating()
|
||||
hourly = await client.async_get_hourly_price_rating()
|
||||
monthly = await client.async_get_monthly_price_rating()
|
||||
|
||||
rating_base = daily["viewer"]["homes"][0]["currentSubscription"]["priceRating"]
|
||||
|
||||
return {
|
||||
"data": {
|
||||
"viewer": {
|
||||
"homes": [
|
||||
{
|
||||
"currentSubscription": {
|
||||
"priceRating": {
|
||||
"thresholdPercentages": rating_base[
|
||||
"thresholdPercentages"
|
||||
],
|
||||
"daily": rating_base["daily"],
|
||||
"hourly": hourly["viewer"]["homes"][0][
|
||||
"currentSubscription"
|
||||
]["priceRating"]["hourly"],
|
||||
"monthly": monthly["viewer"]["homes"][0][
|
||||
"currentSubscription"
|
||||
]["priceRating"]["monthly"],
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue