Skip to content

Commit

Permalink
bugfix: Mutable types in DataClass
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed May 21, 2024
1 parent 830e16e commit dc23009
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions plugins/extract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class ExtractorBatch:
image: list[np.ndarray] = field(default_factory=list)
detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list)
filename: list[str] = field(default_factory=list)
feed: np.ndarray = np.array([])
prediction: np.ndarray = np.array([])
feed: np.ndarray = field(default_factory=lambda: np.array([]))
prediction: np.ndarray = field(default_factory=lambda: np.array([]))
data: list[dict[str, T.Any]] = field(default_factory=list)

def __repr__(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions plugins/extract/align/_base/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ class AlignerBatch(ExtractorBatch):
"""
batch_id: int = 0
detected_faces: list[DetectedFace] = field(default_factory=list)
landmarks: np.ndarray = np.array([])
landmarks: np.ndarray = field(default_factory=lambda: np.array([]))
refeeds: list[np.ndarray] = field(default_factory=list)
second_pass: bool = False
second_pass_masks: np.ndarray = np.array([])
second_pass_masks: np.ndarray = field(default_factory=lambda: np.array([]))

def __repr__(self):
""" Prettier repr for debug printing """
Expand Down
2 changes: 1 addition & 1 deletion plugins/extract/detect/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DetectorBatch(ExtractorBatch):
rotation_matrix: list[np.ndarray] = field(default_factory=list)
scale: list[float] = field(default_factory=list)
pad: list[tuple[int, int]] = field(default_factory=list)
initial_feed: np.ndarray = np.array([])
initial_feed: np.ndarray = field(default_factory=lambda: np.array([]))

def __repr__(self):
""" Prettier repr for debug printing """
Expand Down

0 comments on commit dc23009

Please sign in to comment.