Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 1D unwrapping of laser envelope phase #145

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions wake_t/physics_models/laser/envelope_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from wake_t.utilities.numba import njit_serial
from .tdma import TDMA
from .utils import unwrap


@njit_serial(fastmath=True)
Expand Down Expand Up @@ -91,7 +92,7 @@ def evolve_envelope(

# Getting the phase of the envelope on axis.
if use_phase:
phases = np.angle(a[:, 0])
phases = unwrap(np.angle(a[:, 0]))

# Loop over z.
for j in range(nz - 1, -1, -1):
Expand All @@ -101,16 +102,6 @@ def evolve_envelope(
d_theta1 = phases[j + 1] - phases[j]
d_theta2 = phases[j + 2] - phases[j + 1]

# Prevent phase jumps bigger than 1.5*pi.
if d_theta1 < -1.5 * np.pi:
d_theta1 += 2 * np.pi
if d_theta2 < -1.5 * np.pi:
d_theta2 += 2 * np.pi
if d_theta1 > 1.5 * np.pi:
d_theta1 -= 2 * np.pi
if d_theta2 > 1.5 * np.pi:
d_theta2 -= 2 * np.pi

# Calculate D factor [Eq. (6)].
D_jkn = (1.5 * d_theta1 - 0.5 * d_theta2) * inv_dz

Expand Down
13 changes: 2 additions & 11 deletions wake_t/physics_models/laser/envelope_solver_non_centered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from wake_t.utilities.numba import njit_serial
from .tdma import TDMA
from .utils import unwrap


@njit_serial(fastmath=True)
Expand Down Expand Up @@ -91,7 +92,7 @@ def evolve_envelope_non_centered(

# Getting the phase of the envelope on axis.
if use_phase:
phases = np.angle(a[:, 0])
phases = unwrap(np.angle(a[:, 0]))

# Loop over z.
for j in range(nz - 1, -1, -1):
Expand All @@ -101,16 +102,6 @@ def evolve_envelope_non_centered(
d_theta1 = phases[j + 1] - phases[j]
d_theta2 = phases[j + 2] - phases[j + 1]

# Prevent phase jumps bigger than 1.5*pi.
if d_theta1 < -1.5 * np.pi:
d_theta1 += 2 * np.pi
if d_theta2 < -1.5 * np.pi:
d_theta2 += 2 * np.pi
if d_theta1 > 1.5 * np.pi:
d_theta1 -= 2 * np.pi
if d_theta2 > 1.5 * np.pi:
d_theta2 -= 2 * np.pi

# Calculate D factor [Eq. (6)].
D_jkn = (1.5 * d_theta1 - 0.5 * d_theta2) * inv_dz

Expand Down
59 changes: 59 additions & 0 deletions wake_t/physics_models/laser/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Utilities for the laser envelope solver."""
import numpy as np
from wake_t.utilities.numba import njit_serial


@njit_serial
def unwrap(p, discont=None, axis=-1, period=6.283185307179586):
"""Numba version of numpy.unwrap.

The implementation is taken from
https://github.com/numba/numba/blob/main/numba/np/arraymath.py,
which currently is not yet included in the latest Numba release.
"""
if axis != -1:
msg = 'Value for argument "axis" is not supported'
raise ValueError(msg)
# Flatten to a 2D array, keeping axis -1
p_init = np.asarray(p).astype(np.float64)
init_shape = p_init.shape
last_axis = init_shape[-1]
p_new = p_init.reshape((p_init.size // last_axis, last_axis))
# Manipulate discont and period
if discont is None:
discont = period / 2
interval_high = period / 2
boundary_ambiguous = True
interval_low = -interval_high

slice1 = (slice(1, None, None),)

# Work on each row separately
for i in range(p_init.size // last_axis):
row = p_new[i]
dd = np.diff(row)
ddmod = np.mod(dd - interval_low, period) + interval_low
if boundary_ambiguous:
ddmod = np.where(
(ddmod == interval_low) & (dd > 0),
interval_high,
ddmod
)
ph_correct = ddmod - dd

ph_correct = np.where(
np.array([abs(x) for x in dd]) < discont,
0,
ph_correct
)
ph_ravel = np.where(
np.array([abs(x) for x in dd]) < discont,
0,
ph_correct
)
ph_correct = np.reshape(ph_ravel, ph_correct.shape)
up = np.copy(row)
up[slice1] = row[slice1] + ph_correct.cumsum()
p_new[i] = up

return p_new.reshape(init_shape)
Loading