From e92a7dc466cef81b041740978cc047594fea6b4d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:19:30 -0500 Subject: [PATCH] fix keras/tf dependency specification (#320) --- pyproject.toml | 6 ++++-- scikeras/_saving_utils.py | 5 ++--- tests/conftest.py | 9 ++++----- tests/test_compile_kwargs.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b9f5cafa..f4e3ef1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,8 @@ version = "0.13.0" [tool.poetry.dependencies] python = ">=3.9.0,<4" -scikit-learn = ">=1.4.1.post1" -keras = { git = "https://github.com/keras-team/keras.git", rev = "master" } +scikit-learn = ">=1.4.2" +keras = ">=3.2.0" tensorflow = { version = ">=2.16.1", optional = true } [tool.poetry.extras] @@ -51,6 +51,8 @@ pre-commit = ">=2.20.0" pytest = ">=7.1.2" pytest-cov = ">=3.0.0" sphinx = ">=5.0.2" +tensorflow = ">=2.16.1" +tensorflow-io-gcs-filesystem = { version = "<=0.31.0", markers = "python_version < '3.12' and sys_platform == 'win32'" } [tool.ruff] select = [ diff --git a/scikeras/_saving_utils.py b/scikeras/_saving_utils.py index 7b2b5869..f10ae21d 100644 --- a/scikeras/_saving_utils.py +++ b/scikeras/_saving_utils.py @@ -3,9 +3,8 @@ import keras as keras import keras.saving -import keras.saving.object_registration import numpy as np -from keras.saving.saving_lib import load_model, save_model +from keras.src.saving.saving_lib import load_model, save_model def unpack_keras_model( @@ -25,7 +24,7 @@ def pack_keras_model( """Support for Pythons's Pickle protocol.""" tp = type(model) out = BytesIO() - if tp not in keras.saving.object_registration.GLOBAL_CUSTOM_OBJECTS: + if tp not in keras.saving.get_custom_objects(): module = ".".join(tp.__qualname__.split(".")[:-1]) name = tp.__qualname__.split(".")[-1] keras.saving.register_keras_serializable(module, name)(tp) diff --git a/tests/conftest.py b/tests/conftest.py index 830e27d2..ee03cfac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,7 @@ from typing import TYPE_CHECKING, Iterator import pytest -from keras.backend import config as backend_config -from keras.backend import set_floatx +from keras import backend if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest @@ -12,10 +11,10 @@ @pytest.fixture(autouse=True) def set_floatx_and_backend_config(request: FixtureRequest) -> Iterator[None]: - current = backend_config.floatx() + current = backend.floatx() floatx = getattr(request, "param", "float32") - set_floatx(floatx) + backend.set_floatx(floatx) try: yield finally: - set_floatx(current) + backend.set_floatx(current) diff --git a/tests/test_compile_kwargs.py b/tests/test_compile_kwargs.py index f0b803a6..29cbcb8e 100644 --- a/tests/test_compile_kwargs.py +++ b/tests/test_compile_kwargs.py @@ -5,9 +5,9 @@ from keras import losses as losses_module from keras import metrics as metrics_module from keras import optimizers as optimizers_module -from keras.backend.common.variables import KerasVariable from keras.layers import Dense, Input from keras.models import Model +from keras.src.backend.common.variables import KerasVariable from sklearn.datasets import make_classification from scikeras.wrappers import KerasClassifier