refactor(interval_pool): improve reliability and test coverage

Added async_shutdown() method for proper cleanup on unload - cancels
debounce and background tasks to prevent orphaned task leaks.

Added Phase 1.5 to GC: removes empty fetch groups after dead interval
cleanup, with index rebuild to maintain consistency.

Added update_batch() to TimestampIndex for efficient batch updates.
Touch operations now use batch updates instead of N remove+add calls.

Rewrote memory leak tests for modular architecture - all 9 tests now
pass using new component APIs (cache, index, gc).

Impact: Prevents task leaks on HA restart/reload, reduces memory
overhead from empty groups, improves touch operation performance.
This commit is contained in:
Julian Pawlowski 2025-12-23 10:10:35 +00:00
parent fc64aecdd9
commit 94615dc6cd
5 changed files with 501 additions and 388 deletions

View file

@ -298,6 +298,9 @@ async def async_unload_entry(
await async_save_pool_state(hass, entry.entry_id, pool_state) await async_save_pool_state(hass, entry.entry_id, pool_state)
LOGGER.debug("[%s] Interval pool state saved on unload", entry.title) LOGGER.debug("[%s] Interval pool state saved on unload", entry.title)
# Shutdown interval pool (cancels background tasks)
await entry.runtime_data.interval_pool.async_shutdown()
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok and entry.runtime_data is not None: if unload_ok and entry.runtime_data is not None:

View file

@ -77,6 +77,15 @@ class TibberPricesIntervalPoolGarbageCollector:
self._home_id, self._home_id,
) )
# Phase 1.5: Remove empty fetch groups (after dead interval cleanup)
empty_removed = self._remove_empty_groups(fetch_groups)
if empty_removed > 0:
_LOGGER_DETAILS.debug(
"GC removed %d empty fetch groups (home %s)",
empty_removed,
self._home_id,
)
# Phase 2: Count total intervals after cleanup # Phase 2: Count total intervals after cleanup
total_intervals = self._cache.count_total_intervals() total_intervals = self._cache.count_total_intervals()
@ -94,7 +103,7 @@ class TibberPricesIntervalPoolGarbageCollector:
if not evicted_indices: if not evicted_indices:
# All intervals are protected, cannot evict # All intervals are protected, cannot evict
return dead_count > 0 return dead_count > 0 or empty_removed > 0
# Phase 4: Rebuild cache and index # Phase 4: Rebuild cache and index
new_fetch_groups = [group for idx, group in enumerate(fetch_groups) if idx not in evicted_indices] new_fetch_groups = [group for idx, group in enumerate(fetch_groups) if idx not in evicted_indices]
@ -110,6 +119,35 @@ class TibberPricesIntervalPoolGarbageCollector:
return True return True
def _remove_empty_groups(self, fetch_groups: list[dict[str, Any]]) -> int:
"""
Remove fetch groups with no intervals.
After dead interval cleanup, some groups may be completely empty.
These should be removed to prevent memory accumulation.
Note: This modifies the cache's internal list in-place and rebuilds
the index to maintain consistency.
Args:
fetch_groups: List of fetch groups (will be modified).
Returns:
Number of empty groups removed.
"""
# Find non-empty groups
non_empty_groups = [group for group in fetch_groups if group["intervals"]]
removed_count = len(fetch_groups) - len(non_empty_groups)
if removed_count > 0:
# Update cache with filtered list
self._cache.set_fetch_groups(non_empty_groups)
# Rebuild index since group indices changed
self._index.rebuild(non_empty_groups)
return removed_count
def _cleanup_dead_intervals(self, fetch_groups: list[dict[str, Any]]) -> int: def _cleanup_dead_intervals(self, fetch_groups: list[dict[str, Any]]) -> int:
""" """
Remove dead intervals from all fetch groups. Remove dead intervals from all fetch groups.

View file

