Skip to content

Commit

Permalink
Add compatibility for numpy 2 while preserving numpy 1 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
neilmehta24 authored and rlouf committed Nov 27, 2024
1 parent 7a9baad commit 63b4feb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
23 changes: 17 additions & 6 deletions outlines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@
from typing import Callable, Optional

import numpy as np
from numpy.lib.function_base import (
_calculate_shapes,
_parse_gufunc_signature,
_parse_input_dimensions,
_update_dim_sizes,
)

# Import required functions based on NumPy version
np_major_version = int(np.__version__.split(".")[0])
if np_major_version >= 2:
from numpy.lib._function_base_impl import (
_calculate_shapes,
_parse_gufunc_signature,
_parse_input_dimensions,
_update_dim_sizes,
)
else:
from numpy.lib.function_base import (
_calculate_shapes,
_parse_gufunc_signature,
_parse_input_dimensions,
_update_dim_sizes,
)

# Allow nested loops for running in notebook. We don't enable it globally as it
# may interfere with other libraries that use asyncio.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"jinja2",
"lark",
"nest_asyncio",
"numpy<2.0.0",
"numpy",
"cloudpickle",
"diskcache",
"pydantic>=2.0",
Expand Down

0 comments on commit 63b4feb

Please sign in to comment.