Skip to content

Commit

Permalink
add a lightning jsonargparse epi-patch for edge cases where the curre…
Browse files Browse the repository at this point in the history
…ntly released patch causes conflicts
  • Loading branch information
speediedan committed Dec 20, 2024
1 parent 690d9f0 commit f698388
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
- bash: |
. /tmp/venvs/fts_dev/bin/activate
bash ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0"
bash ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0 1"
displayName: 'Testing: experimental einsum patch'
- bash: |
Expand Down
21 changes: 19 additions & 2 deletions src/fts_examples/patching/dep_patch_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,31 @@ def _patch_triton():
target_mod = 'triton.runtime.jit'
sys.modules.get(target_mod).__dict__.get('JITFunction').__init__ = _new_init

def _patch_lightning_jsonargparse():
from fts_examples.patching.patched_lightning_jsonargparse import _updated_parse_known_args_patch
target_mod = 'lightning.pytorch.cli'
sys.modules.get(target_mod).__dict__.get('ArgumentParser')._parse_known_args = _updated_parse_known_args_patch

# required for `torch==2.5.x`, TBD wrt subsequent versions
# TODO: remove once `2.6.0` is minimum
# required for `torch==2.5.x`, TBD wrt subsequent versions though appears fixed in torch `2.6.0` nightlies
einsum_strategies_patch = DependencyPatch(
condition=(lwt_compare_version("torch", operator.le, "2.5.2"),
lwt_compare_version("torch", operator.ge, "2.5.0"),),
env_flag=OSEnvToggle("ENABLE_FTS_EINSUM_STRATEGY_PATCH", default="0"),
function=_patch_einsum_strategies, patched_package='torch',
description='Address trivial tp submesh limitation until PyTorch provides upstream fix')

# TODO: remove if lightning fixes `2.5.0` with a post or `2.6.0` is minimum
lightning_jsonargparse_patch = DependencyPatch(
condition=(lwt_compare_version("lightning", operator.eq, "2.5.0"), sys.version_info >= (3, 12, 8),
lwt_compare_version("jsonargparse", operator.ge, "4.35.0") ),
env_flag=OSEnvToggle("ENABLE_FTS_LIGHTNING_JSONARGPARSE_PATCH", default="1"),
function=_patch_lightning_jsonargparse,
patched_package='lightning',
description=('For the edge case where `lightning` patches `jsonargparse` in a manner that breaks '
'certain versions of `jsonargparse`')
)

# TODO: remove once `datasets==2.21.0` is minimum
datasets_numpy_extractor_patch = DependencyPatch(
condition=(lwt_compare_version("numpy", operator.ge, "2.0.0"),
Expand All @@ -71,7 +87,7 @@ def _patch_triton():
patched_package='datasets',
description='Adjust `NumpyArrowExtractor` to properly use `numpy` 2.0 copy semantics')

# only required for `torch==2.4.x`
# TODO: remove once `torch 2.5.0` is minimum, only required for `torch==2.4.x`
triton_codgen_patch = DependencyPatch(
condition=(lwt_compare_version("pytorch-triton", operator.eq, "3.0.0", "45fff310c8"),),
env_flag=OSEnvToggle("ENABLE_FTS_TRITON_CODEGEN_PATCH", default="1"),
Expand All @@ -82,6 +98,7 @@ class ExpPatch(Enum):
EINSUM_STRATEGIES = einsum_strategies_patch
NUMPY_EXTRACTOR = datasets_numpy_extractor_patch
TRITON_CODEGEN = triton_codgen_patch
LIGHTNING_JSONARGPARSE = lightning_jsonargparse_patch

_DEFINED_PATCHES = set(ExpPatch)
_ACTIVE_PATCHES = set()
Expand Down
15 changes: 15 additions & 0 deletions src/fts_examples/patching/patched_lightning_jsonargparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fts_examples.patching._patch_utils import _prepare_module_ctx
from lightning.pytorch.cli import LightningCLI # noqa: F401


globals().update(_prepare_module_ctx('lightning.pytorch.cli', globals()))

# we ignore these for the entire file since we're using our global namespace trickeration to patch
# ruff: noqa: F821
# pyright: reportUndefinedVariable=false

def _updated_parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None,
intermixed: bool = False) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace,
intermixed=intermixed) # type: ignore
return namespace, args
1 change: 1 addition & 0 deletions src/fts_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"You are using `torch.load` with `weights_only=False`",
# required for datasets <= 2.20.0 with python 3.12
'co_lnotab is deprecated, use co_lines instead.',
"is multi-threaded, use of fork", # expected with some tests that use fork and python >= 3.12.8
]

EXPECTED_WARNS.extend(ALL_EXAMPLE_EXPECTED)
Expand Down
1 change: 1 addition & 0 deletions tests/.experiments
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ENABLE_FTS_EINSUM_STRATEGY_PATCH
ENABLE_FTS_NUMPY_EXTRACTOR_PATCH
ENABLE_FTS_TRITON_CODEGEN_PATCH
ENABLE_FTS_LIGHTNING_JSONARGPARSE_PATCH
4 changes: 2 additions & 2 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ Usage: $0
# ./tests/special_tests.sh --mark_type=standalone --log_file=/tmp/some_parent_process_file_to_append_to.log
# run all experimental tests following a pattern that are supported by a given experimental patch mask using the
# default `tests/.experiments` experiments definition location:
# ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0"
# ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0 1"
# same as above, but use a custom experiments definition location:
# ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='model_parallel' --experiments_list=tests/.my_experiments --experiment_patch_mask="1 0 0"
# ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='model_parallel' --experiments_list=tests/.my_experiments --experiment_patch_mask="1 0 0 1"
EOF
exit 1
}
Expand Down

0 comments on commit f698388

Please sign in to comment.