Skip to content

Commit 35d4887

Browse files
committed
[hist] Implement weighted filling of RHistStats
1 parent d9fd616 commit 35d4887

File tree

2 files changed

+246
-3
lines changed

2 files changed

+246
-3
lines changed

hist/histv7/inc/ROOT/RHistStats.hxx

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
#ifndef ROOT_RHistStats
66
#define ROOT_RHistStats
77

8+
#include "RHistUtils.hxx"
89
#include "RLinearizedIndex.hxx"
10+
#include "RWeight.hxx"
911

1012
#include <cmath>
1113
#include <cstdint>
@@ -49,6 +51,14 @@ public:
4951
fSumWX3 += x * x * x;
5052
fSumWX4 += x * x * x * x;
5153
}
54+
55+
void Add(double x, double w)
56+
{
57+
fSumWX += w * x;
58+
fSumWX2 += w * x * x;
59+
fSumWX3 += w * x * x * x;
60+
fSumWX4 += w * x * x * x * x;
61+
}
5262
};
5363

5464
private:
@@ -232,6 +242,15 @@ private:
232242
}
233243
}
234244

245+
template <std::size_t I, std::size_t N, typename... A>
246+
void FillImpl(const std::tuple<A...> &args, double w)
247+
{
248+
fDimensionStats[I].Add(std::get<I>(args), w);
249+
if constexpr (I + 1 < N) {
250+
FillImpl<I + 1, N>(args, w);
251+
}
252+
}
253+
235254
public:
236255
/// Fill an entry into this statistics object.
237256
///
@@ -245,7 +264,8 @@ public:
245264
///
246265
/// \param[in] args the arguments for each dimension
247266
/// \par See also
248-
/// the \ref Fill(const A &... args) "variadic function template overload" accepting arguments directly
267+
/// the \ref Fill(const A &... args) "variadic function template overload" accepting arguments directly and the
268+
/// \ref Fill(const std::tuple<A...> &args, RWeight weight) "overload for weighted filling"
249269
template <typename... A>
250270
void Fill(const std::tuple<A...> &args)
251271
{
@@ -258,22 +278,67 @@ public:
258278
FillImpl<0>(args);
259279
}
260280

281+
/// Fill an entry into this statistics object with a weight.
282+
///
283+
/// \code
284+
/// ROOT::Experimental::RHistStats stats(2);
285+
/// auto args = std::make_tuple(8.5, 10.5);
286+
/// stats.Fill(args, ROOT::Experimental::RWeight(0.8));
287+
/// \endcode
288+
///
289+
/// \param[in] args the arguments for each dimension
290+
/// \param[in] weight the weight for this entry
291+
/// \par See also
292+
/// the \ref Fill(const A &... args) "variadic function template overload" accepting arguments directly and the
293+
/// \ref Fill(const std::tuple<A...> &args) "overload for unweighted filling"
294+
template <typename... A>
295+
void Fill(const std::tuple<A...> &args, RWeight weight)
296+
{
297+
if (sizeof...(A) != fDimensionStats.size()) {
298+
throw std::invalid_argument("invalid number of arguments to Fill");
299+
}
300+
fNEntries++;
301+
double w = weight.fValue;
302+
fSumW += w;
303+
fSumW2 += w * w;
304+
FillImpl<0, sizeof...(A)>(args, w);
305+
}
306+
261307
/// Fill an entry into this statistics object.
262308
///
263309
/// \code
264310
/// ROOT::Experimental::RHistStats stats(2);
265311
/// stats.Fill(8.5, 10.5);
266312
/// \endcode
313+
/// For weighted filling, pass an RWeight as the last argument:
314+
/// \code
315+
/// ROOT::Experimental::RHistStats stats(2);
316+
/// stats.Fill(8.5, 10.5, ROOT::Experimental::RWeight(0.8));
317+
/// \endcode
267318
///
268319
/// Throws an exception if the number of arguments does not match the number of dimensions.
269320
///
270321
/// \param[in] args the arguments for each dimension
271322
/// \par See also
272-
/// the \ref Fill(const std::tuple<A...> &args) "function overload" accepting `std::tuple`
323+
/// the function overloads accepting `std::tuple` \ref Fill(const std::tuple<A...> &args) "for unweighted filling"
324+
/// and \ref Fill(const std::tuple<A...> &args, RWeight) "for weighted filling"
273325
template <typename... A>
274326
void Fill(const A &...args)
275327
{
276-
Fill(std::forward_as_tuple(args...));
328+
auto t = std::forward_as_tuple(args...);
329+
if constexpr (std::is_same_v<typename Internal::LastType<A...>::type, RWeight>) {
330+
static constexpr std::size_t N = sizeof...(A) - 1;
331+
if (N != fDimensionStats.size()) {
332+
throw std::invalid_argument("invalid number of arguments to Fill");
333+
}
334+
fNEntries++;
335+
double w = std::get<N>(t).fValue;
336+
fSumW += w;
337+
fSumW2 += w * w;
338+
FillImpl<0, N>(t, w);
339+
} else {
340+
Fill(t);
341+
}
277342
}
278343

