diff --git a/src/main/java/sc/fiji/labkit/ui/plugin/AbstractProcessFilesInDirectoryPlugin.java b/src/main/java/sc/fiji/labkit/ui/plugin/AbstractProcessFilesInDirectoryPlugin.java index 13f1b713..8a867360 100644 --- a/src/main/java/sc/fiji/labkit/ui/plugin/AbstractProcessFilesInDirectoryPlugin.java +++ b/src/main/java/sc/fiji/labkit/ui/plugin/AbstractProcessFilesInDirectoryPlugin.java @@ -33,6 +33,8 @@ import net.imagej.Dataset; import net.imagej.DatasetService; import net.imagej.ImgPlus; +import net.imglib2.parallel.TaskExecutor; +import net.imglib2.parallel.TaskExecutors; import net.imglib2.type.Type; import net.imglib2.util.Cast; import org.apache.commons.io.FilenameUtils; @@ -42,6 +44,7 @@ import org.scijava.command.Command; import org.scijava.log.Logger; import org.scijava.plugin.Parameter; +import sc.fiji.labkit.pixel_classification.gpu.api.GpuPool; import sc.fiji.labkit.ui.segmentation.SegmentationTool; import sc.fiji.labkit.ui.utils.progress.StatusServiceProgressWriter; @@ -49,6 +52,11 @@ import java.io.FileFilter; import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; abstract class AbstractProcessFilesInDirectoryPlugin implements Command, Cancelable { @@ -90,17 +98,40 @@ public void run() { segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService)); FileFilter wildcardFileFilter = new WildcardFileFilter(file_filter); File[] files = input_directory.listFiles(wildcardFileFilter); - Arrays.sort(files); - for (int i = 0; i < files.length; i++) { + AtomicInteger counter = new AtomicInteger(); + TaskExecutor taskExecutor = getTaskExecutor(); + taskExecutor.forEach(Arrays.asList(files), file -> { try { - processFile(segmenter, files, i); + processFile(segmenter, files, counter.getAndIncrement()); } catch (Exception e) { logger.error(e); } + }); + } + + private TaskExecutor getTaskExecutor() { + if (use_gpu) { + TaskExecutor blocksTaskExecutor = fixedNumThreadsTaskExecuter(GpuPool.size(), + TaskExecutors::multiThreaded); + TaskExecutor imagesTaskExecutor = fixedNumThreadsTaskExecuter(2, () -> blocksTaskExecutor); + return imagesTaskExecutor; + } + else { + return fixedNumThreadsTaskExecuter(1, TaskExecutors::multiThreaded); +// return TaskExecutors.nestedFixedThreadPool( 2, (Runtime.getRuntime().availableProcessors() + 1)/ 2 ); } } + private TaskExecutor fixedNumThreadsTaskExecuter(int numThreads, + Supplier multiThreaded) + { + ThreadFactory threadFactory = TaskExecutors.threadFactory(multiThreaded); + ExecutorService executorService = Executors.newFixedThreadPool(numThreads, threadFactory); + TaskExecutor taskExecutor = TaskExecutors.forExecutorService(executorService); + return taskExecutor; + } + private > void processFile(SegmentationTool segmenter, File[] files, int i) throws IOException { diff --git a/src/test/java/sc/fiji/labkit/ui/segmentation/BatchBenchmark.java b/src/test/java/sc/fiji/labkit/ui/segmentation/BatchBenchmark.java new file mode 100644 index 00000000..93b85249 --- /dev/null +++ b/src/test/java/sc/fiji/labkit/ui/segmentation/BatchBenchmark.java @@ -0,0 +1,89 @@ + +package sc.fiji.labkit.ui.segmentation; + +import bdv.export.ProgressWriterConsole; +import net.imglib2.util.StopWatch; +import org.apache.commons.io.filefilter.WildcardFileFilter; +import org.scijava.Context; +import org.scijava.command.CommandService; +import sc.fiji.labkit.pixel_classification.utils.SingletonContext; +import sc.fiji.labkit.ui.BatchSegmenter; +import sc.fiji.labkit.ui.plugin.LabkitSegmentImagesInDirectoryPlugin; +import sc.fiji.labkit.ui.segmentation.weka.TrainableSegmentationSegmenter; + +import java.io.File; +import java.io.FileFilter; +import java.util.concurrent.ExecutionException; + +public class BatchBenchmark { + + private static final String base = "/home/arzt/tmp/labkit-memory-test/"; + + private static final String inputDirectory = base + "input/"; + + private static final String segmenterFile = base + "normal.classifier"; + + private static final String outputDirectory = base + "output/"; + + private static final String fileFilter = "Untitled_1*.tif"; + + private static final boolean useGpu = false; + + public static void main(String... args) throws Exception { + runNewVersion(); + } + + private static void runOldVersion() throws Exception { + Context context = SingletonContext.getInstance(); + StopWatch sw = StopWatch.createAndStart(); + TrainableSegmentationSegmenter segmenter = new TrainableSegmentationSegmenter(context); + segmenter.setUseGpu(useGpu); + segmenter.openModel(segmenterFile); + BatchSegmenter batch = new BatchSegmenter(segmenter, new ProgressWriterConsole()); + for (File input : listFiles()) { + System.out.println(input); + batch.segment(input, new File(outputDirectory, input.getName())); + } + System.out.println(sw); + context.close(); + System.exit(0); + } + + private static File[] listFiles() { + File baseFile = new File(inputDirectory); + File[] files = baseFile.listFiles((FileFilter) new WildcardFileFilter(fileFilter)); + if (files != null) + return files; + return new File[0]; + } + + private static void runNewVersion() + throws InterruptedException, ExecutionException + { + Context context = SingletonContext.getInstance(); + CommandService cmd = context.service(CommandService.class); + System.out.println("start"); + removeImageFromDirectory(outputDirectory); + StopWatch stopWatch = StopWatch.createAndStart(); + cmd.run(LabkitSegmentImagesInDirectoryPlugin.class, true, + "input_directory", inputDirectory, + "file_filter", fileFilter, + "output_directory", outputDirectory, + "output_file_suffix", "segmentation.tif", + "segmenter_file", segmenterFile, + "use_gpu", useGpu).get(); + System.out.println("stop"); + System.out.println(stopWatch); + System.exit(0); + } + + private static void removeImageFromDirectory(String directory) { + File d = new File(directory); + FileFilter fileFilter = new WildcardFileFilter("*.tif"); + File[] files = d.listFiles(fileFilter); + if (files == null) + return; + for (File file : files) + file.delete(); + } +}