Skip to content

Commit e7d4dde

Browse files
committed
Apply desired precision to floating point value of arithmetic op result
1 parent 3b2a27b commit e7d4dde

File tree

4 files changed

+92
-63
lines changed

4 files changed

+92
-63
lines changed

internal/test/test_util.go

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,16 @@ func jsonNumberEqual(n1 json.Number, v2 interface{}, fpArithSpec *nosqldb.FPArit
374374
prec := binaryPrecision(fpArithSpec.Precision)
375375
bf1 := big.NewFloat(f1).SetPrec(prec).SetMode(fpArithSpec.RoundingMode)
376376
bf2 := big.NewFloat(f2).SetPrec(prec).SetMode(fpArithSpec.RoundingMode)
377-
return bf1.Cmp(bf2) == 0
377+
if bf1.Cmp(bf2) == 0 {
378+
return true
379+
}
380+
381+
decimalPrec := int(fpArithSpec.Precision)
382+
if bf1.Text('G', decimalPrec) == bf2.Text('G', decimalPrec) {
383+
return true
384+
}
385+
386+
return false
378387

379388
case uint8, uint16, uint32, uint64:
380389
i64, err := n1.Int64()
@@ -389,7 +398,7 @@ func jsonNumberEqual(n1 json.Number, v2 interface{}, fpArithSpec *nosqldb.FPArit
389398
if !ok {
390399
return false
391400
}
392-
return ratValueEqual(rat1, v2, &nosqldb.Decimal32)
401+
return ratValueEqual(rat1, v2, fpArithSpec)
393402

394403
case json.Number:
395404
rat1, ok := new(big.Rat).SetString(n1.String())
@@ -402,42 +411,27 @@ func jsonNumberEqual(n1 json.Number, v2 interface{}, fpArithSpec *nosqldb.FPArit
402411
return false
403412
}
404413

405-
return ratValueEqual(rat1, rat2, &nosqldb.Decimal32)
414+
return ratValueEqual(rat1, rat2, fpArithSpec)
406415

407416
default:
408417
return false
409418
}
410419
}
411420

