diff --git a/wake_t/beamline_elements/field_element.py b/wake_t/beamline_elements/field_element.py index 31159d9..891bae6 100644 --- a/wake_t/beamline_elements/field_element.py +++ b/wake_t/beamline_elements/field_element.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, List, Literal +from typing import Optional, Union, Callable, List, Literal import scipy.constants as ct @@ -64,7 +64,7 @@ def __init__( n_out: Optional[int] = 1, name: Optional[str] = 'field element', fields: Optional[List[Field]] = [], - auto_dt_bunch: Optional[str] = None, + auto_dt_bunch: Optional[Callable[[ParticleBunch], float]] = None, push_bunches_before_diags: Optional[bool] = True, ) -> None: self.length = length diff --git a/wake_t/beamline_elements/plasma_stage.py b/wake_t/beamline_elements/plasma_stage.py index fd58c86..16fc96e 100644 --- a/wake_t/beamline_elements/plasma_stage.py +++ b/wake_t/beamline_elements/plasma_stage.py @@ -9,6 +9,7 @@ import wake_t.physics_models.plasma_wakefields as wf from wake_t.fields.base import Field from .field_element import FieldElement +from wake_t.particles.particle_bunch import ParticleBunch DtBunchType = Union[float, str, List[Union[float, str]]] @@ -50,6 +51,10 @@ class PlasmaStage(FieldElement): stage. A list of values can also be provided. In this case, the list should have the same order as the list of bunches given to the ``track`` method. + auto_dt_bunch : callable, optional + Function used to determine the adaptive time step for bunches in + which the time step is set to ``'auto'``. The function should take + solely a ``ParticleBunch`` as argument. push_bunches_before_diags : bool, optional Whether to push the bunches before saving them to the diagnostics. Since the time step of the diagnostics can be different from that @@ -93,6 +98,7 @@ def __init__( wakefield_model: Optional[str] = 'simple_blowout', bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris', dt_bunch: Optional[DtBunchType] = 'auto', + auto_dt_bunch: Optional[Callable[[ParticleBunch], float]] = None, push_bunches_before_diags: Optional[bool] = True, n_out: Optional[int] = 1, name: Optional[str] = 'Plasma stage', @@ -106,6 +112,9 @@ def __init__( if self.wakefield is not None: fields.append(self.wakefield) fields.extend(self.external_fields) + self.auto_dt_bunch = auto_dt_bunch + if self.auto_dt_bunch is None: + self.auto_dt_bunch = self._get_optimized_dt super().__init__( length=length, dt_bunch=dt_bunch, @@ -113,7 +122,7 @@ def __init__( n_out=n_out, name=name, fields=fields, - auto_dt_bunch=self._get_optimized_dt, + auto_dt_bunch=self.auto_dt_bunch, push_bunches_before_diags=push_bunches_before_diags, )