Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch speed optimization #92

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import net.imagej.Dataset;
import net.imagej.DatasetService;
import net.imagej.ImgPlus;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.parallel.TaskExecutors;
import net.imglib2.type.Type;
import net.imglib2.util.Cast;
import org.apache.commons.io.FilenameUtils;
Expand All @@ -42,13 +44,19 @@
import org.scijava.command.Command;
import org.scijava.log.Logger;
import org.scijava.plugin.Parameter;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPool;
import sc.fiji.labkit.ui.segmentation.SegmentationTool;
import sc.fiji.labkit.ui.utils.progress.StatusServiceProgressWriter;

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

abstract class AbstractProcessFilesInDirectoryPlugin implements Command, Cancelable {

Expand Down Expand Up @@ -90,17 +98,40 @@ public void run() {
segmenter.setProgressWriter(new StatusServiceProgressWriter(statusService));
FileFilter wildcardFileFilter = new WildcardFileFilter(file_filter);
File[] files = input_directory.listFiles(wildcardFileFilter);
Arrays.sort(files);
for (int i = 0; i < files.length; i++) {
AtomicInteger counter = new AtomicInteger();
TaskExecutor taskExecutor = getTaskExecutor();
taskExecutor.forEach(Arrays.asList(files), file -> {
try {
processFile(segmenter, files, i);
processFile(segmenter, files, counter.getAndIncrement());
}
catch (Exception e) {
logger.error(e);
}
});
}

private TaskExecutor getTaskExecutor() {
if (use_gpu) {
TaskExecutor blocksTaskExecutor = fixedNumThreadsTaskExecuter(GpuPool.size(),
TaskExecutors::multiThreaded);
TaskExecutor imagesTaskExecutor = fixedNumThreadsTaskExecuter(2, () -> blocksTaskExecutor);
return imagesTaskExecutor;
}
else {
return fixedNumThreadsTaskExecuter(1, TaskExecutors::multiThreaded);
// return TaskExecutors.nestedFixedThreadPool( 2, (Runtime.getRuntime().availableProcessors() + 1)/ 2 );
}
}

private TaskExecutor fixedNumThreadsTaskExecuter(int numThreads,
Supplier<TaskExecutor> multiThreaded)
{
ThreadFactory threadFactory = TaskExecutors.threadFactory(multiThreaded);
ExecutorService executorService = Executors.newFixedThreadPool(numThreads, threadFactory);
TaskExecutor taskExecutor = TaskExecutors.forExecutorService(executorService);
return taskExecutor;
}

private <T extends Type<T>> void processFile(SegmentationTool segmenter, File[] files, int i)
throws IOException
{
Expand Down
89 changes: 89 additions & 0 deletions src/test/java/sc/fiji/labkit/ui/segmentation/BatchBenchmark.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@

package sc.fiji.labkit.ui.segmentation;

import bdv.export.ProgressWriterConsole;
import net.imglib2.util.StopWatch;
import org.apache.commons.io.filefilter.WildcardFileFilter;
import org.scijava.Context;
import org.scijava.command.CommandService;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import sc.fiji.labkit.ui.BatchSegmenter;
import sc.fiji.labkit.ui.plugin.LabkitSegmentImagesInDirectoryPlugin;
import sc.fiji.labkit.ui.segmentation.weka.TrainableSegmentationSegmenter;

import java.io.File;
import java.io.FileFilter;
import java.util.concurrent.ExecutionException;

public class BatchBenchmark {

private static final String base = "/home/arzt/tmp/labkit-memory-test/";

private static final String inputDirectory = base + "input/";

private static final String segmenterFile = base + "normal.classifier";

private static final String outputDirectory = base + "output/";

private static final String fileFilter = "Untitled_1*.tif";

private static final boolean useGpu = false;

public static void main(String... args) throws Exception {
runNewVersion();
}

private static void runOldVersion() throws Exception {
Context context = SingletonContext.getInstance();
StopWatch sw = StopWatch.createAndStart();
TrainableSegmentationSegmenter segmenter = new TrainableSegmentationSegmenter(context);
segmenter.setUseGpu(useGpu);
segmenter.openModel(segmenterFile);
BatchSegmenter batch = new BatchSegmenter(segmenter, new ProgressWriterConsole());
for (File input : listFiles()) {
System.out.println(input);
batch.segment(input, new File(outputDirectory, input.getName()));
}
System.out.println(sw);
context.close();
System.exit(0);
}

private static File[] listFiles() {
File baseFile = new File(inputDirectory);
File[] files = baseFile.listFiles((FileFilter) new WildcardFileFilter(fileFilter));
if (files != null)
return files;
return new File[0];
}

private static void runNewVersion()
throws InterruptedException, ExecutionException
{
Context context = SingletonContext.getInstance();
CommandService cmd = context.service(CommandService.class);
System.out.println("start");
removeImageFromDirectory(outputDirectory);
StopWatch stopWatch = StopWatch.createAndStart();
cmd.run(LabkitSegmentImagesInDirectoryPlugin.class, true,
"input_directory", inputDirectory,
"file_filter", fileFilter,
"output_directory", outputDirectory,
"output_file_suffix", "segmentation.tif",
"segmenter_file", segmenterFile,
"use_gpu", useGpu).get();
System.out.println("stop");
System.out.println(stopWatch);
System.exit(0);
}

private static void removeImageFromDirectory(String directory) {
File d = new File(directory);
FileFilter fileFilter = new WildcardFileFilter("*.tif");
File[] files = d.listFiles(fileFilter);
if (files == null)
return;
for (File file : files)
file.delete();
}
}