279344
/// %ROOT Streamer function to throw when trying to store an object of this class.

hist/histv7/test/hist_stats.cxx

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,40 @@ TEST(RHistStats, GetDimensionStats)
4848
}
4949
}
5050

51+
TEST(RHistStats, GetDimensionStatsWeighted)
52+
{
53+
RHistStats stats(3);
54+
ASSERT_EQ(stats.GetNEntries(), 0);
55+
56+
static constexpr std::size_t Entries = 20;
57+
for (std::size_t i = 0; i < Entries; i++) {
58+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
59+
}
60+
61+
ASSERT_EQ(stats.GetNEntries(), Entries);
62+
{
63+
const auto &dimensionStats = stats.GetDimensionStats(/*=0*/);
64+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX, 93.1);
65+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX2, 1330.0);
66+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX3, 20489.98);
67+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX4, 330265.6);
68+
}
69+
{
70+
const auto &dimensionStats = stats.GetDimensionStats(1);
71+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX, 2 * 93.1);
72+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX2, 4 * 1330.0);
73+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX3, 8 * 20489.98);
74+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX4, 16 * 330265.6);
75+
}
76+
{
77+
const auto &dimensionStats = stats.GetDimensionStats(2);
78+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX, 1330.0);
79+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX2, 330265.6);
80+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX3, 93164182.0);
81+
EXPECT_DOUBLE_EQ(dimensionStats.fSumWX4, 28108731464.8);
82+
}
83+
}
84+
5185
TEST(RHistStats, ComputeNEffectiveEntries)
5286
{
5387
RHistStats stats(1);
@@ -65,6 +99,24 @@ TEST(RHistStats, ComputeNEffectiveEntries)
6599
EXPECT_DOUBLE_EQ(stats.ComputeNEffectiveEntries(), Entries);
66100
}
67101

102+
TEST(RHistStats, ComputeNEffectiveEntriesWeighted)
103+
{
104+
RHistStats stats(1);
105+
ASSERT_EQ(stats.GetNEntries(), 0);
106+
EXPECT_EQ(stats.ComputeNEffectiveEntries(), 0);
107+
108+
static constexpr std::size_t Entries = 20;
109+
for (std::size_t i = 0; i < Entries; i++) {
110+
stats.Fill(1, RWeight(0.1 + 0.03 * i));
111+
}
112+
113+
ASSERT_EQ(stats.GetNEntries(), Entries);
114+
EXPECT_DOUBLE_EQ(stats.GetSumW(), 7.7);
115+
EXPECT_DOUBLE_EQ(stats.GetSumW2(), 3.563);
116+
// Cross-checked with TH1
117+
EXPECT_DOUBLE_EQ(stats.ComputeNEffectiveEntries(), 16.640471512770137);
118+
}
119+
68120
TEST(RHistStats, ComputeMean)
69121
{
70122
RHistStats stats(3);
@@ -84,6 +136,26 @@ TEST(RHistStats, ComputeMean)
84136
EXPECT_DOUBLE_EQ(stats.ComputeMean(2), 123.5);
85137
}
86138

