Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix-numpy-column-pr…
Browse files Browse the repository at this point in the history
…ojection
  • Loading branch information
rjzamora committed Jan 9, 2025
2 parents d677414 + ecb0cdd commit c7b2552
Show file tree
Hide file tree
Showing 24 changed files with 147 additions and 785 deletions.
88 changes: 15 additions & 73 deletions dask/_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,11 @@ def __contains__(self, o: object) -> bool:
SubgraphType = None


def _execute_subgraph(inner_dsk, outkey, inkeys, external_deps):
def _execute_subgraph(inner_dsk, outkey, inkeys):
final = {}
final.update(inner_dsk)
for k, v in inkeys.items():
final[k] = DataNode(None, v)
for k, v in external_deps.items():
final[k] = DataNode(None, v)
res = execute_graph(final, keys=[outkey])
return res[outkey]

Expand All @@ -214,72 +212,17 @@ def convert_legacy_task(
if isinstance(task, GraphNode):
return task

global SubgraphType

if SubgraphType is None:
from dask.optimization import SubgraphCallable

SubgraphType = SubgraphCallable

if type(task) is tuple and task and callable(task[0]):
func, args = task[0], task[1:]
if isinstance(func, SubgraphType):
subgraph = func
all_keys_inner = _MultiContainer(
subgraph.dsk, set(subgraph.inkeys), all_keys
)
sub_dsk = subgraph.dsk
deps: set[KeyType] = set()
converted_subgraph = convert_legacy_graph(sub_dsk, all_keys_inner)
for v in converted_subgraph.values():
if isinstance(v, GraphNode):
deps.update(v.dependencies)

# There is an explicit and implicit way to provide dependencies /
# data to a SubgraphCallable. Since we're not recursing into any
# containers any more we'll have to provide those arguments in a
# more explicit way as a separate argument.

# The explicit way are arguments to the subgraph callable. Those can
# again be tasks.

explicit_inkeys = dict()
for k, target in zip(subgraph.inkeys, args):
explicit_inkeys[k] = t = convert_legacy_task(None, target, all_keys)
if isinstance(t, GraphNode):
deps.update(t.dependencies)
explicit_inkeys_wrapped = Dict(explicit_inkeys)
deps.update(explicit_inkeys_wrapped.dependencies)

# The implicit way is when the tasks inside of the subgraph are
# referencing keys that are not part of the subgraph but are part of
# the outer graph.

deps -= set(subgraph.dsk)
deps -= set(subgraph.inkeys)

implicit_inkeys = dict()
for k in deps - explicit_inkeys_wrapped.dependencies:
assert k is not None
implicit_inkeys[k] = Alias(k)
return Task(
key,
_execute_subgraph,
converted_subgraph,
func.outkey,
explicit_inkeys_wrapped,
Dict(implicit_inkeys),
)
else:
new_args = []
new: object
for a in args:
if isinstance(a, dict):
new = Dict(a)
else:
new = convert_legacy_task(None, a, all_keys)
new_args.append(new)
return Task(key, func, *new_args)
new_args = []
new: object
for a in args:
if isinstance(a, dict):
new = Dict(a)
else:
new = convert_legacy_task(None, a, all_keys)
new_args.append(new)
return Task(key, func, *new_args)
try:
if isinstance(task, (bytes, int, float, str, tuple)):
if task in all_keys:
Expand Down Expand Up @@ -529,8 +472,7 @@ def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode:
{t.key: t for t in tasks},
outkey,
(Dict({k: Alias(k) for k in external_deps}) if external_deps else {}),
{},
data_producer=any(t.data_producer for t in tasks),
_data_producer=any(t.data_producer for t in tasks),
)


