Skip to content

Commit

Permalink
Manage SciJava Context lifecycle
Browse files Browse the repository at this point in the history
And fix failing tests on Windows.

This changes all of labkit-ui to eschew usage of the SingletonContext
class in favor of passing the context between components as appropriate,
creating new contexts in relevant situations, and disposing of contexts
when they are finished being used in all cases except for demo code.

It also escapes backslashes of Windows paths hardcoded into macros,
to account for ImageJ macro using backslash as an escape character.

Some of these changes change existing public API, so this commit
increments the major/minor digit of the development version.
  • Loading branch information
ctrueden committed Oct 10, 2024
1 parent d585f1c commit c7d2cbb
Show file tree
Hide file tree
Showing 30 changed files with 316 additions and 215 deletions.
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

<groupId>sc.fiji</groupId>
<artifactId>labkit-ui</artifactId>
<version>0.4.1-SNAPSHOT</version>
<version>0.5.0-SNAPSHOT</version>

<name>Labkit</name>
<description>The Labkit image segmentation tool for Fiji.</description>
Expand Down Expand Up @@ -95,6 +95,7 @@
<imglib2-algorithm.version>0.15.3</imglib2-algorithm.version>
<bigdataviewer-core.version>10.6.0</bigdataviewer-core.version>
<labkit-pixel-classification.version>0.1.18</labkit-pixel-classification.version>
<spim_data.version>2.3.4</spim_data.version>
</properties>

<dependencies>
Expand Down
5 changes: 0 additions & 5 deletions src/main/java/sc/fiji/labkit/ui/LabkitFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.models.SegmentationModel;
import sc.fiji.labkit.ui.utils.Notifier;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import org.scijava.Context;