139+
TEST(RHistStats, ComputeMeanWeighted)
140+
{
141+
RHistStats stats(3);
142+
ASSERT_EQ(stats.GetNEntries(), 0);
143+
EXPECT_EQ(stats.ComputeMean(/*=0*/), 0);
144+
EXPECT_EQ(stats.ComputeMean(1), 0);
145+
EXPECT_EQ(stats.ComputeMean(2), 0);
146+
147+
static constexpr std::size_t Entries = 20;
148+
for (std::size_t i = 0; i < Entries; i++) {
149+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
150+
}
151+
152+
ASSERT_EQ(stats.GetNEntries(), Entries);
153+
// Cross-checked with TH1
154+
EXPECT_DOUBLE_EQ(stats.ComputeMean(/*=0*/), 12.090909090909090);
155+
EXPECT_DOUBLE_EQ(stats.ComputeMean(1), 24.181818181818180);
156+
EXPECT_DOUBLE_EQ(stats.ComputeMean(2), 172.72727272727272);
157+
}
158+
87159
TEST(RHistStats, ComputeVariance)
88160
{
89161
RHistStats stats(3);
@@ -103,6 +175,26 @@ TEST(RHistStats, ComputeVariance)
103175
EXPECT_DOUBLE_EQ(stats.ComputeVariance(2), 12881.05);
104176
}
105177

178+
TEST(RHistStats, ComputeVarianceWeighted)
179+
{
180+
RHistStats stats(3);
181+
ASSERT_EQ(stats.GetNEntries(), 0);
182+
EXPECT_EQ(stats.ComputeVariance(/*=0*/), 0);
183+
EXPECT_EQ(stats.ComputeVariance(1), 0);
184+
EXPECT_EQ(stats.ComputeVariance(2), 0);
185+
186+
static constexpr std::size_t Entries = 20;
187+
for (std::size_t i = 0; i < Entries; i++) {
188+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
189+
}
190+
191+
ASSERT_EQ(stats.GetNEntries(), Entries);
192+
// Cross-checked with TH1::GetStdDev squared, numerical differences with EXPECT_DOUBLE_EQ
193+
EXPECT_FLOAT_EQ(stats.ComputeVariance(/*=0*/), 26.5371900);
194+
EXPECT_FLOAT_EQ(stats.ComputeVariance(1), 106.148760);
195+
EXPECT_FLOAT_EQ(stats.ComputeVariance(2), 13056.9256);
196+
}
197+
106198
TEST(RHistStats, ComputeStdDev)
107199
{
108200
RHistStats stats(3);
@@ -122,6 +214,26 @@ TEST(RHistStats, ComputeStdDev)
122214
EXPECT_DOUBLE_EQ(stats.ComputeStdDev(2), std::sqrt(12881.05));
123215
}
124216

217+
TEST(RHistStats, ComputeStdDevWeighted)
218+
{
219+
RHistStats stats(3);
220+
ASSERT_EQ(stats.GetNEntries(), 0);
221+
EXPECT_EQ(stats.ComputeStdDev(/*=0*/), 0);
222+
EXPECT_EQ(stats.ComputeStdDev(1), 0);
223+
EXPECT_EQ(stats.ComputeStdDev(2), 0);
224+
225+
static constexpr std::size_t Entries = 20;
226+
for (std::size_t i = 0; i < Entries; i++) {
227+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
228+
}
229+
230+
ASSERT_EQ(stats.GetNEntries(), Entries);
231+
// Cross-checked with TH1, numerical differences with EXPECT_DOUBLE_EQ
232+
EXPECT_FLOAT_EQ(stats.ComputeStdDev(/*=0*/), 5.15142602);
233+
EXPECT_FLOAT_EQ(stats.ComputeStdDev(1), 10.3028520);
234+
EXPECT_FLOAT_EQ(stats.ComputeStdDev(2), 114.266905);
235+
}
236+
125237
TEST(RHistStats, ComputeSkewness)
126238
{
127239
RHistStats stats(3);
@@ -142,6 +254,26 @@ TEST(RHistStats, ComputeSkewness)
142254
EXPECT_FLOAT_EQ(stats.ComputeSkewness(2), 0.66125456);
143255
}
144256

