From cb62be6b2aedfdd4c89e14b318419d19c72f7f9c Mon Sep 17 00:00:00 2001 From: Tom Vercauteren Date: Thu, 26 Sep 2024 22:49:02 +0100 Subject: [PATCH] bumping python and pytorch version in CI --- .github/workflows/python-package.yml | 4 ++-- torchsparsegradutils/indexed_matmul.py | 6 ++++++ torchsparsegradutils/tests/test_indexed_matmul.py | 5 +++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 6164b04..f11e4f6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,8 +12,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.1", "2.0.1"] + python-version: ["3.8", "3.10", "3.12"] + torch-version: ["1.13.1", "2.4.1"] steps: - uses: actions/checkout@v3 diff --git a/torchsparsegradutils/indexed_matmul.py b/torchsparsegradutils/indexed_matmul.py index 298fd6b..f6dc153 100644 --- a/torchsparsegradutils/indexed_matmul.py +++ b/torchsparsegradutils/indexed_matmul.py @@ -27,6 +27,9 @@ def segment_mm(a, b, seglen_a): Returns: torch.Tensor: The output dense matrix of shape ``(N, D2)`` """ + if torch.__version__ < (2, 4): + raise NotImplementedError("PyTorch version is too old for nested tesors") + if dgl_installed: # DGL is probably more computationally efficient # See https://github.com/pytorch/pytorch/issues/136747 @@ -74,6 +77,9 @@ def gather_mm(a, b, idx_b): Returns: torch.Tensor: The output dense matrix of shape ``(N, D2)`` """ + if torch.__version__ < (2, 4): + raise NotImplementedError("PyTorch version is too old for nested tesors") + if dgl_installed: # DGL is more computationally efficient # See https://github.com/pytorch/pytorch/issues/136747 diff --git a/torchsparsegradutils/tests/test_indexed_matmul.py b/torchsparsegradutils/tests/test_indexed_matmul.py index 491f143..7c9a527 100644 --- a/torchsparsegradutils/tests/test_indexed_matmul.py +++ b/torchsparsegradutils/tests/test_indexed_matmul.py @@ -1,6 +1,11 @@ import torch import pytest +if torch.__version__ < (2, 4): + pytest.skip( + "Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True + ) + from torchsparsegradutils import gather_mm, segment_mm # Identify Testing Parameters