diff --git a/bolt/spark/array.py b/bolt/spark/array.py index 8f34eb5..bab7519 100644 --- a/bolt/spark/array.py +++ b/bolt/spark/array.py @@ -1,6 +1,7 @@ from __future__ import print_function from numpy import asarray, unravel_index, prod, mod, ndarray, ceil, where, \ - r_, sort, argsort, array, random, arange, ones, expand_dims, sum + r_, sort, argsort, array, random, arange, ones, expand_dims, sum, add, \ + subtract, multiply, true_divide, isscalar from itertools import groupby from bolt.base import BoltArray @@ -1012,3 +1013,85 @@ def display(self): """ for x in self._rdd.take(10): print(x) + + def elementwise_binary(self, other, op): + """ + Apply an elementwise binary operation between two arrays or an array and scalar. + + Self and other must have the same shape, or other must be a scalar. + + Parameters + ---------- + other : scalar, ndarray, BoltArrayLocal, or BoltArraySpark + Value or array to perform a binary operation with element-wise + + op : function + Binary operator to use for elementwise operations, e.g. add, subtract + + Returns + ------- + BoltArraySpark + """ + if not isscalar(other) and not self.shape == other.shape: + raise ValueError("Shapes %s and %s must be equal" % (self.shape, other.shape)) + + if isinstance(other, ndarray): + from bolt.spark.construct import ConstructSpark + other = ConstructSpark.array(other, self._rdd.context, axis=range(0, self.split)) + + if isscalar(other): + return self.map(lambda x: op(x, other)) + + elif isinstance(other, BoltArraySpark): + + def func(record): + (k1, x), (k2, y) = record + return k1, op(x, y) + + rdd = self.tordd().zip(other.tordd()).map(func) + return self._constructor(rdd).__finalize__(self) + + else: + raise ValueError("other must be ndarray, BoltArrayLocal, or BoltArraySpark. Got %s" % type(other)) + + def __add__(self, other): + """ + Provides element-wise "+" functionality for BoltArraySpark. + """ + return self.elementwise_binary(other, add) + + def __radd__(self, other): + """ + Provides element-wise "+" functionality for BoltArraySpark if scalar provided first. + """ + return self.elementwise_binary(other, add) + + def __sub__(self, other): + """ + Provides element-wise "-" functionality for BoltArraySpark. + """ + return self.elementwise_binary(other, subtract) + + def __mul__(self, other): + """ + Provides element-wise "*" functionality for BoltArraySpark. + """ + return self.elementwise_binary(other, multiply) + + def __rmul__(self, other): + """ + Provides element-wise "*" functionality for BoltArraySpark if scalar provided first. + """ + return self.elementwise_binary(other, multiply) + + def __div__(self, other): + """ + Provides element-wise "/" functionality for BoltArraySpark. + """ + return self.elementwise_binary(other, true_divide) + + def __truediv__(self, other): + """ + Provides element-wise "/" functionality for BoltArraySpark for Python 3 or __future__ division. + """ + return self.elementwise_binary(other, true_divide) diff --git a/test/spark/test_spark_binaryoperators.py b/test/spark/test_spark_binaryoperators.py new file mode 100644 index 0000000..40b5425 --- /dev/null +++ b/test/spark/test_spark_binaryoperators.py @@ -0,0 +1,98 @@ +import pytest +from numpy import arange, true_divide +from bolt import array +from bolt.utils import allclose + +#from pyspark import SparkConf, SparkContext +# +#@pytest.fixture(scope="session", +# params=[pytest.mark.spark_yarn('yarn'), +# pytest.mark.spark_local('local')]) +#def sc(request): +# if request.param == 'local': +# conf = (SparkConf() +# .setMaster("local[2]") +# .setAppName("pytest-pyspark-local-testing") +# ) +# elif request.param == 'yarn': +# conf = (SparkConf() +# .setMaster("yarn-client") +# .setAppName("pytest-pyspark-yarn-testing") +# .set("spark.executor.memory", "1g") +# .set("spark.executor.instances", 2) +# ) +# request.addfinalizer(lambda: spark_context.stop()) +# +# spark_context = SparkContext(conf=conf) +# return spark_context + +def test_elementwise_spark(sc): + x = arange(1, 2*3*4+1).reshape(2, 3, 4) + y = 5*x + bx = array(x, sc, axis=(0,)) + by = array(y, sc, axis=(0,)) + + bxyadd = bx+by + bxyaddarr = bxyadd.toarray() + + bxysub = bx-by + bxysubarr = bxysub.toarray() + + bxymul = bx*by + bxymularr = bxymul.toarray() + + bxydiv = bx/by + bxydivarr = bxydiv.toarray() + + assert allclose(bxyaddarr, x+y) + assert allclose(bxysubarr, x-y) + assert allclose(bxymularr, x*y) + assert allclose(bxydivarr, true_divide(x,y)) + +def test_elementwise_mix(sc): + x = arange(1, 2*3*4+1).reshape(2, 3, 4) + y = x*3 + bx = array(x, sc, axis=(0,)) + by = array(y, sc, axis=(0,)) + + bxyadd = bx+by + bxyaddarr = bxyadd.toarray() + bxyaddloc = bx+y + bxyaddlocarr = bxyaddloc.toarray() + + bxysub = bx-by + bxysubarr = bxysub.toarray() + bxysubloc = bx-y + bxysublocarr = bxysubloc.toarray() + + bxymul = bx*by + bxymularr = bxymul.toarray() + bxymulloc = bx*y + bxymullocarr = bxymulloc.toarray() + + bxydiv = bx/by + bxydivarr = bxydiv.toarray() + bxydivloc = bx/y + bxydivlocarr = bxydivloc.toarray() + + assert allclose(bxyadd, bxyaddlocarr) + assert allclose(bxysub, bxysublocarr) + assert allclose(bxymul, bxymullocarr) + assert allclose(bxydiv, bxydivlocarr) + +def test_elementwise_scalar(sc): + x = arange(2*3*4).reshape(2, 3, 4) + bx = array(x, sc, axis=(0,)) + + bxfive = bx + 5 + bxfivearr = bxfive.toarray() + fivebx = 5 + bx + fivebxarr = fivebx.toarray() + + bxten = bx * 10 + bxtenarr = bxten.toarray() + tenbx = 10 * bx + tenbxarr = tenbx.toarray() + + assert allclose(bxfivearr, fivebxarr) + assert allclose(bxtenarr, tenbxarr)