257+
TEST(RHistStats, ComputeSkewnessWeighted)
258+
{
259+
RHistStats stats(3);
260+
ASSERT_EQ(stats.GetNEntries(), 0);
261+
EXPECT_EQ(stats.ComputeSkewness(/*=0*/), 0);
262+
EXPECT_EQ(stats.ComputeSkewness(1), 0);
263+
EXPECT_EQ(stats.ComputeSkewness(2), 0);
264+
265+
static constexpr std::size_t Entries = 20;
266+
for (std::size_t i = 0; i < Entries; i++) {
267+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
268+
}
269+
270+
ASSERT_EQ(stats.GetNEntries(), Entries);
271+
// Cross-checked with TH1, numerical differences with EXPECT_DOUBLE_EQ
272+
EXPECT_FLOAT_EQ(stats.ComputeSkewness(/*=0*/), -0.50554999);
273+
EXPECT_FLOAT_EQ(stats.ComputeSkewness(1), -0.50554999);
274+
EXPECT_FLOAT_EQ(stats.ComputeSkewness(2), 0.12072240);
275+
}
276+
145277
TEST(RHistStats, ComputeKurtosis)
146278
{
147279
RHistStats stats(3);
@@ -162,6 +294,26 @@ TEST(RHistStats, ComputeKurtosis)
162294
EXPECT_FLOAT_EQ(stats.ComputeKurtosis(2), -0.84198253);
163295
}
164296

297+
TEST(RHistStats, ComputeKurtosisWeighted)
298+
{
299+
RHistStats stats(3);
300+
ASSERT_EQ(stats.GetNEntries(), 0);
301+
EXPECT_EQ(stats.ComputeKurtosis(/*=0*/), 0);
302+
EXPECT_EQ(stats.ComputeKurtosis(1), 0);
303+
EXPECT_EQ(stats.ComputeKurtosis(2), 0);
304+
305+
static constexpr std::size_t Entries = 20;
306+
for (std::size_t i = 0; i < Entries; i++) {
307+
stats.Fill(i, 2 * i, i * i, RWeight(0.1 + 0.03 * i));
308+
}
309+
310+
ASSERT_EQ(stats.GetNEntries(), Entries);
311+
// Cross-checked with TH1, numerical differences with EXPECT_DOUBLE_EQ
312+
EXPECT_FLOAT_EQ(stats.ComputeKurtosis(/*=0*/), -0.74828797);
313+
EXPECT_FLOAT_EQ(stats.ComputeKurtosis(1), -0.74828797);
314+
EXPECT_FLOAT_EQ(stats.ComputeKurtosis(2), -1.2483086);
315+
}
316+
165317
TEST(RHistStats, FillInvalidNumberOfArguments)
166318
{
167319
RHistStats stats1(1);
@@ -174,3 +326,29 @@ TEST(RHistStats, FillInvalidNumberOfArguments)
174326
EXPECT_NO_THROW(stats2.Fill(1, 2));
175327
EXPECT_THROW(stats2.Fill(1, 2, 3), std::invalid_argument);
176328
}
329+
330+
TEST(RHistStats, FillWeightInvalidNumberOfArguments)
331+
{
332+
RHistStats stats1(1);
333+
RHistStats stats2(2);
334+
335+
EXPECT_NO_THROW(stats1.Fill(1, RWeight(1)));
336+
EXPECT_THROW(stats1.Fill(1, 2, RWeight(1)), std::invalid_argument);
337+
338+
EXPECT_THROW(stats2.Fill(1, RWeight(1)), std::invalid_argument);
339+
EXPECT_NO_THROW(stats2.Fill(1, 2, RWeight(1)));
340+
EXPECT_THROW(stats2.Fill(1, 2, 3, RWeight(1)), std::invalid_argument);
341+
}
342+
343+
TEST(RHistStats, FillTupleWeightInvalidNumberOfArguments)
344+
{
345+
RHistStats stats1(1);
346+
RHistStats stats2(2);
347+
348+
EXPECT_NO_THROW(stats1.Fill(std::make_tuple(1), RWeight(1)));
349+
EXPECT_THROW(stats1.Fill(std::make_tuple(1, 2), RWeight(1)), std::invalid_argument);
350+
351+
EXPECT_THROW(stats2.Fill(std::make_tuple(1), RWeight(1)), std::invalid_argument);
352+
EXPECT_NO_THROW(stats2.Fill(std::make_tuple(1, 2), RWeight(1)));
353+
EXPECT_THROW(stats2.Fill(std::make_tuple(1, 2, 3), RWeight(1)), std::invalid_argument);
354+
}

0 commit comments

Comments
 (0)