Skip to content

Commit

Permalink
Small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Dec 16, 2024
1 parent e50bc78 commit b3ea2e5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
8 changes: 1 addition & 7 deletions examples/evaluate_distributed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ use rlst::{
};

fn main() {
// Ensure that there is only one Rayon thread per process

rayon::ThreadPoolBuilder::new()
.num_threads(1)
.build_global()
.unwrap();

// Create the MPI communicator
let universe = mpi::initialize().unwrap();
let world = universe.world();
Expand Down Expand Up @@ -62,6 +55,7 @@ fn main() {
&targets,
&charges,
&mut result,
false,
);

// We now check the result with an evaluation only on the first rank.
Expand Down
32 changes: 24 additions & 8 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ pub trait Kernel: Sync {
// again once we move to the better C interface in `c-api-tools`.

/// Distributed evaluation of a Green's function kernel.
///
/// If `use_multithreaded` is set to true, the evaluation uses Rayon multi-threading on each rank.
/// Otherwise, the evaluation on each rank is single-threaded.
#[cfg(feature = "mpi")]
pub trait DistributedKernelEvaluator: Kernel {
fn evaluate_distributed<
Expand All @@ -135,6 +138,7 @@ pub trait DistributedKernelEvaluator: Kernel {
targets: &DistributedVector<'_, TargetLayout, <Self::T as RlstScalar>::Real>,
charges: &DistributedVector<'_, ChargeLayout, Self::T>,
result: &mut DistributedVector<'_, ResultLayout, Self::T>,
use_multithreaded: bool,
) where
Self::T: Equivalence,
<Self::T as RlstScalar>::Real: Equivalence,
Expand All @@ -155,7 +159,8 @@ pub trait DistributedKernelEvaluator: Kernel {

// Check that the output vector has the correct size.
assert_eq!(
targets.index_layout().number_of_local_indices(),
self.range_component_count(eval_type)
* targets.index_layout().number_of_local_indices(),
3 * result.index_layout().number_of_local_indices()
);

Expand Down Expand Up @@ -186,13 +191,24 @@ pub trait DistributedKernelEvaluator: Kernel {
root_process.broadcast_into(&mut root_charges.data_mut()[..]);

// We now have the sources and charges on all ranks. We can now simply evaluate.
self.evaluate_mt(
eval_type,
&root_sources.data()[..],
targets.local().data(),
&root_charges.data()[..],
result.local_mut().data_mut(),
);

if use_multithreaded {
self.evaluate_mt(
eval_type,
&root_sources.data()[..],
targets.local().data(),
&root_charges.data()[..],
result.local_mut().data_mut(),
);
} else {
self.evaluate_st(
eval_type,
&root_sources.data()[..],
targets.local().data(),
&root_charges.data()[..],
result.local_mut().data_mut(),
);
}
}
}
}
Expand Down

0 comments on commit b3ea2e5

Please sign in to comment.