Skip to content

Commit

Permalink
Merge pull request #149 from AngelFP/feature/push_before_diags
Browse files Browse the repository at this point in the history
Push bunches before saving to diagnostics
  • Loading branch information
AngelFP authored Apr 30, 2024
2 parents e7e4624 + a7edced commit c75ce37
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 24 deletions.
4 changes: 2 additions & 2 deletions tests/test_active_plasma_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_active_plasma_lens():
apl.track(bunch)
bunch_params = analyze_bunch(bunch)
gamma_x = bunch_params['gamma_x']
assert approx(gamma_x, rel=1e-10) == 92.38646379897074
assert approx(gamma_x, rel=1e-10) == 92.38407675999406


def test_active_plasma_lens_with_wakefields():
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_active_plasma_lens_with_wakefields():
# Analyze and check results.
bunch_params = analyze_bunch(bunch)
gamma_x = bunch_params['gamma_x']
assert approx(gamma_x, rel=1e-10) == 77.31995824746237
assert approx(gamma_x, rel=1e-10) == 77.32021188373825


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_custom_blowout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_custom_blowout_wakefield(make_plots=False):

bunch_params = analyze_bunch(bunch)
rel_ene_sp = bunch_params['rel_ene_spread']
assert approx(rel_ene_sp, rel=1e-10) == 0.21192488237458038
assert approx(rel_ene_sp, rel=1e-10) == 0.21192494153185745

if make_plots:
# Analyze bunch evolution.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fluid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_fluid_model(plot=False):

# Check final parameters.
ene_sp = params_evolution['rel_ene_spread'][-1]
assert approx(ene_sp, rel=1e-10) == 0.024179998095119972
assert approx(ene_sp, rel=1e-10) == 0.024157374564016194

# Quick plot of results.
if plot:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multibunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_multibunch_plasma_simulation(plot=False):
# Assert final parameters are correct.
final_energy_driver = driver_params['avg_ene'][-1]
final_energy_witness = witness_params['avg_ene'][-1]
assert approx(final_energy_driver, rel=1e-10) == 1700.3927190416732
assert approx(final_energy_witness, rel=1e-10) == 636.330857261392
assert approx(final_energy_driver, rel=1e-10) == 1700.3843657635728
assert approx(final_energy_witness, rel=1e-10) == 636.3260426124102

if plot:
z = driver_params['prop_dist'] * 1e2
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ramps.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_downramp():
downramp.track(bunch)
bunch_params = analyze_bunch(bunch)
beta_x = bunch_params['beta_x']
assert beta_x == 0.009750309290619276
assert beta_x == 0.009750308724018872


def test_upramp():
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_upramp():
downramp.track(bunch)
bunch_params = analyze_bunch(bunch)
beta_x = bunch_params['beta_x']
assert beta_x == 0.0007631600676104024
assert beta_x == 0.000763155045965493


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simple_blowout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_simple_blowout_wakefield(make_plots=False):

bunch_params = analyze_bunch(bunch)
rel_ene_sp = bunch_params['rel_ene_spread']
assert approx(rel_ene_sp, rel=1e-10) == 0.3637648484576557
assert approx(rel_ene_sp, rel=1e-10) == 0.3637651769109033

