from __future__ import annotations
import dataclasses
from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import Any, TypeVar
import numpy.testing as npt
import pytest
from .compare import DEFAULT_ATOL, DEFAULT_RTOL, compare_intelligent
from .io import (
SENTINEL,
_cache_failed_test,
_store_test_diff,
_uncache_test,
read_snapshot,
snapshot_directory,
snapshot_filename,
write_snapshot,
)
F = TypeVar("F", bound=Callable[..., Any])
[docs]
def auto_update(method: F) -> F:
"""
Decorator that handles snapshot updates and comparisons for testing functions.
Parameters
----------
method
The testing function to wrap.
Raises
------
AssertionError
If snapshot not found and ``snaptol_update`` is ``False``.
"""
@wraps(method)
def wrapper(snapshot: Snapshot, value: Any, *args, **kwargs):
# Do the comparison and store any exceptions for later.
comparison_matched = False
caught_exception = None
problem_found = False
if snapshot.snapshot_found:
try:
method(value, snapshot.expected, *args, **kwargs)
comparison_matched = True
except AssertionError as exc:
# The comparison has not matched, this is only a problem if we are NOT in update mode.
caught_exception = exc
problem_found = not snapshot.snaptol_update
except TypeError as exc:
caught_exception = exc
problem_found = not snapshot.snaptol_update
elif not snapshot.snaptol_update:
# If we are in update mode, we don't care that the snapshot is missing.
caught_exception = FileNotFoundError("Snapshot file not found.")
problem_found = True
if snapshot.snaptol_update:
write_snapshot(snapshot.snapshot_file, value)
_uncache_test(snapshot.cache, snapshot.nodeid)
# Show a diff if requested and if a difference exists.
if snapshot.show_diff and not comparison_matched:
_store_test_diff(
snapshot.config,
snapshot.snapshot_file,
before=snapshot.expected if snapshot.snapshot_found else SENTINEL,
after=value,
)
if problem_found:
_cache_failed_test(
snapshot.cache, snapshot.nodeid, snapshot.snapshot_file, value
)
raise caught_exception from None
return True
return wrapper
[docs]
@dataclasses.dataclass
class Snapshot:
nodeid: str
snapshot_file: Path
snapshot_dir: Path
snaptol_update: bool = False
snapshot_found: bool = False
show_diff: bool = False
rtol: float = DEFAULT_RTOL
atol: float = DEFAULT_ATOL
equal_nan: bool = False
expected: Any = dataclasses.field(init=False, repr=False)
cache: pytest.Cache = None
config: pytest.Config = None
[docs]
@classmethod
def from_request(cls, request) -> Snapshot:
"""
Create a ``Snapshot`` instance from a pytest request object. Returns
the instansiated ``Snapshot`` object.
Parameters
----------
request
The pytest request fixture containing test information.
"""
nodeid = request.node.nodeid
snapshot_file = snapshot_filename(
request.node.nodeid, test_dir=Path(request.fspath).parent
)
snapshot_dir = snapshot_directory(test_dir=Path(request.fspath).parent)
snaptol_update = request.config.getoption(
"--snaptol-update"
) or request.config.getoption("--snaptol-update-all")
show_diff = request.config.getoption("--snaptol-show-diff")
cache = request.config.cache
config = request.config
return cls(
nodeid=nodeid,
snapshot_file=snapshot_file,
snapshot_dir=snapshot_dir,
snaptol_update=snaptol_update,
show_diff=show_diff,
cache=cache,
config=config,
)
def __post_init__(self) -> None:
try:
self.expected = read_snapshot(self.snapshot_file)
self.snapshot_found = True
except FileNotFoundError:
self.expected = None
self.snapshot_found = False
def __eq__(self, value: Any) -> bool:
# Do the comparison and store any exceptions for later.
comparison_matched = False
caught_exception = None
problem_found = False
if self.snapshot_found:
try:
comparison_matched = compare_intelligent(
value, self.expected, self.rtol, self.atol, self.equal_nan
)
# If the comparison has not matched, this is only a problem if we are NOT in update mode.
if not comparison_matched:
problem_found = not self.snaptol_update
except TypeError as exc:
caught_exception = exc
problem_found = not self.snaptol_update
elif not self.snaptol_update:
# If we are in update mode, we don't care that the snapshot is missing.
caught_exception = FileNotFoundError("Snapshot file not found.")
problem_found = True
if self.snaptol_update:
write_snapshot(self.snapshot_file, value)
_uncache_test(self.cache, self.nodeid)
# Show a diff if requested and if a difference exists.
if self.show_diff and not comparison_matched:
_store_test_diff(
self.config,
self.snapshot_file,
before=self.expected if self.snapshot_found else SENTINEL,
after=value,
)
if problem_found:
_cache_failed_test(self.cache, self.nodeid, self.snapshot_file, value)
if caught_exception is not None:
raise caught_exception from None
return False
return True
def __hash__(self):
return hash(self.nodeid)
def __call__(
self, *, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL
) -> Snapshot:
return dataclasses.replace(self, rtol=rtol, atol=atol)
[docs]
def match(
self, value, *, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL
) -> bool:
"""
Compare a value with the stored snapshot. Returns ``True`` if the values match, ``False`` otherwise.
Parameters
----------
value
The value to compare with the snapshot.
rtol
Relative tolerance for comparison.
atol
Absolute tolerance for comparison.
"""
return self(rtol=rtol, atol=atol) == value
[docs]
def matches(self, *args, **kwargs) -> bool:
"""
Alias for match() method. Compare a value with the stored snapshot.
"""
return self.match(*args, **kwargs)
assert_allclose = auto_update(npt.assert_allclose)
assert_array_almost_equal_nulp = auto_update(npt.assert_array_almost_equal_nulp)
assert_array_max_ulp = auto_update(npt.assert_array_max_ulp)
assert_array_equal = auto_update(npt.assert_array_equal)
assert_equal = auto_update(npt.assert_equal)
assert_string_equal = auto_update(npt.assert_string_equal)