import javax.swing.*;
Expand All @@ -65,8 +64,6 @@ public class LabkitFrame {
public static LabkitFrame showForFile(Context context,
final String filename)
{
if (context == null)
context = SingletonContext.getInstance();
Dataset dataset = openDataset(context, filename);
return showForImage(context, new DatasetInputImage(dataset));
}
Expand All @@ -83,8 +80,6 @@ private static Dataset openDataset(Context context, String filename) {
public static LabkitFrame showForImage(Context context,
final InputImage inputImage)
{
if (context == null)
context = SingletonContext.getInstance();
final SegmentationModel model = new DefaultSegmentationModel(context, inputImage);
model.imageLabelingModel().labeling().set(InitialLabeling.initialLabeling(context, inputImage));
return show(model, inputImage.imageForSegmentation().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ List<RandomAccessibleInterval<T>> getSegmentations(T type)
Stream<Segmenter> trainedSegmenters = getTrainedSegmenters();
return trainedSegmenters
.map(segmenter -> {
SegmentationTool segmentationTool = new SegmentationTool(segmenter);
SegmentationTool segmentationTool = new SegmentationTool(context, segmenter);
segmentationTool.setProgressWriter(new DummyProgressWriter());
return segmentationTool.segment(image, type);
})
Expand All @@ -105,7 +105,7 @@ public List<RandomAccessibleInterval<FloatType>> getPredictions() {
Stream<Segmenter> trainedSegmenters = getTrainedSegmenters();
return trainedSegmenters
.map(segmenter -> {
SegmentationTool segmentationTool = new SegmentationTool(segmenter);
SegmentationTool segmentationTool = new SegmentationTool(context, segmenter);
segmentationTool.setProgressWriter(new DummyProgressWriter());
return segmentationTool.probabilityMap(image);
})
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/sc/fiji/labkit/ui/panel/AddSegmenterPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
import sc.fiji.labkit.ui.models.SegmenterListModel;
import sc.fiji.labkit.ui.segmentation.SegmentationPlugin;
import sc.fiji.labkit.ui.segmentation.SegmentationPluginService;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import net.miginfocom.swing.MigLayout;
import org.scijava.Context;

import javax.swing.*;
import java.awt.*;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;

/**
* Panel that shows a list of available segmentation algorithms. This panel is
Expand Down Expand Up @@ -69,10 +70,16 @@ private void addButtons(SegmenterListModel segmenterListModel, JPanel list) {

public static void main(String... args) {
JFrame frame = new JFrame("Select Segmentation Algorithm");
Context context = SingletonContext.getInstance();
Context context = new Context();
SegmenterListModel slm = new SegmenterListModel(context, new ExtensionPoints());
frame.add(new AddSegmenterPanel(slm));
frame.setSize(300, 300);
frame.setVisible(true);
frame.addWindowListener(new WindowAdapter() {
@Override
public void windowClosed(WindowEvent e) {
context.dispose();
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.filefilter.WildcardFileFilter;
import org.scijava.Cancelable;
import org.scijava.Context;
import org.scijava.app.StatusService;
import org.scijava.command.Command;
import org.scijava.log.Logger;
Expand All @@ -52,6 +53,9 @@

abstract class AbstractProcessFilesInDirectoryPlugin implements Command, Cancelable {

@Parameter
private Context context;

@Parameter
private DatasetIOService io;

Expand Down Expand Up @@ -84,7 +88,7 @@ abstract class AbstractProcessFilesInDirectoryPlugin implements Command, Cancela

@Override
public void run() {
SegmentationTool segmenter = new SegmentationTool();
SegmentationTool segmenter = new SegmentationTool(context);
segmenter.setUseGpu(use_gpu);
segmenter.openModel(segmenter_file.getAbsolutePath());
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ public class CalculateProbabilityMapWithLabkitIJ1Plugin implements Command, Canc

@Override
public void run() {
SegmentationTool segmenter = new SegmentationTool();
segmenter.setContext(context);
SegmentationTool segmenter = new SegmentationTool(context);
segmenter.openModel(segmenter_file.getAbsolutePath());
segmenter.setUseGpu(use_gpu);
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ public class CalculateProbabilityMapWithLabkitPlugin implements Command, Cancela

@Override
public void run() {
SegmentationTool segmenter = new SegmentationTool();
segmenter.setContext(context);
SegmentationTool segmenter = new SegmentationTool(context);
segmenter.openModel(segmenter_file.getAbsolutePath());
segmenter.setUseGpu(use_gpu);
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ public class SegmentImageWithLabkitIJ1Plugin implements Command, Cancelable {

@Override
public void run() {
SegmentationTool segmenter = new SegmentationTool();
segmenter.setContext(context);
SegmentationTool segmenter = new SegmentationTool(context);
segmenter.setUseGpu(use_gpu);
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
segmenter.openModel(segmenter_file.getAbsolutePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ public class SegmentImageWithLabkitPlugin implements Command, Cancelable {

@Override
public void run() {
SegmentationTool segmenter = new SegmentationTool();
segmenter.setContext(context);
SegmentationTool segmenter = new SegmentationTool(context);
segmenter.setUseGpu(use_gpu);
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
segmenter.openModel(segmenter_file.getAbsolutePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import net.imglib2.util.Intervals;
import org.apache.commons.lang3.ArrayUtils;
import org.scijava.Context;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.inputimage.ImgPlusViewsOld;
import sc.fiji.labkit.ui.models.CachedImageFactory;
Expand All @@ -71,16 +70,16 @@ public class SegmentationTool {

private final CachedImageFactory cachedImageFactory = DefaultCachedImageFactory.getInstance();

public SegmentationTool() {

public SegmentationTool(Context context) {
this(context, null);
}

public SegmentationTool(Segmenter segmenter) {
public SegmentationTool(Context context, Segmenter segmenter) {
this.context = Objects.requireNonNull(context);
this.segmenter = segmenter;
}

public void openModel(String classifierFile) {
Context context = this.context != null ? this.context : SingletonContext.getInstance();
Segmenter segmenter = new TrainableSegmentationSegmenter(context);
segmenter.openModel(classifierFile);
setSegmenter(segmenter);
Expand All @@ -94,10 +93,6 @@ public void setSegmenter(Segmenter segmenter) {
this.segmenter.setUseGpu(useGpu);
}

public void setContext(Context context) {
this.context = Objects.requireNonNull(context);
}

public void setProgressWriter(ProgressWriter progressWriter) {
this.progressWriter = Objects.requireNonNull(progressWriter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import sc.fiji.labkit.ui.segmentation.SegmentationPlugin;
import sc.fiji.labkit.ui.segmentation.Segmenter;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import org.scijava.Context;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;
Expand Down Expand Up @@ -66,8 +65,7 @@ public boolean canOpenFile(String filename) {
}
}

public static SegmentationPlugin create() {
Context context = SingletonContext.getInstance();
public static SegmentationPlugin create(Context context) {
PixelClassificationPlugin plugin = new PixelClassificationPlugin();
context.inject(plugin);
return plugin;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.models.ImageLabelingModel;
import sc.fiji.labkit.ui.models.SegmentationModel;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.view.Views;
import net.miginfocom.swing.MigLayout;
Expand Down Expand Up @@ -93,8 +92,7 @@ private static void onChangeImageButtonClicked(SegmentationModel segmentationMod
DatasetInputImage datasetInputImage = new DatasetInputImage(image);
model.showable().set(datasetInputImage.showable());
model.imageForSegmentation().set(datasetInputImage.imageForSegmentation());
model.labeling().set(InitialLabeling.initialLabeling(SingletonContext.getInstance(),
datasetInputImage));
model.labeling().set(InitialLabeling.initialLabeling(new Context(), datasetInputImage));
}

private JPanel getBottomPanel() {
Expand Down
6 changes: 5 additions & 1 deletion src/test/java/demo/MultiChannelMovieDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@
*/
public class MultiChannelMovieDemo {

private static Context context;

public static void main(String... args) {
context = new Context();
main2();
// TODO: Dispose context when relevant window is closed.
}

private static void main1() {
Expand Down Expand Up @@ -120,7 +124,7 @@ public void testInputImageImageForSegmentation() {
SegmentationModel segmentationModel = new DefaultSegmentationModel(
new Context(), inputImage);
SegmentationItem segmenter = segmentationModel.segmenterList().addSegmenter(
PixelClassificationPlugin.create());
PixelClassificationPlugin.create(context));
Labeling labeling1 = labeling5d();
segmentationModel.imageLabelingModel().labeling().set(labeling1);
segmenter.train(Collections.singletonList(new ValuePair<>(inputImage
Expand Down
20 changes: 17 additions & 3 deletions src/test/java/sc/fiji/labkit/ui/InitialLabelingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import net.imagej.axis.EnumeratedAxis;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.scijava.Context;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.labeling.Labeling;
import net.imglib2.type.numeric.integer.UnsignedByteType;
Expand All @@ -49,6 +51,18 @@

public class InitialLabelingTest {

private static Context context;

@BeforeClass
public static void setUp() {
context = new Context();
}

@AfterClass
public static void tearDown() {
context.dispose();
}

@Test
public void testDoNotCrashWhenLabelingFileIsEmpty() throws IOException {
File empty = File.createTempFile("labkit-InitialLabelingTest-",
Expand All @@ -59,7 +73,7 @@ public void testDoNotCrashWhenLabelingFileIsEmpty() throws IOException {
DatasetInputImage inputImage = new DatasetInputImage(image);
List<String> defaultLabels = Collections.emptyList();
Labeling result = InitialLabeling.initLabeling(inputImage,
SingletonContext.getInstance(),
context,
defaultLabels);
assertNotNull(result);
}
Expand All @@ -75,7 +89,7 @@ public void testEnumeratedAxis() {
EnumeratedAxis yAxis = new EnumeratedAxis(Axes.Y, "mm", new double[] { 0, 0.3 });
ImgPlus<UnsignedByteType> image = new ImgPlus<>(img, "test", xAxis, yAxis);
DatasetInputImage inputImage = new DatasetInputImage(image);
Labeling result = InitialLabeling.initialLabeling(SingletonContext.getInstance(), inputImage);
Labeling result = InitialLabeling.initialLabeling(context, inputImage);
assertNotNull(result);
}
}
29 changes: 21 additions & 8 deletions src/test/java/sc/fiji/labkit/ui/SegmentationUseCaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.IntegerType;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import sc.fiji.labkit.ui.bdv.BdvShowable;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.inputimage.InputImage;
Expand All @@ -51,7 +53,6 @@
import sc.fiji.labkit.ui.segmentation.weka.PixelClassificationPlugin;
import sc.fiji.labkit.ui.segmentation.SegmentationPlugin;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.util.Intervals;
import net.imglib2.util.ValuePair;
Expand All @@ -70,22 +71,34 @@

public class SegmentationUseCaseTest {

private static Context context;

@BeforeClass
public static void setUp() {
context = new Context();
}

@AfterClass
public static void tearDown() {
context.dispose();
}

@Test
public void test() {
ImgPlus<UnsignedByteType> image = new ImgPlus<>(ArrayImgs.unsignedBytes(new byte[] { 1, 1, 2,
2 }, 2, 2));
InputImage inputImage = new DatasetInputImage(image);
SegmentationModel segmentationModel = new DefaultSegmentationModel(
new Context(), inputImage);
context, inputImage);
addLabels(segmentationModel.imageLabelingModel());
SegmentationPlugin plugin = PixelClassificationPlugin.create();
SegmentationPlugin plugin = PixelClassificationPlugin.create(context);
SegmentationItem segmenter = segmentationModel.segmenterList().addSegmenter(plugin);
segmenter.train(Collections.singletonList(new ValuePair<>(image,
segmentationModel.imageLabelingModel().labeling().get())));
RandomAccessibleInterval<? extends IntegerType<?>> result =
segmenter.results(segmentationModel.imageLabelingModel()).segmentation();
List<Integer> list = new ArrayList<>();
Views.iterable(result).forEach(x -> list.add(x.getInteger()));
result.forEach(x -> list.add(x.getInteger()));
assertEquals(Arrays.asList(1, 1, 0, 0), list);
}

Expand All @@ -99,7 +112,7 @@ private void addLabels(ImageLabelingModel imageLabelingModel) {
}

@Test
public void testMultiChannel() throws InterruptedException {
public void testMultiChannel() {
Img<UnsignedByteType> img = ArrayImgs.unsignedBytes(new byte[] { -1, 0, -1,
0, -1, -1, 0, 0 }, 2, 2, 2);
ImgPlus<UnsignedByteType> imgPlus = new ImgPlus<>(img, "Image",
Expand All @@ -108,17 +121,17 @@ public void testMultiChannel() throws InterruptedException {
.wrap(Views.hyperSlice(img, 2, 0)));

Labeling labeling = getLabeling();
SegmentationModel segmentationModel = new DefaultSegmentationModel(new Context(),
SegmentationModel segmentationModel = new DefaultSegmentationModel(context,
inputImage);
ImageLabelingModel imageLabelingModel = segmentationModel.imageLabelingModel();
imageLabelingModel.labeling().set(labeling);
SegmentationItem segmenter = segmentationModel.segmenterList().addSegmenter(
PixelClassificationPlugin.create());
PixelClassificationPlugin.create(context));
segmenter.train(Collections.singletonList(new ValuePair<>(imgPlus,
imageLabelingModel.labeling().get())));
RandomAccessibleInterval<? extends IntegerType<?>> result =
segmenter.results(imageLabelingModel).segmentation();
Iterator<? extends IntegerType<?>> it = Views.iterable(result).iterator();
Iterator<? extends IntegerType<?>> it = result.iterator();
assertEquals(1, it.next().getInteger());
assertEquals(0, it.next().getInteger());
assertEquals(0, it.next().getInteger());
Expand Down
Loading

0 comments on commit c7d2cbb

Please sign in to comment.