if make_plots:
# Analyze bunch evolution.
Expand Down
18 changes: 17 additions & 1 deletion wake_t/beamline_elements/field_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def track(
self,
bunches: Optional[Union[ParticleBunch, List[ParticleBunch]]] = [],
opmd_diag: Optional[Union[bool, OpenPMDDiagnostics]] = False,
diag_dir: Optional[str] = None
diag_dir: Optional[str] = None,
push_bunches_before_diags: Optional[bool] = True,
) -> Union[List[ParticleBunch], List[List[ParticleBunch]]]:
"""
Track bunch through element.
Expand All @@ -83,6 +84,20 @@ def track(
Directory into which the openPMD output will be written. By default
this is a 'diags' folder in the current directory. Only needed if
`opmd_diag=True`.
push_bunches_before_diags : bool, optional
Whether to push the bunches before saving them to the diagnostics.
Since the time step of the diagnostics can be different from that
of the bunches, it could happen that the bunches appear in the
diagnostics as they were at the last push, but not at the actual
time of the diagnostics. Setting this parameter to ``True``
(default) ensures that an additional push is given to all bunches
to evolve them to the diagnostics time before saving.
This additional push will always have a time step smaller than
the the time step of the bunch, so it has no detrimental impact
on the accuracy of the simulation. However, it could make
convergence studies more difficult to interpret,
since the number of pushes will depend on `n_diags`. Therefore,
it is exposed as an option so that it can be disabled if needed.
Returns
-------
Expand Down Expand Up @@ -119,6 +134,7 @@ def track(
opmd_diags=opmd_diag,
bunch_pusher=self.bunch_pusher,
auto_dt_bunch_f=self.auto_dt_bunch,
push_bunches_before_diags=push_bunches_before_diags,
section_name=self.name
)

Expand Down
92 changes: 78 additions & 14 deletions wake_t/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ class Tracker():
bunch_pusher : str, optional
The particle pusher used to evolve the bunches. Possible values
are `'boris'` or `'rk4'`.
push_bunches_before_diags : bool, optional
Whether to push the bunches before saving them to the diagnostics.
Since the time step of the diagnostics can be different from that
of the bunches, it could happen that the bunches appear in the
diagnostics as they were at the last push, but not at the actual
time of the diagnostics. Setting this parameter to ``True``
(default) ensures that an additional push is given to all bunches
to evolve them to the diagnostics time before saving.
This additional push will always have a time step smaller than
the the time step of the bunch, so it has no detrimental impact
on the accuracy of the simulation. However, it could make
convergence studies more difficult to interpret,
since the number of pushes will depend on `n_diags`. Therefore,
it is exposed as an option so that it can be disabled if needed.
section_name : str, optional
Name of the section to be tracked. This will be appended to the
beginning of the progress bar.
Expand All @@ -78,6 +92,7 @@ def __init__(
opmd_diags: Optional[OpenPMDDiagnostics] = None,
auto_dt_bunch_f: Optional[Callable[[ParticleBunch], float]] = None,
bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris',
push_bunches_before_diags: Optional[bool] = True,
section_name: Optional[str] = 'Simulation'
) -> None:
self.t_final = t_final
Expand All @@ -88,6 +103,7 @@ def __init__(
self.n_diags = n_diags
self.auto_dt_bunch_f = auto_dt_bunch_f
self.bunch_pusher = bunch_pusher
self.push_bunches_before_diags = push_bunches_before_diags
self.section_name = section_name

# Get all numerical fields and their time steps.
Expand Down Expand Up @@ -187,27 +203,35 @@ def do_tracking(self) -> List[List[ParticleBunch]]:

# If next object is a ParticleBunch, update it.
if isinstance(obj_next, ParticleBunch):
obj_next.evolve(
self.fields, t_current, dt_next, self.bunch_pusher)
# Update the time step if set to `'auto'`.
if obj_next in self.auto_dt_bunches:
dt_objects[i_next] = self.auto_dt_bunch_f(obj_next)
# Determine if this was the last push.
final_push = np.float32(t_next) == np.float32(self.t_final)
# Determine if next push brings the bunch beyond `t_final`.
next_push_beyond_final_time = (
t_next + dt_objects[i_next] > self.t_final)
# Make sure the last push of the bunch advances it to exactly
# `t_final`.
if not final_push and next_push_beyond_final_time:
dt_objects[i_next] = self.t_final - t_next
self.evolve_bunch(
bunch=obj_next,
t_current=t_current,
t_next=t_next,
dt_next=dt_next,
i_next=i_next,
dt_objects=dt_objects,
)

# If next object is a NumericalField, update it.
elif isinstance(obj_next, NumericalField):
obj_next.update(self.bunches)

# If next object are the diagnostics, generate them.
elif obj_next == 'diags':
# Evolve all bunches to the diagnostics time.
if self.push_bunches_before_diags:
for i, obj in enumerate(self.objects_to_track):
if isinstance(obj, ParticleBunch):
dt_bunch = t_next - t_objects[i]
self.evolve_bunch(
bunch=obj,
t_current=t_objects[i],
t_next=t_next,
dt_next=dt_bunch,
i_next=i,
dt_objects=dt_objects,
)
t_objects[i] += dt_bunch
self.generate_diagnostics()

# Advance current time of the update object.
Expand All @@ -230,6 +254,46 @@ def do_tracking(self) -> List[List[ParticleBunch]]:

return self.bunch_list

def evolve_bunch(
self,
bunch: ParticleBunch,
t_current: float,
t_next: float,
dt_next: float,
i_next: int,
dt_objects: List,
):
"""Evolve particle bunch to next time step.
Parameters
----------
bunch : ParticleBunch
The particle bunch to evolve.
t_current : float
The current time of the simulation.
t_next : float
The time of the next step.
dt_next : float
The time step by which to advance the bunch.
i_next : int
The index of the bunch in the list of objects to track.
dt_objects : List
The time steps of all objects to track.
"""
bunch.evolve(self.fields, t_current, dt_next, self.bunch_pusher)
# Update the time step if set to `'auto'`.
if bunch in self.auto_dt_bunches:
dt_objects[i_next] = self.auto_dt_bunch_f(bunch)
# Determine if this was the last push.
final_push = np.float32(t_next) == np.float32(self.t_final)
# Determine if next push brings the bunch beyond `t_final`.
next_push_beyond_final_time = (
t_next + dt_objects[i_next] > self.t_final)
# Make sure the last push of the bunch advances it to exactly
# `t_final`.
if not final_push and next_push_beyond_final_time:
dt_objects[i_next] = self.t_final - t_next

def generate_diagnostics(self) -> None:
"""Generate tracking diagnostics."""
# Make copy of current bunches and store in output list.
Expand Down

0 comments on commit c75ce37

Please sign in to comment.