From bb8f5aa8cc22a1247c47d2e6779d39ab9abef2ab Mon Sep 17 00:00:00 2001 From: Julian Pawlowski Date: Sat, 25 Apr 2026 22:46:43 +0000 Subject: [PATCH] chore(testing): add optional Pyright checks for tests Add a dedicated type-check-tests helper, wire it into check-all behind --with-test-types, and align the affected tests with current typing and helper contracts. Impact: No direct user-facing change. User-Impact: none --- scripts/check-all | 26 ++++++++- scripts/type-check-tests | 34 ++++++++++++ tests/services/test_power_scheduler.py | 4 +- tests/test_interval_pool_optimization.py | 65 ++++++++++++----------- tests/test_price_window.py | 9 ++++ tests/test_rating_threshold_validation.py | 1 + tests/test_resource_cleanup.py | 4 +- 7 files changed, 107 insertions(+), 36 deletions(-) create mode 100755 scripts/type-check-tests diff --git a/scripts/check-all b/scripts/check-all index 2569950..08b7e71 100755 --- a/scripts/check-all +++ b/scripts/check-all @@ -1,15 +1,17 @@ #!/bin/bash -# script/check-all: Run full checks for Python and non-Python files +# script/check-all: Run full checks for Python/non-Python files, optionally incl. test Pyright # # Runs project checks and validates formatting/lint state for common -# non-Python files. +# non-Python files. Optionally includes Pyright checks for tests. # # Usage: # ./scripts/check-all +# ./scripts/check-all --with-test-types # # Examples: # ./scripts/check-all +# ./scripts/check-all --with-test-types set -euo pipefail @@ -19,6 +21,21 @@ cd "$SCRIPT_DIR/.." # shellcheck source=scripts/.lib/output.sh source "$SCRIPT_DIR/.lib/output.sh" +run_test_type_check=false + +for arg in "$@"; do + case "$arg" in + --with-test-types) + run_test_type_check=true + ;; + *) + log_error "Unknown argument: $arg" + log_info "Usage: ./scripts/check-all [--with-test-types]" + exit 1 + ;; + esac +done + collect_shell_files() { local files=() local file shebang @@ -40,6 +57,11 @@ collect_shell_files() { log_header "Running Python checks" "$SCRIPT_DIR/check" +if [[ $run_test_type_check == true ]]; then + log_header "Running Pyright checks for tests" + "$SCRIPT_DIR/type-check-tests" +fi + log_header "Checking JSON/JSONC/Markdown with Prettier" npx --yes prettier --check "**/*.{json,jsonc,md,yml,yaml}" diff --git a/scripts/type-check-tests b/scripts/type-check-tests new file mode 100755 index 0000000..888183c --- /dev/null +++ b/scripts/type-check-tests @@ -0,0 +1,34 @@ +#!/bin/bash + +# script/type-check-tests: Run optional Pyright checks for test files +# +# Runs Pyright on tests without changing the main repository type-check scope. +# Defaults to the full tests/ tree, but accepts optional file or folder targets. +# +# Usage: +# ./scripts/type-check-tests +# ./scripts/type-check-tests tests/test_period_overlap.py tests/test_periods_hash.py + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR/.." + +# shellcheck source=scripts/.lib/output.sh +source "$SCRIPT_DIR/.lib/output.sh" + +if [[ -z ${VIRTUAL_ENV:-} ]]; then + # shellcheck source=/dev/null + source "$HOME/.venv/bin/activate" +fi + +targets=("$@") +if [[ ${#targets[@]} -eq 0 ]]; then + targets=("tests") +fi + +log_header "Running type checking tools for test files" + +pyright "${targets[@]}" + +log_success "Test type checking completed" diff --git a/tests/services/test_power_scheduler.py b/tests/services/test_power_scheduler.py index f7d05e6..366a333 100644 --- a/tests/services/test_power_scheduler.py +++ b/tests/services/test_power_scheduler.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta +from zoneinfo import ZoneInfo from custom_components.tibber_prices.services.charging.deadline_solver import resolve_deadline from custom_components.tibber_prices.services.charging.power_scheduler import build_power_schedule @@ -53,6 +54,7 @@ def test_stepped_mode_uses_smallest_sufficient_step() -> None: def test_resolve_deadline_next_peak_period() -> None: """Deadline helper should resolve the next future peak period start.""" now = datetime(2026, 1, 1, 0, 0, tzinfo=UTC) + home_tz = ZoneInfo("UTC") coordinator_data = { "pricePeriods": { "peak_price": { @@ -69,7 +71,7 @@ def test_resolve_deadline_next_peak_period() -> None: deadline, source = resolve_deadline( coordinator_data=coordinator_data, now=now, - home_tz=UTC, + home_tz=home_tz, must_reach_by_event="next_peak_period", ) diff --git a/tests/test_interval_pool_optimization.py b/tests/test_interval_pool_optimization.py index b89502a..7f7618b 100644 --- a/tests/test_interval_pool_optimization.py +++ b/tests/test_interval_pool_optimization.py @@ -47,12 +47,15 @@ def _create_intervals(start: datetime, count: int) -> list[dict]: return [_create_test_interval(start + timedelta(minutes=15 * i)) for i in range(count)] +def _create_pool(api_client: MagicMock) -> TibberPricesIntervalPool: + """Create an interval pool using the current constructor signature.""" + return TibberPricesIntervalPool(home_id="home123", api=api_client) + + @pytest.mark.asyncio @pytest.mark.unit async def test_no_cache_single_api_call() -> None: """Test: Empty cache → 1 API call for entire range.""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client api_client = MagicMock( spec=[ @@ -63,6 +66,7 @@ async def test_no_cache_single_api_call() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) end = start + timedelta(hours=2) # 8 intervals @@ -80,19 +84,17 @@ async def test_no_cache_single_api_call() -> None: user_data = {"timeZone": "Europe/Berlin"} # Act - result = await pool.get_intervals(api_client, user_data, start, end) + intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end) # Assert: Exactly 1 API call assert api_client.async_get_price_info_for_range.call_count == 1 - assert len(result) == 8 + assert len(intervals) == 8 @pytest.mark.asyncio @pytest.mark.unit async def test_full_cache_zero_api_calls() -> None: """Test: Fully cached range → 0 API calls.""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client api_client = MagicMock( spec=[ @@ -103,6 +105,7 @@ async def test_full_cache_zero_api_calls() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) end = start + timedelta(hours=2) # 8 intervals @@ -123,19 +126,17 @@ async def test_full_cache_zero_api_calls() -> None: assert api_client.async_get_price_info_for_range.call_count == 1 # Second call: should use cache - result = await pool.get_intervals(api_client, user_data, start, end) + intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end) # Assert: Still only 1 API call (from first request) assert api_client.async_get_price_info_for_range.call_count == 1 - assert len(result) == 8 + assert len(intervals) == 8 @pytest.mark.asyncio @pytest.mark.unit async def test_single_gap_single_api_call() -> None: """Test: One gap in cache → 1 API call for that gap only.""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client api_client = MagicMock( spec=[ @@ -146,6 +147,7 @@ async def test_single_gap_single_api_call() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) end = start + timedelta(hours=3) # 12 intervals total @@ -175,19 +177,17 @@ async def test_single_gap_single_api_call() -> None: gap_intervals = _create_intervals(start + timedelta(hours=1), 4) api_client.async_get_price_info_for_range = AsyncMock(return_value=gap_intervals) - result = await pool.get_intervals(api_client, user_data, start, end) + intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end) # Assert: Exactly 1 additional API call (for the gap) assert api_client.async_get_price_info_for_range.call_count == call_count_before + 1 - assert len(result) == 12 # All intervals now available + assert len(intervals) == 12 # All intervals now available @pytest.mark.asyncio @pytest.mark.unit async def test_multiple_gaps_multiple_api_calls() -> None: """Test: Multiple gaps → one API call per continuous gap.""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client api_client = MagicMock( spec=[ @@ -198,6 +198,7 @@ async def test_multiple_gaps_multiple_api_calls() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) end = start + timedelta(hours=4) # 16 intervals total @@ -247,19 +248,17 @@ async def test_multiple_gaps_multiple_api_calls() -> None: api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch) - result = await pool.get_intervals(api_client, user_data, start, end) + intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end) # Assert: Exactly 3 additional API calls (one per gap) assert api_client.async_get_price_info_for_range.call_count == call_count_before + 3 - assert len(result) == 16 # All intervals now available + assert len(intervals) == 16 # All intervals now available @pytest.mark.asyncio @pytest.mark.unit async def test_partial_overlap_minimal_fetch() -> None: """Test: Overlapping request → fetch only new intervals.""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client api_client = MagicMock( spec=[ @@ -270,6 +269,7 @@ async def test_partial_overlap_minimal_fetch() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) user_data = {"timeZone": "Europe/Berlin"} @@ -285,7 +285,7 @@ async def test_partial_overlap_minimal_fetch() -> None: batch2 = _create_intervals(start + timedelta(hours=2), 4) # Only new ones api_client.async_get_price_info_for_range = AsyncMock(return_value=batch2) - result = await pool.get_intervals( + intervals, _api_called = await pool.get_intervals( api_client, user_data, start + timedelta(hours=1), @@ -294,15 +294,13 @@ async def test_partial_overlap_minimal_fetch() -> None: # Assert: 1 additional API call (for 12:00-13:00 only) assert api_client.async_get_price_info_for_range.call_count == 2 - assert len(result) == 8 # 11:00-13:00 + assert len(intervals) == 8 # 11:00-13:00 @pytest.mark.asyncio @pytest.mark.unit async def test_detect_missing_ranges_optimization() -> None: """Test: Gap detection returns minimal set of ranges (tested via API behavior).""" - pool = TibberPricesIntervalPool(home_id="home123") - # Mock API client that tracks calls api_client = MagicMock( spec=[ @@ -313,6 +311,7 @@ async def test_detect_missing_ranges_optimization() -> None: "_calculate_day_before_yesterday_midnight", ] ) + pool = _create_pool(api_client) start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0) end = start + timedelta(hours=4) @@ -334,13 +333,17 @@ async def test_detect_missing_ranges_optimization() -> None: # Manually add to cache (simulate previous fetches) # Note: Accessing private _cache for test setup # Single-home architecture: directly populate internal structures - pool._fetch_groups = [ # noqa: SLF001 - { - "intervals": cached, - "fetch_time": dt_util.now().isoformat(), - } - ] - pool._timestamp_index = {interval["startsAt"]: idx for idx, interval in enumerate(cached)} # noqa: SLF001 + cache = pool._cache # noqa: SLF001 + index = pool._index # noqa: SLF001 + cache.set_fetch_groups( + [ + { + "intervals": cached, + "fetched_at": dt_util.now(), + } + ] + ) + index.rebuild(cache.get_fetch_groups()) # Mock responses for the 3 expected gaps gap1 = _create_intervals(start + timedelta(minutes=30), 2) # 10:30-11:00 @@ -362,10 +365,10 @@ async def test_detect_missing_ranges_optimization() -> None: api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch) # Request entire range - should detect exactly 3 gaps - result = await pool.get_intervals(api_client, user_data, start, end) + intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end) # Assert: Exactly 3 API calls (one per gap) assert api_client.async_get_price_info_for_range.call_count == 3 # Verify all intervals are now available - assert len(result) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals + assert len(intervals) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals diff --git a/tests/test_price_window.py b/tests/test_price_window.py index 997ad0e..0033b80 100644 --- a/tests/test_price_window.py +++ b/tests/test_price_window.py @@ -541,6 +541,9 @@ class TestPriceComparison: cheap_stats = calculate_window_statistics(cheapest["intervals"]) expensive_stats = calculate_window_statistics(most_expensive["intervals"]) + assert cheap_stats["price_mean"] is not None + assert expensive_stats["price_mean"] is not None + spread_cheap_to_exp = expensive_stats["price_mean"] - cheap_stats["price_mean"] spread_exp_to_cheap = cheap_stats["price_mean"] - expensive_stats["price_mean"] @@ -582,6 +585,9 @@ class TestPriceComparison: cheap_stats = calculate_window_statistics(cheapest["intervals"]) expensive_stats = calculate_window_statistics(most_expensive["intervals"]) + assert cheap_stats["price_mean"] is not None + assert expensive_stats["price_mean"] is not None + spread = expensive_stats["price_mean"] - cheap_stats["price_mean"] assert abs(spread) < 0.0001 @@ -598,5 +604,8 @@ class TestPriceComparison: cheap_stats = calculate_window_statistics(cheapest["intervals"]) expensive_stats = calculate_window_statistics(most_expensive["intervals"]) + assert cheap_stats["price_mean"] is not None + assert expensive_stats["price_mean"] is not None + spread = expensive_stats["price_mean"] - cheap_stats["price_mean"] assert abs(spread) < 0.0001 diff --git a/tests/test_rating_threshold_validation.py b/tests/test_rating_threshold_validation.py index 79ad3bb..54336f9 100644 --- a/tests/test_rating_threshold_validation.py +++ b/tests/test_rating_threshold_validation.py @@ -280,6 +280,7 @@ def test_hysteresis_sequence_simulation() -> None: rating = calculate_rating_level( diff, threshold_low, threshold_high, previous_rating=previous, hysteresis=hysteresis ) + assert rating is not None results_with.append(rating) previous = rating assert results_with == expected_with_hysteresis diff --git a/tests/test_resource_cleanup.py b/tests/test_resource_cleanup.py index ad7cd07..3de14ce 100644 --- a/tests/test_resource_cleanup.py +++ b/tests/test_resource_cleanup.py @@ -202,7 +202,7 @@ class TestConfigEntryCleanup: coordinator._listener_manager = object.__new__(TibberPricesListenerManager) # noqa: SLF001 coordinator._data_transformer = object.__new__(TibberPricesDataTransformer) # noqa: SLF001 coordinator._period_calculator = object.__new__(TibberPricesPeriodCalculator) # noqa: SLF001 - coordinator._lifecycle_callbacks = [] # noqa: SLF001 + setattr(coordinator, "_lifecycle_callbacks", []) # Manually call the registration that happens in __init__ # This tests the pattern: entry.async_on_unload(entry.add_update_listener(...)) @@ -249,7 +249,7 @@ class TestCacheInvalidation: # Create calculator with cached data calculator = object.__new__(TibberPricesPeriodCalculator) - calculator._config_cache = {"some": "data"} # noqa: SLF001 + calculator._config_cache = {"best": {"some": "data"}} # noqa: SLF001 calculator._config_cache_valid = True # noqa: SLF001 calculator._cached_periods = {"cached": "periods"} # noqa: SLF001 calculator._last_periods_hash = "some_hash" # noqa: SLF001