Skip to content

Commit

Permalink
fix keras/tf dependency specification (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Apr 11, 2024
1 parent 390789f commit e92a7dc
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ version = "0.13.0"

[tool.poetry.dependencies]
python = ">=3.9.0,<4"
scikit-learn = ">=1.4.1.post1"
keras = { git = "https://github.com/keras-team/keras.git", rev = "master" }
scikit-learn = ">=1.4.2"
keras = ">=3.2.0"
tensorflow = { version = ">=2.16.1", optional = true }

[tool.poetry.extras]
Expand All @@ -51,6 +51,8 @@ pre-commit = ">=2.20.0"
pytest = ">=7.1.2"
pytest-cov = ">=3.0.0"
sphinx = ">=5.0.2"
tensorflow = ">=2.16.1"
tensorflow-io-gcs-filesystem = { version = "<=0.31.0", markers = "python_version < '3.12' and sys_platform == 'win32'" }

[tool.ruff]
select = [
Expand Down
5 changes: 2 additions & 3 deletions scikeras/_saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import keras as keras
import keras.saving
import keras.saving.object_registration
import numpy as np
from keras.saving.saving_lib import load_model, save_model
from keras.src.saving.saving_lib import load_model, save_model


def unpack_keras_model(
Expand All @@ -25,7 +24,7 @@ def pack_keras_model(
"""Support for Pythons's Pickle protocol."""
tp = type(model)
out = BytesIO()
if tp not in keras.saving.object_registration.GLOBAL_CUSTOM_OBJECTS:
if tp not in keras.saving.get_custom_objects():
module = ".".join(tp.__qualname__.split(".")[:-1])
name = tp.__qualname__.split(".")[-1]
keras.saving.register_keras_serializable(module, name)(tp)
Expand Down
9 changes: 4 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
from typing import TYPE_CHECKING, Iterator

import pytest
from keras.backend import config as backend_config
from keras.backend import set_floatx
from keras import backend

if TYPE_CHECKING:
from _pytest.fixtures import FixtureRequest


@pytest.fixture(autouse=True)
def set_floatx_and_backend_config(request: FixtureRequest) -> Iterator[None]:
current = backend_config.floatx()
current = backend.floatx()
floatx = getattr(request, "param", "float32")
set_floatx(floatx)
backend.set_floatx(floatx)
try:
yield
finally:
set_floatx(current)
backend.set_floatx(current)
2 changes: 1 addition & 1 deletion tests/test_compile_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from keras import losses as losses_module
from keras import metrics as metrics_module
from keras import optimizers as optimizers_module
from keras.backend.common.variables import KerasVariable
from keras.layers import Dense, Input
from keras.models import Model
from keras.src.backend.common.variables import KerasVariable
from sklearn.datasets import make_classification

from scikeras.wrappers import KerasClassifier
Expand Down

0 comments on commit e92a7dc

Please sign in to comment.