chore(testing): add optional Pyright checks for tests
Some checks failed
Validate / HACS validation (push) Has been cancelled
Lint / Ruff (push) Has been cancelled
Validate / Hassfest validation (push) Has been cancelled

Add a dedicated type-check-tests helper, wire it into check-all behind --with-test-types, and align the affected tests with current typing and helper contracts.

Impact: No direct user-facing change.

User-Impact: none
This commit is contained in:
Julian Pawlowski 2026-04-25 22:46:43 +00:00
parent bbcfdd4443
commit bb8f5aa8cc
7 changed files with 107 additions and 36 deletions

View file

@ -1,15 +1,17 @@
#!/bin/bash
# script/check-all: Run full checks for Python and non-Python files
# script/check-all: Run full checks for Python/non-Python files, optionally incl. test Pyright
#
# Runs project checks and validates formatting/lint state for common
# non-Python files.
# non-Python files. Optionally includes Pyright checks for tests.
#
# Usage:
# ./scripts/check-all
# ./scripts/check-all --with-test-types
#
# Examples:
# ./scripts/check-all
# ./scripts/check-all --with-test-types
set -euo pipefail
@ -19,6 +21,21 @@ cd "$SCRIPT_DIR/.."
# shellcheck source=scripts/.lib/output.sh
source "$SCRIPT_DIR/.lib/output.sh"
run_test_type_check=false
for arg in "$@"; do
case "$arg" in
--with-test-types)
run_test_type_check=true
;;
*)
log_error "Unknown argument: $arg"
log_info "Usage: ./scripts/check-all [--with-test-types]"
exit 1
;;
esac
done
collect_shell_files() {
local files=()
local file shebang
@ -40,6 +57,11 @@ collect_shell_files() {
log_header "Running Python checks"
"$SCRIPT_DIR/check"
if [[ $run_test_type_check == true ]]; then
log_header "Running Pyright checks for tests"
"$SCRIPT_DIR/type-check-tests"
fi
log_header "Checking JSON/JSONC/Markdown with Prettier"
npx --yes prettier --check "**/*.{json,jsonc,md,yml,yaml}"

34
scripts/type-check-tests Executable file
View file

@ -0,0 +1,34 @@
#!/bin/bash
# script/type-check-tests: Run optional Pyright checks for test files
#
# Runs Pyright on tests without changing the main repository type-check scope.
# Defaults to the full tests/ tree, but accepts optional file or folder targets.
#
# Usage:
# ./scripts/type-check-tests
# ./scripts/type-check-tests tests/test_period_overlap.py tests/test_periods_hash.py
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$SCRIPT_DIR/.."
# shellcheck source=scripts/.lib/output.sh
source "$SCRIPT_DIR/.lib/output.sh"
if [[ -z ${VIRTUAL_ENV:-} ]]; then
# shellcheck source=/dev/null
source "$HOME/.venv/bin/activate"
fi
targets=("$@")
if [[ ${#targets[@]} -eq 0 ]]; then
targets=("tests")
fi
log_header "Running type checking tools for test files"
pyright "${targets[@]}"
log_success "Test type checking completed"

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from zoneinfo import ZoneInfo
from custom_components.tibber_prices.services.charging.deadline_solver import resolve_deadline
from custom_components.tibber_prices.services.charging.power_scheduler import build_power_schedule
@ -53,6 +54,7 @@ def test_stepped_mode_uses_smallest_sufficient_step() -> None:
def test_resolve_deadline_next_peak_period() -> None:
"""Deadline helper should resolve the next future peak period start."""
now = datetime(2026, 1, 1, 0, 0, tzinfo=UTC)
home_tz = ZoneInfo("UTC")
coordinator_data = {
"pricePeriods": {
"peak_price": {
@ -69,7 +71,7 @@ def test_resolve_deadline_next_peak_period() -> None:
deadline, source = resolve_deadline(
coordinator_data=coordinator_data,
now=now,
home_tz=UTC,
home_tz=home_tz,
must_reach_by_event="next_peak_period",
)

View file

@ -47,12 +47,15 @@ def _create_intervals(start: datetime, count: int) -> list[dict]:
return [_create_test_interval(start + timedelta(minutes=15 * i)) for i in range(count)]
def _create_pool(api_client: MagicMock) -> TibberPricesIntervalPool:
"""Create an interval pool using the current constructor signature."""
return TibberPricesIntervalPool(home_id="home123", api=api_client)
@pytest.mark.asyncio
@pytest.mark.unit
async def test_no_cache_single_api_call() -> None:
"""Test: Empty cache → 1 API call for entire range."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client
api_client = MagicMock(
spec=[
@ -63,6 +66,7 @@ async def test_no_cache_single_api_call() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
end = start + timedelta(hours=2) # 8 intervals
@ -80,19 +84,17 @@ async def test_no_cache_single_api_call() -> None:
user_data = {"timeZone": "Europe/Berlin"}
# Act
result = await pool.get_intervals(api_client, user_data, start, end)
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
# Assert: Exactly 1 API call
assert api_client.async_get_price_info_for_range.call_count == 1
assert len(result) == 8
assert len(intervals) == 8
@pytest.mark.asyncio
@pytest.mark.unit
async def test_full_cache_zero_api_calls() -> None:
"""Test: Fully cached range → 0 API calls."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client
api_client = MagicMock(
spec=[
@ -103,6 +105,7 @@ async def test_full_cache_zero_api_calls() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
end = start + timedelta(hours=2) # 8 intervals
@ -123,19 +126,17 @@ async def test_full_cache_zero_api_calls() -> None:
assert api_client.async_get_price_info_for_range.call_count == 1
# Second call: should use cache
result = await pool.get_intervals(api_client, user_data, start, end)
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
# Assert: Still only 1 API call (from first request)
assert api_client.async_get_price_info_for_range.call_count == 1
assert len(result) == 8
assert len(intervals) == 8
@pytest.mark.asyncio
@pytest.mark.unit
async def test_single_gap_single_api_call() -> None:
"""Test: One gap in cache → 1 API call for that gap only."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client
api_client = MagicMock(
spec=[
@ -146,6 +147,7 @@ async def test_single_gap_single_api_call() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
end = start + timedelta(hours=3) # 12 intervals total
@ -175,19 +177,17 @@ async def test_single_gap_single_api_call() -> None:
gap_intervals = _create_intervals(start + timedelta(hours=1), 4)
api_client.async_get_price_info_for_range = AsyncMock(return_value=gap_intervals)
result = await pool.get_intervals(api_client, user_data, start, end)
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
# Assert: Exactly 1 additional API call (for the gap)
assert api_client.async_get_price_info_for_range.call_count == call_count_before + 1
assert len(result) == 12 # All intervals now available
assert len(intervals) == 12 # All intervals now available
@pytest.mark.asyncio
@pytest.mark.unit
async def test_multiple_gaps_multiple_api_calls() -> None:
"""Test: Multiple gaps → one API call per continuous gap."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client
api_client = MagicMock(
spec=[
@ -198,6 +198,7 @@ async def test_multiple_gaps_multiple_api_calls() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
end = start + timedelta(hours=4) # 16 intervals total
@ -247,19 +248,17 @@ async def test_multiple_gaps_multiple_api_calls() -> None:
api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch)
result = await pool.get_intervals(api_client, user_data, start, end)
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
# Assert: Exactly 3 additional API calls (one per gap)
assert api_client.async_get_price_info_for_range.call_count == call_count_before + 3
assert len(result) == 16 # All intervals now available
assert len(intervals) == 16 # All intervals now available
@pytest.mark.asyncio
@pytest.mark.unit
async def test_partial_overlap_minimal_fetch() -> None:
"""Test: Overlapping request → fetch only new intervals."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client
api_client = MagicMock(
spec=[
@ -270,6 +269,7 @@ async def test_partial_overlap_minimal_fetch() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
user_data = {"timeZone": "Europe/Berlin"}
@ -285,7 +285,7 @@ async def test_partial_overlap_minimal_fetch() -> None:
batch2 = _create_intervals(start + timedelta(hours=2), 4) # Only new ones
api_client.async_get_price_info_for_range = AsyncMock(return_value=batch2)
result = await pool.get_intervals(
intervals, _api_called = await pool.get_intervals(
api_client,
user_data,
start + timedelta(hours=1),
@ -294,15 +294,13 @@ async def test_partial_overlap_minimal_fetch() -> None:
# Assert: 1 additional API call (for 12:00-13:00 only)
assert api_client.async_get_price_info_for_range.call_count == 2
assert len(result) == 8 # 11:00-13:00
assert len(intervals) == 8 # 11:00-13:00
@pytest.mark.asyncio
@pytest.mark.unit
async def test_detect_missing_ranges_optimization() -> None:
"""Test: Gap detection returns minimal set of ranges (tested via API behavior)."""
pool = TibberPricesIntervalPool(home_id="home123")
# Mock API client that tracks calls
api_client = MagicMock(
spec=[
@ -313,6 +311,7 @@ async def test_detect_missing_ranges_optimization() -> None:
"_calculate_day_before_yesterday_midnight",
]
)
pool = _create_pool(api_client)
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
end = start + timedelta(hours=4)
@ -334,13 +333,17 @@ async def test_detect_missing_ranges_optimization() -> None:
# Manually add to cache (simulate previous fetches)
# Note: Accessing private _cache for test setup
# Single-home architecture: directly populate internal structures
pool._fetch_groups = [ # noqa: SLF001
cache = pool._cache # noqa: SLF001
index = pool._index # noqa: SLF001
cache.set_fetch_groups(
[
{
"intervals": cached,
"fetch_time": dt_util.now().isoformat(),
"fetched_at": dt_util.now(),
}
]
pool._timestamp_index = {interval["startsAt"]: idx for idx, interval in enumerate(cached)} # noqa: SLF001
)
index.rebuild(cache.get_fetch_groups())
# Mock responses for the 3 expected gaps
gap1 = _create_intervals(start + timedelta(minutes=30), 2) # 10:30-11:00
@ -362,10 +365,10 @@ async def test_detect_missing_ranges_optimization() -> None:
api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch)
# Request entire range - should detect exactly 3 gaps
result = await pool.get_intervals(api_client, user_data, start, end)
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
# Assert: Exactly 3 API calls (one per gap)
assert api_client.async_get_price_info_for_range.call_count == 3
# Verify all intervals are now available
assert len(result) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals
assert len(intervals) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals

View file

@ -541,6 +541,9 @@ class TestPriceComparison:
cheap_stats = calculate_window_statistics(cheapest["intervals"])
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
assert cheap_stats["price_mean"] is not None
assert expensive_stats["price_mean"] is not None
spread_cheap_to_exp = expensive_stats["price_mean"] - cheap_stats["price_mean"]
spread_exp_to_cheap = cheap_stats["price_mean"] - expensive_stats["price_mean"]
@ -582,6 +585,9 @@ class TestPriceComparison:
cheap_stats = calculate_window_statistics(cheapest["intervals"])
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
assert cheap_stats["price_mean"] is not None
assert expensive_stats["price_mean"] is not None
spread = expensive_stats["price_mean"] - cheap_stats["price_mean"]
assert abs(spread) < 0.0001
@ -598,5 +604,8 @@ class TestPriceComparison:
cheap_stats = calculate_window_statistics(cheapest["intervals"])
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
assert cheap_stats["price_mean"] is not None
assert expensive_stats["price_mean"] is not None
spread = expensive_stats["price_mean"] - cheap_stats["price_mean"]
assert abs(spread) < 0.0001

View file

@ -280,6 +280,7 @@ def test_hysteresis_sequence_simulation() -> None:
rating = calculate_rating_level(
diff, threshold_low, threshold_high, previous_rating=previous, hysteresis=hysteresis
)
assert rating is not None
results_with.append(rating)
previous = rating
assert results_with == expected_with_hysteresis

View file

@ -202,7 +202,7 @@ class TestConfigEntryCleanup:
coordinator._listener_manager = object.__new__(TibberPricesListenerManager) # noqa: SLF001
coordinator._data_transformer = object.__new__(TibberPricesDataTransformer) # noqa: SLF001
coordinator._period_calculator = object.__new__(TibberPricesPeriodCalculator) # noqa: SLF001
coordinator._lifecycle_callbacks = [] # noqa: SLF001
setattr(coordinator, "_lifecycle_callbacks", [])
# Manually call the registration that happens in __init__
# This tests the pattern: entry.async_on_unload(entry.add_update_listener(...))
@ -249,7 +249,7 @@ class TestCacheInvalidation:
# Create calculator with cached data
calculator = object.__new__(TibberPricesPeriodCalculator)
calculator._config_cache = {"some": "data"} # noqa: SLF001
calculator._config_cache = {"best": {"some": "data"}} # noqa: SLF001
calculator._config_cache_valid = True # noqa: SLF001
calculator._cached_periods = {"cached": "periods"} # noqa: SLF001
calculator._last_periods_hash = "some_hash" # noqa: SLF001