Skip to content

Commit

Permalink
Merge pull request #9 from paul-ollis/xdist
Browse files Browse the repository at this point in the history
Add support for pytest-xdist for **much faster** Textual tests.
  • Loading branch information
willmcgugan authored Jul 22, 2024
2 parents 380386c + 4a9e1f7 commit 7222778
Showing 1 changed file with 150 additions and 45 deletions.
195 changes: 150 additions & 45 deletions pytest_textual_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import os
import pickle
import re
import shutil
from dataclasses import dataclass
from datetime import datetime
from operator import attrgetter
from os import PathLike
from pathlib import Path, PurePath
from tempfile import mkdtemp
from typing import Awaitable, Union, List, Optional, Callable, Iterable, TYPE_CHECKING

import pytest
Expand All @@ -16,14 +20,61 @@
from jinja2 import Template
from rich.console import Console
from syrupy import SnapshotAssertion
from syrupy.extensions.single_file import (
SingleFileSnapshotExtension, WriteMode)

if TYPE_CHECKING:
from _pytest.nodes import Item
from textual.app import App
from textual.pilot import Pilot

TEXTUAL_SNAPSHOT_SVG_KEY = pytest.StashKey[str]()
TEXTUAL_ACTUAL_SVG_KEY = pytest.StashKey[str]()
TEXTUAL_SNAPSHOT_PASS = pytest.StashKey[bool]()

class SVGImageExtension(SingleFileSnapshotExtension):
_file_extension = "svg"
_write_mode = WriteMode.TEXT


class TemporaryDirectory:
"""A temporary that survives forking.
This provides something akin to tempfile.TemporaryDirectory, but this
version is not removed automatically when a process exits.
"""

def __init__(self, name: str = ''):
if name:
self.name = name
else:
self.name = mkdtemp(None, None, None)

def cleanup(self):
"""Clean up the temporary directory."""
shutil.rmtree(self.name, ignore_errors=True)


@dataclass
class PseudoConsole:
"""Something that looks enough like a Console to fill a Jinja2 template."""

legacy_windows: bool
size: ConsoleDimensions


@dataclass
class PseudoApp:
"""Something that looks enough like an App to fill a Jinja2 template.
This can be pickled OK, whereas the 'real' application involved in a test
may contain unpickleable data.
"""

console: PseudoConsole


def rename_styles(svg: str, suffix: str) -> str:
"""Rename style names to prevent clashes when combined in HTML report."""
return re.sub(
r'terminal-(\d+)-r(\d+)', rf'terminal-\1-r\2-{suffix}', svg)


def pytest_addoption(parser):
Expand All @@ -39,6 +90,24 @@ def app_stash_key() -> pytest.StashKey:
app_stash_key._key = pytest.StashKey[App]()
return app_stash_key()


def node_to_report_path(node: Item) -> Path:
"""Generate a report file name for a test node."""
tempdir = get_tempdir()
path, _, name = node.reportinfo()
temp = Path(path.parent)
base = []
while temp != temp.parent and temp.name != 'tests':
base.append(temp.name)
temp = temp.parent
parts = []
if base:
parts.append('_'.join(reversed(base)))
parts.append(path.name.replace('.', '_'))
parts.append(name.replace('[', '_').replace(']', '_'))
return Path(tempdir.name) / '_'.join(parts)