412-
func showRat(rat *big.Rat) string {
413-
bf1, _, _ := big.ParseFloat(rat.RatString(), 10, 24, big.ToNearestEven)
414-
return bf1.Text('G', 7)
415-
}
416-
417421
func ratValueEqual(rat1, rat2 *big.Rat, fpArithSpec *nosqldb.FPArithSpec) bool {
418422
if rat1.Cmp(rat2) == 0 {
419423
return true
420424
}
421425

422-
var err error
423426
var bf1, bf2 *big.Float
424-
425-
prec := binaryPrecision(fpArithSpec.Precision)
426-
427-
bf1, _, err = big.ParseFloat(rat1.RatString(), 10, prec, fpArithSpec.RoundingMode)
428-
if err != nil {
429-
return false
430-
}
431-
432-
bf2, _, err = big.ParseFloat(rat2.RatString(), 10, prec, fpArithSpec.RoundingMode)
433-
if err != nil {
434-
return false
435-
}
427+
bf1 = new(big.Float).SetRat(rat1)
428+
bf2 = new(big.Float).SetRat(rat2)
436429

437430
if bf1.Cmp(bf2) == 0 {
438431
return true
439432
}
440433

434+
// as last resort, compare the string representation.
441435
decimalPrec := int(fpArithSpec.Precision)
442436
if bf1.Text('G', decimalPrec) == bf2.Text('G', decimalPrec) {
443437
return true

nosqldb/arith_op_iter.go

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ import (
1818
)
1919

2020
var _ planIter = (*arithOpIter)(nil)
21+
var (
22+
divByZeroErr = nosqlerr.NewIllegalState("arithmetic operation failed: divide by zero")
23+
ratZero = new(big.Rat).SetInt64(0)
24+
)
2125

2226
// arithOpIter represents a plan iterator that implements the arithmetic operations.
2327
//
@@ -159,13 +163,7 @@ func (iter *arithOpIter) next(rcb *runtimeControlBlock) (more bool, err error) {
159163
opResult = int(1)
160164
}
161165

162-
// Recover from panics such as "divide by zero".
163-
defer func() {
164-
if e := recover(); e != nil {
165-
rcb.trace(1, "arithOpIter.next(): %v", e)
166-
err = fmt.Errorf("arithmetic operation failed: %v", e)
167-
}
168-
}()
166+
fpArithSpec := rcb.getRequest().getFPArithSpec()
169167

170168
for i, argIter := range iter.argIters {
171169
more, err = argIter.next(rcb)
@@ -180,7 +178,7 @@ func (iter *arithOpIter) next(rcb *runtimeControlBlock) (more bool, err error) {
180178

181179
argVal := argIter.getResult(rcb)
182180
if argVal == types.NullValueInstance {
183-
rcb.trace(1, "argVal: %v, ================= got nullValue", argVal)
181+
rcb.trace(1, "argVal: %v, got nullValue", argVal)
184182
rcb.setRegValue(iter.resultReg, types.NullValueInstance)
185183
state.done()
186184
return true, nil
@@ -189,13 +187,25 @@ func (iter *arithOpIter) next(rcb *runtimeControlBlock) (more bool, err error) {
189187
op := iter.ops[i]
190188
switch argVal := argVal.(type) {
191189
case int:
192-
opResult, err = calcInt(op, opResult, argVal)
190+
if (op == '/' || op == 'd') && argVal == 0 {
191+
return false, divByZeroErr
192+
}
193+
opResult, err = calcInt(op, opResult, argVal, fpArithSpec)
193194
case int64:
194-
opResult, err = calcInt64(op, opResult, argVal)
195+
if (op == '/' || op == 'd') && argVal == 0 {
196+
return false, divByZeroErr
197+
}
198+
opResult, err = calcInt64(op, opResult, argVal, fpArithSpec)
195199
case float64:
196-
opResult, err = calcFloat64(op, opResult, argVal)
200+
if (op == '/' || op == 'd') && argVal == 0 {
201+
return false, divByZeroErr
202+
}
203+
opResult, err = calcFloat64(op, opResult, argVal, fpArithSpec)
197204
case *big.Rat:
198-
opResult, err = calcBigRat(op, opResult, argVal, rcb.getRequest().getFPArithSpec())
205+
if (op == '/' || op == 'd') && argVal.Cmp(ratZero) == 0 {
206+
return false, divByZeroErr
207+
}
208+
opResult, err = calcBigRat(op, opResult, argVal, fpArithSpec)
199209
default:
200210
return false, nosqlerr.NewIllegalState("operand in "+
201211
"arithmetic operation has illegal type. "+
@@ -207,14 +217,14 @@ func (iter *arithOpIter) next(rcb *runtimeControlBlock) (more bool, err error) {
207217
}
208218
}
209219

210-
rcb.trace(1, "============ arithOp result: %v", opResult)
220+
rcb.trace(1, "arithOp result: %v", opResult)
211221
rcb.setRegValue(iter.resultReg, opResult)
212222
state.done()
213223
return true, nil
214224
}
215225

216226
// caclInt calculates the result of applying the specified operation to the operand of int type.
217-
func calcInt(op byte, currRes interface{}, v int) (res interface{}, err error) {
227+
func calcInt(op byte, currRes interface{}, v int, fpSpec *FPArithSpec) (res interface{}, err error) {
218228
switch r := currRes.(type) {
219229
case int:
220230
switch op {
@@ -258,15 +268,17 @@ func calcInt(op byte, currRes interface{}, v int) (res interface{}, err error) {
258268

259269
switch op {
260270
case '+':
261-
res = r.Add(r, newVal)
271+
r = r.Add(r, newVal)
262272
case '-':
263-
res = r.Sub(r, newVal)
273+
r = r.Sub(r, newVal)
264274
case '*':
265-
res = r.Mul(r, newVal)
275+
r = r.Mul(r, newVal)
266276
case '/', 'd':
267-
res = r.Quo(r, newVal)
277+
r = r.Quo(r, newVal)
268278
}
269279

280+
res = setPrecAndRounding(r, fpSpec)
281+
270282
default:
271283
return currRes, fmt.Errorf("unsupported result type for arithmetic operation: %T", currRes)
272284
}
@@ -279,7 +291,7 @@ func calcInt(op byte, currRes interface{}, v int) (res interface{}, err error) {
279291
}
280292

281293
// calcInt64 calculates the result of applying the specified operation to the operand of int64 type.
282-
func calcInt64(op byte, currRes interface{}, v int64) (res interface{}, err error) {
294+
func calcInt64(op byte, currRes interface{}, v int64, fpSpec *FPArithSpec) (res interface{}, err error) {
283295
switch r := currRes.(type) {
284296
case int:
285297
switch op {
@@ -323,15 +335,17 @@ func calcInt64(op byte, currRes interface{}, v int64) (res interface{}, err erro
323335

324336
switch op {
325337
case '+':
326-
res = r.Add(r, newVal)
338+
r = r.Add(r, newVal)
327339
case '-':
328-
res = r.Sub(r, newVal)
340+
r = r.Sub(r, newVal)
329341
case '*':
330-
res = r.Mul(r, newVal)
342+
r = r.Mul(r, newVal)
331343
case '/', 'd':
332-
res = r.Quo(r, newVal)
344+
r = r.Quo(r, newVal)
333345
}
334346

347+
res = setPrecAndRounding(r, fpSpec)
348+
335349
default:
336350
return currRes, fmt.Errorf("unsupported result type for arithmetic operation: %T", currRes)
337351
}
@@ -344,7 +358,7 @@ func calcInt64(op byte, currRes interface{}, v int64) (res interface{}, err erro
344358
}
345359

346360
// calcFloat64 calculates the result of applying the specified operation to the operand of float64 type.
347-
func calcFloat64(op byte, currRes interface{}, v float64) (res interface{}, err error) {
361+
func calcFloat64(op byte, currRes interface{}, v float64, fpSpec *FPArithSpec) (res interface{}, err error) {
348362
switch r := currRes.(type) {
349363
case int:
350364
switch op {
@@ -388,15 +402,17 @@ func calcFloat64(op byte, currRes interface{}, v float64) (res interface{}, err
388402

389403
switch op {
390404
case '+':
391-
res = r.Add(r, newVal)
405+
r = r.Add(r, newVal)
392406
case '-':
393-
res = r.Sub(r, newVal)
407+
r = r.Sub(r, newVal)
394408
case '*':
395-
res = r.Mul(r, newVal)
409+
r = r.Mul(r, newVal)
396410
case '/', 'd':
397-
res = r.Quo(r, newVal)
411+
r = r.Quo(r, newVal)
398412
}
399413

414+
res = setPrecAndRounding(r, fpSpec)
415+
400416
default:
401417
return currRes, fmt.Errorf("unsupported result type for arithmetic operation: %T", currRes)
402418
}
@@ -430,23 +446,46 @@ func calcBigRat(op byte, currRes interface{}, v *big.Rat, fpSpec *FPArithSpec) (
430446

431447
switch op {
432448
case '+':
433-
rat.Add(rat, v)
449+
rat = rat.Add(rat, v)
434450
case '-':
435-
rat.Sub(rat, v)
451+
rat = rat.Sub(rat, v)
436452
case '*':
437-
rat.Mul(rat, v)
453+
rat = rat.Mul(rat, v)
438454
case '/', 'd':
439-
rat.Quo(rat, v)
455+
rat = rat.Quo(rat, v)
440456
default:
441457
return 0, fmt.Errorf("unsupported operation: %q", string(op))
442458
}
443459

444-
// TODO: apply the specified FPArithSpec to the result
445-
446-
res = rat
460+
res = setPrecAndRounding(rat, fpSpec)
447461
return res, nil
448462
}
449463

464+
// binaryPrecision returns the closest binary precision for the specified decimal precision.
465+
func binaryPrecision(decimalPrec uint) (prec uint) {
466+
switch decimalPrec {
467+
case 7:
468+
return 24
469+
case 16:
470+
return 53
471+
case 34:
472+
return 113
473+
default:
474+
return 0
475+
}
476+
}
477+
478+
func setPrecAndRounding(v *big.Rat, fpSpec *FPArithSpec) *big.Rat {
479+
if fpSpec == nil {
480+
return v
481+
}
482+
483+
prec := binaryPrecision(fpSpec.Precision)
484+
bf := new(big.Float).SetRat(v).SetPrec(prec).SetMode(fpSpec.RoundingMode)
485+
res, _ := bf.Rat(v)
486+
return res
487+
}
488+
450489
func (iter *arithOpIter) getPlan() string {
451490
return iter.planIterDelegate.getExecPlan(iter)
452491
}

nosqldb/group_iter.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ func (aggr *aggrValue) add(rcb *runtimeControlBlock, countMemory bool, val inter
134134
aggr.gotNumericInput = true
135135
switch v := val.(type) {
136136
case int:
137-
aggr.value, err = calcInt('+', aggr.value, v)
137+
aggr.value, err = calcInt('+', aggr.value, v, fpSpec)
138138
case int64:
139-
aggr.value, err = calcInt64('+', aggr.value, v)
139+
aggr.value, err = calcInt64('+', aggr.value, v, fpSpec)
140140
case float64:
141-
aggr.value, err = calcFloat64('+', aggr.value, v)
141+
aggr.value, err = calcFloat64('+', aggr.value, v, fpSpec)
142142
case *big.Rat:
143143
aggr.value, err = calcBigRat('+', aggr.value, v, fpSpec)
144144
}

nosqldb/request.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,10 +1101,6 @@ func (r *QueryRequest) getShardID() int {
11011101
}
11021102

11031103
func (r *QueryRequest) getFPArithSpec() *FPArithSpec {
1104-
if r.FPArithSpec == nil {
1105-
r.FPArithSpec = &Decimal32
1106-
}
1107-
11081104
return r.FPArithSpec
11091105
}
11101106

0 commit comments

Comments
 (0)