@ -93,6 +93,28 @@ class TibberPricesIntervalPoolTimestampIndex:
starts_at_normalized = self._normalize_timestamp(timestamp) starts_at_normalized = self._normalize_timestamp(timestamp)
self._index.pop(starts_at_normalized, None) self._index.pop(starts_at_normalized, None)
def update_batch(
self,
updates: list[tuple[str, int, int]],
) -> None:
"""
Update multiple index entries efficiently in a single operation.
More efficient than calling remove() + add() for each entry,
as it avoids repeated dict operations and normalization.
Args:
updates: List of (timestamp, fetch_group_index, interval_index) tuples.
Timestamps will be normalized automatically.
"""
for timestamp, fetch_group_index, interval_index in updates:
starts_at_normalized = self._normalize_timestamp(timestamp)
self._index[starts_at_normalized] = {
"fetch_group_index": fetch_group_index,
"interval_index": interval_index,
}
def clear(self) -> None: def clear(self) -> None:
"""Clear entire index.""" """Clear entire index."""
self._index.clear() self._index.clear()

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -372,13 +373,13 @@ class TibberPricesIntervalPool:
# Add touch group to cache # Add touch group to cache
touch_group_index = self._cache.add_fetch_group(touch_intervals, fetch_time_dt) touch_group_index = self._cache.add_fetch_group(touch_intervals, fetch_time_dt)
# Update index to point to new fetch group # Update index to point to new fetch group using batch operation
for interval_index, (starts_at_normalized, _) in enumerate(intervals_to_touch): # This is more efficient than individual remove+add calls
# Remove old index entry index_updates = [
self._index.remove(starts_at_normalized) (starts_at_normalized, touch_group_index, interval_index)
# Add new index entry pointing to touch group for interval_index, (starts_at_normalized, _) in enumerate(intervals_to_touch)
interval = touch_intervals[interval_index] ]
self._index.add(interval, touch_group_index, interval_index) self._index.update_batch(index_updates)
_LOGGER.debug( _LOGGER.debug(
"Touched %d cached intervals for home %s (moved to fetch group %d, fetched at %s)", "Touched %d cached intervals for home %s (moved to fetch group %d, fetched at %s)",
@ -419,6 +420,36 @@ class TibberPricesIntervalPool:
_LOGGER.debug("Auto-save timer cancelled (expected - new changes arrived)") _LOGGER.debug("Auto-save timer cancelled (expected - new changes arrived)")
raise raise
async def async_shutdown(self) -> None:
"""
Clean shutdown - cancel pending background tasks.
Should be called when the config entry is unloaded to prevent
orphaned tasks and ensure clean resource cleanup.
"""
_LOGGER.debug("Shutting down interval pool for home %s", self._home_id)
# Cancel debounce task if running
if self._save_debounce_task is not None and not self._save_debounce_task.done():
self._save_debounce_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._save_debounce_task
_LOGGER.debug("Cancelled pending auto-save task")
# Cancel any other background tasks
if self._background_tasks:
for task in list(self._background_tasks):
if not task.done():
task.cancel()
# Wait for all tasks to complete cancellation
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
_LOGGER.debug("Cancelled %d background tasks", len(self._background_tasks))
self._background_tasks.clear()
_LOGGER.debug("Interval pool shutdown complete for home %s", self._home_id)
async def _auto_save_pool_state(self) -> None: async def _auto_save_pool_state(self) -> None:
"""Auto-save pool state to storage with lock protection.""" """Auto-save pool state to storage with lock protection."""
if self._hass is None or self._entry_id is None: if self._hass is None or self._entry_id is None:

View file

@ -5,39 +5,54 @@ This test module verifies that touch operations don't cause memory leaks by:
1. Reusing existing interval dicts (Python references, not copies) 1. Reusing existing interval dicts (Python references, not copies)
2. Dead intervals being cleaned up by GC 2. Dead intervals being cleaned up by GC
3. Serialization filtering out dead intervals from storage 3. Serialization filtering out dead intervals from storage
4. Empty fetch groups being removed after cleanup
NOTE: These tests are currently skipped due to the interval pool refactoring. Architecture:
The tests access internal attributes (_fetch_groups, _timestamp_index, _gc_cleanup_dead_intervals) The interval pool uses a modular architecture:
that were part of the old monolithic pool.py implementation. After the refactoring into - TibberPricesIntervalPool (manager.py): Main coordinator
separate modules (cache.py, index.py, garbage_collector.py, fetcher.py, manager.py), - TibberPricesIntervalPoolFetchGroupCache (cache.py): Fetch group storage
these internal APIs changed and the tests need to be rewritten. - TibberPricesIntervalPoolTimestampIndex (index.py): O(1) timestamp lookup
- TibberPricesIntervalPoolGarbageCollector (garbage_collector.py): Eviction/cleanup
- TibberPricesIntervalPoolFetcher (fetcher.py): Gap detection and API calls
TODO: Rewrite these tests to work with the new modular architecture: Tests access internal components directly for fine-grained verification.
- Mock the api parameter (TibberPricesApiClient)
- Use public APIs instead of accessing internal attributes
- Test garbage collection through the manager's public interface
""" """
import json import json
from datetime import UTC, datetime from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest import pytest
from custom_components.tibber_prices.interval_pool import TibberPricesIntervalPool from custom_components.tibber_prices.interval_pool.cache import (
TibberPricesIntervalPoolFetchGroupCache,
# Skip all tests in this module until they are rewritten for the new modular architecture )
pytestmark = pytest.mark.skip(reason="Tests need rewrite for modular architecture (manager/cache/index/gc/fetcher)") from custom_components.tibber_prices.interval_pool.garbage_collector import (
TibberPricesIntervalPoolGarbageCollector,
)
from custom_components.tibber_prices.interval_pool.index import (
TibberPricesIntervalPoolTimestampIndex,
)
from custom_components.tibber_prices.interval_pool.manager import (
TibberPricesIntervalPool,
)
@pytest.fixture @pytest.fixture
def pool() -> TibberPricesIntervalPool: def mock_api() -> MagicMock:
"""Create a shared interval pool for testing (single-home architecture).""" """Create a mock API client."""
return TibberPricesIntervalPool(home_id="test_home_id") return MagicMock()
@pytest.fixture
def pool(mock_api: MagicMock) -> TibberPricesIntervalPool:
"""Create an interval pool for testing (single-home architecture)."""
return TibberPricesIntervalPool(home_id="test_home_id", api=mock_api)
@pytest.fixture @pytest.fixture
def sample_intervals() -> list[dict]: def sample_intervals() -> list[dict]:
"""Create 24 sample intervals (1 day).""" """Create 24 sample intervals (1 day, hourly)."""
base_time = datetime(2025, 11, 25, 0, 0, 0, tzinfo=UTC) base_time = datetime(2025, 11, 25, 0, 0, 0, tzinfo=UTC)
return [ return [
{ {
@ -50,15 +65,30 @@ def sample_intervals() -> list[dict]:
] ]
@pytest.fixture
def cache() -> TibberPricesIntervalPoolFetchGroupCache:
"""Create a fresh cache instance for testing."""
return TibberPricesIntervalPoolFetchGroupCache()
@pytest.fixture
def index() -> TibberPricesIntervalPoolTimestampIndex:
"""Create a fresh index instance for testing."""
return TibberPricesIntervalPoolTimestampIndex()
class TestTouchOperations:
"""Test touch operations (re-fetching same intervals)."""
def test_touch_operation_reuses_existing_intervals( def test_touch_operation_reuses_existing_intervals(
self,
pool: TibberPricesIntervalPool, pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Test that touch operations reuse existing interval dicts (references, not copies).""" """Test that touch operations reuse existing interval dicts (references, not copies)."""
# home_id not needed (single-home architecture)
fetch_time_1 = "2025-11-25T10:00:00+01:00" fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00" fetch_time_2 = "2025-11-25T10:15:00+01:00"
# Create sample intervals for this test # Create sample intervals
sample_intervals = [ sample_intervals = [
{ {
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(), "startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
@ -70,8 +100,8 @@ def test_touch_operation_reuses_existing_intervals(
# First fetch: Add intervals # First fetch: Add intervals
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Direct property access (single-home architecture) # Access internal cache
fetch_groups = pool._fetch_groups # noqa: SLF001 fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Verify: 1 fetch group with 24 intervals # Verify: 1 fetch group with 24 intervals
assert len(fetch_groups) == 1 assert len(fetch_groups) == 1
@ -84,6 +114,9 @@ def test_touch_operation_reuses_existing_intervals(
# Second fetch: Touch same intervals # Second fetch: Touch same intervals
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# Re-fetch groups (list may have changed)
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Verify: Now we have 2 fetch groups # Verify: Now we have 2 fetch groups
assert len(fetch_groups) == 2 assert len(fetch_groups) == 2
@ -95,12 +128,11 @@ def test_touch_operation_reuses_existing_intervals(
assert original_id == touched_id, f"Memory addresses differ: {original_id} != {touched_id}" assert original_id == touched_id, f"Memory addresses differ: {original_id} != {touched_id}"
assert first_interval_original is first_interval_touched, "Touch should reuse existing dict, not create copy" assert first_interval_original is first_interval_touched, "Touch should reuse existing dict, not create copy"
def test_touch_operation_updates_index(
def test_touch_operation_leaves_dead_intervals_in_old_group( self,
pool: TibberPricesIntervalPool, pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Test that touch operations leave 'dead' intervals in old fetch groups.""" """Test that touch operations update the index to point to new fetch group."""
# home_id not needed (single-home architecture)
fetch_time_1 = "2025-11-25T10:00:00+01:00" fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00" fetch_time_2 = "2025-11-25T10:15:00+01:00"
@ -115,35 +147,32 @@ def test_touch_operation_leaves_dead_intervals_in_old_group(
# First fetch # First fetch
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Direct property access (single-home architecture)
fetch_groups = pool._fetch_groups # noqa: SLF001
# Second fetch (touch all intervals) # Verify index points to group 0
first_key = sample_intervals[0]["startsAt"][:19]
index_entry = pool._index.get(first_key) # noqa: SLF001
assert index_entry is not None
assert index_entry["fetch_group_index"] == 0
# Second fetch (touch)
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# BEFORE GC cleanup: # Verify index now points to group 1 (touch group)
# - Old group still has 24 intervals (but they're all "dead" - index points elsewhere) index_entry = pool._index.get(first_key) # noqa: SLF001
# - Touch group has 24 intervals (living - index points here) assert index_entry is not None
assert len(fetch_groups) == 2, "Should have 2 fetch groups"
assert len(fetch_groups[0]["intervals"]) == 24, "Old group should still have intervals (dead)"
assert len(fetch_groups[1]["intervals"]) == 24, "Touch group should have intervals (living)"
# Verify index points to touch group (not old group)
timestamp_index = pool._timestamp_index # noqa: SLF001
first_key = sample_intervals[0]["startsAt"][:19]
index_entry = timestamp_index[first_key]
assert index_entry["fetch_group_index"] == 1, "Index should point to touch group" assert index_entry["fetch_group_index"] == 1, "Index should point to touch group"
class TestGarbageCollection:
"""Test garbage collection and dead interval cleanup."""
def test_gc_cleanup_removes_dead_intervals( def test_gc_cleanup_removes_dead_intervals(
pool: TibberPricesIntervalPool, self,
cache: TibberPricesIntervalPoolFetchGroupCache,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None: ) -> None:
"""Test that GC cleanup removes dead intervals from old fetch groups.""" """Test that GC cleanup removes dead intervals from old fetch groups."""
# home_id not needed (single-home architecture) gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00"
# Create sample intervals # Create sample intervals
sample_intervals = [ sample_intervals = [
@ -154,37 +183,83 @@ def test_gc_cleanup_removes_dead_intervals(
for h in range(24) for h in range(24)
] ]
# First fetch # First fetch: Add to cache and index
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001 fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
group_idx_1 = cache.add_fetch_group(sample_intervals, fetch_time_1)
for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_1, i)
# Second fetch (touch all intervals) # Verify initial state
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001 assert cache.count_total_intervals() == 24
assert index.count() == 24
# Direct property access (single-home architecture) # Second fetch (touch): Create new fetch group
fetch_groups = pool._fetch_groups # noqa: SLF001 fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
timestamp_index = pool._timestamp_index # noqa: SLF001 group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
# Before cleanup: old group has 24 intervals # Update index to point to new group (simulates touch)
assert len(fetch_groups[0]["intervals"]) == 24, "Before cleanup" for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_2, i)
# Run GC cleanup explicitly # Before GC: 48 intervals in cache (24 dead + 24 living), 24 in index
dead_count = pool._gc_cleanup_dead_intervals(fetch_groups, timestamp_index) # noqa: SLF001 assert cache.count_total_intervals() == 48
assert index.count() == 24
# Verify: 24 dead intervals were removed # Run GC
assert dead_count == 24, f"Expected 24 dead intervals, got {dead_count}" gc_changed = gc.run_gc()
# After cleanup: old group should be empty # After GC: Dead intervals cleaned, empty group removed
assert len(fetch_groups[0]["intervals"]) == 0, "Old group should be empty after cleanup" assert gc_changed is True
assert cache.count_total_intervals() == 24, "Should only have living intervals"
# Touch group still has 24 living intervals def test_gc_removes_empty_fetch_groups(
assert len(fetch_groups[1]["intervals"]) == 24, "Touch group should still have intervals" self,
cache: TibberPricesIntervalPoolFetchGroupCache,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None:
"""Test that GC removes empty fetch groups after dead interval cleanup."""
gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(4) # Small set
]
# Add two fetch groups
fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
cache.add_fetch_group(sample_intervals, fetch_time_1)
group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
# Index points only to second group
for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_2, i)
# Before GC: 2 groups
assert len(cache.get_fetch_groups()) == 2
# Run GC
gc.run_gc()
# After GC: Only 1 group (empty one removed)
fetch_groups = cache.get_fetch_groups()
assert len(fetch_groups) == 1, "Empty fetch group should be removed"
assert len(fetch_groups[0]["intervals"]) == 4
class TestSerialization:
"""Test serialization excludes dead intervals."""
def test_serialization_excludes_dead_intervals( def test_serialization_excludes_dead_intervals(
self,
pool: TibberPricesIntervalPool, pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Test that to_dict() excludes dead intervals from serialization.""" """Test that to_dict() excludes dead intervals from serialization."""
# home_id not needed (single-home architecture)
fetch_time_1 = "2025-11-25T10:00:00+01:00" fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00" fetch_time_2 = "2025-11-25T10:15:00+01:00"
@ -200,7 +275,7 @@ def test_serialization_excludes_dead_intervals(
# First fetch # First fetch
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Second fetch (touch all intervals) # Second fetch (touch)
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# Serialize WITHOUT running GC cleanup first # Serialize WITHOUT running GC cleanup first
@ -211,27 +286,18 @@ def test_serialization_excludes_dead_intervals(
assert "home_id" in serialized assert "home_id" in serialized
fetch_groups = serialized["fetch_groups"] fetch_groups = serialized["fetch_groups"]
# CRITICAL: Should only serialize touch group (living intervals) # CRITICAL: Should only serialize living intervals
# Old group with all dead intervals should NOT be serialized # Old group with dead intervals should NOT be serialized
assert len(fetch_groups) == 1, "Should only serialize groups with living intervals" total_serialized_intervals = sum(len(g["intervals"]) for g in fetch_groups)
assert total_serialized_intervals == 24, (
# Touch group should have all 24 intervals f"Should only serialize 24 living intervals, got {total_serialized_intervals}"
assert len(fetch_groups[0]["intervals"]) == 24, "Touch group should have all intervals" )
# Verify JSON size is reasonable (not 2x the size)
json_str = json.dumps(serialized)
json_size = len(json_str)
# Each interval is ~100-150 bytes, 24 intervals = ~2.4-3.6 KB
# With metadata + structure, expect < 5 KB
assert json_size < 5000, f"JSON too large: {json_size} bytes (expected < 5000)"
def test_repeated_touch_operations_dont_grow_storage( def test_repeated_touch_operations_dont_grow_storage(
self,
pool: TibberPricesIntervalPool, pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Test that repeated touch operations don't grow storage size unbounded.""" """Test that repeated touch operations don't grow storage size unbounded."""
# home_id not needed (single-home architecture)
# Create sample intervals # Create sample intervals
sample_intervals = [ sample_intervals = [
{ {
@ -246,22 +312,13 @@ def test_repeated_touch_operations_dont_grow_storage(
fetch_time = f"2025-11-25T{10 + i}:00:00+01:00" fetch_time = f"2025-11-25T{10 + i}:00:00+01:00"
pool._add_intervals(sample_intervals, fetch_time) # noqa: SLF001 pool._add_intervals(sample_intervals, fetch_time) # noqa: SLF001
# Memory state: 10 fetch groups (9 empty, 1 with all intervals)
# Direct property access (single-home architecture)
fetch_groups = pool._fetch_groups # noqa: SLF001
assert len(fetch_groups) == 10, "Should have 10 fetch groups in memory"
# Total intervals in memory: 240 references (24 per group, mostly dead)
total_refs = sum(len(g["intervals"]) for g in fetch_groups)
assert total_refs == 24 * 10, "Memory should have 240 interval references"
# Serialize (filters dead intervals) # Serialize (filters dead intervals)
serialized = pool.to_dict() serialized = pool.to_dict()
serialized_groups = serialized["fetch_groups"] serialized_groups = serialized["fetch_groups"]
# Storage should only have 1 group with 24 living intervals # Storage should only have 24 living intervals total
assert len(serialized_groups) == 1, "Should only serialize 1 group (with living intervals)" total_intervals = sum(len(g["intervals"]) for g in serialized_groups)
assert len(serialized_groups[0]["intervals"]) == 24, "Should only have 24 living intervals" assert total_intervals == 24, f"Should only have 24 living intervals, got {total_intervals}"
# Verify storage size is bounded # Verify storage size is bounded
json_str = json.dumps(serialized) json_str = json.dumps(serialized)
@ -270,110 +327,73 @@ def test_repeated_touch_operations_dont_grow_storage(
assert json_size < 10000, f"Storage grew unbounded: {json_size} bytes (expected < 10000)" assert json_size < 10000, f"Storage grew unbounded: {json_size} bytes (expected < 10000)"
def test_gc_cleanup_with_partial_touch( class TestIndexBatchUpdate:
pool: TibberPricesIntervalPool, """Test batch index update functionality."""
sample_intervals: list[dict],
def test_batch_update_efficiency(
self,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None: ) -> None:
"""Test GC cleanup when only some intervals are touched (partial overlap).""" """Test that batch update correctly updates multiple entries."""
# home_id not needed (single-home architecture) # Create test intervals
fetch_time_1 = "2025-11-25T10:00:00+01:00" timestamps = [f"2025-11-25T{h:02d}:00:00" for h in range(24)]
fetch_time_2 = "2025-11-25T10:15:00+01:00"
# First fetch: All 24 intervals # Add intervals pointing to group 0
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001 for i, ts in enumerate(timestamps):
index.add({"startsAt": ts}, 0, i)
# Second fetch: Only first 12 intervals (partial touch) # Verify initial state
partial_intervals = sample_intervals[:12] assert index.count() == 24
pool._add_intervals(partial_intervals, fetch_time_2) # noqa: SLF001 for ts in timestamps:
entry = index.get(ts)
assert entry is not None
assert entry["fetch_group_index"] == 0
# Direct property access (single-home architecture) # Batch update to point to group 1
fetch_groups = pool._fetch_groups # noqa: SLF001 updates = [(ts, 1, i) for i, ts in enumerate(timestamps)]
timestamp_index = pool._timestamp_index # noqa: SLF001 index.update_batch(updates)
# Before cleanup: # Verify all entries now point to group 1
# - Old group: 24 intervals (12 dead, 12 living) for ts in timestamps:
# - Touch group: 12 intervals (all living) entry = index.get(ts)
assert len(fetch_groups[0]["intervals"]) == 24, "Old group should have 24 intervals" assert entry is not None
assert len(fetch_groups[1]["intervals"]) == 12, "Touch group should have 12 intervals" assert entry["fetch_group_index"] == 1, f"Entry for {ts} should point to group 1"
# Run GC cleanup def test_batch_update_with_partial_overlap(
dead_count = pool._gc_cleanup_dead_intervals(fetch_groups, timestamp_index) # noqa: SLF001 self,
index: TibberPricesIntervalPoolTimestampIndex,
# Should clean 12 dead intervals (the ones that were touched)
assert dead_count == 12, f"Expected 12 dead intervals, got {dead_count}"
# After cleanup:
# - Old group: 12 intervals (the ones that were NOT touched)
# - Touch group: 12 intervals (unchanged)
assert len(fetch_groups[0]["intervals"]) == 12, "Old group should have 12 living intervals left"
assert len(fetch_groups[1]["intervals"]) == 12, "Touch group should still have 12 intervals"
def test_memory_leak_prevention_integration(
pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Integration test: Verify no memory leak over multiple operations.""" """Test batch update with only some existing entries."""
# home_id not needed (single-home architecture) # Add initial entries (0-11)
for i in range(12):
ts = f"2025-11-25T{i:02d}:00:00"
index.add({"startsAt": ts}, 0, i)
# Create sample intervals assert index.count() == 12
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# Simulate typical usage pattern over time # Batch update: update first 6, add 6 new (12-17)
# Day 1: Fetch 24 intervals updates = [(f"2025-11-25T{i:02d}:00:00", 1, i) for i in range(18)]
pool._add_intervals(sample_intervals, "2025-11-25T10:00:00+01:00") # noqa: SLF001 index.update_batch(updates)
# Day 1: Re-fetch (touch) - updates fetch time # Should now have 18 entries (12 existing + 6 new)
pool._add_intervals(sample_intervals, "2025-11-25T14:00:00+01:00") # noqa: SLF001 assert index.count() == 18
# Day 1: Re-fetch (touch) again # All should point to group 1
pool._add_intervals(sample_intervals, "2025-11-25T18:00:00+01:00") # noqa: SLF001 for i in range(18):
ts = f"2025-11-25T{i:02d}:00:00"
entry = index.get(ts)
assert entry is not None
assert entry["fetch_group_index"] == 1
# Direct property access (single-home architecture)
fetch_groups = pool._fetch_groups # noqa: SLF001
timestamp_index = pool._timestamp_index # noqa: SLF001
# Memory state BEFORE cleanup:
# - 3 fetch groups
# - Total: 72 interval references (24 per group)
# - Dead: 48 (first 2 groups have all dead intervals)
# - Living: 24 (last group has all living intervals)
assert len(fetch_groups) == 3, "Should have 3 fetch groups"
total_refs = sum(len(g["intervals"]) for g in fetch_groups)
assert total_refs == 72, "Should have 72 interval references in memory"
# Run GC cleanup
dead_count = pool._gc_cleanup_dead_intervals(fetch_groups, timestamp_index) # noqa: SLF001
assert dead_count == 48, "Should clean 48 dead intervals"
# Memory state AFTER cleanup:
# - 3 fetch groups (2 empty, 1 with all intervals)
# - Total: 24 interval references
# - Dead: 0
# - Living: 24
total_refs_after = sum(len(g["intervals"]) for g in fetch_groups)
assert total_refs_after == 24, "Should only have 24 interval references after cleanup"
# Verify serialization excludes empty groups
serialized = pool.to_dict()
serialized_groups = serialized["fetch_groups"]
# Should only serialize 1 group (the one with living intervals)
assert len(serialized_groups) == 1, "Should only serialize groups with living intervals"
assert len(serialized_groups[0]["intervals"]) == 24, "Should have 24 intervals"
class TestIntervalIdentityPreservation:
"""Test that interval dict identity is preserved across operations."""
def test_interval_identity_preserved_across_touch( def test_interval_identity_preserved_across_touch(
self,
pool: TibberPricesIntervalPool, pool: TibberPricesIntervalPool,
) -> None: ) -> None:
"""Test that interval dict identity (memory address) is preserved across touch.""" """Test that interval dict identity (memory address) is preserved across touch."""
# home_id not needed (single-home architecture)
# Create sample intervals # Create sample intervals
sample_intervals = [ sample_intervals = [
{ {
@ -386,8 +406,8 @@ def test_interval_identity_preserved_across_touch(
# First fetch # First fetch
pool._add_intervals(sample_intervals, "2025-11-25T10:00:00+01:00") # noqa: SLF001 pool._add_intervals(sample_intervals, "2025-11-25T10:00:00+01:00") # noqa: SLF001
# Direct property access (single-home architecture) # Get fetch groups
fetch_groups = pool._fetch_groups # noqa: SLF001 fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Collect memory addresses of intervals in original group # Collect memory addresses of intervals in original group
original_ids = [id(interval) for interval in fetch_groups[0]["intervals"]] original_ids = [id(interval) for interval in fetch_groups[0]["intervals"]]
@ -395,6 +415,9 @@ def test_interval_identity_preserved_across_touch(
# Second fetch (touch) # Second fetch (touch)
pool._add_intervals(sample_intervals, "2025-11-25T10:15:00+01:00") # noqa: SLF001 pool._add_intervals(sample_intervals, "2025-11-25T10:15:00+01:00") # noqa: SLF001
# Re-fetch groups
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Collect memory addresses of intervals in touch group # Collect memory addresses of intervals in touch group
touched_ids = [id(interval) for interval in fetch_groups[1]["intervals"]] touched_ids = [id(interval) for interval in fetch_groups[1]["intervals"]]
@ -404,13 +427,9 @@ def test_interval_identity_preserved_across_touch(
# Third fetch (touch again) # Third fetch (touch again)
pool._add_intervals(sample_intervals, "2025-11-25T10:30:00+01:00") # noqa: SLF001 pool._add_intervals(sample_intervals, "2025-11-25T10:30:00+01:00") # noqa: SLF001
# Re-fetch groups
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# New touch group should also reference the SAME original objects # New touch group should also reference the SAME original objects
touched_ids_2 = [id(interval) for interval in fetch_groups[2]["intervals"]] touched_ids_2 = [id(interval) for interval in fetch_groups[2]["intervals"]]
assert original_ids == touched_ids_2, "Multiple touches should preserve original identity" assert original_ids == touched_ids_2, "Multiple touches should preserve original identity"
# Verify: All 3 groups have references to THE SAME interval dicts
# Only the list entries differ (8 bytes each), not the interval dicts (600+ bytes each)
for i in range(24):
assert fetch_groups[0]["intervals"][i] is fetch_groups[1]["intervals"][i] is fetch_groups[2]["intervals"][i], (
f"Interval {i} should be the same object across all groups"
)