diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml index dbd841e3..af8e891a 100644 --- a/.github/workflows/pull_requests.yml +++ b/.github/workflows/pull_requests.yml @@ -4,6 +4,7 @@ on: pull_request: branches: - master + - dev jobs: test: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ebac7dd8..60ba5c2b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: --per-file-ignores, '__init__.py:F401 dexp/cli/*.py:E501'] - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.3.0 hooks: - id: black diff --git a/dexp/cli/dexp_commands/_test/test_cli.py b/dexp/cli/dexp_commands/_test/test_cli.py new file mode 100644 index 00000000..25f60d26 --- /dev/null +++ b/dexp/cli/dexp_commands/_test/test_cli.py @@ -0,0 +1,96 @@ +import subprocess +from typing import Sequence + +import pytest + +from dexp.utils.testing import cupy_only + + +@pytest.mark.parametrize( + "dexp_zarr_path", + [dict(dataset_type="nuclei", dtype="uint16")], + indirect=True, +) +@pytest.mark.parametrize( + "command", + [ + ["info"], + ["check"], + ["background", "-c", "image", "-a", "binary_area_threshold=500"], + ["copy", "-wk", "2", "-s", "[:,:,40:240,40:240]"], + ["deconv", "-c", "image", "-m", "lr", "-bp", "wb", "-i", "3"], + ["deconv", "-c", "image", "-m", "admm", "-w"], + ["denoise", "-c", "image"], + ["extract-psf", "-pt", "100", "-c", "image"], + ["generic", "-f", "gaussian_filter", "-pkg", "cupyx.scipy.ndimage", "-ts", "100", "-a", "sigma=1,2,2"], + ["histogram", "-c", "image", "-m", "0"], + ["projrender", "-c", "image", "-lt", "test projrender", "-lsi", "0.5"], + ["segment", "-dc", "image"], + ["tiff", "--split", "-wk", "2"], + ["tiff", "-p", "0", "-wk", "2"], + ["view", "-q", "-d", "2"], + ], +) +@cupy_only +def test_cli_commands_nuclei_dataset(command: Sequence[str], dexp_zarr_path: str) -> None: + assert subprocess.run(["dexp"] + command + [dexp_zarr_path]).returncode == 0 + + +@pytest.mark.parametrize( + "dexp_zarr_path", + [dict(dataset_type="nuclei", n_time_pts=10, dtype="uint16")], + indirect=True, +) +@pytest.mark.parametrize("command", [["stabilize", "-mr", "2", "-wk", "2"]]) +@cupy_only +def test_cli_commands_long_nuclei_dataset(command: Sequence[str], dexp_zarr_path: str) -> None: + assert subprocess.run(["dexp"] + command + [dexp_zarr_path]).returncode == 0 + + +@pytest.mark.parametrize( + "dexp_zarr_path", + [dict(dataset_type="fusion", dtype="uint16")], + indirect=True, +) +@pytest.mark.parametrize( + "command", + [ + ["register", "-c", "image-C0L0,image-C1L0"], + ["fuse", "-c", "image-C0L0,image-C1L0", "-lr"], # must be after registration + ], +) +@cupy_only +def test_cli_commands_fusion_dataset(command: Sequence[str], dexp_zarr_path: str) -> None: + assert subprocess.run(["dexp"] + command + [dexp_zarr_path]).returncode == 0 + + +@pytest.mark.parametrize( + "dexp_zarr_path", + [dict(dataset_type="nuclei-skewed", dtype="uint16")], + indirect=True, +) +@pytest.mark.parametrize( + "command", + [ + ["deskew", "-c", "image", "-xx", "1", "-zz", "1", "-a", "45"], + ], +) +@cupy_only +def test_cli_commands_skewed_dataset(command: Sequence[str], dexp_zarr_path: str) -> None: + assert subprocess.run(["dexp"] + command + [dexp_zarr_path]).returncode == 0 + + +@pytest.mark.parametrize( + "dexp_zarr_path", + [dict(dataset_type="fusion-skewed", dtype="uint16")], + indirect=True, +) +@pytest.mark.parametrize( + "command", + [ + ["fuse", "-c", "image-C0L0,image-C1L0", "-m", "mvsols", "-zpa", "0, 0"], + ], +) +@cupy_only +def test_cli_commands_fusion_skewed_dataset(command: Sequence[str], dexp_zarr_path: str) -> None: + assert subprocess.run(["dexp"] + command + [dexp_zarr_path]).returncode == 0 diff --git a/dexp/cli/dexp_commands/add.py b/dexp/cli/dexp_commands/add.py index b257360a..c5e03fb9 100644 --- a/dexp/cli/dexp_commands/add.py +++ b/dexp/cli/dexp_commands/add.py @@ -1,43 +1,51 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection from dexp.cli.defaults import DEFAULT_STORE -from dexp.cli.parsing import _get_output_path, _parse_channels -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + channels_option, + input_dataset_argument, + optional_channels_callback, +) +from dexp.datasets import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="List of channels, all channels when ommited.") +@input_dataset_argument() +@channels_option() +@click.option("--output-path", "-o", required=True) @click.option( - "--rename", + "--rename-channels", "-rc", default=None, help="You can rename channels: e.g. if channels are ‘channel1,anotherc’ then ‘gfp,rfp’ would rename the ‘channel1’ channel to ‘gfp’, and ‘anotherc’ to ‘rfp’ ", + callback=optional_channels_callback, ) @click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) @click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) @click.option( "--projection", "-p/-np", is_flag=True, default=True, help="If flags should be copied.", show_default=True ) -def add(input_paths, output_path, channels, rename, store, overwrite, projection): +def add( + input_dataset: ZDataset, + output_path: str, + channels: Sequence[str], + rename_channels: Sequence[str], + store: str, + overwrite: bool, + projection: bool, +) -> None: """Adds the channels selected from INPUT_PATHS to the given output dataset (created if not existing).""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_add") - channels = _parse_channels(input_dataset, channels) - - if rename is None: - rename = input_dataset.channels() - else: - rename = rename.split(",") - - with asection(f"Adding channels: {channels} from: {input_paths} to {output_path}, with new names: {rename}"): + with asection( + f"Adding channels: {channels} from: {input_dataset.path} to {output_path}, with new names: {rename_channels}" + ): input_dataset.add_channels_to( output_path, channels=channels, - rename=rename, + rename=rename_channels, store=store, overwrite=overwrite, add_projections=projection, diff --git a/dexp/cli/dexp_commands/background.py b/dexp/cli/dexp_commands/background.py new file mode 100644 index 00000000..a292d1d0 --- /dev/null +++ b/dexp/cli/dexp_commands/background.py @@ -0,0 +1,70 @@ +from typing import Optional, Sequence, Union + +import click +from arbol import asection + +from dexp.cli.parsing import ( + KwargsCommand, + args_option, + channels_option, + func_args_to_str, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + parse_args_to_kwargs, + slicing_option, + validate_function_kwargs, +) +from dexp.datasets import BaseDataset, ZDataset +from dexp.datasets.operations.background import dataset_background +from dexp.processing.crop.background import foreground_mask + + +@click.command( + cls=KwargsCommand, + context_settings=dict(ignore_unknown_options=True), + epilog=func_args_to_str(foreground_mask, ["array"]), +) +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@multi_devices_option() +@slicing_option() +@args_option() +@click.option( + "--reference-channel", "-rc", type=str, default=None, help="Optional channel to use for background removal." +) +@click.option( + "--merge-channels", + "-mc", + type=bool, + is_flag=True, + default=False, + help="Flag to indicate to merge (with addition) the image channels to detect the foreground.", +) +def background( + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + reference_channel: Optional[str], + merge_channels: bool, + devices: Union[str, Sequence[int]], + args: Sequence[str], +) -> None: + """Remove background by detecting foreground regions using morphological analysis.""" + kwargs = parse_args_to_kwargs(args) + validate_function_kwargs(foreground_mask, kwargs) + + with asection(f"Removing background of {input_dataset.path}."): + dataset_background( + input_dataset=input_dataset, + output_dataset=output_dataset, + channels=channels, + reference_channel=reference_channel, + merge_channels=merge_channels, + devices=devices, + **kwargs, + ) + + input_dataset.close() + output_dataset.close() diff --git a/dexp/cli/dexp_commands/check.py b/dexp/cli/dexp_commands/check.py index a29e9473..a7e15fc2 100644 --- a/dexp/cli/dexp_commands/check.py +++ b/dexp/cli/dexp_commands/check.py @@ -1,20 +1,19 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.parsing import _parse_channels -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import channels_option, input_dataset_argument +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--channels", "-c", default=None, help="List of channels, all channels when ommited.") -def check(input_paths, channels): +@input_dataset_argument() +@channels_option() +def check(input_dataset: ZDataset, channels: Sequence[int]) -> None: """Checks the integrity of a dataset.""" - input_dataset, input_paths = glob_datasets(input_paths) - channels = _parse_channels(input_dataset, channels) - - with asection(f"checking integrity of datasets {input_paths}, channels: {channels}"): + with asection(f"checking integrity of datasets {input_dataset.path}, channels: {channels}"): result = input_dataset.check_integrity(channels) input_dataset.close() diff --git a/dexp/cli/dexp_commands/copy.py b/dexp/cli/dexp_commands/copy.py index 121f80c3..d77a088d 100644 --- a/dexp/cli/dexp_commands/copy.py +++ b/dexp/cli/dexp_commands/copy.py @@ -1,43 +1,27 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import ( - DEFAULT_CLEVEL, - DEFAULT_CODEC, - DEFAULT_STORE, - DEFAULT_WORKERS_BACKEND, -) +from dexp.cli.defaults import DEFAULT_WORKERS_BACKEND from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_chunks, - _parse_slicing, + channels_option, + input_dataset_argument, + output_dataset_options, + slicing_option, + workers_option, ) -from dexp.datasets.open_dataset import glob_datasets +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.copy import dataset_copy +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="List of channels, all channels when ommited.") -@click.option( - "--slicing", - "-s", - default=None, - help="Dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option("--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).") -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: zstd for ’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@slicing_option() +@workers_option() @click.option( "--zerolevel", "-zl", @@ -45,14 +29,7 @@ default=0, help="‘zero-level’ i.e. the pixel values in the restoration (to be substracted)", show_default=True, -) # -@click.option( - "--workers", - "-wk", - default=-4, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n| ", - show_default=True, -) # +) @click.option( "--workersbackend", "-wkb", @@ -60,46 +37,27 @@ default=DEFAULT_WORKERS_BACKEND, help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", show_default=True, -) # -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # +) def copy( - input_paths, - output_path, - channels, - slicing, - store, - chunks, - codec, - clevel, - overwrite, - zerolevel, - workers, - workersbackend, - check, + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + zerolevel: int, + workers: int, + workersbackend: str, ): """Copies a dataset, channels can be selected, cropping can be performed, compression can be changed, ...""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_copy") - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - chunks = _parse_chunks(chunks) - - with asection(f"Copying from: {input_paths} to {output_path} for channels: {channels}, slicing: {slicing} "): + with asection( + f"Copying from: {input_dataset.path} to {output_dataset.path} for channels: {channels}, slicing: {input_dataset.slicing} " + ): dataset_copy( - input_dataset, - output_path, + input_dataset=input_dataset, + output_dataset=output_dataset, channels=channels, - slicing=slicing, - store=store, - chunks=chunks, - compression=codec, - compression_level=clevel, - overwrite=overwrite, zerolevel=zerolevel, workers=workers, workersbackend=workersbackend, - check=check, ) input_dataset.close() diff --git a/dexp/cli/dexp_commands/crop.py b/dexp/cli/dexp_commands/crop.py index 22a8ec1a..bda814f2 100644 --- a/dexp/cli/dexp_commands/crop.py +++ b/dexp/cli/dexp_commands/crop.py @@ -1,21 +1,30 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE -from dexp.cli.parsing import _get_output_path, _parse_channels, _parse_chunks -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + channels_callback, + channels_option, + input_dataset_argument, + output_dataset_options, + workers_option, +) +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.crop import dataset_crop +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="List of channels, all channels when ommited.") +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@workers_option() @click.option( "--quantile", "-q", default=0.99, - type=float, + type=click.FloatRange(0, 1), help="Quantile parameter for lower bound of brightness for thresholding.", show_default=True, ) @@ -24,66 +33,30 @@ "-rc", default=None, help="Reference channel to estimate cropping. If no provided it picks the first one.", + callback=channels_callback, ) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option("--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).") -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: zstd for ’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) -@click.option( - "--workers", - "-wk", - default=-4, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n| ", - show_default=True, -) # -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # def crop( - input_paths, - output_path, - channels, - quantile, - reference_channel, - store, - chunks, - codec, - clevel, - overwrite, - workers, - check, + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + quantile: float, + reference_channel: str, + workers: int, ): - - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_crop") - channels = _parse_channels(input_dataset, channels) - if reference_channel is None: - reference_channel = input_dataset.channels()[0] - chunks = _parse_chunks(chunks) + """Automatically crops the dataset given a reference channel.""" + reference_channel = reference_channel[0] with asection( - f"Cropping from: {input_paths} to {output_path} for channels: {channels}, " + f"Cropping from: {input_dataset.path} to {output_dataset.path} for channels: {channels}, " f"using channel {reference_channel} as a reference." ): - dataset_crop( input_dataset, - output_path, + output_dataset, channels=channels, reference_channel=reference_channel, quantile=quantile, - store=store, - chunks=chunks, - compression=codec, - compression_level=clevel, - overwrite=overwrite, workers=workers, - check=check, ) input_dataset.close() diff --git a/dexp/cli/dexp_commands/deconv.py b/dexp/cli/dexp_commands/deconv.py index 1f1341b6..f39731a8 100644 --- a/dexp/cli/dexp_commands/deconv.py +++ b/dexp/cli/dexp_commands/deconv.py @@ -1,43 +1,29 @@ +from typing import List, Optional, Sequence, Tuple + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import ( - DEFAULT_CLEVEL, - DEFAULT_CODEC, - DEFAULT_STORE, - DEFAULT_WORKERS_BACKEND, -) from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_slicing, - parse_devices, + channels_option, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + slicing_option, + tilesize_option, + tuple_callback, ) -from dexp.datasets.open_dataset import glob_datasets +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.deconv import dataset_deconv +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="list of channels, all channels when ommited.") -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="compression codec: ‘zstd’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True) -@click.option("--tilesize", "-ts", type=int, default=512, help="Tile size for tiled computation", show_default=True) +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@multi_devices_option() +@slicing_option() +@tilesize_option() @click.option( "--method", "-m", @@ -50,13 +36,13 @@ "--iterations", "-i", type=int, - default=None, + default=5, help="Number of deconvolution iterations. More iterations takes longer, will be sharper, but might also be potentially more noisy depending on method. " "The default number of iterations depends on the other parameters, in particular it depends on the choice of backprojection operator. For ‘wb’ as little as 3 iterations suffice. ", show_default=True, ) @click.option( - "--maxcorrection", + "--max-correction", "-mc", type=int, default=None, @@ -72,7 +58,7 @@ show_default=True, ) @click.option( - "--blindspot", + "--blind-spot", "-bs", type=int, default=0, @@ -80,9 +66,9 @@ show_default=True, ) @click.option( - "--backprojection", + "--back-projection", "-bp", - type=str, + type=click.Choice(["tpsf", "wb"]), default="tpsf", help="Back projection operator, can be: ‘tpsf’ (transposed PSF = classic) or ‘wb’ (Wiener-Butterworth = accelerated) ", show_default=True, @@ -104,12 +90,12 @@ help="Microscope objective to use for computing psf, can be: nikon16x08na, olympus20x10na, or path to file", show_default=True, ) -@click.option("--numape", "-na", type=float, default=None, help="Overrides NA value for objective.", show_default=True) +@click.option("--num-ape", "-na", type=float, default=None, help="Overrides NA value for objective.", show_default=True) @click.option("--dxy", "-dxy", type=float, default=0.485, help="Voxel size along x and y in microns", show_default=True) @click.option("--dz", "-dz", type=float, default=4 * 0.485, help="Voxel size along z in microns", show_default=True) -@click.option("--xysize", "-sxy", type=int, default=31, help="PSF size along xy in voxels", show_default=True) -@click.option("--zsize", "-sz", type=int, default=31, help="PSF size along z in voxels", show_default=True) -@click.option("--showpsf", "-sp", is_flag=True, help="Show point spread function (PSF) with napari", show_default=True) +@click.option("--xy-size", "-sxy", type=int, default=31, help="PSF size along xy in voxels", show_default=True) +@click.option("--z-size", "-sz", type=int, default=31, help="PSF size along z in voxels", show_default=True) +@click.option("--show-psf", "-sp", is_flag=True, help="Show point spread function (PSF) with napari", show_default=True) @click.option( "--scaling", "-sc", @@ -117,102 +103,58 @@ default="1,1,1", help="Scales input image along the three axis: sz,sy,sx (numpy order). For example: 2,1,1 upscales along z by a factor 2", show_default=True, -) # -@click.option( - "--workers", - "-k", - type=int, - default=-1, - help="Number of worker threads to spawn, if -1 then num workers = num devices", - show_default=True, + callback=tuple_callback(dtype=float, length=3), ) -@click.option( - "--workersbackend", - "-wkb", - type=str, - default=DEFAULT_WORKERS_BACKEND, - help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", - show_default=True, -) # -@click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) # -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # def deconv( - input_paths, - output_path, - channels, - slicing, - store, - codec, - clevel, - overwrite, - tilesize, - method, - iterations, - maxcorrection, - power, - blindspot, - backprojection, - wb_order, - objective, - numape, - dxy, - dz, - xysize, - zsize, - showpsf, - scaling, - workers, - workersbackend, - devices, - check, + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + tilesize: int, + method: str, + iterations: Optional[int], + max_correction: int, + power: float, + blind_spot: int, + back_projection: Optional[str], + wb_order: int, + objective: str, + num_ape: Optional[float], + dxy: float, + dz: float, + xy_size: int, + z_size: int, + show_psf: bool, + scaling: Tuple[float], + devices: List[int], ): """Deconvolves all or selected channels of a dataset.""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_deconv") - - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - devices = parse_devices(devices) - - if "," in scaling: - scaling = tuple(float(v) for v in scaling.split(",")) - with asection( - f"Deconvolving dataset: {input_paths}, saving it at: {output_path}, for channels: {channels}, slicing: {slicing} " + f"Deconvolving dataset: {input_dataset.path}, saving it at: {output_dataset.path}, for channels: {channels}, slicing: {input_dataset.slicing} " ): dataset_deconv( input_dataset, - output_path, + output_dataset, channels=channels, - slicing=slicing, - store=store, - compression=codec, - compression_level=clevel, - overwrite=overwrite, tilesize=tilesize, method=method, num_iterations=iterations, - max_correction=maxcorrection, + max_correction=max_correction, power=power, - blind_spot=blindspot, - back_projection=backprojection, + blind_spot=blind_spot, + back_projection=back_projection, wb_order=wb_order, psf_objective=objective, - psf_na=numape, + psf_na=num_ape, psf_dxy=dxy, psf_dz=dz, - psf_xy_size=xysize, - psf_z_size=zsize, - psf_show=showpsf, + psf_xy_size=xy_size, + psf_z_size=z_size, + psf_show=show_psf, scaling=scaling, - workers=workers, - workersbackend=workersbackend, devices=devices, - check=check, ) input_dataset.close() + output_dataset.close() aprint("Done!") diff --git a/dexp/cli/dexp_commands/denoise.py b/dexp/cli/dexp_commands/denoise.py index 3a654e29..8059330e 100644 --- a/dexp/cli/dexp_commands/denoise.py +++ b/dexp/cli/dexp_commands/denoise.py @@ -1,82 +1,39 @@ +from typing import Sequence, Union + import click from arbol import asection -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_chunks, - _parse_slicing, - parse_devices, + channels_option, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + slicing_option, + tilesize_option, ) -from dexp.datasets.open_dataset import glob_datasets +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.denoise import dataset_denoise from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="list of channels, all channels when ommited.") -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) -@click.option("--tilesize", "-ts", type=int, default=320, help="Tile size for tiled computation", show_default=True) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option("--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).") -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: zstd for ’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) -@click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@multi_devices_option() +@slicing_option() +@tilesize_option() def denoise( - input_paths, - output_path, - channels, - slicing, - tilesize, - store, - chunks, - codec, - clevel, - overwrite, - devices, + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + tilesize: int, + devices: Union[str, Sequence[int]], ): """Denoises input image using butterworth filter, parameters are estimated automatically using noise2self j-invariant cross-validation loss. """ - - input_dataset, input_paths = glob_datasets(input_paths) - channels = _parse_channels(input_dataset, channels) - slicing = _parse_slicing(slicing) - devices = parse_devices(devices) - - if slicing is not None: - input_dataset.set_slicing(slicing) - - output_path = _get_output_path(input_paths[0], output_path, "_denoised") - mode = "w" + ("" if overwrite else "-") - output_dataset = ZDataset( - output_path, - mode=mode, - store=store, - codec=codec, - clevel=clevel, - chunks=_parse_chunks(chunks), - parent=input_dataset, - ) - - with asection(f"Denoising: {input_paths} to {output_path} for channels {channels}"): + with asection(f"Denoising data to {output_dataset.path} for channels {channels}"): dataset_denoise( input_dataset, output_dataset, diff --git a/dexp/cli/dexp_commands/deskew.py b/dexp/cli/dexp_commands/deskew.py index 3d55c0a9..317f85a0 100644 --- a/dexp/cli/dexp_commands/deskew.py +++ b/dexp/cli/dexp_commands/deskew.py @@ -1,60 +1,53 @@ +from typing import Optional, Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import ( - DEFAULT_CLEVEL, - DEFAULT_CODEC, - DEFAULT_STORE, - DEFAULT_WORKERS_BACKEND, -) from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_slicing, - parse_devices, + channels_option, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + slicing_option, ) -from dexp.datasets.open_dataset import glob_datasets from dexp.datasets.operations.deskew import dataset_deskew +from dexp.datasets.zarr_dataset import ZDataset + + +def _flip_callback(ctx: click.Context, opt: click.Option, value: Optional[str]) -> Sequence[str]: + length = len(ctx.params["input_dataset"].channels()) + if value is None: + value = (False,) * length + else: + value = tuple(bool(s) for s in value.split(",")) + if len(value) != length: + raise ValueError("-fl --flips must have length equal to number of channels.") + return value @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option( - "--channels", - "-c", - default=None, - help="list of channels for the view in standard order for the microscope type (C0L0, C0L1, C1L0, C1L1,...)", -) -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) # -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) +@input_dataset_argument() +@output_dataset_options() +@slicing_option() +@channels_option() +@multi_devices_option() @click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="compression codec: ‘zstd’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", + "--mode", + "-m", + type=click.Choice(["yang", "classic"]), + default="yang", + help="Deskew algorithm", + show_default=True, ) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) +@click.option("--delta-x", "-xx", type=float, default=None, help="Pixel size of the camera", show_default=True) @click.option( - "--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True -) # , help='dataset slice' -@click.option( - "--mode", "-m", type=str, default="yang", help="Deskew algorithm: 'yang' or 'classic'. ", show_default=True -) # -@click.option("--deltax", "-xx", type=float, default=None, help="Pixel size of the camera", show_default=True) # -@click.option( - "--deltaz", + "--delta-z", "-zz", type=float, default=None, help="Scanning step (stage or galvo scanning step, not the same as the distance between the slices)", show_default=True, -) # +) @click.option( "--angle", "-a", @@ -70,99 +63,56 @@ default=None, help="Flips image to deskew in the opposite orientation (True for view 0 and False for view 1)", show_default=True, -) # + callback=_flip_callback, +) @click.option( - "--camorientation", + "--camera-orientation", "-co", type=int, default=0, help="Camera orientation correction expressed as a number of 90 deg rotations to be performed per 2D image in stack -- if required.", show_default=True, -) # -@click.option("--depthaxis", "-za", type=int, default=0, help="Depth axis.", show_default=True) # -@click.option("--lateralaxis", "-xa", type=int, default=1, help="Lateral axis.", show_default=True) # -@click.option( - "--workers", - "-k", - type=int, - default=-1, - help="Number of worker threads to spawn, if -1 then num workers = num devices", - show_default=True, -) # -@click.option( - "--workersbackend", - "-wkb", - type=str, - default=DEFAULT_WORKERS_BACKEND, - help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", - show_default=True, -) # +) +@click.option("--depth-axis", "-za", type=int, default=0, help="Depth axis.", show_default=True) +@click.option("--lateral-axis", "-xa", type=int, default=1, help="Lateral axis.", show_default=True) @click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) # -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # + "--padding", "-p", type=bool, default=False, is_flag=True, help="Pads output image to fit deskwed results." +) def deskew( - input_paths, - output_path, - channels, - slicing, - store, - codec, - clevel, - overwrite, - mode, - deltax, - deltaz, - angle, - flips, - camorientation, - depthaxis, - lateralaxis, - workers, - workersbackend, - devices, - check, + input_dataset: ZDataset, + output_dataset: ZDataset, + channels: Sequence[str], + mode: str, + delta_x: float, + delta_z: float, + angle: float, + flips: Sequence[bool], + camera_orientation: int, + depth_axis: int, + lateral_axis: int, + padding: bool, + devices: Sequence[int], ): """Deskews all or selected channels of a dataset.""" - - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_deskew") - - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - devices = parse_devices(devices) - - if flips is None: - flips = (False,) * len(channels) - else: - flips = tuple(bool(s) for s in flips.split(",")) - with asection( - f"Deskewing dataset: {input_paths}, saving it at: {output_path}, for channels: {channels}, slicing: {slicing} " + f"Deskewing dataset: {input_dataset.path}, saving it at: {output_dataset.path}, for channels: {channels}" ): - aprint(f"Devices used: {devices}, workers: {workers} ") + aprint(f"Devices used: {devices}") dataset_deskew( input_dataset, - output_path, + output_dataset, channels=channels, - slicing=slicing, - dx=deltax, - dz=deltaz, + dx=delta_x, + dz=delta_z, angle=angle, flips=flips, - camera_orientation=camorientation, - depth_axis=depthaxis, - lateral_axis=lateralaxis, + camera_orientation=camera_orientation, + depth_axis=depth_axis, + lateral_axis=lateral_axis, mode=mode, - store=store, - compression=codec, - compression_level=clevel, - overwrite=overwrite, - workers=workers, - workersbackend=workersbackend, + padding=padding, devices=devices, - check=check, ) - input_dataset.close() - aprint("Done!") + input_dataset.close() + output_dataset.close() diff --git a/dexp/cli/dexp_commands/extract_psf.py b/dexp/cli/dexp_commands/extract_psf.py index f05fee4d..5350eb80 100644 --- a/dexp/cli/dexp_commands/extract_psf.py +++ b/dexp/cli/dexp_commands/extract_psf.py @@ -1,26 +1,26 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.parsing import _parse_channels, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + channels_option, + input_dataset_argument, + multi_devices_option, + slicing_option, + verbose_option, +) +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.extract_psf import dataset_extract_psf @click.command(name="extract-psf") -@click.argument("input_paths", nargs=-1, required=True) -@click.option("--out-prefix-path", "-o", default="psf_", help="Output PSF file prefix") -@click.option( - "--channels", - "-c", - default=None, - help="list of channels to extract the PSF, each channel will have their separater PSF (i.e., psf_c0.npy, psf_c1.npy", -) -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) +@input_dataset_argument() +@slicing_option() +@channels_option() +@multi_devices_option() +@verbose_option() +@click.option("--out-prefix-path", "-o", default="psf_", help="Output PSF file prefix", type=str) @click.option( "--peak-threshold", "-pt", @@ -38,43 +38,28 @@ help="Threshold of PSF selection given the similarity (cosine distance) to median PSF.", ) @click.option("--psf_size", "-ps", type=int, default=35, show_default=True, help="Size (shape) of the PSF") -@click.option( - "--device", - "-d", - type=int, - default=0, - help="Sets the CUDA device id, e.g. 0,1,2. It works for only a single device!", - show_default=True, -) -@click.option("--verbose", "-v", type=bool, is_flag=True, default=False, help="Flag to display intermediated results.") def extract_psf( - input_paths, - out_prefix_path, - channels, - slicing, - peak_threshold, - similarity_threshold, - psf_size, - device, - verbose, + input_dataset: BaseDataset, + out_prefix_path: str, + channels: Sequence[str], + peak_threshold: int, + similarity_threshold: float, + psf_size: int, + devices: Sequence[int], + verbose: bool, ): - """Extracts the PSF from beads.""" - input_dataset, input_paths = glob_datasets(input_paths) - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - + """Detects and extracts the PSF from beads.""" with asection( - f"Extracting PSF of dataset: {input_paths}, saving it with prefix: {out_prefix_path}, for channels: {channels}, slicing: {slicing} " + f"Extracting PSF of dataset: {input_dataset.path}, saving it with prefix: {out_prefix_path}, for channels: {channels}" ): - aprint(f"Device used: {device}") + aprint(f"Device used: {devices}") dataset_extract_psf( - dataset=input_dataset, + input_dataset=input_dataset, dest_path=out_prefix_path, channels=channels, - slicing=slicing, peak_threshold=peak_threshold, similarity_threshold=similarity_threshold, psf_size=psf_size, verbose=verbose, - device=device, + devices=devices, ) diff --git a/dexp/cli/dexp_commands/fastcopy.py b/dexp/cli/dexp_commands/fastcopy.py index 3ad2b736..ccdc3677 100644 --- a/dexp/cli/dexp_commands/fastcopy.py +++ b/dexp/cli/dexp_commands/fastcopy.py @@ -3,24 +3,18 @@ import click from arbol.arbol import aprint, asection -from dexp.cli.parsing import _get_output_path +from dexp.cli.parsing import _get_output_path, workers_option from dexp.utils.robocopy import robocopy @click.command() -@click.argument("input_path", nargs=1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option( - "--workers", - "-wk", - default=8, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n| ", - show_default=True, -) # +@click.argument("input_path", nargs=1, type=click.Path(exists=True)) +@click.option("--output_path", "-o") +@workers_option() @click.option( "--large_files", "-lf", is_flag=True, help="Set to true to speed up large file transfer", show_default=True ) -def fastcopy(input_path, output_path, workers, large_files): +def fastcopy(input_path: str, output_path: str, workers: int, large_files: bool) -> None: """Copies a dataset fast, with no processing, just moves the data as fast as possible. For each operating system it uses the best method.""" output_path = _get_output_path(input_path, output_path, "_copy") diff --git a/dexp/cli/dexp_commands/fromraw.py b/dexp/cli/dexp_commands/fromraw.py index b464cb8d..bb023528 100644 --- a/dexp/cli/dexp_commands/fromraw.py +++ b/dexp/cli/dexp_commands/fromraw.py @@ -1,62 +1,31 @@ import click from arbol import asection -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE -from dexp.cli.parsing import _get_output_path, _parse_chunks -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + input_dataset_argument, + output_dataset_options, + workers_option, +) +from dexp.datasets import BaseDataset, ZDataset from dexp.datasets.operations.fromraw import dataset_fromraw -from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--ch-prefix", "-c", default=None, help="Prefix of channels, usually according to wavelength") -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option("--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).") -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: zstd for ’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) -@click.option( - "--workers", - "-wk", - default=-4, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n| ", - show_default=True, -) # +@input_dataset_argument() +@output_dataset_options() +@workers_option() +@click.option("--ch-prefix", "-c", default=None, help="Prefix of channels, usually according to wavelength", type=str) def fromraw( - input_paths, - output_path, - ch_prefix, - store, - chunks, - codec, - clevel, - workers, - overwrite, + input_dataset: BaseDataset, + output_dataset: ZDataset, + ch_prefix: str, + workers: int, ): """Copies a dataset, channels can be selected, cropping can be performed, compression can be changed, ...""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_fromraw") - mode = "w" + ("" if overwrite else "-") - output_dataset = ZDataset( - output_path, - mode=mode, - store=store, - codec=codec, - clevel=clevel, - chunks=_parse_chunks(chunks), - parent=input_dataset, - ) - - with asection(f"Copying `fromraw`: {input_paths} to {output_path} for channels with prefix {ch_prefix}"): + with asection( + f"Copying `fromraw`: {input_dataset.path} to {output_dataset.path} for channels with prefix {ch_prefix}" + ): dataset_fromraw( input_dataset, output_dataset, diff --git a/dexp/cli/dexp_commands/fuse.py b/dexp/cli/dexp_commands/fuse.py index deb83f15..e125f8a6 100644 --- a/dexp/cli/dexp_commands/fuse.py +++ b/dexp/cli/dexp_commands/fuse.py @@ -1,43 +1,26 @@ +from typing import Sequence, Tuple + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_slicing, - parse_devices, + channels_option, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + slicing_option, + tuple_callback, ) -from dexp.datasets.open_dataset import glob_datasets +from dexp.datasets import BaseDataset, ZDataset from dexp.datasets.operations.fuse import dataset_fuse @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option( - "--channels", - "-c", - default=None, - help="list of channels for the view in standard order for the microscope type (C0L0, C0L1, C1L0, C1L1,...)", -) -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) # -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="compression codec: ‘zstd’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option( - "--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True -) # , help='dataset slice' +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@slicing_option() +@multi_devices_option() @click.option( "--microscope", "-m", @@ -57,6 +40,7 @@ "--equalisemode", "-eqm", default="first", + type=click.Choice(["first", "all"]), help="Equalisation modes: compute correction ratios only for first time point: ‘first’ or for all time points: ‘all’.", show_default=True, ) @@ -67,7 +51,7 @@ default=0, help="‘zero-level’ i.e. the pixel values in the restoration (to be substracted)", show_default=True, -) # +) @click.option( "--cliphigh", "-ch", @@ -75,18 +59,24 @@ default=0, help="Clips voxel values above the given value, if zero no clipping is done", show_default=True, -) # +) @click.option( - "--fusion", "-f", type=str, default="tg", help="Fusion mode, can be: ‘tg’ or ‘dct’. ", show_default=True -) # + "--fusion", + "-f", + type=click.Choice(["tg", "dct"]), + default="tg", + help="Fusion mode, can be: ‘tg’ or ‘dct’.", + show_default=True, +) @click.option( "--fusion_bias_strength", "-fbs", - type=(float, float), - default=(0.5, 0.02), + type=str, + default="0.5,0.02", help="Fusion bias strength for illumination and detection ‘fbs_i fbs_d’, set to ‘0 0’) if fusing a cropped region", show_default=True, -) # + callback=tuple_callback(dtype=float, length=2), +) @click.option( "--dehaze_size", "-dhs", @@ -94,7 +84,7 @@ default=65, help="Filter size (scale) for dehazing the final regsitered and fused image to reduce effect of scattered and out-of-focus light. Set to zero to deactivate.", show_default=True, -) # +) @click.option( "--dark-denoise-threshold", "-ddt", @@ -102,27 +92,29 @@ default=0, help="Threshold for denoises the dark pixels of the image -- helps increase compression ratio. Set to zero to deactivate.", show_default=True, -) # +) @click.option( "--zpadapodise", "-zpa", - type=(int, int), - default=(8, 96), + type=str, + default="8,96", help="Pads and apodises the views along z before fusion: ‘pad apo’, where pad is a padding length, and apo is apodisation length, both in voxels. If pad=apo, no original voxel is modified and only added voxels are apodised.", show_default=True, -) # + callback=tuple_callback(length=2, dtype=int), +) @click.option( "--loadreg", "-lr", is_flag=True, help="Turn on to load the registration parameters from a previous run", show_default=True, -) # +) @click.option( "--model-filename", "-mf", help="Model filename to load or save registration model list", default="registration_models.txt", + type=click.Path(exists=True, dir_okay=False), show_default=True, ) @click.option( @@ -132,7 +124,7 @@ default=4, help="Number of iterations for warp registration (if applicable).", show_default=True, -) # +) @click.option( "--minconfidence", "-mc", @@ -140,7 +132,7 @@ default=0.3, help="Minimal confidence for registration parameters, if below that level the registration parameters for previous time points is used.", show_default=True, -) # +) @click.option( "--maxchange", "-md", @@ -148,14 +140,14 @@ default=16, help="Maximal change in registration parameters, if above that level the registration parameters for previous time points is used.", show_default=True, -) # +) @click.option( "--regedgefilter", "-ref", is_flag=True, help="Use this flag to apply an edge filter to help registration.", show_default=True, -) # +) @click.option( "--maxproj/--no-maxproj", "-mp/-nmp", @@ -170,10 +162,7 @@ is_flag=True, help="Use this flag to indicate that the the dataset is _huge_ and that memory allocation should be optimised at the detriment of processing speed.", show_default=True, -) # -@click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) # +) @click.option( "--pad", "-p", is_flag=True, default=False, help="Use this flag to pad views according to the registration models." ) @@ -194,64 +183,46 @@ default=False, help="Use this flag to remove beads before equalizing and fusing", ) -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # def fuse( - input_paths, - output_path, - channels, - slicing, - store, - codec, - clevel, - overwrite, - microscope, - equalise, - equalisemode, - zerolevel, - cliphigh, - fusion, - fusion_bias_strength, - dehaze_size, - dark_denoise_threshold, - zpadapodise, - loadreg, - model_filename, - warpregiter, - minconfidence, - maxchange, - regedgefilter, - maxproj, - hugedataset, - devices, - pad, - white_top_hat_size, - white_top_hat_sampling, - remove_beads, - check, + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + microscope: str, + equalise: bool, + equalisemode: str, + zerolevel: float, + cliphigh: float, + fusion: str, + fusion_bias_strength: Tuple[float, float], + dehaze_size: int, + dark_denoise_threshold: int, + zpadapodise: Tuple[float, float], + loadreg: bool, + model_filename: str, + warpregiter: int, + minconfidence: float, + maxchange: float, + regedgefilter: bool, + maxproj: bool, + hugedataset: bool, + devices: Sequence[int], + pad: bool, + white_top_hat_size: float, + white_top_hat_sampling: int, + remove_beads: bool, ): """Fuses the views of a multi-view light-sheet microscope dataset (available: simview and mvsols)""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_fused") - - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - devices = parse_devices(devices) - with asection( - f"Fusing dataset: {input_paths}, saving it at: {output_path}, for channels: {channels}, slicing: {slicing} " + f"Fusing dataset: {input_dataset.path}, saving it at: {output_dataset.path}," + + f"for channels: {channels}, slicing: {input_dataset.slicing}" ): aprint(f"Microscope type: {microscope}, fusion type: {fusion}") aprint(f"Devices used: {devices}") dataset_fuse( - input_dataset, - output_path, + input_dataset=input_dataset, + output_dataset=output_dataset, channels=channels, - slicing=slicing, - store=store, - compression=codec, - compression_level=clevel, - overwrite=overwrite, microscope=microscope, equalise=equalise, equalise_mode=equalisemode, @@ -276,8 +247,8 @@ def fuse( white_top_hat_size=white_top_hat_size, white_top_hat_sampling=white_top_hat_sampling, remove_beads=remove_beads, - check=check, ) input_dataset.close() + output_dataset.close() aprint("Done!") diff --git a/dexp/cli/dexp_commands/generic.py b/dexp/cli/dexp_commands/generic.py new file mode 100644 index 00000000..ac53e16b --- /dev/null +++ b/dexp/cli/dexp_commands/generic.py @@ -0,0 +1,87 @@ +import inspect +from importlib import import_module +from typing import Callable, Optional, Sequence, Union + +import click +from arbol import asection +from toolz import curry + +from dexp.cli.parsing import ( + args_option, + channels_option, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + parse_args_to_kwargs, + slicing_option, + tilesize_option, + validate_function_kwargs, +) +from dexp.datasets.base_dataset import BaseDataset +from dexp.datasets.operations.generic import dataset_generic +from dexp.datasets.zarr_dataset import ZDataset + + +def _parse_function(func_name: str, pkg_import: str, args: Sequence[str]) -> Callable: + pkg = import_module(pkg_import) + if not hasattr(pkg, func_name): + functions = inspect.getmembers(pkg, inspect.isfunction) + functions = [f[0] for f in functions] + raise ValueError(f"Could not find function {func_name} in {pkg_import}\n. Available functions {functions}") + + kwargs = parse_args_to_kwargs(args) + func = getattr(pkg, func_name) + validate_function_kwargs(func, kwargs) + + func = curry(func, **kwargs) + return func + + +@click.command(context_settings=dict(ignore_unknown_options=True)) +@input_dataset_argument() +@output_dataset_options() +@channels_option() +@multi_devices_option() +@slicing_option() +@tilesize_option(default=None) +@args_option() +@click.option( + "--func-name", + "-f", + type=str, + required=True, + help="Function to be called, arguments should be provided as keyword arguments on the cli.", +) +@click.option( + "--pkg-import", + "-pkg", + type=str, + required=True, + help="Package to be imported for function usage. For example, `cupyx.scipy.ndimage` when using `gaussian_filter`.", +) +def generic( + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + func_name: str, + pkg_import: str, + tilesize: Optional[int], + devices: Union[str, Sequence[int]], + args: Sequence[str], +): + """ + Executes a generic function in multiples gpus using dask and cupy. + """ + func = _parse_function(func_name=func_name, pkg_import=pkg_import, args=args) + with asection(f"Applying {func.__name__} to {output_dataset.path} for channels {channels}"): + dataset_generic( + input_dataset, + output_dataset, + func=func, + channels=channels, + tilesize=tilesize, + devices=devices, + ) + + input_dataset.close() + output_dataset.close() diff --git a/dexp/cli/dexp_commands/histogram.py b/dexp/cli/dexp_commands/histogram.py index 297a931e..5912bd13 100644 --- a/dexp/cli/dexp_commands/histogram.py +++ b/dexp/cli/dexp_commands/histogram.py @@ -1,21 +1,14 @@ import click -from dexp.cli.parsing import _parse_channels, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import channels_option, device_option, input_dataset_argument from dexp.datasets.operations.histogram import dataset_histogram @click.command() -@click.argument("input-paths", nargs=-1) +@input_dataset_argument() +@channels_option() +@device_option() @click.option("--output-directory", "-o", default="histogram", show_default=True, type=str) -@click.option("--channels", "-c", default=None, help="list of channels, all channels when ommited.") -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) -@click.option("--device", "-d", type=int, default=0, show_default=True) @click.option( "--minimum-count", "-m", @@ -32,15 +25,8 @@ type=float, help="Maximum count to clip histogram and improve visualization.", ) -def histogram(input_paths, output_directory, channels, slicing, device, minimum_count, maximum_count): - - input_dataset, input_paths = glob_datasets(input_paths) - channels = _parse_channels(input_dataset, channels) - slicing = _parse_slicing(slicing) - - if slicing is not None: - input_dataset.set_slicing(slicing) - +def histogram(input_dataset, output_directory, channels, device, minimum_count, maximum_count): + """Computes histogram of selected channel.""" dataset_histogram( dataset=input_dataset, output_dir=output_directory, diff --git a/dexp/cli/dexp_commands/info.py b/dexp/cli/dexp_commands/info.py index 87124ab7..215c4127 100644 --- a/dexp/cli/dexp_commands/info.py +++ b/dexp/cli/dexp_commands/info.py @@ -1,17 +1,16 @@ import click from arbol.arbol import aprint, asection -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import input_dataset_argument +from dexp.datasets.base_dataset import BaseDataset @click.command() -@click.argument("input_paths", nargs=-1) -def info(input_paths): +@input_dataset_argument() +def info(input_dataset: BaseDataset): """Prints out extensive information about dataset.""" - input_dataset, input_paths = glob_datasets(input_paths) - - with asection(f"Information on dataset at: {input_paths}"): + with asection(f"Information on dataset at: {input_dataset.path}"): aprint(input_dataset.info()) input_dataset.close() aprint("Done!") diff --git a/dexp/cli/dexp_commands/isonet.py b/dexp/cli/dexp_commands/isonet.py deleted file mode 100644 index 78237ead..00000000 --- a/dexp/cli/dexp_commands/isonet.py +++ /dev/null @@ -1,92 +0,0 @@ -import click -from arbol.arbol import aprint, asection - -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE -from dexp.cli.parsing import _get_output_path, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets -from dexp.datasets.operations.isonet import dataset_isonet - - -@click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--dxy", "-dxy", type=float, required=True) -@click.option("--dz", "-dz", type=float, required=True) -@click.option("--binning", "-b", type=int, default=1) # TODO help -@click.option("--sharpening", "-sh", is_flag=True, help="Dehazes the image with a low-pass rejection filter.") -@click.option("--channel", "-ch", default=None, help="dataset channel name") -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) # -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="compression codec: ‘zstd’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) # -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option( - "--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True -) # , help='dataset slice' -@click.option( - "--context", "-c", default="default", help="IsoNet context name", show_default=True -) # , help='dataset slice' -@click.option( - "--mode", "-m", default="pta", help="mode: 'pta' -> prepare, train and apply, 'a' just apply ", show_default=True -) # , help='dataset slice' -@click.option( - "--max_epochs", "-e", type=int, default="50", help="to force overwrite of target", show_default=True -) # , help='dataset slice' -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # -def isonet( - input_paths, - output_path, - dxy, - dz, - sharpening, - binning, - channel, - slicing, - store, - codec, - clevel, - overwrite, - context, - mode, - max_epochs, - check, -): - """Recovers isotropic resolution using the ISONET approach (Weigert et al.)""" - - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_isonet") - slicing = _parse_slicing(slicing) - - with asection(f"Applying Isonet to: {input_paths}, saving result to: {output_path}, slicing: {slicing} "): - dataset_isonet( - input_dataset, - output_path, - dxy=dxy, - dz=dz, - channel=channel, - slicing=slicing, - store=store, - compression=codec, - compression_level=clevel, - overwrite=overwrite, - context=context, - mode=mode, - max_epochs=max_epochs, - check=check, - sharpening=sharpening, - binning=binning, - training_tp_index=None, - ) - - input_dataset.close() - aprint("Done!") diff --git a/dexp/cli/dexp_commands/projrender.py b/dexp/cli/dexp_commands/projrender.py index b9d6139b..2cffe514 100644 --- a/dexp/cli/dexp_commands/projrender.py +++ b/dexp/cli/dexp_commands/projrender.py @@ -1,40 +1,37 @@ +from pathlib import Path +from typing import Optional, Sequence, Union + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import DEFAULT_WORKERS_BACKEND from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_slicing, - parse_devices, + channels_option, + input_dataset_argument, + multi_devices_option, + overwrite_option, + slicing_option, + tuple_callback, ) -from dexp.datasets.open_dataset import glob_datasets from dexp.datasets.operations.projrender import dataset_projection_rendering +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) +@input_dataset_argument() +@slicing_option() +@channels_option() +@multi_devices_option() +@overwrite_option() @click.option( "--output_path", "-o", - type=str, - default=None, - help="Output folder to store rendered PNGs. Default is: frames_", -) -@click.option("--channels", "-c", type=str, default=None, help="list of channels to color, all channels when ommited.") -@click.option( - "--slicing", - "-s", - type=str, - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z).", + type=click.Path(path_type=Path), + default=Path("frames"), + help="Output folder to store rendered PNGs. Default is: frames/", ) -@click.option( - "--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True -) # , help='dataset slice' @click.option( "--axis", "-ax", type=int, default=0, help="Sets the projection axis: 0->Z, 1->Y, 2->X ", show_default=True -) # , help='dataset slice' +) @click.option( "--dir", "-di", @@ -46,7 +43,7 @@ @click.option( "--mode", "-m", - type=str, + type=click.Choice(["max", "maxcolor", "colormax"]), default="colormax", help="Sets the projection mode: ‘max’: classic max projection, ‘colormax’: color max projection, i.e. color codes for depth, ‘maxcolor’ same as colormax but first does depth-coding by color and then max projects (acheives some level of transparency). ", show_default=True, @@ -57,6 +54,7 @@ type=str, default=None, help="Sets the contrast limits, i.e. -cl 0,1000 sets the contrast limits to [0,1000]", + callback=tuple_callback(dtype=float, length=2), ) @click.option( "--attenuation", @@ -65,7 +63,7 @@ default=0.1, help="Sets the projection attenuation coefficient, should be within [0, 1] ideally close to 0. Larger values mean more attenuation.", show_default=True, -) # , help='dataset slice' +) @click.option( "--gamma", "-g", @@ -73,7 +71,7 @@ default=1.0, help="Sets the gamma coefficient pre-applied to the raw voxel values (before projection or any subsequent processing).", show_default=True, -) # , help='dataset slice' +) @click.option( "--dlim", "-dl", @@ -81,7 +79,8 @@ default=None, help="Sets the depth limits. Depth limits. For example, a value of (0.1, 0.7) means that the colormap start at a normalised depth of 0.1, and ends at a normalised depth of 0.7, other values are clipped. Only used for colormax mode.", show_default=True, -) # , help='dataset slice' + callback=tuple_callback(dtype=float, length=2), +) @click.option( "--colormap", "-cm", @@ -106,7 +105,7 @@ show_default=True, ) @click.option( - "--legendsize", + "--legend-size", "-lsi", type=float, default=1.0, @@ -114,7 +113,7 @@ show_default=True, ) @click.option( - "--legendscale", + "--legend-scale", "-lsc", type=float, default=1.0, @@ -122,7 +121,7 @@ show_default=True, ) @click.option( - "--legendtitle", + "--legend-title", "-lt", type=str, default="color-coded depth (voxels)", @@ -130,15 +129,16 @@ show_default=True, ) @click.option( - "--legendtitlecolor", + "--legend-title-color", "-ltc", type=str, default="1,1,1,1", help="Legend title color as a tuple of normalised floats: R, G, B, A (values between 0 and 1).", show_default=True, + callback=tuple_callback(dtype=float, length=4), ) @click.option( - "--legendposition", + "--legend-position", "-lp", type=str, default="bottom_left", @@ -146,106 +146,66 @@ show_default=True, ) @click.option( - "--legendalpha", + "--legend-alpha", "-la", type=float, default=1, help="Transparency for legend (1 means opaque, 0 means completely transparent)", show_default=True, ) -@click.option("--step", "-sp", type=int, default=1, help="Process every ‘step’ frames.", show_default=True) -@click.option( - "--workers", - "-k", - type=int, - default=-1, - help="Number of worker threads to spawn, if -1 then num workers = num devices", - show_default=True, -) # -@click.option( - "--workersbackend", - "-wkb", - type=str, - default=DEFAULT_WORKERS_BACKEND, - help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", - show_default=True, -) # -@click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) # def projrender( - input_paths, - output_path, - channels, - slicing, - overwrite, - axis, - dir, - mode, - clim, - attenuation, - gamma, - dlim, - colormap, - rgbgamma, - transparency, - legendsize, - legendscale, - legendtitle, - legendtitlecolor, - legendposition, - legendalpha, - step, - workers, - workersbackend, - devices, - stop_at_exception=True, -): + input_dataset: ZDataset, + output_path: Path, + channels: Sequence[str], + overwrite: bool, + axis: int, + dir: int, + mode: str, + clim: Optional[Sequence[float]], + attenuation: float, + gamma: float, + dlim: Optional[Sequence[float]], + colormap: str, + rgbgamma: float, + transparency: bool, + legend_size: float, + legend_scale: float, + legend_title: str, + legend_title_color: Sequence[float], + legend_position: Union[str, Sequence[int]], + legend_alpha: float, + devices: Sequence[int], +) -> None: """Renders datatset using 2D projections.""" - input_dataset, input_paths = glob_datasets(input_paths) - channels = _parse_channels(input_dataset, channels) - slicing = _parse_slicing(slicing) - devices = parse_devices(devices) - - output_path = _get_output_path(input_paths[0], output_path, f"_{mode}_projection") - - dlim = None if dlim is None else tuple(float(strvalue) for strvalue in dlim.split(",")) - legendtitlecolor = tuple(float(v) for v in legendtitlecolor.split(",")) - - if "," in legendposition: - legendposition = tuple(float(strvalue) for strvalue in legendposition.split(",")) + if "," in legend_position: + legend_position = tuple(float(strvalue) for strvalue in legend_position.split(",")) with asection( - f"Projection rendering of: {input_paths} to {output_path} for channels: {channels}, slicing: {slicing} " + f"Projection rendering of: {input_dataset.path} to {output_path} for channels: {channels}, slicing: {input_dataset.slicing} " ): dataset_projection_rendering( - input_dataset, - output_path, - channels, - slicing, - overwrite, - axis, - dir, - mode, - clim, - attenuation, - gamma, - dlim, - colormap, - rgbgamma, - transparency, - legendsize, - legendscale, - legendtitle, - legendtitlecolor, - legendposition, - legendalpha, - step, - workers, - workersbackend, - devices, - stop_at_exception, + input_dataset=input_dataset, + output_path=output_path, + channels=channels, + overwrite=overwrite, + devices=devices, + axis=axis, + dir=dir, + mode=mode, + clim=clim, + attenuation=attenuation, + gamma=gamma, + dlim=dlim, + cmap=colormap, + rgb_gamma=rgbgamma, + transparency=transparency, + legend_size=legend_size, + legend_scale=legend_scale, + legend_title=legend_title, + legend_title_color=legend_title_color, + legend_position=legend_position, + legend_alpha=legend_alpha, ) input_dataset.close() diff --git a/dexp/cli/dexp_commands/register.py b/dexp/cli/dexp_commands/register.py index f7ee6a6f..456e6e62 100644 --- a/dexp/cli/dexp_commands/register.py +++ b/dexp/cli/dexp_commands/register.py @@ -1,32 +1,30 @@ +from typing import Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.parsing import _parse_channels, _parse_slicing, parse_devices -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + channels_option, + input_dataset_argument, + multi_devices_option, + slicing_option, +) +from dexp.datasets import BaseDataset from dexp.datasets.operations.register import dataset_register @click.command() -@click.argument("input_paths", nargs=-1, required=True) +@input_dataset_argument() +@slicing_option() +@channels_option() +@multi_devices_option() @click.option("--out-model-path", "-o", default="registration_models.txt", show_default=True) -@click.option( - "--channels", - "-c", - default=None, - help="list of channels for the view in standard order for the microscope type (C0L0, C0L1, C1L0, C1L1,...)", -) -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) @click.option( "--microscope", "-m", - type=str, + type=click.Choice(("simview", "mvsols")), default="simview", - help="Microscope objective to use for computing psf, can be: simview or mvsols", + help="Microscope used to acquire the data, this selects the fusion algorithm.", show_default=True, ) @click.option( @@ -52,7 +50,9 @@ help="Clips voxel values above the given value, if zero no clipping is done", show_default=True, ) -@click.option("--fusion", "-f", type=str, default="tg", help="Fusion mode, can be: ‘tg’ or ‘dct’. ", show_default=True) +@click.option( + "--fusion", "-f", type=click.Choice(("tg", "dct", "dft")), default="tg", help="Fusion method.", show_default=True +) @click.option( "--fusion_bias_strength", "-fbs", @@ -84,9 +84,6 @@ help="Registers using only the maximum intensity projection from each stack.", show_default=True, ) -@click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True -) @click.option( "--white-top-hat-size", "-wth", @@ -105,34 +102,28 @@ help="Use this flag to remove beads before equalizing and fusing", ) def register( - input_paths, - out_model_path, - channels, - slicing, - microscope, - equalise, - zero_level, - clip_high, - fusion, - fusion_bias_strength, - dehaze_size, - edge_filter, - max_proj, - white_top_hat_size, - white_top_hat_sampling, - remove_beads, - devices, -): + input_dataset: BaseDataset, + out_model_path: str, + channels: Sequence[str], + microscope: str, + equalise: bool, + zero_level: int, + clip_high: int, + fusion: str, + fusion_bias_strength: float, + dehaze_size: int, + edge_filter: bool, + max_proj: bool, + white_top_hat_size: float, + white_top_hat_sampling: int, + remove_beads: bool, + devices: Sequence[int], +) -> None: """ Computes registration model for fusing. """ - input_dataset, input_paths = glob_datasets(input_paths) - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - devices = parse_devices(devices) - with asection( - f"Fusing dataset: {input_paths}, saving it at: {out_model_path}, for channels: {channels}, slicing: {slicing} " + f"Computing fusion model of dataset: {input_dataset.path}, saving it at: {out_model_path}, for channels: {channels}" ): aprint(f"Microscope type: {microscope}, fusion type: {fusion}") aprint(f"Devices used: {devices}") @@ -140,7 +131,6 @@ def register( dataset=input_dataset, model_path=out_model_path, channels=channels, - slicing=slicing, microscope=microscope, equalise=equalise, zero_level=zero_level, @@ -155,3 +145,5 @@ def register( max_proj=max_proj, devices=devices, ) + + input_dataset.close() diff --git a/dexp/cli/dexp_commands/segment.py b/dexp/cli/dexp_commands/segment.py index 765906b5..4ab55032 100644 --- a/dexp/cli/dexp_commands/segment.py +++ b/dexp/cli/dexp_commands/segment.py @@ -1,45 +1,36 @@ import click from arbol import asection -from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE from dexp.cli.parsing import ( - _get_output_path, - _parse_channels, - _parse_chunks, - _parse_slicing, - parse_devices, + channels_callback, + input_dataset_argument, + multi_devices_option, + output_dataset_options, + slicing_option, ) -from dexp.datasets.open_dataset import glob_datasets from dexp.datasets.operations.segment import dataset_segment -from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o", default=None) # , help='output path' -@click.option("--detection-channels", "-dc", default=None, help="list of channels used to detect cells.") -@click.option("--features-channels", "-fc", default=None, help="list of channels for cell features extraction.") +@input_dataset_argument() +@output_dataset_options() +@multi_devices_option() +@slicing_option() @click.option( - "--slicing", - "-s", + "--detection-channels", + "-dc", default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", + help="list of channels used to detect cells.", + callback=channels_callback, ) -@click.option("--out-channel", "-oc", default="Segments", help="Output channel name.", show_default=True) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option("--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).") -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: zstd for ’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) @click.option( - "--devices", "-d", type=str, default="0", help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", show_default=True + "--features-channels", + "-fc", + default=None, + help="list of channels for cell features extraction.", + callback=channels_callback, ) +@click.option("--out-channel", "-oc", default="Segments", help="Output channel name.", show_default=True) @click.option("--z-scale", "-z", type=int, default=1, help="Anisotropy over z axis.") @click.option( "--area-threshold", @@ -47,6 +38,7 @@ type=click.FloatRange(min=0), default=1e4, help="Parameter for cell detection, a smaller values will detect fewer segments.", + show_default=True, ) @click.option("--minimum-area", "-m", type=float, default=500, help="Minimum area (sum of pixels) per segment.") @click.option( @@ -55,6 +47,7 @@ type=float, default=1, help="Parameter to adjust the number of segments, smaller results in more segments.", + show_default=True, ) @click.option( "--compactness", @@ -62,6 +55,7 @@ type=click.FloatRange(min=0.0), default=0, help="Cell compactness (convexity) penalization, it penalizes non-convex shapes.", + show_default=True, ) @click.option( "--gamma", @@ -69,6 +63,7 @@ type=click.FloatRange(min=0.0), default=1.0, help="Gamma parameter to equalize (exponent) the image intensities.", + show_default=True, ) @click.option( "--use-edt", @@ -76,46 +71,17 @@ is_flag=True, default=True, help="Use Euclidean Distance Transform (EDT) or image intensity for segmentation.", + show_default=True, ) -def segment( - input_paths, - **kwargs, -): - """Detects and segment cells using their intensity, returning a table with some properties""" - - input_dataset, input_paths = glob_datasets(input_paths) - features_channels = _parse_channels(input_dataset, kwargs.pop("features_channels")) - detection_channels = _parse_channels(input_dataset, kwargs.pop("detection_channels")) - slicing = _parse_slicing(kwargs.pop("slicing")) - devices = parse_devices(kwargs.pop("devices")) - - if slicing is not None: - input_dataset.set_slicing(slicing) - - output_path = _get_output_path(input_paths[0], kwargs.pop("output_path"), "_segments") - mode = "w" + ("" if kwargs.pop("overwrite") else "-") - output_dataset = ZDataset( - output_path, - mode=mode, - store=kwargs.pop("store"), - codec=kwargs.pop("codec"), - clevel=kwargs.pop("clevel"), - chunks=_parse_chunks(kwargs.pop("chunks")), - parent=input_dataset, - ) - +def segment(**kwargs): + """Segment cells of a combination of multiple detection channels and extract features of each segment from the feature channels.""" with asection( - f"Segmenting {input_paths} with {detection_channels} and extracting features" - f"of {features_channels}, results saved to to {output_path}" + f"Segmenting {kwargs['input_dataset'].path} with {kwargs['detection_channels']} and extracting features" + f"of {kwargs['features_channels']}, results saved to to {kwargs['output_dataset'].path}" ): dataset_segment( - input_dataset, - output_dataset, - features_channels=features_channels, - detection_channels=detection_channels, - devices=devices, **kwargs, ) - input_dataset.close() - output_dataset.close() + kwargs["input_dataset"].close() + kwargs["output_dataset"].close() diff --git a/dexp/cli/dexp_commands/serve.py b/dexp/cli/dexp_commands/serve.py deleted file mode 100644 index f125bd50..00000000 --- a/dexp/cli/dexp_commands/serve.py +++ /dev/null @@ -1,21 +0,0 @@ -import click -from arbol.arbol import aprint, asection - -from dexp.datasets.open_dataset import glob_datasets -from dexp.datasets.operations.serve import dataset_serve - - -@click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--host", "-h", type=str, default="0.0.0.0", help="Host to serve from", show_default=True) -@click.option("--port", "-p", type=int, default=8000, help="Port to serve from", show_default=True) -def serve(input_paths, host, port): - """Serves dataset across network.""" - - input_dataset, input_paths = glob_datasets(input_paths) - - with asection(f"Serving dataset(s): {input_paths}"): - dataset_serve(input_dataset, host=host, port=port) - - input_dataset.close() - aprint("Done!") diff --git a/dexp/cli/dexp_commands/stabilize.py b/dexp/cli/dexp_commands/stabilize.py index 90d43355..dba48dc6 100644 --- a/dexp/cli/dexp_commands/stabilize.py +++ b/dexp/cli/dexp_commands/stabilize.py @@ -1,44 +1,33 @@ +from typing import Optional, Sequence + import click from arbol.arbol import aprint, asection -from dexp.cli.defaults import ( - DEFAULT_CLEVEL, - DEFAULT_CODEC, - DEFAULT_STORE, - DEFAULT_WORKERS_BACKEND, -) -from dexp.cli.parsing import _get_output_path, _parse_channels, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.defaults import DEFAULT_WORKERS_BACKEND +from dexp.cli.parsing import ( + channels_option, + device_option, + input_dataset_argument, + output_dataset_options, + workers_option, +) from dexp.datasets.operations.stabilize import dataset_stabilize +from dexp.datasets.zarr_dataset import ZDataset @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="List of channels, all channels when omitted.") +@input_dataset_argument() +@output_dataset_options() +@channels_option() @click.option( "--reference-channel", "-rc", default=None, help="Reference channel for single stabilization model computation." ) -@click.option( - "--slicing", - "-s", - default=None, - help="Dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) -@click.option("--store", "-st", default=DEFAULT_STORE, help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", show_default=True) -@click.option( - "--codec", - "-z", - default=DEFAULT_CODEC, - help="Compression codec: ‘zstd’, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’ ", - show_default=True, -) -@click.option("--clevel", "-l", type=int, default=DEFAULT_CLEVEL, help="Compression level", show_default=True) -@click.option("--overwrite", "-w", is_flag=True, help="Forces overwrite of target", show_default=True) +@workers_option() +@device_option() @click.option( "--maxrange", "-mr", - type=int, + type=click.IntRange(min=1), default=7, help="Maximal distance, in time points, between pairs of images to registrate.", show_default=True, @@ -46,22 +35,23 @@ @click.option( "--minconfidence", "-mc", - type=float, + type=click.FloatRange(min=0, max=1), default=0.5, help="Minimal confidence for registration parameters, if below that level the registration parameters for previous time points is used.", show_default=True, -) # +) @click.option( "--com/--no-com", type=bool, default=False, + is_flag=True, help="Enable center of mass fallback when standard registration fails.", show_default=True, ) @click.option( "--quantile", "-q", - type=float, + type=click.FloatRange(min=0, max=1), default=0.5, help="Quantile to cut-off background in center-of-mass calculation.", show_default=True, @@ -110,6 +100,7 @@ "-ef", type=bool, default=False, + is_flag=True, help="Applies sobel edge filter to input images.", show_default=True, ) @@ -145,14 +136,6 @@ show_default=True, help="Output path for computed registration model", ) -@click.option( - "--workers", - "-k", - type=int, - default=-4, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n|. Be careful, starting two many workers is know to cause trouble (unfortunately unclear why!).", - show_default=True, -) @click.option( "--workersbackend", "-wkb", @@ -160,63 +143,44 @@ default=DEFAULT_WORKERS_BACKEND, help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", show_default=True, -) # +) @click.option("--device", "-d", type=int, default=0, help="Sets the CUDA devices id, e.g. 0,1,2", show_default=True) # -@click.option("--check", "-ck", default=True, help="Checking integrity of written file.", show_default=True) # def stabilize( - input_paths, - output_path, - channels, - reference_channel, - slicing, - store, - codec, - clevel, - overwrite, - maxrange, - minconfidence, - com, - quantile, - tolerance, - ordererror, - orderreg, - alphareg, - pcsigma, - dsigma, - logcomp, - edgefilter, - detrend, - maxproj, - model_input_path, - model_output_path, - workers, - workersbackend, - device, - check, + input_dataset: ZDataset, + output_dataset: ZDataset, + channels: Sequence[str], + reference_channel: Optional[str], + maxrange: int, + minconfidence: float, + com: bool, + quantile: float, + tolerance: float, + ordererror: float, + orderreg: float, + alphareg: float, + pcsigma: float, + dsigma: float, + logcomp: bool, + edgefilter: bool, + detrend: bool, + maxproj: bool, + model_input_path: str, + model_output_path: str, + workers: int, + workersbackend: str, + device: int, ): """Stabilises dataset against translations across time.""" - - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path, "_stabilized") - - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) - with asection( - f"Stabilizing dataset(s): {input_paths}, saving it at: {output_path}, for channels: {channels}, slicing: {slicing} " + f"Stabilizing dataset(s): {input_dataset.path}, saving it at: {output_dataset.path}, for channels: {channels}" ): dataset_stabilize( input_dataset, - output_path, + output_dataset, channels=channels, model_output_path=model_output_path, model_input_path=model_input_path, reference_channel=reference_channel, - slicing=slicing, - zarr_store=store, - compression_codec=codec, - compression_level=clevel, - overwrite=overwrite, max_range=maxrange, min_confidence=minconfidence, enable_com=com, @@ -234,9 +198,8 @@ def stabilize( workers=workers, workers_backend=workersbackend, device=device, - check=check, - debug_output="stabilization", ) input_dataset.close() + output_dataset.close() aprint("Done!") diff --git a/dexp/cli/dexp_commands/tiff.py b/dexp/cli/dexp_commands/tiff.py index a2a1fb39..8f89a605 100644 --- a/dexp/cli/dexp_commands/tiff.py +++ b/dexp/cli/dexp_commands/tiff.py @@ -1,43 +1,39 @@ +from pathlib import Path +from typing import Optional, Sequence + import click from arbol.arbol import aprint, asection from dexp.cli.defaults import DEFAULT_WORKERS_BACKEND -from dexp.cli.parsing import _get_output_path, _parse_channels, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + _get_output_path, + channels_option, + input_dataset_argument, + slicing_option, + workers_option, +) +from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.operations.tiff import dataset_tiff @click.command() -@click.argument("input_paths", nargs=-1) # , help='input path' -@click.option("--output_path", "-o") # , help='output path' -@click.option("--channels", "-c", default=None, help="selected channels.") # -@click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", -) # +@input_dataset_argument() +@channels_option() +@slicing_option() +@workers_option() +@click.option("--output-path", "-o", type=click.Path(path_type=Path), default=None) +@click.option("--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True) @click.option( - "--overwrite", "-w", is_flag=True, help="to force overwrite of target", show_default=True -) # , help='dataset slice' -@click.option( - "--project", "-p", type=int, default=None, help="max projection over given axis (0->T, 1->Z, 2->Y, 3->X)" -) # , help='dataset slice' + "--project", "-p", type=int, default=None, help="max projection over given an spatial axis (0->Z, 1->Y, 2->X)" +) @click.option( "--split", is_flag=True, help="Splits dataset along first dimension, be carefull, if you slice to a single time point this will split along z!", -) # , help='dataset slice' +) @click.option( "--clevel", "-l", type=int, default=0, help="Compression level, 0 means no compression, max is 9", show_default=True -) # , help='dataset slice' -@click.option( - "--workers", - "-k", - default=-4, - help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n|", - show_default=True, -) # +) @click.option( "--workersbackend", "-wkb", @@ -46,22 +42,29 @@ help="What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) ", show_default=True, ) -def tiff(input_paths, output_path, channels, slicing, overwrite, project, split, clevel, workers, workersbackend): +def tiff( + input_dataset: BaseDataset, + output_path: Optional[Path], + channels: Sequence[str], + overwrite: bool, + project: bool, + split: bool, + clevel: int, + workers: int, + workersbackend: str, +): """Exports dataset as TIFF file(s).""" - input_dataset, input_paths = glob_datasets(input_paths) - output_path = _get_output_path(input_paths[0], output_path) - slicing = _parse_slicing(slicing) - channels = _parse_channels(input_dataset, channels) + if output_path is None: + output_path = Path(_get_output_path(input_dataset.path, None)) with asection( - f"Exporting to TIFF datset: {input_paths}, channels: {channels}, slice: {slicing}, project:{project}, split:{split}" + f"Exporting to TIFF datset: {input_dataset.path}, channels: {channels}, project:{project}, split:{split}" ): dataset_tiff( input_dataset, output_path, channels=channels, - slicing=slicing, overwrite=overwrite, project=project, one_file_per_first_dim=split, diff --git a/dexp/cli/dexp_commands/view.py b/dexp/cli/dexp_commands/view.py index c6cc5578..7a352a42 100644 --- a/dexp/cli/dexp_commands/view.py +++ b/dexp/cli/dexp_commands/view.py @@ -1,28 +1,23 @@ +from typing import Sequence, Tuple + import click from arbol.arbol import aprint, asection -from dexp.cli.parsing import _parse_channels, _parse_slicing -from dexp.datasets.open_dataset import glob_datasets +from dexp.cli.parsing import ( + channels_option, + empty_channels_callback, + input_dataset_argument, +) +from dexp.datasets.base_dataset import BaseDataset +from dexp.datasets.joined_dataset import JoinedDataset from dexp.datasets.operations.view import dataset_view -from dexp.datasets.operations.view_remote import dataset_view_remote @click.command() -@click.argument("input_paths", nargs=-1) -@click.option("--channels", "-c", default=None, help="list of channels, all channels when ommited.") +@input_dataset_argument() +@channels_option() @click.option( - "--slicing", - "-s", - default=None, - help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z).", -) -@click.option("--aspect", "-a", type=float, default=4, help="sets aspect ratio e.g. 4", show_default=True) -@click.option( - "--rescale-time", - "-rt", - is_flag=True, - help="Rescale time axis, useful when channels acquisition rate was not uniform.", - show_default=True, + "--downscale", "-d", type=int, default=1, help="Downscale value to speedup visualization", show_default=True ) @click.option( "--clim", @@ -31,19 +26,20 @@ default="0,512", help="Sets the contrast limits, i.e. -cl 0,1000 sets the contrast limits to [0,1000]", show_default=True, + callback=lambda x, y, clim: tuple(float(v.strip()) for v in clim.split(",")), ) @click.option( "--colormap", "-cm", type=str, - default="viridis", + default="magma", help="sets colormap, e.g. viridis, gray, magma, plasma, inferno ", show_default=True, ) @click.option( "--windowsize", "-ws", - type=int, + type=click.IntRange(min=1), default=1536, help="Sets the napari window size. i.e. -ws 400 sets the window to 400x400", show_default=True, @@ -52,47 +48,47 @@ "--projectionsonly", "-po", is_flag=True, help="To view only the projections, if present.", show_default=True ) @click.option("--volumeonly", "-vo", is_flag=True, help="To view only the volumetric data.", show_default=True) -def view(input_paths, channels, slicing, aspect, rescale_time, clim, colormap, windowsize, projectionsonly, volumeonly): +@click.option("--quiet", "-q", is_flag=True, default=False, help="Quiet mode. Doesn't display GUI.") +@click.option( + "--labels", + "-l", + type=str, + default=None, + help="Channels to be displayed as labels layer", + callback=empty_channels_callback, +) +def view( + input_dataset: BaseDataset, + channels: Sequence[str], + clim: Tuple[float], + downscale: int, + colormap: str, + windowsize: int, + projectionsonly: bool, + volumeonly: bool, + quiet: bool, + labels: Sequence[str], +): """Views dataset using napari (napari.org)""" - slicing = _parse_slicing(slicing) - - name = input_paths[0] + "..." if len(input_paths) > 1 else "" - - contrast_limits = tuple(float(v.strip()) for v in clim.split(",")) - - if len(input_paths) == 1 and "http" in input_paths[0]: - - with asection(f"Viewing dataset at: {input_paths}, channels: {channels}, slicing: {slicing}, aspect:{aspect} "): - dataset_view_remote( - input_path=input_paths[0], - name=name, - aspect=aspect, - channels=channels, - contrast_limits=contrast_limits, - colormap=colormap, - slicing=slicing, - windowsize=windowsize, - ) - aprint("Done!") + name = input_dataset.path + if isinstance(input_dataset, JoinedDataset): + name = name + " ..." - else: - input_dataset, input_paths = glob_datasets(input_paths) - channels = _parse_channels(input_dataset, channels) + with asection(f"Viewing dataset at: {input_dataset.path}, channels: {channels}, downscale: {downscale}"): + dataset_view( + input_dataset=input_dataset, + name=name, + channels=channels, + contrast_limits=clim, + colormap=colormap, + scale=downscale, + windowsize=windowsize, + projections_only=projectionsonly, + volume_only=volumeonly, + quiet=quiet, + labels=labels, + ) - with asection(f"Viewing dataset at: {input_paths}, channels: {channels}, slicing: {slicing}, aspect:{aspect} "): - dataset_view( - input_dataset=input_dataset, - name=name, - aspect=aspect, - channels=channels, - contrast_limits=contrast_limits, - colormap=colormap, - slicing=slicing, - windowsize=windowsize, - projections_only=projectionsonly, - volume_only=volumeonly, - rescale_time=rescale_time, - ) - input_dataset.close() - aprint("Done!") + input_dataset.close() + aprint("Done!") diff --git a/dexp/cli/dexp_main.py b/dexp/cli/dexp_main.py index 121a1ca1..11ef9f7d 100644 --- a/dexp/cli/dexp_main.py +++ b/dexp/cli/dexp_main.py @@ -2,6 +2,7 @@ from arbol.arbol import aprint, asection from dexp.cli.dexp_commands.add import add +from dexp.cli.dexp_commands.background import background from dexp.cli.dexp_commands.check import check from dexp.cli.dexp_commands.copy import copy from dexp.cli.dexp_commands.crop import crop @@ -12,13 +13,12 @@ from dexp.cli.dexp_commands.fastcopy import fastcopy from dexp.cli.dexp_commands.fromraw import fromraw from dexp.cli.dexp_commands.fuse import fuse +from dexp.cli.dexp_commands.generic import generic from dexp.cli.dexp_commands.histogram import histogram from dexp.cli.dexp_commands.info import info -from dexp.cli.dexp_commands.isonet import isonet from dexp.cli.dexp_commands.projrender import projrender from dexp.cli.dexp_commands.register import register from dexp.cli.dexp_commands.segment import segment -from dexp.cli.dexp_commands.serve import serve from dexp.cli.dexp_commands.speedtest import speedtest from dexp.cli.dexp_commands.stabilize import stabilize from dexp.cli.dexp_commands.tiff import tiff @@ -59,29 +59,26 @@ def cli(): aprint("'cupy' module not found! ignored!") -cli.add_command(info) +cli.add_command(add) +cli.add_command(background) cli.add_command(check) cli.add_command(copy) cli.add_command(crop) -cli.add_command(fastcopy) -cli.add_command(add) -cli.add_command(tiff) -cli.add_command(view) -cli.add_command(serve) +cli.add_command(deconv) +cli.add_command(denoise) +cli.add_command(deskew) cli.add_command(extract_psf) +cli.add_command(fastcopy) cli.add_command(fromraw) - -cli.add_command(deskew) cli.add_command(fuse) -cli.add_command(register) -cli.add_command(stabilize) -cli.add_command(deconv) -cli.add_command(denoise) -cli.add_command(isonet) +cli.add_command(generic) cli.add_command(histogram) +cli.add_command(info) +cli.add_command(projrender) +cli.add_command(register) cli.add_command(segment) - cli.add_command(speedtest) - +cli.add_command(stabilize) +cli.add_command(tiff) +cli.add_command(view) cli.add_command(volrender) -cli.add_command(projrender) diff --git a/dexp/cli/parsing.py b/dexp/cli/parsing.py index b70ba77b..2e5c66fb 100644 --- a/dexp/cli/parsing.py +++ b/dexp/cli/parsing.py @@ -1,7 +1,17 @@ -from typing import Optional, Sequence, Tuple, Union +import ast +import inspect +import warnings +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +import click from arbol import aprint from numpy import s_ +from toolz import curry + +from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC, DEFAULT_STORE +from dexp.datasets import ZDataset +from dexp.datasets.open_dataset import glob_datasets +from dexp.utils import overwrite2mode def _get_output_path(input_path, output_path, postfix=""): @@ -47,7 +57,7 @@ def _parse_channels(input_dataset, channels): aprint(f"Selected channel(s) : '{channels}'") return channels - + def parse_devices(devices: str) -> Union[str, Sequence[int]]: aprint(f"Requested devices : '{'--All--' if 'all' in devices else devices}' ") @@ -67,5 +77,353 @@ def parse_devices(devices: str) -> Union[str, Sequence[int]]: def _parse_chunks(chunks: Optional[str]) -> Optional[Tuple[int]]: if chunks is not None: - chunks = eval(chunks) + chunks = (int(c) for c in chunks.split(",")) return chunks + + +def devices_callback(ctx: click.Context, opt: click.Option, value: str) -> Sequence[int]: + return parse_devices(value) + + +def multi_devices_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--devices", + "-d", + type=str, + default="all", + help="Sets the CUDA devices id, e.g. 0,1,2 or ‘all’", + show_default=True, + callback=devices_callback, + )(f) + + return decorator + + +def device_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--device", + "-d", + type=int, + default=0, + show_default=True, + help="Selects the CUDA device by id, starting from 0.", + )(f) + + return decorator + + +def workers_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--workers", + "-wk", + default=-4, + help="Number of worker threads to spawn. Negative numbers n correspond to: number_of _cores / |n| ", + show_default=True, + )(f) + + return decorator + + +def slicing_callback(ctx: click.Context, opt: click.Option, value: str) -> None: + slicing = _parse_slicing(value) + if slicing is not None: + ctx.params["input_dataset"].set_slicing(slicing) + opt.expose_value = False + + +def slicing_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--slicing", + "-s", + default=None, + help="dataset slice (TZYX), e.g. [0:5] (first five stacks) [:,0:100] (cropping in z) ", + callback=slicing_callback, + )(f) + + return decorator + + +def channels_callback(ctx: click.Context, opt: click.Option, value: Optional[str]) -> Sequence[str]: + return _parse_channels(ctx.params["input_dataset"], value) + + +def channels_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--channels", + "-c", + default=None, + help="list of channels, all channels when ommited.", + callback=channels_callback, + )(f) + + return decorator + + +def optional_channels_callback(ctx: click.Context, opt: click.Option, value: Optional[str]) -> Sequence[str]: + """Parses optional channels name, returns `input_dataset` channels if nothing is provided.""" + return ctx.params["input_dataset"].channels() if value is None else value.split(",") + + +def empty_channels_callback(ctx: click.Context, opt: click.Option, value: Optional[str]) -> Sequence[str]: + """Returns empty list if no channels are provided.""" + if value is None: + return [] + return _parse_channels(ctx.params["input_dataset"], value) + + +def input_dataset_callback(ctx: click.Context, arg: click.Argument, value: Sequence[str]) -> None: + try: + ctx.params["input_dataset"], _ = glob_datasets(value) + arg.expose_value = False + + except ValueError as e: + warnings.warn(str(e)) + + +def input_dataset_argument() -> Callable: + def decorator(f: Callable) -> Callable: + return click.argument("input-paths", nargs=-1, callback=input_dataset_callback, is_eager=True)(f) + + return decorator + + +def output_dataset_callback(ctx: click.Context, opt: click.Option, value: Optional[str]) -> None: + mode = overwrite2mode(ctx.params.pop("overwrite")) + if value is None: + # new name with suffix if value is None + value = _get_output_path(ctx.params["input_dataset"].path, None, "." + ctx.command.name) + + ctx.params["output_dataset"] = ZDataset( + value, + mode=mode, + store=ctx.params.pop("store"), + codec=ctx.params.pop("codec"), + clevel=ctx.params.pop("clevel"), + chunks=_parse_chunks(ctx.params.pop("chunks")), + parent=ctx.params["input_dataset"], + ) + # removing used parameters + opt.expose_value = False + + +def output_dataset_options() -> Callable: + click_options = [ + click.option( + "--output-path", "-o", default=None, help="Dataset output path.", callback=output_dataset_callback + ), + overwrite_option(), + click.option( + "--store", + "-st", + default=DEFAULT_STORE, + help="Zarr store: ‘dir’, ‘ndir’, or ‘zip’", + show_default=True, + is_eager=True, + ), + click.option( + "--chunks", "-chk", default=None, help="Dataset chunks dimensions, e.g. (1, 126, 512, 512).", is_eager=True + ), + click.option( + "--codec", + "-z", + default=DEFAULT_CODEC, + help="Compression codec: ‘zstd‘, ‘blosclz’, ‘lz4’, ‘lz4hc’, ‘zlib’ or ‘snappy’", + show_default=True, + is_eager=True, + ), + click.option( + "--clevel", + "-l", + type=int, + default=DEFAULT_CLEVEL, + help="Compression level", + show_default=True, + is_eager=True, + ), + ] + + def decorator(f: Callable) -> Callable: + for opt in click_options: + f = opt(f) + return f + + return decorator + + +def overwrite_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--overwrite", + "-w", + is_flag=True, + help="Forces overwrite of output", + show_default=True, + default=False, + is_eager=True, + )(f) + + return decorator + + +def tilesize_option(default: Optional[int] = 320) -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--tilesize", "-ts", type=int, default=default, help="Tile size for tiled computation", show_default=True + )(f) + + return decorator + + +def args_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--args", + "-a", + type=str, + required=False, + multiple=True, + default=list(), + help="Function arguments, it must be used multiple times for multiple arguments. Example: -a sigma=1,2,3 -a pad=constant", + )(f) + + return decorator + + +def verbose_option() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--verbose", "-v", type=bool, is_flag=True, default=False, help="Flag to display intermediated results." + )(f) + + return decorator + + +class KwargsCommand(click.Command): + def format_epilog(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """Writes the epilog into the formatter if it exists.""" + if self.epilog: + formatter.write_paragraph() + with formatter.indentation(): + formatter.write(self.epilog) + + +def parse_args_to_kwargs(args: Sequence[str], sep: str = "=") -> Dict[str, Any]: + """Parses list of strings with keys and values to a dictionary, + given a separator to split the key and values. + + Parameters + ---------- + args : Sequence[str] + List of strings of key/values, for example ["sigma=1,2,3", "pad=constant"] + sep : str, optional + key-value separator, by default "=" + + Returns + ------- + Dict[str, Any] + Dictionary of key word arguments. + """ + args = [arg.split(sep) for arg in args] + assert all(len(arg) == 2 for arg in args) + + kwargs = dict(args) + kwargs = {k: ast.literal_eval(v) for k, v in kwargs.items()} + + aprint(f"Using key word arguments {kwargs}") + + return kwargs + + +def validate_function_kwargs(func: Callable, kwargs: Dict[str, Any]) -> None: + """Checks if key-word arguments are valid arguments to the provided function. + + Parameters + ---------- + func : Callable + Reference function. + kwargs : Dict[str, Any] + Key-word arguments as dict. + + Raises + ------ + ValueError + If any key isn't an argument of the given function. + """ + sig = inspect.signature(func) + for key in kwargs: + if key not in sig.parameters: + raise ValueError(f"Function {func.__name__} doesn't support keyword {key}") + + +def func_args_to_str(func: Callable, ignore: Sequence[str] = []) -> str: + """Parses functions arguments to string format for click.Command.epilog + + Usage example: + @click.command( + cls=KwargsCommand, + epilog=func_args_to_str(, ), + ) + def func(args: Sequence[str]): + pass + + Parameters + ---------- + func : Callable + Function to extract arguments and default values + ignore : Sequence[str], optional + Arguments to ignore, by default [] + + Returns + ------- + str + String for click.command(epilog) + """ + sig = inspect.signature(func) + text = "Known key-word arguments:\n" + for k, v in sig.parameters.items(): + if k not in ignore: + text += f" -a, --args {k}={v.default}\n" + + return text + + +@curry +def tuple_callback( + ctx: click.Context, + opt: click.Option, + value: str, + dtype: Callable = int, + length: Optional[int] = None, +) -> Optional[Tuple[Any]]: + """Parses string to tuple given dtype and optional length. + Returns None if None is supplied. + + Parameters + ---------- + ctx : click.Context + CLI context, not used. + opt : click.Option + CLI option, not used. + value : str + Input value. + dtype : Callable, optional + Data type for type casting, by default int + length : Optional[int], optional + Optional length for length checking, by default None + + Returns + ------- + Tuple[Any] + Tuple of given dtype and length (optional). + """ + if value is None: + return None + tup = tuple(dtype(s) for s in value.split(",")) + if length is not None and length != len(tup): + raise ValueError(f"Expected tuple of length {length}, got input {tup}") + return tup diff --git a/dexp/conftest.py b/dexp/conftest.py index 525b079b..465e3870 100644 --- a/dexp/conftest.py +++ b/dexp/conftest.py @@ -2,6 +2,7 @@ from dexp.datasets.synthetic_datasets import ( binary_blobs, + generate_dexp_zarr_dataset, generate_fusion_test_data, generate_nuclei_background_data, ) @@ -22,6 +23,13 @@ def dexp_fusion_test_data(request): return generate_fusion_test_data(**request.param) +@pytest.fixture(scope="session") +def dexp_zarr_path(request, tmp_path_factory) -> str: + ds_path = tmp_path_factory.mktemp("data") / "input.zarr" + generate_dexp_zarr_dataset(ds_path, **request.param) + return str(ds_path) + + def pytest_addoption(parser): """pytest command line parser""" parser.addoption("--display", action="store", type=bool, default=False) diff --git a/dexp/datasets/base_dataset.py b/dexp/datasets/base_dataset.py index 3ed0c6a7..67a34149 100644 --- a/dexp/datasets/base_dataset.py +++ b/dexp/datasets/base_dataset.py @@ -1,5 +1,7 @@ +import warnings from abc import ABC, abstractmethod -from typing import Any, Sequence, Tuple +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy @@ -7,11 +9,16 @@ class BaseDataset(ABC): - def __init__(self, dask_backed=False): + def __init__(self, dask_backed=False, path: Union[Path, str] = ""): """Instanciates a Base Dataset""" self.dask_backed = dask_backed self._slicing = None + if not isinstance(path, str): + path = str(path) + + self._path = path + def _selected_channels(self, channels: Sequence[str]): if channels is None: selected_channels = self.channels() @@ -107,5 +114,17 @@ def check_integrity(self, channels: Sequence[str]) -> bool: def set_slicing(self, slicing: slice) -> None: self._slicing = slicing + @property + def slicing(self) -> slice: + return self._slicing + def __getitem__(self, channel: str) -> StackIterator: return StackIterator(self.get_array(channel), self._slicing) + + @property + def path(self) -> str: + return self._path + + def get_resolution(self, channel: Optional[str] = None) -> List[float]: + warnings.warn("`get_resolution` not implemented, returning list of 1.0 as default.") + return [1] * len(self.shape(channel)) diff --git a/dexp/datasets/clearcontrol_dataset.py b/dexp/datasets/clearcontrol_dataset.py index 8f283d8a..e55602c2 100755 --- a/dexp/datasets/clearcontrol_dataset.py +++ b/dexp/datasets/clearcontrol_dataset.py @@ -18,11 +18,10 @@ class CCDataset(BaseDataset): def __init__(self, path, cache_size=8e9): - super().__init__(dask_backed=False) + super().__init__(dask_backed=False, path=path) config_blosc() - self.folder = path self._channels = [] self._index_files = {} @@ -90,8 +89,8 @@ def _parse_channel(self, channel): def _get_stack_file_name(self, channel, time_point): - compressed_file_name = join(self.folder, "stacks", channel, str(time_point).zfill(6) + ".blc") - raw_file_name = join(self.folder, "stacks", channel, str(time_point).zfill(6) + ".raw") + compressed_file_name = join(self._path, "stacks", channel, str(time_point).zfill(6) + ".blc") + raw_file_name = join(self._path, "stacks", channel, str(time_point).zfill(6) + ".raw") if self._is_compressed is None: if exists(compressed_file_name): @@ -253,4 +252,4 @@ def _find_out_if_compressed(self, path): return False def time_sec(self, channel: str) -> np.ndarray: - return np.asarray(self._times_sec[channel]) + return np.asarray(self._times_sec[channel]) diff --git a/dexp/datasets/joined_dataset.py b/dexp/datasets/joined_dataset.py index 4bbd11b7..59bb9c27 100755 --- a/dexp/datasets/joined_dataset.py +++ b/dexp/datasets/joined_dataset.py @@ -147,3 +147,7 @@ def write_stack(self, channel: str, time_point: int, stack_array: numpy.ndarray) def write_array(self, channel: str, array: numpy.ndarray): raise NotImplementedError("Cannot write to a joined dataset!") + + @property + def path(self) -> str: + return ",".join([ds.path for ds in self._dataset_list]) diff --git a/dexp/datasets/operations/_test/test_deconv.py b/dexp/datasets/operations/_test/test_deconv.py index c5771997..d56f4bd6 100644 --- a/dexp/datasets/operations/_test/test_deconv.py +++ b/dexp/datasets/operations/_test/test_deconv.py @@ -3,11 +3,6 @@ from dexp.datasets.operations.demo.demo_deconv import _demo_deconv from dexp.utils.backends import CupyBackend -# deconv CLI only available for cuda -# def test_deconv_numpy(): -# with NumpyBackend(): -# _demo_deconv(display=False) - def test_deconv_cupy(): try: diff --git a/dexp/datasets/operations/_test/test_denoise.py b/dexp/datasets/operations/_test/test_denoise.py index 98380729..30919a76 100644 --- a/dexp/datasets/operations/_test/test_denoise.py +++ b/dexp/datasets/operations/_test/test_denoise.py @@ -6,7 +6,7 @@ from dexp.cli.parsing import parse_devices from dexp.datasets import ZDataset from dexp.datasets.operations.denoise import dataset_denoise -from dexp.utils.backends.cupy_backend import is_cupy_available +from dexp.utils.testing.testing import cupy_only @pytest.mark.parametrize( @@ -14,11 +14,8 @@ [dict(add_noise=True)], indirect=True, ) +@cupy_only def test_dataset_denoise(dexp_nuclei_background_data, tmp_path: Path, display_test: bool): - - if not is_cupy_available(): - pytest.skip(f"Cupy not found. Skipping {test_dataset_denoise.__name__} gpu test.") - # Load _, _, image = dexp_nuclei_background_data @@ -54,6 +51,8 @@ def test_dataset_denoise(dexp_nuclei_background_data, tmp_path: Path, display_te if __name__ == "__main__": + from dexp.utils.testing import test_as_demo + # the same as executing from the CLI # pytest -s --display True - pytest.main([__file__, "-s", "--display", "True"]) + test_as_demo(__file__) diff --git a/dexp/datasets/operations/_test/test_deskew.py b/dexp/datasets/operations/_test/test_deskew.py index 84940807..fe95424a 100644 --- a/dexp/datasets/operations/_test/test_deskew.py +++ b/dexp/datasets/operations/_test/test_deskew.py @@ -1,12 +1,7 @@ from arbol import aprint from dexp.datasets.operations.demo.demo_deskew import _demo_deskew -from dexp.utils.backends import CupyBackend, NumpyBackend - - -def test_deskew_numpy(): - with NumpyBackend(): - _demo_deskew(display=False) +from dexp.utils.backends import CupyBackend def test_deskew_cupy(): diff --git a/dexp/datasets/operations/_test/test_segment.py b/dexp/datasets/operations/_test/test_segment.py index a7f60a75..65adf009 100644 --- a/dexp/datasets/operations/_test/test_segment.py +++ b/dexp/datasets/operations/_test/test_segment.py @@ -8,7 +8,7 @@ from dexp.datasets import ZDataset from dexp.datasets.operations.segment import dataset_segment from dexp.utils.backends.backend import Backend -from dexp.utils.backends.cupy_backend import is_cupy_available +from dexp.utils.testing import cupy_only @pytest.mark.parametrize( @@ -16,11 +16,8 @@ [dict(add_noise=True, length_z_factor=4)], indirect=True, ) +@cupy_only def test_dataset_segment(dexp_nuclei_background_data, tmp_path: Path, display_test: bool): - - if not is_cupy_available(): - pytest.skip(f"Cupy not found. Skipping {test_dataset_segment.__name__} gpu test.") - # Load _, _, image = dexp_nuclei_background_data diff --git a/dexp/datasets/operations/background.py b/dexp/datasets/operations/background.py new file mode 100644 index 00000000..b6e9f47a --- /dev/null +++ b/dexp/datasets/operations/background.py @@ -0,0 +1,91 @@ +from typing import Callable, Dict, Optional, Sequence + +import dask +import numpy as np +from arbol import aprint, asection, section +from dask.distributed import Client +from dask_cuda import LocalCUDACluster +from toolz import curry, reduce + +from dexp.datasets import BaseDataset, ZDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.processing.crop.background import foreground_mask +from dexp.utils.backends.cupy_backend import CupyBackend + + +@dask.delayed +@curry +def _process( + time_point: int, + time_scale: Dict[str, int], + stacks: Dict[str, StackIterator], + out_dataset: ZDataset, + reference_channel: Optional[str], + merge_channels: bool, + foreground_mask_func: Callable, +) -> None: + + foreground_mask_func = section("Detecting foreground.")(foreground_mask_func) + + with CupyBackend() as bkd: + + # selects stacks while scaling time points + time_pts = {ch: int(round(time_point / s)) for ch, s in time_scale.items()} + stacks = {ch: bkd.to_backend(stacks[ch][time_pts[ch]]) for ch in stacks.keys()} + + with asection(f"Removing background of channels {time_pts.keys()} at time points {time_pts.values()}."): + # computes reference mask (or not) + if merge_channels: + reference = foreground_mask_func(reduce(np.add, stacks.values())) + elif reference_channel is not None: + reference = foreground_mask_func(stacks[reference_channel]) + else: + reference = None + + # removes background + for ch in stacks.keys(): + stack = stacks[ch] + mask = foreground_mask_func(stack) if reference is None else reference + stack[np.logical_not(mask)] = 0 + out_dataset.write_stack(ch, time_pts[ch], bkd.to_numpy(stack)) + + +def dataset_background( + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + reference_channel: Optional[str], + merge_channels: bool, + devices: Sequence[int], + **kwargs, +) -> None: + lazy_computations = [] + + if reference_channel is not None and merge_channels: + raise ValueError("`reference_channel` cannot be supplied with `merge_channels` option.") + + arrays = {ch: input_dataset[ch] for ch in channels} + max_t = max(len(arr) for arr in arrays.values()) + time_scale = {ch: input_dataset.get_resolution(ch)[0] for ch in channels} + process = _process( + time_scale=time_scale, + stacks=arrays, + out_dataset=output_dataset, + reference_channel=reference_channel, + merge_channels=merge_channels, + foreground_mask_func=curry(foreground_mask, **kwargs), + ) + + for ch in channels: + output_dataset.add_channel(ch, input_dataset.shape(ch), input_dataset.dtype(ch)) + + lazy_computations = [process(time_point=t) for t in range(max_t)] + + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + aprint("Dask client", client) + + # Compute everything + dask.compute(*lazy_computations) + + output_dataset.check_integrity() diff --git a/dexp/datasets/operations/copy.py b/dexp/datasets/operations/copy.py index 7552d261..84bd9a56 100644 --- a/dexp/datasets/operations/copy.py +++ b/dexp/datasets/operations/copy.py @@ -1,84 +1,77 @@ from typing import Optional, Sequence -import numpy +import numpy as np from arbol.arbol import aprint, asection from joblib import Parallel, delayed +from toolz import curry -from dexp.datasets import BaseDataset +from dexp.datasets import BaseDataset, ZDataset +from dexp.datasets.clearcontrol_dataset import CCDataset +from dexp.datasets.stack_iterator import StackIterator from dexp.utils.misc import compute_num_workers -from dexp.utils.slicing import slice_from_shape + + +@curry +def _process( + i: int, + in_array: StackIterator, + output_dataset: ZDataset, + channel: str, + zerolevel: int, + stop_at_exception: bool = True, +) -> None: + try: + aprint(f"Processing time point: {i} ...") + tp_array = np.asarray(in_array[i]) + if zerolevel != 0: + tp_array = np.clip(tp_array, a_min=zerolevel, a_max=None, out=tp_array) + tp_array -= zerolevel + output_dataset.write_stack(channel=channel, time_point=i, stack_array=tp_array) + + except Exception as error: + aprint(error) + aprint(f"Error occurred while copying time point {i} !") + import traceback + + traceback.print_exc() + if stop_at_exception: + raise error def dataset_copy( - dataset: BaseDataset, - dest_path: str, + input_dataset: BaseDataset, + output_dataset: ZDataset, channels: Sequence[str], - slicing, - store: str = "dir", - chunks: Optional[Sequence[int]] = None, - compression: str = "zstd", - compression_level: int = 3, - overwrite: bool = False, zerolevel: int = 0, workers: int = 1, workersbackend: Optional[str] = None, - check: bool = True, - stop_at_exception: bool = True, ): - - # Create destination dataset: - from dexp.datasets import ZDataset - - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(dest_path, mode, store, parent=dataset) - # Process each channel: - for channel in dataset._selected_channels(channels): + for channel in channels: + array = input_dataset[channel] + process = _process(in_array=array, output_dataset=output_dataset, channel=channel, zerolevel=zerolevel) with asection(f"Copying channel {channel}:"): - array = dataset.get_array(channel) - - aprint(f"Slicing with: {slicing}") - out_shape, volume_slicing, time_points = slice_from_shape(array.shape, slicing) - - dtype = array.dtype - dest_dataset.add_channel( - name=channel, shape=out_shape, dtype=dtype, chunks=chunks, codec=compression, clevel=compression_level - ) - def process(i): - tp = time_points[i] - try: - aprint(f"Processing time point: {i} ...") - tp_array = array[tp][volume_slicing] - if zerolevel != 0: - tp_array = numpy.array(tp_array) - tp_array = numpy.clip(tp_array, a_min=zerolevel, a_max=None, out=tp_array) - tp_array -= zerolevel - dest_dataset.write_stack(channel=channel, time_point=i, stack_array=tp_array) - except Exception as error: - aprint(error) - aprint(f"Error occurred while copying time point {i} !") - import traceback - - traceback.print_exc() - if stop_at_exception: - raise error + aprint(f"Slicing with: {array.slicing}") + output_dataset.add_channel(name=channel, shape=array.shape, dtype=array.dtype) if workers == 1: - for i in range(len(time_points)): + for i in range(len(array)): process(i) else: - n_jobs = compute_num_workers(workers, len(time_points)) - + n_jobs = compute_num_workers(workers, len(array)) parallel = Parallel(n_jobs=n_jobs, backend=workersbackend) - parallel(delayed(process)(i) for i in range(len(time_points))) + parallel(delayed(process)(i) for i in range(len(array))) + + if isinstance(input_dataset, CCDataset): + time = input_dataset.time_sec(channel).tolist() + output_dataset.append_metadata({channel: dict(time=time)}) # Dataset info: - aprint(dest_dataset.info()) + aprint(output_dataset.info()) # Check dataset integrity: - if check: - dest_dataset.check_integrity() + output_dataset.check_integrity() # close destination dataset: - dest_dataset.close() + output_dataset.close() diff --git a/dexp/datasets/operations/crop.py b/dexp/datasets/operations/crop.py index 9468f00e..5a692384 100644 --- a/dexp/datasets/operations/crop.py +++ b/dexp/datasets/operations/crop.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple +from typing import Sequence, Tuple import numpy as np import zarr @@ -6,22 +6,25 @@ from joblib import Parallel, delayed from numpy.typing import ArrayLike from scipy import ndimage as ndi +from toolz import curry from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.datasets.zarr_dataset import ZDataset from dexp.processing.filters.fft_convolve import fft_convolve from dexp.utils.backends import Backend, BestBackend from dexp.utils.misc import compute_num_workers -def _estimate_crop(array: ArrayLike, quantile: float = 0.99) -> Sequence[Tuple[int]]: - window_size = 31 +def _estimate_crop(array: ArrayLike, quantile: float) -> Sequence[Tuple[int]]: + window_size = 15 step = 4 shape = array.shape with BestBackend(): xp = Backend.get_xp_module() array = Backend.to_backend(array[::step, ::step, ::step], dtype=xp.float16) array = xp.clip(array - xp.mean(array), 0, None) # removing background noise - kernel = xp.ones((window_size, window_size, window_size)) / (window_size ** 3) + kernel = xp.ones((window_size, window_size, window_size)) / (window_size**3) kernel = kernel.astype(xp.float16) array = fft_convolve(array, kernel, in_place=True) lower = xp.quantile(array, quantile) @@ -55,82 +58,58 @@ def compute_crop_slicing(array: zarr.Array, time_points: Sequence[int], quantile return tuple(slice(int(l), int(u)) for l, u in zip(lower, upper)) +@curry +def _process(i: int, array: StackIterator, output_dataset: ZDataset, channel: str): + try: + aprint(f"Processing time point: {i} ...") + output_dataset.write_stack(channel=channel, time_point=i, stack_array=np.asarray(array[i])) + except Exception as error: + aprint(error) + aprint(f"Error occurred while copying time point {i} !") + import traceback + + traceback.print_exc() + raise error + + def dataset_crop( - dataset: BaseDataset, - dest_path: str, + input_dataset: BaseDataset, + output_dataset: ZDataset, reference_channel: str, channels: Sequence[str], quantile: float, - store: str = "dir", - chunks: Optional[Sequence[int]] = None, - compression: str = "zstd", - compression_level: int = 3, - overwrite: bool = False, - workers: int = 1, - check: bool = True, - stop_at_exception: bool = True, + workers: int, ): - - # Create destination dataset: - from dexp.datasets import ZDataset - - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(dest_path, mode, store, parent=dataset) - with asection("Estimating region of interest"): - nb_time_pts = dataset.nb_timepoints(reference_channel) + nb_time_pts = input_dataset.nb_timepoints(reference_channel) slicing = compute_crop_slicing( - dataset.get_array(reference_channel), [0, nb_time_pts // 2, nb_time_pts - 1], quantile + input_dataset.get_array(reference_channel), [0, nb_time_pts // 2, nb_time_pts - 1], quantile ) aprint("Estimated slicing of", slicing) - volume_shape = tuple(s.stop - s.start for s in slicing) translation = {k: s.start for k, s in zip(("tz", "ty", "tx"), slicing)} - dest_dataset.append_metadata(translation) + output_dataset.append_metadata(translation) + input_dataset.set_slicing(slicing=(slice(None, None),) + slicing) # Process each channel: - for channel in dataset._selected_channels(channels): + for channel in input_dataset._selected_channels(channels): with asection(f"Cropping channel {channel}:"): - array = dataset.get_array(channel) - - dtype = array.dtype - dest_dataset.add_channel( - name=channel, - shape=(len(array),) + volume_shape, - dtype=dtype, - chunks=chunks, - codec=compression, - clevel=compression_level, - ) - - def process(tp): - try: - aprint(f"Processing time point: {tp} ...") - tp_array = array[tp][slicing] - dest_dataset.write_stack(channel=channel, time_point=tp, stack_array=tp_array) - except Exception as error: - aprint(error) - aprint(f"Error occurred while copying time point {tp} !") - import traceback - - traceback.print_exc() - if stop_at_exception: - raise error + array = input_dataset[channel] + output_dataset.add_channel(name=channel, shape=array.shape, dtype=array.dtype) + process = _process(array=array, output_dataset=output_dataset, channel=channel) if workers == 1: for i in range(len(array)): process(i) else: n_jobs = compute_num_workers(workers, len(array)) - parallel = Parallel(n_jobs=n_jobs) parallel(delayed(process)(i) for i in range(len(array))) # Dataset info: - aprint(dest_dataset.info()) + aprint(output_dataset.info()) # Check dataset integrity: - if check: - dest_dataset.check_integrity() + output_dataset.check_integrity() # close destination dataset: - dest_dataset.close() + output_dataset.close() diff --git a/dexp/datasets/operations/deconv.py b/dexp/datasets/operations/deconv.py index 0eacc536..1796a4e4 100644 --- a/dexp/datasets/operations/deconv.py +++ b/dexp/datasets/operations/deconv.py @@ -1,15 +1,18 @@ import functools from pathlib import Path -from typing import List, Optional, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import dask -import numpy +import numpy as np import scipy from arbol.arbol import aprint, asection from dask.distributed import Client from dask_cuda import LocalCUDACluster +from toolz import curry from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.datasets.zarr_dataset import ZDataset from dexp.optics.psf.standard_psfs import nikon16x08na, olympus20x10na from dexp.processing.deconvolution import ( admm_deconvolution, @@ -17,79 +20,29 @@ ) from dexp.processing.filters.fft_convolve import fft_convolve from dexp.processing.utils.scatter_gather_i2i import scatter_gather_i2i -from dexp.utils.backends import Backend, BestBackend -from dexp.utils.slicing import slice_from_shape +from dexp.utils.backends import BestBackend -def dataset_deconv( - dataset: BaseDataset, - dest_path: str, - channels: Sequence[str], - slicing, - store: str = "dir", - compression: str = "zstd", - compression_level: int = 3, - overwrite: bool = False, - tilesize: Optional[Tuple[int]] = None, - method: str = "lr", - num_iterations: int = 16, - max_correction: int = 16, - power: float = 1, - blind_spot: int = 0, - back_projection: Optional[str] = None, - wb_order: int = 5, - psf_objective: str = "nikon16x08na", - psf_na: float = 0.8, - psf_dxy: float = 0.485, - psf_dz: float = 2, - psf_xy_size: int = 17, - psf_z_size: int = 17, - psf_show: bool = False, - scaling: Optional[Tuple[float]] = None, - workers: int = 1, - workersbackend: str = "", - devices: Optional[List[int]] = None, - check: bool = True, - stop_at_exception: bool = True, -): - - from dexp.datasets import ZDataset - - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(dest_path, mode, store, parent=dataset) - - # Default tile size: - if tilesize is None: - tilesize = 320 # very conservative +def get_psf( + psf_objective: str, + psf_na: float, + psf_dxy: float, + psf_dz: float, + psf_z_size: int, + psf_xy_size: int, + scaling: Tuple[float], + psf_show: bool, +) -> np.ndarray: - # Scaling default value: - if scaling is None: - scaling = (1, 1, 1) + # This is not ideal but difficult to avoid right now: sz, sy, sx = scaling - aprint(f"Input images will be scaled by: (sz,sy,sx)={scaling}") - - # CUDA DASK cluster - cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) - client = Client(cluster) - aprint("Dask Client", client) - lazy_computation = [] - - for channel in dataset._selected_channels(channels): - array = dataset.get_array(channel) - - aprint(f"Slicing with: {slicing}") - out_shape, volume_slicing, time_points = slice_from_shape(array.shape, slicing) - - out_shape = tuple(int(round(u * v)) for u, v in zip(out_shape, (1,) + scaling)) - dtype = numpy.float16 if method == "admm" else array.dtype - - # Adds destination array channel to dataset - dest_array = dest_dataset.add_channel( - name=channel, shape=out_shape, dtype=dtype, codec=compression, clevel=compression_level - ) - - # This is not ideal but difficult to avoid right now: + if Path(psf_objective).exists(): + # loading pre-computed PSF + psf_kernel = np.load(psf_objective) + if sz != 1.0 or sy != 1.0 or sx != 1.0: + psf_kernel = scipy.ndimage.zoom(psf_kernel, zoom=(sz, sy, sx), order=1) + else: sxy = (sx + sy) / 2 # PSF paraneters: @@ -112,126 +65,183 @@ def dataset_deconv( psf_kernel = nikon16x08na(**psf_kwargs) elif psf_objective == "olympus20x10na": psf_kernel = olympus20x10na(**psf_kwargs) - elif Path(psf_objective).exists(): - psf_kernel = numpy.load(psf_objective) - if sz != 1.0 or sy != 1.0 or sx != 1.0: - psf_kernel = scipy.ndimage.zoom(psf_kernel, zoom=(sz, sy, sx), order=1) - psf_z_size = psf_kernel.shape[0] + 10 - psf_xy_size = max(psf_kernel.shape[1:]) + 10 else: raise RuntimeError(f"Object/path {psf_objective} not found.") - # usefull for debugging: - if psf_show: - import napari - - viewer = napari.Viewer(title="DEXP | viewing PSF with napari", ndisplay=3) - viewer.add_image(psf_kernel) - napari.run() - - margins = max(psf_xy_size, psf_z_size) - - if method == "lr": - normalize = False - convolve = functools.partial(fft_convolve, in_place=False, mode="reflect", internal_dtype=numpy.float32) - - def deconv(image): - min_value = image.min() - max_value = image.max() - - return lucy_richardson_deconvolution( - image=image, - psf=psf_kernel, - num_iterations=num_iterations, - max_correction=max_correction, - normalise_minmax=(min_value, max_value), - power=power, - blind_spot=blind_spot, - blind_spot_mode="median+uniform", - blind_spot_axis_exclusion=(0,), - wb_order=wb_order, - back_projection=back_projection, - convolve_method=convolve, - ) - - elif method == "admm": - normalize = True - - def deconv(image): - out = admm_deconvolution( - image, - psf=psf_kernel, - iterations=num_iterations, - derivative=2, - ) - return out + # usefull for debugging: + if psf_show: + import napari + + viewer = napari.Viewer(title="DEXP | viewing PSF with napari", ndisplay=3) + viewer.add_image(psf_kernel) + napari.run() + + return psf_kernel + + +def get_deconv_func( + method: str, + psf_kernel: np.ndarray, + num_iterations: int, + max_correction: int, + blind_spot: int, + power: float, + wb_order: int, + back_projection: str, +) -> Callable: + + if method == "lr" or method == "wb": + convolve = functools.partial(fft_convolve, in_place=False, mode="reflect", internal_dtype=np.float32) + + def deconv(image): + min_value = image.min() + max_value = image.max() + + return lucy_richardson_deconvolution( + image=image, + psf=psf_kernel, + num_iterations=num_iterations, + max_correction=max_correction, + normalise_minmax=(min_value, max_value), + power=power, + blind_spot=blind_spot, + blind_spot_mode="median+uniform", + blind_spot_axis_exclusion=(0,), + wb_order=wb_order, + back_projection=back_projection, + convolve_method=convolve, + ) + + elif method == "admm": + + def deconv(image): + out = admm_deconvolution( + image, + psf=psf_kernel, + iterations=num_iterations, + derivative=2, + ) + return out + + else: + raise ValueError(f"Unknown deconvolution mode: {method}") + + return deconv + + +@dask.delayed +@curry +def _process( + time_point: int, + stacks: StackIterator, + out_dataset: ZDataset, + channel: str, + deconv_func: Callable, + scaling: Tuple[int], +): + + with asection(f"Deconvolving time point for time point {time_point}/{len(stacks)}"): + with asection(f"Loading channel: {channel}"): + stack = np.asarray(stacks[time_point]) - else: - raise ValueError(f"Unknown deconvolution mode: {method}") - - @dask.delayed - def process(i): - tp = time_points[i] - try: - with asection(f"Deconvolving time point for time point {i}/{len(time_points)}"): - with asection(f"Loading channel: {channel}"): - tp_array = numpy.asarray(array[tp][volume_slicing]) - - with BestBackend(exclusive=True, enable_unified_memory=True): - - if sz != 1.0 or sy != 1.0 or sx != 1.0: - with asection(f"Applying scaling {(sz, sy, sx)} to image."): - sp = Backend.get_sp_module() - tp_array = Backend.to_backend(tp_array) - tp_array = sp.ndimage.zoom(tp_array, zoom=(sz, sy, sx), order=1) - tp_array = Backend.to_numpy(tp_array) - - with asection( - f"Deconvolving image of shape: {tp_array.shape}, with tile size: {tilesize}, " - + "margins: {margins} " - ): - aprint(f"Number of iterations: {num_iterations}, back_projection:{back_projection}, ") - tp_array = scatter_gather_i2i( - deconv, - tp_array, - tiles=tilesize, - margins=margins, - normalise=normalize, - internal_dtype=dtype, - ) - - with asection("Moving array from backend to numpy."): - tp_array = Backend.to_numpy(tp_array, dtype=dest_array.dtype, force_copy=False) - - with asection( - f"Saving deconvolved stack for time point {i}, shape:{tp_array.shape}, dtype:{array.dtype}" - ): - dest_dataset.write_stack(channel=channel, time_point=i, stack_array=tp_array) - - aprint(f"Done processing time point: {i}/{len(time_points)} .") - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {i} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - for i in range(len(time_points)): - lazy_computation.append(process(i)) + with BestBackend() as bkd: + if any(s != 1.0 for s in scaling): + with asection(f"Applying scaling {scaling} to image."): + sp = bkd.get_sp_module() + stack = bkd.to_backend(stack) + stack = sp.ndimage.zoom(stack, zoom=scaling, order=1) + stack = bkd.to_numpy(stack) + + with asection("Deconvolving ..."): + stack = deconv_func(stack) + + with asection("Moving array from backend to numpy."): + stack = bkd.to_numpy(stack, dtype=out_dataset.dtype(channel), force_copy=False) + + with asection( + f"Saving deconvolved stack for time point {time_point}, shape:{stack.shape}, dtype:{stack.dtype}" + ): + out_dataset.write_stack(channel=channel, time_point=time_point, stack_array=stack) + + aprint(f"Done processing time point: {time_point}/{len(stacks)} .") + + +def dataset_deconv( + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + tilesize: int, + method: str, + num_iterations: Optional[int], + max_correction: int, + power: float, + blind_spot: int, + back_projection: Optional[str], + wb_order: int, + psf_objective: str, + psf_na: Optional[float], + psf_dxy: float, + psf_dz: float, + psf_xy_size: int, + psf_z_size: int, + psf_show: bool, + scaling: Tuple[float], + devices: List[int], +): + aprint(f"Input images will be scaled by: (sz,sy,sx)={scaling}") + + psf_kernel = get_psf(psf_objective, psf_na, psf_dxy, psf_dz, psf_z_size, psf_xy_size, scaling, psf_show) + margins = max(psf_kernel.shape[1], psf_kernel.shape[0]) + + deconv_func = get_deconv_func( + method, + psf_kernel, + num_iterations, + max_correction, + blind_spot, + power, + wb_order, + back_projection, + ) + deconv_func = curry( + scatter_gather_i2i, + function=deconv_func, + tiles=tilesize, + margins=margins, + normalise=method == "admm", + ) + + lazy_computation = [] + + for channel in channels: + stacks = input_dataset[channel] + + out_shape = tuple(int(round(u * v)) for u, v in zip(stacks.shape, (1,) + scaling)) + dtype = np.float16 if method == "admm" else stacks.dtype + + # Adds destination array channel to dataset + output_dataset.add_channel(name=channel, shape=out_shape, dtype=dtype) + + process = _process( + stacks=stacks, + out_dataset=output_dataset, + channel=channel, + scaling=scaling, + deconv_func=deconv_func(internal_dtype=dtype), + ) + + for t in range(len(stacks)): + lazy_computation.append(process(time_point=t)) dask.compute(*lazy_computation) + # CUDA DASK cluster + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + aprint("Dask Client", client) + # Dataset info: - aprint(dest_dataset.info()) + aprint(output_dataset.info()) # Check dataset integrity: - if check: - dest_dataset.check_integrity() - - # close destination dataset: - dest_dataset.close() - client.close() + output_dataset.check_integrity() diff --git a/dexp/datasets/operations/demo/demo_copy.py b/dexp/datasets/operations/demo/demo_copy.py index 7fb7549a..91fc3732 100644 --- a/dexp/datasets/operations/demo/demo_copy.py +++ b/dexp/datasets/operations/demo/demo_copy.py @@ -75,18 +75,18 @@ def _demo_copy(length_xy=96, zoom=1, n=16, display=True): with asection("Copy..."): # output_folder: output_path = join(tmpdir, "copy.zarr") + dataset.set_slicing((slice(0, 4),)) + copied_dataset = ZDataset(path=output_path, mode="w") # Do actual copy: dataset_copy( - dataset=dataset, - dest_path=output_path, + input_dataset=dataset, + output_dataset=copied_dataset, channels=("channel",), - slicing=(slice(0, 4), ...), workers=4, workersbackend="threading", ) - copied_dataset = ZDataset(path=output_path, mode="a") copied_array = copied_dataset.get_array("channel") assert copied_array.shape[0] == 4 diff --git a/dexp/datasets/operations/demo/demo_deconv.py b/dexp/datasets/operations/demo/demo_deconv.py index d6d211d4..e3341203 100644 --- a/dexp/datasets/operations/demo/demo_deconv.py +++ b/dexp/datasets/operations/demo/demo_deconv.py @@ -1,6 +1,6 @@ import random import tempfile -from os.path import join +from pathlib import Path from arbol import aprint, asection @@ -59,26 +59,52 @@ def _demo_deconv(length_xy=96, zoom=1, n=8, display=True): with tempfile.TemporaryDirectory() as tmpdir: aprint("created temporary directory", tmpdir) + tmpdir = Path(tmpdir) with asection("Prepare dataset..."): - input_path = join(tmpdir, "dataset.zarr") + input_path = tmpdir / "dataset.zarr" dataset = ZDataset(path=input_path, mode="w", store="dir") dataset.add_channel(name="channel", shape=images.shape, chunks=(1, 64, 64, 64), dtype=images.dtype) dataset.write_array(channel="channel", array=Backend.to_numpy(images)) + dataset.set_slicing((slice(2, 3), ...)) + source_array = dataset.get_array("channel") with asection("Deconvolve..."): # output_folder: - output_path = join(tmpdir, "deconv.zarr") + output_path = tmpdir / "deconv.zarr" + output_dataset = ZDataset(path=output_path, mode="w", parent=dataset) # Do deconvolution: - dataset_deconv(dataset=dataset, dest_path=output_path, channels=("channel",), slicing=(slice(2, 3), ...)) - - deconv_dataset = ZDataset(path=output_path, mode="a") - deconv_array = deconv_dataset.get_array("channel") + dataset_deconv( + input_dataset=dataset, + output_dataset=output_dataset, + channels=("channel",), + tilesize=320, + method="lr", + num_iterations=None, + max_correction=None, + power=1.0, + blind_spot=0, + back_projection="tpsf", + wb_order=5, + psf_objective="nikon16x08na", + psf_na=None, + psf_dxy=0.485, + psf_dz=4 * 0.485, + psf_xy_size=31, + psf_z_size=31, + scaling=(1, 1, 1), + devices=[0], + psf_show=False, + ) + + output_dataset.close() + + deconv_array = output_dataset.get_array("channel") assert deconv_array.shape[0] == 1 assert deconv_array.shape[1:] == source_array.shape[1:] @@ -94,6 +120,7 @@ def _c(array): viewer.add_image(_c(source_array), name="source_array") viewer.add_image(_c(deconv_array), name="deconv_array") viewer.grid.enabled = True + viewer.dims.set_point(0, 0) napari.run() diff --git a/dexp/datasets/operations/demo/demo_deskew.py b/dexp/datasets/operations/demo/demo_deskew.py index 2c2e68da..cb4c47d3 100644 --- a/dexp/datasets/operations/demo/demo_deskew.py +++ b/dexp/datasets/operations/demo/demo_deskew.py @@ -1,6 +1,5 @@ -import math import tempfile -from os.path import join +from pathlib import Path from arbol import aprint, asection @@ -8,7 +7,7 @@ from dexp.datasets.operations.deskew import dataset_deskew from dexp.datasets.synthetic_datasets import generate_nuclei_background_data from dexp.optics.psf.standard_psfs import nikon16x08na -from dexp.processing.filters.fft_convolve import fft_convolve +from dexp.processing.deskew import skew from dexp.utils.backends import Backend, CupyBackend, NumpyBackend @@ -27,9 +26,8 @@ def demo_deskew_cupy(): return False -def _demo_deskew(length=96, zoom=1, shift=1, display=True): +def _demo_deskew(length=96, zoom=1, shift=1, angle=45, display=True): xp = Backend.get_xp_module() - sp = Backend.get_sp_module() with asection("prepare simulated timelapse:"): # generate nuclei image: @@ -45,68 +43,51 @@ def _demo_deskew(length=96, zoom=1, shift=1, display=True): dtype=xp.float32, ) - # Pad: - pad_width = ( - (int(shift * zoom * length // 2), int(shift * zoom * length // 2)), - (0, 0), - (0, 0), - ) - image = xp.pad(image, pad_width=pad_width) - - # Add blur: - psf = nikon16x08na() - psf = psf.astype(dtype=image.dtype, copy=False) - image = fft_convolve(image, psf) - - # apply skew: - angle = 45 - matrix = xp.asarray( - [[math.cos(angle * math.pi / 180), math.sin(angle * math.pi / 180), 0], [0, 1, 0], [0, 0, 1]] - ) - offset = 0 * xp.asarray([image.shape[0] // 2, 0, 0]) - # matrix = xp.linalg.inv(matrix) - skewed = sp.ndimage.affine_transform(image, matrix, offset=offset) - - # Add noise and clip - # skewed += xp.random.uniform(-1, 1) - skewed = xp.clip(skewed, a_min=0, a_max=None) - - # cast to uint16: - skewed = skewed.astype(dtype=xp.uint16) - + skewed = skew(image, nikon16x08na(), shift=shift, angle=angle, zoom=zoom, axis=0) # Single timepoint: skewed = skewed[xp.newaxis, ...] with tempfile.TemporaryDirectory() as tmpdir: aprint("created temporary directory", tmpdir) + tmpdir = Path(tmpdir) with asection("Prepare dataset..."): - input_path = join(tmpdir, "dataset.zarr") - dataset = ZDataset(path=input_path, mode="w", store="dir") + input_path = tmpdir / "dataset.zarr" + input_dataset = ZDataset(path=input_path, mode="w", store="dir") - dataset.add_channel(name="channel", shape=skewed.shape, chunks=(1, 64, 64, 64), dtype=skewed.dtype) + input_dataset.add_channel(name="channel", shape=skewed.shape, chunks=(1, 64, 64, 64), dtype=skewed.dtype) - dataset.write_array(channel="channel", array=Backend.to_numpy(skewed)) + input_dataset.write_array(channel="channel", array=Backend.to_numpy(skewed)) - source_array = dataset.get_array("channel") + source_array = input_dataset.get_array("channel") with asection("Deskew..."): # output_folder: - output_path = join(tmpdir, "deskewed.zarr") + output_path = tmpdir / "deskewed.zarr" + output_dataset = ZDataset(path=output_path, mode="w") # deskew: dataset_deskew( - dataset=dataset, - dest_path=output_path, + input_dataset=input_dataset, + output_dataset=output_dataset, channels=("channel",), - slicing=(slice(0, 1),), + # slicing=(slice(0, 1),), + mode="yang", dx=1, dz=1, angle=angle, + flips=(False,), + camera_orientation=0, + depth_axis=0, + lateral_axis=1, + padding=False, + devices=[ + 0, + ], ) - deskew_dataset = ZDataset(path=output_path, mode="a") - deskew_array = deskew_dataset.get_array("channel") + output_dataset = ZDataset(path=output_path, mode="r") + deskew_array = output_dataset.get_array("channel") assert deskew_array.shape[0] == 1 diff --git a/dexp/datasets/operations/demo/demo_deskew_real.py b/dexp/datasets/operations/demo/demo_deskew_real.py index 7ac71fb4..03875898 100644 --- a/dexp/datasets/operations/demo/demo_deskew_real.py +++ b/dexp/datasets/operations/demo/demo_deskew_real.py @@ -1,7 +1,8 @@ import tempfile -from os.path import join +from pathlib import Path from arbol import aprint, asection +from toolz import curry from dexp.datasets import ZDataset from dexp.datasets.operations.deskew import dataset_deskew @@ -20,37 +21,54 @@ def demo_deskew_numpy(): def demo_deskew_cupy(): try: with CupyBackend(): - _demo_deskew(length=128, zoom=4) + _demo_deskew() return True except ModuleNotFoundError: aprint("Cupy module not found! demo ignored") return False -def _demo_deskew(length=96, zoom=1, shift=1, display=True): +def _demo_deskew(display=True): xp = Backend.get_xp_module() with asection("Deskew..."): with tempfile.TemporaryDirectory() as tmpdir: aprint("created temporary directory", tmpdir) + tmpdir = Path(tmpdir) # Open input dataset: dataset = ZDataset(path=input_path, mode="r") + dataset.set_slicing((slice(0, 1),)) # Input array: input_array = dataset.get_array("v0c0") + angle = 45 + + deskew_func = curry( + dataset_deskew, + input_dataset=dataset, + channels=("v0c0",), + dx=1, + dz=1, + angle=angle, + flips=(True,), + camera_orientation=0, + depth_axis=0, + lateral_axis=1, + padding=True, + devices=[ + 0, + ], + ) with asection("Deskew yang..."): # output path: - output_path = join(tmpdir, "deskewed_yang.zarr") + output_path = tmpdir / "deskewed_yang.zarr" + output_dataset = ZDataset(path=input_path, mode="w", parent=dataset) # deskew: - dataset_deskew( - dataset=dataset, - dest_path=output_path, - channels=("v0c0",), - flips=(True,), - slicing=(slice(0, 1),), + deskew_func( + output_dataset=output_dataset, mode="yang", ) @@ -62,17 +80,13 @@ def _demo_deskew(length=96, zoom=1, shift=1, display=True): with asection("Deskew classic..."): # output path: - output_path = join(tmpdir, "deskewed_classic.zarr") + output_path = tmpdir / "deskewed_classic.zarr" + output_dataset = ZDataset(path=input_path, mode="w", parent=dataset) # deskew: - dataset_deskew( - dataset=dataset, - dest_path=output_path, - channels=("v0c0",), - flips=(True,), - slicing=(slice(0, 1),), + deskew_func( + output_dataset=output_dataset, mode="classic", - padding=True, ) # read result diff --git a/dexp/datasets/operations/demo/demo_projrender.py b/dexp/datasets/operations/demo/demo_projrender.py index 44e4159a..23fbe9ef 100644 --- a/dexp/datasets/operations/demo/demo_projrender.py +++ b/dexp/datasets/operations/demo/demo_projrender.py @@ -1,7 +1,6 @@ -import os import random import tempfile -from os.path import join +from pathlib import Path from arbol import aprint, asection from dask.array.image import imread @@ -61,9 +60,10 @@ def _demo_projrender(length_xy=96, zoom=1, n=64, display=True): with tempfile.TemporaryDirectory() as tmpdir: aprint("created temporary directory", tmpdir) + tmpdir = Path(tmpdir) with asection("Prepare dataset..."): - input_path = join(tmpdir, "dataset.zarr") + input_path = tmpdir / "dataset.zarr" dataset = ZDataset(path=input_path, mode="w", store="dir") dataset.add_channel(name="channel", shape=images.shape, chunks=(1, 64, 64, 64), dtype=images.dtype) @@ -72,16 +72,21 @@ def _demo_projrender(length_xy=96, zoom=1, n=64, display=True): with asection("Project..."): # output_folder: - output_path = join(tmpdir, "projections") + output_path = tmpdir / "projections" # Do actual projection: dataset_projection_rendering( - input_dataset=dataset, channels=("channel",), output_path=output_path, legendsize=0.05 + input_dataset=dataset, + channels=("channel",), + output_path=output_path, + legend_size=0.05, + devices=[0], + overwrite=True, ) # load images into a dask array: - filename_pattern = os.path.join(output_path, "frame_*.png") - images = imread(filename_pattern) + filename_pattern = output_path / "frame_*.png" + images = imread(str(filename_pattern)) if display: from napari import Viewer, gui_qt diff --git a/dexp/datasets/operations/demo/demo_stabilize.py b/dexp/datasets/operations/demo/demo_stabilize.py index 4463314a..04835480 100644 --- a/dexp/datasets/operations/demo/demo_stabilize.py +++ b/dexp/datasets/operations/demo/demo_stabilize.py @@ -2,9 +2,8 @@ import tempfile from os.path import join -import numpy +import numpy as np from arbol import aprint, asection -from pytest import approx from dexp.datasets import ZDataset from dexp.datasets.operations.stabilize import dataset_stabilize @@ -120,6 +119,7 @@ def _demo_stabilize( input_path = join(tmpdir, "input.zarr") output_path = join(tmpdir, "output.zarr") input_dataset = ZDataset(path=input_path, mode="w", store="dir") + output_dataset = ZDataset(path=output_path, mode="w") zarr_input_array = input_dataset.add_channel( name="channel", shape=shifted.shape, chunks=(1, 50, 50, 50), dtype=shifted.dtype, codec="zstd", clevel=3 @@ -127,15 +127,13 @@ def _demo_stabilize( input_dataset.write_array(channel="channel", array=Backend.to_numpy(shifted)) dataset_stabilize( - dataset=input_dataset, - output_path=output_path, + input_dataset=input_dataset, + output_dataset=output_dataset, model_output_path=join(tmpdir, "model.json"), channels=["channel"], ) - output_dataset = ZDataset(path=output_path, mode="r", store="dir") - - zarr_output_array = numpy.asarray(output_dataset.get_array("channel")) + zarr_output_array = np.asarray(output_dataset.get_array("channel")) # stabilise with lower level API: model = image_stabilisation_proj(shifted, axis=0, projection_type="max") @@ -146,50 +144,49 @@ def _demo_stabilize( aprint(f"stabilised_seq_ll.shape={stabilised_seq_ll.shape}") aprint(f"zarr_output_array.shape={zarr_output_array.shape}") - assert approx(stabilised_seq_ll.shape[1], zarr_output_array.shape[1], abs=10) - assert approx(stabilised_seq_ll.shape[2], zarr_output_array.shape[2], abs=10) - assert approx(stabilised_seq_ll.shape[3], zarr_output_array.shape[3], abs=10) - error = numpy.mean(numpy.abs(stabilised_seq_ll[:, 0:64, 0:64, 0:64] - zarr_output_array[:, 0:64, 0:64, 0:64])) - error_null = numpy.mean( - numpy.abs(stabilised_seq_ll[:, 0:64, 0:64, 0:64] - Backend.to_numpy(shifted)[:, 0:64, 0:64, 0:64]) + np.testing.assert_allclose(stabilised_seq_ll.shape[1], zarr_output_array.shape[1], atol=10) + np.testing.assert_allclose(stabilised_seq_ll.shape[2], zarr_output_array.shape[2], atol=10) + np.testing.assert_allclose(stabilised_seq_ll.shape[3], zarr_output_array.shape[3], atol=10) + error = np.mean(np.abs(stabilised_seq_ll[:, 0:64, 0:64, 0:64] - zarr_output_array[:, 0:64, 0:64, 0:64])) + error_null = np.mean( + np.abs(stabilised_seq_ll[:, 0:64, 0:64, 0:64] - Backend.to_numpy(shifted)[:, 0:64, 0:64, 0:64]) ) aprint(f"error={error} versus error_null={error_null}") - assert error < 20 - assert error < error_null + assert error < error_null * 0.75 if display: - from napari import Viewer, gui_qt - - with gui_qt(): - - def _c(array): - return Backend.to_numpy(array) - - viewer = Viewer(ndisplay=2) - viewer.add_image(_c(image), name="image", colormap="bop orange", blending="additive", visible=True) - viewer.add_image(_c(shifted), name="shifted", colormap="bop purple", blending="additive", visible=True) - viewer.add_image( - _c(zarr_input_array), - name="zarr_input_array", - colormap="bop purple", - blending="additive", - visible=True, - ) - viewer.add_image( - _c(zarr_output_array), - name="zarr_output_array", - colormap="bop blue", - blending="additive", - visible=True, - ) - viewer.add_image( - _c(stabilised_seq_ll), - name="stabilised_seq_ll", - colormap="bop blue", - blending="additive", - visible=True, - ) - viewer.grid.enabled = True + import napari + + def _c(array): + return Backend.to_numpy(array) + + viewer = napari.Viewer(ndisplay=2) + viewer.add_image(_c(image), name="image", colormap="bop orange", blending="additive", visible=True) + viewer.add_image(_c(shifted), name="shifted", colormap="bop purple", blending="additive", visible=True) + viewer.add_image( + _c(zarr_input_array), + name="zarr_input_array", + colormap="bop purple", + blending="additive", + visible=True, + ) + viewer.add_image( + _c(zarr_output_array), + name="zarr_output_array", + colormap="bop blue", + blending="additive", + visible=True, + ) + viewer.add_image( + _c(stabilised_seq_ll), + name="stabilised_seq_ll", + colormap="bop blue", + blending="additive", + visible=True, + ) + viewer.grid.enabled = True + + napari.run() if __name__ == "__main__": diff --git a/dexp/datasets/operations/demo/demo_tiff.py b/dexp/datasets/operations/demo/demo_tiff.py index 91acb648..f5cdc630 100644 --- a/dexp/datasets/operations/demo/demo_tiff.py +++ b/dexp/datasets/operations/demo/demo_tiff.py @@ -1,6 +1,6 @@ import random import tempfile -from os.path import exists, join +from pathlib import Path from arbol import aprint, asection from tifffile import imread @@ -62,27 +62,29 @@ def _demo_tiff(length_xy=96, zoom=1, n=16, display=True): with tempfile.TemporaryDirectory() as tmpdir: aprint("created temporary directory", tmpdir) + tmpdir = Path(tmpdir) with asection("Prepare dataset..."): - input_path = join(tmpdir, "dataset.zarr") + input_path = tmpdir / "dataset.zarr" dataset = ZDataset(path=input_path, mode="w", store="dir") dataset.add_channel(name="channel", shape=images.shape, chunks=(1, 64, 64, 64), dtype=images.dtype) dataset.write_array(channel="channel", array=Backend.to_numpy(images)) + dataset.set_slicing((slice(0, 4), ...)) source_array = dataset.get_array("channel") with asection("Export to tiff..."): # output_folder: - output_path = join(tmpdir, "file") + output_path = tmpdir / "file" # Do actual tiff export: - dataset_tiff(dataset=dataset, dest_path=output_path, channels=("channel",), slicing=(slice(0, 4), ...)) + dataset_tiff(dataset=dataset, dest_path=output_path, channels=("channel",)) - output_file = output_path + ".tiff" + output_file = Path(str(output_path) + ".tiff") - assert exists(output_file) + assert output_file.exists() tiff_image = imread(output_file) @@ -112,16 +114,14 @@ def _c(array): with asection("Export to tiff, one file per timepoint..."): # output_folder: - output_path = join(tmpdir, "projection") + output_path = tmpdir / "projection" # Do actual tiff export: - dataset_tiff( - dataset=dataset, dest_path=output_path, channels=("channel",), slicing=(slice(0, 4), ...), project=0 - ) + dataset_tiff(dataset=dataset, dest_path=output_path, channels=("channel",), project=0) - output_file = output_path + ".tiff" + output_file = Path(str(output_path) + ".tiff") - assert exists(output_file) + assert output_file.exists() tiff_image = imread(output_file) diff --git a/dexp/datasets/operations/denoise.py b/dexp/datasets/operations/denoise.py index c9ef6999..e23e8c43 100644 --- a/dexp/datasets/operations/denoise.py +++ b/dexp/datasets/operations/denoise.py @@ -11,6 +11,7 @@ from dexp.processing.denoising import calibrate_denoise_butterworth, denoise_butterworth from dexp.processing.utils.scatter_gather_i2i import scatter_gather_i2i from dexp.utils.backends import CupyBackend +from dexp.utils.fft import clear_fft_plan_cache @curry @@ -32,6 +33,7 @@ def _process( denoised = scatter_gather(function=denoise_fun, image=stack) out_dataset.write_stack(channel, time_point, bkd.to_numpy(denoised)) + clear_fft_plan_cache() def dataset_denoise( diff --git a/dexp/datasets/operations/deskew.py b/dexp/datasets/operations/deskew.py index 9b1aee3b..c8ac04b6 100644 --- a/dexp/datasets/operations/deskew.py +++ b/dexp/datasets/operations/deskew.py @@ -1,171 +1,118 @@ -from typing import List, Optional, Sequence +from typing import Callable, List, Optional, Sequence +import dask +import fasteners from arbol.arbol import aprint, asection -from joblib import Parallel, delayed -from zarr.errors import ContainsArrayError, ContainsGroupError +from dask.distributed import Client +from dask_cuda import LocalCUDACluster +from toolz import curry from dexp.datasets import BaseDataset -from dexp.processing.deskew.classic_deskew import classic_deskew -from dexp.processing.deskew.yang_deskew import yang_deskew -from dexp.utils.backends import Backend, BestBackend +from dexp.datasets.stack_iterator import StackIterator +from dexp.datasets.zarr_dataset import ZDataset +from dexp.processing.deskew import deskew_functions +from dexp.utils.backends import BestBackend +from dexp.utils.lock import create_lock + + +@dask.delayed +@curry +def _process( + time_point: int, + stacks: StackIterator, + channel: str, + output_dataset: ZDataset, + lock: fasteners.InterProcessLock, + deskew_func: Callable, +) -> None: + with asection(f"Deskweing channel {channel} at time point {time_point}."): + with BestBackend() as bkd: + with asection("Loading data"): + stack = bkd.to_backend(stacks[time_point]) + with asection("Processing"): + stack = deskew_func(stack) + stack = bkd.to_numpy(stack) + + with lock: + if channel not in output_dataset: + output_dataset.add_channel( + channel, + (len(stacks),) + stack.shape, + dtype=stack.dtype, + ) + + with asection("Saving deskwed array"): + output_dataset.write_stack(channel, time_point, stack) def dataset_deskew( - dataset: BaseDataset, - dest_path: str, + input_dataset: BaseDataset, + output_dataset: ZDataset, channels: Sequence[str], - slicing, - dx: Optional[float] = None, - dz: Optional[float] = None, - angle: Optional[float] = None, - flips: Optional[Sequence[bool]] = None, - camera_orientation: int = 0, - depth_axis: int = 0, - lateral_axis: int = 1, - mode: str = "classic", - padding: bool = True, - store: str = "dir", - compression: str = "zstd", - compression_level: int = 3, - overwrite: bool = False, - workers: int = 1, - workersbackend: str = "threading", - devices: Optional[List[int]] = None, - check: bool = True, - stop_at_exception=True, + dx: Optional[float], + dz: Optional[float], + angle: Optional[float], + flips: Sequence[bool], + camera_orientation: int, + depth_axis: int, + lateral_axis: int, + mode: str, + padding: bool, + devices: List[int], ): - - # Collect arrays for selected channels: - arrays = tuple(dataset.get_array(channel, per_z_slice=False, wrap_with_dask=True) for channel in channels) - - # Notify of selected channels and corresponding properties: - with asection("Channels:"): - for view, channel in zip(arrays, channels): - aprint(f"Channel: {channel} of shape: {view.shape} and dtype: {view.dtype}") - - # Slicing: - if slicing is not None: - aprint(f"Slicing with: {slicing}") - arrays = tuple(view[slicing] for view in arrays) - # Default flipping: if flips is None: - flips = (False,) * len(arrays) - - # Default devices: - if devices is None: - devices = [0] - - # We allocate last minute once we know the shape... - from dexp.datasets import ZDataset - - zarr_mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(dest_path, zarr_mode, store, parent=dataset) + flips = (False,) * len(channels) # Metadata for deskewing: - metadata = dataset.get_metadata() + metadata = input_dataset.get_metadata() aprint(f"Dataset metadata: {metadata}") if dx is None and "res" in metadata: dx = float(metadata["res"]) if dz is None and "dz" in metadata: - dz = float(metadata["dz"]) if dz is None else dz + dz = float(metadata["dz"]) if angle is None and "angle" in metadata: - angle = float(metadata["angle"]) if angle is None else angle + angle = float(metadata["angle"]) + # setting up fixed parameters aprint(f"Deskew parameters: dx={dx}, dz={dz}, angle={angle}") - - # Iterate through channels:: - for array, channel, flip in zip(arrays, channels, flips): - - # shape and dtype of views to deskew: - shape = array[0].shape - nb_timepoints = array.shape[0] - - def process(channel, tp, device): - try: - - with asection(f"Loading channel {channel} for time point {tp}"): - array_tp = array[tp].compute() - - with BestBackend(device, exclusive=True, enable_unified_memory=True): - - if "yang" in mode: - deskewed_view_tp = yang_deskew( - image=array_tp, - depth_axis=depth_axis, - lateral_axis=lateral_axis, - flip_depth_axis=flip, - dx=dx, - dz=dz, - angle=angle, - camera_orientation=camera_orientation, - ) - elif "classic" in mode: - deskewed_view_tp = classic_deskew( - image=array_tp, - depth_axis=depth_axis, - lateral_axis=lateral_axis, - flip_depth_axis=flip, - dx=dx, - dz=dz, - angle=angle, - camera_orientation=camera_orientation, - padding=padding, - ) - else: - raise ValueError(f"Deskew mode: {mode} not supported.") - - with asection("Moving array from backend to numpy."): - deskewed_view_tp = Backend.to_numpy(deskewed_view_tp, dtype=array.dtype, force_copy=False) - - if channel not in dest_dataset.channels(): - try: - dest_dataset.add_channel( - channel, - shape=(array.shape[0],) + deskewed_view_tp.shape, - dtype=deskewed_view_tp.dtype, - codec=compression, - clevel=compression_level, - ) - except (ContainsArrayError, ContainsGroupError): - aprint("Other thread/process created channel before... ") - - with asection( - f"Saving fused stack for time point {tp}, shape:{deskewed_view_tp.shape}, " - + f"dtype:{deskewed_view_tp.dtype}" - ): - dest_dataset.write_stack(channel=channel, time_point=tp, stack_array=deskewed_view_tp) - - aprint(f"Done processing time point: {tp} .") - - except Exception as error: - aprint(error) - aprint(f"Error occurred while deskewing time point {tp} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - if workers == -1: - workers = len(devices) - aprint(f"Number of workers: {workers}") - - if workers > 1: - Parallel(n_jobs=workers, backend=workersbackend)( - delayed(process)(channel, tp, devices[tp % len(devices)]) for tp in range(0, shape[0]) + deskew_func = curry( + deskew_functions[mode], + depth_axis=depth_axis, + lateral_axis=lateral_axis, + dx=dx, + dz=dz, + angle=angle, + camera_orientation=camera_orientation, + padding=padding, + ) + + lazy_computations = [] + + # Iterate through channels + for i, channel in enumerate(channels): + stacks = input_dataset[channel] + lock = create_lock(channel) + lazy_computations += [ + _process( + time_point=t, + stacks=stacks, + channel=channel, + lock=lock, + output_dataset=output_dataset, + deskew_func=deskew_func(flip_depth_axis=flips[i]), ) - else: - for tp in range(0, nb_timepoints): - process(channel, tp, devices[0]) + for t in range(len(stacks)) + ] + + # setting up dask compute scheduler + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) - # Dataset info: - aprint(dest_dataset.info()) + aprint("Dask client", client) + dask.compute(*lazy_computations) - # Check dataset integrity: - if check: - dest_dataset.check_integrity() + # shape and dtype of views to deskew: + output_dataset.check_integrity() - # close destination dataset: - dest_dataset.close() + aprint(output_dataset.info()) diff --git a/dexp/datasets/operations/extract_psf.py b/dexp/datasets/operations/extract_psf.py index deb3fb6e..91e9cd55 100644 --- a/dexp/datasets/operations/extract_psf.py +++ b/dexp/datasets/operations/extract_psf.py @@ -1,67 +1,63 @@ -from typing import Sequence, Union +import warnings +from typing import Callable, Sequence +import dask import numpy as np from arbol import aprint, asection +from dask.distributed import Client +from dask_cuda import LocalCUDACluster +from toolz import curry -from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.datasets.zarr_dataset import ZDataset from dexp.processing.remove_beads import BeadsRemover -from dexp.utils.backends import BestBackend -from dexp.utils.slicing import slice_from_shape +from dexp.utils.backends.best_backend import BestBackend + + +@curry +def _process( + time_point: int, + stacks: StackIterator, + create_beads_remover: Callable[..., BeadsRemover], +) -> np.ndarray: + + with BestBackend() as bkd: + beads_remover = create_beads_remover() + with asection(f"Removing beads from time point {time_point}."): + return beads_remover.detect_beads(bkd.to_backend(stacks[time_point])) def dataset_extract_psf( - dataset: BaseDataset, + input_dataset: ZDataset, dest_path: str, channels: Sequence[str], - slicing: Union[Sequence[slice], slice], - peak_threshold: int = 500, - similarity_threshold: float = 0.5, - psf_size: int = 35, - device: int = 0, - stop_at_exception: bool = True, + peak_threshold: int, + similarity_threshold: float, + psf_size: int, + devices: Sequence[int], verbose: bool = True, ) -> None: """ Computes PSF from beads. Additional information at dexp.processing.remove_beads.beadremover documentation. """ + warnings.warn("This command is subpar, it should be improved.") dest_path = dest_path.split(".")[0] - remove_beads = BeadsRemover( - peak_threshold=peak_threshold, similarity_threshold=similarity_threshold, psf_size=psf_size, verbose=verbose - ) - - for channel in dataset._selected_channels(channels): - array = dataset.get_array(channel) - _, volume_slicing, time_points = slice_from_shape(array.shape, slicing) - - psfs = [] - for i in range(len(time_points)): - tp = time_points[i] - try: - with asection(f"Removing beads of channel: {channel}"): - with asection(f"Loading time point {i}/{len(time_points)}"): - tp_array = np.asarray(array[tp][volume_slicing]) - - with asection("Processing"): - with BestBackend(exclusive=True, enable_unified_memory=True, device_id=device) as backend: - tp_array = backend.to_backend(tp_array) - estimated_psf = remove_beads.detect_beads(tp_array) - - aprint(f"Done extracting PSF from time point: {i}/{len(time_points)} .") - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {i} !") - import traceback + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + aprint("Dask client", client) - traceback.print_exc() + def create_beads_remover() -> BeadsRemover: + return BeadsRemover( + peak_threshold=peak_threshold, similarity_threshold=similarity_threshold, psf_size=psf_size, verbose=verbose + ) - if stop_at_exception: - raise error + for ch in channels: + stacks = input_dataset[ch] + lazy_computations = [_process(t, stacks, create_beads_remover) for t in range(len(stacks))] - psfs.append(estimated_psf) + psf = np.stack(dask.compute(*lazy_computations)).mean(axis=0) + np.save(dest_path + ch + ".npy", psf) - psfs = np.stack(psfs).mean(axis=0) - print(psfs.shape) - np.save(dest_path + channel + ".npy", psfs) + client.close() diff --git a/dexp/datasets/operations/fromraw.py b/dexp/datasets/operations/fromraw.py index 8fe425a0..59312f43 100644 --- a/dexp/datasets/operations/fromraw.py +++ b/dexp/datasets/operations/fromraw.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from arbol import aprint, asection from joblib import Parallel, delayed @@ -17,11 +19,15 @@ def _write(in_ds: CCDataset, out_ds: ZDataset, channel: str, indices: np.ndarray def dataset_fromraw( in_dataset: CCDataset, out_dataset: ZDataset, - channel_prefix: str, + channel_prefix: Optional[str], workers: int, ) -> None: - channels = list(filter(lambda x: x.startswith(channel_prefix), in_dataset.channels())) + if channel_prefix is not None: + channels = list(filter(lambda x: x.startswith(channel_prefix), in_dataset.channels())) + else: + channels = in_dataset.channels() + aprint(f"Selected channels {channels}") argmin = channels[0] diff --git a/dexp/datasets/operations/fuse.py b/dexp/datasets/operations/fuse.py index 913ec164..0edf0db7 100644 --- a/dexp/datasets/operations/fuse.py +++ b/dexp/datasets/operations/fuse.py @@ -1,259 +1,295 @@ +from pathlib import Path +from typing import Callable, Dict, List, Optional, Sequence, Tuple + import dask import numpy as np from arbol.arbol import aprint, asection +from dask.distributed import Client +from dask_cuda import LocalCUDACluster from toolz import curry -from dexp.datasets import ZDataset +from dexp.datasets import BaseDataset, ZDataset +from dexp.datasets.stack_iterator import StackIterator from dexp.processing.multiview_lightsheet.fusion.mvsols import msols_fuse_1C2L from dexp.processing.multiview_lightsheet.fusion.simview import SimViewFusion from dexp.processing.registration.model.model_io import ( model_list_from_file, model_list_to_file, ) -from dexp.utils.backends import Backend, BestBackend, NumpyBackend -from dexp.utils.slicing import slice_from_shape +from dexp.processing.registration.model.pairwise_registration_model import ( + PairwiseRegistrationModel, +) +from dexp.utils import xpArray +from dexp.utils.backends import Backend, BestBackend + + +def load_registration_models(model_list_filename: Path, n_time_pts: int) -> Sequence[PairwiseRegistrationModel]: + aprint(f"Loading registration shifts from existing file! ({model_list_filename})") + models = model_list_from_file(model_list_filename) + if len(models) == 1: + models = [models[0] for _ in range(n_time_pts)] + elif len(models) != n_time_pts: + raise ValueError( + f"Number of registration models provided ({len(models)})" + f"differs from number of input time points ({n_time_pts})" + ) + return models + + +@curry +def get_fusion_func( + model: Optional[PairwiseRegistrationModel], + microscope: str, + metadata: Dict, + equalisation_ratios: Optional[List], + equalise: bool, + zero_level: float, + clip_too_high: float, + fusion: str, + fusion_bias_strength_i: float, + fusion_bias_strength_d: float, + dehaze_size: int, + dark_denoise_threshold: int, + z_pad_apodise: Tuple[int, int], + loadreg: bool, + warpreg_num_iterations: int, + min_confidence: float, + max_change: float, + registration_edge_filter: bool, + maxproj: bool, + huge_dataset: bool, + pad: bool, + white_top_hat_size: float, + white_top_hat_sampling: int, + remove_beads: bool, +) -> Callable[[Dict], Tuple[xpArray, List, PairwiseRegistrationModel]]: + + if microscope == "simview": + if equalisation_ratios is None: + equalisation_ratios = [None, None, None] + + def fusion_func(views: Dict) -> Tuple: + fuse_obj = SimViewFusion( + registration_model=model, + equalise=equalise, + equalisation_ratios=equalisation_ratios, + zero_level=zero_level, + clip_too_high=clip_too_high, + fusion=fusion, + fusion_bias_exponent=2, + fusion_bias_strength_i=fusion_bias_strength_i, + fusion_bias_strength_d=fusion_bias_strength_d, + dehaze_before_fusion=True, + dehaze_size=dehaze_size, + dehaze_correct_max_level=True, + dark_denoise_threshold=dark_denoise_threshold, + dark_denoise_size=9, + white_top_hat_size=white_top_hat_size, + white_top_hat_sampling=white_top_hat_sampling, + remove_beads=remove_beads, + butterworth_filter_cutoff=1, + flip_camera1=True, + pad=pad, + ) + + stack = fuse_obj(**views) + new_equalisation_ratios = fuse_obj._equalisation_ratios + return stack, new_equalisation_ratios, model + + elif microscope == "mvsols": + angle = metadata["angle"] + dz = metadata["dz"] + res = metadata["res"] + illumination_correction_sigma = metadata["ic_sigma"] if "ic_sigma" in metadata else None + if equalisation_ratios is None: + equalisation_ratios = [None] + + def fusion_func(views: Dict) -> Tuple: + stack, new_model, new_equalisation_ratios = msols_fuse_1C2L( + *list(views.values()), + z_pad=z_pad_apodise[0], + z_apodise=z_pad_apodise[1], + registration_num_iterations=warpreg_num_iterations, + registration_force_model=loadreg, + registration_model=model, + registration_min_confidence=min_confidence, + registration_max_change=max_change, + registration_edge_filter=registration_edge_filter, + equalise=equalise, + equalisation_ratios=equalisation_ratios, + zero_level=zero_level, + clip_too_high=clip_too_high, + fusion=fusion, + fusion_bias_exponent=2, + fusion_bias_strength_x=fusion_bias_strength_i, + dehaze_size=dehaze_size, + dark_denoise_threshold=dark_denoise_threshold, + angle=angle, + dx=res, + dz=dz, + illumination_correction_sigma=illumination_correction_sigma, + registration_mode="projection" if maxproj else "full", + huge_dataset_mode=huge_dataset, + ) + return stack, new_equalisation_ratios, new_model + + else: + raise NotImplementedError + + return fusion_func + + +@curry +def _process( + time_point: int, + views: Dict[str, StackIterator], + out_dataset: ZDataset, + fusion_func: Callable, +) -> Tuple[List, PairwiseRegistrationModel]: + + stack = list(views.values())[0] + in_shape = stack.shape + dtype = stack.dtype + + with asection(f"Fusing time point for time point {time_point}/{in_shape[0]}"): + with asection(f"Loading channels {list(views.keys())}"): + views_tp = {k: np.asarray(view[time_point]) for k, view in views.items()} + + with BestBackend(): + stack, new_equalisation_ratios, model = fusion_func(views_tp) + + with asection("Moving array from backend to numpy."): + stack = Backend.to_numpy(stack, dtype=dtype, force_copy=False) + + aprint(f"Last equalisation ratios: {new_equalisation_ratios}") + + # moving to numpy to allow pickling + if model is not None: + model = model.to_numpy() + + new_equalisation_ratios = [None if v is None else Backend.to_numpy(v) for v in new_equalisation_ratios] + + with asection(f"Saving fused stack for time point {time_point}"): + + if time_point == 0: + # We allocate last minute once we know the shape... because we don't always know + # the shape in advance!! + # @jordao NOTE: that could be pre computed + out_dataset.add_channel("fused", shape=(in_shape[0],) + stack.shape, dtype=stack.dtype) + + out_dataset.write_stack(channel="fused", time_point=time_point, stack_array=stack) + + aprint(f"Done processing time point: {time_point} / {in_shape[0]}.") + + return new_equalisation_ratios, model def dataset_fuse( - dataset, - output_path, - channels, - slicing, - store, - compression, - compression_level, - overwrite, - microscope, - equalise, - equalise_mode, - zero_level, - clip_too_high, - fusion, - fusion_bias_strength_i, - fusion_bias_strength_d, - dehaze_size, - dark_denoise_threshold, - z_pad_apodise, - loadreg, - model_list_filename, - warpreg_num_iterations, - min_confidence, - max_change, - registration_edge_filter, - maxproj, - huge_dataset, - devices, - check, - pad, - white_top_hat_size, - white_top_hat_sampling, - remove_beads, - stop_at_exception=True, + input_dataset: BaseDataset, + output_dataset: ZDataset, + microscope: str, + model_list_filename: str, + channels: Sequence[str], + equalise: bool, + equalise_mode: str, + zero_level: float, + clip_too_high: float, + fusion: str, + fusion_bias_strength_i: float, + fusion_bias_strength_d: float, + dehaze_size: int, + dark_denoise_threshold: int, + z_pad_apodise: Tuple[int, int], + loadreg: bool, + warpreg_num_iterations: int, + min_confidence: float, + max_change: float, + registration_edge_filter: bool, + maxproj: bool, + huge_dataset: bool, + pad: bool, + white_top_hat_size: float, + white_top_hat_sampling: int, + remove_beads: bool, + devices: Sequence[int], ): - views = {channel.split("-")[-1]: dataset.get_array(channel, per_z_slice=False) for channel in channels} + views = {channel.split("-")[-1]: input_dataset[channel] for channel in channels} + n_time_pts = len(list(views.values())[0]) with asection("Views:"): for channel, view in views.items(): aprint(f"View: {channel} of shape: {view.shape} and dtype: {view.dtype}") - key = list(views.keys())[0] - dtype = views[key].dtype - in_nb_time_points = views[key].shape[0] - aprint(f"Slicing with: {slicing}") - _, volume_slicing, time_points = slice_from_shape(views[key].shape, slicing) - if microscope == "simview": views = SimViewFusion.validate_views(views) - # load registration models: - with NumpyBackend(): - if loadreg: - aprint(f"Loading registration shifts from existing file! ({model_list_filename})") - models = model_list_from_file(model_list_filename) - if len(models) == 1: - models = [models[0] for _ in range(in_nb_time_points)] - elif len(models) != in_nb_time_points: - raise ValueError( - f"Number of registration models provided ({len(models)})" - f"differs from number of input time points ({in_nb_time_points})" - ) - else: - models = [None] * len(time_points) - - @curry - def process(i, params, device_id=0): - equalisation_ratios_reference, model, dest_dataset = params - tp = time_points[i] - try: - with asection(f"Fusing time point for time point {i}/{len(time_points)}"): - with asection(f"Loading channels {channels}"): - views_tp = {k: np.asarray(view[tp][volume_slicing]) for k, view in views.items()} - - with BestBackend(exclusive=True, enable_unified_memory=True, device_id=device_id): - if models[i] is not None: - model = models[i] - # otherwise it could be a None model or from the first iteration if equalisation mode was 'first' - if microscope == "simview": - fuse_obj = SimViewFusion( - registration_model=model, - equalise=equalise, - equalisation_ratios=equalisation_ratios_reference, - zero_level=zero_level, - clip_too_high=clip_too_high, - fusion=fusion, - fusion_bias_exponent=2, - fusion_bias_strength_i=fusion_bias_strength_i, - fusion_bias_strength_d=fusion_bias_strength_d, - dehaze_before_fusion=True, - dehaze_size=dehaze_size, - dehaze_correct_max_level=True, - dark_denoise_threshold=dark_denoise_threshold, - dark_denoise_size=9, - white_top_hat_size=white_top_hat_size, - white_top_hat_sampling=white_top_hat_sampling, - remove_beads=remove_beads, - butterworth_filter_cutoff=1, - flip_camera1=True, - pad=pad, - ) - - tp_array = fuse_obj(**views_tp) - new_equalisation_ratios = fuse_obj._equalisation_ratios - - elif microscope == "mvsols": - metadata = dataset.get_metadata() - angle = metadata["angle"] - dz = metadata["dz"] - res = metadata["res"] - illumination_correction_sigma = metadata["ic_sigma"] if "ic_sigma" in metadata else None - - tp_array, model, new_equalisation_ratios = msols_fuse_1C2L( - *list(views_tp.values()), - z_pad=z_pad_apodise[0], - z_apodise=z_pad_apodise[1], - registration_num_iterations=warpreg_num_iterations, - registration_force_model=loadreg, - registration_model=model, - registration_min_confidence=min_confidence, - registration_max_change=max_change, - registration_edge_filter=registration_edge_filter, - equalise=equalise, - equalisation_ratios=equalisation_ratios_reference, - zero_level=zero_level, - clip_too_high=clip_too_high, - fusion=fusion, - fusion_bias_exponent=2, - fusion_bias_strength_x=fusion_bias_strength_i, - dehaze_size=dehaze_size, - dark_denoise_threshold=dark_denoise_threshold, - angle=angle, - dx=res, - dz=dz, - illumination_correction_sigma=illumination_correction_sigma, - registration_mode="projection" if maxproj else "full", - huge_dataset_mode=huge_dataset, - ) - else: - raise NotImplementedError - - with asection("Moving array from backend to numpy."): - tp_array = Backend.to_numpy(tp_array, dtype=dtype, force_copy=False) - - del views_tp - Backend.current().clear_memory_pool() - - aprint(f"Last equalisation ratios: {new_equalisation_ratios}") - - # moving to numpy to allow pickling - if model is not None: - model = model.to_numpy() - new_equalisation_ratios = [ - None if v is None else Backend.to_numpy(v) for v in new_equalisation_ratios - ] - - with asection(f"Saving fused stack for time point {i}, shape:{tp_array.shape}, dtype:{tp_array.dtype}"): - - if i == 0: - # We allocate last minute once we know the shape... because we don't always know - # the shape in advance!!! - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(output_path, mode, store, parent=dataset) - dest_dataset.add_channel( - "fused", - shape=(len(time_points),) + tp_array.shape, - dtype=tp_array.dtype, - codec=compression, - clevel=compression_level, - ) - - dest_dataset.write_stack(channel="fused", time_point=i, stack_array=tp_array) - - aprint(f"Done processing time point: {i}/{len(time_points)} .") - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {i} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - return new_equalisation_ratios, model, dest_dataset - - if len(devices) == 1: - process = process(device_id=devices[0]) - else: - from dask.distributed import Client - from dask_cuda import LocalCUDACluster + if loadreg: + models = load_registration_models(model_list_filename, n_time_pts) + + # returns partial fusion function + fusion_func = get_fusion_func( + microscope=microscope, + metadata=input_dataset.get_metadata(), + equalise=equalise, + zero_level=zero_level, + clip_too_high=clip_too_high, + fusion=fusion, + fusion_bias_strength_i=fusion_bias_strength_i, + fusion_bias_strength_d=fusion_bias_strength_d, + dehaze_size=dehaze_size, + dark_denoise_threshold=dark_denoise_threshold, + z_pad_apodise=z_pad_apodise, + warpreg_num_iterations=warpreg_num_iterations, + min_confidence=min_confidence, + max_change=max_change, + registration_edge_filter=registration_edge_filter, + maxproj=maxproj, + huge_dataset=huge_dataset, + pad=pad, + loadreg=loadreg, + white_top_hat_size=white_top_hat_size, + white_top_hat_sampling=white_top_hat_sampling, + remove_beads=remove_beads, + ) - process = dask.delayed(process) + # it creates the output dataset from the first time point output shape + equalisation_ratios, model = _process( + time_point=0, + views=views, + out_dataset=output_dataset, + fusion_func=fusion_func( + model=models[0] if loadreg else None, + equalisation_ratios=None, + ), + ) - cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) - client = Client(cluster) + if equalise_mode == "all": + equalisation_ratios = None - aprint("Dask Client", client) + output_models = [model] - # the parameters are (equalisation rations, registration model, dest. dataset) - if microscope == "simview": - init_params = ([None, None, None], None, None) - elif microscope == "mvsols": - init_params = ([None], None, None) - else: - raise NotImplementedError + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + aprint("Dask Client", client) - # it creates the output dataset from the first time point output shape - params = process(0, init_params) - if len(devices) > 1: - params = params.persist() + lazy_computations = [] + for t in range(1, n_time_pts): + lazy_computations.append( + dask.delayed(_process)( + time_point=t, + views=views, + out_dataset=output_dataset, + fusion_func=fusion_func(equalisation_ratios=equalisation_ratios, model=models[t] if loadreg else None), + ) + ) - if equalise_mode == "all": - params = init_params[:2] + (params[2],) - - lazy_computations = [] # it is only lazy if len(devices) > 1 - for i in range(1, len(time_points)): - lazy_computations.append(process(i, params)) - - if len(devices) > 1: - if equalise_mode == "all": - first_model, dest_dataset = params[1:] # not a dask operation - else: - first_model, dest_dataset = params.compute()[1:] - models = [first_model] + [output[1] for output in dask.compute(*lazy_computations)] - else: - models = [params[1]] + [output[1] for output in lazy_computations] - dest_dataset = params[2] + # compute remaining stacks and save models + output_models += [output[1] for output in dask.compute(*lazy_computations)] - if not loadreg and models[0] is not None: - model_list_to_file(model_list_filename, models) + if not loadreg and output_models[0] is not None: + model_list_to_file(model_list_filename, output_models) - aprint(dest_dataset.info()) - if check: - dest_dataset.check_integrity() + aprint(output_dataset.info()) - # close destination dataset: - dest_dataset.close() - if len(devices) > 1: - client.close() + output_dataset.check_integrity() diff --git a/dexp/datasets/operations/generic.py b/dexp/datasets/operations/generic.py new file mode 100644 index 00000000..2b08945d --- /dev/null +++ b/dexp/datasets/operations/generic.py @@ -0,0 +1,61 @@ +from typing import Callable, Optional, Sequence, Tuple + +import dask +from arbol import aprint, asection +from dask.distributed import Client +from dask_cuda import LocalCUDACluster +from toolz import curry + +from dexp.datasets import BaseDataset, ZDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.processing.utils.scatter_gather_i2i import scatter_gather_i2i +from dexp.utils.backends import CupyBackend + + +@dask.delayed +@curry +def _process( + time_point: int, + stacks: StackIterator, + out_dataset: ZDataset, + channel: str, + func: Callable, +) -> None: + + with CupyBackend() as bkd: + with asection(f"Applying {func.__name__} for channel {channel} at time point {time_point}"): + stack = bkd.to_backend(stacks[time_point]) + stack = func(stack) + out_dataset.write_stack(channel, time_point, bkd.to_numpy(stack)) + + +def dataset_generic( + input_dataset: BaseDataset, + output_dataset: ZDataset, + channels: Sequence[str], + func: Callable, + tilesize: Optional[Tuple[int]], + devices: Sequence[int], +) -> None: + lazy_computations = [] + + if tilesize is not None: + func = curry(scatter_gather_i2i, function=func, tiles=tilesize, margins=32) + + for ch in channels: + stacks = input_dataset[ch] + output_dataset.add_channel(ch, stacks.shape, dtype=input_dataset.dtype(ch)) + + process = _process(stacks=stacks, out_dataset=output_dataset, channel=ch, func=func) + + # Stores functions to be computed + lazy_computations += [process(time_point=t) for t in range(len(stacks))] + + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + aprint("Dask client", client) + + # Compute everything + dask.compute(*lazy_computations) + + output_dataset.check_integrity() diff --git a/dexp/datasets/operations/isonet.py b/dexp/datasets/operations/isonet.py deleted file mode 100644 index e3bb31f8..00000000 --- a/dexp/datasets/operations/isonet.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Optional, Sequence - -import numpy -from arbol.arbol import aprint -from skimage.transform import downscale_local_mean - -from dexp.datasets import BaseDataset -from dexp.utils.timeit import timeit - - -def dataset_isonet( - dataset: BaseDataset, - path: str, - channel: Optional[Sequence[str]], - slicing, - store: str, - compression: str, - compression_level: int, - overwrite: bool, - context: str, - mode: str, - dxy: float, - dz: float, - binning: int, - sharpening: bool, - training_tp_index: Optional[int], - max_epochs: int, - check: bool, -): - if channel is None: - channel = "fused" - - if training_tp_index is None: - training_tp_index = dataset.nb_timepoints(channel) // 2 - - aprint(f"Selected channel {channel}") - - aprint("getting Dask arrays to apply isonet on...") - array = dataset.get_array(channel, per_z_slice=False, wrap_with_dask=True) - - if slicing is not None: - aprint(f"Slicing with: {slicing}") - array = array[slicing] - - aprint(f"Binning image by a factor {binning}...") - dxy *= binning - - subsampling = dz / dxy - aprint(f"Parameters: dxy={dxy}, dz={dz}, subsampling={subsampling}") - - psf = numpy.ones((1, 1)) / 1 - print(f"PSF (along xy): {psf}") - from dexp.processing.isonet import IsoNet - - isonet = IsoNet(context, subsampling=subsampling) - - if "p" in mode: - training_array_tp = array[training_tp_index].compute() - training_downscaled_array_tp = downscale_local_mean(training_array_tp, factors=(1, binning, binning)) - print(f"Training image shape: {training_downscaled_array_tp.shape} ") - isonet.prepare(training_downscaled_array_tp, psf=psf, threshold=0.999) - - if "t" in mode: - isonet.train(max_epochs=max_epochs) - - if "a" in mode: - from dexp.datasets import ZDataset - - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(path, mode, store, parent=dataset) - zarr_array = None - - for tp in range(0, array.shape[0] - 1): - with timeit("Elapsed time: "): - - aprint(f"Processing time point: {tp} ...") - tp_array = array[tp].compute() - - aprint("Downscaling image...") - tp_array = downscale_local_mean(tp_array, factors=(1, binning, binning)) - # array_tp_downscaled = zoom(tp_array, zoom=(1, 1.0/binning, 1.0/binning), order=0) - - if sharpening: - aprint("Sharpening image...") - from dexp.processing.restoration import dehazing - - tp_array = dehazing(tp_array, mode="hybrid", min=0, max=1024, margin_pad=False) - - aprint("Applying IsoNet to image...") - tp_array = isonet.apply(tp_array) - - aprint(f"Result: image of shape: {tp_array.shape}, dtype: {tp_array.dtype} ") - - if zarr_array is None: - shape = (array.shape[0],) + tp_array.shape - aprint(f"Creating Zarr array of shape: {shape} ") - zarr_array = dest_dataset.add_channel( - name=channel, shape=shape, dtype=array.dtype, codec=compression, clevel=compression_level - ) - - aprint("Writing image to Zarr file...") - dest_dataset.write_stack(channel=channel, time_point=tp, stack_array=tp_array) - - aprint(f"Done processing time point: {tp}") - - aprint(dest_dataset.info()) - if check: - dest_dataset.check_integrity() - - # close destination dataset: - dest_dataset.close() diff --git a/dexp/datasets/operations/projrender.py b/dexp/datasets/operations/projrender.py index dbfe1ccf..649d8fe2 100644 --- a/dexp/datasets/operations/projrender.py +++ b/dexp/datasets/operations/projrender.py @@ -1,129 +1,84 @@ -from os import makedirs -from os.path import exists, join -from typing import Any, Sequence, Tuple +from pathlib import Path +from typing import Callable, Sequence +import dask import imageio +import numpy as np from arbol.arbol import aprint, asection -from joblib import Parallel, delayed +from dask.distributed import Client +from toolz import curry from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator from dexp.processing.color.projection import project_image -from dexp.utils.backends import Backend, BestBackend +from dexp.utils.backends import BestBackend +from dexp.utils.backends.cupy_backend import is_cupy_available + + +@dask.delayed +@curry +def _process(time_point: int, stacks: StackIterator, outpath: Path, project_func: Callable, overwrite: bool) -> None: + with asection(f"Rendering Frame \t: {time_point:05}"): + + filename = outpath / f"frame_{time_point:05}.png" + + if overwrite or not filename.exists(): + + with asection("Loading stack..."): + stack = np.asarray(stacks[time_point]) + + with BestBackend() as bkd: + with asection(f"Projecting image of shape: {stack.shape} "): + projection = bkd.to_numpy(project_func(bkd.to_backend(stack))) + + with asection(f"Saving frame {time_point} as: {filename}"): + imageio.imwrite(filename, projection, compress_level=1) def dataset_projection_rendering( input_dataset: BaseDataset, - output_path: str = None, - channels: Sequence[str] = None, - slicing: Any = None, - overwrite: bool = False, - axis: int = 0, - dir: int = -1, - mode: str = "colormax", - clim: Tuple[float, float] = None, - attenuation: float = 0.05, - gamma: float = 1.0, - dlim: Tuple[float, float] = None, - colormap: str = None, - rgbgamma: float = 1.0, - transparency: bool = False, - legendsize: float = 1.0, - legendscale: float = 1.0, - legendtitle: str = "color-coded depth (voxels)", - legendtitlecolor: Tuple[float, float, float, float] = (1, 1, 1, 1), - legendposition: str = "bottom_left", - legendalpha: float = 1.0, - step: int = 1, - workers: int = -1, - workersbackend: str = "threading", - devices: Tuple[int, ...] = (0,), - stop_at_exception=True, + output_path: Path, + channels: Sequence[str], + overwrite: bool, + devices: Sequence[int], + **kwargs, ): - for channel in channels: + project_func = curry(project_image, **kwargs) + lazy_computations = [] + for channel in channels: # Ensures that the output folder exists per channel: if len(channels) == 1: channel_output_path = output_path else: - channel_output_path = output_path + f"_{channel}" + channel_output_path = output_path / channel + + channel_output_path.mkdir(exist_ok=True, parents=True) - makedirs(channel_output_path, exist_ok=True) + stacks = input_dataset[channel] - with asection(f"Channel '{channel}' shape: {input_dataset.shape(channel)}:"): + process = _process( + stacks=stacks, + outpath=channel_output_path, + project_func=project_func, + overwrite=overwrite, + ) + + with asection(f"Channel '{channel}' shape: {stacks.shape}:"): aprint(input_dataset.info(channel)) - array = input_dataset.get_array(channel, wrap_with_dask=True) - - if slicing: - array = array[slicing] - - aprint(f"Rendering array of shape={array.shape} and dtype={array.dtype} for channel '{channel}'.") - - nbframes = array.shape[0] - - with asection("Rendering:"): - - def process(tp, _clim, device): - try: - with asection(f"Rendering Frame : {tp:05}"): - - filename = join(channel_output_path, f"frame_{tp:05}.png") - - if overwrite or not exists(filename): - - with asection("Loading stack..."): - stack = array[tp].compute() - - with BestBackend(device, exclusive=True, enable_unified_memory=True): - if _clim is not None: - aprint(f"Using provided min and max for contrast limits: {_clim}") - min_value, max_value = (float(strvalue) for strvalue in _clim.split(",")) - _clim = (min_value, max_value) - - with asection(f"Projecting image of shape: {stack.shape} "): - projection = project_image( - stack, - axis=axis, - dir=dir, - mode=mode, - attenuation=attenuation, - attenuation_min_density=0.002, - attenuation_filtering=4, - gamma=gamma, - clim=_clim, - cmap=colormap, - dlim=dlim, - rgb_gamma=rgbgamma, - transparency=transparency, - legend_size=legendsize, - legend_scale=legendscale, - legend_title=legendtitle, - legend_title_color=legendtitlecolor, - legend_position=legendposition, - legend_alpha=legendalpha, - ) - - with asection(f"Saving frame {tp} as: {filename}"): - imageio.imwrite(filename, Backend.to_numpy(projection), compress_level=1) - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {tp} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - if workers == -1: - workers = len(devices) - aprint(f"Number of workers: {workers}") - - if workers > 1: - Parallel(n_jobs=workers, backend=workersbackend)( - delayed(process)(tp, clim, devices[tp % len(devices)]) for tp in range(0, nbframes, step) - ) - else: - for tp in range(0, nbframes, step): - process(tp, clim, devices[0]) + lazy_computations += [process(time_point=t) for t in range(len(stacks))] + + if len(devices) > 0 and is_cupy_available(): + from dask_cuda import LocalCUDACluster + + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) + client = Client(cluster) + else: + client = Client() + aprint("Dask client", client) + + # Compute everything + dask.compute(*lazy_computations) + + client.close() diff --git a/dexp/datasets/operations/register.py b/dexp/datasets/operations/register.py index e57dc844..579bed23 100644 --- a/dexp/datasets/operations/register.py +++ b/dexp/datasets/operations/register.py @@ -1,4 +1,5 @@ -from typing import List +from copy import deepcopy +from typing import Dict, List, Sequence import dask import numpy @@ -6,123 +7,118 @@ from arbol.arbol import aprint, asection from dask.distributed import Client from dask_cuda import LocalCUDACluster +from toolz import curry +from dexp.datasets.base_dataset import BaseDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.processing.multiview_lightsheet.fusion.basefusion import BaseFusion from dexp.processing.multiview_lightsheet.fusion.simview import SimViewFusion from dexp.processing.registration.model.model_io import model_list_to_file from dexp.processing.registration.model.translation_registration_model import ( TranslationRegistrationModel, ) -from dexp.utils.backends import Backend, BestBackend -from dexp.utils.slicing import slice_from_shape +from dexp.utils.backends import BestBackend + + +@dask.delayed +@curry +def _process( + tp: int, + views: Dict[str, StackIterator], + fuse_model: BaseFusion, + max_proj: bool, + registration_edge_filter: bool, +) -> TranslationRegistrationModel: + + fuse_model = deepcopy(fuse_model) + + with BestBackend(exclusive=True, enable_unified_memory=True): + with asection(f"Loading volume {tp}"): + views_tp = {k: np.asarray(view[tp]) for k, view in views.items()} + + with asection(f"Registring volume {tp}:"): + C0Lx, C1Lx = fuse_model.preprocess(**views_tp) + fuse_model.compute_registration( + C0Lx, + C1Lx, + mode="projection" if max_proj else "full", + edge_filter=registration_edge_filter, + crop_factor_along_z=0.3, + ) + model = fuse_model.registration_model.to_numpy() + + return model def dataset_register( - dataset, - model_path, - channels, - slicing, - microscope, - equalise, - zero_level, - clip_too_high, - fusion, - fusion_bias_strength_i, - dehaze_size, - registration_edge_filter, - max_proj, - white_top_hat_size, - white_top_hat_sampling, - remove_beads, - devices, - stop_at_exception=True, -): - - views = {channel.split("-")[-1]: dataset.get_array(channel, per_z_slice=False) for channel in channels} + dataset: BaseDataset, + model_path: str, + channels: Sequence[str], + microscope: str, + equalise: bool, + zero_level: int, + clip_too_high: int, + fusion: str, + fusion_bias_strength_i: float, + dehaze_size: int, + registration_edge_filter: bool, + max_proj: bool, + white_top_hat_size: float, + white_top_hat_sampling: int, + remove_beads: bool, + devices: Sequence[int], +) -> None: + + views = {channel.split("-")[-1]: dataset[channel] for channel in channels} + n_time_pts = len(list(views.values())[0]) with asection("Views:"): for channel, view in views.items(): aprint(f"View: {channel} of shape: {view.shape} and dtype: {view.dtype}") - key = list(views.keys())[0] - print(f"Slicing with: {slicing}") - _, volume_slicing, time_points = slice_from_shape(views[key].shape, slicing) - if microscope == "simview": - views = SimViewFusion.validate_views(views) - - @dask.delayed - def process(i): - tp = time_points[i] - try: - with asection(f"Loading channels {channel} for time point {i}/{len(time_points)}"): - views_tp = {k: np.asarray(view[tp][volume_slicing]) for k, view in views.items()} - - with BestBackend(exclusive=True, enable_unified_memory=True): - if microscope == "simview": - fuse_obj = SimViewFusion( - registration_model=None, - equalise=equalise, - equalisation_ratios=[None, None, None], - zero_level=zero_level, - clip_too_high=clip_too_high, - fusion=fusion, - fusion_bias_exponent=2, - fusion_bias_strength_i=fusion_bias_strength_i, - fusion_bias_strength_d=0.0, - dehaze_before_fusion=True, - dehaze_size=dehaze_size, - dehaze_correct_max_level=True, - dark_denoise_threshold=0, - dark_denoise_size=0, - butterworth_filter_cutoff=0.0, - white_top_hat_size=white_top_hat_size, - white_top_hat_sampling=white_top_hat_sampling, - remove_beads=remove_beads, - flip_camera1=True, - ) - - C0Lx, C1Lx = fuse_obj.preprocess(**views_tp) - del views_tp - Backend.current().clear_memory_pool() - fuse_obj.compute_registration( - C0Lx, - C1Lx, - mode="projection" if max_proj else "full", - edge_filter=registration_edge_filter, - crop_factor_along_z=0.3, - ) - del C0Lx, C1Lx - Backend.current().clear_memory_pool() - model = fuse_obj.registration_model.to_numpy() - else: - raise NotImplementedError - - aprint(f"Done processing time point: {i}/{len(time_points)} .") - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {i} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - return model + fuse_model = SimViewFusion( + registration_model=None, + equalise=equalise, + equalisation_ratios=[None, None, None], + zero_level=zero_level, + clip_too_high=clip_too_high, + fusion=fusion, + fusion_bias_exponent=2, + fusion_bias_strength_i=fusion_bias_strength_i, + fusion_bias_strength_d=0.0, + dehaze_before_fusion=True, + dehaze_size=dehaze_size, + dehaze_correct_max_level=True, + dark_denoise_threshold=0, + dark_denoise_size=0, + butterworth_filter_cutoff=0.0, + white_top_hat_size=white_top_hat_size, + white_top_hat_sampling=white_top_hat_sampling, + remove_beads=remove_beads, + flip_camera1=True, + ) + views = fuse_model.validate_views(views) + else: + raise NotImplementedError + + process = _process( + views=views, fuse_model=fuse_model, max_proj=max_proj, registration_edge_filter=registration_edge_filter + ) cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=devices) client = Client(cluster) aprint("Dask Client", client) lazy_computations = [] - for i in range(len(time_points)): - lazy_computations.append(process(i)) + for i in range(n_time_pts): + lazy_computations.append(process(tp=i)) models = dask.compute(*lazy_computations) mode_model = compute_median_translation(models) model_list_to_file(model_path, [mode_model]) + client.close() diff --git a/dexp/datasets/operations/segment.py b/dexp/datasets/operations/segment.py index 13b2d576..ea250e9b 100644 --- a/dexp/datasets/operations/segment.py +++ b/dexp/datasets/operations/segment.py @@ -7,12 +7,12 @@ from dask.distributed import Client from dask_cuda import LocalCUDACluster from skimage.measure import regionprops_table -from skimage.segmentation import relabel_sequential from toolz import curry from dexp.datasets import BaseDataset, ZDataset from dexp.datasets.stack_iterator import StackIterator from dexp.processing.morphology import area_white_top_hat +from dexp.processing.segmentation import roi_watershed_from_minima from dexp.utils.backends import CupyBackend @@ -39,7 +39,6 @@ def _process( from cucim.skimage import morphology as morph from cucim.skimage.filters import threshold_otsu from edt import edt - from pyift.shortestpath import watershed_from_minima with CupyBackend() as bkd: with asection(f"Segmenting time point {time_point}:"): @@ -72,8 +71,12 @@ def _process( with asection(f"Detecting cells of channel {i} ..."): ch_detection = wth > threshold_otsu(wth) del wth - ch_detection = morph.binary_closing(ch_detection, morph.ball(np.sqrt(2))) - ch_detection = morph.remove_small_objects(ch_detection, min_size=minimum_area) + + # removing small white/black elements + ch_detection = morph.binary_opening(ch_detection, morph.ball(2)) + ch_detection = morph.binary_closing(ch_detection, morph.ball(2)) + + # ch_detection = morph.remove_small_objects(ch_detection, min_size=minimum_area) detection |= ch_detection count = detection.sum() @@ -88,7 +91,7 @@ def _process( detection = bkd.to_numpy(detection) with asection("Segmenting ..."): - _, labels = watershed_from_minima( + labels = roi_watershed_from_minima( image=basins, mask=detection, H_minima=h_minima, @@ -98,10 +101,6 @@ def _process( del basins, detection labels = labels.astype(np.int32) - with asection("Relabeling ..."): - labels[labels < 0] = 0 - labels, _, _ = relabel_sequential(labels) - out_dataset.write_stack(out_channel, time_point, labels) feature_image = np.stack( @@ -177,7 +176,7 @@ def dataset_segment( client = Client(cluster) aprint("Dask client", client) - df_path = output_dataset._path.replace(".zarr", ".csv") + df_path = output_dataset.path.replace(".zarr", ".csv") df = pd.concat(dask.compute(*lazy_computations)) df.to_csv(df_path, index=False) diff --git a/dexp/datasets/operations/serve.py b/dexp/datasets/operations/serve.py deleted file mode 100644 index dc80457e..00000000 --- a/dexp/datasets/operations/serve.py +++ /dev/null @@ -1,18 +0,0 @@ -from arbol.arbol import aprint - -from dexp.datasets import ZDataset - - -def dataset_serve(dataset: ZDataset, host: str = "0.0.0.0", port: int = 8000): - if not type(dataset) == ZDataset: - aprint("Cannot serve a non-Zarr dataset!") - return - - aprint(dataset.info()) - try: - from simple_zarr_server import serve - - serve(dataset._root_group, host=host, port=port) - finally: - # close destination dataset: - dataset.close() diff --git a/dexp/datasets/operations/stabilize.py b/dexp/datasets/operations/stabilize.py index 9c0e26cb..2dff5025 100644 --- a/dexp/datasets/operations/stabilize.py +++ b/dexp/datasets/operations/stabilize.py @@ -3,8 +3,11 @@ import numpy from arbol.arbol import aprint, asection from joblib import Parallel, delayed +from toolz import curry from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator +from dexp.datasets.zarr_dataset import ZDataset from dexp.processing.registration.model.model_io import from_json from dexp.processing.registration.model.sequence_registration_model import ( SequenceRegistrationModel, @@ -18,7 +21,6 @@ def _compute_model( input_dataset: BaseDataset, channel: str, - slicing, max_range: int, min_confidence: float, enable_com: bool, @@ -35,15 +37,15 @@ def _compute_model( maxproj: bool = True, device: int = 0, workers: int = 1, - debug_output=None, ) -> SequenceRegistrationModel: # get channel array: array = input_dataset.get_array(channel, per_z_slice=False, wrap_with_dask=True) # Perform slicing: - if slicing is not None: - array = array[slicing] + if input_dataset.slicing is not None: + # dexp slicing API doesn't work with dask + array = array[input_dataset.slicing] # Shape and chunks for array: ndim = array.ndim @@ -71,7 +73,6 @@ def _compute_model( edge_filter=edge_filter, internal_dtype=numpy.float16, workers=workers, - debug_output=debug_output, ) model.to_numpy() @@ -79,8 +80,8 @@ def _compute_model( projections = [] for axis in range(array.ndim - 1): projection = input_dataset.get_projection_array(channel=channel, axis=axis, wrap_with_dask=False) - if slicing is not None: - projection = projection[slicing] + if input_dataset.slicing is not None: + projection = projection[input_dataset.slicing] proj_axis = list(1 + a for a in range(array.ndim - 1) if a != axis) projections.append((*proj_axis, projection)) @@ -103,7 +104,6 @@ def _compute_model( edge_filter=edge_filter, ndim=ndim - 1, internal_dtype=numpy.float16, - debug_output=debug_output, detrend=detrend, ) model.to_numpy() @@ -111,18 +111,39 @@ def _compute_model( return model +# definition of function that processes each time point: +@curry +def _process( + i: int, + array: StackIterator, + output_dataset: ZDataset, + channel: str, + model: SequenceRegistrationModel, + pad: bool, + integral: bool, +) -> None: + with asection(f"Processing time point: {i}/{len(array)}"): + with NumpyBackend() as bkd: + with asection("Loading stack"): + tp_array = bkd.to_backend(array[i]) + + with asection("Applying model..."): + tp_array = model.apply(tp_array, index=i, pad=pad, integral=integral) + + with asection( + f"Saving stabilized stack for time point {i}/{len(array)}, shape:{tp_array.shape}, " + + "dtype:{array.dtype}" + ): + output_dataset.write_stack(channel=channel, time_point=i, stack_array=tp_array) + + def dataset_stabilize( - dataset: BaseDataset, - output_path: str, + input_dataset: BaseDataset, + output_dataset: ZDataset, channels: Sequence[str], model_output_path: str, model_input_path: Optional[str] = None, reference_channel: Optional[str] = None, - slicing=None, - zarr_store: str = "dir", - compression_codec: str = "zstd", - compression_level: int = 3, - overwrite: bool = False, max_range: int = 7, min_confidence: float = 0.5, enable_com: bool = False, @@ -142,9 +163,6 @@ def dataset_stabilize( workers: int = -1, workers_backend: str = "threading", device: int = 0, - check: bool = True, - stop_at_exception: bool = True, - debug_output=None, ): """ Takes an input dataset and performs image stabilisation and outputs a stabilised dataset @@ -152,19 +170,14 @@ def dataset_stabilize( Parameters ---------- - dataset: Input dataset - output_path: Output path for Zarr storage + input_dataset: Input dataset + output_dataset: Output dataset channels: selected channels model_output_path: str Path to store computed stabilization model for reproducibility. model_input_path: Optional[str] = None, Path of previously computed stabilization model. reference_channel: use this channel to compute the stabilization model and apply to every channel. - slicing: selected array slicing - zarr_store: type of store, can be 'dir', 'ndir', or 'zip' - compression_codec: compression codec to be used ('zstd', 'blosclz', 'lz4', 'lz4hc', 'zlib' or 'snappy'). - compression_level: An integer between 0 and 9 specifying the compression level. - overwrite: overwrite output dataset if already exists max_range: maximal distance, in time points, between pairs of images to registrate. min_confidence: minimal confidence below which pairwise registrations are rejected for the stabilisation. enable_com: enable center of mass fallback when standard registration fails. @@ -186,30 +199,22 @@ def dataset_stabilize( workers: number of workers, if -1 then the number of workers == number of cores or devices workers_backend: What backend to spawn workers with, can be ‘loky’ (multi-process) or ‘threading’ (multi-thread) device: Sets the CUDA devices id, e.g. 0,1,2 - check: Checking integrity of written file. - stop_at_exception: True to stop as soon as there is an exception during processing. """ if model_input_path is not None and reference_channel is not None: raise ValueError("`model_input_path` and `reference channel` cannot be supplied at the same time.") - from dexp.datasets import ZDataset - - mode = "w" + ("" if overwrite else "-") - dest_dataset = ZDataset(output_path, mode, zarr_store, parent=dataset) - model = None if model_input_path is not None: with open(model_input_path) as f: model = from_json(f.read()) elif reference_channel is not None: - if reference_channel not in dataset.channels(): + if reference_channel not in input_dataset.channels(): raise ValueError(f"Reference channel {reference_channel} not found.") model = _compute_model( - input_dataset=dataset, + input_dataset=input_dataset, channel=reference_channel, - slicing=slicing, max_range=max_range, min_confidence=min_confidence, enable_com=enable_com, @@ -226,17 +231,15 @@ def dataset_stabilize( maxproj=maxproj, device=device, workers=workers, - debug_output=debug_output, ) with open(model_output_path, mode="w") as f: f.write(model.to_json()) - for channel in dataset._selected_channels(channels): + for channel in input_dataset._selected_channels(channels): if model is None: channel_model = _compute_model( - input_dataset=dataset, + input_dataset=input_dataset, channel=channel, - slicing=slicing, max_range=max_range, min_confidence=min_confidence, enable_com=enable_com, @@ -253,7 +256,6 @@ def dataset_stabilize( maxproj=maxproj, device=device, workers=workers, - debug_output=debug_output, ) formatted_path = f"{model_output_path.split('.')[0]}_{channel}.json" with open(formatted_path, mode="w") as f: @@ -262,7 +264,7 @@ def dataset_stabilize( else: channel_model = model - metadata = dataset.get_metadata() + metadata = input_dataset.get_metadata() if "dt" in metadata.get(channel, {}): prev_displacement = channel_model.total_displacement() channel_model = channel_model.reduce(int(round(metadata[channel]["dt"]))) @@ -276,12 +278,9 @@ def dataset_stabilize( # minimum displacement different results in translations shift from reference channel metadata[channel][name] = current[0] - prev[0] aprint(f"{name.upper()}: {current[0] - prev[0]}") - dest_dataset.append_metadata(metadata) - - array = dataset.get_array(channel, per_z_slice=False, wrap_with_dask=True) - if slicing is not None: - array = array[slicing] + output_dataset.append_metadata(metadata) + array = input_dataset[channel] shape = array.shape dtype = array.dtype nb_timepoints = shape[0] @@ -297,37 +296,17 @@ def dataset_stabilize( # Shape of the resulting array: padded_shape = (nb_timepoints,) + channel_model.padded_shape(shape[1:]) - - dest_dataset.add_channel( - name=channel, shape=padded_shape, dtype=dtype, codec=compression_codec, clevel=compression_level + output_dataset.add_channel(name=channel, shape=padded_shape, dtype=dtype) + + process = _process( + array=array, + output_dataset=output_dataset, + channel=channel, + model=channel_model, + pad=pad, + integral=integral, ) - # definition of function that processes each time point: - def process(tp): - try: - with asection(f"Processing time point: {tp}/{nb_timepoints} ."): - with asection("Loading stack"): - tp_array = array[tp].compute() - - with NumpyBackend(): - with asection("Applying model..."): - tp_array = channel_model.apply(tp_array, index=tp, pad=pad, integral=integral) - - with asection( - f"Saving stabilized stack for time point {tp}/{nb_timepoints}, shape:{tp_array.shape}, " - + "dtype:{array.dtype}" - ): - dest_dataset.write_stack(channel=channel, time_point=tp, stack_array=tp_array) - - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {tp}/{nb_timepoints} !") - import traceback - - traceback.print_exc() - if stop_at_exception: - raise error - # start jobs: if workers == 1: for tp in range(0, nb_timepoints): @@ -336,17 +315,8 @@ def process(tp): n_jobs = compute_num_workers(workers, nb_timepoints) Parallel(n_jobs=n_jobs, backend=workers_backend)(delayed(process)(tp) for tp in range(0, nb_timepoints)) - # printout output dataset info: - aprint(dest_dataset.info()) - if check: - dest_dataset.check_integrity() - # Dataset info: - aprint(dest_dataset.info()) + aprint(output_dataset.info()) # Check dataset integrity: - if check: - dest_dataset.check_integrity() - - # close destination dataset: - dest_dataset.close() + output_dataset.check_integrity() diff --git a/dexp/datasets/operations/tiff.py b/dexp/datasets/operations/tiff.py index 9cdf1bbc..ccf1e9d7 100644 --- a/dexp/datasets/operations/tiff.py +++ b/dexp/datasets/operations/tiff.py @@ -1,132 +1,125 @@ import os -from os.path import join +from pathlib import Path from typing import Sequence, Union +import numpy as np from arbol.arbol import aprint, asection from joblib import Parallel, delayed from tifffile import memmap +from toolz import curry from dexp.datasets import BaseDataset +from dexp.datasets.stack_iterator import StackIterator from dexp.io.io import tiff_save +@curry +def _save_multi_file( + tp: int, stacks: StackIterator, dest_path: Path, channel: str, clevel: int, overwrite: bool, project: bool +) -> None: + with asection(f"Saving time point {tp}: "): + tiff_file_path = dest_path / f"file{tp}_{channel}.tiff" + if overwrite or not tiff_file_path.exists(): + stack = np.asarray(stacks[tp]) + + if type(project) == int: + # project is the axis for projection, but here we are not considering + # the T dimension anymore... + aprint(f"Projecting along axis {project}") + stack = stack.max(axis=project) + + aprint( + f"Writing time point: {tp} of shape: {stack.shape}, dtype:{stack.dtype} " + + f"as TIFF file: '{tiff_file_path}', with compression: {clevel}" + ) + tiff_save(tiff_file_path, stack, compress=clevel) + aprint(f"Done writing time point: {tp} !") + else: + aprint(f"File for time point (or z slice): {tp} already exists.") + + +@curry +def _save_single_file(tp: int, stacks: StackIterator, memmap_image: np.memmap, project: bool) -> None: + aprint(f"Processing time point {tp}") + stack = np.asarray(stacks[tp]) + + if type(project) == int: + # project is the axis for projection, but here we are not considering the T dimension anymore... + aprint(f"Projecting along axis {project}") + stack = stack.max(axis=project) + + memmap_image[tp] = stack + + def dataset_tiff( dataset: BaseDataset, - dest_path: str, + dest_path: Path, channels: Sequence[str], - slicing, overwrite: bool = False, project: Union[int, bool] = False, one_file_per_first_dim: bool = False, clevel: int = 0, workers: int = 1, workersbackend: str = "", - stop_at_exception: bool = True, ): - - selected_channels = dataset._selected_channels(channels) - - aprint(f"getting Dask arrays for channels {selected_channels}") - arrays = list(dataset.get_array(channel, per_z_slice=False, wrap_with_dask=True) for channel in selected_channels) - - if slicing is not None: - aprint(f"Slicing with: {slicing}") - arrays = list(array[slicing] for array in arrays) - aprint("Done slicing.") + aprint(f"getting Dask arrays for channels {channels}") if workers == -1: workers = max(1, os.cpu_count() // abs(workers)) + aprint(f"Number of workers: {workers}") if one_file_per_first_dim: aprint(f"Saving one TIFF file for each tp (or Z if already sliced) to: {dest_path}.") - os.makedirs(dest_path, exist_ok=True) - - def process(tp): - try: - with asection(f"Saving time point {tp}: "): - for channel, array in zip(selected_channels, arrays): - tiff_file_path = join(dest_path, f"file{tp}_{channel}.tiff") - if overwrite or not os.path.exists(tiff_file_path): - stack = array[tp].compute() - - if project is not False and type(project) == int: - # project is the axis for projection, but here we are not considering - # the T dimension anymore... - aprint(f"Projecting along axis {project}") - stack = stack.max(axis=project) - - aprint( - f"Writing time point: {tp} of shape: {stack.shape}, dtype:{stack.dtype} " - + f"as TIFF file: '{tiff_file_path}', with compression: {clevel}" - ) - tiff_save(tiff_file_path, stack, compress=clevel) - aprint(f"Done writing time point: {tp} !") - else: - aprint(f"File for time point (or z slice): {tp} already exists.") - except Exception as error: - aprint(error) - aprint(f"Error occurred while processing time point {tp} !") - import traceback - - traceback.print_exc() - - if stop_at_exception: - raise error - - if workers > 1: - Parallel(n_jobs=workers, backend=workersbackend)( - delayed(process)(tp) for tp in range(0, arrays[0].shape[0]) - ) - else: - for tp in range(0, arrays[0].shape[0]): - process(tp) + dest_path.mkdir(exist_ok=True) - else: + process = _save_multi_file(dest_path=dest_path, clevel=clevel, overwrite=overwrite, project=project) - for channel, array in zip(selected_channels, arrays): - if len(selected_channels) > 1: - tiff_file_path = f"{dest_path}_{channel}.tiff" + for channel in channels: + stacks = dataset[channel] + _process = process(stacks=stacks, channel=channel) + + _workers = min(workers, len(stacks)) + if _workers > 1: + Parallel(n_jobs=_workers, backend=workersbackend)(delayed(_process)(tp) for tp in range(len(stacks))) else: + for tp in range(len(stacks)): + _process(tp) + + else: + for channel in channels: + if len(channels) > 1: + tiff_file_path = f"{dest_path}_{channel}.tiff" + + elif not dest_path.name.endswith((".tiff", ".tif")): tiff_file_path = f"{dest_path}.tiff" - if not overwrite and os.path.exists(tiff_file_path): + tiff_file_path = Path(tiff_file_path) + + if not overwrite and tiff_file_path.exists(): aprint(f"File {tiff_file_path} already exists! Set option -w to overwrite.") return - with asection( - f"Saving array ({array.shape}, {array.dtype}) for channel {channel} into " - + f"TIFF file at: {tiff_file_path}:" - ): + stacks = dataset[channel] + shape = stacks.shape - shape = array.shape + if type(project) == int: + shape = list(shape) + shape.pop(1 + project) + shape = tuple(shape) - if project is not False and type(project) == int: - shape = list(shape) - shape.pop(1 + project) - shape = tuple(shape) + memmap_image = memmap(tiff_file_path, shape=shape, dtype=stacks.dtype, bigtiff=True, imagej=True) - memmap_image = memmap(tiff_file_path, shape=shape, dtype=array.dtype, bigtiff=True, imagej=True) + _process = _save_single_file(stacks=stacks, memmap_image=memmap_image, project=project) - def process(tp): - aprint(f"Processing time point {tp}") - stack = array[tp].compute() + _workers = min(len(stacks), workers) + if _workers > 1: + Parallel(n_jobs=_workers, backend=workersbackend)(delayed(_process)(tp) for tp in range(len(stacks))) - if project is not False and type(project) == int: - # project is the axis for projection, but here we are not considering the T dimension anymore... - aprint(f"Projecting along axis {project}") - stack = stack.max(axis=project) - - memmap_image[tp] = stack - - if workers > 1: - Parallel(n_jobs=workers, backend=workersbackend)( - delayed(process)(tp) for tp in range(0, array.shape[0]) - ) - else: - for tp in range(0, array.shape[0]): - process(tp) + else: + for tp in range(len(stacks)): + _process(tp) - memmap_image.flush() - del memmap_image + memmap_image.flush() + del memmap_image diff --git a/dexp/datasets/operations/view.py b/dexp/datasets/operations/view.py index a6d4f50c..165fc436 100644 --- a/dexp/datasets/operations/view.py +++ b/dexp/datasets/operations/view.py @@ -1,126 +1,178 @@ from typing import List, Sequence, Union import dask +import numpy as np from arbol import aprint from dask.array import reshape +from toolz import curry -from dexp.datasets import BaseDataset +from dexp.datasets import ZDataset +from dexp.datasets.base_dataset import BaseDataset def dataset_view( input_dataset: BaseDataset, channels: Sequence[str], - slicing, - aspect: float, + scale: int, contrast_limits: Union[List[int], List[float]], colormap: str, name: str, windowsize: int, projections_only: bool, volume_only: bool, - rescale_time: bool, + quiet: bool, + labels: Sequence[str], ): import napari + import tensorstore as ts from napari.layers.utils._link_layers import ( _get_common_evented_attributes, link_layers, ) - if rescale_time: - nb_time_points = [input_dataset.shape(channel)[0] for channel in input_dataset.channels()] - max_time_points = max(nb_time_points) - time_scales = [round(max_time_points / n) for n in nb_time_points] + if quiet: + from napari.components.viewer_model import ViewerModel + + viewer = ViewerModel() else: - time_scales = [1] * len(input_dataset.channels()) + viewer = napari.Viewer(title=f"DEXP | viewing with napari: {name} ", ndisplay=2) + viewer.window.resize(windowsize + 256, windowsize) - viewer = napari.Viewer(title=f"DEXP | viewing with napari: {name} ", ndisplay=2) viewer.grid.enabled = True - viewer.window.resize(windowsize + 256, windowsize) - for time_scale, channel in zip(time_scales, channels): + scale_array = np.asarray((1, scale, scale, scale)) + + for channel in channels: aprint(f"Channel '{channel}' shape: {input_dataset.shape(channel)}") aprint(input_dataset.info(channel)) - array = input_dataset.get_array(channel, wrap_with_dask=True) + # setting up downsample and layer type + if channel in labels: + downsample = curry(ts.downsample, method="stride") - layers = [] - try: - if not volume_only: - for axis in range(array.ndim - 1): - proj_array = input_dataset.get_projection_array(channel, axis=axis, wrap_with_dask=True) + def add_layer(data, **kwargs): + return viewer.add_labels(data, **kwargs) - # if the data format does not support projections we skip: - if proj_array is None: - continue + else: + downsample = curry(ts.downsample, method="mean") - if axis == 1: - proj_array = dask.array.flip(proj_array, 1) # flipping y - elif axis == 2: - proj_array = dask.array.rot90(proj_array, axes=(2, 1)) # - - shape = ( - proj_array.shape[0], - 1, - ) + proj_array.shape[1:] - proj_array = reshape(proj_array, shape=shape) - - if proj_array is not None: - proj_layer = viewer.add_image( - proj_array, - name=channel + "_proj_" + str(axis), - contrast_limits=contrast_limits, - blending="additive", - colormap=colormap, - ) - layers.append(proj_layer) - - if aspect is not None: - if axis == 0: - proj_layer.scale = (time_scale, 1, 1, 1) - elif axis == 1: - proj_layer.scale = (time_scale, 1, aspect, 1) - elif axis == 2: - proj_layer.scale = (time_scale, 1, 1, aspect) # reordered due to rotation - - aprint(f"Setting aspect ratio for projection (layer.scale={proj_layer.scale})") + def add_layer(data, **kwargs): + return viewer.add_image( + data, + contrast_limits=contrast_limits, + blending="additive", + colormap=colormap, + rendering="attenuated_mip", + **kwargs, + ) - except KeyError: - aprint("Warning: can't find projections!") + resolution = np.asarray(input_dataset.get_resolution(channel)) - if not projections_only: + if isinstance(input_dataset, ZDataset): + array = input_dataset.get_array(channel, wrap_with_tensorstore=True) + if scale != 1: + array = downsample(array, scale_array) + resolution *= scale_array + else: + array = input_dataset.get_array(channel, wrap_with_dask=True) - if slicing: - array = array[slicing] + layers = [] + if not projections_only: aprint(f"Adding array of shape={array.shape} and dtype={array.dtype} for channel '{channel}'.") # flip x for second camera: if "C1" in channel: - array = dask.array.flip(array, -1) + if isinstance(array, ts.TensorStore): + array = array[..., ::-1] + else: + array = dask.array.flip(array, -1) - layer = viewer.add_image( + layer = add_layer( array, name=channel, - contrast_limits=contrast_limits, - blending="additive", - colormap=colormap, - attenuation=0.04, - rendering="attenuated_mip", + scale=resolution, ) layers.append(layer) - if aspect is not None: - if array.ndim == 3: - layer.scale = (aspect, 1, 1) - elif array.ndim == 4: - layer.scale = (time_scale, aspect, 1, 1) - aprint(f"Setting time scale to {time_scale}") - aprint(f"Setting aspect ratio to {aspect} (layer.scale={layer.scale})") + try: + if not volume_only: + for axis in (1, 2, 0): + if isinstance(array, ts.TensorStore): + proj_array = input_dataset.get_projection_array(channel, axis=axis, wrap_with_tensorstore=True) + else: + proj_array = input_dataset.get_projection_array(channel, axis=axis, wrap_with_dask=True) + + # if the data format does not support projections we skip: + if proj_array is None: + continue + + if isinstance(array, ts.TensorStore): + if axis == 1: + proj_array = proj_array[..., ::-1, :] # flipping y + + elif axis == 2: + # rotating 90 degrees on (2,1)-plane, equivalent to transposing (1,2) -> (-2, 1) + proj_array = proj_array[ + ts.IndexTransform( + input_rank=proj_array.ndim, + output=[ + ts.OutputIndexMap(input_dimension=0), + ts.OutputIndexMap(input_dimension=2, stride=-1), + ts.OutputIndexMap(input_dimension=1), + ], + ) + ] + + proj_array = proj_array[:, np.newaxis] # appending dummy dim + + if scale != 1: + proj_array = downsample(proj_array, scale_array) + + else: + if axis == 1: + proj_array = dask.array.flip(proj_array, 1) # flipping y + elif axis == 2: + proj_array = dask.array.rot90(proj_array, axes=(2, 1)) + + shape = ( + proj_array.shape[0], + 1, + ) + proj_array.shape[1:] + + proj_array = reshape(proj_array, shape=shape) + + proj_resolution = list(resolution) + proj_resolution.pop(axis + 1) + proj_resolution.insert(1, 1) + + if axis == 2: + proj_resolution[-2:] = reversed(proj_resolution[-2:]) + + # flip x for second camera: + if "C1" in channel: + if isinstance(array, ts.TensorStore): + array = array[..., ::-1] + else: + array = dask.array.flip(array, -1) + + proj_layer = add_layer( + proj_array, + name=f"{channel}_proj_{axis}", + scale=proj_resolution, + ) + layers.append(proj_layer) + + except KeyError: + aprint("Warning: couldn't find projections!") if layers: attr = _get_common_evented_attributes(layers) attr.remove("visible") link_layers(layers, attr) - napari.run() + viewer.dims.set_point(1, 0) # moving to z = 0 + + if not quiet: + napari.run() diff --git a/dexp/datasets/operations/view_remote.py b/dexp/datasets/operations/view_remote.py deleted file mode 100644 index 20ab0d4f..00000000 --- a/dexp/datasets/operations/view_remote.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Sequence, Tuple - -from arbol import aprint, asection -from zarr.errors import ArrayNotFoundError - - -def dataset_view_remote( - input_path: str, - name: str, - aspect: float, - channels: Sequence[str], - contrast_limits: Tuple[float, float], - colormap: str, - slicing, - windowsize: int, -): - # Annoying napari induced warnings: - import warnings - - warnings.filterwarnings("ignore") - - if channels is None: - aprint("Channel(s) must be specified!") - return - - if ":" not in input_path.replace("http://", ""): - input_path = f"{input_path}:8000" - - with asection(f"Viewing remote dataset at: {input_path}, channel(s): {channels}"): - channels = tuple(channel.strip() for channel in channels.split(",")) - channels = list(set(channels)) - - from napari import Viewer, gui_qt - - with gui_qt(): - viewer = Viewer(title=f"DEXP | viewing with napari: {name} ", ndisplay=2) - viewer.window.resize(windowsize + 256, windowsize) - - for channel in channels: - import dask.array as da - - if "/" in channel: - array = da.from_zarr(f"{input_path}/{channel}") - else: - array = da.from_zarr(f"{input_path}/{channel}/{channel}") - - if slicing is not None: - array = array[slicing] - - layer = viewer.add_image( - array, - name=channel, - visible=True, - contrast_limits=contrast_limits, - blending="additive", - colormap=colormap, - attenuation=0.04, - rendering="attenuated_mip", - ) - - if aspect is not None: - if array.ndim == 3: - layer.scale = (aspect, 1, 1) - elif array.ndim == 4: - layer.scale = (1, aspect, 1, 1) - aprint(f"Setting aspect ratio to {aspect} (layer.scale={layer.scale})") - - try: - for axis in range(array.ndim): - if "/" in channel: - proj_array = da.from_zarr(f"{input_path}/{channel}_max{axis}") - else: - proj_array = da.from_zarr(f"{input_path}/{channel}/{channel}_max{axis}") - - proj_layer = viewer.add_image( - proj_array, - name=f"{channel}_max{axis}", - visible=True, - contrast_limits=contrast_limits, - blending="additive", - colormap=colormap, - ) - - if aspect is not None: - if axis == 0: - proj_layer.scale = (1, 1) - elif axis == 1: - proj_layer.scale = (aspect, 1) - elif axis == 2: - proj_layer.scale = (aspect, 1) - - aprint(f"Setting aspect ratio for projection (layer.scale={proj_layer.scale})") - - except (KeyError, ArrayNotFoundError): - aprint("Warning: could not find projections in dataset!") diff --git a/dexp/datasets/stack_iterator.py b/dexp/datasets/stack_iterator.py index a6be555b..07853095 100644 --- a/dexp/datasets/stack_iterator.py +++ b/dexp/datasets/stack_iterator.py @@ -20,7 +20,11 @@ def shape(self) -> Tuple[int]: @property def dtype(self) -> np.dtype: - return self._array.dtype + dtype = self._array.dtype + if not isinstance(dtype, np.dtype): + # converting tensorstore dtype + dtype = dtype.numpy_dtype + return dtype @property def slicing(self) -> Tuple[slice]: diff --git a/dexp/datasets/synthetic_datasets/__init__.py b/dexp/datasets/synthetic_datasets/__init__.py index 233e9bd0..45826470 100644 --- a/dexp/datasets/synthetic_datasets/__init__.py +++ b/dexp/datasets/synthetic_datasets/__init__.py @@ -1,4 +1,5 @@ from dexp.datasets.synthetic_datasets.binary_blobs import binary_blobs +from dexp.datasets.synthetic_datasets.dexp_dataset import generate_dexp_zarr_dataset from dexp.datasets.synthetic_datasets.multiview_data import generate_fusion_test_data from dexp.datasets.synthetic_datasets.nuclei_background_data import ( generate_nuclei_background_data, diff --git a/dexp/datasets/synthetic_datasets/dexp_dataset.py b/dexp/datasets/synthetic_datasets/dexp_dataset.py new file mode 100644 index 00000000..14105fc5 --- /dev/null +++ b/dexp/datasets/synthetic_datasets/dexp_dataset.py @@ -0,0 +1,65 @@ +from dexp.datasets.synthetic_datasets.multiview_data import generate_fusion_test_data +from dexp.datasets.synthetic_datasets.nuclei_background_data import ( + generate_nuclei_background_data, +) +from dexp.datasets.zarr_dataset import ZDataset +from dexp.processing.deskew import skew +from dexp.utils.backends import BestBackend + + +def generate_dexp_zarr_dataset(path: str, dataset_type: str, n_time_pts: int = 2, **kwargs) -> ZDataset: + """Auxiliary function to generate dexps' zarr dataset from synthetic datas. + + Parameters + ---------- + path : str + Dataset path. + + dataset_type : str + Dataset type, options are `fusion` or `nuclei`. + + n_time_pts : int, optional + Number of time points (repeats), by default 2 + + Returns + ------- + ZDataset + Generated dataset object + """ + DS_TYPES = ("fusion", "fusion-skewed", "nuclei", "nuclei-skewed") + if dataset_type not in DS_TYPES: + raise ValueError(f"`dataset_type` must be {DS_TYPES}, found {dataset_type}") + + ds = ZDataset(path, mode="w") + + if dataset_type.startswith("fusion"): + names = ["ground-truth", "low-quality", "blend-a", "blend-b", "image-C0L0", "image-C1L0"] + images = generate_fusion_test_data(**kwargs) + + elif dataset_type.startswith("nuclei"): + names = ["ground-truth", "background", "image"] + images = generate_nuclei_background_data(**kwargs) + + else: + raise NotImplementedError + + if dataset_type.endswith("skewed"): + with BestBackend() as bkd: + xp = bkd.get_xp_module() + images = list(images) + for i in range(len(images)): + if names[i].startswith("image"): + skewed = bkd.to_numpy(skew(bkd.to_backend(images[i]), xp.ones((5,) * images[-1].ndim), 1, 45, 1)) + images[i] = skewed + + ds.append_metadata(dict(dz=1, res=1, angle=45)) # metadata used for mvsols fusion + + for name, image in zip(names, images): + ds.add_channel(name, (n_time_pts,) + image.shape, dtype=image.dtype) + for t in range(n_time_pts): + # adding a small shift to each time point + ds.write_stack(name, t, image) + + ds.check_integrity() + + return ds diff --git a/dexp/datasets/synthetic_datasets/nuclei_background_data.py b/dexp/datasets/synthetic_datasets/nuclei_background_data.py index 70ce87ef..c45a5fe3 100644 --- a/dexp/datasets/synthetic_datasets/nuclei_background_data.py +++ b/dexp/datasets/synthetic_datasets/nuclei_background_data.py @@ -53,9 +53,6 @@ def generate_nuclei_background_data( if type(Backend.current()) is NumpyBackend: internal_dtype = xp.float32 - if type(Backend.current()) is NumpyBackend: - dtype = numpy.float32 - with asection("generate blob images"): image_gt = binary_blobs( length=length_xy, @@ -88,7 +85,7 @@ def generate_nuclei_background_data( y = xp.linspace(-1, 1, num=ly) z = xp.linspace(-1, 1, num=lz) xx, yy, zz = numpy.meshgrid(x, y, z, sparse=True) - distance = xp.sqrt(xx ** 2 + yy ** 2 + zz ** 2) + distance = xp.sqrt(xx**2 + yy**2 + zz**2) mask = distance < radius # f = 0.5*(1 + level**3 / (1 + xp.absolute(level)**3)) diff --git a/dexp/datasets/tiff_dataset.py b/dexp/datasets/tiff_dataset.py index d04eb5c8..0e80b75a 100644 --- a/dexp/datasets/tiff_dataset.py +++ b/dexp/datasets/tiff_dataset.py @@ -10,7 +10,7 @@ class TIFDataset(BaseDataset): def __init__(self, path: str): - super().__init__(dask_backed=False) + super().__init__(dask_backed=False, path=path) self._channel = "stack" self._array = imread(path) diff --git a/dexp/datasets/zarr_dataset.py b/dexp/datasets/zarr_dataset.py index 3e916e81..967b57d4 100755 --- a/dexp/datasets/zarr_dataset.py +++ b/dexp/datasets/zarr_dataset.py @@ -5,19 +5,20 @@ import sys from os.path import exists, isdir, isfile, join from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import dask import numpy import zarr from arbol.arbol import aprint, asection from ome_zarr.format import CurrentFormat -from zarr import Blosc, CopyError, Group, convenience, open_group +from zarr import Blosc, CopyError, convenience, open_group from dexp.cli.defaults import DEFAULT_CLEVEL, DEFAULT_CODEC from dexp.datasets.base_dataset import BaseDataset from dexp.datasets.ome_dataset import create_coord_transform, default_omero_metadata from dexp.datasets.stack_iterator import StackIterator +from dexp.utils import compress_dictionary_lists_length from dexp.utils.backends import Backend, BestBackend from dexp.utils.config import config_blosc @@ -65,94 +66,91 @@ def __init__( """ config_blosc() - super().__init__(dask_backed=False) + super().__init__(dask_backed=False, path=path) - if not isinstance(path, str): - path = str(path) - - self._path = path self._store = None self._root_group = None - self._arrays: Dict[str, zarr.Array] = {} - self._projections = {} self._codec = codec self._clevel = clevel self._chunks = chunks # Open remote store: - if "http" in path: - aprint(f"Opening a remote store at: {path}") + if "http" in self._path: + aprint(f"Opening a remote store at: {self._path}") from fsspec import get_mapper - self.store = get_mapper(path) + self.store = get_mapper(self._path) self._root_group = zarr.open(self.store, mode=mode) - self._initialise_existing() return # Correct path to adhere to convention: if "a" in mode or "w" in mode: - if path.endswith(".zarr.zip") or store == "zip": - path = path + ".zip" if path.endswith(".zarr") else path - path = path if path.endswith(".zarr.zip") else path + ".zarr.zip" - elif path.endswith(".nested.zarr") or path.endswith(".nested.zarr/") or store == "ndir": - path = path if path.endswith(".nested.zarr") else path + ".nested.zarr" - elif path.endswith(".zarr") or path.endswith(".zarr/") or store == "dir": - path = path if path.endswith(".zarr") else path + ".zarr" + if self._path.endswith(".zarr.zip") or store == "zip": + self._path = self._path + ".zip" if self._path.endswith(".zarr") else self._path + self._path = self._path if self._path.endswith(".zarr.zip") else self._path + ".zarr.zip" + elif self._path.endswith(".nested.zarr") or self._path.endswith(".nested.zarr/") or store == "ndir": + self._path = self._path if self._path.endswith(".nested.zarr") else self._path + ".nested.zarr" + elif self._path.endswith(".zarr") or self._path.endswith(".zarr/") or store == "dir": + self._path = self._path if self._path.endswith(".zarr") else self._path + ".zarr" # if exists and overwrite then delete! - if exists(path) and mode == "w-": - raise ValueError(f"Storage '{path}' already exists, add option '-w' to force overwrite!") - elif exists(path) and mode == "w": - aprint(f"Deleting '{path}' for overwrite!") - if isdir(path): + if exists(self._path) and mode == "w-": + raise ValueError(f"Storage '{self._path}' already exists, add option '-w' to force overwrite!") + elif exists(self._path) and mode == "w": + aprint(f"Deleting '{self._path}' for overwrite!") + if isdir(self._path): # This is a very dangerous operation, let's double check that the folder really holds a zarr dataset: - _zgroup_file = join(path, ".zgroup") - _zarray_file = join(path, ".zarray") + _zgroup_file = join(self._path, ".zgroup") + _zarray_file = join(self._path, ".zarray") # We check that either of these two hiddwn files are present, and if '.zarr' is part of the name: - if (".zarr" in path) and (exists(_zgroup_file) or exists(_zarray_file)): - shutil.rmtree(path, ignore_errors=True) + if (".zarr" in self._path) and (exists(_zgroup_file) or exists(_zarray_file)): + shutil.rmtree(self._path, ignore_errors=True) else: raise ValueError( "Specified path does not seem to be a zarr dataset, deletion for overwrite" "not performed out of abundance of caution, check path!" ) - elif isfile(path): - os.remove(path) + elif isfile(self._path): + os.remove(self._path) - if exists(path): - aprint(f"Opening existing Zarr storage: '{path}' with read/write mode: '{mode}' and store type: '{store}'") - if isfile(path) and (path.endswith(".zarr.zip") or store == "zip"): + if exists(self._path): + aprint(f"Opening existing Zarr: '{self._path}' with read/write mode: '{mode}' and store type: '{store}'") + if isfile(self._path) and (self._path.endswith(".zarr.zip") or store == "zip"): aprint("Opening as ZIP store") - self._store = zarr.storage.ZipStore(path) - elif isdir(path) and (path.endswith(".nested.zarr") or path.endswith(".nested.zarr/") or store == "ndir"): + self._store = zarr.storage.ZipStore(self._path) + elif isdir(self._path) and ( + self._path.endswith(".nested.zarr") or self._path.endswith(".nested.zarr/") or store == "ndir" + ): aprint("Opening as Nested Directory store") - self._store = zarr.storage.NestedDirectoryStore(path) - elif isdir(path) and (path.endswith(".zarr") or path.endswith(".zarr/") or store == "dir"): + self._store = zarr.storage.NestedDirectoryStore(self._path) + elif isdir(self._path) and ( + self._path.endswith(".zarr") or self._path.endswith(".zarr/") or store == "dir" + ): aprint("Opening as Directory store") - self._store = zarr.storage.DirectoryStore(path) + self._store = zarr.storage.DirectoryStore(self._path) aprint(f"Opening with mode: {mode}") self._root_group = open_group(self._store, mode=mode) - self._initialise_existing() + elif "a" in mode or "w" in mode: - aprint(f"Creating Zarr storage: '{path}' with read/write mode: '{mode}' and store type: '{store}'") + aprint(f"Creating Zarr storage: '{self._path}' with read/write mode: '{mode}' and store type: '{store}'") if store is None: store = "dir" try: - if path.endswith(".zarr.zip") or store == "zip": + if self._path.endswith(".zarr.zip") or store == "zip": aprint("Opening as ZIP store") - self._store = zarr.storage.ZipStore(path) - elif path.endswith(".nested.zarr") or path.endswith(".nested.zarr/") or store == "ndir": + self._store = zarr.storage.ZipStore(self._path) + elif self._path.endswith(".nested.zarr") or self._path.endswith(".nested.zarr/") or store == "ndir": aprint("Opening as Nested Directory store") - self._store = zarr.storage.NestedDirectoryStore(path) - elif path.endswith(".zarr") or path.endswith(".zarr/") or store == "dir": + self._store = zarr.storage.NestedDirectoryStore(self._path) + elif self._path.endswith(".zarr") or self._path.endswith(".zarr/") or store == "dir": aprint("Opening as Directory store") - self._store = zarr.storage.DirectoryStore(path) + self._store = zarr.storage.DirectoryStore(self._path) else: aprint( - f"Cannot open {path}, needs to be a zarr directory (directory that ends with `.zarr` or " + f"Cannot open {self._path}, needs to be a zarr directory (directory that ends with `.zarr` or " + "`.nested.zarr` for nested folders), or a zipped zarr file (file that ends with `.zarr.zip`)" ) @@ -161,10 +159,10 @@ def __init__( except Exception: raise ValueError( "Problem: can't create target file/directory, most likely the target dataset " - + f"already exists or path incorrect: {path}" + + f"already exists or path incorrect: {self._path}" ) else: - raise ValueError(f"Invalid read/write mode or invalid path: {path} (check path!)") + raise ValueError(f"Invalid read/write mode or invalid path: {self._path} (check path!)") # updating metadata if parent is not None: @@ -175,34 +173,8 @@ def __init__( if mode in ("a", "w", "w-"): self.append_cli_history(parent if isinstance(parent, ZDataset) else None) - def _initialise_existing(self): - self._channels = [channel for channel, _ in self._root_group.groups()] - - aprint("Exploring Zarr hierarchy...") - for channel, channel_group in self._root_group.groups(): - aprint(f"Found channel: {channel}") - - channel_items = channel_group.items() - - for item_name, array in channel_items: - aprint(f"Found array: {item_name}") - - if item_name == channel or item_name == "fused": - # print(f'Opening array at {path}:{channel}/{item_name} ') - self._arrays[channel] = array - # self._arrays[channel] = from_zarr(path, component=f"{channel}/{item_name}") - elif (item_name.startswith(channel) or item_name.startswith("fused")) and "_projection_" in item_name: - self._projections[item_name] = array - - def _get_group_for_channel(self, channel: str) -> Union[None, Sequence[Group]]: - groups = [g for c, g in self._root_group.groups() if c == channel] - if len(groups) == 0: - return None - else: - return groups[0] - def chunk_shape(self, channel: str) -> Sequence[int]: - return self._arrays[channel].chunks + return self.get_array(channel).chunks @staticmethod def _default_chunks(shape: Tuple[int], dtype: Union[str, numpy.dtype], max_size: int = 2147483647) -> Tuple[int]: @@ -237,13 +209,13 @@ def check_integrity(self, channels: Sequence[str] = None) -> bool: return True def channels(self) -> Sequence[str]: - return list(self._arrays.keys()) + return list(self._root_group.keys()) def nb_timepoints(self, channel: str) -> int: return self.get_array(channel).shape[0] def shape(self, channel: str) -> Sequence[int]: - return self._arrays[channel].shape + return self.get_array(channel).shape def dtype(self, channel: str): return self.get_array(channel).dtype @@ -255,7 +227,7 @@ def info(self, channel: str = None, cli_history: bool = True) -> str: f"Channel: '{channel}', nb time points: {self.shape(channel)[0]}, shape: {self.shape(channel)[1:]}" ) info_str += ".\n" - info_str += str(self._arrays[channel].info) + info_str += str(self.get_array(channel).info) return info_str else: info_str += f"Dataset at location: {self._path} \n" @@ -264,14 +236,15 @@ def info(self, channel: str = None, cli_history: bool = True) -> str: info_str += str(self._root_group.tree()) info_str += ".\n\n" info_str += "Arrays: \n" - for name, array in self._arrays.items(): + for name in self.channels(): info_str += " │ \n" - info_str += " └──" + name + ":\n" + str(array.info) + "\n\n" + info_str += " └──" + name + ":\n" + str(self.get_array(name).info) + "\n\n" info_str += ".\n\n" info_str += ".\n\n" info_str += "\nMetadata: \n" - for key, value in self.get_metadata().items(): + metadata = compress_dictionary_lists_length(self.get_metadata(), 5) + for key, value in metadata.items(): if "cli_history" not in key: info_str += f"\t{key} : {value} \n" @@ -323,7 +296,7 @@ def _load_tensorstore(self, array: zarr.Array): "driver": "zarr", "kvstore": { "driver": "file", - "path": self._path, + "path": str(Path(self.path).resolve()), }, "path": array.path, "metadata": metadata, @@ -332,9 +305,9 @@ def _load_tensorstore(self, array: zarr.Array): def get_array( self, channel: str, per_z_slice: bool = False, wrap_with_dask: bool = False, wrap_with_tensorstore: bool = False - ): + ) -> Union[zarr.Array, Any]: assert (wrap_with_dask != wrap_with_tensorstore) or not wrap_with_dask - array = self._arrays[channel] + array = self._root_group[channel].get(channel) if wrap_with_dask: return dask.array.from_array(array, chunks=array.chunks) elif wrap_with_tensorstore: @@ -345,11 +318,19 @@ def get_stack(self, channel: str, time_point: int, per_z_slice: bool = False, wr stack_array = self.get_array(channel, per_z_slice=per_z_slice, wrap_with_dask=wrap_with_dask)[time_point] return stack_array - def get_projection_array(self, channel: str, axis: int, wrap_with_dask: bool = False) -> Any: - array = self._projections.get(self._projection_name(channel, axis)) + def get_projection_array( + self, channel: str, axis: int, wrap_with_dask: bool = False, wrap_with_tensorstore: bool = False + ) -> Optional[Union[zarr.Array, Any]]: + assert (wrap_with_dask != wrap_with_tensorstore) or not wrap_with_dask + array = self._root_group[channel].get(self._projection_name(channel, axis)) if array is None: - return - return dask.array.from_array(array, chunks=array.chunks) if wrap_with_dask else array + return None + + if wrap_with_dask: + return dask.array.from_array(array, chunks=array.chunks) + elif wrap_with_tensorstore: + return self._load_tensorstore(array) + return array def _projection_name(self, channel: str, axis: int): return f"{channel}_projection_{axis}" @@ -441,8 +422,6 @@ def add_channel( fill_value=fill_value, ) - self._arrays[name] = array - if enable_projections: ndim = len(shape) - 1 for axis in range(ndim): @@ -455,7 +434,7 @@ def add_channel( # chunking along time must be 1 to allow parallelism, but no chunking for each projection (not needed!) proj_chunks = (1,) + (None,) * (len(chunks) - 2) - proj_array = channel_group.full( + channel_group.full( name=proj_name, shape=proj_shape, dtype=dtype, @@ -464,7 +443,6 @@ def add_channel( compressor=compressor, fill_value=fill_value, ) - self._projections[proj_name] = proj_array return array @@ -500,7 +478,7 @@ def add_channels_to( for channel, new_name in zip(channels, rename): try: array = self.get_array(channel, per_z_slice=False, wrap_with_dask=False) - source_group = self._get_group_for_channel(channel) + source_group = self._root_group[channel] source_arrays = source_group.items() aprint(f"Creating group for channel {channel} of new name {new_name}.") @@ -561,7 +539,7 @@ def get_translation(self, channel: Optional[str] = None) -> List[float]: """ Gets channel translation. """ - axes = ("tz", "ty", "tx") + axes = ("tt", "tz", "ty", "tx") metadata = self.get_metadata() if channel is None: translation = [metadata.get(axis, 0) for axis in axes] @@ -633,7 +611,7 @@ def to_ome_zarr(self, path: str, force_dtype: Optional[int] = None, n_scales: in arrays = [] datasets = [] for i in range(n_scales): - factor = 2 ** i + factor = 2**i array_path = f"{i}" shape = ome_zarr_shape[:2] + tuple(int(m.ceil(s / factor)) for s in ome_zarr_shape[2:]) chunks = (1,) + self._default_chunks(shape, dtype) @@ -654,7 +632,7 @@ def to_ome_zarr(self, path: str, force_dtype: Optional[int] = None, n_scales: in arrays[0][t, c] = stack stack = bkd.to_backend(stack) for i in range(1, n_scales): - factors = (2 ** i,) * stack.ndim + factors = (2**i,) * stack.ndim arrays[i][t, c] = bkd.to_numpy(downscale_local_mean(stack, factors)) aprint("Done conversion to OME zarr") @@ -681,3 +659,7 @@ def to_ome_zarr(self, path: str, force_dtype: Optional[int] = None, n_scales: in def __getitem__(self, channel: str) -> StackIterator: return StackIterator(self.get_array(channel, wrap_with_tensorstore=True), self._slicing) + + def __contains__(self, channel: str) -> bool: + """Checks if channels exists, valid for when using multiple process.""" + return channel in self._root_group diff --git a/dexp/io/_test/test_compress_array.py b/dexp/io/_test/test_compress_array.py index a2d61f52..4d1878b2 100644 --- a/dexp/io/_test/test_compress_array.py +++ b/dexp/io/_test/test_compress_array.py @@ -26,7 +26,7 @@ def do_test(array_dc, array_uc, min_num_chunks=0): def test_compress_small_array(): - array_uc = numpy.linspace(0, 1024, 2 ** 15).astype(numpy.uint16) + array_uc = numpy.linspace(0, 1024, 2**15).astype(numpy.uint16) array_dc = numpy.empty_like(array_uc) do_test(array_dc, array_uc) diff --git a/dexp/io/io.py b/dexp/io/io.py index bf3ecb54..7fbee865 100644 --- a/dexp/io/io.py +++ b/dexp/io/io.py @@ -1,4 +1,4 @@ -import numpy +import numpy as np from tifffile import imsave @@ -17,18 +17,20 @@ def tiff_save(file, img, axes="ZYX", compress=0, **imsave_kwargs): """ # convert to imagej-compatible data type + img = np.asarray(img) + t = img.dtype if "float" in t.name: - t_new = numpy.float32 + t_new = np.float32 elif "uint" in t.name: - t_new = numpy.uint16 if t.itemsize >= 2 else numpy.uint8 + t_new = np.uint16 if t.itemsize >= 2 else np.uint8 elif "int" in t.name: - t_new = numpy.int16 + t_new = np.int16 else: t_new = t img = img.astype(t_new, copy=False) if t != t_new: - print(f"Converting data type from '{t}' to ImageJ-compatible '{numpy.dtype(t_new)}'.") + print(f"Converting data type from '{t}' to ImageJ-compatible '{np.dtype(t_new)}'.") imsave_kwargs["imagej"] = True imsave(file, img, **imsave_kwargs, compress=compress, metadata={"axes": axes}) # , diff --git a/dexp/processing/crop/background.py b/dexp/processing/crop/background.py new file mode 100644 index 00000000..bf22efc4 --- /dev/null +++ b/dexp/processing/crop/background.py @@ -0,0 +1,60 @@ +from scipy.signal._signaltools import _centered + +from dexp.utils import xpArray +from dexp.utils.backends import Backend + + +def foreground_mask( + array: xpArray, + intensity_threshold: int = 10, + binary_area_threshold: int = 3000, + downsample: int = 2, + display: bool = False, +) -> xpArray: + """ + Detects foreground by merging nearby components and removing objects below a given area threshold. + + Parameters + ---------- + array : xpArray + Input grayscale image. + intensity_threshold : int, optional + Threshold used to binarized filtered image --- detect objects. + binary_area_threshold : int, optional + Threshold used to remove detected objects. + downsample : int, optional + Downsample factor to speed up computation, it affects the distance between components and the area threshold. + display : bool, optional + Debugging flag, if true displays result using napari, by default False + + Returns + ------- + xpArray + Foreground mask. + """ + + ndi = Backend.get_sp_module(array).ndimage + morph = Backend.get_skimage_submodule("morphology", array) + + ndim = array.ndim + small = ndi.zoom(array, (1 / downsample,) * array.ndim, order=1) + + small = ndi.grey_opening(small, ndim * (3,)) + small = ndi.grey_closing(small, ndim * (7,)) + small = small > intensity_threshold + small = ndi.grey_dilation(small, ndim * (9,)) + + small = morph.remove_small_objects(small, binary_area_threshold) + mask = ndi.zoom(small, (downsample,) * ndim, order=0) + mask = _centered(mask, array.shape) + + if display: + import napari + + v = napari.Viewer() + v.add_image(Backend.to_numpy(array)) + v.add_labels(Backend.to_numpy(mask)) + + napari.run() + + return mask diff --git a/dexp/processing/crop/representative_crop.py b/dexp/processing/crop/representative_crop.py index 7b1e9907..1e8e9bb3 100644 --- a/dexp/processing/crop/representative_crop.py +++ b/dexp/processing/crop/representative_crop.py @@ -16,7 +16,7 @@ def representative_crop( smoothing_size: int = 1.0, equal_sides: bool = False, favour_odd_lengths: bool = False, - fast_mode: bool = False, + fast_mode: bool = True, fast_mode_num_crops: int = 1500, max_time_in_seconds: float = 10, return_slice: bool = False, @@ -192,6 +192,9 @@ def _crop_slice(translation, cropped_shape, downscale: int = 1): aprint("Interrupting crop search because of timeout!") break + if best_slice is None: + raise RuntimeError("Could not find representative crop.") + # We make sure to have the full and original crop! best_crop = image[best_slice] diff --git a/dexp/processing/deconvolution/blind_deconvolution.py b/dexp/processing/deconvolution/blind_deconvolution.py index e3ed5312..93c58c8e 100644 --- a/dexp/processing/deconvolution/blind_deconvolution.py +++ b/dexp/processing/deconvolution/blind_deconvolution.py @@ -33,6 +33,7 @@ def blind_deconvolution( microscope_params: Dict, n_iterations: int, n_zernikes: int = 15, + psf_output_path: Optional[str] = None, display: bool = False, ) -> xpArray: # Local imports to @@ -91,4 +92,7 @@ def get_psf(coefs: Optional[ArrayLike] = None) -> ArrayLike: if display: napari.run() + if psf_output_path is not None: + np.save(psf_output_path, backend.to_numpy(psf)) + return deconv diff --git a/dexp/processing/deconvolution/demo/demo_lr_deconvolution_3d_sg.py b/dexp/processing/deconvolution/demo/demo_lr_deconvolution_3d_sg.py index fcb8790a..3833c32f 100644 --- a/dexp/processing/deconvolution/demo/demo_lr_deconvolution_3d_sg.py +++ b/dexp/processing/deconvolution/demo/demo_lr_deconvolution_3d_sg.py @@ -52,7 +52,7 @@ def f(_image): _image, psf, num_iterations=iterations, normalise_input=False, padding=17, max_correction=max_correction ) - deconvolved_sg = scatter_gather_i2i(f, blurry, tiles=length_xy // 2, margins=17, normalise=True, to_numpy=True) + deconvolved_sg = scatter_gather_i2i(blurry, f, tiles=length_xy // 2, margins=17, normalise=True, to_numpy=True) from napari import Viewer, gui_qt diff --git a/dexp/processing/denoising/butterworth.py b/dexp/processing/denoising/butterworth.py index 574a10ef..c2265560 100644 --- a/dexp/processing/denoising/butterworth.py +++ b/dexp/processing/denoising/butterworth.py @@ -32,7 +32,7 @@ def calibrate_denoise_butterworth( min_order: float = 0.5, max_order: float = 6.0, num_order: int = 32, - crop_size_in_voxels: Optional[int] = 256 ** 3, + crop_size_in_voxels: Optional[int] = 256**3, display: bool = False, **other_fixed_parameters, ): @@ -364,10 +364,10 @@ def _setup_butterworth_denoiser( def _butterworth_filter(image_f: xpArray, f: xpArray, order: float) -> xpArray: if isinstance(image_f, np.ndarray): - return image_f / np.sqrt(1 + f ** order) + return image_f / np.sqrt(1 + f**order) else: # Faster operation if cupy array is used - return image_f * rsqrt(1 + f ** order) + return image_f * rsqrt(1 + f**order) # "Compiles" to cupy if available diff --git a/dexp/processing/denoising/j_invariance.py b/dexp/processing/denoising/j_invariance.py index 234c5e7d..c1286439 100644 --- a/dexp/processing/denoising/j_invariance.py +++ b/dexp/processing/denoising/j_invariance.py @@ -114,7 +114,7 @@ def _generate_mask(image: xpArray, stride: int = 4): # Generate slice for mask: spatialdims = image.ndim - n_masks = stride ** spatialdims + n_masks = stride**spatialdims phases = np.unravel_index(n_masks // 2, (stride,) * len(image.shape[:spatialdims])) mask = tuple(slice(p, None, stride) for p in phases) diff --git a/dexp/processing/denoising/metrics.py b/dexp/processing/denoising/metrics.py index 96a479ce..a5468a5d 100644 --- a/dexp/processing/denoising/metrics.py +++ b/dexp/processing/denoising/metrics.py @@ -133,7 +133,7 @@ def ssim(image_a, image_b, win_size=None): C1 = (K1 * R) ** 2 C2 = (K2 * R) ** 2 - A1, A2, B1, B2 = (2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2) + A1, A2, B1, B2 = (2 * ux * uy + C1, 2 * vxy + C2, ux**2 + uy**2 + C1, vx + vy + C2) D = B1 * B2 S = (A1 * A2) / D diff --git a/dexp/processing/deskew/__init__.py b/dexp/processing/deskew/__init__.py index e69de29b..2f1c557f 100644 --- a/dexp/processing/deskew/__init__.py +++ b/dexp/processing/deskew/__init__.py @@ -0,0 +1,8 @@ +from dexp.processing.deskew.classic_deskew import classic_deskew +from dexp.processing.deskew.utils import skew +from dexp.processing.deskew.yang_deskew import yang_deskew + +deskew_functions = dict( + classic=classic_deskew, + yang=yang_deskew, +) diff --git a/dexp/processing/deskew/utils.py b/dexp/processing/deskew/utils.py new file mode 100644 index 00000000..deb4c7ac --- /dev/null +++ b/dexp/processing/deskew/utils.py @@ -0,0 +1,67 @@ +import math + +import numpy as np + +from dexp.processing.filters.fft_convolve import fft_convolve +from dexp.utils import xpArray +from dexp.utils.backends import Backend + + +def skew(image: xpArray, psf: xpArray, shift: int, angle: float, zoom: int, axis: int = 0) -> xpArray: + """Translates, skews and blurs the input image. + + Parameters + ---------- + image : xpArray + Input n-dim array. + psf : xpArray + Input PSF, must match image dimensions. + shift : int + Translation shift. + angle : float + Skew angle. + zoom : int + Zoom factor. + axis : int, optional + Axis to apply the translation and skewing, by default 0 + + Returns + ------- + xpArray + Skewed image. + """ + xp = Backend.get_xp_module(image) + sp = Backend.get_sp_module(image) + + assert image.ndim == psf.ndim + + # Pad: + pad_width = () + ((0, 0),) * (image.ndim - 1) + pad_width = [(0, 0) for _ in range(image.ndim)] + pad_width[axis] = (int(shift * zoom * image.shape[axis] // 2), int(shift * zoom * image.shape[axis] // 2)) + + image = np.pad(image, pad_width=pad_width) + + # Add blur: + psf = psf.astype(dtype=image.dtype, copy=False) + image = fft_convolve(image, psf) + + # apply skew: + angle = 45 + matrix = xp.eye(image.ndim) + matrix[axis, 0] = math.cos(angle * math.pi / 180) + matrix[axis, 1] = math.sin(angle * math.pi / 180) + offset = np.zeros(image.ndim) + offset[axis] = shift + + # matrix = xp.linalg.inv(matrix) + skewed = sp.ndimage.affine_transform(image, matrix, offset=offset) + + # Add noise and clip + # skewed += xp.random.uniform(-1, 1) + skewed = np.clip(skewed, a_min=0, a_max=None) + + # cast to uint16: + skewed = skewed.astype(dtype=image.dtype) + + return skewed diff --git a/dexp/processing/deskew/yang_deskew.py b/dexp/processing/deskew/yang_deskew.py index f8ef1927..f81e06fc 100644 --- a/dexp/processing/deskew/yang_deskew.py +++ b/dexp/processing/deskew/yang_deskew.py @@ -17,6 +17,7 @@ def yang_deskew( camera_orientation: int = 0, num_split: int = 4, internal_dtype=None, + padding: None = None, ): """'Yang' Deskewing as done in Yang et al. 2019 ( https://www.biorxiv.org/content/10.1101/2020.09.22.309229v2 ) @@ -33,6 +34,7 @@ def yang_deskew( as a number of 90 deg rotations to be performed per 2D image in stack. num_split : number of splits to break down the data into pieces (along y, axis=2) to fit into the memory of GPU internal_dtype : internal dtype to perform computation + padding : dummy argument for compatibility Returns ------- diff --git a/dexp/processing/filters/kernels/butterworth.py b/dexp/processing/filters/kernels/butterworth.py index 5d1ded5d..73f5daf0 100644 --- a/dexp/processing/filters/kernels/butterworth.py +++ b/dexp/processing/filters/kernels/butterworth.py @@ -76,7 +76,7 @@ def butterworth_kernel( + ((z / cz) ** 2)[:, xp.newaxis, xp.newaxis] ) - kernel_fft = 1 / xp.sqrt(1.0 + (epsilon ** 2) * (freq ** order)) + kernel_fft = 1 / xp.sqrt(1.0 + (epsilon**2) * (freq**order)) kernel_fft = xp.squeeze(kernel_fft) diff --git a/dexp/processing/filters/kernels/wiener_butterworth.py b/dexp/processing/filters/kernels/wiener_butterworth.py index 7957106f..499588b8 100644 --- a/dexp/processing/filters/kernels/wiener_butterworth.py +++ b/dexp/processing/filters/kernels/wiener_butterworth.py @@ -61,7 +61,7 @@ def wiener_butterworth_kernel( cutoffs = tuple(float(xp.count_nonzero(b) / b.size) for b in pass_band) # cutoffs = (max(cutoffs),)*psf_f.ndim - epsilon = math.sqrt((beta ** -2) - 1) + epsilon = math.sqrt((beta**-2) - 1) bwk_f = butterworth_kernel( shape=kernel.shape, diff --git a/dexp/processing/filters/sobel_filter.py b/dexp/processing/filters/sobel_filter.py index 76f943d9..8a2e6af4 100644 --- a/dexp/processing/filters/sobel_filter.py +++ b/dexp/processing/filters/sobel_filter.py @@ -77,7 +77,7 @@ def sobel_filter( for i in range(ndim): sobel_one_axis = xp.absolute(sp.ndimage.sobel(image, axis=i)) - sobel_image += sobel_one_axis ** exponent + sobel_image += sobel_one_axis**exponent if exponent == 1: pass diff --git a/dexp/processing/isonet/__init__.py b/dexp/processing/isonet/__init__.py deleted file mode 100644 index 83a51432..00000000 --- a/dexp/processing/isonet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from dexp.processing.isonet.isonet import IsoNet diff --git a/dexp/processing/isonet/demo/demo_isonet.py b/dexp/processing/isonet/demo/demo_isonet.py deleted file mode 100644 index 0530459d..00000000 --- a/dexp/processing/isonet/demo/demo_isonet.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy -from napari import Viewer, gui_qt -from scipy.ndimage import zoom -from tifffile import imread, imwrite - -from dexp.enhance import sharpen -from dexp.processing.isonet import IsoNet -from dexp.utils.timeit import timeit - -# image = imread('data/retina/cropped_farred_RFP_GFP_2109175_2color_sub_10.20.tif') -image = imread("../data/zfish/zfish2.tif") - -image = zoom(image, zoom=(1, 0.5, 0.5), order=0) - -image = sharpen(image, mode="hybrid", min=0, max=1024, margin_pad=False) - -print(f"shape={image.shape}, dtype={image.dtype}, min={image.min()}, max={image.max()}") -image_crop = image # [36-16:36+16,750-128:750+128,750-128:750+128] -print(image_crop.shape) - -dxy = 0.4 * 2 -dz = 6.349 -subsampling = dz / dxy -print(f"subsampling={subsampling}") - -# psf = nikon16x08na(dxy=dxy,dz=dz, size=16)[:,(15-1)//2,:] -psf = numpy.ones((3, 3)) / 9 - -isonet = IsoNet() - -# with timeit("Preparation:"): -# isonet.prepare(image_crop, subsampling=subsampling, psf=psf, threshold=0.97) -# -# with timeit("Training:"): -# isonet.train(max_epochs=40) - -with timeit("Evaluation:"): - restored = isonet.apply_pair(image_crop, subsampling=subsampling, batch_size=1) - -print(restored.shape) - -imwrite( - "../data/isotropic.tif", - restored, - imagej=True, - resolution=(1 / dxy, 1 / dxy), - metadata={"spacing": dz, "unit": "um"}, -) - -image = zoom(image, zoom=(subsampling, 1, 1), order=0) - -with gui_qt(): - viewer = Viewer() - - viewer.add_image(image, name="image (all)") - viewer.add_image(image_crop, name="image selection") - viewer.add_image(restored, name="restored") - viewer.add_image(psf, name="psf") diff --git a/dexp/processing/isonet/isonet.py b/dexp/processing/isonet/isonet.py deleted file mode 100644 index e5ecc20d..00000000 --- a/dexp/processing/isonet/isonet.py +++ /dev/null @@ -1,79 +0,0 @@ -import tempfile -import warnings -from os.path import join - -import numpy as np -from csbdeep.data import RawData, create_patches, no_background_patches -from csbdeep.data.transform import anisotropic_distortions -from csbdeep.io import load_training_data, save_training_data -from csbdeep.models import Config, IsotropicCARE - - -class IsoNet: - def __init__(self, context_name="default", subsampling=5): - """ """ - - self.context_folder = join(tempfile.gettempdir(), context_name) - self.training_data_path = join(self.context_folder, "training_data.npz") - self.model_path = join(self.context_folder, "model") - - self.model = None - - self.subsampling = subsampling - - def prepare(self, image, psf=np.ones((3, 3)) / 9, threshold=0.9): - print("image size =", image.shape) - print("Z subsample factor =", self.subsampling) - - raw_data = RawData.from_arrays(image, image, axes="ZYX") - - anisotropic_transform = anisotropic_distortions( - subsample=self.subsampling, - psf=psf, # use the actual PSF here - psf_axes="YX", - ) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - X, Y, XY_axes = create_patches( - raw_data=raw_data, - patch_size=(128, 128), - n_patches_per_image=512, - transforms=[anisotropic_transform], - patch_filter=no_background_patches(threshold=threshold, percentile=100 - 0.1), - ) - - assert X.shape == Y.shape - print("shape of X,Y =", X.shape) - print("axes of X,Y =", XY_axes) - - print("Saving training data...") - save_training_data(self.training_data_path, X, Y, XY_axes) - print("Done saving training data.") - - def train(self, max_epochs=100): - (X, Y), (X_val, Y_val), axes = load_training_data(self.training_data_path, validation_split=0.1, verbose=True) - - n_channel_in, n_channel_out = 1, 1 - - config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=30) - print(config) - vars(config) - - model = IsotropicCARE(config, "my_model", basedir=self.model_path) - - history = model.train(X, Y, validation_data=(X_val, Y_val), epochs=max_epochs) - - print(sorted(list(history.history.keys()))) - - def apply(self, image, batch_size=8): - if self.model is None: - print("Loading model.") - self.model = IsotropicCARE(config=None, name="my_model", basedir=self.model_path) - axes = "ZYX" - - print(f"Applying model to image of shape: {image.shape}...") - restored = self.model.predict(image, axes, self.subsampling, batch_size=batch_size) - print("done!") - - return restored diff --git a/dexp/processing/morphology/_test/test_componet_tree.py b/dexp/processing/morphology/_test/test_componet_tree.py index 3271d132..8e5f67db 100644 --- a/dexp/processing/morphology/_test/test_componet_tree.py +++ b/dexp/processing/morphology/_test/test_componet_tree.py @@ -26,7 +26,7 @@ def test_white_area_top_hat(dexp_nuclei_background_data, display_test: bool): for props in regionprops(labels): max_area = max(max_area, props.area) - max_area = max_area / (sampling ** 3) + 1 + max_area = max_area / (sampling**3) + 1 estimated_cells = area_white_top_hat(both, area_threshold=max_area, sampling=sampling) @@ -50,6 +50,8 @@ def test_white_area_top_hat(dexp_nuclei_background_data, display_test: bool): if __name__ == "__main__": + from dexp.utils.testing import test_as_demo + # the same as executing from the CLI # pytest -s --display True - pytest.main([__file__, "-s", "--display", "True"]) + test_as_demo(__file__) diff --git a/dexp/processing/multiview_lightsheet/deconvolution/demo/demo_simview_deconv.py b/dexp/processing/multiview_lightsheet/deconvolution/demo/demo_simview_deconv.py index 364dd77f..3d42f247 100644 --- a/dexp/processing/multiview_lightsheet/deconvolution/demo/demo_simview_deconv.py +++ b/dexp/processing/multiview_lightsheet/deconvolution/demo/demo_simview_deconv.py @@ -78,7 +78,7 @@ def f(_image): return lucy_richardson_deconvolution(image=_image, num_iterations=10, **parameters) view_dehazed_deconvolved = scatter_gather_i2i( - f, view_dehazed, tiles=320, margins=psf_size, clip=False, to_numpy=True + view_dehazed, f, tiles=320, margins=psf_size, clip=False, to_numpy=True ) with timeit("lucy_richardson_deconvolution_wb"): @@ -87,7 +87,7 @@ def f(_image): return lucy_richardson_deconvolution(image=_image, num_iterations=3, **parameters, **parameters_wb) view_dehazed_deconvolved_wb = scatter_gather_i2i( - f, view_dehazed, tiles=320, margins=psf_size, clip=False, to_numpy=True + view_dehazed, f, tiles=320, margins=psf_size, clip=False, to_numpy=True ) with gui_qt(): diff --git a/dexp/processing/multiview_lightsheet/fusion/mvsols.py b/dexp/processing/multiview_lightsheet/fusion/mvsols.py index d21c14f2..56f81e0b 100644 --- a/dexp/processing/multiview_lightsheet/fusion/mvsols.py +++ b/dexp/processing/multiview_lightsheet/fusion/mvsols.py @@ -166,7 +166,7 @@ def msols_fuse_1C2L( raise ValueError("The two views must have same dtype!") if C0L0.shape != C0L1.shape: - raise ValueError("The two views must have same shapes!") + raise ValueError(f"The two views must have same shapes! Found {C0L0.shape} and {C0L1.shape}") if type(Backend.current()) is NumpyBackend: internal_dtype = numpy.float32 @@ -355,7 +355,7 @@ def msols_fuse_1C2L( sigma = illumination_correction_sigma length = C1Lx.shape[1] correction = xp.linspace(-length // 2, length // 2, num=length) - correction = xp.exp(-(correction ** 2) / (2 * sigma * sigma)) + correction = xp.exp(-(correction**2) / (2 * sigma * sigma)) correction = 1.0 / correction correction = correction.astype(dtype=internal_dtype) C1Lx *= correction[xp.newaxis, :, xp.newaxis] diff --git a/dexp/processing/registration/translation_nd.py b/dexp/processing/registration/translation_nd.py index 12ae280c..62fbd763 100644 --- a/dexp/processing/registration/translation_nd.py +++ b/dexp/processing/registration/translation_nd.py @@ -126,7 +126,7 @@ def register_translation_nd( # Compute confidence: masked_correlation = correlation.copy() - mask_size = tuple(max(8, int(s ** 0.9) // 8) for s in masked_correlation.shape) + mask_size = tuple(max(8, int(s**0.9) // 8) for s in masked_correlation.shape) masked_correlation[tuple(slice(rs - s, rs + s) for rs, s in zip(rough_shift, mask_size))] = 0 background_correlation_max = xp.max(masked_correlation) epsilon = 1e-6 diff --git a/dexp/processing/registration/warp_multiscale_nd.py b/dexp/processing/registration/warp_multiscale_nd.py index def97ff5..270e4813 100644 --- a/dexp/processing/registration/warp_multiscale_nd.py +++ b/dexp/processing/registration/warp_multiscale_nd.py @@ -78,7 +78,7 @@ def register_warp_multiscale_nd( image = image_b with asection(f"register_warp_nd {i}"): - nb_div = 2 ** i + nb_div = 2**i # Note: the trick below is a 'ceil division' ceil(u/v) == -(-u//v) chunks = tuple(max(min_chunk, -(-s // nb_div)) for s in image_a.shape) margins = tuple(max(min_margin, int(c * r)) for c, r in zip(chunks, margin_ratios)) diff --git a/dexp/processing/registration/warp_nd.py b/dexp/processing/registration/warp_nd.py index caa22584..b73a186b 100644 --- a/dexp/processing/registration/warp_nd.py +++ b/dexp/processing/registration/warp_nd.py @@ -53,7 +53,7 @@ def f(x, y): shift, confidence = model.get_shift_and_confidence() return xp.asarray(shift), xp.asarray(confidence) - vector_field, confidence = scatter_gather_i2v(f, (image_a, image_b), tiles=chunks, margins=margins) + vector_field, confidence = scatter_gather_i2v((image_a, image_b), f, tiles=chunks, margins=margins) model = WarpRegistrationModel(vector_field=vector_field, confidence=confidence, force_numpy=force_numpy) diff --git a/dexp/processing/remove_beads/beadsremover.py b/dexp/processing/remove_beads/beadsremover.py index 0ceedd0f..f0a810e1 100644 --- a/dexp/processing/remove_beads/beadsremover.py +++ b/dexp/processing/remove_beads/beadsremover.py @@ -123,7 +123,10 @@ def detect_beads(self, array: xpArray, estimated_psf: Optional[ArrayLike] = None if np.all(np.equal(bead.shape, self.psf_size)): beads.append(bead) - avg_bead = Bead.from_data(np.median(np.stack(tuple(beads)), axis=0).reshape((self.psf_size,) * array.ndim)) + if len(beads) == 0: + avg_bead = np.zeros((self.psf_size,) * array.ndim) + else: + avg_bead = Bead.from_data(np.median(np.stack(tuple(beads)), axis=0).reshape((self.psf_size,) * array.ndim)) # select regions that are similar to the median estimated bead selected_beads = [] diff --git a/dexp/processing/segmentation/__init__.py b/dexp/processing/segmentation/__init__.py new file mode 100644 index 00000000..e6a55540 --- /dev/null +++ b/dexp/processing/segmentation/__init__.py @@ -0,0 +1 @@ +from dexp.processing.segmentation.watershed import roi_watershed_from_minima diff --git a/dexp/processing/segmentation/watershed.py b/dexp/processing/segmentation/watershed.py new file mode 100644 index 00000000..c7ba3190 --- /dev/null +++ b/dexp/processing/segmentation/watershed.py @@ -0,0 +1,40 @@ +import numpy as np +import scipy.ndimage as ndi +from skimage.segmentation import relabel_sequential + + +def roi_watershed_from_minima(image: np.ndarray, mask: np.ndarray, **kwargs) -> np.ndarray: + """Computes the watershed from minima from pyift for each connected component ROI, it's useful to save memory. + + Parameters + ---------- + image : np.ndarray + Input gradient (basins) image. + mask : np.ndarray + Input mask for ROI computation. + + Returns + ------- + np.ndarray + Output watershed labels. + """ + from pyift.shortestpath import watershed_from_minima + + labels = np.zeros(image.shape, dtype=np.int32) + mask, _ = ndi.label(mask) + + offset = 1 + for slicing in ndi.find_objects(mask): + if slicing is None: + continue + _, lb = watershed_from_minima( + image[slicing], + mask[slicing] > 0, + **kwargs, + ) + lb[lb < 0] = 0 # non-masked regions are -1 + lb, _, _ = relabel_sequential(lb, offset=offset) + offset = lb.max() + 1 + labels[slicing] = lb + + return labels diff --git a/dexp/processing/utils/_test/test_scatter_gather_i2i.py b/dexp/processing/utils/_test/test_scatter_gather_i2i.py index 71a1dc14..a95d3a22 100644 --- a/dexp/processing/utils/_test/test_scatter_gather_i2i.py +++ b/dexp/processing/utils/_test/test_scatter_gather_i2i.py @@ -25,7 +25,7 @@ def f(x): result_ref = 0 * image + 17 with timeit("scatter_gather(f)"): - result = scatter_gather_i2i(f, image, tiles=(length_xy // splits,) * ndim, margins=filter_size // 2) + result = scatter_gather_i2i(image, f, tiles=(length_xy // splits,) * ndim, margins=filter_size // 2) image = Backend.to_numpy(image) result_ref = Backend.to_numpy(result_ref) diff --git a/dexp/processing/utils/_test/test_scatter_gather_i2v.py b/dexp/processing/utils/_test/test_scatter_gather_i2v.py index 4e11867f..3970faf1 100644 --- a/dexp/processing/utils/_test/test_scatter_gather_i2v.py +++ b/dexp/processing/utils/_test/test_scatter_gather_i2v.py @@ -20,7 +20,7 @@ def f(x, y): with timeit("scatter_gather(f)"): chunks = (length_xy // splits,) * ndim - result1, result2 = scatter_gather_i2v(f, (image1, image2), tiles=chunks, margins=8) + result1, result2 = scatter_gather_i2v((image1, image2), f, tiles=chunks, margins=8) assert result1.ndim == ndim + 1 assert result2.ndim == ndim + 1 diff --git a/dexp/processing/utils/linear_solver.py b/dexp/processing/utils/linear_solver.py index 3fdcec12..2f36137d 100644 --- a/dexp/processing/utils/linear_solver.py +++ b/dexp/processing/utils/linear_solver.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Sequence, Tuple import numpy @@ -68,9 +69,10 @@ def fun(x): aprint(f"Warning: optimisation finished after {result.nit} iterations!") if not result.success: - raise RuntimeError( + warnings.warn( f"Convergence failed: '{result.message}' after {result.nit} " + "iterations and {result.nfev} function evaluations." ) + return Backend.to_backend(x0) return Backend.to_backend(result.x) diff --git a/dexp/processing/utils/scatter_gather_i2i.py b/dexp/processing/utils/scatter_gather_i2i.py index 81a67c33..23ae98cc 100644 --- a/dexp/processing/utils/scatter_gather_i2i.py +++ b/dexp/processing/utils/scatter_gather_i2i.py @@ -9,8 +9,8 @@ def scatter_gather_i2i( - function: Callable, image: xpArray, + function: Callable, tiles: Union[int, Tuple[int, ...]], margins: Optional[Union[int, Tuple[int, ...]]] = None, normalise: bool = False, @@ -28,8 +28,8 @@ def scatter_gather_i2i( Parameters ---------- - function : unary function image : input image (can be any backend, numpy ) + function : unary function tiles : tile sizes to cut input image into, can be a single integer or a tuple of integers. margins : margins to add to each tile, can be a single integer or a tuple of integers. if None, no margins are added. diff --git a/dexp/processing/utils/scatter_gather_i2v.py b/dexp/processing/utils/scatter_gather_i2v.py index 27cdcf5f..583101cd 100644 --- a/dexp/processing/utils/scatter_gather_i2v.py +++ b/dexp/processing/utils/scatter_gather_i2v.py @@ -9,8 +9,8 @@ def scatter_gather_i2v( - function: Callable, images: Union[xpArray, Tuple[xpArray]], + function: Callable, tiles: Union[int, Tuple[int, ...]], margins: Optional[Union[int, Tuple[int, ...]]] = None, to_numpy: bool = True, @@ -24,9 +24,9 @@ def scatter_gather_i2v( Parameters ---------- + images : sequence of images (can be any backend, numpy) function : n-ary function. Must accept one or more arrays -- of same shape --- and return a _tuple_ of arrays as result. - images : sequence of images (can be any backend, numpy) tiles : tiles to cut input image into, can be a single integer or a tuple of integers. margins : margins to add to each tile, can be a single integer or a tuple of integers. to_numpy : should the result be a numpy array? Very usefull when the compute backend diff --git a/dexp/utils/__init__.py b/dexp/utils/__init__.py index 664812f3..e376e787 100644 --- a/dexp/utils/__init__.py +++ b/dexp/utils/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, List, Union import numpy @@ -9,8 +9,6 @@ except ImportError: xpArray = numpy.ndarray -from dexp.utils.backends.cupy_backend import is_cupy_available - def dict_or(lhs: Dict, rhs: Dict) -> Dict: """ @@ -23,3 +21,22 @@ def dict_or(lhs: Dict, rhs: Dict) -> Dict: lhs[k] = v return lhs + + +def overwrite2mode(overwrite: bool) -> str: + """Returns writing mode from overwrite flag""" + return "w" if overwrite else "w-" + + +def compress_dictionary_lists_length(d: Dict, max_length: int) -> Dict: + """Reduces the length of lists in an dictionary and convert to string. Useful for printing""" + out = {} + for k, v in d.items(): + if isinstance(v, Dict): + out[k] = compress_dictionary_lists_length(v, max_length) + elif isinstance(v, List) and len(v) > max_length: + out[k] = str(v[:max_length])[:-1] + ", ...]" + else: + out[k] = v + + return out diff --git a/dexp/utils/backends/__init__.py b/dexp/utils/backends/__init__.py index af20df0e..1b605014 100644 --- a/dexp/utils/backends/__init__.py +++ b/dexp/utils/backends/__init__.py @@ -1,4 +1,4 @@ -from dexp.utils.backends.backend import Backend +from dexp.utils.backends.backend import Backend, dispatch_data_to_backend from dexp.utils.backends.best_backend import BestBackend from dexp.utils.backends.cupy_backend import CupyBackend from dexp.utils.backends.numpy_backend import NumpyBackend diff --git a/dexp/utils/backends/_cupy/texture/_test/test_texture.py b/dexp/utils/backends/_cupy/texture/_test/test_texture.py index 06fe274d..c95c5b0c 100644 --- a/dexp/utils/backends/_cupy/texture/_test/test_texture.py +++ b/dexp/utils/backends/_cupy/texture/_test/test_texture.py @@ -251,7 +251,7 @@ def test_basic_cupy_texture_leak(): with CupyBackend(): # allocate input/output arrays length = 512 - tex_data = cupy.arange(length ** 3, dtype=cupy.float32).reshape(length, length, length) + tex_data = cupy.arange(length**3, dtype=cupy.float32).reshape(length, length, length) with asection("loop"): for i in range(100): diff --git a/dexp/utils/backends/backend.py b/dexp/utils/backends/backend.py index 9a2e5255..a82bf586 100644 --- a/dexp/utils/backends/backend.py +++ b/dexp/utils/backends/backend.py @@ -1,9 +1,10 @@ +import importlib import threading import types from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict, Iterable, List, Optional, Tuple -import numpy +import numpy as np from dexp.utils import xpArray @@ -62,7 +63,7 @@ def set(backend: "Backend"): Backend._local.backend_stack.append(backend) @staticmethod - def to_numpy(array: xpArray, dtype=None, force_copy: bool = False) -> numpy.ndarray: + def to_numpy(array: xpArray, dtype=None, force_copy: bool = False) -> np.ndarray: return Backend.current()._to_numpy(array, dtype=dtype, force_copy=force_copy) @staticmethod @@ -70,13 +71,17 @@ def to_backend(array: xpArray, dtype=None, force_copy: bool = False) -> Any: return Backend.current()._to_backend(array, dtype=dtype, force_copy=force_copy) @staticmethod - def get_xp_module(array: xpArray = None) -> Any: + def get_xp_module(array: Optional[xpArray] = None) -> types.ModuleType: return Backend.current()._get_xp_module(array) @staticmethod - def get_sp_module(array: xpArray = None) -> Any: + def get_sp_module(array: Optional[xpArray] = None) -> types.ModuleType: return Backend.current()._get_sp_module(array) + @staticmethod + def get_skimage_submodule(submodule: str, array: Optional[xpArray] = None) -> types.ModuleType: + return Backend.current()._get_skimage_submodule(submodule, array) + def __enter__(self): if not hasattr(Backend._local, "backend_stack"): Backend._local.backend_stack = [] @@ -101,7 +106,7 @@ def close(self): self.clear_memory_pool() @abstractmethod - def _to_numpy(self, array: xpArray, dtype=None, force_copy: bool = False) -> numpy.ndarray: + def _to_numpy(self, array: xpArray, dtype=None, force_copy: bool = False) -> np.ndarray: """Converts backend array to numpy. If array is already a numpy array it is returned unchanged. Parameters @@ -129,8 +134,7 @@ def _to_backend(self, array: xpArray, dtype=None, force_copy: bool = False) -> A """ raise NotImplementedError("Method not implemented!") - @abstractmethod - def _get_xp_module(self, array: xpArray = None) -> types.ModuleType: + def _get_xp_module(self, array: xpArray) -> types.ModuleType: """Returns the numpy-like module for a given array Parameters @@ -142,10 +146,14 @@ def _get_xp_module(self, array: xpArray = None) -> types.ModuleType: ------- numpy-like module """ - raise NotImplementedError("Method not implemented!") + try: + import cupy - @abstractmethod - def _get_sp_module(self, array: xpArray = None) -> types.ModuleType: + return cupy.get_array_module(array) + except ModuleNotFoundError: + return np + + def _get_sp_module(self, array: xpArray) -> types.ModuleType: """Returns the scipy-like module for a given array Parameters @@ -156,7 +164,66 @@ def _get_sp_module(self, array: xpArray = None) -> types.ModuleType: ------- scipy-like module """ - raise NotImplementedError("Method not implemented!") + try: + import cupyx + + return cupyx.scipy.get_array_module(array) + except ModuleNotFoundError: + import scipy + + return scipy + + def _get_skimage_submodule(self, submodule: str, array: xpArray) -> types.ModuleType: + """Returns the skimage-like module for a given array (skimage or cucim) + + Parameters + ---------- + array : xpArray, optional + Reference array + Returns + ------- + types.ModuleType + skimage-like module + """ + + # avoiding try-catch + try: + import cupy + + if isinstance(array, cupy.ndarray): + return importlib.import_module(f"cucim.skimage.{submodule}") + except ModuleNotFoundError: + return importlib.import_module(f"skimage.{submodule}") + + return importlib.import_module(f"skimage.{submodule}") def synchronise(self) -> None: pass + + +def _maybe_to_backend(backend: Backend, obj: Any) -> Any: + + if isinstance(obj, np.ndarray): + # this must be first because arrays are iterables + return backend.to_backend(obj) + + if isinstance(obj, Dict): + dispatch_data_to_backend(backend, [], obj) + return obj + + elif isinstance(obj, (List, Tuple)): + obj = list(obj) + dispatch_data_to_backend(backend, obj, {}) + return obj + + else: + return obj + + +def dispatch_data_to_backend(backend: Backend, args: Iterable, kwargs: Dict) -> None: + """Function to move arrays to backend INPLACE!""" + for i, v in enumerate(args): + args[i] = _maybe_to_backend(backend, v) + + for k, v in kwargs.items(): + kwargs[k] = _maybe_to_backend(backend, v) diff --git a/dexp/utils/backends/cupy_backend.py b/dexp/utils/backends/cupy_backend.py index 147f67a0..ca3be19e 100644 --- a/dexp/utils/backends/cupy_backend.py +++ b/dexp/utils/backends/cupy_backend.py @@ -1,8 +1,10 @@ import gc +import importlib import math import os import threading -from typing import Any +import types +from typing import Any, Optional import numpy from arbol import aprint @@ -255,25 +257,24 @@ def _to_backend(self, array: xpArray, dtype=None, force_copy: bool = False) -> A with self.cupy_device: return cupy.asarray(array, dtype=dtype) - def _get_xp_module(self, array: xpArray = None) -> Any: + def _get_xp_module(self, array: Optional[xpArray] = None) -> types.ModuleType: if array is None: import cupy return cupy - else: - import cupy - - return cupy.get_array_module(array) + return super()._get_xp_module(array) - def _get_sp_module(self, array: xpArray = None) -> Any: + def _get_sp_module(self, array: Optional[xpArray] = None) -> types.ModuleType: if array is None: import cupyx return cupyx.scipy - else: - import cupyx + return super()._get_sp_module(array) - return cupyx.scipy.get_array_module(array) + def _get_skimage_submodule(self, submodule: str, array: Optional[xpArray] = None) -> types.ModuleType: + if array is None: + return importlib.import_module(f"cucim.skimage.{submodule}") + return super()._get_skimage_submodule(submodule, array) def is_cupy_available() -> bool: diff --git a/dexp/utils/backends/numpy_backend.py b/dexp/utils/backends/numpy_backend.py index d94d9c36..95315609 100644 --- a/dexp/utils/backends/numpy_backend.py +++ b/dexp/utils/backends/numpy_backend.py @@ -1,4 +1,6 @@ -from typing import Any +import importlib +import types +from typing import Any, Optional import numpy import scipy @@ -62,24 +64,17 @@ def _to_backend(self, array: xpArray, dtype=None, force_copy: bool = False) -> A else: return array - def _get_xp_module(self, array: xpArray = None) -> Any: + def _get_xp_module(self, array: Optional[xpArray] = None) -> types.ModuleType: if array is None: return numpy - else: - try: - import cupy - - return cupy.get_array_module(array) - except ModuleNotFoundError: - return numpy + return super()._get_xp_module(array) - def _get_sp_module(self, array: xpArray = None) -> Any: + def _get_sp_module(self, array: Optional[xpArray] = None) -> types.ModuleType: if array is None: return scipy - else: - try: - import cupyx + return super()._get_sp_module(array) - return cupyx.scipy.get_array_module(array) - except ModuleNotFoundError: - return scipy + def _get_skimage_submodule(self, submodule: str, array: Optional[xpArray] = None) -> types.ModuleType: + if array is None: + return importlib.import_module(f"skimage.{submodule}") + return super()._get_skimage_submodule(submodule, array) diff --git a/dexp/utils/lock.py b/dexp/utils/lock.py new file mode 100644 index 00000000..22d4c760 --- /dev/null +++ b/dexp/utils/lock.py @@ -0,0 +1,15 @@ +import uuid +from pathlib import Path + +import fasteners +from arbol import aprint + + +def create_lock(prefix: str = "") -> fasteners.InterProcessLock: + """Creates an multi-processing safe lock using unique unit idenfier.""" + identifier = uuid.uuid4().hex + worker_directory = Path.home() / f".multiprocessing/{identifier}" + worker_directory.mkdir(exist_ok=True, parents=True) + aprint(f"Created lock directory {worker_directory}") + + return fasteners.InterProcessLock(path=worker_directory / "lock.file") diff --git a/dexp/utils/testing/__init__.py b/dexp/utils/testing/__init__.py index 2eef8c2c..77b222b5 100644 --- a/dexp/utils/testing/__init__.py +++ b/dexp/utils/testing/__init__.py @@ -1,6 +1,6 @@ import pytest -from dexp.utils.testing.testing import execute_both_backends +from dexp.utils.testing.testing import cupy_only, execute_both_backends def test_as_demo(file: str) -> None: diff --git a/dexp/utils/testing/testing.py b/dexp/utils/testing/testing.py index 8c9fef5d..4c6fc851 100644 --- a/dexp/utils/testing/testing.py +++ b/dexp/utils/testing/testing.py @@ -1,39 +1,11 @@ import inspect from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Callable -import numpy as np import pytest -from dexp.utils.backends import Backend, CupyBackend, NumpyBackend - - -def _maybe_to_backend(backend: Backend, obj: Any) -> Any: - - if isinstance(obj, np.ndarray): - # this must be first because arrays are iterables - return backend.to_backend(obj) - - if isinstance(obj, Dict): - dispatch_data_to_backend(backend, [], obj) - return obj - - elif isinstance(obj, (List, Tuple)): - obj = list(obj) - dispatch_data_to_backend(backend, obj, {}) - return obj - - else: - return obj - - -def dispatch_data_to_backend(backend: Backend, args: Iterable, kwargs: Dict) -> None: - """Function to move arrays to backend.""" - for i, v in enumerate(args): - args[i] = _maybe_to_backend(backend, v) - - for k, v in kwargs.items(): - kwargs[k] = _maybe_to_backend(backend, v) +from dexp.utils.backends import CupyBackend, NumpyBackend, dispatch_data_to_backend +from dexp.utils.backends.cupy_backend import is_cupy_available def _add_cuda_signature(func: Callable) -> Callable: @@ -73,3 +45,14 @@ def wrapper(cuda: bool, *args, **kwargs): func(*args, **kwargs) return _add_cuda_signature(wrapper) + + +def cupy_only(func: Callable) -> Callable: + """Helper function to skip test function is cupy is not found.""" + + @pytest.mark.skipif(not is_cupy_available(), reason=f"Cupy not found. Skipping {func.__name__} gpu test.") + @wraps(func) + def _func(*args, **kwargs): + return func(*args, **kwargs) + + return _func diff --git a/dexp/volumerender/quaternion.py b/dexp/volumerender/quaternion.py index c7b7555d..4678fd32 100644 --- a/dexp/volumerender/quaternion.py +++ b/dexp/volumerender/quaternion.py @@ -54,9 +54,9 @@ def toRotation4(self): a, b, c, d = self.data return np.array( [ - [a ** 2 + b ** 2 - c ** 2 - d ** 2, 2 * (b * c - a * d), 2 * (b * d + a * c), 0], - [2 * (b * c + a * d), a ** 2 - b ** 2 + c ** 2 - d ** 2, 2 * (c * d - a * b), 0], - [2 * (b * d - a * c), 2 * (c * d + a * b), a ** 2 - b ** 2 - c ** 2 + d ** 2, 0], + [a**2 + b**2 - c**2 - d**2, 2 * (b * c - a * d), 2 * (b * d + a * c), 0], + [2 * (b * c + a * d), a**2 - b**2 + c**2 - d**2, 2 * (c * d - a * b), 0], + [2 * (b * d - a * c), 2 * (c * d + a * b), a**2 - b**2 - c**2 + d**2, 0], [0, 0, 0, 1], ] ) @@ -65,9 +65,9 @@ def toRotation3(self): a, b, c, d = self.data return np.array( [ - [a ** 2 + b ** 2 - c ** 2 - d ** 2, 2 * (b * c - a * d), 2 * (b * d + a * c)], - [2 * (b * c + a * d), a ** 2 - b ** 2 + c ** 2 - d ** 2, 2 * (c * d - a * b)], - [2 * (b * d - a * c), 2 * (c * d + a * b), a ** 2 - b ** 2 - c ** 2 + d ** 2], + [a**2 + b**2 - c**2 - d**2, 2 * (b * c - a * d), 2 * (b * d + a * c)], + [2 * (b * c + a * d), a**2 - b**2 + c**2 - d**2, 2 * (c * d - a * b)], + [2 * (b * d - a * c), 2 * (c * d + a * b), a**2 - b**2 - c**2 + d**2], ] ) diff --git a/dexp/volumerender/transform_matrices.py b/dexp/volumerender/transform_matrices.py index 596ed77f..50f98e10 100644 --- a/dexp/volumerender/transform_matrices.py +++ b/dexp/volumerender/transform_matrices.py @@ -22,7 +22,7 @@ def mat4_scale(x=1.0, y=1.0, z=1.0): def mat4_rotation(w=0, x=1, y=0, z=0): n = np.array([x, y, z], np.float32) - n *= 1.0 / np.sqrt(1.0 * np.sum(n ** 2)) + n *= 1.0 / np.sqrt(1.0 * np.sum(n**2)) q = Quaternion(np.cos(0.5 * w), *(np.sin(0.5 * w) * n)) return q.toRotation4() @@ -105,11 +105,11 @@ def mat4_lookat(eye, center, up): _fwd = _center - _eye # normalize - _fwd *= 1.0 / np.sqrt(np.sum(_fwd ** 2)) + _fwd *= 1.0 / np.sqrt(np.sum(_fwd**2)) s = np.cross(_fwd, _up) - s *= 1.0 / np.sqrt(np.sum(s ** 2)) + s *= 1.0 / np.sqrt(np.sum(s**2)) _up = np.cross(s, _fwd) diff --git a/docs/Makefile b/docs/Makefile index a6b8fc07..dee6bf8b 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,6 +3,7 @@ SPHINXOPTS = SPHINXBUILD = sphinx-build BUILDDIR = build +SRCDIR = source # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 @@ -19,6 +20,7 @@ help: clean: -rm -rf $(BUILDDIR)/* + -rm -rf ${SRCDIR}/api/_autosummary/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html diff --git a/docs/source/commands/dexp_commands.rst b/docs/source/commands/dexp_commands.rst index 571205d0..708c823b 100644 --- a/docs/source/commands/dexp_commands.rst +++ b/docs/source/commands/dexp_commands.rst @@ -3,6 +3,11 @@ DEXP Commands ============= +.. contents:: + :local: + :class: this-will-duplicate-information-and-it-is-still-useful-here + + .. click:: dexp.cli.dexp_commands.add:add :prog: add :nested: full @@ -23,6 +28,11 @@ DEXP Commands :nested: full +.. click:: dexp.cli.dexp_commands.denoise:denoise + :prog: deskew + :nested: full + + .. click:: dexp.cli.dexp_commands.deskew:deskew :prog: deskew :nested: full @@ -38,8 +48,13 @@ DEXP Commands :nested: full -.. click:: dexp.cli.dexp_commands.isonet:isonet - :prog: isonet +.. click:: dexp.cli.dexp_commands.generic:generic + :prog: info + :nested: full + + +.. click:: dexp.cli.dexp_commands.histogram:histogram + :prog: info :nested: full @@ -48,8 +63,8 @@ DEXP Commands :nested: full -.. click:: dexp.cli.dexp_commands.serve:serve - :prog: serve +.. click:: dexp.cli.dexp_commands.segment:segment + :prog: stabilize :nested: full @@ -71,8 +86,3 @@ DEXP Commands .. click:: dexp.cli.dexp_commands.view:view :prog: view :nested: full - - -.. click:: dexp.cli.video_commands.volrender:volrender - :prog: volrender - :nested: full diff --git a/docs/source/commands/video_commands.rst b/docs/source/commands/video_commands.rst index db7abe79..67ccdf8c 100644 --- a/docs/source/commands/video_commands.rst +++ b/docs/source/commands/video_commands.rst @@ -3,6 +3,11 @@ Video Commands ============== +.. contents:: + :local: + :class: this-will-duplicate-information-and-it-is-still-useful-here + + .. click:: dexp.cli.video_commands.blend:blend :prog: blend :nested: full diff --git a/setup.py b/setup.py index 3051ac5e..89c24d9e 100755 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "higra", "pyotf", "pyift", - "tensorstore", + "tensorstore==0.1.16", ], "gpu": [ "gputools",