Skip to content

Commit

Permalink
Add default alignment configurations for 2D EM.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704440015
  • Loading branch information
timblakely authored and copybara-github committed Dec 9, 2024
1 parent 302c1ac commit 82020da
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 121 deletions.
136 changes: 15 additions & 121 deletions pipeline/mesh_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,14 @@
"""

import dataclasses
import enum
from typing import Any

from connectomics.common import utils
from connectomics.volume import subvolume_processor
import dataclasses_json
from sofima import mesh as mesh_lib
from sofima.processor import maps
from sofima.processor import mesh


# TODO(blakely): Combine with flow_config.DefaultPipeline
class DefaultPipeline(enum.Enum):
EM_2D = 'em_2d'
from sofima.processor.defaults import em_2d


@dataclasses.dataclass(frozen=True)
Expand All @@ -45,127 +40,26 @@ class MeshRelaxationConfig(dataclasses_json.DataClassJsonMixin):
reconcile_cross_block_config: maps.ReconcileCrossBlockMaps.Config


def default_em_2d_relax_mesh_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default mesh relaxation configuration for EM 2D data."""

config = mesh.RelaxMesh.Config(
output_dir='NONE',
integration_config=mesh_lib.IntegrationConfig(
dt=0.001,
gamma=0.0,
k0=0.01,
k=0.1,
stride=(40, 40),
num_iters=1000,
max_iters=100000,
stop_v_max=0.005,
dt_max=1000,
start_cap=0.01,
final_cap=10,
prefer_orig_order=True,
),
mesh=None,
flows=[],
sections_to_skip=[],
ranges_to_skip=[],
mask=None,
block_starts=[],
block_ends=[],
backward=False,
mesh_min_frac=0.5,
mesh_max_frac=2.0,
coming_in=[],
options=mesh.MeshOptions(
irregular_mask_radius=5,
),
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_within_block_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default configuration for within-block mesh relaxation."""
config = default_em_2d_relax_mesh_config()
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_last_section_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default configuration for relaxation of last section of blockwise mesh."""
config = default_em_2d_relax_mesh_config()
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_cross_block_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default cross-block mesh relaxation configuration for EM 2D data."""
config = default_em_2d_relax_mesh_config({
'integration_config': {
'k0': 0.001,
'stride': (320, 320),
'stop_v_max': 0.001,
},
'options': {
'init_state': mesh.MeshInitState.PREV_MEDIAN,
},
})
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_reconcile_config(
overrides: dict[str, Any] | None = None,
) -> maps.ReconcileCrossBlockMaps.Config:
"""Default cross-block reconciliation configuration for EM 2D data."""
config = maps.ReconcileCrossBlockMaps.Config(
cross_block='NONE',
cross_block_inv='NONE',
last_inv='NONE',
main_inv='NONE',
z_map={},
stride=40,
xy_overlap=128,
backward=False,
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d(
overrides: dict[str, Any] | None = None,
) -> MeshRelaxationConfig:
"""Default mesh relaxation configuration for EM 2D data."""
within_block = em_2d.within_block_config()
last_section = em_2d.last_section_config()
cross_block = em_2d.cross_block_config()
reconcile_cross_block = em_2d.default_em_2d_reconcile_config()
config = MeshRelaxationConfig(
within_block_config=default_em_2d_within_block_config(),
last_section_config=default_em_2d_last_section_config(),
cross_block_config=default_em_2d_cross_block_config(),
reconcile_cross_block_config=default_em_2d_reconcile_config(),
within_block_config=within_block,
last_section_config=last_section,
cross_block_config=cross_block,
reconcile_cross_block_config=reconcile_cross_block,
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


_DEFAULT_CONFIG_TYPE_DISPATCH = {
DefaultPipeline.EM_2D: default_em_2d,
}


def default(
default_type: DefaultPipeline, overrides: dict[str, Any] | None = None
) -> MeshRelaxationConfig:
"""Default mesh relaxation configuration for a given data type."""
return _DEFAULT_CONFIG_TYPE_DISPATCH[default_type](overrides)
subvolume_processor.register_default_config(
subvolume_processor.DefaultConfigType.EM_2D,
MeshRelaxationConfig,
default_em_2d,
)
109 changes: 109 additions & 0 deletions processor/defaults/em_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from connectomics.common import utils
from connectomics.volume import subvolume_processor
from sofima import mesh as mesh_lib
from sofima.processor import flow
from sofima.processor import maps
from sofima.processor import mesh


def estimate_flow_config(
Expand Down Expand Up @@ -103,3 +106,109 @@ def estimate_missing_flow_config(
flow.EstimateMissingFlow.Config,
estimate_missing_flow_config,
)


def relax_mesh_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default mesh relaxation configuration for EM 2D data."""

config = mesh.RelaxMesh.Config(
output_dir='NONE',
integration_config=mesh_lib.IntegrationConfig(
dt=0.001,
gamma=0.0,
k0=0.01,
k=0.1,
stride=(40, 40),
num_iters=1000,
max_iters=100000,
stop_v_max=0.005,
dt_max=1000,
start_cap=0.01,
final_cap=10,
prefer_orig_order=True,
),
mesh=None,
flows=[],
sections_to_skip=[],
ranges_to_skip=[],
mask=None,
block_starts=[],
block_ends=[],
backward=False,
mesh_min_frac=0.5,
mesh_max_frac=2.0,
coming_in=[],
options=mesh.MeshOptions(
irregular_mask_radius=5,
),
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


subvolume_processor.register_default_config(
subvolume_processor.DefaultConfigType.EM_2D,
flow.EstimateFlow.Config,
estimate_flow_config,
)


def within_block_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default configuration for within-block mesh relaxation."""
config = relax_mesh_config()
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def last_section_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default configuration for relaxation of last section of blockwise mesh."""
config = relax_mesh_config()
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def cross_block_config(
overrides: dict[str, Any] | None = None,
) -> mesh.RelaxMesh.Config:
"""Default cross-block mesh relaxation configuration for EM 2D data."""
config = relax_mesh_config({
'integration_config': {
'k0': 0.001,
'stride': (320, 320),
'stop_v_max': 0.001,
},
'options': {
'init_state': mesh.MeshInitState.PREV_MEDIAN,
},
})
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_reconcile_config(
overrides: dict[str, Any] | None = None,
) -> maps.ReconcileCrossBlockMaps.Config:
"""Default cross-block reconciliation configuration for EM 2D data."""
config = maps.ReconcileCrossBlockMaps.Config(
cross_block='NONE',
cross_block_inv='NONE',
last_inv='NONE',
main_inv='NONE',
z_map={},
stride=40,
xy_overlap=128,
backward=False,
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config

0 comments on commit 82020da

Please sign in to comment.