Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
syleshfb committed Jul 22, 2024
1 parent 38d4e24 commit 935e827
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
get_last_failed_example as get_last_failed_example,
)
from fairseq2.data.data_pipeline import list_files as list_files
from fairseq2.data.data_pipeline import read_sequence as read_sequence
from fairseq2.data.data_pipeline import read_iterator as read_iterator
from fairseq2.data.data_pipeline import read_sequence as read_sequence
from fairseq2.data.data_pipeline import read_zipped_records as read_zipped_records
from fairseq2.data.memory import MemoryBlock as MemoryBlock
from fairseq2.data.vocabulary_info import VocabularyInfo as VocabularyInfo
12 changes: 10 additions & 2 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
final,
)
Expand Down Expand Up @@ -386,7 +388,13 @@ def read_zipped_records(path: Path) -> DataPipelineBuilder:
"""Read each file in a zip archive"""
...

def read_iterator(iterator: Iterator[Any], reset_fn: Callable[[Iterator], None], infinite: bool) -> DataPipelineBuilder:
T = TypeVar("T", bound=Iterator[Any])

def read_iterator(
iterator: T,
reset_fn: Callable[[T], None],
infinite: bool,
) -> DataPipelineBuilder:
"""Read each element of ``iterator``.
:param iterator:
Expand Down Expand Up @@ -536,8 +544,8 @@ class RecordError(RuntimeError):
get_last_failed_example as get_last_failed_example,
)
from fairseq2n.bindings.data.data_pipeline import list_files as list_files
from fairseq2n.bindings.data.data_pipeline import read_sequence as read_sequence
from fairseq2n.bindings.data.data_pipeline import read_iterator as read_iterator
from fairseq2n.bindings.data.data_pipeline import read_sequence as read_sequence
from fairseq2n.bindings.data.data_pipeline import (
read_zipped_records as read_zipped_records,
)
Expand Down
29 changes: 12 additions & 17 deletions tests/unit/data/data_pipeline/test_read_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,44 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterator, Union

import pytest
from typing_extensions import Self

from fairseq2.data import DataPipelineError, read_iterator, read_sequence


class DefaultIterator:
def __init__(self):
self.i = 0

def reset(self):
class DefaultIterator(Iterator[int]):
def __init__(self) -> None:
self.i = 0

def __iter__(self):
self.reset()
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> int:
ret = self.i
self.i += 1
return ret


class EarlyStopIterator:
def __init__(self, n):
class EarlyStopIterator(Iterator[int]):
def __init__(self, n: int):
self.i = 0
self.n = n

def reset(self):
self.i = 0

def __iter__(self):
self.reset()
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> int:
ret = self.i
if ret >= self.n:
raise StopIteration
self.i += 1
return ret


def reset_fn(iterator):
def reset_fn(iterator: Union[DefaultIterator, EarlyStopIterator]) -> None:
iterator.i = 0


Expand Down

0 comments on commit 935e827

Please sign in to comment.