@pytest.fixture
def snap_compare(
snapshot: SnapshotAssertion, request: FixtureRequest
Expand All @@ -48,6 +117,8 @@ def snap_compare(
app with the output of the same app in the past. This is snapshot testing, and it
used to catch regressions in output.
"""
# Switch so one file per snapshot, stored as plain simple SVG file.
snapshot = snapshot.use_extension(SVGImageExtension)

def compare(
app_path: str | PurePath,
Expand Down Expand Up @@ -93,17 +164,18 @@ def compare(
terminal_size=terminal_size,
run_before=run_before,
)
console = Console(legacy_windows=False, force_terminal=True)
p_app = PseudoApp(PseudoConsole(console.legacy_windows, console.size))

result = snapshot == actual_screenshot
expected_svg_text = str(snapshot)
full_path, line_number, name = request.node.reportinfo()

if result is False:
# The split and join below is a mad hack, sorry...
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(
str(snapshot).splitlines()[1:-1]
)
node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot
node.stash[app_stash_key()] = app
else:
node.stash[TEXTUAL_SNAPSHOT_PASS] = True
data = (
result, expected_svg_text, actual_screenshot, p_app, full_path,
line_number, name)
data_path = node_to_report_path(request.node)
data_path.write_bytes(pickle.dumps(data))

return result

Expand All @@ -125,37 +197,69 @@ class SvgSnapshotDiff:
environment: dict


def pytest_sessionstart(
session: Session,
) -> None:
"""Set up a temporary directory to store snapshots.
The temporary directory name is stored in an environment vairable so that
pytest-xdist worker child processes can retrieve it.
"""
if os.environ.get('PYTEST_XDIST_WORKER') is None:
tempdir = TemporaryDirectory()
os.environ['TEXTUAL_SNAPSHOT_TEMPDIR'] = tempdir.name


def get_tempdir():
"""Get the TemporaryDirectory."""
return TemporaryDirectory(os.environ['TEXTUAL_SNAPSHOT_TEMPDIR'])


def pytest_sessionfinish(
session: Session,
exitstatus: Union[int, ExitCode],
) -> None:
"""Called after whole test run finished, right before returning the exit status to the system.
Generates the snapshot report and writes it to disk.
"""
diffs: List[SvgSnapshotDiff] = []
num_snapshots_passing = 0

for item in session.items:
# Grab the data our fixture attached to the pytest node
num_snapshots_passing += int(item.stash.get(TEXTUAL_SNAPSHOT_PASS, False))
snapshot_svg = item.stash.get(TEXTUAL_SNAPSHOT_SVG_KEY, None)
actual_svg = item.stash.get(TEXTUAL_ACTUAL_SVG_KEY, None)
app = item.stash.get(app_stash_key(), None)

if app:
path, line_index, name = item.reportinfo()
diffs.append(
SvgSnapshotDiff(
snapshot=str(snapshot_svg),
actual=str(actual_svg),
test_name=name,
path=path,
line_number=line_index + 1,
app=app,
environment=dict(os.environ),
)
)
if os.environ.get('PYTEST_XDIST_WORKER') is None:
tempdir = get_tempdir()
diffs, num_snapshots_passing = retrieve_svg_diffs(tempdir)
save_svg_diffs(diffs, session, num_snapshots_passing)
tempdir.cleanup()


def retrieve_svg_diffs(
tempdir: TemporaryDirectory,
) -> tuple[list[SvgSnapshotDiff], int]:
"""Retrieve snapshot diffs from the temporary directory."""
diffs: list[SvgSnapshotDiff] = []
pass_count = 0

n = 0
for data_path in Path(tempdir.name).iterdir():
(passed, expect_svg_text, svg_text, app, full_path, line_index, name
) = pickle.loads(data_path.read_bytes())
pass_count += 1 if passed else 0
if not passed:
n += 1
diffs.append(SvgSnapshotDiff(
snapshot=rename_styles(str(expect_svg_text), f'exp{n}'),
actual=rename_styles(svg_text, f'act{n}'),
test_name=name,
path=full_path,
line_number=line_index + 1,
app=app,
environment=dict(os.environ)))
return diffs, pass_count


def save_svg_diffs(
diffs: list[SvgSnapshotDiff],
session: Session,
num_snapshots_passing: int,
) -> None:
"""Save any detected differences to an HTML formatted report."""
if diffs:
diff_sort_key = attrgetter("test_name")
diffs = sorted(diffs, key=diff_sort_key)
Expand Down Expand Up @@ -198,13 +302,14 @@ def pytest_terminal_summary(
"""Add a section to terminal summary reporting.
Displays the link to the snapshot report that was generated in a prior hook.
"""
diffs = getattr(config, "_textual_snapshots", None)
console = Console(legacy_windows=False, force_terminal=True)
if diffs:
snapshot_report_location = config._textual_snapshot_html_report
console.print("[b red]Textual Snapshot Report", style="red")
console.print(
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
)
console.print(f"[dim]{snapshot_report_location}\n")
if os.environ.get('PYTEST_XDIST_WORKER') is None:
diffs = getattr(config, "_textual_snapshots", None)
console = Console(legacy_windows=False, force_terminal=True)
if diffs:
snapshot_report_location = config._textual_snapshot_html_report
console.print("[b red]Textual Snapshot Report", style="red")
console.print(
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
)
console.print(f"[dim]{snapshot_report_location}\n")

0 comments on commit 7222778

Please sign in to comment.