diff --git a/examples/evaluate_distributed.rs b/examples/evaluate_distributed.rs index b8d8a1d..bf3fdf1 100644 --- a/examples/evaluate_distributed.rs +++ b/examples/evaluate_distributed.rs @@ -12,13 +12,6 @@ use rlst::{ }; fn main() { - // Ensure that there is only one Rayon thread per process - - rayon::ThreadPoolBuilder::new() - .num_threads(1) - .build_global() - .unwrap(); - // Create the MPI communicator let universe = mpi::initialize().unwrap(); let world = universe.world(); @@ -62,6 +55,7 @@ fn main() { &targets, &charges, &mut result, + false, ); // We now check the result with an evaluation only on the first rank. diff --git a/src/traits.rs b/src/traits.rs index c1f457e..3c7a7bc 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -121,6 +121,9 @@ pub trait Kernel: Sync { // again once we move to the better C interface in `c-api-tools`. /// Distributed evaluation of a Green's function kernel. +/// +/// If `use_multithreaded` is set to true, the evaluation uses Rayon multi-threading on each rank. +/// Otherwise, the evaluation on each rank is single-threaded. #[cfg(feature = "mpi")] pub trait DistributedKernelEvaluator: Kernel { fn evaluate_distributed< @@ -135,6 +138,7 @@ pub trait DistributedKernelEvaluator: Kernel { targets: &DistributedVector<'_, TargetLayout, ::Real>, charges: &DistributedVector<'_, ChargeLayout, Self::T>, result: &mut DistributedVector<'_, ResultLayout, Self::T>, + use_multithreaded: bool, ) where Self::T: Equivalence, ::Real: Equivalence, @@ -155,7 +159,8 @@ pub trait DistributedKernelEvaluator: Kernel { // Check that the output vector has the correct size. assert_eq!( - targets.index_layout().number_of_local_indices(), + self.range_component_count(eval_type) + * targets.index_layout().number_of_local_indices(), 3 * result.index_layout().number_of_local_indices() ); @@ -186,13 +191,24 @@ pub trait DistributedKernelEvaluator: Kernel { root_process.broadcast_into(&mut root_charges.data_mut()[..]); // We now have the sources and charges on all ranks. We can now simply evaluate. - self.evaluate_mt( - eval_type, - &root_sources.data()[..], - targets.local().data(), - &root_charges.data()[..], - result.local_mut().data_mut(), - ); + + if use_multithreaded { + self.evaluate_mt( + eval_type, + &root_sources.data()[..], + targets.local().data(), + &root_charges.data()[..], + result.local_mut().data_mut(), + ); + } else { + self.evaluate_st( + eval_type, + &root_sources.data()[..], + targets.local().data(), + &root_charges.data()[..], + result.local_mut().data_mut(), + ); + } } } }