mirror of
https://github.com/jpawlowski/hass.tibber_prices.git
synced 2026-03-29 21:03:40 +00:00
Added async_shutdown() method for proper cleanup on unload - cancels debounce and background tasks to prevent orphaned task leaks. Added Phase 1.5 to GC: removes empty fetch groups after dead interval cleanup, with index rebuild to maintain consistency. Added update_batch() to TimestampIndex for efficient batch updates. Touch operations now use batch updates instead of N remove+add calls. Rewrote memory leak tests for modular architecture - all 9 tests now pass using new component APIs (cache, index, gc). Impact: Prevents task leaks on HA restart/reload, reduces memory overhead from empty groups, improves touch operation performance.
435 lines
16 KiB
Python
435 lines
16 KiB
Python
"""
|
|
Tests for memory leak prevention in interval pool.
|
|
|
|
This test module verifies that touch operations don't cause memory leaks by:
|
|
1. Reusing existing interval dicts (Python references, not copies)
|
|
2. Dead intervals being cleaned up by GC
|
|
3. Serialization filtering out dead intervals from storage
|
|
4. Empty fetch groups being removed after cleanup
|
|
|
|
Architecture:
|
|
The interval pool uses a modular architecture:
|
|
- TibberPricesIntervalPool (manager.py): Main coordinator
|
|
- TibberPricesIntervalPoolFetchGroupCache (cache.py): Fetch group storage
|
|
- TibberPricesIntervalPoolTimestampIndex (index.py): O(1) timestamp lookup
|
|
- TibberPricesIntervalPoolGarbageCollector (garbage_collector.py): Eviction/cleanup
|
|
- TibberPricesIntervalPoolFetcher (fetcher.py): Gap detection and API calls
|
|
|
|
Tests access internal components directly for fine-grained verification.
|
|
"""
|
|
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from custom_components.tibber_prices.interval_pool.cache import (
|
|
TibberPricesIntervalPoolFetchGroupCache,
|
|
)
|
|
from custom_components.tibber_prices.interval_pool.garbage_collector import (
|
|
TibberPricesIntervalPoolGarbageCollector,
|
|
)
|
|
from custom_components.tibber_prices.interval_pool.index import (
|
|
TibberPricesIntervalPoolTimestampIndex,
|
|
)
|
|
from custom_components.tibber_prices.interval_pool.manager import (
|
|
TibberPricesIntervalPool,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_api() -> MagicMock:
|
|
"""Create a mock API client."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def pool(mock_api: MagicMock) -> TibberPricesIntervalPool:
|
|
"""Create an interval pool for testing (single-home architecture)."""
|
|
return TibberPricesIntervalPool(home_id="test_home_id", api=mock_api)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_intervals() -> list[dict]:
|
|
"""Create 24 sample intervals (1 day, hourly)."""
|
|
base_time = datetime(2025, 11, 25, 0, 0, 0, tzinfo=UTC)
|
|
return [
|
|
{
|
|
"startsAt": (base_time.replace(hour=h)).isoformat(),
|
|
"total": 10.0 + h,
|
|
"energy": 8.0 + h,
|
|
"tax": 2.0,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def cache() -> TibberPricesIntervalPoolFetchGroupCache:
|
|
"""Create a fresh cache instance for testing."""
|
|
return TibberPricesIntervalPoolFetchGroupCache()
|
|
|
|
|
|
@pytest.fixture
|
|
def index() -> TibberPricesIntervalPoolTimestampIndex:
|
|
"""Create a fresh index instance for testing."""
|
|
return TibberPricesIntervalPoolTimestampIndex()
|
|
|
|
|
|
class TestTouchOperations:
|
|
"""Test touch operations (re-fetching same intervals)."""
|
|
|
|
def test_touch_operation_reuses_existing_intervals(
|
|
self,
|
|
pool: TibberPricesIntervalPool,
|
|
) -> None:
|
|
"""Test that touch operations reuse existing interval dicts (references, not copies)."""
|
|
fetch_time_1 = "2025-11-25T10:00:00+01:00"
|
|
fetch_time_2 = "2025-11-25T10:15:00+01:00"
|
|
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# First fetch: Add intervals
|
|
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
|
|
|
|
# Access internal cache
|
|
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
|
|
|
|
# Verify: 1 fetch group with 24 intervals
|
|
assert len(fetch_groups) == 1
|
|
assert len(fetch_groups[0]["intervals"]) == 24
|
|
|
|
# Get reference to first interval
|
|
first_interval_original = fetch_groups[0]["intervals"][0]
|
|
original_id = id(first_interval_original)
|
|
|
|
# Second fetch: Touch same intervals
|
|
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
|
|
|
|
# Re-fetch groups (list may have changed)
|
|
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
|
|
|
|
# Verify: Now we have 2 fetch groups
|
|
assert len(fetch_groups) == 2
|
|
|
|
# Get reference to first interval from TOUCH group
|
|
first_interval_touched = fetch_groups[1]["intervals"][0]
|
|
touched_id = id(first_interval_touched)
|
|
|
|
# CRITICAL: Should be SAME object (same memory address)
|
|
assert original_id == touched_id, f"Memory addresses differ: {original_id} != {touched_id}"
|
|
assert first_interval_original is first_interval_touched, "Touch should reuse existing dict, not create copy"
|
|
|
|
def test_touch_operation_updates_index(
|
|
self,
|
|
pool: TibberPricesIntervalPool,
|
|
) -> None:
|
|
"""Test that touch operations update the index to point to new fetch group."""
|
|
fetch_time_1 = "2025-11-25T10:00:00+01:00"
|
|
fetch_time_2 = "2025-11-25T10:15:00+01:00"
|
|
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# First fetch
|
|
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
|
|
|
|
# Verify index points to group 0
|
|
first_key = sample_intervals[0]["startsAt"][:19]
|
|
index_entry = pool._index.get(first_key) # noqa: SLF001
|
|
assert index_entry is not None
|
|
assert index_entry["fetch_group_index"] == 0
|
|
|
|
# Second fetch (touch)
|
|
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
|
|
|
|
# Verify index now points to group 1 (touch group)
|
|
index_entry = pool._index.get(first_key) # noqa: SLF001
|
|
assert index_entry is not None
|
|
assert index_entry["fetch_group_index"] == 1, "Index should point to touch group"
|
|
|
|
|
|
class TestGarbageCollection:
|
|
"""Test garbage collection and dead interval cleanup."""
|
|
|
|
def test_gc_cleanup_removes_dead_intervals(
|
|
self,
|
|
cache: TibberPricesIntervalPoolFetchGroupCache,
|
|
index: TibberPricesIntervalPoolTimestampIndex,
|
|
) -> None:
|
|
"""Test that GC cleanup removes dead intervals from old fetch groups."""
|
|
gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
|
|
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# First fetch: Add to cache and index
|
|
fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
|
|
group_idx_1 = cache.add_fetch_group(sample_intervals, fetch_time_1)
|
|
for i, interval in enumerate(sample_intervals):
|
|
index.add(interval, group_idx_1, i)
|
|
|
|
# Verify initial state
|
|
assert cache.count_total_intervals() == 24
|
|
assert index.count() == 24
|
|
|
|
# Second fetch (touch): Create new fetch group
|
|
fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
|
|
group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
|
|
|
|
# Update index to point to new group (simulates touch)
|
|
for i, interval in enumerate(sample_intervals):
|
|
index.add(interval, group_idx_2, i)
|
|
|
|
# Before GC: 48 intervals in cache (24 dead + 24 living), 24 in index
|
|
assert cache.count_total_intervals() == 48
|
|
assert index.count() == 24
|
|
|
|
# Run GC
|
|
gc_changed = gc.run_gc()
|
|
|
|
# After GC: Dead intervals cleaned, empty group removed
|
|
assert gc_changed is True
|
|
assert cache.count_total_intervals() == 24, "Should only have living intervals"
|
|
|
|
def test_gc_removes_empty_fetch_groups(
|
|
self,
|
|
cache: TibberPricesIntervalPoolFetchGroupCache,
|
|
index: TibberPricesIntervalPoolTimestampIndex,
|
|
) -> None:
|
|
"""Test that GC removes empty fetch groups after dead interval cleanup."""
|
|
gc = TibberPricesIntervalPoolGarbageCollector(cache, index, "test_home")
|
|
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(4) # Small set
|
|
]
|
|
|
|
# Add two fetch groups
|
|
fetch_time_1 = datetime(2025, 11, 25, 10, 0, 0, tzinfo=UTC)
|
|
fetch_time_2 = datetime(2025, 11, 25, 10, 15, 0, tzinfo=UTC)
|
|
|
|
cache.add_fetch_group(sample_intervals, fetch_time_1)
|
|
group_idx_2 = cache.add_fetch_group(sample_intervals, fetch_time_2)
|
|
|
|
# Index points only to second group
|
|
for i, interval in enumerate(sample_intervals):
|
|
index.add(interval, group_idx_2, i)
|
|
|
|
# Before GC: 2 groups
|
|
assert len(cache.get_fetch_groups()) == 2
|
|
|
|
# Run GC
|
|
gc.run_gc()
|
|
|
|
# After GC: Only 1 group (empty one removed)
|
|
fetch_groups = cache.get_fetch_groups()
|
|
assert len(fetch_groups) == 1, "Empty fetch group should be removed"
|
|
assert len(fetch_groups[0]["intervals"]) == 4
|
|
|
|
|
|
class TestSerialization:
|
|
"""Test serialization excludes dead intervals."""
|
|
|
|
def test_serialization_excludes_dead_intervals(
|
|
self,
|
|
pool: TibberPricesIntervalPool,
|
|
) -> None:
|
|
"""Test that to_dict() excludes dead intervals from serialization."""
|
|
fetch_time_1 = "2025-11-25T10:00:00+01:00"
|
|
fetch_time_2 = "2025-11-25T10:15:00+01:00"
|
|
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# First fetch
|
|
pool._add_intervals(sample_intervals, fetch_time_1) # noqa: SLF001
|
|
|
|
# Second fetch (touch)
|
|
pool._add_intervals(sample_intervals, fetch_time_2) # noqa: SLF001
|
|
|
|
# Serialize WITHOUT running GC cleanup first
|
|
serialized = pool.to_dict()
|
|
|
|
# Verify serialization structure
|
|
assert "fetch_groups" in serialized
|
|
assert "home_id" in serialized
|
|
fetch_groups = serialized["fetch_groups"]
|
|
|
|
# CRITICAL: Should only serialize living intervals
|
|
# Old group with dead intervals should NOT be serialized
|
|
total_serialized_intervals = sum(len(g["intervals"]) for g in fetch_groups)
|
|
assert total_serialized_intervals == 24, (
|
|
f"Should only serialize 24 living intervals, got {total_serialized_intervals}"
|
|
)
|
|
|
|
def test_repeated_touch_operations_dont_grow_storage(
|
|
self,
|
|
pool: TibberPricesIntervalPool,
|
|
) -> None:
|
|
"""Test that repeated touch operations don't grow storage size unbounded."""
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# Simulate 10 re-fetches of the same intervals
|
|
for i in range(10):
|
|
fetch_time = f"2025-11-25T{10 + i}:00:00+01:00"
|
|
pool._add_intervals(sample_intervals, fetch_time) # noqa: SLF001
|
|
|
|
# Serialize (filters dead intervals)
|
|
serialized = pool.to_dict()
|
|
serialized_groups = serialized["fetch_groups"]
|
|
|
|
# Storage should only have 24 living intervals total
|
|
total_intervals = sum(len(g["intervals"]) for g in serialized_groups)
|
|
assert total_intervals == 24, f"Should only have 24 living intervals, got {total_intervals}"
|
|
|
|
# Verify storage size is bounded
|
|
json_str = json.dumps(serialized)
|
|
json_size = len(json_str)
|
|
# Should still be < 10 KB even after 10 fetches
|
|
assert json_size < 10000, f"Storage grew unbounded: {json_size} bytes (expected < 10000)"
|
|
|
|
|
|
class TestIndexBatchUpdate:
|
|
"""Test batch index update functionality."""
|
|
|
|
def test_batch_update_efficiency(
|
|
self,
|
|
index: TibberPricesIntervalPoolTimestampIndex,
|
|
) -> None:
|
|
"""Test that batch update correctly updates multiple entries."""
|
|
# Create test intervals
|
|
timestamps = [f"2025-11-25T{h:02d}:00:00" for h in range(24)]
|
|
|
|
# Add intervals pointing to group 0
|
|
for i, ts in enumerate(timestamps):
|
|
index.add({"startsAt": ts}, 0, i)
|
|
|
|
# Verify initial state
|
|
assert index.count() == 24
|
|
for ts in timestamps:
|
|
entry = index.get(ts)
|
|
assert entry is not None
|
|
assert entry["fetch_group_index"] == 0
|
|
|
|
# Batch update to point to group 1
|
|
updates = [(ts, 1, i) for i, ts in enumerate(timestamps)]
|
|
index.update_batch(updates)
|
|
|
|
# Verify all entries now point to group 1
|
|
for ts in timestamps:
|
|
entry = index.get(ts)
|
|
assert entry is not None
|
|
assert entry["fetch_group_index"] == 1, f"Entry for {ts} should point to group 1"
|
|
|
|
def test_batch_update_with_partial_overlap(
|
|
self,
|
|
index: TibberPricesIntervalPoolTimestampIndex,
|
|
) -> None:
|
|
"""Test batch update with only some existing entries."""
|
|
# Add initial entries (0-11)
|
|
for i in range(12):
|
|
ts = f"2025-11-25T{i:02d}:00:00"
|
|
index.add({"startsAt": ts}, 0, i)
|
|
|
|
assert index.count() == 12
|
|
|
|
# Batch update: update first 6, add 6 new (12-17)
|
|
updates = [(f"2025-11-25T{i:02d}:00:00", 1, i) for i in range(18)]
|
|
index.update_batch(updates)
|
|
|
|
# Should now have 18 entries (12 existing + 6 new)
|
|
assert index.count() == 18
|
|
|
|
# All should point to group 1
|
|
for i in range(18):
|
|
ts = f"2025-11-25T{i:02d}:00:00"
|
|
entry = index.get(ts)
|
|
assert entry is not None
|
|
assert entry["fetch_group_index"] == 1
|
|
|
|
|
|
class TestIntervalIdentityPreservation:
|
|
"""Test that interval dict identity is preserved across operations."""
|
|
|
|
def test_interval_identity_preserved_across_touch(
|
|
self,
|
|
pool: TibberPricesIntervalPool,
|
|
) -> None:
|
|
"""Test that interval dict identity (memory address) is preserved across touch."""
|
|
# Create sample intervals
|
|
sample_intervals = [
|
|
{
|
|
"startsAt": datetime(2025, 11, 25, h, 0, 0, tzinfo=UTC).isoformat(),
|
|
"total": 10.0 + h,
|
|
}
|
|
for h in range(24)
|
|
]
|
|
|
|
# First fetch
|
|
pool._add_intervals(sample_intervals, "2025-11-25T10:00:00+01:00") # noqa: SLF001
|
|
|
|
# Get fetch groups
|
|
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
|
|
|
|
# Collect memory addresses of intervals in original group
|
|
original_ids = [id(interval) for interval in fetch_groups[0]["intervals"]]
|
|
|
|
# Second fetch (touch)
|
|
pool._add_intervals(sample_intervals, "2025-11-25T10:15:00+01:00") # noqa: SLF001
|
|
|
|
# Re-fetch groups
|
|
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
|
|
|
|
# Collect memory addresses of intervals in touch group
|
|
touched_ids = [id(interval) for interval in fetch_groups[1]["intervals"]]
|
|
|
|
# CRITICAL: All memory addresses should be identical (same objects)
|
|
assert original_ids == touched_ids, "Touch should preserve interval identity (memory addresses)"
|
|
|
|
# Third fetch (touch again)
|
|
pool._add_intervals(sample_intervals, "2025-11-25T10:30:00+01:00") # noqa: SLF001
|
|
|
|
# Re-fetch groups
|
|
fetch_groups = pool._cache.get_fetch_groups() # noqa: SLF001
|
|
|
|
# New touch group should also reference the SAME original objects
|
|
touched_ids_2 = [id(interval) for interval in fetch_groups[2]["intervals"]]
|
|
assert original_ids == touched_ids_2, "Multiple touches should preserve original identity"
|