Skip to content

Commit

Permalink
sse approx
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang committed Nov 10, 2023
1 parent 3f88060 commit f60ce8f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 57 deletions.
38 changes: 8 additions & 30 deletions test/train-sets/ref/active_interactive.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,6 @@ Connecting to localhost:12353...
Done
Sending unlabeled examples...
Request for example 0: tag="", prediction=0.0:
Provide? [0/1/skip]: Request for example 1: tag="", prediction=0.0:
Provide? [0/1/skip]: Request for example 2: tag="", prediction=0.076144:
Provide? [0/1/skip]: Request for example 3: tag="", prediction=0.147194:
Provide? [0/1/skip]: Request for example 4: tag="", prediction=0.060332:
Provide? [0/1/skip]: Request for example 5: tag="", prediction=0.060998:
Provide? [0/1/skip]: Request for example 6: tag="", prediction=0.090767:
Provide? [0/1/skip]: Request for example 7: tag="", prediction=0.049165:
Provide? [0/1/skip]: Request for example 8: tag="", prediction=0.147871:
Provide? [0/1/skip]: Request for example 9: tag="", prediction=0.099147:
Provide? [0/1/skip]: Request for example 10: tag="", prediction=0.006728:
Provide? [0/1/skip]: Request for example 11: tag="", prediction=0.16017:
Provide? [0/1/skip]: Request for example 12: tag="", prediction=0.117594:
Provide? [0/1/skip]: Request for example 13: tag="", prediction=0.052956:
Provide? [0/1/skip]: Request for example 14: tag="", prediction=0.17299:
Provide? [0/1/skip]: Request for example 15: tag="", prediction=0.521811:
Provide? [0/1/skip]: Request for example 16: tag="", prediction=0.761672:
Provide? [0/1/skip]: Request for example 17: tag="", prediction=0.992553:
Provide? [0/1/skip]: Request for example 18: tag="", prediction=0.360711:
Provide? [0/1/skip]: Request for example 19: tag="", prediction=0.195309:
Provide? [0/1/skip]:
active_interactor stderr:

Expand All @@ -44,17 +25,14 @@ average since example example current current cur
loss last counter weight label predict features
n.a. n.a. 1 1.0 unknown 0.0000 41
0.000000 0.000000 2 2.0 0.0000 0.0000 41
0.500000 1.000000 4 4.0 1.0000 0.0000 75
0.256866 0.013732 8 8.0 0.0000 0.1472 44
0.130685 0.004504 16 16.0 0.0000 0.0492 42
0.182888 0.235091 32 32.0 1.0000 0.5218 34
0.000000 n.a. 4 4.0 unknown 0.0000 62
0.000000 n.a. 8 8.0 unknown 0.0000 108
0.000000 n.a. 16 16.0 unknown 0.0000 31

finished run
number of examples = 40
weighted example sum = 40.000000
weighted label sum = 6.000000
average loss = 0.157566
best constant = 0.300000
best constant's loss = 0.210000
total feature number = 2522
number of examples = 21
weighted example sum = 21.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 1302

1 change: 1 addition & 0 deletions test/train-sets/ref/big_feature_poison.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
[warning] NAN prediction in example 2, forcing 0
[error] The features have too much magnitude
[warning] update is NAN, replacing with 0
[warning] update is NAN, replacing with 0
54 changes: 27 additions & 27 deletions vowpalwabbit/core/src/reductions/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,33 +153,33 @@ VW_WARNING_STATE_POP

static inline float inv_sqrt(float x)
{
return 1.0f / std::sqrt(x);
/*
#if !defined(VW_NO_INLINE_SIMD)
# if defined(__ARM_NEON__)
// Propagate into vector
float32x2_t v1 = vdup_n_f32(x);
// Estimate
float32x2_t e1 = vrsqrte_f32(v1);
// N-R iteration 1
float32x2_t e2 = vmul_f32(e1, vrsqrts_f32(v1, vmul_f32(e1, e1)));
// N-R iteration 2
float32x2_t e3 = vmul_f32(e2, vrsqrts_f32(v1, vmul_f32(e2, e2)));
// Extract result
return vget_lane_f32(e3, 0);
# elif defined(__SSE2__)
__m128 eta = _mm_load_ss(&x);
eta = _mm_rsqrt_ss(eta);
_mm_store_ss(&x, eta);
# else
x = quake_inv_sqrt(x);
# endif
#else
x = quake_inv_sqrt(x);
#endif
return x;
*/
#if !defined(VW_NO_INLINE_SIMD)
# if defined(__ARM_NEON__)
// Propagate into vector
float32x2_t v1 = vdup_n_f32(x);
// Estimate
float32x2_t e1 = vrsqrte_f32(v1);
// N-R iteration 1
float32x2_t e2 = vmul_f32(e1, vrsqrts_f32(v1, vmul_f32(e1, e1)));
// N-R iteration 2
float32x2_t e3 = vmul_f32(e2, vrsqrts_f32(v1, vmul_f32(e2, e2)));
// Extract result
return vget_lane_f32(e3, 0);
# elif defined(__SSE2__)
__m128 eta = _mm_load_ss(&x);
eta = _mm_rsqrt_ss(eta); // Fast approximate inverse square root
// One iteration of Newton-Raphson refinement:
__m128 half_x = _mm_set_ss(0.5f * x);
eta = _mm_mul_ss(eta, _mm_sub_ss(_mm_set_ss(1.5f), _mm_mul_ss(half_x, _mm_mul_ss(eta, eta))));
_mm_store_ss(&x, eta);
# else
x = quake_inv_sqrt(x);
# endif
#else
x = quake_inv_sqrt(x);
#endif

return x;
}

VW_WARNING_STATE_PUSH
Expand Down

0 comments on commit f60ce8f

Please sign in to comment.