diff --git a/data/samples/sample_log_schema.json b/data/samples/sample_log_schema.json new file mode 100644 index 00000000..418e440c --- /dev/null +++ b/data/samples/sample_log_schema.json @@ -0,0 +1,24 @@ +{ + "name": "CustomLogs", + "properties": { + "timestamp": { + "type": "string", + "format": "date", + "pattern": "(\\d\\d\\d\\d-([0-2])?\\d-([0-3])?\\dT?([0-2])?\\d:([0-5])?\\d:([0-5])?\\d\\.\\d?\\d?\\d?Z?)" + }, + "some_property": { + "type": "string" + }, + "some_other_property": { + "type": "string" + }, + "an_integer_property": { + "type": "integer" + }, + "a_numeric_property": { + "type": "string" + } + }, + "required": ["timestamp", "some_property"], + "additionalProperties": false +} \ No newline at end of file diff --git a/src/baskerville/main.py b/src/baskerville/main.py index e37cd999..0e3684fa 100644 --- a/src/baskerville/main.py +++ b/src/baskerville/main.py @@ -62,20 +62,6 @@ def run_simulation(conf, spark=None): print('Set up Simulation...') -def populate_with_test_data(database_config): - """ - Load the test data and save them in the database - :param dict[str, T] database_config: - :return: - """ - global logger - from baskerville.util.model_serialization import import_pickled_model - path = os.path.join(get_default_data_path(), 'samples', 'sample_model') - test_model_path = os.path.join(get_default_data_path(), 'samples', 'test_model') - logger.info(f'Loading test model from: {test_model_path}') - import_pickled_model(database_config, path, test_model_path) - - def main(): """ Baskerville commandline arguments @@ -99,13 +85,6 @@ def main(): "in the configuration port", ) - parser.add_argument( - "-t", "--testmodel", dest="test_model", - help="Add a test model in the models table", - default=False, - action="store_true" - ) - parser.add_argument( "-c", "--conf", action="store", dest="conf_file", default=os.path.join(src_dir, '..', 'conf', 'baskerville.yaml'), @@ -143,10 +122,6 @@ def main(): logger.info(f'Starting Baskerville Exporter at ' f'http://localhost:{port}') - # populate with test data if specified - if args.test_model: - populate_with_test_data(conf['database']) - for p in PROCESS_LIST[::-1]: print(f"{p.name} starting...") p.start() diff --git a/src/baskerville/models/anomaly_detector.py b/src/baskerville/models/anomaly_detector.py deleted file mode 100644 index ea09b9e9..00000000 --- a/src/baskerville/models/anomaly_detector.py +++ /dev/null @@ -1,215 +0,0 @@ -import abc -from collections import defaultdict -from typing import NamedTuple -import os - -import _pickle as cPickle - -from baskerville.util.helpers import get_classifier_load_path, \ - get_scaler_load_path, class_from_str -from pyspark.ml.feature import StandardScalerModel -from pyspark_iforest.ml.iforest import IForestModel - -from baskerville.features.helpers import extract_features_in_order -from baskerville.spark import get_spark_session -from baskerville.spark.helpers import load_model_from_path -from baskerville.spark.schemas import prediction_schema -from pyspark.sql import functions as F - - -class AnomalyDetector(NamedTuple): - model: any - features: list - host_encoder: any - scaler: any - threshold: float - features_col: str = 'features' - prediction_col: str = 'prediction' - score_col: str = 'score' - - -class AnomalyDetectorManagerBase(object, metaclass=abc.ABCMeta): - """ - todo: might raise memory consumption a bit - V - """ - anomaly_detector = None - anomaly_detector_bc = None - - @abc.abstractmethod - def predict(self, df): - pass - - def load_classifier(self, classifier, algorithm): - return classifier - - def load_scaler(self, scaler, scaler_type): - return scaler - - def load_features(self, features): - return features - - def load_threshold(self, threshold): - return threshold - - def load_host_encoder(self, host_encoder): - return host_encoder - - @abc.abstractmethod - def load(self, model): - """ - Instantiates an AnomalyDetector object, using the - baskerville.db.models.Model - :param baskerville.db.models.Model model: - :param register_metrics: - :return: - """ - self.anomaly_detector = AnomalyDetector( - model=self.load_classifier(model.classifier, model.algorithm), - features=self.load_features(model.features), - scaler=self.load_scaler(model.scaler, model.scaler_type), - threshold=self.load_threshold(model.threshold), - host_encoder=self.load_host_encoder(model.host_encoder), - ) - - def broadcast(self): - self.anomaly_detector_bc = get_spark_session().sparkContext.broadcast( - self.anomaly_detector - ) - - -class ScikitAnomalyDetectorManager(AnomalyDetectorManagerBase): - - def predict(self, df): - from baskerville.models.metrics.helpers import \ - CLIENT_PREDICTION_ACCUMULATOR, CLIENT_REQUEST_SET_COUNT - global CPA, CRSC - - CPA = CLIENT_PREDICTION_ACCUMULATOR - CRSC = CLIENT_PREDICTION_ACCUMULATOR - - def predict_dict(target, dict_features, update_metrics=False): - """ - Scale the feature values and use the model to predict - :param dict[str, float] dict_features: the feature dictionary - :param update_metrics: - :return: 0 if normal, 1 if abnormal, -1 if something went wrong - """ - global CPA, CRSC - import json - prediction = 0, 0. - threshold = None - score = None - - if isinstance(dict_features, str): - dict_features = json.loads(dict_features) - try: - x = dict_features - detector = self.anomaly_detector_bc.value - x = [extract_features_in_order(x, detector.features)] - x = detector.scaler.transform(x) - y = 0.5 - detector.model.decision_function(x) - threshold = detector.threshold if detector.threshold \ - else 0.5 - detector.model.threshold_ - prediction = float((y > threshold)[0]) - score = float(y[0]) - # because of net.razorvine.pickle.PickleException: expected - # zero arguments for construction of ClassDict - # (for numpy.core.multiarray._reconstruct): - threshold = float(threshold) - - except ValueError: - import traceback - traceback.print_exc() - print('Cannot predict:', dict_features) - - if update_metrics: - CRSC += {target: 1} - CPA += {target: prediction} - - return prediction, score, threshold - - udf_predict_dict = F.udf(predict_dict, prediction_schema) - - df = df.withColumn( - 'y', - udf_predict_dict( - 'target', - 'features', - F.lit(False) # todo - ) - ) - - df = df.withColumn( - self.anomaly_detector.prediction_col, F.col('y.prediction') - ).withColumn( - self.anomaly_detector.score_col, F.col('y.score') - ).withColumn( - 'threshold', F.col('y.threshold') - ).drop('y') - return df - - def load_classifier(self, classifier, algorithm): - return cPickle.loads(classifier) - - def load_scaler(self, scaler, scaler_type): - return cPickle.loads(scaler) if scaler else scaler - - def load_host_encoder(self, host_encoder): - return cPickle.loads(host_encoder) if host_encoder else host_encoder - - def load(self, model): - super().load(model) - self.broadcast() - - -class SparkAnomalyDetectorManager(AnomalyDetectorManagerBase): - - def predict(self, df): - from baskerville.spark.udfs import to_dense_vector_udf - - df = df.withColumn( - 'vectorized_features', - to_dense_vector_udf('vectorized_features') - ) - - df = self.anomaly_detector.scaler.transform(df) - df = self.anomaly_detector.model.transform(df) - df = df.withColumnRenamed('anomalyScore', 'score') - return df - - def load_classifier(self, classifier, algorithm): - return load_model_from_path( - f'{algorithm}Model', - bytes.decode(classifier, 'utf8') - ) - - def load_scaler(self, scaler, scaler_type): - if scaler: - return load_model_from_path( - f'{scaler_type}Model', - bytes.decode(scaler, 'utf8') - ) - return scaler - - def load_host_encoder(self, host_encoder): - """ - # todo - :param host_encoder: - :return: - """ - return host_encoder - - def load(self, model): - model_path = bytes.decode(model.classifier, 'utf8') - self.anomaly_detector = AnomalyDetector( - model=class_from_str(f'{model.algorithm}Model').load( - get_classifier_load_path(model_path) - ), - features=model.features, - scaler=class_from_str(f'{model.scaler_type}Model').load( - get_scaler_load_path(model_path) - ), - threshold=self.load_threshold(model.threshold), - host_encoder=None, - ) diff --git a/src/baskerville/models/anomaly_model.py b/src/baskerville/models/anomaly_model.py new file mode 100644 index 00000000..b0f36977 --- /dev/null +++ b/src/baskerville/models/anomaly_model.py @@ -0,0 +1,162 @@ +from pyspark.ml.feature import StandardScaler, StandardScalerModel, StringIndexer, StringIndexerModel +from pyspark.ml.linalg import Vectors, VectorUDT +from pyspark.sql import functions as F +from pyspark.sql.functions import array + +from baskerville.models.model_interface import ModelInterface +from baskerville.spark.helpers import map_to_array, StorageLevelFactory +from baskerville.spark.udfs import udf_to_dense_vector, udf_add_to_dense_vector +from pyspark_iforest.ml.iforest import IForest, IForestModel +import os +import numpy as np + +from baskerville.util.file_manager import FileManager + + +class AnomalyModel(ModelInterface): + + def __init__(self, feature_map_column='features', + features=None, + categorical_features=[], + prediction_column="prediction", + threshold=0.5, + score_column="score", + num_trees=100, max_samples=1.0, max_features=1.0, max_depth=10, + contamination=0.1, bootstrap=False, approximate_quantile_relative_error=0., + seed=777, + scaler_with_mean=False, scaler_with_std=True, + storage_level='OFF_HEAP'): + super().__init__() + self.prediction_column = prediction_column + self.score_column = score_column + self.num_trees = num_trees + self.max_samples = max_samples + self.max_features = max_features + self.max_depth = max_depth + self.contamination = contamination + self.bootstrap = bootstrap + self.approximate_quantile_relative_error = approximate_quantile_relative_error + self.seed = seed + self.scaler_with_mean = scaler_with_mean + self.scaler_with_std = scaler_with_std + self.features = features + self.categorical_features = categorical_features + self.feature_map_column = feature_map_column + self.storage_level = storage_level + + self.scaler_model = None + self.iforest_model = None + self.threshold = threshold + self.indexes = None + self.features_values_column = 'features_values' + self.features_values_scaled = 'features_values_scaled' + + def build_features_vectors(self, df): + res = map_to_array( + df, + map_col=self.feature_map_column, + array_col=self.features_values_column, + map_keys=self.features + ).persist(StorageLevelFactory.get_storage_level(self.storage_level)) + df.unpersist() + + return res.withColumn( + self.features_values_column, + udf_to_dense_vector(self.features_values_column) + ) + + def _create_indexes(self, df): + self.indexes = [] + for c in self.categorical_features: + indexer = StringIndexer(inputCol=c, outputCol=f'{c}_index') \ + .setHandleInvalid('keep') \ + .setStringOrderType('alphabetAsc') + self.indexes.append(indexer.fit(df)) + + def _add_categorical_features(self, df, feature_column): + index_columns = [] + for index_model in self.indexes: + df = index_model.transform(df) + index_columns.append(index_model.getOutputCol()) + + df = df.withColumn('features_all', udf_add_to_dense_vector(feature_column, array(*index_columns))) \ + .drop(*index_columns, feature_column) \ + .withColumnRenamed('features_all', feature_column) + return df + + def train(self, df): + df = self.build_features_vectors(df) + + scaler = StandardScaler() + scaler.setInputCol(self.features_values_column) + scaler.setOutputCol(self.features_values_scaled) + scaler.setWithMean(self.scaler_with_mean) + scaler.setWithStd(self.scaler_with_std) + self.scaler_model = scaler.fit(df) + df = self.scaler_model.transform(df).persist( + StorageLevelFactory.get_storage_level(self.storage_level) + ) + if len(self.categorical_features): + self._create_indexes(df) + self._add_categorical_features(df, self.features_values_scaled) + + iforest = IForest( + featuresCol=self.features_values_scaled, + predictionCol=self.prediction_column, + # anomalyScore=self.score_column, + numTrees=self.num_trees, + maxSamples=self.max_samples, + maxFeatures=self.max_features, + maxDepth=self.max_depth, + contamination=self.contamination, + bootstrap=self.bootstrap, + approxQuantileRelativeError=self.approximate_quantile_relative_error, + numCategoricalFeatures=len(self.categorical_features) + ) + iforest.setSeed(self.seed) + params = {'threshold': self.threshold} + self.iforest_model = iforest.fit(df, params) + df.unpersist() + + def predict(self, df): + df = self.build_features_vectors(df) + df = self.scaler_model.transform(df) + if len(self.categorical_features): + df = self._add_categorical_features(df, self.features_values_scaled) + df = self.iforest_model.transform(df) + df = df.withColumnRenamed('anomalyScore', self.score_column) + return df + + def _get_params_path(self, path): + return os.path.join(path, 'params.json') + + def _get_iforest_path(self, path): + return os.path.join(path, 'iforest') + + def _get_scaler_path(self, path): + return os.path.join(path, 'scaler') + + def _get_index_path(self, path, feature): + return os.path.join(path, 'indexes', feature) + + def save(self, path, spark_session=None): + file_manager = FileManager(path, spark_session) + file_manager.save_to_file(self.get_params(), self._get_params_path(path), format='json') + self.iforest_model.write().overwrite().save(self._get_iforest_path(path)) + self.scaler_model.write().overwrite().save(self._get_scaler_path(path)) + + if len(self.categorical_features): + for feature, index in zip(self.categorical_features, self.indexes): + index.write().overwrite().save(self._get_index_path(path, feature)) + + def load(self, path, spark_session=None): + self.iforest_model = IForestModel.load(self._get_iforest_path(path)) + self.scaler_model = StandardScalerModel.load(self._get_scaler_path(path)) + + file_manager = FileManager(path, spark_session) + params = file_manager.load_from_file(self._get_params_path(path), format='json') + self.set_params(**params) + + self.indexes = [] + for feature in self.categorical_features: + self.indexes.append(StringIndexerModel.load(self._get_index_path(path, feature))) diff --git a/src/baskerville/models/anomaly_model_sklearn.py b/src/baskerville/models/anomaly_model_sklearn.py new file mode 100644 index 00000000..51f9fd1f --- /dev/null +++ b/src/baskerville/models/anomaly_model_sklearn.py @@ -0,0 +1,179 @@ +from baskerville.models.model_interface import ModelInterface +from baskerville.spark import get_spark_session +import os + +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import IsolationForest +import pandas as pd +from typing import NamedTuple +from pyspark.sql import functions as F +from pyspark.sql import types as T + +from baskerville.util.file_manager import FileManager + +prediction_schema = T.StructType([ + T.StructField("prediction", T.FloatType(), False), + T.StructField("score", T.FloatType(), True) +]) + + +class AnomalyDetector(NamedTuple): + model: any + features: list + scaler: any + threshold: float + features_col: str = 'features' + prediction_col: str = 'prediction' + score_col: str = 'score' + + +def extract_features_in_order(feature_dict, model_features): + """ + Returns the model features in the order the model requires them. + """ + return [feature_dict[feature] for feature in model_features] + + +class AnomalyModelSklearn(ModelInterface): + + def __init__(self, feature_map_column='features', features=None, prediction_column="prediction", + score_column="score", + num_trees=100, max_samples="auto", max_features=1.0, n_jobs=1, verbose=10, + contamination=0.1, bootstrap=False, + seed=777, + scaler_with_mean=False, scaler_with_std=True): + super().__init__() + self.prediction_column = prediction_column + self.score_column = score_column + self.num_trees = num_trees + self.max_samples = max_samples + self.max_features = max_features + self.contamination = contamination + self.bootstrap = bootstrap + self.seed = seed + self.n_jobs = n_jobs + self.verbose = verbose + + self.scaler_with_mean = scaler_with_mean + self.scaler_with_std = scaler_with_std + + self.scaler_model = None + self.iforest_model = None + + self.features = features + self.feature_map_column = feature_map_column + self.anomaly_detector_broadcast = None + + def train(self, df): + df = df.toPandas() + features = self.features + if not features or len(features) == 0: + self.features = [*df[self.feature_map_column][0].keys()] + + df[self.features] = df[self.feature_map_column].apply(pd.Series) + df.drop(self.feature_map_column, axis=1, inplace=True) + + self.scaler_model = StandardScaler(with_mean=self.scaler_with_mean, with_std=self.scaler_with_std) + x_train = self.scaler_model.fit_transform( + df[self.features].values + ) + + self.iforest_model = IsolationForest( + n_estimators=self.num_trees, + max_samples=self.max_samples, + contamination=self.contamination, + max_features=self.max_features, + bootstrap=self.bootstrap, + n_jobs=self.n_jobs, + random_state=self.seed, + verbose=self.verbose + ) + self.iforest_model.fit(x_train) + + def predict(self, df): + from baskerville.models.metrics.helpers import CLIENT_PREDICTION_ACCUMULATOR + global CPA, CRSC + + CPA = CLIENT_PREDICTION_ACCUMULATOR + CRSC = CLIENT_PREDICTION_ACCUMULATOR + + def predict_dict(target, dict_features, update_metrics=False): + """ + Scale the feature values and use the model to predict + :param dict[str, float] dict_features: the feature dictionary + :param update_metrics: + :return: 0 if normal, 1 if abnormal, -1 if something went wrong + """ + global CPA, CRSC + import json + prediction = 0, 0. + score = None + + if isinstance(dict_features, str): + dict_features = json.loads(dict_features) + try: + x = dict_features + detector = self.anomaly_detector_broadcast.value + x = [extract_features_in_order(x, detector.features)] + x = detector.scaler.transform(x) + y = 0.5 - detector.model.decision_function(x) + threshold = detector.threshold if detector.threshold > 0 \ + else 0.5 - detector.model.threshold_ + prediction = float((y > threshold)[0]) + score = float(y[0]) + # because of net.razorvine.pickle.PickleException: expected + # zero arguments for construction of ClassDict + # (for numpy.core.multiarray._reconstruct): + + except ValueError: + import traceback + traceback.print_exc() + print('Cannot predict:', dict_features) + + if update_metrics: + CRSC += {target: 1} + CPA += {target: prediction} + + return prediction, score + + udf_predict_dict = F.udf(predict_dict, prediction_schema) + + df = df.withColumn( + 'y', + udf_predict_dict( + 'target', + 'features', + F.lit(False) # todo + ) + ) + + df = df.withColumn( + self.prediction_column, F.col('y.prediction') + ).withColumn( + self.score_column, F.col('y.score') + ).drop('y') + return df + + def save(self, path, spark_session=None): + file_manager = FileManager(path, spark_session) + file_manager.save_to_file(self.get_params(), os.path.join(path, 'params.json'), format='json') + file_manager.save_to_file(self.iforest_model, os.path.join(path, 'iforest.pickle'), format='pickle') + file_manager.save_to_file(self.scaler_model, os.path.join(path, 'scaler.pickle'), format='pickle') + + def load(self, path, spark_session=None): + file_manager = FileManager(path, spark_session) + params = file_manager.load_from_file(os.path.join(path, 'params.json'), format='json') + self.set_params(**params) + self.iforest_model = file_manager.load_from_file(os.path.join(path, 'iforest.pickle'), format='pickle') + self.scaler_model = file_manager.load_from_file(os.path.join(path, 'scaler.pickle'), format='pickle') + + anomaly_detector = AnomalyDetector( + model=self.iforest_model, + features=self.features, + scaler=self.scaler_model, + threshold=0.0 + ) + + self.anomaly_detector_broadcast = get_spark_session().sparkContext.broadcast( + anomaly_detector + ) diff --git a/src/baskerville/models/base.py b/src/baskerville/models/base.py index f87589af..1e751833 100644 --- a/src/baskerville/models/base.py +++ b/src/baskerville/models/base.py @@ -1,7 +1,6 @@ import abc from baskerville.db import get_jdbc_url -from baskerville.models.feature_manager import FeatureManager from baskerville.util.helpers import get_logger @@ -72,35 +71,3 @@ def initialize(self): @abc.abstractmethod def finish_up(self): pass - - -class TrainingPipelineBase(PipelineBase, metaclass=abc.ABCMeta): - def __init__(self, db_conf, engine_conf, spark_conf, clean_up=True): - super().__init__(db_conf, engine_conf, clean_up) - self.data = None - self.training_conf = self.engine_conf.training - self.spark_conf = spark_conf - self.feature_manager = FeatureManager(self.engine_conf) - - @abc.abstractmethod - def get_data(self): - pass - - @abc.abstractmethod - def train(self): - pass - - @abc.abstractmethod - def test(self): - pass - - @abc.abstractmethod - def evaluate(self): - pass - - @abc.abstractmethod - def save(self, *args, **kwargs): - pass - - def run(self, *args, **kwargs): - super().run() diff --git a/src/baskerville/models/base_spark.py b/src/baskerville/models/base_spark.py index 17815ac5..5c022ef3 100644 --- a/src/baskerville/models/base_spark.py +++ b/src/baskerville/models/base_spark.py @@ -5,12 +5,9 @@ from baskerville.models.base import PipelineBase from baskerville.models.feature_manager import FeatureManager -from baskerville.models.model_manager import ModelManager -from baskerville.spark.helpers import save_df_to_table, map_to_array, \ - reset_spark_storage +from baskerville.spark.helpers import save_df_to_table, reset_spark_storage, set_unknown_prediction from baskerville.spark.schemas import get_cache_schema -from baskerville.spark.udfs import to_dense_vector_udf -from baskerville.util.helpers import TimeBucket, FOLDER_CACHE +from baskerville.util.helpers import TimeBucket, FOLDER_CACHE, instantiate_from_str from pyspark.sql import types as T, DataFrame from baskerville.spark import get_or_create_spark_session @@ -94,8 +91,9 @@ def __init__(self, self.remaining_steps = list(self.step_to_action.keys()) self.time_bucket = TimeBucket(self.engine_conf.time_bucket) - self.model_manager = ModelManager(self.db_conf, self.engine_conf) self.feature_manager = FeatureManager(self.engine_conf) + self.model_index = None + self.model = None def load_test(self): """ @@ -151,15 +149,7 @@ def initialize(self): # initialize spark session self.spark = self.instantiate_spark_session() - - # set the model and feature related stuff - self.model_manager.initialize(self.spark, self.tools) - self.feature_manager.initialize(self.model_manager) - self.model_manager.can_predict = self.feature_manager.\ - feature_config_is_valid() - - self._can_predict = self.feature_manager.feature_config_is_valid() \ - and self.model_manager.ml_model + self.feature_manager.initialize() self.drop_if_missing_filter = self.data_parser.drop_if_missing_filter() # set up cache @@ -175,6 +165,13 @@ def initialize(self): self.feature_manager.update_feature_cols ).difference(RequestSet.columns) + if self.engine_conf.model_id: + self.model_index = self.tools.get_ml_model_from_db(self.engine_conf.model_id) + self.model = instantiate_from_str(self.model_index.algorithm) + self.model.load(bytes.decode(self.model_index.classifier, 'utf8'), self.spark) + else: + self.model = None + self._is_initialized = True def get_columns_to_filter_by(self): @@ -277,9 +274,9 @@ def get_post_group_by_calculations(self): ).otherwise(F.lit(0)) } - if self.model_manager.ml_model: + if self.model_index: post_group_by_columns['model_version'] = F.lit( - self.model_manager.ml_model.id + self.model_index.id ) # todo: what if a feature defines a column name that already exists? @@ -801,13 +798,6 @@ def feature_update(self): ) ) )) - self.logs_df = map_to_array( - self.logs_df, - 'features', - 'vectorized_features', - self.feature_manager.active_feature_names - ) - self.remove_feature_columns() # older way with a udf: # self.logs_df = self.logs_df.withColumn( @@ -872,31 +862,21 @@ def cross_reference(self): F.col('cross_reference.id_attribute')).otherwise(None) ) - def predict_sparkml(self): - """ - Predict using the Spark ML implementation of the algorithms - :return: - """ - self.logs_df = self.logs_df.withColumn( - 'vectorized_features', - to_dense_vector_udf('vectorized_features') - ) - - self.logs_df = self.model_manager.ml_model.scaler_model.transform( - self.logs_df - ) - self.logs_df = self.model_manager.ml_model.classifier_model.transform( - self.logs_df - ) - self.logs_df = self.logs_df.withColumnRenamed('anomalyScore', 'score') - def predict(self): """ Predict on the request_sets. Prediction on request_sets requires feature averaging where there is an existing request_set. ` :return: None """ - self.logs_df = self.model_manager.predict(self.logs_df) + if self.model: + self.logs_df = self.model.predict(self.logs_df) + else: + self.logs_df = set_unknown_prediction(self.logs_df).withColumn( + 'prediction', F.col('prediction').cast(T.IntegerType()) + ).withColumn( + 'score', F.col('score').cast(T.FloatType()) + ).withColumn( + 'threshold', F.col('threshold').cast(T.FloatType())) def save_df_to_table( self, df, table_name, json_cols=('features',), mode='append' diff --git a/src/baskerville/models/config.py b/src/baskerville/models/config.py index dc62965a..3ef572b8 100644 --- a/src/baskerville/models/config.py +++ b/src/baskerville/models/config.py @@ -5,7 +5,7 @@ from functools import wraps import dateutil -from baskerville.util.enums import AlgorithmEnum, ScalerEnum +from baskerville.util.enums import ModelEnum from baskerville.util.helpers import get_logger, get_default_data_path, \ SerializableMixin from dateutil.tz import tzutc @@ -439,41 +439,26 @@ class TrainingConfig(Config): - training days - from - to date - other filters, like hosts - Classifier Parameters: (optional) - - max_samples - - contamination - - n_estimators - Scaler Parameters (optional) + Model Parameters: (optional) + - """ classifier: str scaler: str - data_parameters: dict - classifier_parameters: dict - scaler_parameters: dict - model_options: dict - host_feature: list = [] + model_parameters: dict n_jobs: int = -1 - use_host_feature: bool = False max_number_of_records_to_read_from_db: int = None def __init__(self, config, parent=None): super(TrainingConfig, self).__init__(config, parent) - self.allowed_algorithms = list(vars(AlgorithmEnum)['_value2member_map_'].keys()) - self.allowed_scalers = list(vars(ScalerEnum)['_value2member_map_'].keys()) + self.allowed_models = list(vars(ModelEnum)['_value2member_map_'].keys()) def validate(self): logger.debug('Validating TrainingConfig...') - if self.classifier: - if self.classifier not in self.allowed_algorithms: + if self.model: + if self.model not in self.allowed_models: raise ValueError( - f'{self.classifier} is not in allowed algorithms: ' - f'{",".join(self.allowed_algorithms)}' - ) - if self.scaler: - if self.scaler not in self.allowed_scalers: - raise ValueError( - f'{self.scaler} is not in allowed algorithms: ' - f'{",".join(self.allowed_scalers)}' + f'{self.model} is not in allowed models: ' + f'{",".join(self.allowed_models)}' ) if self.data_parameters: @@ -484,11 +469,8 @@ def validate(self): f'Either training days or from-to date should be specified' ) - if not self.classifier_parameters: - self.classifier_parameters = {} - - if not self.scaler_parameters: - self.scaler_parameters = {} + if not self.model_parameters: + self.model_parameters = {} self._is_validated = True return self diff --git a/src/baskerville/models/feature_manager.py b/src/baskerville/models/feature_manager.py index bd8c0d87..e801bc64 100644 --- a/src/baskerville/models/feature_manager.py +++ b/src/baskerville/models/feature_manager.py @@ -8,12 +8,10 @@ class FeatureManager(object): def __init__( self, - engine_conf, - model_manager=None, + engine_conf ): self.all_features = engine_conf.all_features self.extra_features = engine_conf.extra_features - self.model_manager = model_manager self.active_features = None self.active_feature_names = None self.updateable_active_features = None @@ -29,8 +27,7 @@ def __init__( output_file=engine_conf.logpath ) - def initialize(self, model_manager=None): - self.model_manager = model_manager + def initialize(self,): self.active_features = self.get_active_features() self.active_feature_names = self.get_active_feature_names() self.updateable_active_features = self.get_updateable_active_features() @@ -46,8 +43,6 @@ def get_active_features(self): :return: """ feature_list = self.extra_features - if self.model_manager and self.model_manager.ml_model: - feature_list += self.model_manager.ml_model.features if not feature_list: raise RuntimeError('No features specified! Either input model ' 'or specify features in config.') @@ -120,22 +115,10 @@ def feature_config_is_valid(self) -> bool: :return: """ checks = [] - if self.model_manager.ml_model: - checks.append(self.features_subset_of_model_features()) - checks.append(self.feature_dependencies_met()) + checks.append(self.feature_dependencies_met()) return False not in checks - def features_subset_of_model_features(self) -> bool: - """ - Checks whether the features required for the model are a subset of the - active feature names - :return: - """ - return set(self.model_manager.ml_model.features).issubset( - set(self.active_feature_names) - ) - def feature_dependencies_met(self) -> bool: """ Checks that the features defined as dependencies are included in the diff --git a/src/baskerville/models/model_interface.py b/src/baskerville/models/model_interface.py new file mode 100644 index 00000000..1772e19e --- /dev/null +++ b/src/baskerville/models/model_interface.py @@ -0,0 +1,39 @@ +import inspect + + +class ModelInterface(object): + def __init__(self): + super().__init__() + + def get_param_names(self): + return list(inspect.signature(self.__init__).parameters.keys()) + + def set_params(self, **params): + param_names = self.get_param_names() + for key, value in params.items(): + if key not in param_names: + raise RuntimeError(f'Class {self.__class__.__name__} does not have {key} attribute') + setattr(self, key, value) + + def get_params(self): + params = {} + for name in self.get_param_names(): + params[name] = getattr(self, name) + return params + + def _get_class_path(self): + return f'{self.__class__.__module__}.{self.__class__.__name__}' + + def train(self, df): + pass + + def predict(self, df): + pass + + def save(self, path, spark_session=None): + pass + + def load(self, path, spark_session=None): + pass + + diff --git a/src/baskerville/models/model_manager.py b/src/baskerville/models/model_manager.py deleted file mode 100644 index 17b1a3fb..00000000 --- a/src/baskerville/models/model_manager.py +++ /dev/null @@ -1,86 +0,0 @@ -from baskerville.spark.helpers import set_unknown_prediction -from baskerville.util.enums import ANOMALY_MODEL_MANAGER -from baskerville.util.helpers import instantiate_from_str, get_logger -from pyspark.sql import functions as F, types as T - - -class ModelManager(object): - def __init__(self, db_conf, engine_conf, spark_session=None, db_tools=None): - self.db_conf = db_conf - self.engine_conf = engine_conf - self.ml_model = None - self.can_predict = True - self.db_tools = db_tools - self.spark_session = spark_session - self.anomaly_model_manager = None - self.logger = get_logger( - self.__class__.__name__, - logging_level=self.engine_conf.log_level, - output_file=self.engine_conf.logpath - ) - - def initialize(self, spark_session, db_tools): - self.spark_session = spark_session - self.db_tools = db_tools - self.load() - - def load(self): - """ - Loads the model from db if model_id is defined, - :return: - """ - if self.engine_conf.model_id: - if self.engine_conf.model_id == -1: - self.logger.debug('Loading latest Model from db') - # get the latest models from db - self.ml_model = self.db_tools.get_latest_ml_model_from_db() - elif self.engine_conf.model_id > 0: - self.ml_model = self.db_tools.get_ml_model_from_db( - self.engine_conf.model_id - ) - elif self.engine_conf.model_path: - self.ml_model = self.db_tools.get_ml_model_from_file( - self.engine_conf.model_path - ) - if self.ml_model: - self.logger.debug(f'Loaded Model with id: {self.ml_model.id}') - self.anomaly_model_manager = instantiate_from_str( - ANOMALY_MODEL_MANAGER[self.ml_model.algorithm] - ) - self.anomaly_model_manager.load(self.ml_model) - else: - self.logger.info('No Model loaded.') - - return self.ml_model - - def predict(self, df): - """ - Use the anomaly model manager to predict on the dataframe or set - the default values for prediction and score - :param df: - :return: - """ - if self.can_predict and self.ml_model: - df = self.anomaly_model_manager.predict(df) - else: - if not self.can_predict: - self.logger.warn( - 'Active features do not match model features, ' - 'skipping prediction' - ) - elif not self.ml_model: - self.logger.warn( - 'No ml model specified, ' - 'skipping prediction' - ) - df = set_unknown_prediction( - df, columns=('score', 'prediction', 'threshold') - ).withColumn( - 'prediction', F.col('prediction').cast(T.IntegerType()) - ).withColumn( - 'score', F.col('score').cast(T.FloatType()) - ).withColumn( - 'threshold', F.col('threshold').cast(T.FloatType()) - ) - - return df diff --git a/src/baskerville/models/pipeline_factory.py b/src/baskerville/models/pipeline_factory.py index 5e83f40c..8f46fad2 100644 --- a/src/baskerville/models/pipeline_factory.py +++ b/src/baskerville/models/pipeline_factory.py @@ -1,8 +1,6 @@ -from baskerville.models.pipeline_training import TrainingPipeline, \ - TrainingSparkMLPipeline -from baskerville.models.pipelines import RawLogPipeline, ElasticsearchPipeline, \ - KafkaPipeline -from baskerville.util.enums import RunType, SPARK_ML_MODELS, SKLEARN_MODELS +from baskerville.models.pipeline_training import TrainingPipeline +from baskerville.models.pipelines import RawLogPipeline, ElasticsearchPipeline, KafkaPipeline +from baskerville.util.enums import RunType class PipelineFactory(object): @@ -28,19 +26,11 @@ def get_pipeline(self, run_type, config): config.spark ) elif run_type == RunType.training: - if config.engine.training.classifier in SPARK_ML_MODELS: - return TrainingSparkMLPipeline( - config.database, - config.engine, - config.spark - ) - elif config.engine.training.classifier in SKLEARN_MODELS: - - return TrainingPipeline( - config.database, - config.engine, - config.spark - ) + return TrainingPipeline( + config.database, + config.engine, + config.spark + ) raise RuntimeError( 'Cannot set up a pipeline with the current configuration.' ) diff --git a/src/baskerville/models/pipeline_training.py b/src/baskerville/models/pipeline_training.py index eddea186..804993a1 100644 --- a/src/baskerville/models/pipeline_training.py +++ b/src/baskerville/models/pipeline_training.py @@ -1,177 +1,29 @@ -import json import os from collections import OrderedDict - +import json import pyspark -import _pickle as cPickle -from baskerville.models.config import EngineConfig, \ - DatabaseConfig, SparkConfig +from baskerville.models.config import EngineConfig, DatabaseConfig, SparkConfig from baskerville.spark import get_or_create_spark_session -from baskerville.spark.helpers import save_df_to_table, map_to_array, \ - reset_spark_storage -from baskerville.spark.schemas import get_models_schema -from baskerville.spark.udfs import to_dense_vector_udf +from baskerville.spark.helpers import reset_spark_storage from baskerville.util.enums import Step -from baskerville.util.helpers import instantiate_from_str, get_model_path, \ - get_scaler_load_path, get_classifier_load_path, RANDOM_SEED -from dateutil.tz import tzutc +from baskerville.util.helpers import instantiate_from_str, get_model_path +from baskerville.db.models import Model -from baskerville.models.base import TrainingPipelineBase +from baskerville.models.base import PipelineBase import datetime -import numpy as np -import pandas as pd +from dateutil.tz import tzutc from baskerville.util.baskerville_tools import BaskervilleDBTools -from pyspark.ml.feature import VectorAssembler -from sklearn.preprocessing import StandardScaler # todo -from sklearn.ensemble import IsolationForest # todo -from baskerville.db.models import Model -from sklearn.preprocessing import LabelBinarizer -import pyspark.sql.functions as F - - -class TrainingPipeline(TrainingPipelineBase): - - def __init__( - self, - db_conf, - engine_conf, - clean_up=True - ): - super(TrainingPipeline, self).__init__( - db_conf, engine_conf, clean_up - ) - self.step_to_action = OrderedDict( - zip([ - Step.get_data, - Step.train, - Step.test, - Step.evaluate, - Step.save, - ], [ - self.get_data, - self.train, - self.test, - self.evaluate, - self.save, - ])) - self.model = None - self.scaler = None - self.host_encoder = None - self.db_tools = None - self.db_conf = db_conf - self.remaining_steps = list(self.step_to_action.keys()) - - def finish_up(self): - if self.db_tools: - self.db_tools.disconnect_from_db() - - def initialize(self): - conf = self.db_conf - conf.maintenance = None - self.active_features = self.engine_conf.extra_features - self.db_tools = BaskervilleDBTools(conf) - self.db_tools.connect_to_db() - - def get_date_filter(self): - training_days = self.training_conf.data_parameters.get("training_days") - return f'created_at > CURRENT_DATE - INTERVAL \'{training_days} days\'' - - def get_model_parameters(self): - return dict({ - 'verbose': 10, - 'n_jobs': self.training_conf.n_jobs, - }, **self.training_conf.classifier_parameters or {}) - - def get_query(self): - limit = '' - if self.engine_conf.training.max_number_of_records_to_read_from_db: - limit = f'limit {self.engine_conf.training.max_number_of_records_to_read_from_db}' - return F'select target, features from request_sets where {self.get_date_filter()} {limit}' - - def get_data(self): - self.data = pd.read_sql(self.get_query(), self.db_tools.engine) - self.logger.info(f'{len(self.data)} records retrieved.') - self.logger.info(f'Unwrapping features...') - if len(self.data) > 0: - self.data[[*self.data['features'][0].keys()]] = self.data[ - 'features' - ].apply(pd.Series) - self.data.drop('features', axis=1, inplace=True) - self.logger.info(f'Unwrapping features complete.') - - return self.data - - def train(self): - self.scaler = StandardScaler() - x_train = self.scaler.fit_transform( - self.data[self.active_features].values - ) - - if self.engine_conf.training.use_host_feature > 0: - self.host_encoder = LabelBinarizer() - self.host_encoder.fit(self.training_conf.host_feature) - self.logger.info('Host feature one-hot encoding...') - host_features = self.host_encoder.transform(self.data['target']) - self.logger.info('Host feature one-hot encoding concatenating...') - x_train = np.concatenate((x_train, host_features), axis=1) - - self.model = IsolationForest() - self.model.set_params(**self.get_model_parameters()) - self.logger.info('Model.fit()...') - self.model.fit(x_train) - self.logger.info('Model.fit() done.') - return self.model, self.scaler, self.host_encoder - - def test(self): - """ - # todo - :return: - """ - self.logger.debug('Testing: Coming soon-ish...') - - def evaluate(self): - """ - # todo - :return: - """ - self.logger.debug('Evaluating: Coming soon-ish...') - def save(self, recall=0, precision=0, f1_score=0): - - model = Model() - model.created_at = datetime.datetime.now(tz=tzutc()) - model.features = self.active_features - model.algorithm = self.training_conf.classifier - model.scaler_type = self.training_conf.scaler - model.parameters = str(self.get_model_parameters()) - model.recall = float(recall) - model.precision = float(precision) - model.f1_score = float(f1_score) - model.classifier = cPickle.dumps(self.model) - model.scaler = cPickle.dumps(self.scaler) - model.n_training = 0 - model.n_testing = 0 - model.host_encoder = cPickle.dumps(self.host_encoder) - model_dict = {} - for k in list(model.__dict__)[1:]: - model_dict[k] = getattr(model, k) - - # save to db - self.db_tools.session.add(model) - self.db_tools.session.commit() +import pyspark.sql.functions as F -class TrainingSparkMLPipeline(TrainingPipelineBase): +class TrainingPipeline(PipelineBase): """ - Training Pipeline for the Spark ML Estimators - todo: use a pipeline for the steps (scaling etc) + Training Pipeline """ - classifier: pyspark.ml.wrapper.JavaEstimator - classifier_model: object - scaler: pyspark.ml.feature.StandardScaler - scaler_model: object + model: object evaluation_results: dict data: pyspark.sql.DataFrame spark: pyspark.sql.SparkSession @@ -183,14 +35,16 @@ def __init__( spark_conf: SparkConfig, clean_up: bool = True ): - super().__init__(db_conf, engine_conf, spark_conf, clean_up) + super().__init__(db_conf, engine_conf, clean_up) + self.data = None + self.training_conf = self.engine_conf.training + self.spark_conf = spark_conf + self.logger.debug(f'{self.__class__.__name__} initiated') self.columns_to_keep = [ 'ip', 'target', 'created_at', 'features', ] - self.model_path = get_model_path(self.engine_conf.storage_path, self.__class__.__name__) - self.step_to_action = OrderedDict( zip([ Step.get_data, @@ -207,7 +61,7 @@ def __init__( ])) self.training_row_n = 0 self.testing_row_n = 0 - self.fit_params = {} + self.db_tools = None self.conn_properties = { 'user': self.db_conf.user, 'password': self.db_conf.password, @@ -215,30 +69,25 @@ def __init__( } self.remaining_steps = list(self.step_to_action.keys()) - if self.training_conf.threshold: - self.fit_params = {'threshold': self.training_conf.threshold} def initialize(self): """ Get a spark session - Create the classifier instance - Create the scaler instance + Create the model instance Set the appropriate parameters as set up in configuration :return: """ self.spark = get_or_create_spark_session(self.spark_conf) - self.feature_manager.initialize(None) - self.classifier = instantiate_from_str(self.training_conf.classifier) - self.scaler = instantiate_from_str(self.training_conf.scaler) + self.model = instantiate_from_str(self.training_conf.model) + self.model.set_params(**self.engine_conf.training.model_parameters) - self.classifier.setParams( - **self.engine_conf.training.classifier_parameters - ) - self.scaler.setParams( - **self.engine_conf.training.scaler_parameters - ) - self.classifier.setSeed(RANDOM_SEED) - # self.scaler.setSeed(RANDOM_SEED) + conf = self.db_conf + conf.maintenance = None + self.db_tools = BaskervilleDBTools(conf) + self.db_tools.connect_to_db() + + def run(self, *args, **kwargs): + super().run() def get_data(self): """ @@ -260,14 +109,7 @@ def get_data(self): # get the active feature names and transform the features to list self.active_features = json_schema.fieldNames() - data = map_to_array( - self.data, - 'features', - 'features', - self.active_features - ).persist(self.spark_conf.storage_level) - self.data.unpersist() - self.data = data + self.training_row_n = self.data.count() self.logger.debug(f'Loaded #{self.training_row_n} of request sets...') @@ -280,31 +122,10 @@ def train(self): # ) :return: None """ - # currently does not work with IForest: - # https://github.com/titicaca/spark-iforest/issues/24 - # assembler = VectorAssembler( - # inputCols=self.active_features[:2], - # outputCol="vectorized_features" - # ) - # self.data = assembler.transform(self.data) - - self.data = self.data.withColumnRenamed( - 'features', 'vectorized_features' - ).withColumn( - 'vectorized_features', - to_dense_vector_udf('vectorized_features') - ) - - self.scaler.setInputCol('vectorized_features') - self.scaler.setOutputCol('scaled_features') - self.classifier.setParams(featuresCol='scaled_features') - - self.scaler_model = self.scaler.fit(self.data) - self.data = self.scaler_model.transform(self.data).persist( - self.spark_conf.storage_level - ) - - self.classifier_model = self.classifier.fit(self.data, self.fit_params) + if not self.model.features: + self.model.features = self.active_features + self.model.train(self.data) + self.data.unpersist() def test(self): """ @@ -325,42 +146,20 @@ def save(self): Save the models on disc and add a baskerville.db.Model in the database :return: None """ - self.logger.info( - f'Saving model (classifier, scaler) in: {self.model_path}' - ) - self.classifier_model.write().overwrite().save( - get_classifier_load_path(self.model_path) - ) - self.scaler_model.write().overwrite().save( - get_scaler_load_path(self.model_path) - ) - data = [ - [ - self.active_features, - self.training_conf.classifier, - self.training_conf.scaler, - json.dumps(self.training_conf.to_dict()), - 0., - 0., - 0., - bytearray(self.model_path.encode('utf8')), - bytearray([]), - self.training_row_n, - self.testing_row_n, - float(self.training_conf.threshold) - ] - ] - self.db_conf.conn_str = self.db_url + model_path = get_model_path(self.engine_conf.storage_path, self.model.__class__.__name__) + self.model.save(path=model_path, spark_session=self.spark) - model_df = self.spark.createDataFrame(data, schema=get_models_schema()) - save_df_to_table( - model_df, - Model.__tablename__, - self.db_conf.__dict__, - self.spark_conf.storage_level - ) + db_model = Model() + db_model.created_at = datetime.datetime.now(tz=tzutc()) + db_model.algorithm = self.training_conf.model + db_model.parameters = json.dumps(self.model.get_params()) + db_model.classifier = bytearray(model_path.encode('utf8')) + + # save to db + self.db_tools.session.add(db_model) + self.db_tools.session.commit() - def get_bounds(self, from_date, to_date=None, field='created_at'): + def get_bounds(self, from_date, to_date=None, field='stop'): """ Get the lower and upper limit :param str from_date: lower date bound @@ -382,7 +181,7 @@ def get_bounds(self, from_date, to_date=None, field='created_at'): properties=self.conn_properties ) - def load(self, extra_filters=None) -> pyspark.sql.DataFrame: + def load(self) -> pyspark.sql.DataFrame: """ Loads the request_sets already in the database :return: @@ -409,28 +208,22 @@ def load(self, extra_filters=None) -> pyspark.sql.DataFrame: f'Fetching {bounds.rows} rows. ' f'min: {bounds.min_id} max: {bounds.max_id}' ) - if not bounds.min_id: - raise RuntimeError( - 'No data to train. Please, check your training configuration' - ) q = f'(select id, {",".join(self.columns_to_keep)} ' \ f'from request_sets where id >= {bounds.min_id} ' \ - f'and id <= {bounds.max_id} and created_at >= \'{from_date}\' ' \ - f'and created_at <=\'{to_date}\') as request_sets' + f'and id <= {bounds.max_id} and stop >= \'{from_date}\' ' \ + f'and stop <=\'{to_date}\') as request_sets' - if not extra_filters: - return self.spark.read.jdbc( - url=self.db_url, - table=q, - numPartitions=int(self.spark.conf.get( - 'spark.sql.shuffle.partitions' - )) or os.cpu_count()*2, - column='id', - lowerBound=bounds.min_id, - upperBound=bounds.max_id + 1, - properties=self.conn_properties - ) - raise NotImplementedError(f'No implementation for "extra_filters"') + return self.spark.read.jdbc( + url=self.db_url, + table=q, + numPartitions=int(self.spark.conf.get( + 'spark.sql.shuffle.partitions' + )) or os.cpu_count()*2, + column='id', + lowerBound=bounds.min_id, + upperBound=bounds.max_id + 1, + properties=self.conn_properties + ) def finish_up(self): """ @@ -438,3 +231,5 @@ def finish_up(self): :return: """ reset_spark_storage() + if self.db_tools: + self.db_tools.disconnect_from_db() diff --git a/src/baskerville/models/request_set_cache.py b/src/baskerville/models/request_set_cache.py index cece0c98..f3f566bd 100644 --- a/src/baskerville/models/request_set_cache.py +++ b/src/baskerville/models/request_set_cache.py @@ -1,12 +1,12 @@ import datetime import gc import os -import shutil from baskerville.spark import get_spark_session from baskerville.spark.helpers import StorageLevel from pyspark.sql import functions as F +from baskerville.util.file_manager import FileManager from baskerville.util.helpers import get_logger @@ -22,7 +22,7 @@ def __init__( session_getter=get_spark_session, group_by_fields=('target', 'ip'), format_='parquet', - path='' + path='request_set_cache' ): self.__cache = None self.__persistent_cache = None @@ -45,53 +45,15 @@ def __init__( self._count = 0 self._last_updated = datetime.datetime.utcnow() self._changed = False + self.file_manager = FileManager(path, self.session_getter()) - if path.startswith('hdfs://'): - # find the third occurrence of '/' - slash = path.find('/', path.find('/', path.find('/') + 1) + 1) - if slash == -1: - raise RuntimeError(f'HDFS path "{path}" must contain at least one folder. ' - f'For example "hdfs://xxx:8020/baskerville"). ') - self.hdfs = path[:slash] - self.path = path[slash+1:] - sc = self.session_getter()._sc - self.hdfs_path = sc._gateway.jvm.org.apache.hadoop.fs.Path - self.hdfs_file_system = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem.get( - sc._gateway.jvm.java.net.URI(self.hdfs), - sc._jsc.hadoopConfiguration()) - else: - self.hdfs = None - self.path = path - - self.file_name = os.path.join(self.path, f'{self.__class__.__name__}.{self.format_}') - self.temp_file_name = os.path.join(self.path, f'{self.__class__.__name__}temp.{self.format_}') - - if self.path_exists(self.file_name): - self.delete_path(self.file_name) - if self.path_exists(self.temp_file_name): - self.delete_path(self.temp_file_name) - - def path_exists(self, path): - if self.hdfs: - return self.hdfs_file_system.exists(self.hdfs_path(path)) - return os.path.exists(path) - - def delete_path(self, path): - if self.hdfs: - self.hdfs_file_system.delete(self.hdfs_path(path), True) - return - shutil.rmtree(path) - - def rename_path(self, source, destination): - if self.hdfs: - self.hdfs_file_system.rename(self.hdfs_path(source), self.hdfs_path(destination)) - return - os.rename(source, destination) - - def get_full_path(self, path): - if self.hdfs: - return os.path.join(self.hdfs, path) - return path + self.file_name = os.path.join(path, f'{self.__class__.__name__}.{self.format_}') + self.temp_file_name = os.path.join(path, f'{self.__class__.__name__}temp.{self.format_}') + + if self.file_manager.path_exists(self.file_name): + self.file_manager.delete_path(self.file_name) + if self.file_manager.path_exists(self.temp_file_name): + self.file_manager.delete_path(self.temp_file_name) @property def cache(self): @@ -321,11 +283,11 @@ def update_self( self.logger.debug(f'Source_df count = {source_df.count()}') # read the whole thing again - if self.path_exists(self.file_name): + if self.file_manager.path_exists(self.file_name): self.__persistent_cache = self.session_getter().read.format( self.format_ ).load( - self.get_full_path(self.file_name) + self.file_name ).persist(self.storage_level) # http://www.learnbymarketing.com/1100/pyspark-joins-by-example/ @@ -377,14 +339,14 @@ def update_self( # write back to parquet - different file/folder though # because self.parquet_name is already in use # rename temp to self.parquet_name - if self.path_exists(self.temp_file_name): - self.delete_path(self.temp_file_name) + if self.file_manager.path_exists(self.temp_file_name): + self.file_manager.delete_path(self.temp_file_name) self.__persistent_cache.write.mode( 'overwrite' ).format( self.format_ - ).save(self.get_full_path(self.temp_file_name)) + ).save(self.temp_file_name) self.logger.debug( f'# Number of rows in persistent cache: ' @@ -398,10 +360,10 @@ def update_self( self.empty_all() # rename temp to self.parquet_name - if self.path_exists(self.file_name): - self.delete_path(self.file_name) + if self.file_manager.path_exists(self.file_name): + self.file_manager.delete_path(self.file_name) - self.rename_path(self.temp_file_name, self.file_name) + self.file_manager.rename_path(self.temp_file_name, self.file_name) def refresh(self, update_date, hosts, extra_filters=None): df = self._load( diff --git a/src/baskerville/spark/helpers.py b/src/baskerville/spark/helpers.py index c1557f3c..cd0e5ee7 100644 --- a/src/baskerville/spark/helpers.py +++ b/src/baskerville/spark/helpers.py @@ -187,7 +187,7 @@ def save_model(model, path, mode='overwrite'): writer.save(path) -def set_unknown_prediction(df, columns=('prediction', 'score')): +def set_unknown_prediction(df, columns=('prediction', 'score', 'threshold')): """ Sets the preset unknown value for prediction and score :param pyspark.sql.Dataframe df: diff --git a/src/baskerville/spark/schemas.py b/src/baskerville/spark/schemas.py index 75f9a97c..c7296bde 100644 --- a/src/baskerville/spark/schemas.py +++ b/src/baskerville/spark/schemas.py @@ -13,23 +13,6 @@ ]) -def get_models_schema(): - return T.StructType([ - T.StructField("features", T.ArrayType(T.StringType()), True), - T.StructField("algorithm", T.StringType(), True), - T.StructField("scaler_type", T.StringType(), True), - T.StructField("parameters", T.StringType(), True), - T.StructField("recall", T.DoubleType(), True), - T.StructField("precision", T.DoubleType(), True), - T.StructField("f1_score", T.DoubleType(), True), - T.StructField("classifier", T.BinaryType(), True), - T.StructField("scaler", T.BinaryType(), True), - T.StructField("n_training", T.IntegerType(), True), - T.StructField("n_testing", T.IntegerType(), True), - T.StructField("threshold", T.DoubleType(), True) - ]) - - def get_cache_schema(): return T.StructType([ T.StructField("id", T.IntegerType(), False), diff --git a/src/baskerville/spark/udfs.py b/src/baskerville/spark/udfs.py index 7eb769c1..b7c4edcb 100644 --- a/src/baskerville/spark/udfs.py +++ b/src/baskerville/spark/udfs.py @@ -8,6 +8,7 @@ from pyspark.sql import functions as F from pyspark.sql import types as T from tzwhere import tzwhere +import numpy as np def normalize_host_name(host): @@ -251,4 +252,5 @@ def cross_reference_misp(ip, db_conf): update_features, T.MapType(T.StringType(), T.FloatType()) ) udf_bulk_update_request_sets = F.udf(bulk_update_request_sets, T.BooleanType()) -to_dense_vector_udf = F.udf(lambda l: Vectors.dense(l), VectorUDT()) +udf_to_dense_vector = F.udf(lambda l: Vectors.dense(l), VectorUDT()) +udf_add_to_dense_vector = F.udf(lambda features, arr: Vectors.dense(np.append(features, [v for v in arr])), VectorUDT()) diff --git a/src/baskerville/util/baskerville_tools.py b/src/baskerville/util/baskerville_tools.py index e4041423..8b7848f7 100644 --- a/src/baskerville/util/baskerville_tools.py +++ b/src/baskerville/util/baskerville_tools.py @@ -2,12 +2,7 @@ import pickle from baskerville.util.crypto import encrypt, decrypt import os -import pandas as pd -from sqlalchemy import func, and_ - from baskerville.db import set_up_db -# import models here to be registered to the db -from baskerville.db.models import * # noqa from baskerville.db.models import RequestSet, Runtime, Model # Utility class that holds commonly used functions. @@ -90,8 +85,11 @@ def get_request_sets(self): def get_latest_ml_model_from_db(self): return self.session.query(Model).order_by(Model.id.desc()).first() - def get_ml_model_from_db(self, version_id): - return self.session.query(Model).filter(Model.id == version_id).first() + def get_ml_model_from_db(self, model_id): + if model_id < 0: + return self.get_latest_ml_model_from_db() + + return self.session.query(Model).filter(Model.id == model_id).first() def get_ml_model_from_file(self, model_path): @@ -108,7 +106,7 @@ def get_ml_model_from_file(self, model_path): return model def get_ml_clf_from_file(self, model_dir, version_id): - path_to_model = f'{model_dir}/model_{version_id}.sav' + path_to_model = f'{model_dir}/model_version{version_id}.sav' if not os.path.lexists(path_to_model): clf = None else: diff --git a/src/baskerville/util/enums.py b/src/baskerville/util/enums.py index 95936d8f..f14e953e 100644 --- a/src/baskerville/util/enums.py +++ b/src/baskerville/util/enums.py @@ -117,30 +117,6 @@ class PartitionByEnum(BaseStrEnum): m = 'month' -class AlgorithmEnum(BaseStrEnum): - isolation_forest_sklearn = "sklearn.ensemble.IsolationForest" - isolation_forest_pyspark = "pyspark_iforest.ml.iforest.IForest" - - -class ScalerEnum(BaseStrEnum): - scaler_sklearn = "sklearn.preprocessing.StandardScaler" - scaler_pyspark = "pyspark.ml.feature.StandardScaler" - - -SPARK_ML_MODELS = { - AlgorithmEnum.isolation_forest_pyspark, - ScalerEnum.scaler_pyspark -} -SKLEARN_MODELS = { - AlgorithmEnum.isolation_forest_sklearn, - ScalerEnum.scaler_sklearn -} - -anomaly_detector_root = 'baskerville.models.anomaly_detector' - -ANOMALY_MODEL_MANAGER = { - AlgorithmEnum.isolation_forest_sklearn: f'{anomaly_detector_root}.' - f'ScikitAnomalyDetectorManager', - AlgorithmEnum.isolation_forest_pyspark: f'{anomaly_detector_root}.' - f'SparkAnomalyDetectorManager', -} \ No newline at end of file +class ModelEnum(BaseStrEnum): + isolation_forest_sklearn = "baskerville.models.anomaly_model_sklearn.AnomalyModelSklearn" + isolation_forest = "baskerville.models.anomaly_model.AnomalyModel" diff --git a/src/baskerville/util/file_manager.py b/src/baskerville/util/file_manager.py new file mode 100644 index 00000000..613948c7 --- /dev/null +++ b/src/baskerville/util/file_manager.py @@ -0,0 +1,116 @@ +import os +import shutil +import tempfile +import json +import _pickle as cPickle + +import errno + + +class FileManager(object): + + def __init__(self, path, spark_session=None): + """ + FileManager supports both local and distributed file systems. + + :param path: the root path of the storage. It must start with 'hdfs://ip:port/' for HDFS + :param spark_session: the spark session + """ + super().__init__() + self.spark_session = spark_session + + if path.startswith('hdfs://'): + if not self.spark_session: + raise RuntimeError('You must pass a valid spark session if you use distributed storage') + + # find the third occurrence of '/' + slash = path.find('/', path.find('/', path.find('/') + 1) + 1) + if slash == -1: + raise RuntimeError(f'Path "{path}" must contain at least one folder. ' + f'For example "hdfs://xxx:8020/baskerville"). ') + connection_URI= path[:slash] + self.jvm_path_class = self.spark_session._sc._gateway.jvm.org.apache.hadoop.fs.Path + self.jvm_file_system = self.spark_session._sc._gateway.jvm.org.apache.hadoop.fs.FileSystem.get( + self.spark_session._sc._gateway.jvm.java.net.URI(connection_URI), + self.spark_session._sc._jsc.hadoopConfiguration()) + else: + self.jvm_path = None + self.jvm_file_system = None + + def path_exists(self, path): + if self.jvm_file_system: + return self.jvm_file_system.exists(self.jvm_path_class(path)) + return os.path.exists(path) + + def delete_path(self, path): + if self.jvm_file_system: + return self.jvm_file_system.delete(self.jvm_path_class(path), True) + + try: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + return True + except OSError as exc: + return False + + def rename_path(self, source, destination): + if self.jvm_file_system: + self.jvm_file_system.rename(self.jvm_path_class(source), + self.jvm_path_class(destination)) + return + os.rename(source, destination) + + def _save_to_local_file(self, file, value, format='json'): + if format == 'json': + json.dump(value, file) + elif format == 'pickle': + cPickle.dump(value, file) + else: + raise RuntimeError(f'Unsupported file format {format}') + + def save_to_file(self, value, file_name, format='json'): + mode = 'w' if format == 'json' else 'wb' + + if not self.jvm_file_system: + if not os.path.exists(os.path.dirname(file_name)): + try: + os.makedirs(os.path.dirname(file_name)) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise + + with open(file_name, mode=mode) as f: + self._save_to_local_file(f, value, format=format) + return + + with tempfile.NamedTemporaryFile(mode=mode) as f: + self._save_to_local_file(f, value, format=format) + f.flush() + + self.jvm_file_system.copyFromLocalFile( + self.jvm_path_class(f.name), + self.jvm_path_class(file_name)) + + def _load_from_local_file(self, file, format='json'): + if format == 'json': + return json.load(file) + elif format == 'pickle': + return cPickle.load(file) + else: + raise RuntimeError(f'Unsupported file format {format}') + + def load_from_file(self, path, format='json'): + mode = 'r' if format == 'json' else 'rb' + + if not self.jvm_file_system: + with open(path, mode=mode) as f: + return self._load_from_local_file(f, format=format) + + with tempfile.NamedTemporaryFile(mode=mode) as f: + self.jvm_file_system.copyToLocalFile( + self.jvm_path_class(path), + self.jvm_path_class(f.name)) + return self._load_from_local_file(f, format=format) + diff --git a/src/baskerville/util/helpers.py b/src/baskerville/util/helpers.py index 81cb13e8..e866ead4 100644 --- a/src/baskerville/util/helpers.py +++ b/src/baskerville/util/helpers.py @@ -9,7 +9,6 @@ FOLDER_MODELS = 'models' FOLDER_CACHE = 'cache' -RANDOM_SEED = 42 def parse_config(path=None, data=None, tag='!ENV'): @@ -261,7 +260,6 @@ def get_default_data_path() -> str: os.path.dirname(os.path.realpath(__file__)), '..', '..', '..', 'data' ) - def get_days_in_year(year): """ Returns the number of days in a specific year @@ -407,10 +405,3 @@ def get_model_path(storage_path, model_name='model'): FOLDER_MODELS, f'{model_name}__{get_timestamp()}') - -def get_classifier_load_path(path): - return os.path.join(path, 'classifier') - - -def get_scaler_load_path(path): - return os.path.join(path, 'scaler') \ No newline at end of file diff --git a/src/baskerville/util/model_serialization.py b/src/baskerville/util/model_serialization.py deleted file mode 100644 index aee5042c..00000000 --- a/src/baskerville/util/model_serialization.py +++ /dev/null @@ -1,120 +0,0 @@ -import _pickle as cPickle -import os - -from baskerville.models.config import DatabaseConfig, SparkConfig -from baskerville.db import set_up_db -from baskerville.db.models import Model -from baskerville.spark import get_or_create_spark_session -from baskerville.util.enums import AlgorithmEnum -from baskerville.util.helpers import get_default_data_path, \ - get_classifier_load_path, get_scaler_load_path -from pyspark.ml.feature import StandardScalerModel - - -def pickle_model(model_id, db_config, out_path, ml_model_out_path): - """ - Pickle the baskerville.db.models.Model instance - :param int model_id: which model to get from database - :param dict db_config: the databse configuration - :param str out_path: where to store the pickled db model - :param str ml_model_out_path: where to store the actual ml model - (and scaler) - :return: - """ - db_cfg = DatabaseConfig(db_config) - session, _ = set_up_db(db_cfg.__dict__) - - model = session.query(Model).filter_by(id=model_id).first() - - with open(out_path, 'wb') as out_f: - cPickle.dump(model, out_f) - - print(f'Pickled Model with id: {model_id} to {out_path}') - if model.algorithm == AlgorithmEnum.isolation_forest_pyspark.value: - model_path = model.classifier.decode('utf-8') - if os.path.exists(model_path): - from pyspark_iforest.ml.iforest import IForestModel - spark = get_or_create_spark_session( - SparkConfig( - {'jars': f'{get_default_data_path()}/jars/spark-iforest-2.4.0.jar'} - ).validate()) - IForestModel.load(get_classifier_load_path(model_path)).write( - ).overwrite().save( - get_classifier_load_path(ml_model_out_path) - ) - StandardScalerModel.load(get_scaler_load_path(model_path)).write( - ).overwrite().save( - get_scaler_load_path(ml_model_out_path) - ) - print(f'Copied {model_path} to {ml_model_out_path}') - - -def import_pickled_model(db_config, model_path, current_storage_path): - """ - Loads the pickled model and imports it into the database - :param dict db_config: the database configuration - :return: - """ - db_cfg = DatabaseConfig(db_config).validate() - session, _ = set_up_db(db_cfg.__dict__, partition=False) - - with open(model_path, 'rb') as f: - model = cPickle.load(f) - - model_out = Model() - model_out.scaler = model.scaler - model_out.scaler_type = model.scaler_type - model_out.host_encoder = model.host_encoder - model_out.classifier = model.classifier - model_out.threshold = model.threshold - model_out.features = model.features - model_out.algorithm = model.algorithm - model_out.analysis_notebook = model.analysis_notebook - model_out.created_at = model.created_at - model_out.f1_score = model.f1_score - model_out.n_training = model.n_training - model_out.n_testing = model.n_testing - model_out.notes = model.notes - model_out.parameters = model.parameters - model_out.precision = model.precision - model_out.recall = model.recall - # todo: once the model changes are done: - # model_out.scaler = bytearray(current_storage_path).encode('utf-8') - # restore_model_path( - # current_storage_path, - # model.scaler.decode('utf-8')).encode('utf-8') - # ) - model_out.classifier = bytearray(current_storage_path.encode('utf-8')) - # bytearray( - # restore_model_path( - # current_storage_path, - # model.classifier.decode('utf-8')).encode('utf-8') - # ) - session.add(model_out) - session.commit() - session.close() - - -def restore_model_path(current_storage_path, previous_path: str): - import os - - filename = os.path.basename(previous_path) - print('>> filename, path', filename, previous_path) - if not os.path.isdir(previous_path): - return os.path.join(current_storage_path, filename) - return previous_path - - -if __name__ == '__main__': - db_cfg = { - 'name': 'baskerville_test', - 'user': 'postgres', - 'password': 'secret', - 'host': '127.0.0.1', - 'port': 5432, - 'type': 'postgres', - } - db_model_path = f'{get_default_data_path()}/samples/sample_model' - ml_model_path = f'{get_default_data_path()}/samples/test_model' - pickle_model(56, db_cfg, db_model_path, ml_model_path) - # import_pickled_model(db_cfg, db_model_path, ml_model_path) diff --git a/tests/unit/baskerville_tests/models_tests/test_base_spark.py b/tests/unit/baskerville_tests/models_tests/test_base_spark.py index 517d45ee..67624c16 100644 --- a/tests/unit/baskerville_tests/models_tests/test_base_spark.py +++ b/tests/unit/baskerville_tests/models_tests/test_base_spark.py @@ -7,8 +7,8 @@ from unittest import mock from baskerville.db.models import RequestSet -from baskerville.models.anomaly_detector import SparkAnomalyDetectorManager from baskerville.spark.helpers import StorageLevelFactory +from baskerville.util.helpers import get_default_data_path from tests.unit.baskerville_tests.helpers.spark_testing_base import \ SQLTestCaseLatestSpark @@ -18,7 +18,7 @@ from baskerville.models.config import ( Config, DatabaseConfig, SparkConfig, DataParsingConfig ) -from baskerville.util.enums import Step, AlgorithmEnum, LabelEnum +from baskerville.util.enums import Step, LabelEnum from pyspark.sql.types import (StructType, StructField, IntegerType, StringType, TimestampType) @@ -36,19 +36,11 @@ def setUp(self): self.request_set_cache_patcher = mock.patch( 'baskerville.models.base_spark.SparkPipelineBase.set_up_request_set_cache' ) - self.ANOMALY_MODEL_MANAGER_patcher = mock.patch( - 'baskerville.util.enums.ANOMALY_MODEL_MANAGER' - ) self.mock_baskerville_tools = self.db_tools_patcher.start() self.mock_spark_session = self.spark_patcher.start() self.mock_request_set_cache = self.request_set_cache_patcher.start() - self.mock_ANOMALY_MODEL_MANAGER = self.ANOMALY_MODEL_MANAGER_patcher.start() self.dummy_conf = Config({}) - self.db_conf = DatabaseConfig({ - 'user': 'postgres', - 'password': '***', - 'host': 'localhost' - }).validate() + self.db_conf = DatabaseConfig({'user': 'postgres', 'password': '***', 'host': 'localhost'}) self.engine_conf = mock.MagicMock() self.engine_conf.log_level = 'DEBUG' self.engine_conf.time_bucket = 10 @@ -78,9 +70,6 @@ def tearDown(self): if mock._is_started(self.request_set_cache_patcher): self.request_set_cache_patcher.stop() self.mock_request_set_cache.reset_mock() - if mock._is_started(self.mock_ANOMALY_MODEL_MANAGER): - self.mock_ANOMALY_MODEL_MANAGER.stop() - self.mock_ANOMALY_MODEL_MANAGER.reset_mock() def test_instance(self): self.assertTrue(hasattr(self.spark_pipeline, 'db_conf')) @@ -113,121 +102,119 @@ def test_instantiate_spark_session(self, mock_get_or_create_spark_session): ) self.db_tools_patcher.start() - def test_initialize(self): - # to call get_latest_ml_model_from_db + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_initialize(self, mock_bytes, mock_instantiate_from_str): + # to call get_ml_model_from_db self.engine_conf.model_id = -1 + model_index = mock.MagicMock() + model_index.id = 1 db_tools = self.mock_baskerville_tools.return_value - db_tools.get_latest_ml_model_from_db.return_value.algorithm = \ - AlgorithmEnum.isolation_forest_pyspark - with mock.patch.object(SparkAnomalyDetectorManager, 'load') \ - as mock_load: - self.spark_pipeline.initialize() - self.assertEqual( - self.spark_pipeline.time_bucket.sec, - self.engine_conf.time_bucket - ) - self.assertEqual( - self.spark_pipeline.time_bucket.td, - timedelta(seconds=self.engine_conf.time_bucket) - ) - - self.mock_baskerville_tools.assert_called_once_with( - self.spark_pipeline.db_conf - ) - db_tools = self.mock_baskerville_tools.return_value - db_tools.connect_to_db.assert_called_once() - - self.spark_pipeline.instantiate_spark_session.assert_called_once() - self.spark_pipeline.set_up_request_set_cache.assert_called_once() - - self.assertEqual(len(self.spark_pipeline.group_by_aggs), 3) - self.assertTrue('first_request' in self.spark_pipeline.group_by_aggs) - self.assertTrue('last_request' in self.spark_pipeline.group_by_aggs) - self.assertTrue('num_requests' in self.spark_pipeline.group_by_aggs) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['first_request']._jc), - 'min(@timestamp) AS `first_request`' - ) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['last_request']._jc), - 'max(@timestamp) AS `last_request`' - ) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['num_requests']._jc), - 'count(@timestamp) AS `num_requests`' - ) - - self.assertEqual(len(self.spark_pipeline.feature_manager.column_renamings), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_features), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_feature_names), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_columns), 0) - self.assertEqual(len(self.spark_pipeline.columns_to_filter_by), 3) - self.assertSetEqual( - self.spark_pipeline.columns_to_filter_by, - {'client_request_host', 'client_ip', '@timestamp'} - ) - db_tools.get_latest_ml_model_from_db.assert_called_once() - mock_load.assert_called_once() - - def test_initialize_model_path(self): + db_tools.get_ml_model_from_db.return_value = model_index - # to call get_ml_model_from_file - self.engine_conf.model_id = None - self.engine_conf.model_path = 'some test path' - db_tools = self.mock_baskerville_tools.return_value - db_tools.get_ml_model_from_file.return_value.algorithm = \ - AlgorithmEnum.isolation_forest_pyspark - with mock.patch.object(SparkAnomalyDetectorManager, - 'load') as mock_load: - self.spark_pipeline.initialize() - self.assertEqual( - self.spark_pipeline.time_bucket.sec, - self.engine_conf.time_bucket - ) - self.assertEqual( - self.spark_pipeline.time_bucket.td, - timedelta(seconds=self.engine_conf.time_bucket) - ) + self.spark_pipeline.initialize() + self.assertEqual( + self.spark_pipeline.time_bucket.sec, + self.engine_conf.time_bucket + ) + self.assertEqual( + self.spark_pipeline.time_bucket.td, + timedelta(seconds=self.engine_conf.time_bucket) + ) - db_tools = self.spark_pipeline.tools - db_tools.connect_to_db.assert_called_once() + self.mock_baskerville_tools.assert_called_once_with( + self.spark_pipeline.db_conf + ) - self.spark_pipeline.instantiate_spark_session.assert_called_once() - self.spark_pipeline.set_up_request_set_cache.assert_called_once() + db_tools.connect_to_db.assert_called_once() - self.assertEqual(len(self.spark_pipeline.group_by_aggs), 3) - self.assertTrue('first_request' in self.spark_pipeline.group_by_aggs) - self.assertTrue('last_request' in self.spark_pipeline.group_by_aggs) - self.assertTrue('num_requests' in self.spark_pipeline.group_by_aggs) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['first_request']._jc), - 'min(@timestamp) AS `first_request`' - ) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['last_request']._jc), - 'max(@timestamp) AS `last_request`' - ) - self.assertEqual( - str(self.spark_pipeline.group_by_aggs['num_requests']._jc), - 'count(@timestamp) AS `num_requests`' - ) + self.spark_pipeline.instantiate_spark_session.assert_called_once() + self.spark_pipeline.set_up_request_set_cache.assert_called_once() - self.assertEqual(len(self.spark_pipeline.feature_manager.column_renamings), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_features), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_feature_names), 0) - self.assertEqual(len(self.spark_pipeline.feature_manager.active_columns), 0) - self.assertEqual(len(self.spark_pipeline.columns_to_filter_by), 3) - self.assertSetEqual( - self.spark_pipeline.columns_to_filter_by, - {'client_request_host', 'client_ip', '@timestamp'} - ) - db_tools = self.spark_pipeline.tools - db_tools.get_ml_model_from_file.assert_called_once_with( - self.engine_conf.model_path - ) - mock_load.assert_called_once() + self.assertEqual(len(self.spark_pipeline.group_by_aggs), 3) + self.assertTrue('first_request' in self.spark_pipeline.group_by_aggs) + self.assertTrue('last_request' in self.spark_pipeline.group_by_aggs) + self.assertTrue('num_requests' in self.spark_pipeline.group_by_aggs) + self.assertEqual( + str(self.spark_pipeline.group_by_aggs['first_request']._jc), + 'min(@timestamp) AS `first_request`' + ) + self.assertEqual( + str(self.spark_pipeline.group_by_aggs['last_request']._jc), + 'max(@timestamp) AS `last_request`' + ) + self.assertEqual( + str(self.spark_pipeline.group_by_aggs['num_requests']._jc), + 'count(@timestamp) AS `num_requests`' + ) - def test_initialize_no_model_register_metrics(self): + self.assertEqual(len(self.spark_pipeline.feature_manager.column_renamings), 0) + self.assertEqual(len(self.spark_pipeline.feature_manager.active_features), 0) + self.assertEqual(len(self.spark_pipeline.feature_manager.active_feature_names), 0) + self.assertEqual(len(self.spark_pipeline.feature_manager.active_columns), 0) + self.assertEqual(len(self.spark_pipeline.columns_to_filter_by), 3) + self.assertSetEqual( + self.spark_pipeline.columns_to_filter_by, + {'client_request_host', 'client_ip', '@timestamp'} + ) + mock_bytes.decode.assert_called_once() + mock_instantiate_from_str.assert_called_once() + + # def test_initialize_model_path(self): + # + # # to call get_ml_model_from_file + # self.engine_conf.model_id = None + # self.engine_conf.model_path = 'some test path' + # self.spark_pipeline.model_manager.set_anomaly_detector_broadcast = mock.MagicMock() + # self.spark_pipeline.initialize() + # self.assertEqual( + # self.spark_pipeline.time_bucket.sec, + # self.engine_conf.time_bucket + # ) + # self.assertEqual( + # self.spark_pipeline.time_bucket.td, + # timedelta(seconds=self.engine_conf.time_bucket) + # ) + # + # db_tools = self.spark_pipeline.tools + # db_tools.connect_to_db.assert_called_once() + # + # self.spark_pipeline.instantiate_spark_session.assert_called_once() + # self.spark_pipeline.set_up_request_set_cache.assert_called_once() + # + # self.assertEqual(len(self.spark_pipeline.group_by_aggs), 3) + # self.assertTrue('first_request' in self.spark_pipeline.group_by_aggs) + # self.assertTrue('last_request' in self.spark_pipeline.group_by_aggs) + # self.assertTrue('num_requests' in self.spark_pipeline.group_by_aggs) + # self.assertEqual( + # str(self.spark_pipeline.group_by_aggs['first_request']._jc), + # 'min(@timestamp) AS `first_request`' + # ) + # self.assertEqual( + # str(self.spark_pipeline.group_by_aggs['last_request']._jc), + # 'max(@timestamp) AS `last_request`' + # ) + # self.assertEqual( + # str(self.spark_pipeline.group_by_aggs['num_requests']._jc), + # 'count(@timestamp) AS `num_requests`' + # ) + # + # self.assertEqual(len(self.spark_pipeline.feature_manager.column_renamings), 0) + # self.assertEqual(len(self.spark_pipeline.feature_manager.active_features), 0) + # self.assertEqual(len(self.spark_pipeline.feature_manager.active_feature_names), 0) + # self.assertEqual(len(self.spark_pipeline.feature_manager.active_columns), 0) + # self.assertEqual(len(self.spark_pipeline.columns_to_filter_by), 3) + # self.assertSetEqual( + # self.spark_pipeline.columns_to_filter_by, + # {'client_request_host', 'client_ip', '@timestamp'} + # ) + # db_tools = self.spark_pipeline.tools + # db_tools.get_ml_model_from_file.assert_called_once_with( + # self.engine_conf.model_path + # ) + + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + def test_initialize_no_model_register_metrics(self, mock_instantiate_from_str): # for an empty model: self.engine_conf.model_id = None @@ -236,7 +223,6 @@ def test_initialize_no_model_register_metrics(self): self.engine_conf.metrics = mock.MagicMock() self.engine_conf.metrics.progress = True self.spark_pipeline.register_metrics = mock.MagicMock() - self.spark_pipeline.model_manager.load = mock.MagicMock() self.spark_pipeline.initialize() self.assertEqual( self.spark_pipeline.time_bucket.sec, @@ -280,8 +266,6 @@ def test_initialize_no_model_register_metrics(self): {'client_request_host', 'client_ip', '@timestamp'} ) - self.assertTrue(self.spark_pipeline.model_manager.ml_model is None) - def test_add_calc_columns(self): mock_feature = mock.MagicMock() now = datetime.utcnow() @@ -319,7 +303,9 @@ def test_add_calc_columns(self): self.assertDataFrameEqual(self.spark_pipeline.logs_df, df_expected) - def test_add_post_group_columns(self): + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_add_post_group_columns(self, mock_bytes, mock_instantiate_from_str): # spark saves binary as byte array pickled = bytearray(pickle.dumps(-1)) @@ -362,7 +348,12 @@ def test_add_post_group_columns(self): ]) self.spark_pipeline.runtime = mock.MagicMock() self.spark_pipeline.runtime.id = -1 - self.spark_pipeline.model_manager = mock.MagicMock() + + model_index = mock.MagicMock() + model_index.id = 1 + db_tools = self.mock_baskerville_tools.return_value + db_tools.get_ml_model_from_db.return_value = model_index + self.spark_pipeline.initialize() df = self.session.createDataFrame(logs, schema=schema) @@ -375,20 +366,17 @@ def test_add_post_group_columns(self): 'last_request', 'old_subset_count', ) - self.spark_pipeline.model_manager.ml_model.id = 1 - df = df.withColumn( - 'model_version', F.lit( - self.spark_pipeline.model_manager.ml_model.id - ) - ) + df = df.withColumn('classifier', F.lit(pickled)) df = df.withColumn('model', F.lit(pickled)) df = df.withColumn('scaler', F.lit(pickled)) df = df.withColumn('model_features', F.lit(pickled)) df = df.withColumn('subset_count', F.lit(10)) + df = df.withColumn('model_version', F.lit(1)) df = df.withColumn('start', F.col('first_ever_request')) df = df.withColumn('stop', F.col('last_request')) + self.spark_pipeline.model = None self.spark_pipeline.engine_conf.time_bucket = 600 self.spark_pipeline.add_cache_columns = mock.MagicMock() self.spark_pipeline.add_post_groupby_columns() @@ -411,7 +399,10 @@ def test_add_post_group_columns(self): df ) - def test_add_post_group_columns_no_ml_model(self): + + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_add_post_group_columns_no_ml_model(self, mock_bytes, mock_instantiate_from_str): logs = [ { 'client_request_host': 'testhost', @@ -449,9 +440,11 @@ def test_add_post_group_columns_no_ml_model(self): StructField('first_request', TimestampType(), True), StructField('last_request', TimestampType(), True), ]) - self.spark_pipeline.model_manager.load = mock.MagicMock() - self.spark_pipeline.model_manager.load.return_value = None - self.spark_pipeline.model_manager.ml_model = None + + model_index = mock.MagicMock() + model_index.id = 1 + db_tools = self.mock_baskerville_tools.return_value + db_tools.get_ml_model_from_db.return_value = model_index self.spark_pipeline.initialize() @@ -467,6 +460,11 @@ def test_add_post_group_columns_no_ml_model(self): 'last_request', 'old_subset_count', ) + df = df.withColumn( + 'model_version', F.lit( + 1 + ) + ) df = df.withColumn('subset_count', F.lit(10)) df = df.withColumn('start', F.col('first_ever_request')) df = df.withColumn('stop', F.col('last_request')) @@ -493,7 +491,9 @@ def test_add_post_group_columns_no_ml_model(self): df ) - def test_group_by(self): + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_group_by(self, mock_bytes, mock_instantiate_from_str): logs = [ { 'client_request_host': 'testhost', @@ -530,11 +530,14 @@ def test_group_by(self): self.spark_pipeline.feature_manager.get_active_features.return_value = [mock_feature] self.spark_pipeline.add_post_groupby_columns = mock.MagicMock() self.spark_pipeline.feature_manager.active_features = [mock_feature] - self.spark_pipeline.model_manager.load = mock.MagicMock() - self.spark_pipeline.model_manager.load.return_value = None - self.spark_pipeline.model_manager.ml_model = None + self.spark_pipeline.model = None self.spark_pipeline.active_columns = self.spark_pipeline.feature_manager.get_active_columns() + model_index = mock.MagicMock() + model_index.id = 1 + db_tools = self.mock_baskerville_tools.return_value + db_tools.get_ml_model_from_db.return_value = model_index + self.spark_pipeline.initialize() self.spark_pipeline.group_by() @@ -552,7 +555,9 @@ def test_group_by(self): self.assertTrue('num_requests' in self.spark_pipeline.logs_df.columns) self.assertDataFrameEqual(grouped_df, self.spark_pipeline.logs_df) - def test_group_by_empty_feature_group_by_aggs(self): + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_group_by_empty_feature_group_by_aggs(self, mock_bytes, mock_instantiate_from_str): logs = [ { 'client_request_host': 'testhost', @@ -585,8 +590,14 @@ def test_group_by_empty_feature_group_by_aggs(self): mock_feature = mock.MagicMock() mock_feature.columns = [] self.spark_pipeline.add_post_groupby_columns = mock.MagicMock() - self.spark_pipeline.model_manager = mock.MagicMock() + self.spark_pipeline.set_broadcasts = mock.MagicMock() self.spark_pipeline.feature_manager.active_features = [mock_feature] + + model_index = mock.MagicMock() + model_index.id = 1 + db_tools = self.mock_baskerville_tools.return_value + db_tools.get_ml_model_from_db.return_value = model_index + self.spark_pipeline.initialize() self.spark_pipeline.group_by() @@ -665,7 +676,9 @@ def test_features_to_dict(self): for k, v in cldf[i].features.items(): self.assertAlmostEqual(getattr(row, k), v, 1) - def test_save(self): + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_save(self, mock_bytes, mock_instantiate_from_str): now = datetime.now() logs = [ @@ -720,7 +733,7 @@ def test_save(self): 'model_version': 'test', } ] - self.spark_pipeline.model_manager = mock.MagicMock() + self.spark_pipeline.set_broadcasts = mock.MagicMock() self.spark_pipeline.initialize() df = self.session.createDataFrame(logs) @@ -793,7 +806,6 @@ def test_predict_no_ml_model(self, mock_udf): df = df.withColumn('features', F.lit(0)) df = df.withColumn('prediction', F.lit(LabelEnum.unknown.value).cast('int')) df = df.withColumn('score', F.lit(LabelEnum.unknown.value).cast('float')) - df = df.withColumn('threshold', F.lit(LabelEnum.unknown.value).cast('float')) df = self.fix_schema( df, self.spark_pipeline.logs_df.schema, @@ -950,7 +962,9 @@ def test_save_df_to_table_json_cols(self, col_to_json): self.assertTupleEqual(tuple(actual_json_col), test_json_cols) - def test_filter_columns(self): + @mock.patch('baskerville.models.base_spark.instantiate_from_str') + @mock.patch('baskerville.models.base_spark.bytes') + def test_filter_columns(self, mock_bytes, mock_instantiate_from_str): logs = [ { 'client_ip': '1', @@ -979,8 +993,14 @@ def test_filter_columns(self): self.spark_pipeline.feature_manager.get_active_columns.return_value = [ 'client_ip', 'client_request_host', 'afeature', 'drop_if_null' ] - self.spark_pipeline.model_manager = mock.MagicMock() + self.spark_pipeline.set_broadcasts = mock.MagicMock() + model_index = mock.MagicMock() + model_index.id = 1 + db_tools = self.mock_baskerville_tools.return_value + db_tools.get_ml_model_from_db.return_value = model_index + self.spark_pipeline.initialize() + self.spark_pipeline.logs_df = df self.spark_pipeline.filter_columns() @@ -1001,10 +1021,9 @@ def test_filter_columns(self): ) def get_data_parser_helper(self): - from tests.unit.baskerville_tests.helpers.utils import get_default_data_path data_conf = DataParsingConfig({ 'parser': 'JSONLogSparkParser', - 'schema': f'{get_default_data_path()}/sample_log_schema.json', + 'schema': f'{get_default_data_path()}/samples/log_schema.json', 'timestamp_column': '@timestamp' }) data_conf.validate() diff --git a/tests/unit/baskerville_tests/models_tests/test_feature_manager.py b/tests/unit/baskerville_tests/models_tests/test_feature_manager.py index 34197cef..d8c77a83 100644 --- a/tests/unit/baskerville_tests/models_tests/test_feature_manager.py +++ b/tests/unit/baskerville_tests/models_tests/test_feature_manager.py @@ -38,8 +38,6 @@ def _helper_check_features_instantiated_once(self, mock_features): v.assert_called_once() def test_initialize(self): - mock_model_manager = mock.MagicMock() - mock_model_manager.ml_model = None mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): mock_features[k].DEPENDENCIES = [] @@ -60,11 +58,11 @@ def test_initialize(self): self.mock_engine_conf.extra_features = list(mock_features.keys()) self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.all_features = mock_features - self.feature_manager.initialize(mock_model_manager) + self.feature_manager.initialize() self.assertTrue( len(self.feature_manager.active_features), @@ -91,11 +89,9 @@ def test_initialize(self): ) def test_get_active_features_no_ml_model(self): - mock_model_manager = mock.MagicMock() - mock_model_manager.ml_model = None mock_features = self._helper_get_mock_features() self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.all_features = mock_features self.feature_manager.extra_features = mock_features.keys() @@ -109,11 +105,9 @@ def test_get_active_features_no_ml_model(self): ) def test_get_active_features_no_ml_model_unknown_features(self): - mock_model_manager = mock.MagicMock() - mock_model_manager.ml_model = None mock_features = self._helper_get_mock_features() self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.all_features = mock_features self.feature_manager.extra_features = list(mock_features.keys()) + [ @@ -129,44 +123,6 @@ def test_get_active_features_no_ml_model_unknown_features(self): 4 ) - def test_get_active_features_all_features(self): - mock_model_manager = mock.MagicMock() - mock_features = self._helper_get_mock_features() - mock_model_manager.ml_model.features = ['mock_feature1'] - - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) - self.feature_manager.all_features = mock_features - self.feature_manager.extra_features = list(mock_features.keys()) - - actual_active_features = self.feature_manager.get_active_features() - - self._helper_check_features_instantiated_once(mock_features) - self.assertEqual( - len(actual_active_features), - 4 - ) - - def test_get_active_features_extra_features(self): - mock_model_manager = mock.MagicMock() - mock_features = self._helper_get_mock_features() - mock_model_manager.ml_model.features = ['mock_feature1'] - - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) - self.feature_manager.all_features = mock_features - self.feature_manager.extra_features = list(mock_features.keys())[1:] - - actual_active_features = self.feature_manager.get_active_features() - - self._helper_check_features_instantiated_once(mock_features) - self.assertEqual( - len(actual_active_features), - 4 - ) - def test_get_active_feature_names(self): mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): @@ -233,61 +189,28 @@ def test_get_update_feature_columns_no_updateable_features(self): self.assertEqual(feature_columns, []) def test_feature_config_is_valid_true(self): - mock_model_manager = mock.MagicMock() mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): f.feature_name_from_class.return_value = k f.dependencies = list(mock_features.values()) self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.active_features = mock_features.values() self.feature_manager.active_feature_names = list(mock_features.keys()) - mock_model_manager.ml_model.features = list(mock_features.keys()) - - is_config_valid = self.feature_manager.feature_config_is_valid() - self.assertEqual(is_config_valid, True) - - def test_feature_config_is_valid_true_no_ml_model(self): - mock_model_manager = mock.MagicMock() - mock_model_manager.ml_model = None - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) is_config_valid = self.feature_manager.feature_config_is_valid() self.assertEqual(is_config_valid, True) def test_feature_config_is_valid_false_feature_dependencies_not_met(self): - mock_model_manager = mock.MagicMock() mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): f.feature_name_from_class.return_value = 'test' f.dependencies = list(mock_features.values()) self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) - self.feature_manager.active_features = mock_features.values() - self.feature_manager.active_feature_names = list(mock_features.keys()) - mock_model_manager.ml_model.features = list(mock_features.keys()) - - is_config_valid = self.feature_manager.feature_config_is_valid() - self.assertEqual(is_config_valid, False) - - def test_feature_config_is_valid_false_feature_different_ml_features(self): - mock_model_manager = mock.MagicMock() - mock_features = self._helper_get_mock_features() - for k, f in mock_features.items(): - f.feature_name_from_class.return_value = k - f.dependencies = list(mock_features.values()) - - mock_model_manager.ml_model.features = list(mock_features.keys()) + [ - 'unknown_feature'] - - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.active_features = mock_features.values() self.feature_manager.active_feature_names = list(mock_features.keys()) @@ -295,43 +218,14 @@ def test_feature_config_is_valid_false_feature_different_ml_features(self): is_config_valid = self.feature_manager.feature_config_is_valid() self.assertEqual(is_config_valid, False) - def test_features_subset_of_model_features_true(self): - mock_model_manager = mock.MagicMock() - mock_features = self._helper_get_mock_features() - mock_model_manager.ml_model.features = list(mock_features.keys()) - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) - self.feature_manager.active_features = mock_features.values() - self.feature_manager.active_feature_names = list(mock_features.keys()) - - is_subset = self.feature_manager.features_subset_of_model_features() - self.assertEqual(is_subset, True) - - def test_features_subset_of_model_features_true_false(self): - mock_model_manager = mock.MagicMock() - mock_features = self._helper_get_mock_features() - mock_model_manager.ml_model.features = list(mock_features.keys()) + [ - 'unknown_feature'] - - self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager - ) - self.feature_manager.active_features = mock_features.values() - self.feature_manager.active_feature_names = list(mock_features.keys()) - - is_subset = self.feature_manager.features_subset_of_model_features() - self.assertEqual(is_subset, False) - def test_feature_dependencies_met_true(self): - mock_model_manager = mock.MagicMock() mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): f.feature_name_from_class.return_value = k f.dependencies = list(mock_features.values()) self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.active_features = mock_features.values() self.feature_manager.active_feature_names = list(mock_features.keys()) @@ -340,14 +234,13 @@ def test_feature_dependencies_met_true(self): self.assertEqual(is_config_valid, True) def test_feature_dependencies_met_false(self): - mock_model_manager = mock.MagicMock() mock_features = self._helper_get_mock_features() for k, f in mock_features.items(): f.feature_name_from_class.return_value = 'other feature' f.dependencies = list(mock_features.values()) self.feature_manager = FeatureManager( - self.mock_engine_conf, mock_model_manager + self.mock_engine_conf ) self.feature_manager.active_features = mock_features.values() self.feature_manager.active_feature_names = list(mock_features.keys()) diff --git a/tests/unit/baskerville_tests/models_tests/test_model_manager.py b/tests/unit/baskerville_tests/models_tests/test_model_manager.py deleted file mode 100644 index 977600db..00000000 --- a/tests/unit/baskerville_tests/models_tests/test_model_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -import _pickle as cPickle -from unittest import mock - -import pyspark -from baskerville.models.anomaly_detector import AnomalyDetector, \ - ScikitAnomalyDetectorManager, SparkAnomalyDetectorManager -from baskerville.models.config import DatabaseConfig, EngineConfig -from baskerville.models.model_manager import ModelManager -from baskerville.util.enums import AlgorithmEnum - -from tests.unit.baskerville_tests.helpers.spark_testing_base import \ - SQLTestCaseLatestSpark - - -class TestModelManager(SQLTestCaseLatestSpark): - def setUp(self): - super(TestModelManager, self).setUp() - self.db_tools_patcher = mock.patch( - 'baskerville.util.baskerville_tools.BaskervilleDBTools' - ) - self.spark_patcher = mock.patch( - 'baskerville.models.base_spark.SparkPipelineBase.' - 'instantiate_spark_session' - ) - self.mock_db_tools = self.db_tools_patcher.start() - self.mock_spark = self.spark_patcher.start() - self.mock_db_conf = DatabaseConfig({}).validate() - self.mock_engine_conf = EngineConfig({ - 'log_level': 'INFO' - }).validate() - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf, self.session - ) - - def test_initialize(self): - test_session = 'test session' - test_db_tools = 'test db_tools' - - with mock.patch.object( - ModelManager, 'load' - ) as mock_load: - - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf - ) - self.model_manager.initialize(test_session, test_db_tools) - self.assertEqual(self.model_manager.spark_session, test_session) - self.assertEqual(self.model_manager.db_tools, test_db_tools) - mock_load.assert_called_once() - - def test_get_active_model_model_id_scikit(self): - with mock.patch.object( - ScikitAnomalyDetectorManager, 'load' - ) as mock_manager: - # mock_engine_conf = mock.MagicMock() - mock_db_tools = mock.MagicMock() - self.mock_engine_conf.model_id = 1 - mock_model_value = mock.MagicMock() - mock_model_value.algorithm = AlgorithmEnum.isolation_forest_sklearn - mock_db_tools.get_ml_model_from_db.return_value = mock_model_value - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf, db_tools=mock_db_tools - ) - - active_model = self.model_manager.load() - self.assertEqual(active_model, mock_model_value) - mock_db_tools.get_ml_model_from_db.assert_called_once_with( - self.mock_engine_conf.model_id - ) - mock_manager.assert_called_once() - - def test_get_active_model_model_id_sparkml(self): - with mock.patch.object( - SparkAnomalyDetectorManager, 'load' - ) as mock_manager: - # mock_engine_conf = mock.MagicMock() - mock_db_tools = mock.MagicMock() - self.mock_engine_conf.model_id = 1 - mock_model_value = mock.MagicMock() - mock_model_value.algorithm = AlgorithmEnum.isolation_forest_pyspark - mock_db_tools.get_ml_model_from_db.return_value = mock_model_value - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf, db_tools=mock_db_tools - ) - - active_model = self.model_manager.load() - self.assertEqual(active_model, mock_model_value) - mock_db_tools.get_ml_model_from_db.assert_called_once_with( - self.mock_engine_conf.model_id - ) - mock_manager.assert_called_once() - - def test_get_active_model_model_path_scikit(self): - with mock.patch.object( - ScikitAnomalyDetectorManager, 'load' - ) as mock_manager: - self.mock_engine_conf.model_id = None - self.mock_engine_conf.model_path = 'a path' - mock_model_value = mock.MagicMock() - mock_model_value.algorithm = AlgorithmEnum.isolation_forest_sklearn - mock_model_value.classifier = cPickle.dumps({}) - mock_model_value.scaler = cPickle.dumps({}) - mock_model_value.host_encoder = cPickle.dumps({}) - self.mock_db_tools.get_ml_model_from_file.return_value = mock_model_value - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf, db_tools=self.mock_db_tools - ) - active_model = self.model_manager.load() - self.assertEqual(active_model, mock_model_value) - self.mock_db_tools.get_ml_model_from_file.assert_called_once_with( - self.mock_engine_conf.model_path - ) - - def test_get_active_model_model_path_spark_ml(self): - with mock.patch.object( - SparkAnomalyDetectorManager, 'load' - ) as mock_manager: - self.mock_engine_conf.model_id = None - self.mock_engine_conf.model_path = 'a path' - mock_model_value = mock.MagicMock() - mock_model_value.algorithm = AlgorithmEnum.isolation_forest_pyspark - mock_model_value.classifier = cPickle.dumps({}) - mock_model_value.scaler = cPickle.dumps({}) - mock_model_value.host_encoder = cPickle.dumps({}) - self.mock_db_tools.get_ml_model_from_file.return_value = mock_model_value - self.model_manager = ModelManager( - self.mock_db_conf, self.mock_engine_conf, db_tools=self.mock_db_tools - ) - active_model = self.model_manager.load() - self.assertEqual(active_model, mock_model_value) - self.mock_db_tools.get_ml_model_from_file.assert_called_once_with( - self.mock_engine_conf.model_path - ) \ No newline at end of file diff --git a/tests/unit/baskerville_tests/models_tests/test_training_pipelines.py b/tests/unit/baskerville_tests/models_tests/test_training_pipelines.py deleted file mode 100644 index 7a199b41..00000000 --- a/tests/unit/baskerville_tests/models_tests/test_training_pipelines.py +++ /dev/null @@ -1,107 +0,0 @@ -import sys -from unittest import mock - -from baskerville.models.config import DatabaseConfig, EngineConfig, SparkConfig -from tests.unit.baskerville_tests.helpers.spark_testing_base import \ - SQLTestCaseLatestSpark - - -class TestTrainingSparkMLPipeline(SQLTestCaseLatestSpark): - def setUp(self): - if 'baskerville.util.helpers' in sys.modules: - del sys.modules['baskerville.util.helpers'] - if 'baskerville.models.pipeline_training' in sys.modules: - del sys.modules['baskerville.models.pipeline_training'] - if 'baskerville.models.model_manager' in sys.modules: - del sys.modules['baskerville.models.model_manager'] - if 'tests.unit.baskerville_tests.utils_tests.test_helpers' in sys.modules: - del sys.modules['tests.unit.baskerville_tests.utils_tests.test_helpers'] - - self.db_tools_patcher = mock.patch( - 'baskerville.util.baskerville_tools.BaskervilleDBTools' - ) - self.instantiate_from_str_patcher = mock.patch( - 'baskerville.util.helpers.instantiate_from_str' - ) - self.mock_baskerville_tools = self.db_tools_patcher.start() - self.mock_instantiate_from_str = self.instantiate_from_str_patcher.start() - self.training_conf = { - 'model_name': 'test_model', - 'scaler_name': 'test_scaler', - 'path': 'test_scaler', - 'classifier': 'pyspark_iforest.ml.iforest.IForest', - 'scaler': 'pyspark.ml.feature.StandardScaler', - 'data_parameters': { - 'training_days': 30, - 'from_date': '2020-02-25', - 'to_date': '2020-03-15', - }, - 'classifier_parameters': { - 'maxSamples': 10000, - 'contamination': 0.1, - 'numTrees': 1000 - } - } - self.db_conf = DatabaseConfig({}).validate() - self.engine_conf = EngineConfig({ - 'log_level': 'INFO', - 'training': self.training_conf, - 'extra_features': ['a', 'b', 'c'] - }).validate() - - # from baskerville.util.helpers import get_default_data_path - # iforest_jar = f'{get_default_data_path()}/jars/spark-iforest-2.4.0.jar' - self.spark_conf = SparkConfig({ - 'db_driver': 'test', - # 'jars': iforest_jar, - # 'driver_extra_class_path': iforest_jar - }).validate() - - from baskerville.models.pipeline_training import \ - TrainingSparkMLPipeline - - self.pipeline = TrainingSparkMLPipeline( - self.db_conf, - self.engine_conf, - self.spark_conf - ) - super().setUp() - - @mock.patch('baskerville.util.helpers.instantiate_from_str') - def test_initialize(self, mock_instantiate_from_str): - # todo: some reference to instantiate_from_str prevents the correct - # mocking - thus this test fails when ran with all other tests - pass - # self.pipeline.initialize() - # self.assertTrue(self.pipeline.spark is not None) - # self.assertTrue(self.pipeline.feature_manager is not None) - # self.assertTrue(self.pipeline.scaler is not None) - # self.assertTrue(self.pipeline.classifier is not None) - # self.pipeline.classifier.setParams.assert_called() - # self.pipeline.scaler.setParams.assert_called() - # self.pipeline.classifier.setSeed.assert_called_once_with(42) - - # def test_get_data(self): - # raise NotImplementedError() - # - # def test_train(self): - # raise NotImplementedError() - # - # def test_test(self): - # raise NotImplementedError() - # - # def test_evaluate(self): - # raise NotImplementedError() - # - # def test_save(self): - # raise NotImplementedError() - # - # def test_get_bounds(self): - # raise NotImplementedError() - # - # def test_load(self): - # raise NotImplementedError() - # - # def test_finish_up(self): - # raise NotImplementedError() - diff --git a/tests/unit/baskerville_tests/utils_tests/test_file_manager.py b/tests/unit/baskerville_tests/utils_tests/test_file_manager.py new file mode 100644 index 00000000..ca9a9b2f --- /dev/null +++ b/tests/unit/baskerville_tests/utils_tests/test_file_manager.py @@ -0,0 +1,75 @@ +import unittest +import tempfile +import os + +from pyspark import SparkConf +from pyspark.sql import SparkSession + +from baskerville.util.file_manager import FileManager + + +class TestFileManager(unittest.TestCase): + + def test_json(self): + temp_dir = tempfile.gettempdir() + fm = FileManager(path=temp_dir) + + some_dict = {'A': 777} + file_name = os.path.join(temp_dir, 'file_manager_test.pickle') + + fm.save_to_file(value=some_dict, file_name=file_name, format='json') + assert(os.path.exists(os.path.join(temp_dir, file_name))) + + res = fm.load_from_file(file_name) + self.assertDictEqual(res, some_dict) + + def test_pickle(self): + temp_dir = tempfile.gettempdir() + fm = FileManager(path=temp_dir) + + some_dict = {'A': 777} + file_name = os.path.join(temp_dir, 'file_manager_test.pickle') + + fm.save_to_file(value=some_dict, file_name=file_name, format='pickle') + res = fm.load_from_file(file_name, format='pickle') + self.assertDictEqual(res, some_dict) + + # HDFS tests are commented our since it should no be executed with all the unit tests + # def test_json_hdfs(self): + # temp_dir = 'hdfs://hadoop-01:8020/anton2' + # + # conf = SparkConf() + # conf.set('spark.hadoop.dfs.client.use.datanode.hostname', 'true') + # + # spark = SparkSession \ + # .builder.config(conf=conf) \ + # .appName("aaa") \ + # .getOrCreate() + # + # fm = FileManager(path=temp_dir, spark_session=spark) + # + # some_dict = {'A': 777} + # file_name = os.path.join(temp_dir, 'file_manager_test7.pickle') + # + # fm.save_to_file(value=some_dict, file_name=file_name, format='json') + # res = fm.load_from_file(file_name, format='json') + # self.assertDictEqual(res, some_dict) + # + # def test_pickle_hdfs(self): + # temp_dir = 'hdfs://hadoop-01:8020/anton2' + # + # conf = SparkConf() + # conf.set('spark.hadoop.dfs.client.use.datanode.hostname', 'true') + # + # spark = SparkSession \ + # .builder.config(conf=conf) \ + # .appName("aaa") \ + # .getOrCreate() + # + # fm = FileManager(path=temp_dir, spark_session=spark) + # some_dict = {'A': 777} + # file_name = os.path.join(temp_dir, 'file_manager_test7.pickle') + # + # fm.save_to_file(value=some_dict, file_name=file_name, format='pickle') + # res = fm.load_from_file(file_name, format='pickle') + # self.assertDictEqual(res, some_dict)