mirror of
https://github.com/jpawlowski/hass.tibber_prices.git
synced 2026-05-28 18:43:40 +00:00
chore(testing): add optional Pyright checks for tests
Add a dedicated type-check-tests helper, wire it into check-all behind --with-test-types, and align the affected tests with current typing and helper contracts. Impact: No direct user-facing change. User-Impact: none
This commit is contained in:
parent
bbcfdd4443
commit
bb8f5aa8cc
7 changed files with 107 additions and 36 deletions
|
|
@ -1,15 +1,17 @@
|
|||
#!/bin/bash
|
||||
|
||||
# script/check-all: Run full checks for Python and non-Python files
|
||||
# script/check-all: Run full checks for Python/non-Python files, optionally incl. test Pyright
|
||||
#
|
||||
# Runs project checks and validates formatting/lint state for common
|
||||
# non-Python files.
|
||||
# non-Python files. Optionally includes Pyright checks for tests.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/check-all
|
||||
# ./scripts/check-all --with-test-types
|
||||
#
|
||||
# Examples:
|
||||
# ./scripts/check-all
|
||||
# ./scripts/check-all --with-test-types
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
@ -19,6 +21,21 @@ cd "$SCRIPT_DIR/.."
|
|||
# shellcheck source=scripts/.lib/output.sh
|
||||
source "$SCRIPT_DIR/.lib/output.sh"
|
||||
|
||||
run_test_type_check=false
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--with-test-types)
|
||||
run_test_type_check=true
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown argument: $arg"
|
||||
log_info "Usage: ./scripts/check-all [--with-test-types]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
collect_shell_files() {
|
||||
local files=()
|
||||
local file shebang
|
||||
|
|
@ -40,6 +57,11 @@ collect_shell_files() {
|
|||
log_header "Running Python checks"
|
||||
"$SCRIPT_DIR/check"
|
||||
|
||||
if [[ $run_test_type_check == true ]]; then
|
||||
log_header "Running Pyright checks for tests"
|
||||
"$SCRIPT_DIR/type-check-tests"
|
||||
fi
|
||||
|
||||
log_header "Checking JSON/JSONC/Markdown with Prettier"
|
||||
npx --yes prettier --check "**/*.{json,jsonc,md,yml,yaml}"
|
||||
|
||||
|
|
|
|||
34
scripts/type-check-tests
Executable file
34
scripts/type-check-tests
Executable file
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
|
||||
# script/type-check-tests: Run optional Pyright checks for test files
|
||||
#
|
||||
# Runs Pyright on tests without changing the main repository type-check scope.
|
||||
# Defaults to the full tests/ tree, but accepts optional file or folder targets.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/type-check-tests
|
||||
# ./scripts/type-check-tests tests/test_period_overlap.py tests/test_periods_hash.py
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
|
||||
# shellcheck source=scripts/.lib/output.sh
|
||||
source "$SCRIPT_DIR/.lib/output.sh"
|
||||
|
||||
if [[ -z ${VIRTUAL_ENV:-} ]]; then
|
||||
# shellcheck source=/dev/null
|
||||
source "$HOME/.venv/bin/activate"
|
||||
fi
|
||||
|
||||
targets=("$@")
|
||||
if [[ ${#targets[@]} -eq 0 ]]; then
|
||||
targets=("tests")
|
||||
fi
|
||||
|
||||
log_header "Running type checking tools for test files"
|
||||
|
||||
pyright "${targets[@]}"
|
||||
|
||||
log_success "Test type checking completed"
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from custom_components.tibber_prices.services.charging.deadline_solver import resolve_deadline
|
||||
from custom_components.tibber_prices.services.charging.power_scheduler import build_power_schedule
|
||||
|
|
@ -53,6 +54,7 @@ def test_stepped_mode_uses_smallest_sufficient_step() -> None:
|
|||
def test_resolve_deadline_next_peak_period() -> None:
|
||||
"""Deadline helper should resolve the next future peak period start."""
|
||||
now = datetime(2026, 1, 1, 0, 0, tzinfo=UTC)
|
||||
home_tz = ZoneInfo("UTC")
|
||||
coordinator_data = {
|
||||
"pricePeriods": {
|
||||
"peak_price": {
|
||||
|
|
@ -69,7 +71,7 @@ def test_resolve_deadline_next_peak_period() -> None:
|
|||
deadline, source = resolve_deadline(
|
||||
coordinator_data=coordinator_data,
|
||||
now=now,
|
||||
home_tz=UTC,
|
||||
home_tz=home_tz,
|
||||
must_reach_by_event="next_peak_period",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,12 +47,15 @@ def _create_intervals(start: datetime, count: int) -> list[dict]:
|
|||
return [_create_test_interval(start + timedelta(minutes=15 * i)) for i in range(count)]
|
||||
|
||||
|
||||
def _create_pool(api_client: MagicMock) -> TibberPricesIntervalPool:
|
||||
"""Create an interval pool using the current constructor signature."""
|
||||
return TibberPricesIntervalPool(home_id="home123", api=api_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_no_cache_single_api_call() -> None:
|
||||
"""Test: Empty cache → 1 API call for entire range."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -63,6 +66,7 @@ async def test_no_cache_single_api_call() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=2) # 8 intervals
|
||||
|
||||
|
|
@ -80,19 +84,17 @@ async def test_no_cache_single_api_call() -> None:
|
|||
user_data = {"timeZone": "Europe/Berlin"}
|
||||
|
||||
# Act
|
||||
result = await pool.get_intervals(api_client, user_data, start, end)
|
||||
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
|
||||
|
||||
# Assert: Exactly 1 API call
|
||||
assert api_client.async_get_price_info_for_range.call_count == 1
|
||||
assert len(result) == 8
|
||||
assert len(intervals) == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_full_cache_zero_api_calls() -> None:
|
||||
"""Test: Fully cached range → 0 API calls."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -103,6 +105,7 @@ async def test_full_cache_zero_api_calls() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=2) # 8 intervals
|
||||
|
||||
|
|
@ -123,19 +126,17 @@ async def test_full_cache_zero_api_calls() -> None:
|
|||
assert api_client.async_get_price_info_for_range.call_count == 1
|
||||
|
||||
# Second call: should use cache
|
||||
result = await pool.get_intervals(api_client, user_data, start, end)
|
||||
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
|
||||
|
||||
# Assert: Still only 1 API call (from first request)
|
||||
assert api_client.async_get_price_info_for_range.call_count == 1
|
||||
assert len(result) == 8
|
||||
assert len(intervals) == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_single_gap_single_api_call() -> None:
|
||||
"""Test: One gap in cache → 1 API call for that gap only."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -146,6 +147,7 @@ async def test_single_gap_single_api_call() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=3) # 12 intervals total
|
||||
|
||||
|
|
@ -175,19 +177,17 @@ async def test_single_gap_single_api_call() -> None:
|
|||
gap_intervals = _create_intervals(start + timedelta(hours=1), 4)
|
||||
api_client.async_get_price_info_for_range = AsyncMock(return_value=gap_intervals)
|
||||
|
||||
result = await pool.get_intervals(api_client, user_data, start, end)
|
||||
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
|
||||
|
||||
# Assert: Exactly 1 additional API call (for the gap)
|
||||
assert api_client.async_get_price_info_for_range.call_count == call_count_before + 1
|
||||
assert len(result) == 12 # All intervals now available
|
||||
assert len(intervals) == 12 # All intervals now available
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_multiple_gaps_multiple_api_calls() -> None:
|
||||
"""Test: Multiple gaps → one API call per continuous gap."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -198,6 +198,7 @@ async def test_multiple_gaps_multiple_api_calls() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=4) # 16 intervals total
|
||||
|
||||
|
|
@ -247,19 +248,17 @@ async def test_multiple_gaps_multiple_api_calls() -> None:
|
|||
|
||||
api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch)
|
||||
|
||||
result = await pool.get_intervals(api_client, user_data, start, end)
|
||||
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
|
||||
|
||||
# Assert: Exactly 3 additional API calls (one per gap)
|
||||
assert api_client.async_get_price_info_for_range.call_count == call_count_before + 3
|
||||
assert len(result) == 16 # All intervals now available
|
||||
assert len(intervals) == 16 # All intervals now available
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_partial_overlap_minimal_fetch() -> None:
|
||||
"""Test: Overlapping request → fetch only new intervals."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -270,6 +269,7 @@ async def test_partial_overlap_minimal_fetch() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
|
||||
user_data = {"timeZone": "Europe/Berlin"}
|
||||
|
|
@ -285,7 +285,7 @@ async def test_partial_overlap_minimal_fetch() -> None:
|
|||
batch2 = _create_intervals(start + timedelta(hours=2), 4) # Only new ones
|
||||
api_client.async_get_price_info_for_range = AsyncMock(return_value=batch2)
|
||||
|
||||
result = await pool.get_intervals(
|
||||
intervals, _api_called = await pool.get_intervals(
|
||||
api_client,
|
||||
user_data,
|
||||
start + timedelta(hours=1),
|
||||
|
|
@ -294,15 +294,13 @@ async def test_partial_overlap_minimal_fetch() -> None:
|
|||
|
||||
# Assert: 1 additional API call (for 12:00-13:00 only)
|
||||
assert api_client.async_get_price_info_for_range.call_count == 2
|
||||
assert len(result) == 8 # 11:00-13:00
|
||||
assert len(intervals) == 8 # 11:00-13:00
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.unit
|
||||
async def test_detect_missing_ranges_optimization() -> None:
|
||||
"""Test: Gap detection returns minimal set of ranges (tested via API behavior)."""
|
||||
pool = TibberPricesIntervalPool(home_id="home123")
|
||||
|
||||
# Mock API client that tracks calls
|
||||
api_client = MagicMock(
|
||||
spec=[
|
||||
|
|
@ -313,6 +311,7 @@ async def test_detect_missing_ranges_optimization() -> None:
|
|||
"_calculate_day_before_yesterday_midnight",
|
||||
]
|
||||
)
|
||||
pool = _create_pool(api_client)
|
||||
|
||||
start = dt_util.now().replace(hour=10, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=4)
|
||||
|
|
@ -334,13 +333,17 @@ async def test_detect_missing_ranges_optimization() -> None:
|
|||
# Manually add to cache (simulate previous fetches)
|
||||
# Note: Accessing private _cache for test setup
|
||||
# Single-home architecture: directly populate internal structures
|
||||
pool._fetch_groups = [ # noqa: SLF001
|
||||
cache = pool._cache # noqa: SLF001
|
||||
index = pool._index # noqa: SLF001
|
||||
cache.set_fetch_groups(
|
||||
[
|
||||
{
|
||||
"intervals": cached,
|
||||
"fetch_time": dt_util.now().isoformat(),
|
||||
"fetched_at": dt_util.now(),
|
||||
}
|
||||
]
|
||||
pool._timestamp_index = {interval["startsAt"]: idx for idx, interval in enumerate(cached)} # noqa: SLF001
|
||||
)
|
||||
index.rebuild(cache.get_fetch_groups())
|
||||
|
||||
# Mock responses for the 3 expected gaps
|
||||
gap1 = _create_intervals(start + timedelta(minutes=30), 2) # 10:30-11:00
|
||||
|
|
@ -362,10 +365,10 @@ async def test_detect_missing_ranges_optimization() -> None:
|
|||
api_client.async_get_price_info_for_range = AsyncMock(side_effect=mock_fetch)
|
||||
|
||||
# Request entire range - should detect exactly 3 gaps
|
||||
result = await pool.get_intervals(api_client, user_data, start, end)
|
||||
intervals, _api_called = await pool.get_intervals(api_client, user_data, start, end)
|
||||
|
||||
# Assert: Exactly 3 API calls (one per gap)
|
||||
assert api_client.async_get_price_info_for_range.call_count == 3
|
||||
|
||||
# Verify all intervals are now available
|
||||
assert len(result) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals
|
||||
assert len(intervals) == 16 # 2 + 2 + 2 + 2 + 1 + 7 = 16 intervals
|
||||
|
|
|
|||
|
|
@ -541,6 +541,9 @@ class TestPriceComparison:
|
|||
cheap_stats = calculate_window_statistics(cheapest["intervals"])
|
||||
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
|
||||
|
||||
assert cheap_stats["price_mean"] is not None
|
||||
assert expensive_stats["price_mean"] is not None
|
||||
|
||||
spread_cheap_to_exp = expensive_stats["price_mean"] - cheap_stats["price_mean"]
|
||||
spread_exp_to_cheap = cheap_stats["price_mean"] - expensive_stats["price_mean"]
|
||||
|
||||
|
|
@ -582,6 +585,9 @@ class TestPriceComparison:
|
|||
cheap_stats = calculate_window_statistics(cheapest["intervals"])
|
||||
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
|
||||
|
||||
assert cheap_stats["price_mean"] is not None
|
||||
assert expensive_stats["price_mean"] is not None
|
||||
|
||||
spread = expensive_stats["price_mean"] - cheap_stats["price_mean"]
|
||||
assert abs(spread) < 0.0001
|
||||
|
||||
|
|
@ -598,5 +604,8 @@ class TestPriceComparison:
|
|||
cheap_stats = calculate_window_statistics(cheapest["intervals"])
|
||||
expensive_stats = calculate_window_statistics(most_expensive["intervals"])
|
||||
|
||||
assert cheap_stats["price_mean"] is not None
|
||||
assert expensive_stats["price_mean"] is not None
|
||||
|
||||
spread = expensive_stats["price_mean"] - cheap_stats["price_mean"]
|
||||
assert abs(spread) < 0.0001
|
||||
|
|
|
|||
|
|
@ -280,6 +280,7 @@ def test_hysteresis_sequence_simulation() -> None:
|
|||
rating = calculate_rating_level(
|
||||
diff, threshold_low, threshold_high, previous_rating=previous, hysteresis=hysteresis
|
||||
)
|
||||
assert rating is not None
|
||||
results_with.append(rating)
|
||||
previous = rating
|
||||
assert results_with == expected_with_hysteresis
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ class TestConfigEntryCleanup:
|
|||
coordinator._listener_manager = object.__new__(TibberPricesListenerManager) # noqa: SLF001
|
||||
coordinator._data_transformer = object.__new__(TibberPricesDataTransformer) # noqa: SLF001
|
||||
coordinator._period_calculator = object.__new__(TibberPricesPeriodCalculator) # noqa: SLF001
|
||||
coordinator._lifecycle_callbacks = [] # noqa: SLF001
|
||||
setattr(coordinator, "_lifecycle_callbacks", [])
|
||||
|
||||
# Manually call the registration that happens in __init__
|
||||
# This tests the pattern: entry.async_on_unload(entry.add_update_listener(...))
|
||||
|
|
@ -249,7 +249,7 @@ class TestCacheInvalidation:
|
|||
|
||||
# Create calculator with cached data
|
||||
calculator = object.__new__(TibberPricesPeriodCalculator)
|
||||
calculator._config_cache = {"some": "data"} # noqa: SLF001
|
||||
calculator._config_cache = {"best": {"some": "data"}} # noqa: SLF001
|
||||
calculator._config_cache_valid = True # noqa: SLF001
|
||||
calculator._cached_periods = {"cached": "periods"} # noqa: SLF001
|
||||
calculator._last_periods_hash = "some_hash" # noqa: SLF001
|
||||
|
|
|
|||
Loading…
Reference in a new issue