Expand Down Expand Up @@ -685,7 +627,7 @@ def __init__(
func: Callable,
/,
*args: Any,
data_producer: bool = False,
_data_producer: bool = False,
_dependencies: set | frozenset | None = None,
**kwargs: Any,
):
Expand All @@ -711,7 +653,7 @@ def __init__(
self._is_coro = None
self._token = None
self._repr = None
self._data_producer = data_producer
self._data_producer = _data_producer

@property
def data_producer(self) -> bool:
Expand Down Expand Up @@ -829,7 +771,7 @@ def substitute(
key or self.key,
self.func,
*new_args,
data_producer=self.data_producer,
_data_producer=self.data_producer,
**new_kwargs,
)
elif key is None or key == self.key:
Expand All @@ -840,7 +782,7 @@ def substitute(
key,
self.func,
*self.args,
data_producer=self.data_producer,
_data_producer=self.data_producer,
**self.kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions dask/array/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

from dask.array.core import normalize_chunks, normalize_chunks_cached # noqa: F401
23 changes: 21 additions & 2 deletions dask/array/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import functools
import math
import operator
import os
Expand Down Expand Up @@ -311,7 +312,7 @@ def graph_from_arraylike(
ArraySliceDep(chunks),
out_ind,
numblocks={},
data_producer=True,
_data_producer=True,
**kwargs,
)
return HighLevelGraph.from_collections(name, layer)
Expand All @@ -329,7 +330,7 @@ def graph_from_arraylike(
ArraySliceDep(chunks),
out_ind,
numblocks={},
data_producer=True,
_data_producer=True,
**kwargs,
)

Expand Down Expand Up @@ -3007,6 +3008,24 @@ def ensure_int(f):
return i


@functools.lru_cache
def normalize_chunks_cached(
chunks, shape=None, limit=None, dtype=None, previous_chunks=None
):
"""Cached version of normalize_chunks.
.. note::
chunks and previous_chunks are expected to be hashable. Dicts and lists aren't
allowed for this function.
See :func:`normalize_chunks` for further documentation.
"""
return normalize_chunks(
chunks, shape=shape, limit=limit, dtype=dtype, previous_chunks=previous_chunks
)


def normalize_chunks(chunks, shape=None, limit=None, dtype=None, previous_chunks=None):
"""Normalize chunks to tuple of tuples
Expand Down
43 changes: 22 additions & 21 deletions dask/array/tests/test_array_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
graph_from_arraylike,
load_store_chunk,
normalize_chunks,
normalize_chunks_cached,
optimize,
stack,
store,
Expand Down Expand Up @@ -3056,30 +3057,30 @@ def test_from_array_with_missing_chunks():
assert d.chunks == da.from_array(x, chunks=(2, 2, 3)).chunks


def test_normalize_chunks():
assert normalize_chunks(3, (4, 6)) == ((3, 1), (3, 3))
assert normalize_chunks(((3, 3), (8,)), (6, 8)) == ((3, 3), (8,))
assert normalize_chunks((4, 5), (9,)) == ((4, 5),)
assert normalize_chunks((4, 5), (9, 9)) == ((4, 4, 1), (5, 4))
assert normalize_chunks(-1, (5, 5)) == ((5,), (5,))
assert normalize_chunks((3, -1), (5, 5)) == ((3, 2), (5,))
assert normalize_chunks((3, None), (5, 5)) == ((3, 2), (5,))
assert normalize_chunks({0: 3}, (5, 5)) == ((3, 2), (5,))
assert normalize_chunks([[2, 2], [3, 3]]) == ((2, 2), (3, 3))
assert normalize_chunks(10, (30, 5)) == ((10, 10, 10), (5,))
assert normalize_chunks((), (0, 0)) == ((0,), (0,))
assert normalize_chunks(-1, (0, 3)) == ((0,), (3,))
assert normalize_chunks(((float("nan"),),)) == ((np.nan,),)

assert normalize_chunks("auto", shape=(20,), limit=5, dtype="uint8") == (
(5, 5, 5, 5),
)
assert normalize_chunks(("auto", None), (5, 5), dtype=int) == ((5,), (5,))
@pytest.mark.parametrize("func", [normalize_chunks, normalize_chunks_cached])
def test_normalize_chunks(func):
assert func(3, (4, 6)) == ((3, 1), (3, 3))
assert func(((3, 3), (8,)), (6, 8)) == ((3, 3), (8,))
assert func((4, 5), (9,)) == ((4, 5),)
assert func((4, 5), (9, 9)) == ((4, 4, 1), (5, 4))
assert func(-1, (5, 5)) == ((5,), (5,))
assert func((3, -1), (5, 5)) == ((3, 2), (5,))
assert func((3, None), (5, 5)) == ((3, 2), (5,))
if func is normalize_chunks:
assert func({0: 3}, (5, 5)) == ((3, 2), (5,))
assert func([[2, 2], [3, 3]]) == ((2, 2), (3, 3))
assert func(10, (30, 5)) == ((10, 10, 10), (5,))
assert func((), (0, 0)) == ((0,), (0,))
assert func(-1, (0, 3)) == ((0,), (3,))
assert func(((float("nan"),),)) == ((np.nan,),)

assert func("auto", shape=(20,), limit=5, dtype="uint8") == ((5, 5, 5, 5),)
assert func(("auto", None), (5, 5), dtype=int) == ((5,), (5,))

with pytest.raises(ValueError):
normalize_chunks(((10,),), (11,))
func(((10,),), (11,))
with pytest.raises(ValueError):
normalize_chunks(((5,), (5,)), (5,))
func(((5,), (5,)), (5,))


def test_single_element_tuple():
Expand Down
74 changes: 6 additions & 68 deletions dask/array/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,11 @@

import dask
import dask.array as da
from dask.array.chunk import getitem as da_getitem
from dask.array.core import getter as da_getter
from dask.array.core import getter_nofancy as da_getter_nofancy
from dask.array.chunk import getitem
from dask.array.core import getter
from dask.array.optimization import fuse_slice, optimize, optimize_blockwise
from dask.array.utils import assert_eq
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import SubgraphCallable


def _wrap_getter(func, wrap):
"""
Getters generated from a Blockwise layer might be wrapped in a SubgraphCallable.
Make sure that the optimization functions can still work if that is the case.
"""
if wrap:
return SubgraphCallable({"key": (func, "index")}, outkey="key", inkeys="index")
else:
return func


@pytest.fixture(params=[True, False])
def getter(request):
"""
Parameterized fixture for dask.array.core.getter both alone (False)
and wrapped in a SubgraphCallable (True).
"""
yield _wrap_getter(da_getter, request.param)


@pytest.fixture(params=[True, False])
def getitem(request):
"""
Parameterized fixture for dask.array.chunk.getitem both alone (False)
and wrapped in a SubgraphCallable (True).
"""
yield _wrap_getter(da_getitem, request.param)


@pytest.fixture(params=[True, False])
def getter_nofancy(request):
"""
Parameterized fixture for dask.array.chunk.getter_nofancy both alone (False)
and wrapped in a SubgraphCallable (True).
"""
yield _wrap_getter(da_getter_nofancy, request.param)


def _check_get_task_eq(a, b) -> bool:
Expand All @@ -66,12 +26,8 @@ def _check_get_task_eq(a, b) -> bool:
if len(a) < 1 or len(a) != len(b):
return False

a_callable = (
list(a[0].dsk.values())[0][0] if isinstance(a[0], SubgraphCallable) else a[0]
)
b_callable = (
list(b[0].dsk.values())[0][0] if isinstance(b[0], SubgraphCallable) else b[0]
)
a_callable = a[0]
b_callable = b[0]
if a_callable != b_callable:
return False

Expand All @@ -84,25 +40,7 @@ def _check_get_task_eq(a, b) -> bool:
return True


def _assert_getter_dsk_eq(a, b):
"""
Compare two getter dsks.
TODO: this is here to support the fact that low-level array slice fusion needs to be
able to introspect slicing tasks. But some slicing tasks (e.g. `from_array`) could
be hidden within SubgraphCallables. This and _check_get_task_eq should be removed
when high-level slicing lands, and replaced with basic equality checks.
"""
assert a.keys() == b.keys()
for k, av in a.items():
bv = b[k]
if dask.core.istask(av):
assert _check_get_task_eq(av, bv)
else:
assert av == bv


def test_optimize_with_getitem_fusion(getter):
def test_optimize_with_getitem_fusion():
dsk = {
"a": "some-array",
"b": (getter, "a", (slice(10, 20), slice(100, 200))),
Expand Down Expand Up @@ -185,7 +123,7 @@ def test_dont_fuse_numpy_arrays():
)


def test_fuse_slices_with_alias(getter, getitem):
def test_fuse_slices_with_alias():
dsk = {
"x": np.arange(16).reshape((4, 4)),
("dx", 0, 0): (getter, "x", (slice(0, 4), slice(0, 4))),
Expand Down
7 changes: 3 additions & 4 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def lazify_task(task, start=True):
assert len(task.args) == 1
task = task.args[0]
if task.func is _execute_subgraph:
subgraph = task.args[0]
outkey = task.args[1]
subgraph, outkey, inkeys = task.args
# If there is a reify at the output of the subgraph we don't want to act
final_task = lazify_task(subgraph[outkey], True)
subgraph = {
Expand All @@ -125,8 +124,8 @@ def lazify_task(task, start=True):
_execute_subgraph,
subgraph,
outkey,
*task.args[2:],
**task.kwargs,
inkeys,
_data_producer=task.data_producer,
)
return Task(
task.key,
Expand Down
Loading

0 comments on commit c7b2552

Please sign in to comment.