Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ambrosia/spark_tools/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2022 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

EMPTY_VALUE_PARTITION: int = 0
8 changes: 5 additions & 3 deletions ambrosia/spark_tools/split_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ambrosia.spark_tools.stratification as strat_pkg
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.tools import split_tools
from ambrosia.tools.import_tools import spark_installed

Expand All @@ -26,7 +27,6 @@
HASH_COLUMN_NAME: str = "__hashed_ambrosia_column"
GROUPS_COLUMN: str = "group"
ROW_NUMBER: str = "__row_number"
EMPTY_VALUE: int = 0


def unite_spark_tables(*dataframes: types.SparkDataFrame) -> types.SparkDataFrame:
Expand Down Expand Up @@ -90,7 +90,7 @@ def udf_make_labels(row_number: int) -> str:
label_ind = (row_number - 1) // groups_size
return labels[label_ind]

window = Window.orderBy(HASH_COLUMN_NAME).partitionBy(spark_funcs.lit(EMPTY_VALUE))
window = Window.orderBy(HASH_COLUMN_NAME).partitionBy(spark_funcs.lit(EMPTY_VALUE_PARTITION))
result = hashed_dataframe.withColumn(ROW_NUMBER, spark_funcs.row_number().over(window)).withColumn(
GROUPS_COLUMN, spark_funcs.udf(udf_make_labels)(spark_funcs.col(ROW_NUMBER))
)
Expand Down Expand Up @@ -128,7 +128,9 @@ def udf_make_labels_with_find(row_number: int):
not_used_ids.withColumn(
ROW_NUMBER,
spark_funcs.row_number().over(
Window.orderBy(spark_funcs.lit(EMPTY_VALUE)).partitionBy(spark_funcs.lit(EMPTY_VALUE))
Window.orderBy(spark_funcs.lit(EMPTY_VALUE_PARTITION)).partitionBy(
spark_funcs.lit(EMPTY_VALUE_PARTITION)
)
),
)
.withColumn(GROUPS_COLUMN, spark_funcs.udf(udf_make_labels_with_find)(ROW_NUMBER))
Expand Down
131 changes: 102 additions & 29 deletions ambrosia/spark_tools/stat_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@

import ambrosia.tools.pvalue_tools as pvalue_pkg
import ambrosia.tools.theoretical_tools as theory_pkg
import ambrosia.tools.type_checks as cast_pkg
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.spark_tools.theory import get_stats_from_table
from ambrosia.tools.ab_abstract_component import ABStatCriterion
from ambrosia.tools.configs import Effects
from ambrosia.tools.import_tools import spark_installed
from ambrosia.tools.stat_criteria import TtestRelHelpful

if spark_installed():
import pyspark.sql.functions as F
from pyspark.sql.functions import col, row_number
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col, mean, row_number, variance
from pyspark.sql.window import Window


