diff --git a/ci/notebook_list.py b/ci/notebook_list.py index b3a235f146..55fdd33735 100644 --- a/ci/notebook_list.py +++ b/ci/notebook_list.py @@ -31,6 +31,16 @@ def skip_book_dir(runtype): return runtype in runtype_dict and Path(runtype_dict.get(runtype)).is_file() +def _get_cuda_version_string(): + status, version = runtime.getLocalRuntimeVersion() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA runtime version.") + major, minor = divmod(version, 1000) + minor //= 10 + return f"{major}.{minor}" + +cuda_version_string = _get_cuda_version_string() + parser = argparse.ArgumentParser(description="Condition for running the notebook tests") parser.add_argument("runtype", type=str)