Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General numeric type support #97

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/instrumented-jetreco.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ function parse_command_line(args)
arg_type = RecoStrategy.Strategy
default = RecoStrategy.Best

"--type", "-T"
help = """Numerical type to use for the reconstruction (Float32, Float64)"""
arg_type = Symbol
default = :Float64

"--nsamples", "-m"
help = "Number of measurement points to acquire."
arg_type = Int
Expand Down Expand Up @@ -331,9 +336,9 @@ function main()

# Try to read events into the correct type!
if JetReconstruction.is_ee(args[:algorithm])
jet_type = EEjet
jet_type = EEjet{eval(args[:type])}
else
jet_type = PseudoJet
jet_type = PseudoJet{eval(args[:type])}
end
events::Vector{Vector{jet_type}} = read_final_state_particles(args[:file],
maxevents = args[:maxevents],
Expand Down
10 changes: 6 additions & 4 deletions examples/jetreco.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,19 @@ function main()
logger = ConsoleLogger(stdout, Logging.Info)
global_logger(logger)
# Try to read events into the correct type!
if JetReconstruction.is_ee(args[:algorithm])
jet_type = EEjet
# If we don't have an algorithm we default to PseudoJet
if !isnothing(args[:algorithm])
JetReconstruction.is_ee(args[:algorithm])
jet_type = EEjet{Float64}
graeme-a-stewart marked this conversation as resolved.
Show resolved Hide resolved
else
jet_type = PseudoJet
jet_type = PseudoJet{Float64}
graeme-a-stewart marked this conversation as resolved.
Show resolved Hide resolved
end
events::Vector{Vector{jet_type}} = read_final_state_particles(args[:file],
maxevents = args[:maxevents],
skipevents = args[:skip],
T = jet_type)
if isnothing(args[:algorithm]) && isnothing(args[:power])
@warn "Neither algorithm nor power specified, defaulting to AntiKt"
@warn "Neither algorithm nor power specified, defaulting to pp event AntiKt"
args[:algorithm] = JetAlgorithm.AntiKt
end
jet_process(events, distance = args[:distance], algorithm = args[:algorithm],
Expand Down
7 changes: 7 additions & 0 deletions src/ClusterSequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ ClusterSequence(algorithm::JetAlgorithm.Algorithm, p::Real, R::Float64, strategy
Qtot)
end

function ClusterSequence{T}(algorithm::JetAlgorithm.Algorithm, p::Real, R::Float64,
strategy::RecoStrategy.Strategy, jets::Vector{T}, history,
Qtot) where {T <: FourMomentum}
ClusterSequence{T}(algorithm, Float64(p), R, strategy, jets, length(jets), history,
Qtot)
end

"""
add_step_to_history!(clusterseq::ClusterSequence, parent1, parent2, jetp_index, dij)

Expand Down
23 changes: 15 additions & 8 deletions src/EEAlgorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,15 @@ function ee_genkt_algorithm(particles::AbstractArray{T, 1}; p = 1, R = 4.0,
recombination_particles = copy(particles)
sizehint!(recombination_particles, length(particles) * 2)
else
recombination_particles = EEjet[]
# We don't really know what element type we have here, so we need to
# drill down to a component to get that underlying type
ParticleType = typeof(px(particles[1]))
recombination_particles = EEjet{ParticleType}[]
sizehint!(recombination_particles, length(particles) * 2)
for i in eachindex(particles)
push!(recombination_particles,
EEjet(px(particles[i]), py(particles[i]), pz(particles[i]),
energy(particles[i])))
EEjet{ParticleType}(px(particles[i]), py(particles[i]), pz(particles[i]),
energy(particles[i])))
end
end

Expand All @@ -264,15 +267,18 @@ end

This function is the actual implementation of the e+e- jet clustering algorithm.
"""
function _ee_genkt_algorithm(; particles::Vector{EEjet}, p = 1, R = 4.0,
function _ee_genkt_algorithm(; particles::Vector{EEjet{T}}, p = 1, R = 4.0,
algorithm::JetAlgorithm.Algorithm = JetAlgorithm.Durham,
recombine = +)
recombine = +) where {T <: Real}
# Bounds
N::Int = length(particles)

# R squared
R2 = R^2

# Numerical type?
ParticleType = T

# Constant factor for the dij metric and the beam distance function
if algorithm == JetAlgorithm.Durham
dij_factor = 2.0
Expand All @@ -289,14 +295,15 @@ function _ee_genkt_algorithm(; particles::Vector{EEjet}, p = 1, R = 4.0,
# For optimised reconstruction generate an SoA containing the necessary
# jet information and populate it accordingly
# We need N slots for this array
eereco = StructArray{EERecoJet}(undef, N)
eereco = StructArray{EERecoJet{ParticleType}}(undef, N)
fill_reco_array!(eereco, particles, R2, p)

# Setup the initial history and get the total energy
history, Qtot = initial_history(particles)

clusterseq = ClusterSequence(algorithm, p, R, RecoStrategy.N2Plain, particles, history,
Qtot)
clusterseq = ClusterSequence{EEjet{ParticleType}}(algorithm, p, R, RecoStrategy.N2Plain,
particles, history,
Qtot)

# Run over initial pairs of jets to find nearest neighbours
get_angular_nearest_neighbours!(eereco, algorithm, dij_factor)
Expand Down
156 changes: 130 additions & 26 deletions src/EEjet.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,123 @@
"""
struct EEjet
struct EEjet{T <: Real} <: FourMomentum

The `EEjet` struct is a 4-momentum object used for the e+e jet reconstruction routines.
The `EEjet` struct is a 4-momentum object used for the e+e jet reconstruction
routines. Internal fields are used to track the reconstruction and to cache
values needed during the execution of the algorithm.

# Fields
- `px::Float64`: The x-component of the jet momentum.
- `py::Float64`: The y-component of the jet momentum.
- `pz::Float64`: The z-component of the jet momentum.
- `E::Float64`: The energy of the jet.
- `px::T`: The x-component of the jet momentum.
- `py::T`: The y-component of the jet momentum.
- `pz::T`: The z-component of the jet momentum.
- `E::T`: The energy of the jet.
- `_cluster_hist_index::Int`: The index of the cluster histogram.
- `_p2::Float64`: The squared momentum of the jet.
- `_inv_p::Float64`: The inverse momentum of the jet.
- `_p2::T`: The squared momentum of the jet.
- `_inv_p::T`: The inverse momentum of the jet.

# Type Parameters
- `T <: Real`: The type of the numerical values.
"""
mutable struct EEjet <: FourMomentum
px::Float64
py::Float64
pz::Float64
E::Float64
_p2::Float64
_inv_p::Float64
mutable struct EEjet{T <: Real} <: FourMomentum
px::T
py::T
pz::T
E::T
_p2::T
_inv_p::T
_cluster_hist_index::Int
end

function EEjet(px::Real, py::Real, pz::Real, E::Real, _cluster_hist_index::Int)
"""
Base.eltype(::Type{EEjet{T}}) where T

Return the element type of the `EEjet` struct.
"""
Base.eltype(::Type{EEjet{T}}) where {T} = T

Check warning on line 35 in src/EEjet.jl

View check run for this annotation

Codecov / codecov/patch

src/EEjet.jl#L35

Added line #L35 was not covered by tests

"""
EEjet(px::T, py::T, pz::T, E::T, _cluster_hist_index::Integer) where {T <: Real}

Constructs an `EEjet` object with the given momentum components `px`, `py`,
`pz`, energy `E`, and cluster histogram index `_cluster_hist_index`.

The constructed EEjet object will be parametrised by the type `T`.

# Arguments
- `px::T`: The x-component of the momentum.
- `py::T`: The y-component of the momentum.
- `pz::T`: The z-component of the momentum.
- `E::T`: The energy of the jet.
- `_cluster_hist_index::Integer`: The index of the cluster histogram.

# Returns
- The initialised `EEjet` object.

# Note
- `T` must be a subtype of `Real`.
- The `@muladd` macro is used to perform fused multiply-add operations for
computing `p2`.
- The `@fastmath` macro is used to allow the compiler to perform optimizations
for computing `inv_p`.
"""
function EEjet(px::T, py::T, pz::T, E::T, _cluster_hist_index::Integer) where {T <: Real}
@muladd p2 = px * px + py * py + pz * pz
inv_p = @fastmath 1.0 / sqrt(p2)
EEjet(px, py, pz, E, p2, inv_p, _cluster_hist_index)
EEjet{T}(px, py, pz, E, p2, inv_p, _cluster_hist_index)
end

EEjet(px::Real, py::Real, pz::Real, E::Real) = EEjet(px, py, pz, E, 0)
"""
EEjet(px::T, py::T, pz::T, E::T) where {T <: Real}

Constructs an `EEjet` object with the given momentum components `px`, `py`,
`pz`, energy `E`, and the cluster histogram index set to zero.

The constructed EEjet object will be parametrised by the type `T`.

# Arguments
- `px::T`: The x-component of the momentum.
- `py::T`: The y-component of the momentum.
- `pz::T`: The z-component of the momentum.
- `E::T`: The energy of the jet.

# Returns
- The initialised `EEjet` object.
"""
EEjet(px::T, py::T, pz::T, E::T) where {T <: Real} = EEjet(px, py, pz, E, 0)

"""
EEjet{U}(px::T, py::T, pz::T, E::T) where {T <: Real, U <: Real}

Constructs an `EEjet` object with conversion of the given momentum components
(`px`, `py`, `pz`) and energy (`E`) from type `T` to type `U`.

# Arguments
- `px::T`: The x-component of the momentum.
- `py::T`: The y-component of the momentum.
- `pz::T`: The z-component of the momentum.
- `E::T`: The energy.

# Type Parameters
- `T <: Real`: The type of the input momentum components and energy.
- `U <: Real`: The type to which the input values will be converted

# Returns
An `EEjet` object with the momentum components and energy parametrised to type
`U`.
"""
EEjet{U}(px::T, py::T, pz::T, E::T) where {T <: Real, U <: Real} = EEjet(U(px), U(py),
U(pz), U(E), 0)

"""
EEjet(pj::PseudoJet) -> EEjet

Constructs an `EEjet` object from a given `PseudoJet` object `pj`.

# Arguments
- `pj::PseudoJet`: A `PseudoJet` object used to create the `EEjet`.

# Returns
- An `EEjet` object initialized with the same properties of the given `PseudoJet`.
"""
EEjet(pj::PseudoJet) = EEjet(px(pj), py(pj), pz(pj), energy(pj), cluster_hist_index(pj))

p2(eej::EEjet) = eej._p2
Expand Down Expand Up @@ -87,15 +175,31 @@
" cluster_hist_index: ", eej._cluster_hist_index, ")")
end

# Optimised reconstruction struct for e+e jets
"""
mutable struct EERecoJet{T <: Real}

Optimised struct for e+e jets reconstruction, to be used with StructArrays.

# Fields
- `index::Int`: The index of the jet.
- `nni::Int`: The nearest neighbour index.
- `nndist::T`: The distance to the nearest neighbour.
- `dijdist::T`: The distance between jets.
- `nx::T`: The x-component of the jet's momentum.
- `ny::T`: The y-component of the jet's momentum.
- `nz::T`: The z-component of the jet's momentum.
- `E2p::T`: The energy raised to the power of 2p for this jet.

mutable struct EERecoJet
# Type Parameters
- `T <: Real`: The type of the numerical values.
"""
mutable struct EERecoJet{T <: Real}
index::Int
nni::Int
nndist::Float64
dijdist::Float64
nx::Float64
ny::Float64
nz::Float64
E2p::Float64
nndist::T
dijdist::T
nx::T
ny::T
nz::T
E2p::T
end
30 changes: 19 additions & 11 deletions src/PlainAlgo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,16 @@ function plain_jet_reconstruct(particles::AbstractArray{T, 1}; p::Union{Real, No
# Integer p if possible
p = (round(p) == p) ? Int(p) : p

if T == PseudoJet
if T isa PseudoJet
# recombination_particles will become part of the cluster sequence, so size it for
# the starting particles and all N recombinations
recombination_particles = copy(particles)
sizehint!(recombination_particles, length(particles) * 2)
else
recombination_particles = PseudoJet[]
# We don't really know what element type we have here, so we need to
# drill down to a component to get that underlying type
ParticleType = typeof(px(particles[1]))
recombination_particles = PseudoJet{ParticleType}[]
sizehint!(recombination_particles, length(particles) * 2)
for i in eachindex(particles)
push!(recombination_particles,
Expand Down Expand Up @@ -278,23 +281,26 @@ generalised k_t algorithm.
- `clusterseq`: The resulting `ClusterSequence` object representing the
reconstructed jets.
"""
function _plain_jet_reconstruct(; particles::Vector{PseudoJet}, p = -1, R = 1.0,
function _plain_jet_reconstruct(; particles::Vector{PseudoJet{T}}, p = -1, R = 1.0,
algorithm::JetAlgorithm.Algorithm = JetAlgorithm.AntiKt,
recombine = +)
recombine = +) where {T <: Real}
# Bounds
N::Int = length(particles)
# Parameters
R2 = R^2

# Numerical type for this reconstruction
ParticleType = T

# Optimised compact arrays for determining the next merge step
# We make sure these arrays are type stable - have seen issues where, depending on the values
# returned by the methods, they can become unstable and performance degrades
kt2_array::Vector{Float64} = pt2.(particles) .^ p
phi_array::Vector{Float64} = phi.(particles)
rapidity_array::Vector{Float64} = rapidity.(particles)
kt2_array::Vector{ParticleType} = pt2.(particles) .^ p
phi_array::Vector{ParticleType} = phi.(particles)
rapidity_array::Vector{ParticleType} = rapidity.(particles)
nn::Vector{Int} = Vector(1:N) # nearest neighbours
nndist::Vector{Float64} = fill(float(R2), N) # geometric distances to the nearest neighbour
nndij::Vector{Float64} = zeros(N) # dij metric distance
nndist::Vector{ParticleType} = fill(float(R2), N) # geometric distances to the nearest neighbour
nndij::Vector{ParticleType} = zeros(N) # dij metric distance

# Maps index from the compact array to the clusterseq jet vector
clusterseq_index::Vector{Int} = collect(1:N)
Expand All @@ -304,8 +310,10 @@ function _plain_jet_reconstruct(; particles::Vector{PseudoJet}, p = -1, R = 1.0,
# Current implementation mutates the particles vector, so need to copy it
# for the cluster sequence (there is too much copying happening, so this
# needs to be rethought and reoptimised)
clusterseq = ClusterSequence(algorithm, p, R, RecoStrategy.N2Plain, particles, history,
Qtot)
clusterseq = ClusterSequence{PseudoJet{ParticleType}}(algorithm, p, R,
RecoStrategy.N2Plain, particles,
history,
Qtot)

# Initialize nearest neighbours
@simd for i in 1:N
Expand Down
Loading
Loading