Skip to content

Commit

Permalink
update show function
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-shalloo committed Dec 19, 2024
1 parent 3255057 commit f26ebb0
Showing 1 changed file with 88 additions and 22 deletions.
110 changes: 88 additions & 22 deletions lasy/laser.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import numpy as np
from axiprop.lib import PropagatorFFT2, PropagatorResampling
from scipy.constants import c
from scipy.constants import c,epsilon_0

from lasy.utils.grid import Grid, time_axis_indx
from lasy.utils.laser_utils import (
normalize_energy,
normalize_peak_field_amplitude,
normalize_peak_intensity,
get_duration,
get_w0
)
from lasy.utils.openpmd_output import write_to_openpmd_file

from mpl_toolkits.axes_grid1 import make_axes_locatable


class Laser:
"""
Expand Down Expand Up @@ -391,46 +395,108 @@ def write_to_file(
)
self.output_iteration += 1

def show(self, **kw):
def show(self,show_intensity=False,**kw):
"""
Show a 2D image of the laser amplitude.
Show a 2D image of the laser amplitude or intensity.
Parameters
----------
show_intensity : bool
if False the laser amplitude is plotted
if True then the intensity of the laser is plotted along with lineouts
and a measure of the pulse duration and spot size
**kw : additional arguments to be passed to matplotlib's imshow command
"""
temporal_field = self.grid.get_temporal_field()

if show_intensity:
F = epsilon_0 * c /2 * np.abs(self.grid.get_temporal_field())**2 /1e4
cbar_label = r'I (W/cm$^2$)'
else:
F = np.abs(self.grid.get_temporal_field())
cbar_label = r'$|E_{envelope}|$ (V/m)'

# Calculate spatial scales for the axes
if self.grid.hi[0] > 1:
# scale is meters
spatial_scale = 1
spatial_unit = r'(m)'
elif self.grid.hi[0] > 1e-3:
# scale is millimeters
spatial_scale = 1e-3
spatial_unit = r'(mm)'
else:
# scale is microns
spatial_scale = 1e-6
spatial_unit = r'($\mu m$)'

# Calculate temporal scales for the axes
if self.grid.hi[-1] > 1e-9:
# scale is nanoseconds
temporal_scale = 1e-9
temporal_unit = r'(ns)'
elif self.grid.hi[-1] > 1e-12:
# scale is picoseconds
temporal_scale = 1e-12
temporal_unit = r'(ps)'
else:
# scale is femtoseconds
temporal_scale = 1e-15
temporal_unit = r'(fs)'


if self.dim == "rt":
# Show field in the plane y=0, above and below axis, with proper sign for each mode
E = [
F_plot = [
np.concatenate(
((-1.0) ** m * temporal_field[m, ::-1], temporal_field[m])
((-1.0) ** m * F[m, ::-1], F[m])
)
for m in self.grid.azimuthal_modes
]
E = sum(E) # Sum all the modes
F_plot = sum(F_plot) # Sum all the modes
extent = [
self.grid.lo[-1],
self.grid.hi[-1],
-self.grid.hi[0],
self.grid.hi[0],
self.grid.lo[-1]/temporal_scale,
self.grid.hi[-1]/temporal_scale,
-self.grid.hi[0]/spatial_scale,
self.grid.hi[0]/spatial_scale,
]

else:
# In 3D show an image in the xt plane
i_slice = int(temporal_field.shape[1] // 2)
E = temporal_field[:, i_slice, :]
i_slice = int(F.shape[1] // 2)
F_plot = F[:, i_slice, :]
extent = [
self.grid.lo[-1],
self.grid.hi[-1],
self.grid.lo[0],
self.grid.hi[0],
self.grid.lo[-1]/temporal_scale,
self.grid.hi[-1]/temporal_scale,
self.grid.lo[0]/spatial_scale,
self.grid.hi[0]/spatial_scale,
]

import matplotlib.pyplot as plt

plt.imshow(abs(E), extent=extent, aspect="auto", origin="lower", **kw)
cb = plt.colorbar()
cb.set_label("$|E_{envelope}|$ (V/m)")
plt.xlabel("t (s)")
plt.ylabel("x (m)")
fig,ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
im = ax.imshow(F_plot, extent=extent, cmap='Reds',aspect="auto", origin="lower", **kw)
cb = fig.colorbar(im,cax=cax)
cb.set_label(cbar_label)
ax.set_xlabel(r"t "+temporal_unit)
ax.set_ylabel(r"x "+spatial_unit)

if show_intensity:
# Create projected lineouts along time and space
temporal_lineout = np.sum(F_plot,axis=0)/np.sum(F_plot,axis=0).max()
ax.plot(self.grid.axes[-1]/temporal_scale,
0.15*temporal_lineout * (extent[3]-extent[2]) + extent[2],c=(.3,.3,.3))

spatial_lineout = np.sum(F_plot,axis=1)/np.sum(F_plot,axis=1).max()
ax.plot(0.15*spatial_lineout * (extent[1]-extent[0]) + extent[0],
np.linspace(extent[2],extent[3],F_plot.shape[0]),c=(.3,.3,.3))

# Get the pulse duration
tau = 2 * get_duration(self.grid, self.dim) /temporal_scale
ax.text(0.75,0.9,r'$\tau$ = %.2f '%(tau) +temporal_unit[1:-1] ,transform=ax.transAxes)

# Get the spot size
w0 = get_w0(self.grid, self.dim) / spatial_scale
ax.text(0.75,0.85,r'$w_0$ = %.2f '%(w0) +spatial_unit[1:-1] ,transform=ax.transAxes)

0 comments on commit f26ebb0

Please sign in to comment.