From f698388167dc59202786f6d6380c307b91aed674 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Fri, 20 Dec 2024 08:47:32 -0800 Subject: [PATCH] add a lightning jsonargparse epi-patch for edge cases where the currently released patch causes conflicts --- .azure-pipelines/gpu-tests.yml | 2 +- src/fts_examples/patching/dep_patch_shim.py | 21 +++++++++++++++++-- .../patched_lightning_jsonargparse.py | 15 +++++++++++++ src/fts_examples/test_examples.py | 1 + tests/.experiments | 1 + tests/special_tests.sh | 4 ++-- 6 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 src/fts_examples/patching/patched_lightning_jsonargparse.py diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index e490b99..650a755 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -94,7 +94,7 @@ jobs: - bash: | . /tmp/venvs/fts_dev/bin/activate - bash ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0" + bash ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0 1" displayName: 'Testing: experimental einsum patch' - bash: | diff --git a/src/fts_examples/patching/dep_patch_shim.py b/src/fts_examples/patching/dep_patch_shim.py index e2d4f76..8f23287 100644 --- a/src/fts_examples/patching/dep_patch_shim.py +++ b/src/fts_examples/patching/dep_patch_shim.py @@ -53,8 +53,13 @@ def _patch_triton(): target_mod = 'triton.runtime.jit' sys.modules.get(target_mod).__dict__.get('JITFunction').__init__ = _new_init +def _patch_lightning_jsonargparse(): + from fts_examples.patching.patched_lightning_jsonargparse import _updated_parse_known_args_patch + target_mod = 'lightning.pytorch.cli' + sys.modules.get(target_mod).__dict__.get('ArgumentParser')._parse_known_args = _updated_parse_known_args_patch -# required for `torch==2.5.x`, TBD wrt subsequent versions +# TODO: remove once `2.6.0` is minimum +# required for `torch==2.5.x`, TBD wrt subsequent versions though appears fixed in torch `2.6.0` nightlies einsum_strategies_patch = DependencyPatch( condition=(lwt_compare_version("torch", operator.le, "2.5.2"), lwt_compare_version("torch", operator.ge, "2.5.0"),), @@ -62,6 +67,17 @@ def _patch_triton(): function=_patch_einsum_strategies, patched_package='torch', description='Address trivial tp submesh limitation until PyTorch provides upstream fix') +# TODO: remove if lightning fixes `2.5.0` with a post or `2.6.0` is minimum +lightning_jsonargparse_patch = DependencyPatch( + condition=(lwt_compare_version("lightning", operator.eq, "2.5.0"), sys.version_info >= (3, 12, 8), + lwt_compare_version("jsonargparse", operator.ge, "4.35.0") ), + env_flag=OSEnvToggle("ENABLE_FTS_LIGHTNING_JSONARGPARSE_PATCH", default="1"), + function=_patch_lightning_jsonargparse, + patched_package='lightning', + description=('For the edge case where `lightning` patches `jsonargparse` in a manner that breaks ' + 'certain versions of `jsonargparse`') +) + # TODO: remove once `datasets==2.21.0` is minimum datasets_numpy_extractor_patch = DependencyPatch( condition=(lwt_compare_version("numpy", operator.ge, "2.0.0"), @@ -71,7 +87,7 @@ def _patch_triton(): patched_package='datasets', description='Adjust `NumpyArrowExtractor` to properly use `numpy` 2.0 copy semantics') -# only required for `torch==2.4.x` +# TODO: remove once `torch 2.5.0` is minimum, only required for `torch==2.4.x` triton_codgen_patch = DependencyPatch( condition=(lwt_compare_version("pytorch-triton", operator.eq, "3.0.0", "45fff310c8"),), env_flag=OSEnvToggle("ENABLE_FTS_TRITON_CODEGEN_PATCH", default="1"), @@ -82,6 +98,7 @@ class ExpPatch(Enum): EINSUM_STRATEGIES = einsum_strategies_patch NUMPY_EXTRACTOR = datasets_numpy_extractor_patch TRITON_CODEGEN = triton_codgen_patch + LIGHTNING_JSONARGPARSE = lightning_jsonargparse_patch _DEFINED_PATCHES = set(ExpPatch) _ACTIVE_PATCHES = set() diff --git a/src/fts_examples/patching/patched_lightning_jsonargparse.py b/src/fts_examples/patching/patched_lightning_jsonargparse.py new file mode 100644 index 0000000..42746ca --- /dev/null +++ b/src/fts_examples/patching/patched_lightning_jsonargparse.py @@ -0,0 +1,15 @@ +from fts_examples.patching._patch_utils import _prepare_module_ctx +from lightning.pytorch.cli import LightningCLI # noqa: F401 + + +globals().update(_prepare_module_ctx('lightning.pytorch.cli', globals())) + +# we ignore these for the entire file since we're using our global namespace trickeration to patch +# ruff: noqa: F821 +# pyright: reportUndefinedVariable=false + +def _updated_parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None, + intermixed: bool = False) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, + intermixed=intermixed) # type: ignore + return namespace, args diff --git a/src/fts_examples/test_examples.py b/src/fts_examples/test_examples.py index b70b8d1..f3bc90b 100644 --- a/src/fts_examples/test_examples.py +++ b/src/fts_examples/test_examples.py @@ -56,6 +56,7 @@ "You are using `torch.load` with `weights_only=False`", # required for datasets <= 2.20.0 with python 3.12 'co_lnotab is deprecated, use co_lines instead.', + "is multi-threaded, use of fork", # expected with some tests that use fork and python >= 3.12.8 ] EXPECTED_WARNS.extend(ALL_EXAMPLE_EXPECTED) diff --git a/tests/.experiments b/tests/.experiments index 1b9f259..aa3ff2f 100644 --- a/tests/.experiments +++ b/tests/.experiments @@ -1,3 +1,4 @@ ENABLE_FTS_EINSUM_STRATEGY_PATCH ENABLE_FTS_NUMPY_EXTRACTOR_PATCH ENABLE_FTS_TRITON_CODEGEN_PATCH +ENABLE_FTS_LIGHTNING_JSONARGPARSE_PATCH diff --git a/tests/special_tests.sh b/tests/special_tests.sh index b81d9cb..6b050e1 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -46,9 +46,9 @@ Usage: $0 # ./tests/special_tests.sh --mark_type=standalone --log_file=/tmp/some_parent_process_file_to_append_to.log # run all experimental tests following a pattern that are supported by a given experimental patch mask using the # default `tests/.experiments` experiments definition location: - # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0" + # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0 1" # same as above, but use a custom experiments definition location: - # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='model_parallel' --experiments_list=tests/.my_experiments --experiment_patch_mask="1 0 0" + # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='model_parallel' --experiments_list=tests/.my_experiments --experiment_patch_mask="1 0 0 1" EOF exit 1 }