From 0425a8c925e7300433e0103bdfd1be1fd43e88ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81ngel=20Ferran=20Pousa?= Date: Fri, 10 May 2024 16:06:25 +0200 Subject: [PATCH] Separate q into q and w, make q_species a float Potentially, this fixes a buf when `free_electrons_per_ion > `1` due to an inconsistency between `q_species_i` and `q`. --- .../qs_rz_baxevanis_ion/b_theta.py | 29 ++--- .../qs_rz_baxevanis_ion/plasma_particles.py | 102 +++++++++--------- .../qs_rz_baxevanis_ion/plasma_push/ab2.py | 13 ++- .../psi_and_derivatives.py | 61 ++++++----- .../qs_rz_baxevanis_ion/utils.py | 20 ++-- 5 files changed, 121 insertions(+), 104 deletions(-) diff --git a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/b_theta.py b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/b_theta.py index 9afb361..6a2ff11 100644 --- a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/b_theta.py +++ b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/b_theta.py @@ -9,7 +9,7 @@ @njit_serial() def calculate_b_theta_at_particles( - r_e, pr_e, q_e, q_center_e, gamma_e, + r_e, pr_e, w_e, w_center_e, gamma_e, q_e, r_i, ion_motion, psi_e, dr_psi_e, dxi_psi_e, @@ -67,9 +67,11 @@ def calculate_b_theta_at_particles( Parameters ---------- - r_e, pr_e, q_e, gamma_e : ndarray - Radial position, momentum, charge and Lorenz factor of the plasma + r_e, pr_e, w_e, w_center_e, gamma_e : ndarray + Radial position, momentum, weight and Lorenz factor of the plasma electrons. + q_e : float + Charge of the plasma electron species. r_i : ndarray Radial position of the plasma ions. i_sort_e, i_sort_i : ndarray @@ -100,6 +102,9 @@ def calculate_b_theta_at_particles( Arrays where azimuthal magnetic field at the plasma electrons and ions will be stored. """ + # Only the magnetic field from the electrons is computed, so the equations + # below assume that q_i/m_i = 1. + # Calculate the A_i, B_i, C_i coefficients in Eq. (26). calculate_ABC( r_e, pr_e, gamma_e, @@ -108,8 +113,8 @@ def calculate_b_theta_at_particles( ) # Calculate the a_i, b_i coefficients in Eq. (27). - calculate_KU(r_e, q_e, q_center_e, A, K, U) - calculate_ai_bi_from_axis(r_e, q_e, q_center_e, A, B, C, K, U, a_0, a, b) + calculate_KU(r_e, q_e, w_e, w_center_e, A, K, U) + calculate_ai_bi_from_axis(r_e, q_e, w_e, w_center_e, A, B, C, K, U, a_0, a, b) # Calculate b_theta at plasma particles. calculate_b_theta_at_particle_centers(a, b, r_e, b_t_e) @@ -178,7 +183,7 @@ def calculate_b_theta_at_particle_centers(a, b, r, b_theta): @njit_serial(error_model='numpy') -def calculate_ai_bi_from_axis(r, q, q_center, A, B, C, K, U, a_0, a, b): +def calculate_ai_bi_from_axis(r, q, w, w_center, A, B, C, K, U, a_0, a, b): """ Calculate the values of a_i and b_i which are needed to determine b_theta at any r position. @@ -202,8 +207,8 @@ def calculate_ai_bi_from_axis(r, q, q_center, A, B, C, K, U, a_0, a, b): # Iterate over particles for i in range(i_start, n_part): r_i = r[i] - q_i = q[i] - q_center_i = q_center[i] + q_i = q * w[i] + q_center_i = q * w_center[i] A_i = A[i] B_i = B[i] C_i = C[i] @@ -277,7 +282,7 @@ def calculate_ABC(r, pr, gamma, psi, dr_psi, dxi_psi, b_theta_0, nabla_a2, A, B, C): """Calculate the A_i, B_i and C_i coefficients of the linear system. - The coefficients are missing the q_i term. They are multiplied by it + The coefficients are missing the q_i * w_i term. They are multiplied by it in following functions. """ n_part = r.shape[0] @@ -312,7 +317,7 @@ def calculate_ABC(r, pr, gamma, psi, dr_psi, dxi_psi, b_theta_0, @njit_serial(error_model='numpy') -def calculate_KU(r, q, q_center, A, K, U): +def calculate_KU(r, q, w, w_center, A, K, U): """Calculate the K_i and U_i values of the linear system.""" n_part = r.shape[0] @@ -322,8 +327,8 @@ def calculate_KU(r, q, q_center, A, K, U): for i in range(n_part): r_i = r[i] - q_i = q[i] - q_center_i = q_center[i] + q_i = q * w[i] + q_center_i = q * w_center[i] A_i = A[i] A_inv_r_i = A_i / r_i A_r_i = A_i * r_i diff --git a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_particles.py b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_particles.py index cb21f61..d7ce275 100644 --- a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_particles.py +++ b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_particles.py @@ -136,15 +136,15 @@ def initialize(self): pr = np.zeros(self.n_elec) pz = np.zeros(self.n_elec) gamma = np.ones(self.n_elec) - q = dr_p * r * self.radial_density(r) - q_center = q / 2 - dr_p ** 2 / 8 - q *= self.free_electrons_per_ion - q_center *= self.free_electrons_per_ion - m_e = np.ones(self.n_elec) - m_i = np.ones(self.n_elec) * self.ion_mass / ct.m_e - q_species_e = np.ones(self.n_elec) - q_species_i = - np.ones(self.n_elec) * self.free_electrons_per_ion tag = np.arange(self.n_elec, dtype=np.int32) + w = dr_p * r * self.radial_density(r) + w_center = w / 2 - dr_p ** 2 / 8 + + # Charge and mass of the macroparticles of each species. + self.m_elec = self.free_electrons_per_ion + self.m_ion = self.ion_mass / ct.m_e + self.q_species_elec = self.free_electrons_per_ion + self.q_species_ion = - self.free_electrons_per_ion # Combine arrays of both species. self.r = np.concatenate((r, r)) @@ -152,10 +152,8 @@ def initialize(self): self.pr = np.concatenate((pr, pr)) self.pz = np.concatenate((pz, pz)) self.gamma = np.concatenate((gamma, gamma)) - self.q = np.concatenate((q, -q)) - self.q_center = np.concatenate((q_center, -q_center)) - self.q_species = np.concatenate((q_species_e, q_species_i)) - self.m = np.concatenate((m_e, m_i)) + self.w = np.concatenate((w, w)) + self.w_center = np.concatenate((w_center, w_center)) self.r_to_x = np.ones(self.n_part, dtype=np.int32) self.tag = np.concatenate((tag, tag)) @@ -201,8 +199,8 @@ def sort(self): self.pr_elec, self.pz_elec, self.gamma_elec, - self.q_elec, - self.q_center_elec, + self.w_elec, + self.w_center_elec, self.r_to_x_elec, self.tag_elec, self._dr_e, @@ -217,8 +215,8 @@ def sort(self): self.pr_ion, self.pz_ion, self.gamma_ion, - self.q_ion, - self.q_center_ion, + self.w_ion, + self.w_center_ion, self.r_to_x_ion, self.tag_ion, self._dr_i, @@ -267,10 +265,10 @@ def calculate_fields(self): log(self.r_ion, self.log_r_ion) calculate_psi_and_derivatives_at_particles( - self.r_elec, self.log_r_elec, self.pr_elec, self.q_elec, - self.q_center_elec, - self.r_ion, self.log_r_ion, self.pr_ion, self.q_ion, - self.q_center_ion, + self.r_elec, self.log_r_elec, self.pr_elec, self.w_elec, + self.w_center_elec, self.q_species_elec, + self.r_ion, self.log_r_ion, self.pr_ion, self.w_ion, + self.w_center_ion, self.q_species_ion, self.ion_motion, self.ions_computed, self._sum_1_e, self._sum_2_e, self._sum_3_e, self._sum_1_i, self._sum_2_i, self._sum_3_i, @@ -278,21 +276,20 @@ def calculate_fields(self): self._psi_i, self._dr_psi_i, self._dxi_psi_i, self._psi, self._dr_psi, self._dxi_psi ) + update_gamma_and_pz( + self.gamma_elec, self.pz_elec, self.pr_elec, + self._a2_e, self._psi_e, self.q_species_elec, self.m_elec + ) if self.ion_motion: update_gamma_and_pz( - self.gamma, self.pz, self.pr, - self._a2, self._psi, self.q_species, self.m - ) - else: - update_gamma_and_pz( - self.gamma_elec, self.pz_elec, self.pr_elec, - self._a2_e, self._psi_e, self.q_species_elec, self.m_elec + self.gamma_ion, self.pz_ion, self.pr_ion, + self._a2_i, self._psi_i, self.q_species_ion, self.m_ion ) check_gamma(self.gamma_elec, self.pz_elec, self.pr_elec, self.max_gamma) calculate_b_theta_at_particles( - self.r_elec, self.pr_elec, self.q_elec, self.q_center_elec, - self.gamma_elec, + self.r_elec, self.pr_elec, self.w_elec, self.w_center_elec, + self.gamma_elec, self.q_species_elec, self.r_ion, self.ion_motion, self._psi_e, self._dr_psi_e, self._dxi_psi_e, @@ -323,18 +320,18 @@ def calculate_b_theta_at_grid(self, r_eval, b_theta): def evolve(self, dxi): """Evolve plasma particles to next longitudinal slice.""" + evolve_plasma_ab2( + dxi, self.r_elec, self.pr_elec, self.gamma_elec, self.m_elec, + self.q_species_elec, self.r_to_x_elec, + self._nabla_a2_e, self._b_t_0_e, + self._b_t_e, self._psi_e, self._dr_psi_e, self._dr_e, self._dpr_e + ) if self.ion_motion: evolve_plasma_ab2( - dxi, self.r, self.pr, self.gamma, self.m, self.q_species, - self.r_to_x, self._nabla_a2, self._b_t_0, self._b_t, - self._psi, self._dr_psi, self._dr, self._dpr - ) - else: - evolve_plasma_ab2( - dxi, self.r_elec, self.pr_elec, self.gamma_elec, self.m_elec, - self.r_to_x_elec, self.q_species_elec, - self._nabla_a2_e, self._b_t_0_e, - self._b_t_e, self._psi_e, self._dr_psi_e, self._dr, self._dpr + dxi, self.r_ion, self.pr_ion, self.gamma_ion, self.m_ion, + self.q_species_ion, self.r_to_x_ion, + self._nabla_a2_i, self._b_t_0_i, + self._b_t_i, self._psi_i, self._dr_psi_i, self._dr_i, self._dpr_i ) if self.store_history: @@ -344,7 +341,13 @@ def evolve(self, dxi): def calculate_weights(self): """Calculate the plasma density weights of each particle.""" - calculate_rho(self.q, self.pz, self.gamma, self._rho) + calculate_rho( + self.q_species_elec, self.w_elec, self.pz_elec, self.gamma_elec, + self._rho_e) + if self.ion_motion or not self.ions_computed: + calculate_rho( + self.q_species_ion, self.w_ion, self.pz_ion, self.gamma_ion, + self._rho_i) def deposit_rho(self, rho, rho_e, rho_i, r_fld, nr, dr): @@ -364,7 +367,9 @@ def deposit_rho(self, rho, rho_e, rho_i, r_fld, nr, dr): def deposit_chi(self, chi, r_fld, nr, dr): """Deposit plasma susceptibility on a grid slice.""" - calculate_chi(self.q_elec, self.pz_elec, self.gamma_elec, self._chi_e) + calculate_chi( + self.q_species_elec, self.w_elec, self.pz_elec, self.gamma_elec, + self._chi_e) deposit_plasma_particles( self.r_elec, self._chi_e, r_fld[0], nr, dr, chi, self.shape ) @@ -464,10 +469,8 @@ def _make_species_views(self): self.pr_elec = self.pr[:self.n_elec] self.pz_elec = self.pz[:self.n_elec] self.gamma_elec = self.gamma[:self.n_elec] - self.q_elec = self.q[:self.n_elec] - self.q_center_elec = self.q_center[:self.n_elec] - self.q_species_elec = self.q_species[:self.n_elec] - self.m_elec = self.m[:self.n_elec] + self.w_elec = self.w[:self.n_elec] + self.w_center_elec = self.w_center[:self.n_elec] self.r_to_x_elec = self.r_to_x[:self.n_elec] self.tag_elec = self.tag[:self.n_elec] @@ -477,10 +480,8 @@ def _make_species_views(self): self.pr_ion = self.pr[self.n_elec:] self.pz_ion = self.pz[self.n_elec:] self.gamma_ion = self.gamma[self.n_elec:] - self.q_ion = self.q[self.n_elec:] - self.q_center_ion = self.q_center[self.n_elec:] - self.q_species_ion = self.q_species[self.n_elec:] - self.m_ion = self.m[self.n_elec:] + self.w_ion = self.w[self.n_elec:] + self.w_center_ion = self.w_center[self.n_elec:] self.r_to_x_ion = self.r_to_x[self.n_elec:] self.tag_ion = self.tag[self.n_elec:] @@ -493,8 +494,11 @@ def _make_species_views(self): self._b_t_e = self._b_t[:self.n_elec] self._b_t_i = self._b_t[self.n_elec:] self._b_t_0_e = self._b_t_0[:self.n_elec] + self._b_t_0_i = self._b_t_0[self.n_elec:] self._nabla_a2_e = self._nabla_a2[:self.n_elec] + self._nabla_a2_i = self._nabla_a2[self.n_elec:] self._a2_e = self._a2[:self.n_elec] + self._a2_i = self._a2[self.n_elec:] self._sum_1_e = self._sum_1[:self.n_elec + 1] self._sum_2_e = self._sum_2[:self.n_elec + 1] self._sum_1_i = self._sum_1[self.n_elec + 1:] diff --git a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_push/ab2.py b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_push/ab2.py index aabf797..62d3309 100644 --- a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_push/ab2.py +++ b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/plasma_push/ab2.py @@ -20,10 +20,11 @@ def evolve_plasma_ab2( ---------- dxi : float Longitudinal step. - r, pr, gamma, m, q, r_to_x : ndarray - Radial position, radial momentum, Lorentz factor, mass and charge of - the plasma particles as well an array that keeps track of axis crosses - to convert from r to x. + r, pr, gamma, r_to_x : ndarray + Radial position, radial momentum, Lorentz factor as well an array that + keeps track of axis crosses to convert from r to x. + m, q : float + Mass and charge of the plasma species. nabla_a2, b_theta_0, b_theta, psi, dr_psi : ndarray Arrays with the value of the fields at the particle positions. dr, dpr : ndarray @@ -63,6 +64,8 @@ def calculate_derivatives( pr, gamma : ndarray Arrays containing the radial momentum and Lorentz factor of the plasma particles. + m, q : float + Mass and charge of the plasma species. b_theta_0 : ndarray Array containing the value of the azimuthal magnetic field from the beam distribution at the position of each plasma particle. @@ -80,8 +83,8 @@ def calculate_derivatives( radial momentum will be stored. """ # Calculate derivatives of r and pr. + q_over_m = q / m for i in range(pr.shape[0]): - q_over_m = q[i] / m[i] inv_psi_i = 1. / (1. + psi[i] * q_over_m) dpr[i] = (gamma[i] * dr_psi[i] * inv_psi_i - b_theta_bar[i] diff --git a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/psi_and_derivatives.py b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/psi_and_derivatives.py index c25ff9a..36c0376 100644 --- a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/psi_and_derivatives.py +++ b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/psi_and_derivatives.py @@ -11,8 +11,8 @@ @njit_serial(fastmath=True, error_model="numpy") def calculate_psi_and_derivatives_at_particles( - r_e, log_r_e, pr_e, q_e, q_center_e, - r_i, log_r_i, pr_i, q_i, q_center_i, + r_e, log_r_e, pr_e, w_e, w_center_e, q_e, + r_i, log_r_i, pr_i, w_i, w_center_i, q_i, ion_motion, calculate_ion_sums, sum_1_e, sum_2_e, sum_3_e, sum_1_i, sum_2_i, sum_3_i, @@ -24,13 +24,16 @@ def calculate_psi_and_derivatives_at_particles( Parameters ---------- - r_e, log_r_e, pr_e, q_e, q_center_e : ndarray - Radial position (and log), momentum, charge (and central charge) + r_e, log_r_e, pr_e, w_e, w_center_e : ndarray + Radial position (and log), momentum, weight (and central weight) of the plasma electrons. - - r_i, log_r_i, pr_i, q_i, q_center_i, dr_p_i : ndarray - Radial position (and log), momentum, charge (and central charge) + q_e : float + Charge of the plasma electron species. + r_i, log_r_i, pr_i, w_i, w_center_i, dr_p_i : ndarray + Radial position (and log), momentum, weight (and central weight) of the plasma ions. + q_i : float + Charge of the plasma ion species. ion_motion : bool Whether the ions can move. If `True`, the potential and its derivatives will also be calculated at the ions. @@ -53,11 +56,11 @@ def calculate_psi_and_derivatives_at_particles( """ # Calculate cumulative sums 1 and 2 (Eqs. (29) and (31)). - calculate_cumulative_sum_1(q_e, q_center_e, sum_1_e) - calculate_cumulative_sum_2(log_r_e, q_e, q_center_e, sum_2_e) + calculate_cumulative_sum_1(q_e, w_e, w_center_e, sum_1_e) + calculate_cumulative_sum_2(q_e, log_r_e, w_e, w_center_e, sum_2_e) if ion_motion or not calculate_ion_sums: - calculate_cumulative_sum_1(q_i, q_center_i, sum_1_i) - calculate_cumulative_sum_2(log_r_i, q_i, q_center_i, sum_2_i) + calculate_cumulative_sum_1(q_i, w_i, w_center_i, sum_1_i) + calculate_cumulative_sum_2(q_i, log_r_i, w_i, w_center_i, sum_2_i) # Calculate the psi and dr_psi background at the neighboring points. # For the electrons, compute the psi and dr_psi due to the ions at @@ -83,9 +86,9 @@ def calculate_psi_and_derivatives_at_particles( check_psi_derivative(dr_psi) # Calculate cumulative sum 3 (Eq. (32)). - calculate_cumulative_sum_3(r_e, pr_e, q_e, q_center_e, psi_e, sum_3_e) + calculate_cumulative_sum_3(q_e, r_e, pr_e, w_e, w_center_e, psi_e, sum_3_e) if ion_motion or not calculate_ion_sums: - calculate_cumulative_sum_3(r_i, pr_i, q_i, q_center_i, psi_i, sum_3_i) + calculate_cumulative_sum_3(q_i, r_i, pr_i, w_i, w_center_i, psi_i, sum_3_i) # Calculate the dxi_psi background at the neighboring points. # For the electrons, compute the psi and dr_psi due to the ions at @@ -104,50 +107,50 @@ def calculate_psi_and_derivatives_at_particles( @njit_serial(fastmath=True) -def calculate_cumulative_sum_1(q, q_center, sum_1_arr): +def calculate_cumulative_sum_1(q, w, w_center, sum_1_arr): """Calculate the cumulative sum in Eq. (29).""" sum_1 = 0. - for i in range(q.shape[0]): - q_i = q[i] - q_center_i = q_center[i] + for i in range(w.shape[0]): + w_i = w[i] + w_center_i = w_center[i] # Integrate up to particle centers. - sum_1_arr[i] = sum_1 + q_center_i + sum_1_arr[i] = sum_1 + q * w_center_i # And add all charge for next iteration. - sum_1 += q_i + sum_1 += q * w_i # Total sum after last particle. sum_1_arr[-1] = sum_1 @njit_serial(fastmath=True) -def calculate_cumulative_sum_2(log_r, q, q_center, sum_2_arr): +def calculate_cumulative_sum_2(q, log_r, w, w_center, sum_2_arr): """Calculate the cumulative sum in Eq. (31).""" sum_2 = 0. for i in range(log_r.shape[0]): log_r_i = log_r[i] - q_i = q[i] - q_center_i = q_center[i] + w_i = w[i] + w_center_i = w_center[i] # Integrate up to particle centers. - sum_2_arr[i] = sum_2 + q_center_i * log_r_i + sum_2_arr[i] = sum_2 + q * w_center_i * log_r_i # And add all charge for next iteration. - sum_2 += q_i * log_r_i + sum_2 += q * w_i * log_r_i # Total sum after last particle. sum_2_arr[-1] = sum_2 @njit_serial(fastmath=True, error_model="numpy") -def calculate_cumulative_sum_3(r, pr, q, q_center, psi, sum_3_arr): +def calculate_cumulative_sum_3(q, r, pr, w, w_center, psi, sum_3_arr): """Calculate the cumulative sum in Eq. (32).""" sum_3 = 0. for i in range(r.shape[0]): r_i = r[i] pr_i = pr[i] - q_i = q[i] - q_center_i = q_center[i] + w_i = w[i] + w_center_i = w_center[i] psi_i = psi[i] # Integrate up to particle centers. - sum_3_arr[i] = sum_3 + (q_center_i * pr_i) / (r_i * (1 + psi_i)) + sum_3_arr[i] = sum_3 + (q * w_center_i * pr_i) / (r_i * (1 + psi_i)) # And add all charge for next iteration. - sum_3 += (q_i * pr_i) / (r_i * (1 + psi_i)) + sum_3 += (q * w_i * pr_i) / (r_i * (1 + psi_i)) # Total sum after last particle. sum_3_arr[-1] = sum_3 diff --git a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/utils.py b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/utils.py index bb82d42..6db637a 100644 --- a/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/utils.py +++ b/wake_t/physics_models/plasma_wakefields/qs_rz_baxevanis_ion/utils.py @@ -15,23 +15,23 @@ def log(input, output): @njit_serial(error_model="numpy") -def calculate_chi(q, pz, gamma, chi): +def calculate_chi(q, w, pz, gamma, chi): """Calculate the contribution of each particle to `chi`.""" - for i in range(q.shape[0]): - q_i = q[i] + for i in range(w.shape[0]): + w_i = w[i] pz_i = pz[i] inv_gamma_i = 1. / gamma[i] - chi[i] = q_i / (1. - pz_i * inv_gamma_i) * inv_gamma_i + chi[i] = q * w_i / (1. - pz_i * inv_gamma_i) * inv_gamma_i @njit_serial(error_model="numpy") -def calculate_rho(q, pz, gamma, chi): +def calculate_rho(q, w, pz, gamma, rho): """Calculate the contribution of each particle to `rho`.""" - for i in range(q.shape[0]): - q_i = q[i] + for i in range(w.shape[0]): + w_i = w[i] pz_i = pz[i] inv_gamma_i = 1. / gamma[i] - chi[i] = q_i / (1. - pz_i * inv_gamma_i) + rho[i] = q * w_i / (1. - pz_i * inv_gamma_i) @njit_serial() @@ -175,10 +175,12 @@ def update_gamma_and_pz(gamma, pz, pr, a2, psi, q, m): pr, a2, psi : ndarray Arrays containing the radial momentum of the particles and the value of a2 and psi at the position of the particles. + q, m : float + Charge and mass of the plasma species. """ + q_over_m = q / m for i in range(pr.shape[0]): - q_over_m = q[i] / m[i] psi_i = psi[i] * q_over_m pz_i = ( (1 + pr[i] ** 2 + q_over_m ** 2 * a2[i] - (1 + psi_i) ** 2) /