Skip to content

Commit

Permalink
use StructArrays, remove unnecessary structs, add quaternion compatib…
Browse files Browse the repository at this point in the history
…ility, example Drift tests
  • Loading branch information
mattsignorelli committed Nov 9, 2024
1 parent 4d9160a commit b990add
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 125 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
ReferenceFrameRotations = "74f56ac7-18b3-5285-802d-d4bd4f104033"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Distributions = "0.25"
GTPSA = "1.2.0"
ReferenceFrameRotations = "3"
StaticArrays = "1"
StructArrays = "0.6.18"
Unitful = "1.21.0"
julia = "1.9"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"

[targets]
test = ["Test", "Distributions"]
test = ["Test", "Distributions", "JET", "GTPSA"]
89 changes: 3 additions & 86 deletions src/BeamTracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module BeamTracking
using GTPSA,
ReferenceFrameRotations,
StaticArrays,
StructArrays,
Distributions,
Unitful

Expand All @@ -28,96 +29,12 @@ export Beam,
sinhc,
track!

# SoA ----------------------------------
Base.@kwdef struct Coords{T}
x::Vector{T}
px::Vector{T}
y::Vector{T}
py::Vector{T}
z::Vector{T}
pz::Vector{T}
end

function Coords(
n::Integer;
d_x::Distribution=Normal(0,0), d_px::Distribution=Normal(0,0),
d_y::Distribution=Normal(0,0), d_py::Distribution=Normal(0,0),
d_z::Distribution=Normal(0,0), d_pz::Distribution=Normal(0,0)
)
x = rand(d_x , n)
px = rand(d_px, n)
y = rand(d_y , n)
py = rand(d_py, n)
z = rand(d_z , n)
pz = rand(d_pz, n)

return Coords(x, px, y, py, z, pz)
end



struct Beam{S,T}
species::Species
beta_gamma_0::S
z::Coords{T}
end

function Beam(
n::Integer; species::Species=Species("electron"), beta_gamma_0=1,
d_x::Distribution=Normal(0,0), d_px::Distribution=Normal(0,0),
d_y::Distribution=Normal(0,0), d_py::Distribution=Normal(0,0),
d_z::Distribution=Normal(0,0), d_pz::Distribution=Normal(0,0),
)

coords = Coords(n; d_x=d_x, d_px=d_px, d_y=d_y, d_py=d_py, d_z=d_z, d_pz=d_pz)

return Beam(species, beta_gamma_0, coords)
end

# Creates a Beam as identity GTPSA
function Beam(d::Descriptor; species::Species=Species("electron"), beta_gamma_0=1)
GTPSA.numvars(d) == 6 || error("Invalid GTPSA Descriptor! Number of variables must be equal to 6.")
z = vars(d)
return Beam(species, beta_gamma_0, Coords([z[1]], [z[2]], [z[3]], [z[4]], [z[5]], [z[6]]))
end

#Create a Beam with single particle for testing
function Beam(
;species::Species=Species("electron"), beta_gamma_0=1,
x=0.0, px=0.0, y=0.0, py=0.0, z=0.0, pz=0.0,
)

coords = Coords(x=[x], px=[px], y=[y], py=[py], z=[z], pz=[pz])

return Beam(species, beta_gamma_0, coords)
end

#AoS from SoA for single particle
# Extract the phase space coord of a particle in a beam
Base.@kwdef struct Coord{T} <: FieldVector{6, T} # Just like Coords but only 1 Coord
x::T = 0.0
px::T = 0.0
y::T = 0.0
py::T = 0.0
z::T = 0.0
pz::T = 0.0
end

struct Particle{S,T}
species::Species
beta_gamma_0::S
z::Coord{T}
end

function Particle(b::Beam, n::Integer=1)
z = b.z
coord = Coord(z.x[n],z.px[n],z.y[n],z.py[n],z.z[n],z.pz[n])

return Particle(b.species, b.beta_gamma_0, coord)
end
include("types.jl")

# Empty tracking method ----------------
track!(beamf::Beam, ::Nothing, beami::Beam) = (beamf.z .= beami.z; return)
track!(beam::Beam, ::Nothing) = beam

# --------------------------------------
include("utils.jl")
Expand Down
45 changes: 22 additions & 23 deletions src/linear/Linear.jl → src/Linear/Linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,31 @@ end


