hass.tibber_prices/tests/test_interval_pool_memory_leak.py
Julian Pawlowski 94615dc6cd refactor(interval_pool): improve reliability and test coverage
Added async_shutdown() method for proper cleanup on unload - cancels
debounce and background tasks to prevent orphaned task leaks.

Added Phase 1.5 to GC: removes empty fetch groups after dead interval
cleanup, with index rebuild to maintain consistency.

Added update_batch() to TimestampIndex for efficient batch updates.
Touch operations now use batch updates instead of N remove+add calls.

Rewrote memory leak tests for modular architecture - all 9 tests now
pass using new component APIs (cache, index, gc).

Impact: Prevents task leaks on HA restart/reload, reduces memory
overhead from empty groups, improves touch operation performance.
2025-12-23 10:10:35 +00:00

435 lines
16 KiB
Python

"""
Tests for memory leak prevention in interval pool.
This test module verifies that touch operations don't cause memory leaks by:
1. Reusing existing interval dicts (Python references, not copies)
2. Dead intervals being cleaned up by GC
3. Serialization filtering out dead intervals from storage
4. Empty fetch groups being removed after cleanup
Architecture:
The interval pool uses a modular architecture:
- TibberPricesIntervalPool (manager.py): Main coordinator
- TibberPricesIntervalPoolFetchGroupCache (cache.py): Fetch group storage
- TibberPricesIntervalPoolTimestampIndex (index.py): O(1) timestamp lookup
- TibberPricesIntervalPoolGarbageCollector (garbage_collector.py): Eviction/cleanup
- TibberPricesIntervalPoolFetcher (fetcher.py): Gap detection and API calls
Tests access internal components directly for fine-grained verification.
"""
import json
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
from custom_components.tibber_prices.interval_pool.cache import (
TibberPricesIntervalPoolFetchGroupCache,
)
from custom_components.tibber_prices.interval_pool.garbage_collector import (
TibberPricesIntervalPoolGarbageCollector,
)
from custom_components.tibber_prices.interval_pool.index import (
TibberPricesIntervalPoolTimestampIndex,
)
from custom_components.tibber_prices.interval_pool.manager import (
TibberPricesIntervalPool,
)
@pytest.fixture
def mock_api() -> MagicMock:
"""Create a mock API client."""
return MagicMock()
@pytest.fixture
def pool(mock_api: MagicMock) -> TibberPricesIntervalPool:
"""Create an interval pool for testing (single-home architecture)."""
return TibberPricesIntervalPool(home_id="test_home_id", api=mock_api)
@pytest.fixture
def sample_intervals() -> list[dict]:
"""Create 24 sample intervals (1 day, hourly)."""
base_time = datetime(2025, 11, 25, 0, 0, 0, tzinfo=UTC)
return [
{
"startsAt": (base_time.replace(hour=h)).isoformat(),
"total": 10.0 + h,
"energy": 8.0 + h,
"tax": 2.0,
}
for h in range(24)
]
@pytest.fixture
def cache() -> TibberPricesIntervalPoolFetchGroupCache:
"""Create a fresh cache instance for testing."""
return TibberPricesIntervalPoolFetchGroupCache()
@pytest.fixture
def index() -> TibberPricesIntervalPoolTimestampIndex:
"""Create a fresh index instance for testing."""
return TibberPricesIntervalPoolTimestampIndex()
class TestTouchOperations:
"""Test touch operations (re-fetching same intervals)."""
def test_touch_operation_reuses_existing_intervals(
self,
pool: TibberPricesIntervalPool,
) -> None:
"""Test that touch operations reuse existing interval dicts (references, not copies)."""
fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00"
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# First fetch: Add intervals
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Access internal cache
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Verify: 1 fetch group with 24 intervals
assert len(fetch_groups) == 1
assert len(fetch_groups[0]["intervals"]) == 24
# Get reference to first interval
first_interval_original = fetch_groups[0]["intervals"][0]
original_id = id(first_interval_original)
# Second fetch: Touch same intervals
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# Re-fetch groups (list may have changed)
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Verify: Now we have 2 fetch groups
assert len(fetch_groups) == 2
# Get reference to first interval from TOUCH group
first_interval_touched = fetch_groups[1]["intervals"][0]
touched_id = id(first_interval_touched)
# CRITICAL: Should be SAME object (same memory address)
assert original_id == touched_id, f"Memory addresses differ: {original_id} != {touched_id}"
assert first_interval_original is first_interval_touched, "Touch should reuse existing dict, not create copy"
def test_touch_operation_updates_index(
self,
pool: TibberPricesIntervalPool,
) -> None:
"""Test that touch operations update the index to point to new fetch group."""
fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00"
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# First fetch
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Verify index points to group 0
first_key = sample_intervals[0]["startsAt"][:19]
index_entry = pool._index.get(first_key) # noqa: SLF001
assert index_entry is not None
assert index_entry["fetch_group_index"] == 0
# Second fetch (touch)
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# Verify index now points to group 1 (touch group)
index_entry = pool._index.get(first_key) # noqa: SLF001
assert index_entry is not None
assert index_entry["fetch_group_index"] == 1, "Index should point to touch group"
class TestGarbageCollection:
"""Test garbage collection and dead interval cleanup."""
def test_gc_cleanup_removes_dead_intervals(
self,
cache: TibberPricesIntervalPoolFetchGroupCache,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None:
"""Test that GC cleanup removes dead intervals from old fetch groups."""
gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# First fetch: Add to cache and index
fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
group_idx_1 = cache.add_fetch_group(sample_intervals, fetch_time_1)
for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_1, i)
# Verify initial state
assert cache.count_total_intervals() == 24
assert index.count() == 24
# Second fetch (touch): Create new fetch group
fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
# Update index to point to new group (simulates touch)
for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_2, i)
# Before GC: 48 intervals in cache (24 dead + 24 living), 24 in index
assert cache.count_total_intervals() == 48
assert index.count() == 24
# Run GC
gc_changed = gc.run_gc()
# After GC: Dead intervals cleaned, empty group removed
assert gc_changed is True
assert cache.count_total_intervals() == 24, "Should only have living intervals"
def test_gc_removes_empty_fetch_groups(
self,
cache: TibberPricesIntervalPoolFetchGroupCache,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None:
"""Test that GC removes empty fetch groups after dead interval cleanup."""
gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(4) # Small set
]
# Add two fetch groups
fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
cache.add_fetch_group(sample_intervals, fetch_time_1)
group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
# Index points only to second group
for i, interval in enumerate(sample_intervals):
index.add(interval, group_idx_2, i)
# Before GC: 2 groups
assert len(cache.get_fetch_groups()) == 2
# Run GC
gc.run_gc()
# After GC: Only 1 group (empty one removed)
fetch_groups = cache.get_fetch_groups()
assert len(fetch_groups) == 1, "Empty fetch group should be removed"
assert len(fetch_groups[0]["intervals"]) == 4
class TestSerialization:
"""Test serialization excludes dead intervals."""
def test_serialization_excludes_dead_intervals(
self,
pool: TibberPricesIntervalPool,
) -> None:
"""Test that to_dict() excludes dead intervals from serialization."""
fetch_time_1 = "2025-11-25T10:00:00+01:00"
fetch_time_2 = "2025-11-25T10:15:00+01:00"
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# First fetch
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
# Second fetch (touch)
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
# Serialize WITHOUT running GC cleanup first
serialized = pool.to_dict()
# Verify serialization structure
assert "fetch_groups" in serialized
assert "home_id" in serialized
fetch_groups = serialized["fetch_groups"]
# CRITICAL: Should only serialize living intervals
# Old group with dead intervals should NOT be serialized
total_serialized_intervals = sum(len(g["intervals"]) for g in fetch_groups)
assert total_serialized_intervals == 24, (
f"Should only serialize 24 living intervals, got {total_serialized_intervals}"
)
def test_repeated_touch_operations_dont_grow_storage(
self,
pool: TibberPricesIntervalPool,
) -> None:
"""Test that repeated touch operations don't grow storage size unbounded."""
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# Simulate 10 re-fetches of the same intervals
for i in range(10):
fetch_time = f"2025-11-25T{10 + i}:00:00+01:00"
pool._add_intervals(sample_intervals, fetch_time) # noqa: SLF001
# Serialize (filters dead intervals)
serialized = pool.to_dict()
serialized_groups = serialized["fetch_groups"]
# Storage should only have 24 living intervals total
total_intervals = sum(len(g["intervals"]) for g in serialized_groups)
assert total_intervals == 24, f"Should only have 24 living intervals, got {total_intervals}"
# Verify storage size is bounded
json_str = json.dumps(serialized)
json_size = len(json_str)
# Should still be < 10 KB even after 10 fetches
assert json_size < 10000, f"Storage grew unbounded: {json_size} bytes (expected < 10000)"
class TestIndexBatchUpdate:
"""Test batch index update functionality."""
def test_batch_update_efficiency(
self,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None:
"""Test that batch update correctly updates multiple entries."""
# Create test intervals
timestamps = [f"2025-11-25T{h:02d}:00:00" for h in range(24)]
# Add intervals pointing to group 0
for i, ts in enumerate(timestamps):
index.add({"startsAt": ts}, 0, i)
# Verify initial state
assert index.count() == 24
for ts in timestamps:
entry = index.get(ts)
assert entry is not None
assert entry["fetch_group_index"] == 0
# Batch update to point to group 1
updates = [(ts, 1, i) for i, ts in enumerate(timestamps)]
index.update_batch(updates)
# Verify all entries now point to group 1
for ts in timestamps:
entry = index.get(ts)
assert entry is not None
assert entry["fetch_group_index"] == 1, f"Entry for {ts} should point to group 1"
def test_batch_update_with_partial_overlap(
self,
index: TibberPricesIntervalPoolTimestampIndex,
) -> None:
"""Test batch update with only some existing entries."""
# Add initial entries (0-11)
for i in range(12):
ts = f"2025-11-25T{i:02d}:00:00"
index.add({"startsAt": ts}, 0, i)
assert index.count() == 12
# Batch update: update first 6, add 6 new (12-17)
updates = [(f"2025-11-25T{i:02d}:00:00", 1, i) for i in range(18)]
index.update_batch(updates)
# Should now have 18 entries (12 existing + 6 new)
assert index.count() == 18
# All should point to group 1
for i in range(18):
ts = f"2025-11-25T{i:02d}:00:00"
entry = index.get(ts)
assert entry is not None
assert entry["fetch_group_index"] == 1
class TestIntervalIdentityPreservation:
"""Test that interval dict identity is preserved across operations."""
def test_interval_identity_preserved_across_touch(
self,
pool: TibberPricesIntervalPool,
) -> None:
"""Test that interval dict identity (memory address) is preserved across touch."""
# Create sample intervals
sample_intervals = [
{
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
"total": 10.0 + h,
}
for h in range(24)
]
# First fetch
pool._add_intervals(sample_intervals, "2025-11-25T10:00:00+01:00") # noqa: SLF001
# Get fetch groups
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Collect memory addresses of intervals in original group
original_ids = [id(interval) for interval in fetch_groups[0]["intervals"]]
# Second fetch (touch)
pool._add_intervals(sample_intervals, "2025-11-25T10:15:00+01:00") # noqa: SLF001
# Re-fetch groups
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# Collect memory addresses of intervals in touch group
touched_ids = [id(interval) for interval in fetch_groups[1]["intervals"]]
# CRITICAL: All memory addresses should be identical (same objects)
assert original_ids == touched_ids, "Touch should preserve interval identity (memory addresses)"
# Third fetch (touch again)
pool._add_intervals(sample_intervals, "2025-11-25T10:30:00+01:00") # noqa: SLF001
# Re-fetch groups
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
# New touch group should also reference the SAME original objects
touched_ids_2 = [id(interval) for interval in fetch_groups[2]["intervals"]]
assert original_ids == touched_ids_2, "Multiple touches should preserve original identity"