Skip to content

Commit 128aab1

Browse files
committed
Calculation of identical functions with different coefficiants calculated in broadcast for efficiency
1 parent 3be63a1 commit 128aab1

File tree

2 files changed

+76
-21
lines changed

2 files changed

+76
-21
lines changed

src/combination.jl

+15-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ function calculate_covariancematrix(
2323
return covariance
2424
end
2525

26+
function eval_block_array(params,BA::Array{ObservableBlock})
27+
vcat(eval_block.(Ref(params),BA)...)
28+
end
29+
30+
function eval_block(params,B::ObservableBlock)
31+
if(B.ObsandCoeffs.Coeff == [[] for i in 1:length(B.ObsandCoeffs.Coeff)])
32+
vcat(B.f.(Ref(params))...)
33+
else
34+
vcat(B.f.(Ref(params),B.ObsandCoeffs.Coeff)...)
35+
end
36+
end
2637

2738

2839
function combinemeasurements(
@@ -35,16 +46,10 @@ function combinemeasurements(
3546
result::Float64=0.0
3647

3748
nmeas=length(m.measurement_values)
38-
39-
for i in 1:nmeas
40-
r1 = m.measurement_values[i] - m.observable_functions[m.measurement_observables[i]].f.obj.x(parameters)[1]
41-
for j in (i+1):nmeas
42-
r2 = m.measurement_values[j] - m.observable_functions[m.measurement_observables[j]].f.obj.x(parameters)[1]
43-
result += r1 * invcov[i,j] * r2
44-
end
45-
result += 0.5 * r1 * invcov[i,i] * r1
46-
end
4749

48-
final_result::Float64 = -result #- log((2*π)^(0.5*nmeas) * sqrtdetcov) TODO:delete factor?
50+
r1 = eval_block_array(parameters,m.observable_blocks) - m.measurement_values
51+
52+
final_result = -0.5 * transpose(r1)* invcov * r1
53+
4954
return final_result
5055
end

src/datahandling.jl

+61-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export Observable
55
export Uncertainties
66
export Correlation
77
export createmodel
8+
export ObservableBlock
89

910
using FunctionWrappers
1011
import FunctionWrappers: FunctionWrapper
@@ -16,10 +17,16 @@ struct RealRealFunc
1617
end
1718
(cb::RealRealFunc)(v) = cb.f(v)
1819

20+
struct ObservableBlock
21+
f::Function
22+
ObsandCoeffs::NamedTuple
23+
end
24+
1925

2026
struct Model
2127
observable_names::Vector{String}
2228
observable_functions::Vector{RealRealFunc}
29+
observable_blocks::Array{ObservableBlock}
2330

2431
measurement_names::Vector{String}
2532
measurement_observables::Vector{Int}
@@ -33,13 +40,22 @@ struct Model
3340
end
3441

3542

43+
3644
struct Observable
3745
name::String
3846
func::Function
47+
coeff::Array{Real}
3948
# TODO min::Real
4049
# TODO max::Real
4150
end
4251

52+
function Observable(
53+
name::String,
54+
func::Function
55+
)
56+
57+
Observable(name, func, [])
58+
end
4359

4460
struct Measurement
4561
name::String
@@ -85,14 +101,17 @@ function createmodel(
85101
correlations::AbstractArray{Correlation}
86102
)
87103

88-
observable_names, observable_functions = createobservables(observables)
104+
observable_names, observable_functions, unique_funcs = createobservables(observables)
89105

90106
measurement_names, measurement_observables, measurement_values, active_measurements = createmeasurements(measurements, observable_names)
91107

108+
observable_block = createblock(unique_funcs,observables,measurement_observables)
109+
92110
uncertainty_names, uncertainties, correlationmatrices = createuncertainties(measurements, correlations, active_measurements)
93111

94112
Model(observable_names,
95113
observable_functions,
114+
observable_block,
96115
measurement_names,
97116
measurement_observables,
98117
measurement_values,
@@ -102,25 +121,55 @@ function createmodel(
102121
correlationmatrices)
103122
end
104123

124+
function createblock(unique_funcs::Vector{Function},observables::AbstractArray{Observable},measuredobs::Vector{Int})
125+
126+
nunique_funcs = length(unique_funcs)
127+
blocks = Array{ObservableBlock}(undef,nunique_funcs)
128+
129+
130+
for i_unique_function in 1:nunique_funcs
131+
nobs = sum([observables[i].func == unique_funcs[i_unique_function] for i in measuredobs])
132+
names = Vector{String}(undef, nobs)
133+
coeffs = Vector{Array{Real}}(undef,nobs)
134+
i_obs = 1
135+
for i in measuredobs
136+
if(observables[i].func == unique_funcs[i_unique_function])
137+
names[i_obs] = observables[i].name
138+
coeffs[i_obs] = observables[i].coeff
139+
i_obs += 1
140+
end
141+
end
142+
blocks[i_unique_function] = ObservableBlock(unique_funcs[i_unique_function],(Names = names, Coeff = coeffs))
143+
end
144+
blocks
145+
end
105146

106147

107148
function createobservables(observables::AbstractArray{Observable})
108149
nobs = length(observables)
109150

110151
names = Vector{String}(undef, nobs)
111152
funcs = Vector{RealRealFunc}(undef, nobs)
153+
112154

113-
for i in 1:nobs
114-
names[i] = observables[i].name
115-
funcs[i] = RealRealFunc(observables[i].func)
155+
unique_funcs = convert(Array{Function,1},unique([observables[i].func for i in 1:nobs]))
156+
157+
for j in unique_funcs
158+
for k in 1:nobs
159+
if observables[k].func == j
160+
names[k] = observables[k].name
161+
funcs[k] = RealRealFunc(observables[k].func)
162+
end
163+
end
116164
end
165+
117166

118167
duplicates = findfirstduplicate(names)
119168
if(duplicates[1])
120169
throw(ArgumentError("Observable with the name \"" * duplicates[2] * "\" already exists."))
121170
end
122171

123-
names, funcs
172+
names, funcs, unique_funcs
124173
end
125174

126175

@@ -140,12 +189,7 @@ function createmeasurements(
140189

141190
nmeas = length(measurements)
142191

143-
nactives=0
144-
for i in 1:nmeas
145-
if(measurements[i].activity)
146-
nactives += 1
147-
end
148-
end
192+
nactives=sum([measurements[i].activity for i in 1:nmeas])
149193

150194
names = Vector{String}(undef, nactives)
151195
measuredobs = Vector{Int}(undef, nactives)
@@ -163,6 +207,12 @@ function createmeasurements(
163207
end
164208
end
165209

210+
sort_pattern = sortperm(measuredobs)
211+
names = names[sort_pattern]
212+
values = values[sort_pattern]
213+
actives = actives[sort_pattern]
214+
sort!(measuredobs)
215+
166216
names, measuredobs, values, actives
167217
end
168218

0 commit comments

Comments
 (0)