From 1d1f6ec3ca6d5923568910dfc287244bce9e9998 Mon Sep 17 00:00:00 2001 From: Julian Pawlowski Date: Tue, 20 May 2025 19:25:10 +0000 Subject: [PATCH] fix --- custom_components/tibber_prices/__init__.py | 2 +- custom_components/tibber_prices/services.py | 371 ++++++++++++++------ 2 files changed, 268 insertions(+), 105 deletions(-) diff --git a/custom_components/tibber_prices/__init__.py b/custom_components/tibber_prices/__init__.py index baf0158..5419d90 100644 --- a/custom_components/tibber_prices/__init__.py +++ b/custom_components/tibber_prices/__init__.py @@ -86,7 +86,7 @@ async def async_unload_entry( # Unregister services if this was the last config entry if not hass.config_entries.async_entries(DOMAIN): - for service in ("get_priceinfo", "get_pricerating"): + for service in "get_price": if hass.services.has_service(DOMAIN, service): hass.services.async_remove(DOMAIN, service) diff --git a/custom_components/tibber_prices/services.py b/custom_components/tibber_prices/services.py index 92313fe..569593d 100644 --- a/custom_components/tibber_prices/services.py +++ b/custom_components/tibber_prices/services.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dataclasses import dataclass from datetime import datetime from typing import Any, Final @@ -21,7 +22,7 @@ ATTR_TIME: Final = "time" SERVICE_SCHEMA: Final = vol.Schema( { vol.Required(ATTR_ENTRY_ID): str, - vol.Optional(ATTR_DAY, default="today"): vol.In(["yesterday", "today", "tomorrow"]), + vol.Optional(ATTR_DAY): vol.In(["yesterday", "today", "tomorrow"]), vol.Optional(ATTR_TIME): vol.Match(r"^(\d{2}:\d{2}(:\d{2})?)$"), # HH:mm or HH:mm:ss } ) @@ -32,18 +33,20 @@ def _merge_priceinfo_and_pricerating(price_info: list[dict], price_rating: list[ Merge priceInfo and priceRating intervals by timestamp, prefixing rating fields. Also rename startsAt to start_time. Preserves item order. + Adds 'start_dt' (datetime) to each merged interval for reliable sorting/comparison. """ rating_by_time = {(r.get("time") or r.get("startsAt")): r for r in price_rating or []} merged = [] for interval in price_info or []: ts = interval.get("startsAt") - merged_interval = {"start_time": ts} if ts is not None else {} + start_dt = dt_util.parse_datetime(ts) if ts else None + merged_interval = {"start_time": ts, "start_dt": start_dt} if ts is not None else {"start_dt": None} for k, v in interval.items(): if k == "startsAt": continue if k == "total": merged_interval["price"] = v - merged_interval["price_ct"] = round(v * 100, 2) + merged_interval["price_minor"] = round(v * 100, 2) elif k not in ("energy", "tax"): merged_interval[k] = v rating = rating_by_time.get(ts) @@ -56,6 +59,8 @@ def _merge_priceinfo_and_pricerating(price_info: list[dict], price_rating: list[ else: merged_interval[f"rating_{k}"] = v merged.append(merged_interval) + # Always sort by start_dt (datetime), None values last + merged.sort(key=lambda x: (x.get("start_dt") is None, x.get("start_dt"))) return merged @@ -101,9 +106,30 @@ def _find_next_interval( return None -def _select_intervals( - merged: list[dict], all_ratings: list[dict], coordinator: Any, day: str, now: datetime, *, is_simulated: bool -) -> tuple[Any, Any, Any]: +@dataclass +class IntervalContext: + """ + Context for selecting price intervals. + + Attributes: + merged: List of merged price and rating intervals for the selected day. + all_ratings: All rating intervals for the selected day. + coordinator: Data update coordinator for the integration. + day: The day being queried ('yesterday', 'today', or 'tomorrow'). + now: The datetime used for interval selection. + is_simulated: Whether the time is simulated (from user input) or real. + + """ + + merged: list[dict] + all_ratings: list[dict] + coordinator: Any + day: str + now: datetime + is_simulated: bool + + +def _select_intervals(ctx: IntervalContext) -> tuple[Any, Any, Any]: """ Select previous, current, and next intervals for the given day and time. @@ -112,50 +138,39 @@ def _select_intervals( - For 'tomorrow', never fetch next from the day after tomorrow. If is_simulated is False, previous/current/next are None for 'yesterday' and 'tomorrow'. """ - if not merged: - return None, None, None + merged = ctx.merged + all_ratings = ctx.all_ratings + coordinator = ctx.coordinator + day = ctx.day + now = ctx.now + is_simulated = ctx.is_simulated - if not is_simulated and day in ("yesterday", "tomorrow"): + if not merged or (not is_simulated and day in ("yesterday", "tomorrow")): return None, None, None idx = None + cmp_now = dt_util.as_local(now) if now.tzinfo is None else now for i, interval in enumerate(merged): - start_time = interval.get("start_time") - if not start_time: + start_dt = interval.get("start_dt") + if not start_dt: continue - start_dt = dt_util.parse_datetime(start_time) - if start_dt is None: - try: - start_dt = datetime.fromisoformat(start_time) - except ValueError: - continue if start_dt.tzinfo is None: start_dt = dt_util.as_local(start_dt) - cmp_now = now - if cmp_now.tzinfo is None: - cmp_now = dt_util.as_local(cmp_now) if start_dt <= cmp_now: idx = i - if start_dt > cmp_now: + elif start_dt > cmp_now: break - previous_interval = None - current_interval = None - next_interval = None + previous_interval = merged[idx - 1] if idx is not None and idx > 0 else None + current_interval = merged[idx] if idx is not None else None + next_interval = ( + merged[idx + 1] if idx is not None and idx + 1 < len(merged) else (merged[0] if idx is None else None) + ) - if idx is None: - next_interval = merged[0] - else: - current_interval = merged[idx] - previous_interval = merged[idx - 1] if idx > 0 else None - if idx + 1 < len(merged): - next_interval = merged[idx + 1] - - # For today, allow previous/next from adjacent days if day == "today": - if idx is not None and idx == 0: + if idx == 0: previous_interval = _find_previous_interval(merged, all_ratings, coordinator, day) - if idx is not None and idx == len(merged) - 1: + if idx == len(merged) - 1: next_interval = _find_next_interval(merged, all_ratings, coordinator, day) return previous_interval, current_interval, next_interval @@ -218,6 +233,86 @@ def get_price_stat(merged: list[dict], stat: str) -> tuple[float, str | None, st return val, start_time, end_time +def _get_price_stats(merged: list[dict]) -> PriceStats: + """Calculate average, min, and max price and their intervals from merged data.""" + if merged: + price_sum = sum(float(interval.get("price", 0)) for interval in merged if "price" in interval) + price_avg = round(price_sum / len(merged), 4) + else: + price_avg = 0 + price_min, price_min_start_time, price_min_end_time = get_price_stat(merged, "min") + price_max, price_max_start_time, price_max_end_time = get_price_stat(merged, "max") + return PriceStats( + price_avg=price_avg, + price_min=price_min, + price_min_start_time=price_min_start_time, + price_min_end_time=price_min_end_time, + price_max=price_max, + price_max_start_time=price_max_start_time, + price_max_end_time=price_max_end_time, + stats_merged=merged, + ) + + +@dataclass +class PriceStats: + """Encapsulates price statistics and their intervals for the Tibber Prices service.""" + + price_avg: float + price_min: float + price_min_start_time: str | None + price_min_end_time: str | None + price_max: float + price_max_start_time: str | None + price_max_end_time: str | None + stats_merged: list[dict] + + +@dataclass +class PriceResponseContext: + """Context for building the price response.""" + + price_stats: PriceStats + previous_interval: dict | None + current_interval: dict | None + next_interval: dict | None + currency: str | None + rating_threshold_percentages: Any + merged: list[dict] + + +def _build_price_response(ctx: PriceResponseContext) -> dict[str, Any]: + """Build the response dictionary for the price service.""" + price_stats = ctx.price_stats + return { + "average": { + "start_time": price_stats.stats_merged[0].get("start_time") if price_stats.stats_merged else None, + "end_time": price_stats.stats_merged[0].get("end_time") if price_stats.stats_merged else None, + "price": price_stats.price_avg, + "price_minor": round(price_stats.price_avg * 100, 2), + }, + "minimum": { + "start_time": price_stats.price_min_start_time, + "end_time": price_stats.price_min_end_time, + "price": price_stats.price_min, + "price_minor": round(price_stats.price_min * 100, 2), + }, + "maximum": { + "start_time": price_stats.price_max_start_time, + "end_time": price_stats.price_max_end_time, + "price": price_stats.price_max, + "price_minor": round(price_stats.price_max * 100, 2), + }, + "previous": ctx.previous_interval, + "current": ctx.current_interval, + "next": ctx.next_interval, + "currency": ctx.currency, + "rating_threshold_%": ctx.rating_threshold_percentages, + "interval_count": len(ctx.merged), + "intervals": ctx.merged, + } + + async def _get_price(call: ServiceCall) -> dict[str, Any]: """ Return merged priceInfo and priceRating for the requested day and config entry. @@ -225,11 +320,77 @@ async def _get_price(call: ServiceCall) -> dict[str, Any]: If 'time' is provided, it must be in HH:mm or HH:mm:ss format and is combined with the selected 'day'. This only affects 'previous', 'current', and 'next' fields, not the 'prices' list. If 'time' is not provided, the current time is used for all days. + If 'day' is not provided, the prices list will include today and tomorrow, but stats and interval + selection are only for today. """ hass = call.hass - day = call.data.get(ATTR_DAY, "today") - entry_id = call.data.get(ATTR_ENTRY_ID) + entry_id_raw = call.data.get(ATTR_ENTRY_ID) + if entry_id_raw is None: + raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_entry_id") + entry_id: str = str(entry_id_raw) time_value = call.data.get(ATTR_TIME) + explicit_day = ATTR_DAY in call.data + day = call.data.get(ATTR_DAY) + + entry, coordinator, data = _get_entry_and_data(hass, entry_id) + price_info_data, price_rating_data, hourly_ratings, rating_threshold_percentages, currency = _extract_price_data( + data + ) + + price_info_by_day, day_prefixes, ratings_by_day = _prepare_day_structures(price_info_data, hourly_ratings) + + ( + merged, + stats_merged, + interval_selection_merged, + interval_selection_ratings, + interval_selection_day, + ) = _select_merge_strategy( + explicit_day=explicit_day, + day=day if day is not None else "today", + price_info_by_day=price_info_by_day, + ratings_by_day=ratings_by_day, + ) + + annotate_intervals_with_times( + merged, + price_info_by_day, + interval_selection_day, + ) + + price_stats = _get_price_stats(stats_merged) + + now, is_simulated = _determine_now_and_simulation(time_value, interval_selection_merged) + + ctx = IntervalContext( + merged=interval_selection_merged, + all_ratings=interval_selection_ratings, + coordinator=coordinator, + day=interval_selection_day, + now=now, + is_simulated=is_simulated, + ) + previous_interval, current_interval, next_interval = _select_intervals(ctx) + + for interval in merged: + if "previous_end_time" in interval: + del interval["previous_end_time"] + + response_ctx = PriceResponseContext( + price_stats=price_stats, + previous_interval=previous_interval, + current_interval=current_interval, + next_interval=next_interval, + currency=currency, + rating_threshold_percentages=rating_threshold_percentages, + merged=merged, + ) + + return _build_price_response(response_ctx) + + +def _get_entry_and_data(hass: HomeAssistant, entry_id: str) -> tuple[Any, Any, dict]: + """Validate entry and extract coordinator and data.""" if not entry_id: raise ServiceValidationError(translation_domain=DOMAIN, translation_key="missing_entry_id") entry = next((e for e in hass.config_entries.async_entries(DOMAIN) if e.entry_id == entry_id), None) @@ -237,13 +398,21 @@ async def _get_price(call: ServiceCall) -> dict[str, Any]: raise ServiceValidationError(translation_domain=DOMAIN, translation_key="invalid_entry_id") coordinator = entry.runtime_data.coordinator data = coordinator.data or {} + return entry, coordinator, data + + +def _extract_price_data(data: dict) -> tuple[dict, dict, list, Any, Any]: + """Extract price info and rating data from coordinator data.""" price_info_data = data.get("priceInfo") or {} price_rating_data = data.get("priceRating") or {} hourly_ratings = price_rating_data.get("hourly") or [] rating_threshold_percentages = price_rating_data.get("thresholdPercentages") currency = price_rating_data.get("currency") + return price_info_data, price_rating_data, hourly_ratings, rating_threshold_percentages, currency - # Fetch all relevant day data once + +def _prepare_day_structures(price_info_data: dict, hourly_ratings: list) -> tuple[dict, dict, dict]: + """Prepare price info, day prefixes, and ratings by day.""" price_info_by_day = {d: price_info_data.get(d) or [] for d in ("yesterday", "today", "tomorrow")} day_prefixes = {d: _get_day_prefixes(price_info_by_day[d]) for d in ("yesterday", "today", "tomorrow")} ratings_by_day = { @@ -256,95 +425,89 @@ async def _get_price(call: ServiceCall) -> dict[str, Any]: else [] for d in ("yesterday", "today", "tomorrow") } + return price_info_by_day, day_prefixes, ratings_by_day - price_info = price_info_by_day[day] - all_ratings = ratings_by_day[day] - merged = _merge_priceinfo_and_pricerating(price_info, all_ratings) - annotate_intervals_with_times(merged, price_info_by_day, day) - - price_avg = ( - round(sum(float(interval.get("price", 0)) for interval in merged if "price" in interval) / len(merged), 4) - if merged - else 0 +def _select_merge_strategy( + *, + explicit_day: bool, + day: str, + price_info_by_day: dict, + ratings_by_day: dict, +) -> tuple[list, list, list, list, str]: + """Select merging strategy for intervals and stats.""" + if not explicit_day: + merged_today = _merge_priceinfo_and_pricerating(price_info_by_day["today"], ratings_by_day["today"]) + merged_tomorrow = _merge_priceinfo_and_pricerating(price_info_by_day["tomorrow"], ratings_by_day["tomorrow"]) + merged = merged_today + merged_tomorrow + stats_merged = merged_today + interval_selection_merged = merged_today + interval_selection_ratings = ratings_by_day["today"] + interval_selection_day = "today" + else: + day_key = day if day in ("yesterday", "today", "tomorrow") else "today" + merged = _merge_priceinfo_and_pricerating(price_info_by_day[day_key], ratings_by_day[day_key]) + stats_merged = merged + interval_selection_merged = merged + interval_selection_ratings = ratings_by_day[day_key] + interval_selection_day = day_key + return ( + merged, + stats_merged, + interval_selection_merged, + interval_selection_ratings, + interval_selection_day, ) - price_min, price_min_start_time, price_min_end_time = get_price_stat(merged, "min") - price_max, price_max_start_time, price_max_end_time = get_price_stat(merged, "max") + +def _determine_now_and_simulation( + time_value: str | None, interval_selection_merged: list[dict] +) -> tuple[datetime, bool]: + """Determine the 'now' datetime and simulation flag.""" + is_simulated = False if time_value: - if not price_info or not price_info[0].get("startsAt"): + if not interval_selection_merged or not interval_selection_merged[0].get("start_time"): raise ServiceValidationError( translation_domain=DOMAIN, translation_key="no_data_for_day", ) - day_prefix = price_info[0]["startsAt"].split("T")[0] + day_prefix = interval_selection_merged[0]["start_time"].split("T")[0] dt_str = f"{day_prefix}T{time_value}" try: now = datetime.fromisoformat(dt_str) except ValueError as exc: raise ServiceValidationError( translation_domain=DOMAIN, - translation_key="invalid_simulate_time", + translation_key="invalid_time", translation_placeholders={"error": str(exc)}, ) from exc is_simulated = True + elif not interval_selection_merged or not interval_selection_merged[0].get("start_time"): + now = dt_util.now().replace(second=0, microsecond=0) else: - if not price_info or not price_info[0].get("startsAt"): + day_prefix = interval_selection_merged[0]["start_time"].split("T")[0] + current_time = dt_util.now().time().replace(second=0, microsecond=0) + dt_str = f"{day_prefix}T{current_time.isoformat()}" + try: + now = datetime.fromisoformat(dt_str) + except ValueError: now = dt_util.now().replace(second=0, microsecond=0) - else: - day_prefix = price_info[0]["startsAt"].split("T")[0] - current_time = dt_util.now().time().replace(second=0, microsecond=0) - dt_str = f"{day_prefix}T{current_time.isoformat()}" - try: - now = datetime.fromisoformat(dt_str) - except ValueError: - now = dt_util.now().replace(second=0, microsecond=0) - is_simulated = True - - previous_interval, current_interval, next_interval = _select_intervals( - merged, ratings_by_day[day], coordinator, day, now, is_simulated=is_simulated - ) - - # Remove 'previous_end_time' from output intervals - for interval in merged: - if "previous_end_time" in interval: - del interval["previous_end_time"] - - return { - "average": { - "start_time": merged[0].get("start_time") if merged else None, - "end_time": merged[0].get("end_time") if merged else None, - "price": price_avg, - "price_ct": round(price_avg * 100, 2), - }, - "minimum": { - "start_time": price_min_start_time, - "end_time": price_min_end_time, - "price": price_min, - "price_ct": round(price_min * 100, 2), - }, - "maximum": { - "start_time": price_max_start_time, - "end_time": price_max_end_time, - "price": price_max, - "price_ct": round(price_max * 100, 2), - }, - "previous": previous_interval, - "current": current_interval, - "next": next_interval, - "currency": currency, - "rating_threshold_%": rating_threshold_percentages, - "prices": merged, - } + return now, is_simulated -def _get_day_prefixes(price_info: list[dict]) -> list[str]: - """Get ISO date prefixes for the requested day from price_info intervals.""" +DAY_PREFIX_LENGTH = 10 + + +def _get_day_prefixes(day_info: list[dict]) -> list[str]: + """Return a list of unique day prefixes from the intervals' start datetimes.""" prefixes = set() - for interval in price_info: - ts = interval.get("startsAt") - if ts and "T" in ts: - prefixes.add(ts.split("T")[0]) + for interval in day_info: + dt_str = interval.get("time") or interval.get("startsAt") + if not dt_str: + continue + start_dt = dt_util.parse_datetime(dt_str) + if start_dt: + prefixes.add(start_dt.date().isoformat()) return list(prefixes)