"""
track!(beamf::Beam, ele::Linear.Drift, beami::Beam) -> beamf
track!(beam::Beam, ele::Linear.Drift) -> beam
Routine to tracking through a drift using the approximation and
including higher-order energy effects.
### Arguments
- `beam` -- Input/output beam before/after tracking through
- `ele` -- `Drift` type element
- `beamf` -- Output beam after tracking through
- `beami` -- Input beam before tracking through
"""
function track!(beamf::Beam, ele::Linear.Drift, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
function track!(beam::Beam, ele::Linear.Drift)
L = ele.L
zi = beami.z
zf = beamf.z
v = beam.v

gamma_ref = sr_gamma(beam.beta_gamma_ref)

@FastGTPSA! begin
@. zf.x = zi.x + zi.px * L
@. zf.px = zi.px
@. zf.y = zi.y + zi.py * L
@. zf.py = zi.py
@. zf.z = zi.z + zf.pz * L
@. zf.pz = zi.pz
@. v.x = v.x + v.px * L
@. v.px = v.px
@. v.y = v.y + v.py * L
@. v.py = v.py
@. v.z = v.z + v.pz*L/gamma_ref^2
@. v.pz = v.pz
end

return beamf
return beam
end


Expand All @@ -72,12 +71,12 @@ Routine to linearly tracking through a quadrupole
"""
function track!(beamf::Beam, ele::Linear.Quadrupole, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec
L = ele.L
q = chargeof(beami.species)

k1 = ele.B1 / brho(massof(beami.species),beami.beta_gamma_0,q) #quadrupole strengh
k1 = ele.B1 / brho(massof(beami.species),beami.beta_gamma_ref,q) #quadrupole strengh

k = sqrt(abs(k1))

Expand Down Expand Up @@ -121,13 +120,13 @@ Routine to linearly tracking through a Sector Magnet
"""
function track!(beamf::Beam, ele::Linear.SBend, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec
L = ele.L
q = chargeof(beami.species)

#curvature of B field
k = ele.B0 / brho(massof(beami.species),beami.beta_gamma_0,q)
k = ele.B0 / brho(massof(beami.species),beami.beta_gamma_ref,q)
kl = k * L


Expand All @@ -147,11 +146,11 @@ Routine to linearly tracking through a Combined Magnet
"""
function track!(beamf::Beam, ele::Linear.Combined, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec
L = ele.L

brho = brho(massof(beami.species),beami.beta_gamma_0,q)
brho = brho(massof(beami.species),beami.beta_gamma_ref,q)
ka = ele.B0 / brho #curvature
k1 = ele.B1 / brho #quad strength

Expand Down
12 changes: 6 additions & 6 deletions src/MatrixKick/MatrixKick.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ function track!(beamf::Beam, ele::Drift, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
@assert !(beamf.species === beami.species) "Input species must be equal to output!"
L = ele.L
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec

tilde_m = mass(beami.species) / pc_ref(beami.species, beami.beta_0)

Expand All @@ -52,8 +52,8 @@ track quadrupole
function track!(beamf::Beam, ele::Quadrupole, beami::Beam)
@assert !(beamf === beami) "Aliasing beamf === beami not allowed!"
L = ele.L
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec

tilde_m = massof(ele.species_ref) / ele.pc_ref
β0 = ele.pc_ref / ele.E_tot_ref
Expand All @@ -74,8 +74,8 @@ end # function track!(::Quadrupole)
track "matrix part" of quadrupole
"""
function trackQuadMf!(beamf::Beam, beami::Beam, s::Float64, kappa_num::Float64)
zi = beami.z
zf = beamf.z
zi = beami.vec
zf = beamf.vec

tilde_m = mass(beami.species) / pc_ref(beami.species, beami.beta_0)

Expand Down
6 changes: 3 additions & 3 deletions src/aapc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Temp code for AAPC
Base.@kwdef struct Species
name::String
end
massof(::Species) = 510998.95069
chargeof(::Species) = -1
massof(::Species) = 510998.95069*u"eV/c^2"
chargeof(::Species) = -1*u"q"

c_light = 2.99792458e8
c_light = 2.99792458e8*u"m/s"
112 changes: 112 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#=
Beam, Coord, and Particle type definitions. StructArrays is
used to handle an internal SoA layout of memory which also
allows us to mutate, so both the phase space Coord struct
(defined here) and Quaternion struct (defined in
ReferenceFrameRotations) can be static.
Particle struct simply goes from StructArrays SoA to AoS.
=#


# Static phase space coordinate vector
Base.@kwdef struct Coord{T} <: FieldVector{6, T}
x::T = 0.0
px::T = 0.0
y::T = 0.0
py::T = 0.0
z::T = 0.0
pz::T = 0.0
end

# Static quaternion type defined by ReferenceFrameRotations

struct Beam{T<:StructVector{<:Coord}, U<:Union{Nothing, StructVector{<:Quaternion}}}
species::Species
beta_gamma_ref::Float64
v::T
q::U
end

# Initialize a Beam with a single particle
function Beam(; species::Species=Species("electron"), beta_gamma_ref=1.0, spin::Union{Bool,Nothing}=nothing,
x::Number=0.0, px::Number=0.0, y::Number=0.0, py::Number=0.0, z::Number=0.0, pz::Number=0.0)

T = promote_type(typeof(x), typeof(px), typeof(y),typeof(py), typeof(z), typeof(pz))
v = StructArray{Coord{T}}((T[x], T[px], T[y], T[py], T[z], T[pz]))

if !isnothing(spin)
if spin == true
# Use "one(first(..))" etc for GTPSA - descriptor is implicit
q0 = T[one(first(v.x))]
q1 = T[zero(first(v.x))]
q2 = T[zero(first(v.x))]
q3 = T[zero(first(v.x))]
q = StructArray{Quaternion{T}}((q0, q1, q2, q3))
else
error("For no spin tracking, please omit the spin kwarg or set spin=nothing. This is to ensure type stability.")
end
else
q = nothing
end

return Beam(species, Float64(beta_gamma_ref), v, q)
end


# Initialize Beam as identity GTPSA
# Creates a Beam as identity GTPSA
function Beam(d::Descriptor; species::Species=Species("electron"), beta_gamma_ref=1.0, spin::Union{Bool,Nothing}=nothing)
GTPSA.numvars(d) == 6 || error("Invalid GTPSA Descriptor! Number of variables must be equal to 6.")
v = vars(d)
return Beam(species=species, beta_gamma_ref=beta_gamma_ref, spin=spin, x=v[1], px=v[2], y=v[3], py=v[4], z=v[5], pz=v[6])
end


# Initialize a Beam given some distributions in each coordinate
function Beam(n::Integer; species::Species=Species("electron"),
beta_gamma_ref=1.0, spin::Union{Bool,Nothing}=nothing,
d_x::Distribution=Normal(0,0), d_px::Distribution=Normal(0,0),
d_y::Distribution=Normal(0,0), d_py::Distribution=Normal(0,0),
d_z::Distribution=Normal(0,0), d_pz::Distribution=Normal(0,0) )

x = rand(d_x , n)
px = rand(d_px, n)
y = rand(d_y , n)
py = rand(d_py, n)
z = rand(d_z , n)
pz = rand(d_pz, n)

v = StructArray{Coord{eltype(x)}}((x, px, y, py, z, pz))

if !isnothing(spin)
if spin == true
q0 = ones(eltype(x), n)
q1 = zeros(eltype(x), n)
q2 = zeros(eltype(x), n)
q3 = zeros(eltype(x), n)
q = StructArray{Quaternion{eltype(x)}}((q0, q1, q2, q3))
else
error("For no spin tracking, please omit the spin kwarg or set spin=nothing. This is to ensure type stability.")
end
else
q = nothing
end

return Beam(species, Float64(beta_gamma_ref), v, q)
end

struct Particle{T,U<:Union{Nothing,Quaternion{T}}}
species::Species
beta_gamma_ref::Float64
v::Coord{T}
q::U
end

function Particle(beam::Beam, idx::Integer=1)
v = beam.v[idx] # StructArrays handles this!
q = isnothing(beam.q) ? nothing : beam.q[idx]
return Particle(beam.species, beam.beta_gamma_ref, v, q)
end
Loading

0 comments on commit b990add

Please sign in to comment.