diff --git a/custom_components/tibber_prices/coordinator.py b/custom_components/tibber_prices/coordinator.py index 2a25814..fecea2d 100644 --- a/custom_components/tibber_prices/coordinator.py +++ b/custom_components/tibber_prices/coordinator.py @@ -30,6 +30,7 @@ from .const import ( DEFAULT_PRICE_RATING_THRESHOLD_LOW, DOMAIN, ) +from .price_utils import enrich_price_info_with_differences _LOGGER = logging.getLogger(__name__) @@ -157,24 +158,12 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Any]]): # Get price data for all homes price_data = await self.api.async_get_price_info() - # Get rating data for all homes - hourly_rating = await self.api.async_get_hourly_price_rating() - daily_rating = await self.api.async_get_daily_price_rating() - monthly_rating = await self.api.async_get_monthly_price_rating() - all_homes_data = {} homes_list = price_data.get("homes", {}) for home_id, home_price_data in homes_list.items(): - hourly_data = hourly_rating.get("homes", {}).get(home_id, {}) - daily_data = daily_rating.get("homes", {}).get(home_id, {}) - monthly_data = monthly_rating.get("homes", {}).get(home_id, {}) - home_data = { "price_info": home_price_data, - "hourly_rating": hourly_data.get("hourly", []), - "daily_rating": daily_data.get("daily", []), - "monthly_rating": monthly_data.get("monthly", []), } all_homes_data[home_id] = home_data @@ -292,26 +281,26 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Any]]): "timestamp": raw_data.get("timestamp"), "homes": {}, "priceInfo": {}, - "priceRating": {}, } # Use the first home's data as the main entry's data first_home_data = next(iter(homes_data.values())) price_info = first_home_data.get("price_info", {}) - # Combine rating data - wrap entries in dict for sensor compatibility - price_rating = { - "hourly": {"entries": first_home_data.get("hourly_rating", [])}, - "daily": {"entries": first_home_data.get("daily_rating", [])}, - "monthly": {"entries": first_home_data.get("monthly_rating", [])}, - "thresholdPercentages": self._get_threshold_percentages(), - } + # Get threshold percentages for enrichment + thresholds = self._get_threshold_percentages() + + # Enrich price info with calculated differences (trailing 24h averages) + price_info = enrich_price_info_with_differences( + price_info, + threshold_low=thresholds["low"], + threshold_high=thresholds["high"], + ) return { "timestamp": raw_data.get("timestamp"), "homes": homes_data, "priceInfo": price_info, - "priceRating": price_rating, } def _transform_data_for_subentry(self, main_data: dict[str, Any]) -> dict[str, Any]: @@ -327,23 +316,23 @@ class TibberPricesDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Any]]): return { "timestamp": main_data.get("timestamp"), "priceInfo": {}, - "priceRating": {}, } price_info = home_data.get("price_info", {}) - # Combine rating data for this specific home - wrap entries in dict for sensor compatibility - price_rating = { - "hourly": {"entries": home_data.get("hourly_rating", [])}, - "daily": {"entries": home_data.get("daily_rating", [])}, - "monthly": {"entries": home_data.get("monthly_rating", [])}, - "thresholdPercentages": self._get_threshold_percentages(), - } + # Get threshold percentages for enrichment + thresholds = self._get_threshold_percentages() + + # Enrich price info with calculated differences (trailing 24h averages) + price_info = enrich_price_info_with_differences( + price_info, + threshold_low=thresholds["low"], + threshold_high=thresholds["high"], + ) return { "timestamp": main_data.get("timestamp"), "priceInfo": price_info, - "priceRating": price_rating, } # --- Methods expected by sensors and services --- diff --git a/custom_components/tibber_prices/price_utils.py b/custom_components/tibber_prices/price_utils.py new file mode 100644 index 0000000..9284a6b --- /dev/null +++ b/custom_components/tibber_prices/price_utils.py @@ -0,0 +1,258 @@ +"""Utility functions for price data calculations.""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from typing import Any + +from homeassistant.util import dt as dt_util + +_LOGGER = logging.getLogger(__name__) + +MINUTES_PER_INTERVAL = 15 + + +def calculate_trailing_average_for_interval( + interval_start: datetime, + all_prices: list[dict[str, Any]], +) -> float | None: + """ + Calculate the trailing 24-hour average price for a specific interval. + + Args: + interval_start: The start time of the interval we're calculating for + all_prices: List of all available price intervals (yesterday + today + tomorrow) + + Returns: + The average price of all intervals in the 24 hours before interval_start, + or None if insufficient data is available. + + """ + if not all_prices: + return None + + # Calculate the lookback period: 24 hours before this interval + lookback_start = interval_start - timedelta(hours=24) + + # Collect all prices that fall within the 24-hour lookback window + matching_prices = [] + + for price_data in all_prices: + starts_at_str = price_data.get("startsAt") + if not starts_at_str: + continue + + # Parse the timestamp + price_time = dt_util.parse_datetime(starts_at_str) + if price_time is None: + continue + + # Convert to local timezone for comparison + price_time = dt_util.as_local(price_time) + + # Check if this price falls within our lookback window + # Include prices that start >= lookback_start and start < interval_start + if lookback_start <= price_time < interval_start: + total_price = price_data.get("total") + if total_price is not None: + matching_prices.append(float(total_price)) + + if not matching_prices: + _LOGGER.debug( + "No prices found in 24-hour lookback window for interval starting at %s (lookback: %s to %s)", + interval_start, + lookback_start, + interval_start, + ) + return None + + # Calculate and return the average + average = sum(matching_prices) / len(matching_prices) + _LOGGER.debug( + "Calculated trailing 24h average for interval %s: %.6f from %d prices", + interval_start, + average, + len(matching_prices), + ) + return average + + +def calculate_difference_percentage( + current_price: float, + trailing_average: float | None, +) -> float | None: + """ + Calculate the difference percentage between current price and trailing average. + + This mimics the API's "difference" field from priceRating endpoint. + + Args: + current_price: The current interval's price + trailing_average: The 24-hour trailing average price + + Returns: + The percentage difference: ((current - average) / average) * 100 + or None if trailing_average is None or zero. + + """ + if trailing_average is None or trailing_average == 0: + return None + + return ((current_price - trailing_average) / trailing_average) * 100 + + +def calculate_rating_level( + difference: float | None, + threshold_low: float, + threshold_high: float, +) -> str | None: + """ + Calculate the rating level based on difference percentage and thresholds. + + This mimics the API's "level" field from priceRating endpoint. + + Args: + difference: The difference percentage (from calculate_difference_percentage) + threshold_low: The low threshold percentage (typically -100 to 0) + threshold_high: The high threshold percentage (typically 0 to 100) + + Returns: + "LOW" if difference <= threshold_low + "HIGH" if difference >= threshold_high + "NORMAL" otherwise + None if difference is None + + """ + if difference is None: + return None + + # If difference falls in both ranges (shouldn't normally happen), return NORMAL + if difference <= threshold_low and difference >= threshold_high: + return "NORMAL" + + # Classify based on thresholds + if difference <= threshold_low: + return "LOW" + + if difference >= threshold_high: + return "HIGH" + + return "NORMAL" + + +def _process_price_interval( + price_interval: dict[str, Any], + all_prices: list[dict[str, Any]], + threshold_low: float, + threshold_high: float, + day_label: str, +) -> None: + """ + Process a single price interval and add difference and rating_level. + + Args: + price_interval: The price interval to process (modified in place) + all_prices: All available price intervals for lookback calculation + threshold_low: Low threshold percentage + threshold_high: High threshold percentage + day_label: Label for logging ("today" or "tomorrow") + + """ + starts_at_str = price_interval.get("startsAt") + if not starts_at_str: + return + + starts_at = dt_util.parse_datetime(starts_at_str) + if starts_at is None: + return + + starts_at = dt_util.as_local(starts_at) + current_price = price_interval.get("total") + + if current_price is None: + return + + # Calculate trailing average + trailing_avg = calculate_trailing_average_for_interval(starts_at, all_prices) + + # Calculate and set the difference and rating_level + if trailing_avg is not None: + difference = calculate_difference_percentage(float(current_price), trailing_avg) + price_interval["difference"] = difference + + # Calculate rating_level based on difference + rating_level = calculate_rating_level(difference, threshold_low, threshold_high) + price_interval["rating_level"] = rating_level + + _LOGGER.debug( + "Set difference and rating_level for %s interval %s: difference=%.2f%%, level=%s (price: %.6f, avg: %.6f)", + day_label, + starts_at, + difference if difference is not None else 0, + rating_level, + float(current_price), + trailing_avg, + ) + else: + # Set to None if we couldn't calculate + price_interval["difference"] = None + price_interval["rating_level"] = None + _LOGGER.debug( + "Could not calculate trailing average for %s interval %s", + day_label, + starts_at, + ) + + +def enrich_price_info_with_differences( + price_info: dict[str, Any], + threshold_low: float | None = None, + threshold_high: float | None = None, +) -> dict[str, Any]: + """ + Enrich price info with calculated 'difference' and 'rating_level' values. + + Computes the trailing 24-hour average, difference percentage, and rating level + for each interval in today and tomorrow (excluding yesterday since it's historical). + + Args: + price_info: Dictionary with 'yesterday', 'today', 'tomorrow' keys + threshold_low: Low threshold percentage for rating_level (defaults to -10) + threshold_high: High threshold percentage for rating_level (defaults to 10) + + Returns: + Updated price_info dict with 'difference' and 'rating_level' added + + """ + if threshold_low is None: + threshold_low = -10 + if threshold_high is None: + threshold_high = 10 + + yesterday_prices = price_info.get("yesterday", []) + today_prices = price_info.get("today", []) + tomorrow_prices = price_info.get("tomorrow", []) + + # Combine all prices for lookback calculation + all_prices = yesterday_prices + today_prices + tomorrow_prices + + _LOGGER.debug( + "Enriching price info with differences and rating levels: " + "yesterday=%d, today=%d, tomorrow=%d, thresholds: low=%.2f, high=%.2f", + len(yesterday_prices), + len(today_prices), + len(tomorrow_prices), + threshold_low, + threshold_high, + ) + + # Process today's prices + for price_interval in today_prices: + _process_price_interval(price_interval, all_prices, threshold_low, threshold_high, "today") + + # Process tomorrow's prices + for price_interval in tomorrow_prices: + _process_price_interval(price_interval, all_prices, threshold_low, threshold_high, "tomorrow") + + return price_info diff --git a/custom_components/tibber_prices/sensor.py b/custom_components/tibber_prices/sensor.py index 6f408f1..2c5769a 100644 --- a/custom_components/tibber_prices/sensor.py +++ b/custom_components/tibber_prices/sensor.py @@ -161,18 +161,6 @@ RATING_SENSORS = ( name="Current Price Rating", icon="mdi:clock-outline", ), - SensorEntityDescription( - key="daily_rating", - translation_key="daily_rating", - name="Daily Price Rating", - icon="mdi:calendar-today", - ), - SensorEntityDescription( - key="monthly_rating", - translation_key="monthly_rating", - name="Monthly Price Rating", - icon="mdi:calendar-month", - ), ) # Diagnostic sensors for data availability @@ -258,9 +246,7 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): stat_func=lambda prices: sum(prices) / len(prices), in_euro=True, decimals=4 ), # Rating sensors - "price_rating": lambda: self._get_rating_value(rating_type="hourly"), - "daily_rating": lambda: self._get_rating_value(rating_type="daily"), - "monthly_rating": lambda: self._get_rating_value(rating_type="monthly"), + "price_rating": lambda: self._get_rating_value(rating_type="current"), # Diagnostic sensors "data_timestamp": self._get_data_timestamp, # Price forecast sensor @@ -431,72 +417,30 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): return en_translations["sensor"]["price_rating"]["price_levels"][level] return level - def _find_rating_entry(self, entries: list[dict], now: datetime, rating_type: str) -> dict | None: - """Find the correct rating entry for the given type and time.""" - if not entries: - return None - predicate = None - if rating_type == "hourly": - - def interval_predicate(entry_time: datetime) -> bool: - interval_end = entry_time + timedelta(minutes=MINUTES_PER_INTERVAL) - return entry_time <= now < interval_end and entry_time.date() == now.date() - - predicate = interval_predicate - elif rating_type == "daily": - - def daily_predicate(entry_time: datetime) -> bool: - return dt_util.as_local(entry_time).date() == now.date() - - predicate = daily_predicate - elif rating_type == "monthly": - - def monthly_predicate(entry_time: datetime) -> bool: - local_time = dt_util.as_local(entry_time) - return local_time.month == now.month and local_time.year == now.year - - predicate = monthly_predicate - if predicate: - for entry in entries: - entry_time = dt_util.parse_datetime(entry["time"]) - if entry_time and predicate(entry_time): - return entry - # For hourly, fallback to hour match if not found - if rating_type == "hourly": - for entry in entries: - entry_time = dt_util.parse_datetime(entry["time"]) - if entry_time: - entry_time = dt_util.as_local(entry_time) - if entry_time.hour == now.hour and entry_time.date() == now.date(): - return entry - return None - def _get_rating_value(self, *, rating_type: str) -> str | None: """ - Handle rating sensor values for hourly, daily, and monthly ratings. + Get the price rating level from the current price interval in priceInfo. Returns the translated rating level as the main status, and stores the original level and percentage difference as attributes. """ - if not self.coordinator.data: + if not self.coordinator.data or rating_type != "current": self._last_rating_difference = None self._last_rating_level = None return None - price_rating = self.coordinator.data.get("priceRating", {}) + now = dt_util.now() - # price_rating[rating_type] contains a dict with "entries" key, extract it - rating_data = price_rating.get(rating_type, {}) - if isinstance(rating_data, dict): - entries = rating_data.get("entries", []) - else: - entries = rating_data if isinstance(rating_data, list) else [] - entry = self._find_rating_entry(entries, now, rating_type) - if entry: - difference = entry.get("difference") - level = entry.get("level") - self._last_rating_difference = float(difference) if difference is not None else None - self._last_rating_level = level if level is not None else None - return self._translate_rating_level(level or "") + price_info = self.coordinator.data.get("priceInfo", {}) + current_interval = find_price_data_for_interval(price_info, now) + + if current_interval: + rating_level = current_interval.get("rating_level") + difference = current_interval.get("difference") + if rating_level is not None: + self._last_rating_difference = float(difference) if difference is not None else None + self._last_rating_level = rating_level + return self._translate_rating_level(rating_level) + self._last_rating_difference = None self._last_rating_level = None return None @@ -542,7 +486,6 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): return None price_info = self.coordinator.data.get("priceInfo", {}) - price_rating = self.coordinator.data.get("priceRating", {}) today_prices = price_info.get("today", []) tomorrow_prices = price_info.get("tomorrow", []) @@ -559,22 +502,6 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): # Track the maximum intervals to return intervals_to_return = MAX_FORECAST_INTERVALS if max_intervals is None else max_intervals - # Extract hourly rating data for enriching the forecast - rating_data = {} - hourly_rating = price_rating.get("hourly", {}) - if hourly_rating and "entries" in hourly_rating: - for entry in hourly_rating.get("entries", []): - if entry.get("time"): - timestamp = dt_util.parse_datetime(entry["time"]) - if timestamp: - timestamp = dt_util.as_local(timestamp) - # Store with ISO format key for easier lookup - time_key = timestamp.replace(second=0, microsecond=0).isoformat() - rating_data[time_key] = { - "difference": float(entry.get("difference", 0)), - "rating_level": entry.get("level"), - } - for day_key in ["today", "tomorrow"]: for price_data in price_info.get(day_key, []): starts_at = dt_util.parse_datetime(price_data["startsAt"]) @@ -585,25 +512,21 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): interval_end = starts_at + timedelta(minutes=MINUTES_PER_INTERVAL) if starts_at > now: - starts_at_key = starts_at.replace(second=0, microsecond=0).isoformat() - - interval_rating = rating_data.get(starts_at_key) or {} - future_prices.append( { "interval_start": starts_at.isoformat(), "interval_end": interval_end.isoformat(), "price": float(price_data["total"]), "price_cents": round(float(price_data["total"]) * 100, 2), - "level": price_data.get("level", "NORMAL"), # Price level from priceInfo - "rating": interval_rating.get("difference", None), # Rating from priceRating - "rating_level": interval_rating.get("rating_level"), # Level from priceRating + "level": price_data.get("level", "NORMAL"), + "rating": price_data.get("difference", None), + "rating_level": price_data.get("rating_level"), "day": day_key, } ) # Sort by start time - future_prices.sort(key=lambda x: x["interval_start"]) # Updated sort key + future_prices.sort(key=lambda x: x["interval_start"]) # Limit to the requested number of intervals return future_prices[:intervals_to_return] if future_prices else None @@ -868,10 +791,11 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): break def _add_statistics_attributes(self, attributes: dict) -> None: - """Add attributes for statistics, rating, and diagnostic sensors.""" + """Add attributes for statistics and rating sensors.""" key = self.entity_description.key price_info = self.coordinator.data.get("priceInfo", {}) now = dt_util.now() + if key == "price_rating": interval_data = find_price_data_for_interval(price_info, now) attributes["timestamp"] = interval_data["startsAt"] if interval_data else None @@ -880,21 +804,6 @@ class TibberPricesSensor(TibberPricesEntity, SensorEntity): if hasattr(self, "_last_rating_level") and self._last_rating_level is not None: attributes["level_id"] = self._last_rating_level attributes["level_value"] = PRICE_RATING_MAPPING.get(self._last_rating_level, self._last_rating_level) - elif key == "daily_rating": - attributes["timestamp"] = now.replace(hour=0, minute=0, second=0, microsecond=0).isoformat() - if hasattr(self, "_last_rating_difference") and self._last_rating_difference is not None: - attributes["difference_" + PERCENTAGE] = self._last_rating_difference - if hasattr(self, "_last_rating_level") and self._last_rating_level is not None: - attributes["level_id"] = self._last_rating_level - attributes["level_value"] = PRICE_RATING_MAPPING.get(self._last_rating_level, self._last_rating_level) - elif key == "monthly_rating": - first_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - attributes["timestamp"] = first_of_month.isoformat() - if hasattr(self, "_last_rating_difference") and self._last_rating_difference is not None: - attributes["difference_" + PERCENTAGE] = self._last_rating_difference - if hasattr(self, "_last_rating_level") and self._last_rating_level is not None: - attributes["level_id"] = self._last_rating_level - attributes["level_value"] = PRICE_RATING_MAPPING.get(self._last_rating_level, self._last_rating_level) else: # Fallback: use the first timestamp of today first_timestamp = price_info.get("today", [{}])[0].get("startsAt") diff --git a/custom_components/tibber_prices/services.py b/custom_components/tibber_prices/services.py index b8ea9b8..1437dd9 100644 --- a/custom_components/tibber_prices/services.py +++ b/custom_components/tibber_prices/services.py @@ -80,21 +80,11 @@ REFRESH_USER_DATA_SERVICE_SCHEMA: Final = vol.Schema( } ) -# region Top-level functions (ordered by call hierarchy) - # --- Entry point: Service handler --- async def _get_price(call: ServiceCall) -> dict[str, Any]: - """ - Return merged priceInfo and priceRating for the requested day and config entry. - - If 'time' is provided, it must be in HH:mm or HH:mm:ss format and is combined with the selected 'day'. - This only affects 'previous', 'current', and 'next' fields, not the 'prices' list. - If 'time' is not provided, the current time is used for all days. - If 'day' is not provided, the prices list will include today and tomorrow, but stats and interval - selection are only for today. - """ + """Return price information for the requested day and config entry.""" hass = call.hass entry_id_raw = call.data.get(ATTR_ENTRY_ID) if entry_id_raw is None: @@ -102,62 +92,62 @@ async def _get_price(call: ServiceCall) -> dict[str, Any]: entry_id: str = str(entry_id_raw) time_value = call.data.get(ATTR_TIME) explicit_day = ATTR_DAY in call.data - day = call.data.get(ATTR_DAY) + day = call.data.get(ATTR_DAY, "today") - entry, coordinator, data = _get_entry_and_data(hass, entry_id) - price_info_data, price_rating_data, hourly_ratings, rating_threshold_percentages, currency = _extract_price_data( - data + _, coordinator, _ = _get_entry_and_data(hass, entry_id) + price_info_data, currency = _extract_price_data(coordinator.data) + + # Determine which days to include + if explicit_day: + day_key = day if day in ("yesterday", "today", "tomorrow") else "today" + prices_raw = price_info_data.get(day_key, []) + stats_raw = prices_raw + else: + # No explicit day: include today + tomorrow for prices, use today for stats + today_raw = price_info_data.get("today", []) + tomorrow_raw = price_info_data.get("tomorrow", []) + prices_raw = today_raw + tomorrow_raw + stats_raw = today_raw + day_key = "today" + + # Transform to service format + prices_transformed = _transform_price_intervals(prices_raw) + stats_transformed = _transform_price_intervals(stats_raw) + + # Calculate stats + price_stats = _get_price_stats(stats_transformed) + + # Determine now and simulation flag + now, is_simulated = _determine_now_and_simulation(time_value, stats_transformed) + + # Select intervals + previous_interval, current_interval, next_interval = _select_intervals( + stats_transformed, coordinator, day_key, now, simulated=is_simulated ) - price_info_by_day, day_prefixes, ratings_by_day = _prepare_day_structures(price_info_data, hourly_ratings) + # Add end_time to intervals + _annotate_end_times(prices_transformed, price_info_data, day_key) - ( - merged, - stats_merged, - interval_selection_merged, - interval_selection_ratings, - interval_selection_day, - ) = _select_merge_strategy( - explicit_day=explicit_day, - day=day if day is not None else "today", - price_info_by_day=price_info_by_day, - ratings_by_day=ratings_by_day, - ) - - _annotate_intervals_with_times( - merged, - price_info_by_day, - interval_selection_day, - ) - - price_stats = _get_price_stats(stats_merged) - - now, is_simulated = _determine_now_and_simulation(time_value, interval_selection_merged) - - ctx = IntervalContext( - merged=interval_selection_merged, - all_ratings=interval_selection_ratings, - coordinator=coordinator, - day=interval_selection_day, - now=now, - is_simulated=is_simulated, - ) - previous_interval, current_interval, next_interval = _select_intervals(ctx) - - for interval in merged: - if "previous_end_time" in interval: - del interval["previous_end_time"] + # Clean up temp fields from all intervals + for interval in prices_transformed: if "start_dt" in interval: del interval["start_dt"] + # Also clean up from selected intervals + if previous_interval and "start_dt" in previous_interval: + del previous_interval["start_dt"] + if current_interval and "start_dt" in current_interval: + del current_interval["start_dt"] + if next_interval and "start_dt" in next_interval: + del next_interval["start_dt"] + response_ctx = PriceResponseContext( price_stats=price_stats, previous_interval=previous_interval, current_interval=current_interval, next_interval=next_interval, currency=currency, - rating_threshold_percentages=rating_threshold_percentages, - merged=merged, + merged=prices_transformed, ) return _build_price_response(response_ctx) @@ -185,7 +175,7 @@ async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]: if not entry_id: raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id") - entry, coordinator, data = _get_entry_and_data(hass, entry_id) + _, coordinator, _ = _get_entry_and_data(hass, entry_id) # Get entries based on level_type entries = _get_apexcharts_entries(coordinator, day, level_type) @@ -201,21 +191,12 @@ async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]: return {"points": points} -def _get_apexcharts_entries(coordinator: Any, day: str, level_type: str) -> list[dict]: - """Get the appropriate entries for ApexCharts based on level_type and day.""" - if level_type == "rating_level": - entries = coordinator.data.get("priceRating", {}).get("hourly", []) - price_info = coordinator.data.get("priceInfo", {}) - day_info = price_info.get(day, []) - prefixes = _get_day_prefixes(day_info) - - if not prefixes: - return [] - - return [e for e in entries if e.get("time", e.get("startsAt", "")).startswith(prefixes[0])] - - # For non-rating level types, return the price info for the specified day - return coordinator.data.get("priceInfo", {}).get(day, []) +def _get_apexcharts_entries(coordinator: Any, day: str, _: str) -> list[dict]: + """Get the appropriate entries for ApexCharts based on day.""" + # Price info is already enriched with difference and rating_level from coordinator + price_info = coordinator.data.get("priceInfo", {}) + day_info = price_info.get(day, []) + return day_info if day_info else [] def _generate_apexcharts_points(entries: list[dict], level_key: str) -> list: @@ -310,7 +291,7 @@ async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]: # Get the entry and coordinator try: - entry, coordinator, data = _get_entry_and_data(hass, entry_id) + entry, coordinator, _ = _get_entry_and_data(hass, entry_id) except ServiceValidationError as ex: return { "success": False, @@ -348,7 +329,7 @@ async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]: } -# --- Direct helpers (called by service handler or each other) --- +# --- Helpers --- def _get_entry_and_data(hass: HomeAssistant, entry_id: str) -> tuple[Any, Any, dict]: @@ -363,219 +344,72 @@ def _get_entry_and_data(hass: HomeAssistant, entry_id: str) -> tuple[Any, Any, d return entry, coordinator, data -def _extract_price_data(data: dict) -> tuple[dict, dict, list, Any, Any]: - """Extract price info and rating data from coordinator data.""" +def _extract_price_data(data: dict) -> tuple[dict, Any]: + """Extract price info from enriched coordinator data.""" price_info_data = data.get("priceInfo") or {} - price_rating_data = data.get("priceRating") or {} - hourly_ratings = price_rating_data.get("hourly") or [] - rating_threshold_percentages = price_rating_data.get("thresholdPercentages") currency = price_info_data.get("currency") - return price_info_data, price_rating_data, hourly_ratings, rating_threshold_percentages, currency + return price_info_data, currency -def _prepare_day_structures(price_info_data: dict, hourly_ratings: list) -> tuple[dict, dict, dict]: - """Prepare price info, day prefixes, and ratings by day.""" - price_info_by_day = {d: price_info_data.get(d) or [] for d in ("yesterday", "today", "tomorrow")} - day_prefixes = {d: _get_day_prefixes(price_info_by_day[d]) for d in ("yesterday", "today", "tomorrow")} - ratings_by_day = { - d: [ - r - for r in hourly_ratings - if isinstance(r, dict) - and day_prefixes[d] - and r.get("time", r.get("startsAt", "")).startswith(day_prefixes[d][0]) - ] - if price_info_by_day[d] and day_prefixes[d] - else [] - for d in ("yesterday", "today", "tomorrow") - } - return price_info_by_day, day_prefixes, ratings_by_day - - -def _select_merge_strategy( - *, - explicit_day: bool, - day: str, - price_info_by_day: dict, - ratings_by_day: dict, -) -> tuple[list, list, list, list, str]: - """Select merging strategy for intervals and stats.""" - if not explicit_day: - merged_today = _merge_priceinfo_and_pricerating(price_info_by_day["today"], ratings_by_day["today"]) - merged_tomorrow = _merge_priceinfo_and_pricerating(price_info_by_day["tomorrow"], ratings_by_day["tomorrow"]) - merged = merged_today + merged_tomorrow - stats_merged = merged_today - interval_selection_merged = merged_today - interval_selection_ratings = ratings_by_day["today"] - interval_selection_day = "today" - else: - day_key = day if day in ("yesterday", "today", "tomorrow") else "today" - merged = _merge_priceinfo_and_pricerating(price_info_by_day[day_key], ratings_by_day[day_key]) - stats_merged = merged - interval_selection_merged = merged - interval_selection_ratings = ratings_by_day[day_key] - interval_selection_day = day_key - return ( - merged, - stats_merged, - interval_selection_merged, - interval_selection_ratings, - interval_selection_day, - ) - - -def _get_day_prefixes(day_info: list[dict]) -> list[str]: - """Return a list of unique day prefixes from the intervals' start datetimes.""" - prefixes = set() - for interval in day_info: - dt_str = interval.get("time") or interval.get("startsAt") - if not dt_str: - continue - start_dt = dt_util.parse_datetime(dt_str) - if start_dt: - prefixes.add(start_dt.date().isoformat()) - return list(prefixes) - - -def _get_adjacent_start_time(price_info_by_day: dict, day_key: str, *, first: bool) -> str | None: - """Get the start_time from the first/last interval of an adjacent day.""" - info = price_info_by_day.get(day_key) or [] - if not info: - return None - idx = 0 if first else -1 - return info[idx].get("startsAt") - - -def _merge_priceinfo_and_pricerating(price_info: list[dict], price_rating: list[dict]) -> list[dict]: - """ - Merge priceInfo and priceRating intervals by timestamp, prefixing rating fields. - - Also rename startsAt to start_time. Preserves item order. - Adds 'start_dt' (datetime) to each merged interval for reliable sorting/comparison. - """ - rating_by_time = {(r.get("time") or r.get("startsAt")): r for r in price_rating or []} - merged = [] +def _transform_price_intervals(price_info: list[dict]) -> list[dict]: + """Transform priceInfo intervals to service output format.""" + result = [] for interval in price_info or []: ts = interval.get("startsAt") start_dt = dt_util.parse_datetime(ts) if ts else None - merged_interval = {"start_time": ts, "start_dt": start_dt} if ts is not None else {"start_dt": None} + item = {"start_time": ts, "start_dt": start_dt} if ts else {"start_dt": None} + for k, v in interval.items(): if k == "startsAt": continue if k == "total": - merged_interval["price"] = v - merged_interval["price_minor"] = round(v * 100, 2) + item["price"] = v + item["price_minor"] = round(v * 100, 2) elif k not in ("energy", "tax"): - merged_interval[k] = v - rating = rating_by_time.get(ts) - if rating: - for k, v in rating.items(): - if k in ("time", "startsAt", "total", "tax", "energy"): - continue - if k == "difference": - merged_interval["rating_difference_%"] = v - elif k == "rating": - merged_interval["rating"] = v - else: - merged_interval[f"rating_{k}"] = v - merged.append(merged_interval) - # Always sort by start_dt (datetime), None values last - merged.sort(key=lambda x: (x.get("start_dt") is None, x.get("start_dt"))) - return merged + item[k] = v + + result.append(item) + + # Sort by datetime + result.sort(key=lambda x: (x.get("start_dt") is None, x.get("start_dt"))) + return result -def _find_previous_interval( - merged: list[dict], - all_ratings: list[dict], - coordinator: Any, - day: str, -) -> Any: - """Find previous interval from previous day if needed.""" - if merged and day == "today": - yday_info = coordinator.data.get("priceInfo", {}).get("yesterday", []) - if yday_info: - yday_ratings = [ - r - for r in all_ratings - if r.get("time", r.get("startsAt", "")).startswith(_get_day_prefixes(yday_info)[0]) - ] - yday_merged = _merge_priceinfo_and_pricerating(yday_info, yday_ratings) - if yday_merged: - return yday_merged[-1] - return None - - -def _find_next_interval( - merged: list[dict], - all_ratings: list[dict], - coordinator: Any, - day: str, -) -> Any: - """Find next interval from next day if needed.""" - if merged and day == "today": - tmrw_info = coordinator.data.get("priceInfo", {}).get("tomorrow", []) - if tmrw_info: - tmrw_ratings = [ - r - for r in all_ratings - if r.get("time", r.get("startsAt", "")).startswith(_get_day_prefixes(tmrw_info)[0]) - ] - tmrw_merged = _merge_priceinfo_and_pricerating(tmrw_info, tmrw_ratings) - if tmrw_merged: - return tmrw_merged[0] - return None - - -def _annotate_intervals_with_times( - merged: list[dict], - price_info_by_day: dict, - day: str, -) -> None: - """Annotate merged intervals with end_time and previous_end_time.""" +def _annotate_end_times(merged: list[dict], price_info_by_day: dict, day: str) -> None: + """Annotate merged intervals with end_time.""" for idx, interval in enumerate(merged): # Default: next interval's start_time if idx + 1 < len(merged): interval["end_time"] = merged[idx + 1].get("start_time") - # Last interval: look into tomorrow if today, or None otherwise - elif day == "today": - next_start = _get_adjacent_start_time(price_info_by_day, "tomorrow", first=True) - interval["end_time"] = next_start - elif day == "yesterday": - next_start = _get_adjacent_start_time(price_info_by_day, "today", first=True) - interval["end_time"] = next_start - elif day == "tomorrow": - interval["end_time"] = None + # Last interval: look into next day's first interval else: - interval["end_time"] = None - # First interval: look into yesterday if today, or None otherwise - if idx == 0: - if day == "today": - prev_end = _get_adjacent_start_time(price_info_by_day, "yesterday", first=False) - interval["previous_end_time"] = prev_end - elif day == "tomorrow": - prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False) - interval["previous_end_time"] = prev_end + next_day = "tomorrow" if day == "today" else (day if day == "tomorrow" else None) + if next_day and price_info_by_day.get(next_day): + first_of_next = price_info_by_day[next_day][0] + interval["end_time"] = first_of_next.get("startsAt") else: - interval["previous_end_time"] = None + interval["end_time"] = None def _get_price_stats(merged: list[dict]) -> PriceStats: - """Calculate average, min, and max price and their intervals from merged data.""" + """Calculate average, min, and max price from merged data.""" if merged: price_sum = sum(float(interval.get("price", 0)) for interval in merged if "price" in interval) price_avg = round(price_sum / len(merged), 4) else: price_avg = 0 - price_min, price_min_start_time, price_min_end_time = _get_price_stat(merged, "min") - price_max, price_max_start_time, price_max_end_time = _get_price_stat(merged, "max") + price_min, price_min_interval = _get_price_stat(merged, "min") + price_max, price_max_interval = _get_price_stat(merged, "max") return PriceStats( price_avg=price_avg, price_min=price_min, - price_min_start_time=price_min_start_time, - price_min_end_time=price_min_end_time, + price_min_start_time=price_min_interval.get("start_time") if price_min_interval else None, + price_min_end_time=price_min_interval.get("end_time") if price_min_interval else None, price_max=price_max, - price_max_start_time=price_max_start_time, - price_max_end_time=price_max_end_time, + price_max_start_time=price_max_interval.get("start_time") if price_max_interval else None, + price_max_end_time=price_max_interval.get("end_time") if price_max_interval else None, + price_min_interval=price_min_interval, + price_max_interval=price_max_interval, stats_merged=merged, ) @@ -587,7 +421,6 @@ def _determine_now_and_simulation( is_simulated = False if time_value: if not interval_selection_merged or not interval_selection_merged[0].get("start_time"): - # Instead of raising, return a simulated now for the requested day (structure will be empty) now = dt_util.now().replace(second=0, microsecond=0) is_simulated = True return now, is_simulated @@ -616,25 +449,14 @@ def _determine_now_and_simulation( return now, is_simulated -def _select_intervals(ctx: IntervalContext) -> tuple[Any, Any, Any]: - """ - Select previous, current, and next intervals for the given day and time. - - If is_simulated is True, always calculate previous/current/next for all days, but: - - For 'yesterday', never fetch previous from the day before yesterday. - - For 'tomorrow', never fetch next from the day after tomorrow. - If is_simulated is False, previous/current/next are None for 'yesterday' and 'tomorrow'. - """ - merged = ctx.merged - all_ratings = ctx.all_ratings - coordinator = ctx.coordinator - day = ctx.day - now = ctx.now - is_simulated = ctx.is_simulated - - if not merged or (not is_simulated and day in ("yesterday", "tomorrow")): +def _select_intervals( + merged: list[dict], coordinator: Any, day: str, now: datetime, *, simulated: bool +) -> tuple[Any, Any, Any]: + """Select previous, current, and next intervals for the given day and time.""" + if not merged or (not simulated and day in ("yesterday", "tomorrow")): return None, None, None + # Find current interval by time idx = None cmp_now = dt_util.as_local(now) if now.tzinfo is None else now for i, interval in enumerate(merged): @@ -654,89 +476,103 @@ def _select_intervals(ctx: IntervalContext) -> tuple[Any, Any, Any]: merged[idx + 1] if idx is not None and idx + 1 < len(merged) else (merged[0] if idx is None else None) ) + # For today, try to fetch adjacent intervals from neighboring days if day == "today": - if idx == 0: - previous_interval = _find_previous_interval(merged, all_ratings, coordinator, day) - if idx == len(merged) - 1: - next_interval = _find_next_interval(merged, all_ratings, coordinator, day) + if idx == 0 and previous_interval is None: + yday_info = coordinator.data.get("priceInfo", {}).get("yesterday", []) + if yday_info: + yday_transformed = _transform_price_intervals(yday_info) + if yday_transformed: + previous_interval = yday_transformed[-1] + + if idx == len(merged) - 1 and next_interval is None: + tmrw_info = coordinator.data.get("priceInfo", {}).get("tomorrow", []) + if tmrw_info: + tmrw_transformed = _transform_price_intervals(tmrw_info) + if tmrw_transformed: + next_interval = tmrw_transformed[0] return previous_interval, current_interval, next_interval -# --- Indirect helpers (called by helpers above) --- - - def _build_price_response(ctx: PriceResponseContext) -> dict[str, Any]: """Build the response dictionary for the price service.""" price_stats = ctx.price_stats + + # Helper to clean internal fields from interval + def clean_interval(interval: dict | None) -> dict | None: + """Remove internal fields like start_dt from interval.""" + if not interval: + return interval + return {k: v for k, v in interval.items() if k != "start_dt"} + + # Build average interval (synthetic, using first interval as template) + average_interval = {} + if price_stats.stats_merged: + first = price_stats.stats_merged[0] + # Copy all attributes from first interval (excluding internal fields) + for k in first: + if k not in ("start_time", "end_time", "start_dt", "price", "price_minor"): + average_interval[k] = first[k] + return { "average": { + **average_interval, "start_time": price_stats.stats_merged[0].get("start_time") if price_stats.stats_merged else None, "end_time": price_stats.stats_merged[0].get("end_time") if price_stats.stats_merged else None, "price": price_stats.price_avg, "price_minor": round(price_stats.price_avg * 100, 2), }, - "minimum": { + "minimum": clean_interval( + { + **price_stats.price_min_interval, + "price": price_stats.price_min, + "price_minor": round(price_stats.price_min * 100, 2), + } + ) + if price_stats.price_min_interval + else { "start_time": price_stats.price_min_start_time, "end_time": price_stats.price_min_end_time, "price": price_stats.price_min, "price_minor": round(price_stats.price_min * 100, 2), }, - "maximum": { + "maximum": clean_interval( + { + **price_stats.price_max_interval, + "price": price_stats.price_max, + "price_minor": round(price_stats.price_max * 100, 2), + } + ) + if price_stats.price_max_interval + else { "start_time": price_stats.price_max_start_time, "end_time": price_stats.price_max_end_time, "price": price_stats.price_max, "price_minor": round(price_stats.price_max * 100, 2), }, - "previous": ctx.previous_interval, - "current": ctx.current_interval, - "next": ctx.next_interval, + "previous": clean_interval(ctx.previous_interval), + "current": clean_interval(ctx.current_interval), + "next": clean_interval(ctx.next_interval), "currency": ctx.currency, - "rating_threshold_%": ctx.rating_threshold_percentages, "interval_count": len(ctx.merged), "intervals": ctx.merged, } -def _get_price_stat(merged: list[dict], stat: str) -> tuple[float, str | None, str | None]: - """Return min or max price and its start and end time from merged intervals.""" +def _get_price_stat(merged: list[dict], stat: str) -> tuple[float, dict | None]: + """Return min or max price and its full interval from merged intervals.""" if not merged: - return 0, None, None + return 0, None values = [float(interval.get("price", 0)) for interval in merged if "price" in interval] if not values: - return 0, None, None + return 0, None val = min(values) if stat == "min" else max(values) - start_time = next((interval.get("start_time") for interval in merged if interval.get("price") == val), None) - end_time = next((interval.get("end_time") for interval in merged if interval.get("price") == val), None) - return val, start_time, end_time + interval = next((interval for interval in merged if interval.get("price") == val), None) + return val, interval -# endregion - -# region Main classes (dataclasses) - - -@dataclass -class IntervalContext: - """ - Context for selecting price intervals. - - Attributes: - merged: List of merged price and rating intervals for the selected day. - all_ratings: All rating intervals for the selected day. - coordinator: Data update coordinator for the integration. - day: The day being queried ('yesterday', 'today', or 'tomorrow'). - now: The datetime used for interval selection. - is_simulated: Whether the time is simulated (from user input) or real. - - """ - - merged: list[dict] - all_ratings: list[dict] - coordinator: Any - day: str - now: datetime - is_simulated: bool +# --- Dataclasses --- @dataclass @@ -747,9 +583,11 @@ class PriceStats: price_min: float price_min_start_time: str | None price_min_end_time: str | None + price_min_interval: dict | None price_max: float price_max_start_time: str | None price_max_end_time: str | None + price_max_interval: dict | None stats_merged: list[dict] @@ -762,13 +600,10 @@ class PriceResponseContext: current_interval: dict | None next_interval: dict | None currency: str | None - rating_threshold_percentages: Any merged: list[dict] -# endregion - -# region Service registration +# --- Service registration --- @callback @@ -802,6 +637,3 @@ def async_setup_services(hass: HomeAssistant) -> None: schema=REFRESH_USER_DATA_SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) - - -# endregion diff --git a/custom_components/tibber_prices/services_old.py b/custom_components/tibber_prices/services_old.py new file mode 100644 index 0000000..bc3a5f7 --- /dev/null +++ b/custom_components/tibber_prices/services_old.py @@ -0,0 +1,662 @@ +"""Services for Tibber Prices integration.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Final + +import voluptuous as vol + +from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse, callback +from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry +from homeassistant.util import dt as dt_util + +from .api import ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientCommunicationError, + TibberPricesApiClientError, +) +from .const import ( + DOMAIN, + PRICE_LEVEL_CHEAP, + PRICE_LEVEL_EXPENSIVE, + PRICE_LEVEL_NORMAL, + PRICE_LEVEL_VERY_CHEAP, + PRICE_LEVEL_VERY_EXPENSIVE, + PRICE_RATING_HIGH, + PRICE_RATING_LOW, + PRICE_RATING_NORMAL, + get_price_level_translation, +) + +PRICE_SERVICE_NAME = "get_price" +APEXCHARTS_DATA_SERVICE_NAME = "get_apexcharts_data" +APEXCHARTS_YAML_SERVICE_NAME = "get_apexcharts_yaml" +REFRESH_USER_DATA_SERVICE_NAME = "refresh_user_data" +ATTR_DAY: Final = "day" +ATTR_ENTRY_ID: Final = "entry_id" +ATTR_TIME: Final = "time" + +PRICE_SERVICE_SCHEMA: Final = vol.Schema( + { + vol.Required(ATTR_ENTRY_ID): str, + vol.Optional(ATTR_DAY): vol.In(["yesterday", "today", "tomorrow"]), + vol.Optional(ATTR_TIME): vol.Match(r"^(\d{2}:\d{2}(:\d{2})?)$"), # HH:mm or HH:mm:ss + } +) + +APEXCHARTS_DATA_SERVICE_SCHEMA: Final = vol.Schema( + { + vol.Required("entity_id"): str, + vol.Required("day"): vol.In(["yesterday", "today", "tomorrow"]), + vol.Required("level_type"): vol.In(["level", "rating_level"]), + vol.Required("level_key"): vol.In( + [ + PRICE_LEVEL_CHEAP, + PRICE_LEVEL_EXPENSIVE, + PRICE_LEVEL_NORMAL, + PRICE_LEVEL_VERY_CHEAP, + PRICE_LEVEL_VERY_EXPENSIVE, + PRICE_RATING_HIGH, + PRICE_RATING_LOW, + PRICE_RATING_NORMAL, + ] + ), + } +) + +APEXCHARTS_SERVICE_SCHEMA: Final = vol.Schema( + { + vol.Required("entity_id"): str, + vol.Optional("day", default="today"): vol.In(["yesterday", "today", "tomorrow"]), + } +) + +REFRESH_USER_DATA_SERVICE_SCHEMA: Final = vol.Schema( + { + vol.Required(ATTR_ENTRY_ID): str, + } +) + +# region Top-level functions (ordered by call hierarchy) + +# --- Entry point: Service handler --- + + +async def _get_price(call: ServiceCall) -> dict[str, Any]: + """ + Return price information with enriched rating data for the requested day and config entry. + + If 'time' is provided, it must be in HH:mm or HH:mm:ss format and is combined with the selected 'day'. + This only affects 'previous', 'current', and 'next' fields, not the 'prices' list. + If 'day' is not provided, prices list includes today and tomorrow, stats/interval selection for today. + """ + hass = call.hass + entry_id_raw = call.data.get(ATTR_ENTRY_ID) + if entry_id_raw is None: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_entry_id") + entry_id: str = str(entry_id_raw) + time_value = call.data.get(ATTR_TIME) + explicit_day = ATTR_DAY in call.data + day = call.data.get(ATTR_DAY, "today") + + entry, coordinator, data = _get_entry_and_data(hass, entry_id) + price_info_data, currency = _extract_price_data(data) + + # Determine which days to include + if explicit_day: + day_key = day if day in ("yesterday", "today", "tomorrow") else "today" + prices_raw = price_info_data.get(day_key, []) + stats_raw = prices_raw + else: + # No explicit day: include today + tomorrow for prices, use today for stats + today_raw = price_info_data.get("today", []) + tomorrow_raw = price_info_data.get("tomorrow", []) + prices_raw = today_raw + tomorrow_raw + stats_raw = today_raw + day_key = "today" + + # Transform to service format + prices_transformed = _transform_price_intervals(prices_raw) + stats_transformed = _transform_price_intervals(stats_raw) + + # Calculate stats only from stats_raw + price_stats = _get_price_stats(stats_transformed) + + # Determine now and simulation flag + now, is_simulated = _determine_now_and_simulation(time_value, stats_transformed) + + # Select intervals + previous_interval, current_interval, next_interval = _select_intervals( + stats_transformed, coordinator, day_key, now, is_simulated + ) + + # Add end_time to intervals + _annotate_end_times(prices_transformed, stats_transformed, day_key, price_info_data) + + # Clean up temp fields + for interval in prices_transformed: + if "start_dt" in interval: + del interval["start_dt"] + + response_ctx = PriceResponseContext( + price_stats=price_stats, + previous_interval=previous_interval, + current_interval=current_interval, + next_interval=next_interval, + currency=currency, + merged=prices_transformed, + ) + + return _build_price_response(response_ctx) + + +async def _get_entry_id_from_entity_id(hass: HomeAssistant, entity_id: str) -> str | None: + """Return the config entry_id for a given entity_id.""" + entity_registry = async_get_entity_registry(hass) + entry = entity_registry.async_get(entity_id) + if entry is not None: + return entry.config_entry_id + return None + + +async def _get_apexcharts_data(call: ServiceCall) -> dict[str, Any]: + """Return points for ApexCharts for a single level type (e.g., LOW, NORMAL, HIGH, etc).""" + entity_id = call.data.get("entity_id", "sensor.tibber_price_today") + day = call.data.get("day", "today") + level_type = call.data.get("level_type", "rating_level") + level_key = call.data.get("level_key") + hass = call.hass + + # Get entry ID and verify it exists + entry_id = await _get_entry_id_from_entity_id(hass, entity_id) + if not entry_id: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entity_id") + + entry, coordinator, data = _get_entry_and_data(hass, entry_id) + + # Get entries based on level_type + entries = _get_apexcharts_entries(coordinator, day, level_type) + if not entries: + return {"points": []} + + # Ensure level_key is a string + if level_key is None: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_level_key") + + # Generate points for the chart + points = _generate_apexcharts_points(entries, str(level_key)) + return {"points": points} + + +def _get_apexcharts_entries(coordinator: Any, day: str, level_type: str) -> list[dict]: + """Get the appropriate entries for ApexCharts based on level_type and day.""" + if level_type == "rating_level": + # price_info is now enriched with difference and rating_level from the coordinator + price_info = coordinator.data.get("priceInfo", {}) + day_info = price_info.get(day, []) + return day_info if day_info else [] + + # For non-rating level types, return the price info for the specified day + return coordinator.data.get("priceInfo", {}).get(day, []) + + +def _generate_apexcharts_points(entries: list[dict], level_key: str) -> list: + """Generate data points for ApexCharts based on the entries and level key.""" + points = [] + for i in range(len(entries) - 1): + p = entries[i] + if p.get("level") != level_key: + continue + points.append([p.get("time") or p.get("startsAt"), round((p.get("total") or 0) * 100, 2)]) + + # Add a final point with null value if there are any points + if points: + points.append([points[-1][0], None]) + + return points + + +async def _get_apexcharts_yaml(call: ServiceCall) -> dict[str, Any]: + """Return a YAML snippet for an ApexCharts card using the get_apexcharts_data service for each level.""" + entity_id = call.data.get("entity_id", "sensor.tibber_price_today") + day = call.data.get("day", "today") + level_type = call.data.get("level_type", "rating_level") + if level_type == "rating_level": + series_levels = [ + (PRICE_RATING_LOW, "#2ecc71"), + (PRICE_RATING_NORMAL, "#f1c40f"), + (PRICE_RATING_HIGH, "#e74c3c"), + ] + else: + series_levels = [ + (PRICE_LEVEL_VERY_CHEAP, "#2ecc71"), + (PRICE_LEVEL_CHEAP, "#27ae60"), + (PRICE_LEVEL_NORMAL, "#f1c40f"), + (PRICE_LEVEL_EXPENSIVE, "#e67e22"), + (PRICE_LEVEL_VERY_EXPENSIVE, "#e74c3c"), + ] + series = [] + for level_key, color in series_levels: + name = get_price_level_translation(level_key, "en") or level_key + data_generator = ( + f"const data = await hass.callService('tibber_prices', 'get_apexcharts_data', " + f"{{ entity_id: '{entity_id}', day: '{day}', level_type: '{level_type}', level_key: '{level_key}' }});\n" + f"return data.points;" + ) + series.append( + { + "entity": entity_id, + "name": name, + "type": "area", + "color": color, + "yaxis_id": "price", + "show": {"extremas": level_key != "NORMAL"}, + "data_generator": data_generator, + } + ) + title = "Preisphasen Tagesverlauf" if level_type == "rating" else "Preisniveau" + return { + "type": "custom:apexcharts-card", + "update_interval": "5m", + "span": {"start": "day"}, + "header": { + "show": True, + "title": title, + "show_states": False, + }, + "apex_config": { + "stroke": {"curve": "stepline"}, + "fill": {"opacity": 0.4}, + "tooltip": {"x": {"format": "HH:mm"}}, + "legend": {"show": True}, + }, + "yaxis": [ + {"id": "price", "decimals": 0, "min": 0}, + ], + "now": {"show": True, "color": "#8e24aa", "label": "🕒 LIVE"}, + "all_series_config": {"stroke_width": 1, "show": {"legend_value": False}}, + "series": series, + } + + +async def _refresh_user_data(call: ServiceCall) -> dict[str, Any]: + """Refresh user data for a specific config entry and return updated information.""" + entry_id = call.data.get(ATTR_ENTRY_ID) + hass = call.hass + + if not entry_id: + return { + "success": False, + "message": "Entry ID is required", + } + + # Get the entry and coordinator + try: + entry, coordinator, data = _get_entry_and_data(hass, entry_id) + except ServiceValidationError as ex: + return { + "success": False, + "message": f"Invalid entry ID: {ex}", + } + + # Force refresh user data using the public method + try: + updated = await coordinator.refresh_user_data() + except ( + TibberPricesApiClientAuthenticationError, + TibberPricesApiClientCommunicationError, + TibberPricesApiClientError, + ) as ex: + return { + "success": False, + "message": f"API error refreshing user data: {ex!s}", + } + else: + if updated: + user_profile = coordinator.get_user_profile() + homes = coordinator.get_user_homes() + + return { + "success": True, + "message": "User data refreshed successfully", + "user_profile": user_profile, + "homes_count": len(homes), + "homes": homes, + "last_updated": user_profile.get("last_updated"), + } + return { + "success": False, + "message": "User data was already up to date", + } + + +# --- Direct helpers (called by service handler or each other) --- + + +def _get_entry_and_data(hass: HomeAssistant, entry_id: str) -> tuple[Any, Any, dict]: + """Validate entry and extract coordinator and data.""" + if not entry_id: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_entry_id") + entry = next((e for e in hass.config_entries.async_entries(DOMAIN) if e.entry_id == entry_id), None) + if not entry or not hasattr(entry, "runtime_data") or not entry.runtime_data: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entry_id") + coordinator = entry.runtime_data.coordinator + data = coordinator.data or {} + return entry, coordinator, data + + +def _extract_price_data(data: dict) -> tuple[dict, Any]: + """ + Extract price info from enriched coordinator data. + + The price_info_data returned already includes 'difference' and 'rating_level' + enrichment from the coordinator, so no separate rating data extraction is needed. + """ + price_info_data = data.get("priceInfo") or {} + currency = price_info_data.get("currency") + return price_info_data, currency + + + + +def _transform_price_intervals(price_info: list[dict]) -> list[dict]: + """Transform priceInfo intervals to service output format.""" + result = [] + for interval in price_info or []: + ts = interval.get("startsAt") + start_dt = dt_util.parse_datetime(ts) if ts else None + item = {"start_time": ts, "start_dt": start_dt} if ts else {"start_dt": None} + + for k, v in interval.items(): + if k == "startsAt": + continue + if k == "total": + item["price"] = v + item["price_minor"] = round(v * 100, 2) + elif k not in ("energy", "tax"): + item[k] = v + + result.append(item) + + # Sort by datetime + result.sort(key=lambda x: (x.get("start_dt") is None, x.get("start_dt"))) + return result + + +def _annotate_intervals_with_times( + merged: list[dict], + price_info_by_day: dict, + day: str, +) -> None: + """Annotate merged intervals with end_time and previous_end_time.""" + for idx, interval in enumerate(merged): + # Default: next interval's start_time + if idx + 1 < len(merged): + interval["end_time"] = merged[idx + 1].get("start_time") + # Last interval: look into tomorrow if today, or None otherwise + elif day == "today": + next_start = _get_adjacent_start_time(price_info_by_day, "tomorrow", first=True) + interval["end_time"] = next_start + elif day == "yesterday": + next_start = _get_adjacent_start_time(price_info_by_day, "today", first=True) + interval["end_time"] = next_start + elif day == "tomorrow": + interval["end_time"] = None + else: + interval["end_time"] = None + # First interval: look into yesterday if today, or None otherwise + if idx == 0: + if day == "today": + prev_end = _get_adjacent_start_time(price_info_by_day, "yesterday", first=False) + interval["previous_end_time"] = prev_end + elif day == "tomorrow": + prev_end = _get_adjacent_start_time(price_info_by_day, "today", first=False) + interval["previous_end_time"] = prev_end + else: + interval["previous_end_time"] = None + + +def _get_price_stats(merged: list[dict]) -> PriceStats: + """Calculate average, min, and max price and their intervals from merged data.""" + if merged: + price_sum = sum(float(interval.get("price", 0)) for interval in merged if "price" in interval) + price_avg = round(price_sum / len(merged), 4) + else: + price_avg = 0 + price_min, price_min_start_time, price_min_end_time = _get_price_stat(merged, "min") + price_max, price_max_start_time, price_max_end_time = _get_price_stat(merged, "max") + return PriceStats( + price_avg=price_avg, + price_min=price_min, + price_min_start_time=price_min_start_time, + price_min_end_time=price_min_end_time, + price_max=price_max, + price_max_start_time=price_max_start_time, + price_max_end_time=price_max_end_time, + stats_merged=merged, + ) + + +def _determine_now_and_simulation( + time_value: str | None, interval_selection_merged: list[dict] +) -> tuple[datetime, bool]: + """Determine the 'now' datetime and simulation flag.""" + is_simulated = False + if time_value: + if not interval_selection_merged or not interval_selection_merged[0].get("start_time"): + # Instead of raising, return a simulated now for the requested day (structure will be empty) + now = dt_util.now().replace(second=0, microsecond=0) + is_simulated = True + return now, is_simulated + day_prefix = interval_selection_merged[0]["start_time"].split("T")[0] + dt_str = f"{day_prefix}T{time_value}" + try: + now = datetime.fromisoformat(dt_str) + except ValueError as exc: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="invalid_time", + translation_placeholders={"error": str(exc)}, + ) from exc + is_simulated = True + elif not interval_selection_merged or not interval_selection_merged[0].get("start_time"): + now = dt_util.now().replace(second=0, microsecond=0) + else: + day_prefix = interval_selection_merged[0]["start_time"].split("T")[0] + current_time = dt_util.now().time().replace(second=0, microsecond=0) + dt_str = f"{day_prefix}T{current_time.isoformat()}" + try: + now = datetime.fromisoformat(dt_str) + except ValueError: + now = dt_util.now().replace(second=0, microsecond=0) + is_simulated = True + return now, is_simulated + + +def _select_intervals(ctx: IntervalContext) -> tuple[Any, Any, Any]: + """ + Select previous, current, and next intervals for the given day and time. + + If is_simulated is True, always calculate previous/current/next for all days, but: + - For 'yesterday', never fetch previous from the day before yesterday. + - For 'tomorrow', never fetch next from the day after tomorrow. + If is_simulated is False, previous/current/next are None for 'yesterday' and 'tomorrow'. + """ + merged = ctx.merged + coordinator = ctx.coordinator + day = ctx.day + now = ctx.now + is_simulated = ctx.is_simulated + + if not merged or (not is_simulated and day in ("yesterday", "tomorrow")): + return None, None, None + + idx = None + cmp_now = dt_util.as_local(now) if now.tzinfo is None else now + for i, interval in enumerate(merged): + start_dt = interval.get("start_dt") + if not start_dt: + continue + if start_dt.tzinfo is None: + start_dt = dt_util.as_local(start_dt) + if start_dt <= cmp_now: + idx = i + elif start_dt > cmp_now: + break + + previous_interval = merged[idx - 1] if idx is not None and idx > 0 else None + current_interval = merged[idx] if idx is not None else None + next_interval = ( + merged[idx + 1] if idx is not None and idx + 1 < len(merged) else (merged[0] if idx is None else None) + ) + + if day == "today": + if idx == 0: + previous_interval = _find_previous_interval(merged, coordinator, day) + if idx == len(merged) - 1: + next_interval = _find_next_interval(merged, coordinator, day) + + return previous_interval, current_interval, next_interval + + +# --- Indirect helpers (called by helpers above) --- + + +def _build_price_response(ctx: PriceResponseContext) -> dict[str, Any]: + """Build the response dictionary for the price service.""" + price_stats = ctx.price_stats + return { + "average": { + "start_time": price_stats.stats_merged[0].get("start_time") if price_stats.stats_merged else None, + "end_time": price_stats.stats_merged[0].get("end_time") if price_stats.stats_merged else None, + "price": price_stats.price_avg, + "price_minor": round(price_stats.price_avg * 100, 2), + }, + "minimum": { + "start_time": price_stats.price_min_start_time, + "end_time": price_stats.price_min_end_time, + "price": price_stats.price_min, + "price_minor": round(price_stats.price_min * 100, 2), + }, + "maximum": { + "start_time": price_stats.price_max_start_time, + "end_time": price_stats.price_max_end_time, + "price": price_stats.price_max, + "price_minor": round(price_stats.price_max * 100, 2), + }, + "previous": ctx.previous_interval, + "current": ctx.current_interval, + "next": ctx.next_interval, + "currency": ctx.currency, + "interval_count": len(ctx.merged), + "intervals": ctx.merged, + } + + +def _get_price_stat(merged: list[dict], stat: str) -> tuple[float, str | None, str | None]: + """Return min or max price and its start and end time from merged intervals.""" + if not merged: + return 0, None, None + values = [float(interval.get("price", 0)) for interval in merged if "price" in interval] + if not values: + return 0, None, None + val = min(values) if stat == "min" else max(values) + start_time = next((interval.get("start_time") for interval in merged if interval.get("price") == val), None) + end_time = next((interval.get("end_time") for interval in merged if interval.get("price") == val), None) + return val, start_time, end_time + + +# endregion + +# region Main classes (dataclasses) + + +@dataclass +class IntervalContext: + """ + Context for selecting price intervals. + + Attributes: + merged: List of merged price and rating intervals for the selected day. + coordinator: Data update coordinator for the integration. + day: The day being queried ('yesterday', 'today', or 'tomorrow'). + now: The datetime used for interval selection. + is_simulated: Whether the time is simulated (from user input) or real. + + """ + + merged: list[dict] + coordinator: Any + day: str + now: datetime + is_simulated: bool + + +@dataclass +class PriceStats: + """Encapsulates price statistics and their intervals for the Tibber Prices service.""" + + price_avg: float + price_min: float + price_min_start_time: str | None + price_min_end_time: str | None + price_max: float + price_max_start_time: str | None + price_max_end_time: str | None + stats_merged: list[dict] + + +@dataclass +class PriceResponseContext: + """Context for building the price response.""" + + price_stats: PriceStats + previous_interval: dict | None + current_interval: dict | None + next_interval: dict | None + currency: str | None + merged: list[dict] + + +# endregion + +# region Service registration + + +@callback +def async_setup_services(hass: HomeAssistant) -> None: + """Set up services for Tibber Prices integration.""" + hass.services.async_register( + DOMAIN, + PRICE_SERVICE_NAME, + _get_price, + schema=PRICE_SERVICE_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) + hass.services.async_register( + DOMAIN, + APEXCHARTS_DATA_SERVICE_NAME, + _get_apexcharts_data, + schema=APEXCHARTS_DATA_SERVICE_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) + hass.services.async_register( + DOMAIN, + APEXCHARTS_YAML_SERVICE_NAME, + _get_apexcharts_yaml, + schema=APEXCHARTS_SERVICE_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) + hass.services.async_register( + DOMAIN, + REFRESH_USER_DATA_SERVICE_NAME, + _refresh_user_data, + schema=REFRESH_USER_DATA_SERVICE_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) + + +# endregion diff --git a/tests/test_price_utils.py b/tests/test_price_utils.py new file mode 100644 index 0000000..4d29fc5 --- /dev/null +++ b/tests/test_price_utils.py @@ -0,0 +1,168 @@ +"""Test price utils calculations.""" + +from datetime import timedelta + +from custom_components.tibber_prices.price_utils import ( + calculate_difference_percentage, + calculate_rating_level, + calculate_trailing_average_for_interval, + enrich_price_info_with_differences, +) +from homeassistant.util import dt as dt_util + + +def test_calculate_trailing_average_for_interval() -> None: + """Test trailing average calculation for a specific interval.""" + # Create sample price data spanning 24 hours + base_time = dt_util.now().replace(hour=12, minute=0, second=0, microsecond=0) + + prices = [] + # Create 96 quarter-hourly intervals (24 hours worth) + for i in range(96): + price_time = base_time - timedelta(hours=24) + timedelta(minutes=15 * i) + prices.append( + { + "startsAt": price_time.isoformat(), + "total": 0.1 + (i * 0.001), # Incrementing price + } + ) + + # Test interval at current time (should average last 24 hours) + test_time = base_time + average = calculate_trailing_average_for_interval(test_time, prices) + + assert average is not None + # Average of 96 prices from 0.1 to 0.195 (0.1 + 95*0.001) + expected_avg = (0.1 + 0.195) / 2 # ~0.1475 + assert abs(average - expected_avg) < 0.001 + + +def test_calculate_difference_percentage() -> None: + """Test difference percentage calculation.""" + current = 0.15 + average = 0.10 + + diff = calculate_difference_percentage(current, average) + assert diff is not None + assert abs(diff - 50.0) < 0.01 # 50% higher than average + + # Test with same price + diff = calculate_difference_percentage(0.10, 0.10) + assert diff == 0.0 + + # Test with None average + diff = calculate_difference_percentage(0.15, None) + assert diff is None + + # Test with zero average + diff = calculate_difference_percentage(0.15, 0.0) + assert diff is None + + +def test_enrich_price_info_with_differences() -> None: + """Test enriching price info with difference values.""" + base_time = dt_util.now().replace(hour=12, minute=0, second=0, microsecond=0) + + # Create mock price data covering 48 hours + price_info = { + "yesterday": [], + "today": [], + "tomorrow": [], + } + + # Fill yesterday with constant price + for i in range(96): # 96 intervals = 24 hours + price_time = base_time - timedelta(days=1) + timedelta(minutes=15 * i) + price_info["yesterday"].append( + { + "startsAt": price_time.isoformat(), + "total": 0.10, + } + ) + + # Add one interval for today + price_info["today"].append( + { + "startsAt": base_time.isoformat(), + "total": 0.15, + } + ) + + # Add one interval for tomorrow + price_info["tomorrow"].append( + { + "startsAt": (base_time + timedelta(days=1)).isoformat(), + "total": 0.12, + } + ) + + enriched = enrich_price_info_with_differences(price_info) + + # Today's price should have a difference calculated + assert "difference" in enriched["today"][0] + assert enriched["today"][0]["difference"] is not None + # 0.15 vs average of 0.10 = 50% higher + assert abs(enriched["today"][0]["difference"] - 50.0) < 1.0 + + # Today's price should also have a rating_level (50% > 10% threshold = HIGH) + assert "rating_level" in enriched["today"][0] + assert enriched["today"][0]["rating_level"] == "HIGH" + + # Tomorrow's price should also have a difference + assert "difference" in enriched["tomorrow"][0] + assert enriched["tomorrow"][0]["difference"] is not None + + # Tomorrow's price should have a rating_level + # The average will be pulled from yesterday (0.10) and today (0.15) + # With tomorrow price at 0.12, it should be close to NORMAL or LOW + assert "rating_level" in enriched["tomorrow"][0] + rating_level_tomorrow = enriched["tomorrow"][0]["rating_level"] + assert rating_level_tomorrow in {"LOW", "NORMAL"} + + +def test_calculate_rating_level() -> None: + """Test rating level calculation based on difference percentage and thresholds.""" + threshold_low = -10 + threshold_high = 10 + + # Test LOW threshold + level = calculate_rating_level(-15.0, threshold_low, threshold_high) + assert level == "LOW" + + # Test exact low threshold + level = calculate_rating_level(-10.0, threshold_low, threshold_high) + assert level == "LOW" + + # Test HIGH threshold + level = calculate_rating_level(15.0, threshold_low, threshold_high) + assert level == "HIGH" + + # Test exact high threshold + level = calculate_rating_level(10.0, threshold_low, threshold_high) + assert level == "HIGH" + + # Test NORMAL (between thresholds) + level = calculate_rating_level(0.0, threshold_low, threshold_high) + assert level == "NORMAL" + + level = calculate_rating_level(5.0, threshold_low, threshold_high) + assert level == "NORMAL" + + level = calculate_rating_level(-5.0, threshold_low, threshold_high) + assert level == "NORMAL" + + # Test None difference + level = calculate_rating_level(None, threshold_low, threshold_high) + assert level is None + + # Test edge case: difference in both ranges (both ranges simultaneously) + # This shouldn't normally happen, but if low > high, return NORMAL + level = calculate_rating_level(5.0, 10, -10) # inverted thresholds + assert level == "NORMAL" + + +if __name__ == "__main__": + test_calculate_trailing_average_for_interval() + test_calculate_difference_percentage() + test_enrich_price_info_with_differences() + test_calculate_rating_level() diff --git a/tests/test_price_utils_integration.py b/tests/test_price_utils_integration.py new file mode 100644 index 0000000..8405464 --- /dev/null +++ b/tests/test_price_utils_integration.py @@ -0,0 +1,131 @@ +"""Integration test for price utils with realistic data.""" + +from datetime import datetime, timedelta + +from custom_components.tibber_prices.price_utils import enrich_price_info_with_differences +from homeassistant.util import dt as dt_util + + +def generate_price_intervals(base_time: datetime, hours: int, base_price: float, variation: float = 0.05) -> list: + """Generate realistic price intervals.""" + intervals = [] + for i in range(hours * 4): # 4 intervals per hour (15-minute intervals) + time = base_time + timedelta(minutes=15 * i) + # Add sinusoidal variation (peak at 18:00, low at 6:00) + hour_of_day = time.hour + time.minute / 60 + variation_factor = 1 + variation * (((hour_of_day - 6) / 12) * 3.14159) + price = base_price * (1 + 0.1 * (variation_factor - 1)) + + intervals.append( + { + "startsAt": time.isoformat(), + "total": price, + "energy": price * 0.75, + "tax": price * 0.25, + "level": "NORMAL", + } + ) + + return intervals + + +def test_realistic_day_pricing() -> None: + """Test with realistic pricing patterns across 48 hours.""" + base_time = dt_util.now().replace(hour=12, minute=0, second=0, microsecond=0) + + # Generate realistic data + price_info = { + "yesterday": generate_price_intervals(base_time - timedelta(days=1), hours=24, base_price=0.12, variation=0.08), + "today": generate_price_intervals( + base_time.replace(hour=0, minute=0), hours=24, base_price=0.15, variation=0.10 + ), + "tomorrow": generate_price_intervals( + base_time.replace(hour=0, minute=0) + timedelta(days=1), hours=24, base_price=0.13, variation=0.07 + ), + } + + # Enrich with differences + enriched = enrich_price_info_with_differences(price_info) + + # Verify all today intervals have differences + today_intervals = enriched["today"] + for interval in today_intervals: + assert "difference" in interval, f"Missing difference in today interval {interval['startsAt']}" + assert "rating_level" in interval, f"Missing rating_level in today interval {interval['startsAt']}" + + # Verify all tomorrow intervals have differences + tomorrow_intervals = enriched["tomorrow"] + for interval in tomorrow_intervals: + assert "difference" in interval, f"Missing difference in tomorrow interval {interval['startsAt']}" + assert "rating_level" in interval, f"Missing rating_level in tomorrow interval {interval['startsAt']}" + + # Verify yesterday is unchanged (except for missing difference) + yesterday_intervals = enriched["yesterday"] + assert len(yesterday_intervals) == 96 + + # Analyze statistics + today_diffs = [i.get("difference") for i in today_intervals if i.get("difference") is not None] + today_levels = [i.get("rating_level") for i in today_intervals if i.get("rating_level") is not None] + tomorrow_levels = [i.get("rating_level") for i in tomorrow_intervals if i.get("rating_level") is not None] + + # Verify rating_level values are valid + valid_levels = {"LOW", "NORMAL", "HIGH"} + assert all(level in valid_levels for level in today_levels), "Invalid rating_level in today intervals" + assert all(level in valid_levels for level in tomorrow_levels), "Invalid rating_level in tomorrow intervals" + + # With realistic pricing variation and default thresholds of -10/+10, + # we should have at least 2 different levels (most likely HIGH and NORMAL for today, + # and NORMAL for tomorrow due to cheaper prices) + unique_today_levels = set(today_levels) + assert len(unique_today_levels) >= 1, "Today should have at least one rating level" + + +def test_day_boundary_calculations() -> None: + """Test calculations across midnight boundary.""" + midnight = dt_util.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # Create data that spans the midnight boundary + price_info = { + "yesterday": generate_price_intervals(midnight - timedelta(days=1), hours=24, base_price=0.10), + "today": generate_price_intervals(midnight, hours=24, base_price=0.15), + "tomorrow": generate_price_intervals(midnight + timedelta(days=1), hours=24, base_price=0.12), + } + + enriched = enrich_price_info_with_differences(price_info) + + # Check the midnight boundary interval (first of tomorrow) + midnight_tomorrow = enriched["tomorrow"][0] + + # This should include all 96 intervals from yesterday and all 96 from today + assert "difference" in midnight_tomorrow + diff = midnight_tomorrow.get("difference") + + # Since tomorrow is cheaper (0.12) than both yesterday (0.10) and today (0.15) + # The difference could be negative (cheap) or positive (expensive) depending on the mix + diff = midnight_tomorrow.get("difference") + assert diff is not None, "Midnight boundary interval should have difference" + + +def test_early_morning_calculations() -> None: + """Test calculations in early morning hours.""" + base_time = dt_util.now().replace(hour=6, minute=0, second=0, microsecond=0) + + price_info = { + "yesterday": generate_price_intervals(base_time - timedelta(days=1), hours=24, base_price=0.12), + "today": generate_price_intervals(base_time.replace(hour=0, minute=0), hours=24, base_price=0.15), + "tomorrow": generate_price_intervals( + base_time.replace(hour=0, minute=0) + timedelta(days=1), hours=24, base_price=0.13 + ), + } + + enriched = enrich_price_info_with_differences(price_info) + + # Get 6 AM interval (24th interval of the day) + six_am_interval = enriched["today"][24] + assert "difference" in six_am_interval + + # At 6 AM, we should include: + # - Yesterday from 6 AM to midnight (68 intervals) + # - Today from midnight to 6 AM (24 intervals) + # Total: 92 intervals (not quite 24 hours) + assert "difference" in six_am_interval diff --git a/tests/test_services_enrich.py b/tests/test_services_enrich.py new file mode 100644 index 0000000..178c81b --- /dev/null +++ b/tests/test_services_enrich.py @@ -0,0 +1,106 @@ +"""Test that min/max/average include enriched attributes.""" + +from datetime import datetime + +import pytest + +from custom_components.tibber_prices.services import _get_price_stat, _get_price_stats + + +def test_min_max_intervals_include_enriched_attributes(): + """Test that min/max intervals contain difference and rating_level.""" + merged = [ + { + "start_time": "2025-11-01T00:00:00+01:00", + "end_time": "2025-11-01T01:00:00+01:00", + "start_dt": datetime(2025, 11, 1, 0, 0), + "price": 0.15, + "price_minor": 15, + "difference": -10.5, + "rating_level": "LOW", + "level": "VERY_CHEAP", + }, + { + "start_time": "2025-11-01T01:00:00+01:00", + "end_time": "2025-11-01T02:00:00+01:00", + "start_dt": datetime(2025, 11, 1, 1, 0), + "price": 0.25, + "price_minor": 25, + "difference": 5.0, + "rating_level": "NORMAL", + "level": "NORMAL", + }, + { + "start_time": "2025-11-01T02:00:00+01:00", + "end_time": "2025-11-01T03:00:00+01:00", + "start_dt": datetime(2025, 11, 1, 2, 0), + "price": 0.35, + "price_minor": 35, + "difference": 25.3, + "rating_level": "HIGH", + "level": "EXPENSIVE", + }, + ] + + stats = _get_price_stats(merged) + + # Verify min interval has all attributes + assert stats.price_min == 0.15 + assert stats.price_min_interval is not None + assert stats.price_min_interval["difference"] == -10.5 + assert stats.price_min_interval["rating_level"] == "LOW" + assert stats.price_min_interval["level"] == "VERY_CHEAP" + + # Verify max interval has all attributes + assert stats.price_max == 0.35 + assert stats.price_max_interval is not None + assert stats.price_max_interval["difference"] == 25.3 + assert stats.price_max_interval["rating_level"] == "HIGH" + assert stats.price_max_interval["level"] == "EXPENSIVE" + + # Verify average price is calculated + assert stats.price_avg == pytest.approx((0.15 + 0.25 + 0.35) / 3, rel=1e-4) + + +def test_get_price_stat_returns_full_interval(): + """Test that _get_price_stat returns the complete interval dict.""" + merged = [ + { + "start_time": "2025-11-01T00:00:00+01:00", + "price": 0.10, + "difference": -15.0, + "rating_level": "LOW", + }, + { + "start_time": "2025-11-01T01:00:00+01:00", + "price": 0.20, + "difference": 0.0, + "rating_level": "NORMAL", + }, + ] + + min_price, min_interval = _get_price_stat(merged, "min") + max_price, max_interval = _get_price_stat(merged, "max") + + # Min should be first interval + assert min_price == 0.10 + assert min_interval is not None + assert min_interval["difference"] == -15.0 + assert min_interval["rating_level"] == "LOW" + + # Max should be second interval + assert max_price == 0.20 + assert max_interval is not None + assert max_interval["difference"] == 0.0 + assert max_interval["rating_level"] == "NORMAL" + + +def test_empty_merged_returns_none_intervals(): + """Test that empty merged list returns None for intervals.""" + stats = _get_price_stats([]) + + assert stats.price_min == 0 + assert stats.price_min_interval is None + assert stats.price_max == 0 + assert stats.price_max_interval is None + assert stats.price_avg == 0