diff --git a/custom_components/tibber_prices/coordinator/data_transformation.py b/custom_components/tibber_prices/coordinator/data_transformation.py index 01d6223..3fc6b17 100644 --- a/custom_components/tibber_prices/coordinator/data_transformation.py +++ b/custom_components/tibber_prices/coordinator/data_transformation.py @@ -21,6 +21,24 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) +def _build_period_calculation_intervals(enriched_intervals: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Return enriched intervals with raw Tibber levels restored for period logic.""" + period_intervals = copy.deepcopy(enriched_intervals) + + for interval in period_intervals: + original_level = interval.pop("_original_level", None) + if original_level is not None: + interval["level"] = original_level + + return period_intervals + + +def _strip_internal_enrichment_fields(enriched_intervals: list[dict[str, Any]]) -> None: + """Remove internal enrichment helpers before exposing priceInfo.""" + for interval in enriched_intervals: + interval.pop("_original_level", None) + + class TibberPricesDataTransformer: """Handles data transformation, enrichment, and period calculations.""" @@ -264,6 +282,9 @@ class TibberPricesDataTransformer: time=self.time, ) + period_intervals = _build_period_calculation_intervals(enriched_intervals) + _strip_internal_enrichment_fields(enriched_intervals) + # Store enriched intervals directly as priceInfo (flat list) transformed_data = { "home_id": home_id, @@ -281,7 +302,7 @@ class TibberPricesDataTransformer: # Calculate periods (best price and peak price) if "priceInfo" in transformed_data: transformed_data["pricePeriods"] = self._calculate_periods_fn( - transformed_data["priceInfo"], transformed_data.get("dayPatterns") + period_intervals, transformed_data.get("dayPatterns") ) # Cache the transformed data diff --git a/custom_components/tibber_prices/coordinator/period_handlers/relaxation.py b/custom_components/tibber_prices/coordinator/period_handlers/relaxation.py index 2fba51b..bd8ae8c 100644 --- a/custom_components/tibber_prices/coordinator/period_handlers/relaxation.py +++ b/custom_components/tibber_prices/coordinator/period_handlers/relaxation.py @@ -570,8 +570,10 @@ def calculate_periods_with_relaxation( Calculate periods with optional global filter relaxation and per-day target tracking. Strategy: a single global relaxation loop iterates flex levels (3% steps from - the configured base flex up to MAX_FLEX_HARD_LIMIT). After every step we re-run - period detection across all available days and check, per day, how many quality + the configured base flex up to MAX_FLEX_HARD_LIMIT). At each flex level we + first re-run period detection with the configured level filter still intact. + Only if that is still insufficient do we retry the same flex with + `level_filter="any"`. After every attempt we check, per day, how many quality periods (CV ≤ PERIOD_MAX_CV) have accumulated. Days that already meet the target (`min_periods`) are not re-processed; the loop exits as soon as **all** days meet their target. Days with very flat prices automatically need only 1 period @@ -580,8 +582,10 @@ def calculate_periods_with_relaxation( If after all flex levels some days still have ZERO periods, a last-resort `min_period_length` fallback is attempted (see `_try_min_duration_fallback`). - Phase 1: Increase flex threshold step-by-step (up to max_relaxation_attempts) - Phase 2: Disable level filter (set to "any") in combination with each flex step + Phase 1: Increase flex threshold step-by-step while preserving the configured + level filter. + Phase 2: Retry the same flex with `level_filter="any"` when a concrete level + filter is configured. Args: all_prices: All price data points @@ -861,10 +865,12 @@ def calculate_periods_with_relaxation( days_meeting_requirement += 1 elif enable_relaxation: + filter_combination_count = 2 if config.level_filter not in (None, "any") else 1 _LOGGER_DETAILS.debug( - "%sAll %d days met target with baseline - no relaxation needed", + "%sRelaxation strategy: 3%% fixed flex increment per step (%d flex levels x %d filter combinations)", INDENT_L1, total_days, + filter_combination_count, ) # Sort periods by start time @@ -917,10 +923,11 @@ def relax_all_prices( """ Relax filters for all prices until min_periods per day is reached. - Strategy: Try increasing flex by 3% increments, then relax level filter. - Processes all prices together (yesterday+today+tomorrow), allowing periods - to cross midnight boundaries. Returns when ALL days have min_periods - (or max attempts exhausted). + Strategy: Try increasing flex by 3% increments while keeping the configured + level filter. For each flex level, optionally retry with `level_filter="any"` + when a concrete level filter is configured. Processes all prices together + (yesterday+today+tomorrow), allowing periods to cross midnight boundaries. + Returns when ALL days have min_periods (or max attempts exhausted). Args: all_prices: All price intervals (yesterday+today+tomorrow). @@ -947,6 +954,10 @@ def relax_all_prices( existing_periods = list(baseline_periods) # Start with baseline phases_used = [] + filter_variants: list[tuple[str | None, str | None]] = [(None, original_level_filter)] + if original_level_filter not in (None, "any"): + filter_variants.append(("any", "any")) + # Get available days from prices for checking prices_by_day = group_prices_by_day(all_prices, time=time) total_days = len(prices_by_day) @@ -964,98 +975,103 @@ def relax_all_prices( ) break - phase_label = f"flex={current_flex * 100:.1f}%" + for level_override, applied_level_filter in filter_variants: + phase_label = f"flex={current_flex * 100:.1f}%" + phase_label_full = phase_label + if applied_level_filter is not None: + phase_label_full = f"{phase_label} +level_{applied_level_filter}" - # Skip this flex level if callback says not to show it - if not should_show_callback(phase_label): - continue + # The callback expects a level override (e.g. None or "any"), not a flex label. + if not should_show_callback(level_override): + continue + + if level_override == "any" and original_level_filter not in (None, "any"): + _LOGGER_DETAILS.debug( + "%s Flex=%.1f%%: OVERRIDING level_filter: %s → ANY", + INDENT_L2, + current_flex * 100, + original_level_filter, + ) + + # NOTE: config.flex is already normalized to positive by get_period_config() + relaxed_config = config._replace( + flex=current_flex, # Already positive from normalization + level_filter=applied_level_filter, + ) - # Try current flex with level="any" (in relaxation mode) - if original_level_filter != "any": _LOGGER_DETAILS.debug( - "%s Flex=%.1f%%: OVERRIDING level_filter: %s → ANY", + "%s Trying %s: config has %d intervals (all days together), level_filter=%s", INDENT_L2, - current_flex * 100, - original_level_filter, - ) - - # NOTE: config.flex is already normalized to positive by get_period_config() - relaxed_config = config._replace( - flex=current_flex, # Already positive from normalization - level_filter="any", - ) - - phase_label_full = f"flex={current_flex * 100:.1f}% +level_any" - _LOGGER_DETAILS.debug( - "%s Trying %s: config has %d intervals (all days together), level_filter=%s", - INDENT_L2, - phase_label_full, - len(all_prices), - relaxed_config.level_filter, - ) - - # Process ALL prices together (allows midnight crossing) - result = calculate_periods( - all_prices, - config=relaxed_config, - time=time, - day_patterns_by_date=day_patterns_by_date, - ) - new_periods = result["periods"] - - _LOGGER_DETAILS.debug( - "%s %s: calculate_periods returned %d periods", - INDENT_L2, - phase_label_full, - len(new_periods), - ) - - # Mark newly found periods with relaxation metadata BEFORE merging - mark_periods_with_relaxation( - new_periods, - relaxation_level=phase_label_full, - original_threshold=base_flex, - applied_threshold=current_flex, - reverse_sort=config.reverse_sort, - ) - - # Resolve overlaps between existing and new periods - combined, standalone_count = resolve_period_overlaps( - existing_periods=existing_periods, - new_relaxed_periods=new_periods, - all_prices=all_prices, - config=config, - time=time, - ) - - # Count periods per day with QUALITY GATE check - # Only periods with CV <= PERIOD_MAX_CV count towards min_periods requirement - days_meeting_requirement, quality_period_count = _count_quality_periods( - combined, all_prices, prices_by_day, min_periods, time=time - ) - - total_periods = len(combined) - _LOGGER_DETAILS.debug( - "%s %s: found %d periods total, %d/%d days meet requirement", - INDENT_L2, - phase_label_full, - total_periods, - days_meeting_requirement, - total_days, - ) - - existing_periods = combined - phases_used.append(phase_label_full) - - # Check if ALL days reached target - if days_meeting_requirement >= total_days: - _LOGGER.info( - "Success with %s - all %d days have %d+ periods (%d total)", phase_label_full, - total_days, - min_periods, - total_periods, + len(all_prices), + relaxed_config.level_filter, ) + + # Process ALL prices together (allows midnight crossing) + result = calculate_periods( + all_prices, + config=relaxed_config, + time=time, + day_patterns_by_date=day_patterns_by_date, + ) + new_periods = result["periods"] + + _LOGGER_DETAILS.debug( + "%s %s: calculate_periods returned %d periods", + INDENT_L2, + phase_label_full, + len(new_periods), + ) + + # Mark newly found periods with relaxation metadata BEFORE merging + mark_periods_with_relaxation( + new_periods, + relaxation_level=phase_label_full, + original_threshold=base_flex, + applied_threshold=current_flex, + reverse_sort=config.reverse_sort, + ) + + # Resolve overlaps between existing and new periods + combined, standalone_count = resolve_period_overlaps( + existing_periods=existing_periods, + new_relaxed_periods=new_periods, + all_prices=all_prices, + config=config, + time=time, + ) + + # Count periods per day with QUALITY GATE check + # Only periods with CV <= PERIOD_MAX_CV count towards min_periods requirement + days_meeting_requirement, quality_period_count = _count_quality_periods( + combined, all_prices, prices_by_day, min_periods, time=time + ) + + total_periods = len(combined) + _LOGGER_DETAILS.debug( + "%s %s: found %d periods total, %d/%d days meet requirement", + INDENT_L2, + phase_label_full, + total_periods, + days_meeting_requirement, + total_days, + ) + + existing_periods = combined + phases_used.append(phase_label_full) + + # Check if ALL days reached target + if days_meeting_requirement >= total_days: + _LOGGER.info( + "Success with %s - all %d days have %d+ periods (%d total)", + phase_label_full, + total_days, + min_periods, + total_periods, + ) + break + + if days_meeting_requirement >= total_days: break # Build final result diff --git a/custom_components/tibber_prices/utils/price.py b/custom_components/tibber_prices/utils/price.py index 427c519..0e006db 100644 --- a/custom_components/tibber_prices/utils/price.py +++ b/custom_components/tibber_prices/utils/price.py @@ -979,6 +979,10 @@ def enrich_price_info_with_differences( # Apply level gap tolerance as post-processing step # This smooths out isolated price level changes from Tibber's API if level_gap_tolerance > 0: + for interval in all_intervals: + level = interval.get("level") + if level is not None: + interval.setdefault("_original_level", level) _apply_level_gap_tolerance(all_intervals, level_gap_tolerance) return all_intervals diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py new file mode 100644 index 0000000..fff4844 --- /dev/null +++ b/tests/test_relaxation.py @@ -0,0 +1,86 @@ +"""Focused regression tests for relaxation phase sequencing.""" + +from __future__ import annotations + +from datetime import timedelta +from unittest.mock import Mock + +import pytest + +from custom_components.tibber_prices.coordinator.period_handlers import core as core_module +from custom_components.tibber_prices.coordinator.period_handlers.relaxation import relax_all_prices +from custom_components.tibber_prices.coordinator.period_handlers.types import TibberPricesPeriodConfig +from custom_components.tibber_prices.coordinator.time_service import TibberPricesTimeService +from homeassistant.util import dt as dt_util + + +def _create_interval(base_time, offset: int, price: float, level: str) -> dict: + """Create one quarter-hour interval for relaxation tests.""" + return { + "startsAt": base_time + timedelta(minutes=offset * 15), + "total": price, + "level": level, + } + + +@pytest.mark.unit +@pytest.mark.freeze_time("2025-11-22 12:00:00+01:00") +def test_relaxation_preserves_level_filter_before_trying_any(monkeypatch: pytest.MonkeyPatch) -> None: + """Relaxation should try flex-only phases before dropping the configured level filter.""" + base_time = dt_util.parse_datetime("2025-11-22T12:00:00+01:00") + assert base_time is not None + + mock_coordinator = Mock() + mock_coordinator.config_entry = Mock() + time_service = TibberPricesTimeService(mock_coordinator) + time_service.now = Mock(return_value=base_time) + + all_prices = [ + _create_interval(base_time, 0, 0.18, "CHEAP"), + _create_interval(base_time, 1, 0.19, "CHEAP"), + _create_interval(base_time, 2, 0.22, "NORMAL"), + _create_interval(base_time, 3, 0.31, "EXPENSIVE"), + ] + config = TibberPricesPeriodConfig( + reverse_sort=False, + flex=0.15, + min_distance_from_avg=5.0, + min_period_length=60, + level_filter="cheap", + gap_count=1, + ) + + calculate_periods_calls: list[tuple[float, str | None]] = [] + callback_args: list[str | None] = [] + + def fake_calculate_periods( + _all_prices: list[dict], + *, + config: TibberPricesPeriodConfig, + time: TibberPricesTimeService, + day_patterns_by_date: dict | None = None, + time_range=None, + ) -> dict: + calculate_periods_calls.append((round(config.flex, 2), config.level_filter)) + return {"periods": [], "metadata": {}, "reference_data": {}} + + monkeypatch.setattr(core_module, "calculate_periods", fake_calculate_periods) + + relax_all_prices( + all_prices=all_prices, + config=config, + min_periods=2, + max_relaxation_attempts=2, + should_show_callback=lambda level_override: callback_args.append(level_override) or True, + baseline_periods=[], + time=time_service, + config_entry=mock_coordinator.config_entry, + ) + + assert callback_args == [None, "any", None, "any"] + assert calculate_periods_calls == [ + (0.18, "cheap"), + (0.18, "any"), + (0.21, "cheap"), + (0.21, "any"), + ] diff --git a/tests/test_tomorrow_data_refresh.py b/tests/test_tomorrow_data_refresh.py index ec02db7..c971a33 100644 --- a/tests/test_tomorrow_data_refresh.py +++ b/tests/test_tomorrow_data_refresh.py @@ -40,6 +40,28 @@ def create_price_intervals(day_offset: int = 0) -> list[dict]: return intervals +def create_level_gap_intervals() -> list[dict]: + """Create a small interval sequence where level smoothing changes the display level.""" + base_time = dt_util.now().replace(hour=12, minute=0, second=0, microsecond=0) + levels = ["CHEAP", "CHEAP", "CHEAP", "NORMAL", "CHEAP", "CHEAP"] + totals = [0.10, 0.101, 0.102, 0.18, 0.103, 0.104] + + intervals: list[dict] = [] + for index, (level, total) in enumerate(zip(levels, totals, strict=True)): + interval_time = base_time + timedelta(minutes=index * 15) + intervals.append( + { + "startsAt": interval_time, + "total": total, + "energy": round(total - 0.02, 4), + "tax": 0.02, + "level": level, + } + ) + + return intervals + + @pytest.mark.unit def test_transformation_cache_invalidation_on_new_timestamp() -> None: """ @@ -222,3 +244,46 @@ def test_cache_preserved_when_neither_timestamp_nor_config_changed() -> None: # Verify period calculation was only called ONCE (during first transform) assert mock_period_calc.calculate_periods_for_price_info.call_count == 1 + + +@pytest.mark.unit +def test_transform_data_uses_raw_levels_for_period_calculation() -> None: + """Period calculation must see raw Tibber levels even when priceInfo is smoothed.""" + config_entry = Mock() + config_entry.entry_id = "test_entry" + config_entry.data = {"home_id": "home_123"} + config_entry.options = { + "price_level_gap_tolerance": 1, + "price_rating_gap_tolerance": 0, + } + + time_service = TibberPricesTimeService() + current_time = datetime(2025, 11, 22, 13, 15, 0, tzinfo=ZoneInfo("Europe/Oslo")) + captured_levels: list[str] = [] + + def _capture_period_levels(price_info: list[dict], _day_patterns: dict | None = None) -> dict[str, list]: + captured_levels.extend(interval["level"] for interval in price_info) + assert all("_original_level" not in interval for interval in price_info) + return {"best_price": [], "peak_price": []} + + transformer = TibberPricesDataTransformer( + config_entry=config_entry, + log_prefix="[Test]", + calculate_periods_fn=_capture_period_levels, + time=time_service, + ) + + result = transformer.transform_data( + { + "timestamp": current_time, + "home_id": "home_123", + "price_info": create_level_gap_intervals(), + "currency": "EUR", + } + ) + + smoothed_levels = [interval["level"] for interval in result["priceInfo"]] + + assert smoothed_levels == ["CHEAP", "CHEAP", "CHEAP", "CHEAP", "CHEAP", "CHEAP"] + assert captured_levels == ["CHEAP", "CHEAP", "CHEAP", "NORMAL", "CHEAP", "CHEAP"] + assert all("_original_level" not in interval for interval in result["priceInfo"])