diff --git a/.ruff.toml b/.ruff.toml index 7e9f81b..2833cdd 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,6 +1,7 @@ # The contents of this file is based on https://github.com/home-assistant/core/blob/dev/pyproject.toml target-version = "py313" +line-length = 120 [lint] select = [ @@ -19,8 +20,12 @@ ignore = [ [lint.flake8-pytest-style] fixture-parentheses = false +[lint.isort] +force-single-line = false +known-first-party = ["custom_components", "homeassistant"] + [lint.pyupgrade] keep-runtime-typing = true [lint.mccabe] -max-complexity = 25 \ No newline at end of file +max-complexity = 25 diff --git a/custom_components/tibber_prices/api.py b/custom_components/tibber_prices/api.py index af0bb46..951c926 100644 --- a/custom_components/tibber_prices/api.py +++ b/custom_components/tibber_prices/api.py @@ -11,6 +11,7 @@ from typing import Any import aiohttp import async_timeout + from homeassistant.const import __version__ as ha_version from .const import VERSION diff --git a/custom_components/tibber_prices/binary_sensor.py b/custom_components/tibber_prices/binary_sensor.py index 741e3a1..bf1c3c8 100644 --- a/custom_components/tibber_prices/binary_sensor.py +++ b/custom_components/tibber_prices/binary_sensor.py @@ -5,6 +5,9 @@ from __future__ import annotations from datetime import UTC, datetime from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Callable + from homeassistant.components.binary_sensor import ( BinarySensorDeviceClass, BinarySensorEntity, @@ -70,18 +73,42 @@ class TibberPricesBinarySensor(TibberPricesEntity, BinarySensorEntity): """Initialize the binary_sensor class.""" super().__init__(coordinator) self.entity_description = entity_description - self._attr_unique_id = ( - f"{coordinator.config_entry.entry_id}_{entity_description.key}" - ) + self._attr_unique_id = f"{coordinator.config_entry.entry_id}_{entity_description.key}" + self._state_getter: Callable | None = self._get_state_getter() + self._attribute_getter: Callable | None = self._get_attribute_getter() + + def _get_state_getter(self) -> Callable | None: + """Return the appropriate state getter method based on the sensor type.""" + key = self.entity_description.key + + if key == "peak_hour": + return lambda: self._get_price_threshold_state(threshold_percentage=0.8, high_is_active=True) + if key == "best_price_hour": + return lambda: self._get_price_threshold_state(threshold_percentage=0.2, high_is_active=False) + if key == "connection": + return lambda: True if self.coordinator.data else None + + return None + + def _get_attribute_getter(self) -> Callable | None: + """Return the appropriate attribute getter method based on the sensor type.""" + key = self.entity_description.key + + if key == "peak_hour": + return lambda: self._get_price_hours_attributes(attribute_name="peak_hours", reverse_sort=True) + if key == "best_price_hour": + return lambda: self._get_price_hours_attributes(attribute_name="best_price_hours", reverse_sort=False) + + return None def _get_current_price_data(self) -> tuple[list[float], float] | None: """Get current price data if available.""" if not ( self.coordinator.data and ( - today_prices := self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ]["priceInfo"].get("today", []) + today_prices := self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"][ + "priceInfo" + ].get("today", []) ) ): return None @@ -102,26 +129,59 @@ class TibberPricesBinarySensor(TibberPricesEntity, BinarySensorEntity): prices.sort() return prices, float(current_hour_data["total"]) + def _get_price_threshold_state(self, *, threshold_percentage: float, high_is_active: bool) -> bool | None: + """ + Determine if current price is above/below threshold. + + Args: + threshold_percentage: The percentage point in the sorted list (0.0-1.0) + high_is_active: If True, value >= threshold is active, otherwise value <= threshold is active + + """ + price_data = self._get_current_price_data() + if not price_data: + return None + + prices, current_price = price_data + threshold_index = int(len(prices) * threshold_percentage) + + if high_is_active: + return current_price >= prices[threshold_index] + + return current_price <= prices[threshold_index] + + def _get_price_hours_attributes(self, *, attribute_name: str, reverse_sort: bool) -> dict | None: + """Get price hours attributes.""" + if not self.coordinator.data: + return None + + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] + + today_prices = price_info.get("today", []) + if not today_prices: + return None + + prices = [ + ( + datetime.fromisoformat(price["startsAt"]).hour, + float(price["total"]), + ) + for price in today_prices + ] + + # Sort by price (high to low for peak, low to high for best) + sorted_hours = sorted(prices, key=lambda x: x[1], reverse=reverse_sort)[:5] + + return {attribute_name: [{"hour": hour, "price": price} for hour, price in sorted_hours]} + @property def is_on(self) -> bool | None: """Return true if the binary_sensor is on.""" try: - price_data = self._get_current_price_data() - if not price_data: + if not self.coordinator.data or not self._state_getter: return None - prices, current_price = price_data - match self.entity_description.key: - case "peak_hour": - threshold_index = int(len(prices) * 0.8) - return current_price >= prices[threshold_index] - case "best_price_hour": - threshold_index = int(len(prices) * 0.2) - return current_price <= prices[threshold_index] - case "connection": - return True - case _: - return None + return self._state_getter() except (KeyError, ValueError, TypeError) as ex: self.coordinator.logger.exception( @@ -137,43 +197,10 @@ class TibberPricesBinarySensor(TibberPricesEntity, BinarySensorEntity): def extra_state_attributes(self) -> dict | None: """Return additional state attributes.""" try: - if not self.coordinator.data: + if not self.coordinator.data or not self._attribute_getter: return None - subscription = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] - price_info = subscription["priceInfo"] - attributes = {} - - if self.entity_description.key in ["peak_hour", "best_price_hour"]: - today_prices = price_info.get("today", []) - if today_prices: - prices = [ - ( - datetime.fromisoformat(price["startsAt"]).hour, - float(price["total"]), - ) - for price in today_prices - ] - - if self.entity_description.key == "peak_hour": - # Get top 5 peak hours - peak_hours = sorted(prices, key=lambda x: x[1], reverse=True)[ - :5 - ] - attributes["peak_hours"] = [ - {"hour": hour, "price": price} for hour, price in peak_hours - ] - else: - # Get top 5 best price hours - best_hours = sorted(prices, key=lambda x: x[1])[:5] - attributes["best_price_hours"] = [ - {"hour": hour, "price": price} for hour, price in best_hours - ] - return attributes - else: - return None + return self._attribute_getter() except (KeyError, ValueError, TypeError) as ex: self.coordinator.logger.exception( diff --git a/custom_components/tibber_prices/config_flow.py b/custom_components/tibber_prices/config_flow.py index 57999c2..8529e27 100644 --- a/custom_components/tibber_prices/config_flow.py +++ b/custom_components/tibber_prices/config_flow.py @@ -3,11 +3,12 @@ from __future__ import annotations import voluptuous as vol +from slugify import slugify + from homeassistant import config_entries from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.helpers import selector from homeassistant.helpers.aiohttp_client import async_create_clientsession -from slugify import slugify from .api import ( TibberPricesApiClient, diff --git a/custom_components/tibber_prices/coordinator.py b/custom_components/tibber_prices/coordinator.py index 50a3f8c..ce55376 100644 --- a/custom_components/tibber_prices/coordinator.py +++ b/custom_components/tibber_prices/coordinator.py @@ -91,9 +91,7 @@ def _get_latest_timestamp_from_prices( for price in today_prices: if starts_at := price.get("startsAt"): timestamp = dt_util.parse_datetime(starts_at) - if timestamp and ( - not latest_timestamp or timestamp > latest_timestamp - ): + if timestamp and (not latest_timestamp or timestamp > latest_timestamp): latest_timestamp = timestamp # Check tomorrow's prices @@ -101,9 +99,7 @@ def _get_latest_timestamp_from_prices( for price in tomorrow_prices: if starts_at := price.get("startsAt"): timestamp = dt_util.parse_datetime(starts_at) - if timestamp and ( - not latest_timestamp or timestamp > latest_timestamp - ): + if timestamp and (not latest_timestamp or timestamp > latest_timestamp): latest_timestamp = timestamp except (KeyError, IndexError, TypeError): @@ -131,9 +127,7 @@ def _get_latest_timestamp_from_rating( for entry in rating_entries: if time := entry.get("time"): timestamp = dt_util.parse_datetime(time) - if timestamp and ( - not latest_timestamp or timestamp > latest_timestamp - ): + if timestamp and (not latest_timestamp or timestamp > latest_timestamp): latest_timestamp = timestamp except (KeyError, IndexError, TypeError): return None @@ -171,16 +165,12 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) # Schedule updates at the start of every hour self._remove_update_listeners.append( - async_track_time_change( - hass, self._async_refresh_hourly, minute=0, second=0 - ) + async_track_time_change(hass, self._async_refresh_hourly, minute=0, second=0) ) # Schedule data rotation at midnight self._remove_update_listeners.append( - async_track_time_change( - hass, self._async_handle_midnight_rotation, hour=0, minute=0, second=0 - ) + async_track_time_change(hass, self._async_handle_midnight_rotation, hour=0, minute=0, second=0) ) async def async_shutdown(self) -> None: @@ -189,18 +179,14 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) for listener in self._remove_update_listeners: listener() - async def _async_handle_midnight_rotation( - self, _now: datetime | None = None - ) -> None: + async def _async_handle_midnight_rotation(self, _now: datetime | None = None) -> None: """Handle data rotation at midnight.""" if not self._cached_price_data: return try: LOGGER.debug("Starting midnight data rotation") - subscription = self._cached_price_data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] + subscription = self._cached_price_data["data"]["viewer"]["homes"][0]["currentSubscription"] price_info = subscription["priceInfo"] # Move today's data to yesterday @@ -262,20 +248,12 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) if stored: # Load cached data self._cached_price_data = cast("TibberPricesData", stored.get("price_data")) - self._cached_rating_data_hourly = cast( - "TibberPricesData", stored.get("rating_data_hourly") - ) - self._cached_rating_data_daily = cast( - "TibberPricesData", stored.get("rating_data_daily") - ) - self._cached_rating_data_monthly = cast( - "TibberPricesData", stored.get("rating_data_monthly") - ) + self._cached_rating_data_hourly = cast("TibberPricesData", stored.get("rating_data_hourly")) + self._cached_rating_data_daily = cast("TibberPricesData", stored.get("rating_data_daily")) + self._cached_rating_data_monthly = cast("TibberPricesData", stored.get("rating_data_monthly")) # Recover timestamps - self._last_price_update = self._recover_timestamp( - self._cached_price_data, stored.get("last_price_update") - ) + self._last_price_update = self._recover_timestamp(self._cached_price_data, stored.get("last_price_update")) self._last_rating_update_hourly = self._recover_timestamp( self._cached_rating_data_hourly, stored.get("last_rating_update_hourly"), @@ -293,8 +271,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) ) LOGGER.debug( - "Loaded stored cache data - " - "Price update: %s, Rating hourly: %s, daily: %s, monthly: %s", + "Loaded stored cache data - Price update: %s, Rating hourly: %s, daily: %s, monthly: %s", self._last_price_update, self._last_rating_update_hourly, self._last_rating_update_daily, @@ -379,46 +356,15 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) else: return result - async def _handle_conditional_update( - self, current_time: datetime - ) -> TibberPricesData: + async def _handle_conditional_update(self, current_time: datetime) -> TibberPricesData: """Handle conditional update based on update conditions.""" - should_update_price = self._should_update_price_data(current_time) - should_update_hourly = self._should_update_rating_type( - current_time, - self._cached_rating_data_hourly, - self._last_rating_update_hourly, - "hourly", - ) - should_update_daily = self._should_update_rating_type( - current_time, - self._cached_rating_data_daily, - self._last_rating_update_daily, - "daily", - ) - should_update_monthly = self._should_update_rating_type( - current_time, - self._cached_rating_data_monthly, - self._last_rating_update_monthly, - "monthly", - ) + # Simplified conditional update checking + update_conditions = self._check_update_conditions(current_time) - if any( - [ - should_update_price, - should_update_hourly, - should_update_daily, - should_update_monthly, - ] - ): + if any(update_conditions.values()): LOGGER.debug( "Updating data based on conditions", - extra={ - "update_price": should_update_price, - "update_hourly": should_update_hourly, - "update_daily": should_update_daily, - "update_monthly": should_update_monthly, - }, + extra=update_conditions, ) return await self._fetch_all_data() @@ -429,6 +375,31 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) LOGGER.debug("No cached data available, fetching new data") return await self._fetch_all_data() + @callback + def _check_update_conditions(self, current_time: datetime) -> dict[str, bool]: + """Check all update conditions and return results as a dictionary.""" + return { + "update_price": self._should_update_price_data(current_time), + "update_hourly": self._should_update_rating_type( + current_time, + self._cached_rating_data_hourly, + self._last_rating_update_hourly, + "hourly", + ), + "update_daily": self._should_update_rating_type( + current_time, + self._cached_rating_data_daily, + self._last_rating_update_daily, + "daily", + ), + "update_monthly": self._should_update_rating_type( + current_time, + self._cached_rating_data_monthly, + self._last_rating_update_monthly, + "monthly", + ), + } + async def _fetch_all_data(self) -> TibberPricesData: """ Fetch all data from the API without checking update conditions. @@ -462,9 +433,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) except TibberPricesApiClientError as ex: LOGGER.error("Failed to fetch price data: %s", ex) if self._cached_price_data is not None: - LOGGER.info( - "Using cached data as fallback after price data fetch failure" - ) + LOGGER.info("Using cached data as fallback after price data fetch failure") return self._merge_all_cached_data() raise @@ -483,24 +452,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) # Update rating data cache only for types that were successfully fetched for rating_type, rating_data in new_data["rating_data"].items(): if rating_data is not None: - if rating_type == "hourly": - self._cached_rating_data_hourly = cast( - "TibberPricesData", rating_data - ) - self._last_rating_update_hourly = current_time - elif rating_type == "daily": - self._cached_rating_data_daily = cast( - "TibberPricesData", rating_data - ) - self._last_rating_update_daily = current_time - else: # monthly - self._cached_rating_data_monthly = cast( - "TibberPricesData", rating_data - ) - self._last_rating_update_monthly = current_time - LOGGER.debug( - "Updated %s rating data cache at %s", rating_type, current_time - ) + self._update_rating_cache(rating_type, rating_data, current_time) # Store the updated cache await self._store_cache() @@ -509,13 +461,25 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) # Return merged data return self._merge_all_cached_data() + @callback + def _update_rating_cache(self, rating_type: str, rating_data: TibberPricesData, current_time: datetime) -> None: + """Update the rating cache for a specific rating type.""" + if rating_type == "hourly": + self._cached_rating_data_hourly = cast("TibberPricesData", rating_data) + self._last_rating_update_hourly = current_time + elif rating_type == "daily": + self._cached_rating_data_daily = cast("TibberPricesData", rating_data) + self._last_rating_update_daily = current_time + else: # monthly + self._cached_rating_data_monthly = cast("TibberPricesData", rating_data) + self._last_rating_update_monthly = current_time + LOGGER.debug("Updated %s rating data cache at %s", rating_type, current_time) + async def _store_cache(self) -> None: """Store cache data.""" # Recover any missing timestamps from the data if self._cached_price_data and not self._last_price_update: - latest_timestamp = _get_latest_timestamp_from_prices( - self._cached_price_data - ) + latest_timestamp = _get_latest_timestamp_from_prices(self._cached_price_data) if latest_timestamp: self._last_price_update = latest_timestamp LOGGER.debug( @@ -537,9 +501,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) for rating_type, (cached_data, last_update) in rating_types.items(): if cached_data and not last_update: - latest_timestamp = self._get_latest_timestamp_from_rating_type( - cached_data, rating_type - ) + latest_timestamp = self._get_latest_timestamp_from_rating_type(cached_data, rating_type) if latest_timestamp: if rating_type == "hourly": self._last_rating_update_hourly = latest_timestamp @@ -558,9 +520,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) "rating_data_hourly": self._cached_rating_data_hourly, "rating_data_daily": self._cached_rating_data_daily, "rating_data_monthly": self._cached_rating_data_monthly, - "last_price_update": self._last_price_update.isoformat() - if self._last_price_update - else None, + "last_price_update": self._last_price_update.isoformat() if self._last_price_update else None, "last_rating_update_hourly": self._last_rating_update_hourly.isoformat() if self._last_rating_update_hourly else None, @@ -586,9 +546,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) return True # Get the latest timestamp from our price data - latest_price_timestamp = _get_latest_timestamp_from_prices( - self._cached_price_data - ) + latest_price_timestamp = _get_latest_timestamp_from_prices(self._cached_price_data) if not latest_price_timestamp: LOGGER.debug("No valid timestamp found in price data, update needed") return True @@ -603,30 +561,22 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) # Check if we're in the update window (13:00-15:00) current_hour = current_time.hour - in_update_window = ( - PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR - ) + in_update_window = PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR # Get tomorrow's date at midnight - tomorrow = (current_time + timedelta(days=1)).replace( - hour=0, minute=0, second=0, microsecond=0 - ) + tomorrow = (current_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) # If we're in the update window and don't have tomorrow's complete data if in_update_window and latest_price_timestamp < tomorrow: LOGGER.debug( - "In update window (%d:00) and latest price timestamp (%s) " - "is before tomorrow, update needed", + "In update window (%d:00) and latest price timestamp (%s) is before tomorrow, update needed", current_hour, latest_price_timestamp, ) return True # If it's been more than 24 hours since our last update - if ( - self._last_price_update - and current_time - self._last_price_update >= UPDATE_INTERVAL - ): + if self._last_price_update and current_time - self._last_price_update >= UPDATE_INTERVAL: LOGGER.debug( "More than 24 hours since last price update (%s), update needed", self._last_price_update, @@ -651,19 +601,13 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) """Check if specific rating type should be updated.""" # If no cached data, we definitely need an update if cached_data is None: - LOGGER.debug( - "No cached %s rating data available, update needed", rating_type - ) + LOGGER.debug("No cached %s rating data available, update needed", rating_type) return True # Get the latest timestamp from our rating data - latest_timestamp = self._get_latest_timestamp_from_rating_type( - cached_data, rating_type - ) + latest_timestamp = self._get_latest_timestamp_from_rating_type(cached_data, rating_type) if not latest_timestamp: - LOGGER.debug( - "No valid timestamp found in %s rating data, update needed", rating_type - ) + LOGGER.debug("No valid timestamp found in %s rating data, update needed", rating_type) return True # If we have rating data but no last_update timestamp, set it @@ -682,22 +626,16 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) last_update = latest_timestamp current_hour = current_time.hour - in_update_window = ( - PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR - ) + in_update_window = PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR should_update = False if rating_type == "monthly": - current_month_start = current_time.replace( - day=1, hour=0, minute=0, second=0, microsecond=0 - ) + current_month_start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0) should_update = latest_timestamp < current_month_start or ( last_update and current_time - last_update >= timedelta(days=1) ) else: - tomorrow = (current_time + timedelta(days=1)).replace( - hour=0, minute=0, second=0, microsecond=0 - ) + tomorrow = (current_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) should_update = ( in_update_window and latest_timestamp < tomorrow ) or current_time - last_update >= UPDATE_INTERVAL @@ -722,9 +660,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) @callback def _is_price_update_window(self, current_hour: int) -> bool: """Check if current hour is within price update window.""" - return ( - PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR - ) + return PRICE_UPDATE_RANDOM_MIN_HOUR <= current_hour <= PRICE_UPDATE_RANDOM_MAX_HOUR async def _fetch_price_data(self) -> dict: """Fetch fresh price data from API.""" @@ -737,14 +673,10 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) try: # Try to access data in the transformed structure first try: - price_info = data["viewer"]["homes"][0]["currentSubscription"][ - "priceInfo" - ] + price_info = data["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] except KeyError: # If that fails, try the raw data structure - price_info = data["data"]["viewer"]["homes"][0]["currentSubscription"][ - "priceInfo" - ] + price_info = data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] # Ensure we have all required fields extracted_price_info = { @@ -771,15 +703,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) } } } - return { - "data": { - "viewer": { - "homes": [ - {"currentSubscription": {"priceInfo": extracted_price_info}} - ] - } - } - } + return {"data": {"viewer": {"homes": [{"currentSubscription": {"priceInfo": extracted_price_info}}]}}} @callback def _get_latest_timestamp_from_rating_type( @@ -790,9 +714,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) return None try: - subscription = rating_data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] + subscription = rating_data["data"]["viewer"]["homes"][0]["currentSubscription"] price_rating = subscription["priceRating"] result = None @@ -823,15 +745,11 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) except KeyError: try: # If that fails, try the raw data structure - rating = data["data"]["viewer"]["homes"][0]["currentSubscription"][ - "priceRating" - ] + rating = data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceRating"] except KeyError as ex: LOGGER.error("Failed to extract rating data: %s", ex) raise TibberPricesApiClientError( - TibberPricesApiClientError.EMPTY_DATA_ERROR.format( - query_type=rating_type - ) + TibberPricesApiClientError.EMPTY_DATA_ERROR.format(query_type=rating_type) ) from ex else: return { @@ -841,9 +759,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) { "currentSubscription": { "priceRating": { - "thresholdPercentages": rating[ - "thresholdPercentages" - ], + "thresholdPercentages": rating["thresholdPercentages"], rating_type: rating[rating_type], } } @@ -860,9 +776,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) { "currentSubscription": { "priceRating": { - "thresholdPercentages": rating[ - "thresholdPercentages" - ], + "thresholdPercentages": rating["thresholdPercentages"], rating_type: rating[rating_type], } } @@ -880,9 +794,7 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) # Start with price info subscription = { - "priceInfo": self._cached_price_data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ]["priceInfo"], + "priceInfo": self._cached_price_data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"], "priceRating": { "thresholdPercentages": None, }, @@ -897,15 +809,11 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[TibberPricesData]) for rating_type, data in rating_data.items(): if data and "data" in data: - rating = data["data"]["viewer"]["homes"][0]["currentSubscription"][ - "priceRating" - ] + rating = data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceRating"] # Set thresholdPercentages from any available rating data if not subscription["priceRating"]["thresholdPercentages"]: - subscription["priceRating"]["thresholdPercentages"] = rating[ - "thresholdPercentages" - ] + subscription["priceRating"]["thresholdPercentages"] = rating["thresholdPercentages"] # Add the specific rating type data subscription["priceRating"][rating_type] = rating[rating_type] diff --git a/custom_components/tibber_prices/sensor.py b/custom_components/tibber_prices/sensor.py index ae9c855..c4bd659 100644 --- a/custom_components/tibber_prices/sensor.py +++ b/custom_components/tibber_prices/sensor.py @@ -3,7 +3,10 @@ from __future__ import annotations from datetime import UTC, datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable from homeassistant.components.sensor import ( SensorDeviceClass, @@ -16,6 +19,8 @@ from homeassistant.util import dt as dt_util from .entity import TibberPricesEntity if TYPE_CHECKING: + from collections.abc import Callable + from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -203,192 +208,164 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): """Initialize the sensor class.""" super().__init__(coordinator) self.entity_description = entity_description - self._attr_unique_id = ( - f"{coordinator.config_entry.entry_id}_{entity_description.key}" - ) + self._attr_unique_id = f"{coordinator.config_entry.entry_id}_{entity_description.key}" self._attr_has_entity_name = True + self._value_getter: Callable | None = self._get_value_getter() + + def _get_value_getter(self) -> Callable | None: + """Return the appropriate value getter method based on the sensor type.""" + key = self.entity_description.key + + # Map sensor keys to their handler methods + handlers = { + # Price level + "price_level": self._get_price_level_value, + # Price sensors + "current_price": lambda: self._get_hourly_price_value(hour_offset=0, in_euro=False), + "current_price_eur": lambda: self._get_hourly_price_value(hour_offset=0, in_euro=True), + "next_hour_price": lambda: self._get_hourly_price_value(hour_offset=1, in_euro=False), + "next_hour_price_eur": lambda: self._get_hourly_price_value(hour_offset=1, in_euro=True), + # Statistics sensors + "lowest_price_today": lambda: self._get_statistics_value(stat_func=min, in_euro=False), + "lowest_price_today_eur": lambda: self._get_statistics_value(stat_func=min, in_euro=True), + "highest_price_today": lambda: self._get_statistics_value(stat_func=max, in_euro=False), + "highest_price_today_eur": lambda: self._get_statistics_value(stat_func=max, in_euro=True), + "average_price_today": lambda: self._get_statistics_value( + stat_func=lambda prices: sum(prices) / len(prices), in_euro=False + ), + "average_price_today_eur": lambda: self._get_statistics_value( + stat_func=lambda prices: sum(prices) / len(prices), in_euro=True + ), + # Rating sensors + "hourly_rating": lambda: self._get_rating_value(rating_type="hourly"), + "daily_rating": lambda: self._get_rating_value(rating_type="daily"), + "monthly_rating": lambda: self._get_rating_value(rating_type="monthly"), + # Diagnostic sensors + "data_timestamp": self._get_data_timestamp, + "tomorrow_data_available": self._get_tomorrow_data_status, + } + + return handlers.get(key) def _get_current_hour_data(self) -> dict | None: """Get the price data for the current hour.""" if not self.coordinator.data: return None now = datetime.now(tz=UTC).astimezone() - price_info = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ]["priceInfo"] + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] for price_data in price_info.get("today", []): starts_at = datetime.fromisoformat(price_data["startsAt"]) if starts_at.hour == now.hour: return price_data return None - def _get_price_value(self, price: float) -> float: - """Convert price based on unit.""" - return ( - price * 100 - if self.entity_description.native_unit_of_measurement == "ct/kWh" - else price - ) + def _get_price_level_value(self) -> str | None: + """Get the current price level value.""" + current_hour_data = self._get_current_hour_data() + return current_hour_data["level"] if current_hour_data else None - def _get_price_sensor_value(self) -> float | None: - """Handle price sensor values.""" + def _get_price_value(self, price: float, *, in_euro: bool) -> float: + """Convert price based on unit.""" + return price if in_euro else price * 100 + + def _get_hourly_price_value(self, *, hour_offset: int, in_euro: bool) -> float | None: + """Get price for current hour or with offset.""" if not self.coordinator.data: return None - subscription = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] - price_info = subscription["priceInfo"] + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] now = datetime.now(tz=UTC).astimezone() - current_hour_data = self._get_current_hour_data() + target_hour = (now.hour + hour_offset) % 24 - key = self.entity_description.key - if key in ["current_price", "current_price_eur"]: - if not current_hour_data: - return None - return ( - self._get_price_value(float(current_hour_data["total"])) - if key == "current_price" - else float(current_hour_data["total"]) - ) - - if key in ["next_hour_price", "next_hour_price_eur"]: - next_hour = (now.hour + 1) % 24 - for price_data in price_info.get("today", []): - starts_at = datetime.fromisoformat(price_data["startsAt"]) - if starts_at.hour == next_hour: - return ( - self._get_price_value(float(price_data["total"])) - if key == "next_hour_price" - else float(price_data["total"]) - ) - return None + for price_data in price_info.get("today", []): + starts_at = datetime.fromisoformat(price_data["startsAt"]) + if starts_at.hour == target_hour: + return self._get_price_value(float(price_data["total"]), in_euro=in_euro) return None - def _get_statistics_value(self) -> float | None: - """Handle statistics sensor values.""" + def _get_statistics_value(self, *, stat_func: Callable[[list[float]], float], in_euro: bool) -> float | None: + """Handle statistics sensor values using the provided statistical function.""" if not self.coordinator.data: return None - price_info = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ]["priceInfo"] + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] today_prices = price_info.get("today", []) if not today_prices: return None - key = self.entity_description.key prices = [float(price["total"]) for price in today_prices] - - if key in ["lowest_price_today", "lowest_price_today_eur"]: - value = min(prices) - elif key in ["highest_price_today", "highest_price_today_eur"]: - value = max(prices) - elif key in ["average_price_today", "average_price_today_eur"]: - value = sum(prices) / len(prices) - else: + if not prices: return None - return self._get_price_value(value) if key.endswith("today") else value + value = stat_func(prices) + return self._get_price_value(value, in_euro=in_euro) - def _get_rating_value(self) -> float | None: + def _get_rating_value(self, *, rating_type: str) -> float | None: """Handle rating sensor values.""" if not self.coordinator.data: return None - def check_hourly(entry: dict) -> bool: - return datetime.fromisoformat(entry["time"]).hour == now.hour - - def check_daily(entry: dict) -> bool: - return datetime.fromisoformat(entry["time"]).date() == now.date() - - def check_monthly(entry: dict) -> bool: - dt = datetime.fromisoformat(entry["time"]) - return dt.month == now.month and dt.year == now.year - - subscription = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] + subscription = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"] price_rating = subscription.get("priceRating", {}) or {} now = datetime.now(tz=UTC).astimezone() - key = self.entity_description.key - if key == "hourly_rating": - rating_data = price_rating.get("hourly", {}) - entries = rating_data.get("entries", []) if rating_data else [] - time_match = check_hourly - elif key == "daily_rating": - rating_data = price_rating.get("daily", {}) - entries = rating_data.get("entries", []) if rating_data else [] - time_match = check_daily - elif key == "monthly_rating": - rating_data = price_rating.get("monthly", {}) - entries = rating_data.get("entries", []) if rating_data else [] - time_match = check_monthly - else: - return None + rating_data = price_rating.get(rating_type, {}) + entries = rating_data.get("entries", []) if rating_data else [] + + if rating_type == "hourly": + for entry in entries: + entry_time = datetime.fromisoformat(entry["time"]) + if entry_time.hour == now.hour: + return round(float(entry["difference"]) * 100, 1) + elif rating_type == "daily": + for entry in entries: + entry_time = datetime.fromisoformat(entry["time"]) + if entry_time.date() == now.date(): + return round(float(entry["difference"]) * 100, 1) + elif rating_type == "monthly": + for entry in entries: + entry_time = datetime.fromisoformat(entry["time"]) + if entry_time.month == now.month and entry_time.year == now.year: + return round(float(entry["difference"]) * 100, 1) - for entry in entries: - if time_match(entry): - return round(float(entry["difference"]) * 100, 1) return None - def _get_diagnostic_value(self) -> datetime | str | None: - """Handle diagnostic sensor values.""" + def _get_data_timestamp(self) -> datetime | None: + """Get the latest data timestamp.""" if not self.coordinator.data: return None - price_info = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ]["priceInfo"] - key = self.entity_description.key + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] + latest_timestamp = None - if key == "data_timestamp": - latest_timestamp = None - for day in ["today", "tomorrow"]: - for price_data in price_info.get(day, []): - timestamp = datetime.fromisoformat(price_data["startsAt"]) - if not latest_timestamp or timestamp > latest_timestamp: - latest_timestamp = timestamp - return dt_util.as_utc(latest_timestamp) if latest_timestamp else None + for day in ["today", "tomorrow"]: + for price_data in price_info.get(day, []): + timestamp = datetime.fromisoformat(price_data["startsAt"]) + if not latest_timestamp or timestamp > latest_timestamp: + latest_timestamp = timestamp - if key == "tomorrow_data_available": - tomorrow_prices = price_info.get("tomorrow", []) - if not tomorrow_prices: - return "No" - return "Yes" if len(tomorrow_prices) == HOURS_IN_DAY else "Partial" + return dt_util.as_utc(latest_timestamp) if latest_timestamp else None - return None + def _get_tomorrow_data_status(self) -> str | None: + """Get tomorrow's data availability status.""" + if not self.coordinator.data: + return None + + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] + tomorrow_prices = price_info.get("tomorrow", []) + + if not tomorrow_prices: + return "No" + return "Yes" if len(tomorrow_prices) == HOURS_IN_DAY else "Partial" @property def native_value(self) -> float | str | datetime | None: """Return the native value of the sensor.""" - result = None try: - if self.coordinator.data: - key = self.entity_description.key - current_hour_data = self._get_current_hour_data() - - if key == "price_level": - result = current_hour_data["level"] if current_hour_data else None - elif key in [ - "current_price", - "current_price_eur", - "next_hour_price", - "next_hour_price_eur", - ]: - result = self._get_price_sensor_value() - elif "price_today" in key: - result = self._get_statistics_value() - elif "rating" in key: - result = self._get_rating_value() - elif key in ["data_timestamp", "tomorrow_data_available"]: - result = self._get_diagnostic_value() - else: - result = None - else: - result = None + if not self.coordinator.data or not self._value_getter: + return None + return self._value_getter() except (KeyError, ValueError, TypeError) as ex: self.coordinator.logger.exception( "Error getting sensor value", @@ -397,101 +374,57 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): "entity": self.entity_description.key, }, ) - result = None - return result + return None @property - def extra_state_attributes(self) -> dict | None: # noqa: PLR0912 + def extra_state_attributes(self) -> dict | None: """Return additional state attributes.""" + if not self.coordinator.data: + return None + + attributes = self._get_sensor_attributes() + + # Add translated description + if attributes and self.hass is not None: + base_key = "entity.sensor" + key = f"{base_key}.{self.entity_description.translation_key}.description" + language_config = getattr(self.hass.config, "language", None) + if isinstance(language_config, dict): + description = language_config.get(key) + if description is not None: + attributes = dict(attributes) # Make a copy before modifying + attributes["description"] = description + + return attributes + + def _get_sensor_attributes(self) -> dict | None: + """Get attributes based on sensor type.""" try: - if not self.coordinator.data: - return None + key = self.entity_description.key + attributes: dict[str, Any] = {} - subscription = self.coordinator.data["data"]["viewer"]["homes"][0][ - "currentSubscription" - ] - price_info = subscription["priceInfo"] + # Get the timestamp attribute for different sensor types + price_info = self.coordinator.data["data"]["viewer"]["homes"][0]["currentSubscription"]["priceInfo"] - attributes = {} - - # Get current hour's data for timestamp - now = datetime.now(tz=UTC).astimezone() current_hour_data = self._get_current_hour_data() + now = datetime.now(tz=UTC).astimezone() - if self.entity_description.key in ["current_price", "current_price_eur"]: - attributes["timestamp"] = ( - current_hour_data["startsAt"] if current_hour_data else None - ) - - if self.entity_description.key in [ - "next_hour_price", - "next_hour_price_eur", - ]: + # Price sensors timestamps + if key in ["current_price", "current_price_eur", "price_level"]: + attributes["timestamp"] = current_hour_data["startsAt"] if current_hour_data else None + elif key in ["next_hour_price", "next_hour_price_eur"]: next_hour = (now.hour + 1) % 24 for price_data in price_info.get("today", []): starts_at = datetime.fromisoformat(price_data["startsAt"]) if starts_at.hour == next_hour: attributes["timestamp"] = price_data["startsAt"] break - - if self.entity_description.key == "price_level": - attributes["timestamp"] = ( - current_hour_data["startsAt"] if current_hour_data else None - ) - - if self.entity_description.key == "lowest_price_today": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "highest_price_today": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "average_price_today": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "hourly_rating": - attributes["timestamp"] = ( - current_hour_data["startsAt"] if current_hour_data else None - ) - - if self.entity_description.key == "daily_rating": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "monthly_rating": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "data_timestamp": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - if self.entity_description.key == "tomorrow_data_available": - attributes["timestamp"] = price_info.get("today", [{}])[0].get( - "startsAt" - ) - - # Add translated description - if self.hass is not None: - base_key = "entity.sensor" - key = ( - f"{base_key}.{self.entity_description.translation_key}.description" - ) - language_config = getattr(self.hass.config, "language", None) - if isinstance(language_config, dict): - description = language_config.get(key) - if description is not None: - attributes["description"] = description - - return attributes if attributes else None # noqa: TRY300 + # Statistics, rating, and diagnostic sensors + elif any( + pattern in key for pattern in ["_price_today", "rating", "data_timestamp", "tomorrow_data_available"] + ): + first_timestamp = price_info.get("today", [{}])[0].get("startsAt") + attributes["timestamp"] = first_timestamp except (KeyError, ValueError, TypeError) as ex: self.coordinator.logger.exception( @@ -502,3 +435,5 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): }, ) return None + else: + return attributes if attributes else None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1c4ac2a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +# pyproject.toml + +[tool.black] +line-length = 120 +target-version = ['py313'] +skip-string-normalization = false + +[tool.isort] +profile = "black" +line_length = 120