Skip to content

Commit

Permalink
Move WarpByMap to external SOFIMA repository.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700414638
  • Loading branch information
timblakely authored and copybara-github committed Nov 26, 2024
1 parent 77e3a3b commit 141f5af
Showing 1 changed file with 285 additions and 0 deletions.
285 changes: 285 additions & 0 deletions processor/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
"""Processors for image warping and rendering."""

from concurrent import futures
import dataclasses
import json
from typing import Any, Sequence

from absl import logging
from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.common import file
from connectomics.common import geom_utils
from connectomics.common import utils
from connectomics.volume import mask as mask_lib
from connectomics.volume import metadata
from connectomics.volume import subvolume
from connectomics.volume import subvolume_processor
import edt
Expand Down Expand Up @@ -333,3 +340,281 @@ def process(
ret = ret.astype(self.output_type(subvol.data.dtype))

return self.crop_box_and_data(box, ret[None, ...])


class WarpByMap(subvolume_processor.SubvolumeProcessor):
"""Warps data according to an inverse coordinate map.
Optionally performs on-the-fly downsampling of the warped data. This
is useful when the source data is available at a higher resolution
than that required for all downstream processing steps.
This processor is run over a template output volume, and loads the
coordinate maps and the data to warp from other volumes, as configured
in the constructor.
"""

crop_at_borders = False
output_num = subvolume_processor.OutputNums.MULTI
ignores_input_data = True
_mask_configs: mask_lib.MaskConfigs | None = None

@dataclasses.dataclass(eq=True)
class Config(utils.NPDataClassJsonMixin):
"""Configuration for WarpByMap.
Attributes:
stride: pixel distance between nearest neighbors of the coordinate map, in
pixels of the output volume
map_volinfo: path to the volinfo of the coordinate map volume
data_volinfo: path to the volinfo of the source data volume
map_decorator_specs: decorator specs of the coordinate map volume, as a
JSON string or Python object
data_decorator_specs: decorator specs of the source data volume, as a JSON
string or Python object
map_scale: multiplicative factor for the coordinate map values
interpolation: interpolation scheme to use (see warp.warp_subvolume)
downsample: downsampling factor to apply to the data after warping;
downsampling is performed in XY via area-averaging
offset: (deprecated, do not use)
mask_config: MaskConfigs proto in text format; pixels corresponding to the
positive entries of the mask will be zeroed out, and warping of
completely masked subvolumes will be skipped
source_cache_bytes: max. number of bytes to use for the raw data chunk
cache; it is recommended for the buffer to be large enough to hold a
single layer of chunks covering the complete subvolume laterally
"""

# TODO(blakely): Convert stride to Sequence[float]
stride: float
map_volinfo: str
data_volinfo: str
map_decorator_specs: str | dict[str, Any] | None = None
data_decorator_specs: str | dict[str, Any] | None = None
map_scale: float = 1.0
interpolation: str | None = None
downsample: int = 1
offset: float = 0.0
mask_configs: str | mask_lib.MaskConfigs | None = None
source_cache_bytes = int(1e9)

def __init__(
self,
config: Config,
input_volinfo=None,
):
"""Constructor.
Args:
config: Configuration for this processor.
input_volinfo: VolumeInfo proto of the input.
"""
self._map_volinfo = config.map_volinfo
self._scale = config.map_scale
self._interpolation = config.interpolation
self._data_volinfo = config.data_volinfo

# TODO(blakely): Convert these to dataclasses.
def _get_specs(specs):
if specs is None:
return []
elif isinstance(specs, str):
return json.loads(specs)
else:
return specs

self._data_decorator_specs = _get_specs(config.data_decorator_specs)
self._map_decorator_specs = _get_specs(config.map_decorator_specs)

self._downsample = np.array(
[config.downsample, config.downsample, 1]
) # xyz
self._target_stride = config.stride
self._source_stride = config.stride * config.downsample
self._offset = config.offset
self._source_cache_bytes = config.source_cache_bytes

self._mask_config = None
if config.mask_configs is not None:
mask_configs = config.mask_configs
if isinstance(mask_configs, str):
mask_configs = self._get_mask_configs(config.mask_configs)
self._mask_configs = mask_configs

def _load_and_warp(
self,
data_box: bounding_box.BoundingBox,
data_vol,
map_data: np.ndarray,
map_box: bounding_box.BoundingBox,
out_box: bounding_box.BoundingBox,
) -> np.ndarray | None:
"""Produces warped data for 'out_box'."""
data = data_vol[data_box.to_slice4d()]
if self._mask_configs is not None:
mask = self._build_mask(self._mask_configs, data_box)
for c in range(data.shape[0]):
data[c, ...][mask] = 0
else:
mask = None

if mask is not None and np.all(mask):
return

return warp.warp_subvolume(
data,
data_box,
map_data,
map_box,
self._source_stride,
out_box,
self._interpolation,
self._offset,
)

def _get_map_for_box(
self, box: bounding_box.BoundingBox
) -> tuple[bounding_box.BoundingBox | None, np.ndarray | None]:
"""Loads coordinate map corresponding to the output subvolume."""
s = 1.0 / self._target_stride
map_box = box.scale([s, s, 1.0]).adjusted_by(
start=(-2, -2, 0), end=(2, 2, 0)
)
map_vol = self._map_volinfo
if self._map_decorator_specs:
map_vol = metadata.DecoratedVolume(
path=self._map_volinfo,
decorator_specs=self._map_decorator_specs,
)
map_vol = self._open_volume(map_vol)
map_box = map_vol.clip_box_to_volume(map_box)
if map_box is None or np.all(map_box.size == 0):
logging.debug('Clipped map box is None for box %r', box)
return None, None
rel_map = map_vol[map_box.to_slice4d()].astype(np.float64) * self._scale

if np.all(np.isnan(rel_map)):
logging.debug('Map is invalid for %r.', map_box)
return None, None

return map_box, rel_map

def _generate_boxes_to_warp(self, data_vol, box: bounding_box.BoundingBox):
"""Generates work items for which to perform warping."""
map_box, rel_map = self._get_map_for_box(box)
if map_box is None:
logging.debug('No map found for %r.', box)
return

data_box = map_utils.outer_box(rel_map, map_box, self._source_stride, 1)
data_box = data_vol.clip_box_to_volume(data_box)
if data_box is None or np.all(map_box.size == 0):
logging.debug('Data out of bounds for map: %r.', map_box)
return

# 32k is the max input size supported by OpenCV's remap() function
# used in warp.warp_subvolume. If the input is smaller, we can generate
# output for 'box' directly.
if np.all(data_box.size < 2**15):
yield box, data_box, rel_map, map_box
return

# Bail out if the output box is small, yet the input data remains large.
if np.any(box.size[:2] < self._target_stride * 3):
logging.debug('Output box too small: %r', box)
# TODO(blakely): Re-enable counters
# beam.metrics.Metrics.counter('warp-by-map', 'map-failures').inc()
return

# Otherwise, divide the current output box into a 2x2 grid of smaller
# boxes, and try again.
subvol_size = np.array(list(-(-box.size[:2] // 2)) + [box.size[2]])
subvol_size = -(-subvol_size // self._downsample) * self._downsample

calc = box_generator.BoxGenerator(box, subvol_size, box_overlap=(0, 0, 0))
for sub_box in calc.boxes:
yield from self._generate_boxes_to_warp(data_vol, sub_box)

def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany:
box = subvol.bbox
input_ndarray = subvol.data

data_vol = self._data_volinfo
if self._data_decorator_specs:
data_vol = metadata.DecoratedVolume(
path=self._data_volinfo,
decorator_specs=self._data_decorator_specs,
)
data_vol = self._open_volume(data_vol)

# TODO(blakely): Re-enable this cache size check.
# min_cache_size = (
# -(-box.size[0] // data_vol.info.chunksize.x)
# * data_vol.info.chunksize.x
# * -(-box.size[1] // data_vol.info.chunksize.y)
# * data_vol.info.chunksize.y
# * data_vol.info.chunksize.z
# * ndarray.data_type_to_size[data_vol.info.channel_type]
# )
# if min_cache_size > self._source_cache_bytes:
# logging.warning(
# (
# 'Cache size needs to be at least %d bytes in order '
# 'to optimally utilize data loaded from the source volume.'
# ),
# min_cache_size,
# )

warped = np.zeros(
[input_ndarray.shape[0]] + box.size[::-1].tolist(),
dtype=input_ndarray.dtype,
)
del input_ndarray

# Warp data section-wise.
for z in range(warped.shape[1]):
curr_box = bounding_box.BoundingBox(
start=box.start + [0, 0, z], size=[box.size[0], box.size[1], 1]
)
logging.debug('warping z=%d', z)

for out_box, data_box, map_data, map_box in self._generate_boxes_to_warp(
data_vol, curr_box
):
logging.debug('warp %r -> %r via map %r', data_box, out_box, map_box)
warp_box = out_box.scale(self._downsample)
warped_sec = self._load_and_warp(
data_box, data_vol, map_data, map_box, warp_box
)
if warped_sec is None:
continue
if warp_box != out_box:
downsampled = []

# Cast to a wider type to eliminate overflow or reduce precision
# loss in integral image computation.
if warped_sec.dtype in (np.uint8, np.uint32):
warped_sec = warped_sec.astype(np.int64)
elif warped_sec.dtype == np.float32:
warped_sec = warped_sec.astype(np.float64)
warped_sec = np.nan_to_num(warped_sec)
else:
raise NotImplementedError(
f'Downsampling of {warped_sec.dtype} not supported.'
)

for chan in range(warped_sec.shape[0]):
svt = geom_utils.integral_image(warped_sec[chan, ...])
down_box, down_data = geom_utils.downsample_area(
svt, warp_box, self._downsample, warped.dtype
)
downsampled.append(down_data)
write_box = down_box.translate(-box.start)
warped[write_box.to_slice4d()] = np.concatenate(
downsampled, axis=0
).astype(warped.dtype)
else:
write_box = out_box.translate(-box.start)
warped[write_box.to_slice4d()] = warped_sec

return [self.crop_box_and_data(box, warped)]

0 comments on commit 141f5af

Please sign in to comment.