Expand Down Expand Up @@ -88,8 +92,7 @@ class TtestIndCriterionSpark(ABSparkCriterion):
Unit for pyspark independent T-test.
"""

__implemented_effect_types: List = ["absolute", "relative"]
__type_error_msg: str = f"Choose effect type from {__implemented_effect_types}"
implemented_effect_types: List = ["absolute", "relative"]
__data_parameters = ["mean_group_a", "mean_group_b", "std_group_a", "std_group_b", "nobs_group_a", "nobs_group_b"]

def __calc_and_cache_data_parameters(
Expand Down Expand Up @@ -127,18 +130,18 @@ def calculate_pvalue(
effect_type: str = "absolute",
**kwargs,
):
if effect_type not in TtestIndCriterionSpark.__implemented_effect_types:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
if effect_type not in self.implemented_effect_types:
raise ValueError(self._send_type_error_msg())
if not self.parameters_are_cached:
self.__calc_and_cache_data_parameters(group_a, group_b, column)
if effect_type == "absolute":
p_value = sps.ttest_ind_from_stats(
self.data_stats["mean_group_b"],
self.data_stats["std_group_b"],
self.data_stats["nobs_group_b"],
self.data_stats["mean_group_a"],
self.data_stats["std_group_a"],
self.data_stats["nobs_group_a"],
self.data_stats["mean_group_b"],
self.data_stats["std_group_b"],
self.data_stats["nobs_group_b"],
**kwargs,
).pvalue
elif effect_type == "relative":
Expand All @@ -163,7 +166,7 @@ def calculate_effect(
"mean_group_a"
]
else:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
raise ValueError(self._send_type_error_msg())
return effect

def calculate_conf_interval(
Expand All @@ -175,10 +178,12 @@ def calculate_conf_interval(
effect_type: str = "absolute",
**kwargs,
):
alpha = cast_pkg.transform_alpha_np(alpha)
if self.parameters_are_cached is not True:
self.__calc_and_cache_data_parameters(group_a, group_b, column)
if effect_type == "absolute":
alpha_corrected: float = pvalue_pkg.corrected_alpha(alpha, kwargs["alternative"])
alternative = "two-sided" if"alternative" not in kwargs else kwargs["alternative"]
alpha_corrected: float = pvalue_pkg.corrected_alpha(alpha, alternative)
quantiles, sd = theory_pkg.get_ttest_info_from_stats(
var_a=self.data_stats["std_group_a"] ** 2,
var_b=self.data_stats["std_group_b"] ** 2,
Expand All @@ -189,15 +194,15 @@ def calculate_conf_interval(
mean = self.data_stats["mean_group_b"] - self.data_stats["mean_group_a"]
left_ci: np.ndarray = mean - quantiles * sd
right_ci: np.ndarray = mean + quantiles * sd
return self._make_ci(left_ci, right_ci, kwargs["alternative"])
return self._make_ci(left_ci, right_ci, alternative)
elif effect_type == "relative":
conf_interval = self._apply_delta_method(alpha, **kwargs)[0]
return conf_interval
else:
raise ValueError(TtestIndCriterionSpark.__type_error_msg)
raise ValueError(self._send_type_error_msg())


class TtestRelativeCriterionSpark(ABSparkCriterion):
class TtestRelativeCriterionSpark(ABSparkCriterion, TtestRelHelpful):
"""
Relative ttest for spark
"""
Expand All @@ -213,15 +218,23 @@ def _rename_col(column: str, group: str) -> str:
def _calc_and_cache_data_parameters(
self, group_a: types.SparkDataFrame, group_b: types.SparkDataFrame, column: types.ColumnNameType
) -> None:
a_ = (
col_a: str = self._rename_col(column, "a")
col_b: str = self._rename_col(column, "b")
a_: DataFrame = (
group_a.withColumn(self.__ord_col, F.lit(1))
.withColumn(self.__add_index_name, row_number().over(Window().orderBy(self.__ord_col)))
.withColumnRenamed(column, self._rename_col(column, "a"))
.withColumn(
self.__add_index_name,
row_number().over(Window().orderBy(self.__ord_col).partitionBy(F.lit(EMPTY_VALUE_PARTITION))),
)
.withColumnRenamed(column, col_a)
)
b_ = (
b_: DataFrame = (
group_b.withColumn(self.__ord_col, F.lit(1))
.withColumn(self.__add_index_name, row_number().over(Window().orderBy(self.__ord_col)))
.withColumnRenamed(column, self._rename_col(column, "b"))
.withColumn(
self.__add_index_name,
row_number().over(Window().orderBy(self.__ord_col).partitionBy(F.lit(EMPTY_VALUE_PARTITION))),
)
.withColumnRenamed(column, col_b)
)

n_a_obs: int = group_a.count()
Expand All @@ -230,11 +243,25 @@ def _calc_and_cache_data_parameters(
if n_a_obs != n_b_obs:
raise ValueError("Size of group A and B must be equal")

both = a_.join(b_, self.__add_index_name, "inner").withColumn(
self.__diff, col(self._rename_col(column, "b")) - col(self._rename_col(column, "a"))
)
both: DataFrame = a_.join(b_, self.__add_index_name, "inner").withColumn(self.__diff, col(col_b) - col(col_a))

cov: float = both.stat.cov(col_a, col_b)
stats = both.select(
variance(col_a).alias("__var_a"),
variance(col_b).alias("__var_b"),
mean(col_a).alias("__mean_a"),
mean(col_b).alias("__mean_b"),
).first()
var_a: float = theory_pkg.unbiased_to_sufficient(stats["__var_a"], n_a_obs, is_std=False)
var_b: float = theory_pkg.unbiased_to_sufficient(stats["__var_b"], n_a_obs, is_std=False)

self.data_stats["mean"], self.data_stats["std"] = get_stats_from_table(both, self.__diff)
self.data_stats["n_obs"] = n_a_obs
self.data_stats["cov"] = cov
self.data_stats["var_a"] = var_a
self.data_stats["var_b"] = var_b
self.data_stats["mean_a"] = stats["__mean_a"]
self.data_stats["mean_b"] = stats["__mean_b"]
self.parameters_are_cached = True

def calculate_pvalue(
Expand All @@ -247,30 +274,76 @@ def calculate_pvalue(
):
self._recalc_cache(group_a, group_b, column)
if effect_type == Effects.abs.value:
if "alternative" in kwargs:
kwargs["alternative"] = theory_pkg.switch_alternative(kwargs["alternative"])
p_value = theory_pkg.ttest_1samp_from_stats(
mean=self.data_stats["mean"], std=self.data_stats["std"], n_obs=self.data_stats["n_obs"], **kwargs
)
)[
1
] # (stat, pvalue)
elif effect_type == Effects.rel.value:
raise NotImplementedError("Will be implemented later")
_, p_value = theory_pkg.apply_delta_method_by_stats(
size=self.data_stats["n_obs"],
mean_group_a=self.data_stats["mean_a"],
mean_group_b=self.data_stats["mean_b"],
var_group_a=self.data_stats["var_a"],
var_group_b=self.data_stats["var_b"],
cov_groups=self.data_stats["cov"],
transformation="fraction",
**kwargs
)
else:
raise ValueError(self._send_type_error_msg())
self._check_clear_cache()
return p_value

def calculate_conf_interval(
self,
group_a: types.SparkDataFrame,
group_b: types.SparkDataFrame,
alpha: types.StatErrorType,
effect_type: str,
column: str,
alpha: types.StatErrorType = np.array([0.05]),
effect_type: str = Effects.abs.value,
**kwargs,
) -> List[Tuple]:
raise NotImplementedError("Will be implemented later")
self._recalc_cache(group_a, group_b, column)
alpha = cast_pkg.transform_alpha_np(alpha)
if effect_type == Effects.abs.value:
confidence_intervals = self._build_intervals_absolute_from_stats(
center=self.data_stats["mean"],
sd_1=self.data_stats["std"],
n_obs=self.data_stats["n_obs"],
alpha=alpha,
**kwargs,
)
elif effect_type == Effects.rel.value:
confidence_intervals, _ = theory_pkg.apply_delta_method_by_stats(
size=self.data_stats["n_obs"],
mean_group_a=self.data_stats["mean_a"],
mean_group_b=self.data_stats["mean_b"],
var_group_a=self.data_stats["var_a"],
var_group_b=self.data_stats["var_b"],
cov_groups=self.data_stats["cov"],
alpha=alpha,
transformation="fraction",
**kwargs,
)
else:
raise ValueError(self._send_type_error_msg())
return confidence_intervals

def calculate_effect(
self, group_a: types.SparkDataFrame, group_b: types.SparkDataFrame, column: str, effect_type: str
self,
group_a: types.SparkDataFrame,
group_b: types.SparkDataFrame,
column: str,
effect_type: str = Effects.abs.value,
) -> float:
self._recalc_cache(group_a, group_b, column)
if effect_type == Effects.abs.value:
effect: float = self.data_stats["mean"]
elif effect_type == Effects.rel.value:
effect: float = (self.data_stats["mean_b"] - self.data_stats["mean_a"]) / self.data_stats["mean_a"]
else:
raise NotImplementedError("Will be implemented later")
raise ValueError(self._send_type_error_msg())
return effect
4 changes: 2 additions & 2 deletions ambrosia/spark_tools/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

import ambrosia.tools.ab_abstract_component as ab_abstract
from ambrosia import types
from ambrosia.spark_tools.constants import EMPTY_VALUE_PARTITION
from ambrosia.tools.import_tools import spark_installed

if spark_installed():
import pyspark.sql.functions as spark_funcs
from pyspark.sql import Window


EMPTY_VALUE: int = 0
STRAT_GROUPS: str = "__ambrosia_strat"


Expand All @@ -38,7 +38,7 @@ def fit(self, dataframe: types.SparkDataFrame, columns: Optional[Iterable[types.
self.strats = {ab_abstract.EmptyStratValue.NO_STRATIFICATION: dataframe}
return

window = Window.orderBy(*columns).partitionBy(spark_funcs.lit(EMPTY_VALUE))
window = Window.orderBy(*columns).partitionBy(spark_funcs.lit(EMPTY_VALUE_PARTITION))
with_groups = dataframe.withColumn(STRAT_GROUPS, spark_funcs.dense_rank().over(window))
amount_of_strats: int = with_groups.select(spark_funcs.max(STRAT_GROUPS)).collect()[0][0]

Expand Down
2 changes: 1 addition & 1 deletion ambrosia/tester/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PandasCriteria(enum.Enum):

class SparkCriteria(enum.Enum):
ttest: StatCriterion = spark_crit_pkg.TtestIndCriterionSpark
ttest_rel: StatCriterion = None # spark_crit_pkg.TtestRelativeCriterionSpark it's in development now
ttest_rel: StatCriterion = spark_crit_pkg.TtestRelativeCriterionSpark
mw: StatCriterion = None
wilcoxon: StatCriterion = None

Expand Down
2 changes: 1 addition & 1 deletion ambrosia/tools/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_all_enum_values(cls) -> tp.List[str]:
@classmethod
def raise_if_value_incorrect_enum(cls, value: tp.Any) -> None:
if not cls.check_value_in_enum(value):
msg: str = f"Choose value from " + ", ".join(cls.get_all_enum_values())
msg: str = f"Choose value from {', '.join(cls.get_all_enum_values())}, your value - {value}"
raise ValueError(msg)


Expand Down
8 changes: 4 additions & 4 deletions ambrosia/tools/pvalue_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def calculate_pvalue_by_delta_method(
raise ValueError(f"Got unknown random variable transformation: {ADMISSIBLE_TRANSFORMATIONS}")

if alternative == "less":
pvalue: float = sps.norm.cdf(statistic)
elif alternative == "greater":
pvalue: float = sps.norm.sf(statistic)
elif alternative == "greater":
pvalue: float = sps.norm.cdf(statistic)
elif alternative == "two-sided":
pvalue: float = 2 * min(sps.norm.cdf(statistic), sps.norm.sf(statistic))
else:
Expand Down Expand Up @@ -156,9 +156,9 @@ def choose_from_bounds(
"""
cond_many: bool = isinstance(left_ci, Iterable)
amount: int = len(left_ci) if cond_many else 1
if alternative == "greater":
right_ci = np.ones(amount) * right_bound if cond_many else right_bound
if alternative == "less":
right_ci = np.ones(amount) * right_bound if cond_many else right_bound
if alternative == "greater":
left_ci = np.ones(amount) * left_bound if cond_many else left_bound
return left_ci, right_ci

Expand Down
Loading