diff --git a/.gitignore b/.gitignore index b6e47617..07a41974 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,5 @@ dmypy.json # Pyre type checker .pyre/ + +ip_cache/ diff --git a/alembic/versions/88eb5854154f_add_id_group_in_request_sets.py b/alembic/versions/88eb5854154f_add_id_group_in_request_sets.py index 1b50c306..3d55b97b 100644 --- a/alembic/versions/88eb5854154f_add_id_group_in_request_sets.py +++ b/alembic/versions/88eb5854154f_add_id_group_in_request_sets.py @@ -1,4 +1,4 @@ -"""add id_request_sets in request_sets +"""add uuid_request_set in request_sets Revision ID: 88eb5854154f Revises: @@ -16,8 +16,8 @@ def upgrade(): - op.add_column('request_sets', sa.Column('id_request_sets', sa.TEXT)) + op.add_column('request_sets', sa.Column('uuid_request_set', sa.TEXT)) def downgrade(): - op.op.drop_column('request_sets', 'id_request_sets') + op.op.drop_column('request_sets', 'uuid_request_set') diff --git a/data/Baskerville ER Diagram.png b/data/Baskerville ER Diagram.png new file mode 100644 index 00000000..a1da0db6 Binary files /dev/null and b/data/Baskerville ER Diagram.png differ diff --git a/data/jars/aws-java-sdk-1.7.4.jar b/data/jars/aws-java-sdk-1.7.4.jar new file mode 100644 index 00000000..02233a84 Binary files /dev/null and b/data/jars/aws-java-sdk-1.7.4.jar differ diff --git a/data/jars/hadoop-aws-2.7.1.jar b/data/jars/hadoop-aws-2.7.1.jar new file mode 100644 index 00000000..ea6581c1 Binary files /dev/null and b/data/jars/hadoop-aws-2.7.1.jar differ diff --git a/data/samples/sample_feedback_schema.json b/data/samples/sample_feedback_schema.json new file mode 100644 index 00000000..9f41f4c9 --- /dev/null +++ b/data/samples/sample_feedback_schema.json @@ -0,0 +1,19 @@ +{ + "name": "FeedbackSchema", + "properties": { + "id_context": { + "type": "string" + }, + "uuid_organization": { + "type": "string" + }, + "feedback_context": { + "type": "object" + }, + "feedback": { + "type": "object" + } + }, + "required": ["id_context", "uuid_organization", "feedback_context", "feedback"], + "additionalProperties": false +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5ce8c90e..db6b09fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -jinja2==2.10 -pgpubsub +jinja2>=2.10.1 numpy==1.14.3 PyYAML==3.12 cryptography==2.2.2 @@ -8,20 +7,19 @@ python-geoip==1.2 python-geoip-geolite2==2015.303 certifi==2018.4.16 ua-parser==0.8.0 -bokeh==0.12.16 +# bokeh==0.12.16 pandas==0.23.0 pycountry==18.2.23 scipy==1.1.0 matplotlib==2.2.2 seaborn==0.8.1 -hdbscan==0.8.13 alembic==1.0.8 enum34==1.1.6 tzwhere==3.0.3 pytz==2014.10 sqlalchemy_utils==0.33.3 pyspark==2.4.4 -es_retriever==1.0.0 +# es_retriever==1.0.0 psutil==5.4.6 psycopg2==2.7.5 yellowbrick==0.8 @@ -29,7 +27,7 @@ dateparser==0.7.0 pymisp==2.4.93 attrs==18.1.0 warlock==1.3.0 -jsonschema==2.6 +jsonschema==2.6.0 stringcase==1.2.0 prometheus_client==0.5.0 grafanalib==0.5.3 diff --git a/requirements_unit_tests.txt b/requirements_unit_tests.txt index 503c92af..e14217ec 100644 --- a/requirements_unit_tests.txt +++ b/requirements_unit_tests.txt @@ -1,4 +1,3 @@ -jinja2==2.10 pgpubsub numpy==1.14.3 PyYAML==3.12 diff --git a/src/baskerville/db/dashboard_models.py b/src/baskerville/db/dashboard_models.py new file mode 100644 index 00000000..601a0eb5 --- /dev/null +++ b/src/baskerville/db/dashboard_models.py @@ -0,0 +1,182 @@ +# Copyright (c) 2020, eQualit.ie inc. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from baskerville.db import Base +from baskerville.db.models import utcnow, SerializableMixin +from sqlalchemy import Column, Integer, ForeignKey, DateTime, Enum, String, \ + Boolean, BigInteger, Float, JSON, Text, TEXT +from sqlalchemy.orm import relationship +from passlib.apps import custom_app_context as pwd_context + +from baskerville.util.enums import UserCategoryEnum, FeedbackEnum, \ + FeedbackContextTypeEnum + + +class UserCategory(Base, SerializableMixin): + __tablename__ = 'user_categories' + id = Column(Integer, primary_key=True, autoincrement=True) + category = Column(Enum(UserCategoryEnum)) + # a user can belong to more than one category + users = relationship( + 'User', uselist=True, back_populates='category' + ) + + +class Organization(Base, SerializableMixin): + __tablename__ = 'organizations' + id = Column(BigInteger(), primary_key=True, autoincrement=True, unique=True) + uuid = Column(String(300), primary_key=True, unique=True) + name = Column(String(200), index=True) + details = Column(TEXT()) + registered = Column(Boolean(), default=False) + created_at = Column(DateTime(timezone=True), server_default=utcnow()) + updated_at = Column( + DateTime(timezone=True), nullable=True, onupdate=utcnow() + ) + users = relationship( + 'User', uselist=False, back_populates='organization' + ) + + +class User(Base, SerializableMixin): + __tablename__ = 'users' + id = Column(BigInteger(), primary_key=True, autoincrement=True, unique=True) + id_organization = Column(BigInteger(), ForeignKey('organizations.id')) + id_category = Column(Integer, ForeignKey('user_categories.id'), nullable=False) + username = Column(String(200), index=True) + first_name = Column(String(200), index=True) + last_name = Column(String(200), index=True) + email = Column(String(256), unique=True, nullable=False) + password_hash = Column(String(128)) + is_active = Column(Boolean()) + is_gitlab_login = Column(Boolean(), default=False) + is_admin = Column(Boolean(), default=False) + created_at = Column(DateTime(timezone=True), server_default=utcnow()) + updated_at = Column( + DateTime(timezone=True), nullable=True, onupdate=utcnow() + ) + + # users * - 1 category + category = relationship( + 'UserCategory', + foreign_keys=id_category, back_populates='users' + ) + organization = relationship( + 'Organization', + foreign_keys=id_organization, back_populates='users' + ) + runtimes = relationship( + 'Runtime', + uselist=False, + # back_populates='user' + ) + + _remove = ['password_hash'] + + def hash_password(self, password): + self.password_hash = pwd_context.encrypt(password) + return self.password_hash + + def verify_password(self, password): + return pwd_context.verify(password, self.password_hash) + + +class FeedbackContext(Base, SerializableMixin): + __tablename__ = 'feedback_contexts' + id = Column(BigInteger, primary_key=True, autoincrement=True, unique=True) + uuid_organization = Column(String(300), nullable=False) + reason = Column(Enum(FeedbackContextTypeEnum)) + reason_descr = Column(TEXT()) + start = Column(DateTime(timezone=True)) + stop = Column(DateTime(timezone=True)) + ip_count = Column(Integer) + notes = Column(TEXT) + progress_report = Column(TEXT) + pending = Column(Boolean(), default=True) + + +class Feedback(Base, SerializableMixin): + __tablename__ = 'feedback' + + id = Column(BigInteger, primary_key=True, autoincrement=True, unique=True) + id_feedback_context = Column(BigInteger(), ForeignKey('feedback_contexts.id'), nullable=False) + id_user = Column(BigInteger(), ForeignKey('users.id'), nullable=False) + uuid_request_set = Column(TEXT(), nullable=False) + prediction = Column(Integer, nullable=False) + score = Column(Float, nullable=False) + attack_prediction = Column(Float, nullable=False) + low_rate = Column(Boolean(), nullable=True) + ip = Column(String, nullable=False) + target = Column(String, nullable=False) + features = Column(JSON, nullable=False) + feedback = Column(Enum(FeedbackEnum)) + start = Column(DateTime(timezone=True), nullable=False) + stop = Column(DateTime(timezone=True), nullable=False) + submitted = Column(Boolean(), default=False) + created_at = Column(DateTime(timezone=True), server_default=utcnow()) + updated_at = Column( + DateTime(timezone=True), nullable=True, onupdate=utcnow() + ) + + user = relationship( + 'User', + foreign_keys=id_user + ) + request_set = relationship( + 'RequestSet', + primaryjoin='foreign(Feedback.uuid_request_set) == remote(RequestSet.uuid_request_set)' + ) + feedback_context = relationship( + 'FeedbackContext', + foreign_keys=id_feedback_context + ) + + +class SubmittedFeedback(Base, SerializableMixin): + __tablename__ = 'submitted_feedback' + + id = Column(BigInteger, primary_key=True, autoincrement=True, unique=True) + # not all feedback is part of an attack + id_context = Column(BigInteger(), ForeignKey('feedback_contexts.id'), nullable=False) + uuid_organization = Column(String(300), nullable=False) + uuid_request_set = Column(TEXT(), nullable=False) + prediction = Column(Integer, nullable=False) + score = Column(Float, nullable=False) + attack_prediction = Column(Float, nullable=False) + low_rate = Column(Boolean(), nullable=True) + features = Column(JSON, nullable=True) + feedback = Column(Enum(FeedbackEnum)) + start = Column(DateTime(timezone=True), nullable=True) + stop = Column(DateTime(timezone=True), nullable=True) + submitted_at = Column(DateTime(timezone=True)) + created_at = Column(DateTime(timezone=True), server_default=utcnow()) + updated_at = Column( + DateTime(timezone=True), nullable=True, onupdate=utcnow() + ) + + organization = relationship( + 'Organization', + primaryjoin='foreign(SubmittedFeedback.uuid_organization) == remote(Organization.uuid)' + ) + request_set = relationship( + 'RequestSet', + primaryjoin='foreign(SubmittedFeedback.uuid_request_set) == remote(RequestSet.uuid_request_set)' + ) + columns = [ + 'id', + 'id_context', + 'uuid_organization', + 'uuid_request_set', + 'prediction', + 'score', + 'attack_prediction', + 'low_rate', + 'features', + 'feedback', + 'start', + 'submitted_at', + 'updated_at' + ] \ No newline at end of file diff --git a/src/baskerville/db/models.py b/src/baskerville/db/models.py index 8d82fcc1..a8be4d8a 100644 --- a/src/baskerville/db/models.py +++ b/src/baskerville/db/models.py @@ -13,6 +13,7 @@ from sqlalchemy.sql import expression from baskerville.db import Base +from baskerville.util.helpers import SerializableMixin LONG_TEXT_LEN = 4294000000 @@ -28,33 +29,6 @@ def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" -class SerializableMixin(object): - def as_dict(self, extra_cols=(), remove=()): - """ - - :param set extra_cols: - :param set remove: - :return: - :rtype: dict[str, T] - """ - basic_attrs = {c.name: getattr(self, c.name) - for c in self.__table__.columns - if c not in remove} - extra_attrs = {} - if len(extra_cols) > 0: - for attr in extra_cols: - d = getattr(self, attr) - if d is None: - continue - if isinstance(d, list): - extra_attrs[attr] = [each.as_dict() for each in d] - else: - extra_attrs[attr] = d.as_dict() - basic_attrs.update(extra_attrs) - - return basic_attrs - - class Encryption(Base, SerializableMixin): __tablename__ = 'encryption' @@ -69,6 +43,7 @@ class Runtime(Base, SerializableMixin): id = Column(BigInteger, primary_key=True) id_encryption = Column(BigInteger, ForeignKey('encryption.id')) + id_user = Column(BigInteger, ForeignKey('users.id')) start = Column(DateTime(timezone=True)) stop = Column(DateTime(timezone=True)) target = Column(TEXT(), nullable=True) @@ -89,6 +64,15 @@ class Runtime(Base, SerializableMixin): 'Encryption', foreign_keys=id_encryption, back_populates='runtimes' ) + # runtimes * - 1 users + try: + from baskerville.db.dashboard_models import User + except: + pass + user = relationship( + 'User', + foreign_keys=id_user, back_populates='runtimes' + ) class RequestSet(Base, SerializableMixin): @@ -96,7 +80,7 @@ class RequestSet(Base, SerializableMixin): id = Column(BigInteger, primary_key=True) id_runtime = Column(BigInteger, ForeignKey('runtimes.id'), nullable=True) - id_request_sets = Column(TEXT()) + uuid_request_set = Column(TEXT()) target = Column(TEXT()) target_original = Column(TEXT()) ip = Column(String(45)) @@ -152,7 +136,7 @@ class RequestSet(Base, SerializableMixin): ) columns = [ - 'id_request_sets', + 'uuid_request_set', 'ip', 'target', 'target_original', @@ -226,8 +210,9 @@ class ModelTrainingSetLink(Base, SerializableMixin): class Attack(Base, SerializableMixin): __tablename__ = 'attacks' - id = Column(BigInteger, primary_key=True) + id = Column(BigInteger, primary_key=True, autoincrement=True) id_misp = Column(BigInteger) + uuid_org = Column(TEXT()) date = Column(DateTime(timezone=True)) start = Column(DateTime(timezone=True)) stop = Column(DateTime(timezone=True)) @@ -240,6 +225,7 @@ class Attack(Base, SerializableMixin): sync_stop = Column(DateTime(timezone=True)) processed = Column(Integer) notes = Column(TEXT) + progress_report = Column(TEXT) analysis_notebook = Column(TEXT) request_sets = relationship( @@ -250,6 +236,10 @@ class Attack(Base, SerializableMixin): 'Attribute', secondary='attribute_attack_link', back_populates='attacks' ) + organization = relationship( + 'Organization', + primaryjoin='foreign(Attack.uuid_org) == remote(Organization.uuid)' + ) class Attribute(Base, SerializableMixin): diff --git a/src/baskerville/models/anomaly_model.py b/src/baskerville/models/anomaly_model.py index 29db791c..5db2c3ab 100644 --- a/src/baskerville/models/anomaly_model.py +++ b/src/baskerville/models/anomaly_model.py @@ -240,4 +240,4 @@ def load(self, path, spark_session=None): self.indexes = {} for feature in self.categorical_string_features(): self.indexes[feature] = StringIndexerModel.load(self._get_index_path(path, feature)) - return self \ No newline at end of file + return self diff --git a/src/baskerville/models/config.py b/src/baskerville/models/config.py index fa6eea3b..34102299 100644 --- a/src/baskerville/models/config.py +++ b/src/baskerville/models/config.py @@ -182,12 +182,13 @@ class BaskervilleConfig(Config): - kafka : optional - depends on the chosen pipeline """ - database = None - elastic = None - misp = None - engine = None - kafka = None - spark = None + database: 'DatabaseConfig' = None + elastic: 'ElasticConfig' = None + misp: 'MispConfig' = None + engine: 'EngineConfig' = None + kafka: 'KafkaConfig' = None + spark: 'SparkConfig' = None + user_details: 'UserDetailsConfig' = None def __init__(self, config): super(BaskervilleConfig, self).__init__(config) @@ -203,6 +204,8 @@ def __init__(self, config): self.kafka = KafkaConfig(self.kafka) if self.spark: self.spark = SparkConfig(self.spark) + if self.user_details: + self.user_details = UserDetailsConfig(self.user_details) def validate(self): logger.debug('Validating BaskervilleConfig...') @@ -226,6 +229,10 @@ def validate(self): self.spark.validate() else: logger.debug('No spark config') + if self.user_details: + self.user_details.validate() + else: + logger.error('No user_details config') self._is_validated = True self._is_valid = len(self.errors) == 0 @@ -246,6 +253,7 @@ class EngineConfig(Config): simulation = None datetime_format = '%Y-%m-%d %H:%M:%S' cache_path = None + save_cache_to_storage = False storage_path = None cache_expire_time = None cache_load_past = False @@ -275,7 +283,7 @@ class EngineConfig(Config): ip_cache_passed_challenge_size = 100000 ip_cache_pending_ttl = 60 * 60 * 1 # 1h ip_cache_pending_size = 100000 - + save_to_storage = True white_list_ips = [] white_list_hosts = [] banjax_sql_update_filter_minutes = 90 @@ -285,6 +293,7 @@ class EngineConfig(Config): url_origin_ips = '' new_model_check_in_seconds = 300 kafka_send_by_partition = True + use_storage_for_request_cache = False def __init__(self, config, parent=None): super(EngineConfig, self).__init__(config, parent) @@ -724,9 +733,13 @@ class KafkaConfig(Config): """ bootstrap_servers = '0.0.0.0:9092' zookeeper = 'localhost:2181' - logs_topic = 'deflect.logs' + data_topic = 'deflect.logs' features_topic = 'features' + feedback_topic = 'feedback' + feedback_response_topic = '' predictions_topic = 'predictions' + register_topic = 'register' + auto_offset_reset = 'largest' banjax_command_topic = 'banjax_command_topic' banjax_report_topic = 'banjax_report_topic' security_protocol = 'PLAINTEXT' @@ -755,8 +768,12 @@ def validate(self): if not self.zookeeper: # kafka client can be used without zookeeper warnings.warn('Zookeeper url is empty.') - if not self.logs_topic: - warnings.warn('Logs topic is empty.') + if not self.data_topic: + warnings.warn('Data topic is empty.') + if not self.feedback_topic: + warnings.warn('Feedback topic is empty.') + if not self.feedback_response_topic: + warnings.warn('Feedback response topic is empty.') if not self.features_topic: warnings.warn('Features topic is empty') if not self.predictions_topic: @@ -808,6 +825,9 @@ class SparkConfig(Config): ssl_ui_enabled = False ssl_standalone_enabled = False ssl_history_server_enabled = False + s3_endpoint: None + s3_access_key: None + s3_secret_key: None def __init__(self, config): super(SparkConfig, self).__init__(config) @@ -962,9 +982,6 @@ class DataParsingConfig(Config): group_by_cols = ('client_request_host', 'client_ip') timestamp_column = '@timestamp' - def __init__(self, config_dict): - super().__init__(config_dict) - def validate(self): logger.debug('Validating DataParsingConfig...') from baskerville.models.log_parsers import LOG_PARSERS @@ -995,3 +1012,31 @@ def validate(self): self._is_validated = True return self + + +class UserDetailsConfig(Config): + username = '' + password = '' + organization_uuid = '' + + def validate(self): + logger.debug('Validating UserDetailsConfig...') + if not self.username: + self.add_error(ConfigError( + f'Please, provide a username', + ['username'], + exception_type=ValueError + )) + if not self.password: + self.add_error(ConfigError( + f'Please, provide a password', + ['password'], + exception_type=ValueError + )) + if not self.organization_uuid: + self.add_error(ConfigError( + f'Please, provide an organization_uuid', + ['organization_uuid'], + exception_type=ValueError + )) + diff --git a/src/baskerville/models/log_parsers.py b/src/baskerville/models/log_parsers.py index cb672ee7..4863189b 100644 --- a/src/baskerville/models/log_parsers.py +++ b/src/baskerville/models/log_parsers.py @@ -11,6 +11,7 @@ import warlock import pyspark.sql.types as T import pyspark.sql.functions as F +from baskerville.spark.schemas import NAME_TO_SCHEMA from baskerville.util.helpers import SerializableMixin @@ -125,13 +126,16 @@ def drop_if_missing_filter(self): class JSONLogSparkParser(JSONLogParser, SerializableMixin): + _name = '' def __init__(self, schema, drop_row_if_missing=None, sample=None): + self._name = schema['name'] self.str_to_type = { - 'string': (lambda sample: T.StringType()), - 'number': (lambda sample: T.FloatType()), - 'integer': (lambda sample: T.IntegerType()), + 'string': (lambda sample, name: T.StringType()), + 'number': (lambda sample, name: T.FloatType()), + 'integer': (lambda sample, name: T.IntegerType()), + 'object': (lambda sample, name: NAME_TO_SCHEMA[self._name].get(name)), } self.sample = sample @@ -144,7 +148,7 @@ def add_json_item(parent, name, value): if 'type' in value.keys(): parent.add(T.StructField( name, - self.str_to_type[value['type']](value.get('default')), + self.str_to_type[value['type']](value.get('default'), name), not value.get('required', False) )) return diff --git a/src/baskerville/models/pipeline_factory.py b/src/baskerville/models/pipeline_factory.py index 80ee076e..067ece42 100644 --- a/src/baskerville/models/pipeline_factory.py +++ b/src/baskerville/models/pipeline_factory.py @@ -63,6 +63,22 @@ def get_pipeline(self, run_type, config): from baskerville.models.pipeline_tasks.training_pipeline \ import set_up_training_pipeline return set_up_training_pipeline(config) + # elif run_type == RunType.dashboard_preprocessing: + # from baskerville.models.pipeline_tasks.dashboard_pipeline import \ + # set_up_dashboard_preprocessing_pipeline + # return set_up_dashboard_preprocessing_pipeline(config) + # elif run_type == RunType.dashboard: + # from baskerville.models.pipeline_tasks.dashboard_pipeline import \ + # set_up_dashboard_pipeline + # return set_up_dashboard_pipeline(config) + elif run_type == RunType.client_rawlog: + from baskerville.models.pipeline_tasks.client_pipeline import \ + set_up_client_rawlog_pipeline + return set_up_client_rawlog_pipeline(config) + elif run_type == RunType.feedback: + from baskerville.models.pipeline_tasks.feedback_pipeline import \ + set_up_feedback_pipeline + return set_up_feedback_pipeline(config) raise RuntimeError( 'Cannot set up a pipeline with the current configuration.' diff --git a/src/baskerville/models/pipeline_tasks/client_pipeline.py b/src/baskerville/models/pipeline_tasks/client_pipeline.py index 6c086a5a..c9aae597 100644 --- a/src/baskerville/models/pipeline_tasks/client_pipeline.py +++ b/src/baskerville/models/pipeline_tasks/client_pipeline.py @@ -10,7 +10,8 @@ from baskerville.models.pipeline_tasks.tasks import GetDataKafka, \ GenerateFeatures, \ Save, CacheSensitiveData, SendToKafka, \ - GetPredictions, MergeWithSensitiveData, RefreshCache, AttackDetection + GetPredictions, MergeWithSensitiveData, RefreshCache, AttackDetection, \ + GetDataLog, Predict, Challenge def set_up_preprocessing_pipeline(config: BaskervilleConfig): @@ -22,7 +23,7 @@ def set_up_preprocessing_pipeline(config: BaskervilleConfig): CacheSensitiveData(config), SendToKafka( config=config, - columns=('id_client', 'id_request_sets', 'features'), + columns=('id_client', 'uuid_request_set', 'features'), topic=config.kafka.features_topic, ), RefreshCache(config) @@ -34,6 +35,35 @@ def set_up_preprocessing_pipeline(config: BaskervilleConfig): return main_task +def set_up_client_rawlog_pipeline(config: BaskervilleConfig): + """ + Reads from raw log and sends features to Kafka for prediction + Note: this is mostly set up for testing + """ + task = [ + GetDataLog( + config, + steps=[ + GenerateFeatures(config), + # CacheSensitiveData(config), + SendToKafka( + config=config, + columns=('id_client', 'uuid_request_set', 'features'), + topic=config.kafka.features_topic, + ), + Predict(config), + AttackDetection(config), + Challenge(config), + # Save(config, json_cols=[]), + # RefreshCache(config), + ]), + ] + + main_task = Task(config, task) + main_task.name = 'Preprocessing Pipeline' + return main_task + + def set_up_postprocessing_pipeline(config: BaskervilleConfig): tasks = [ GetPredictions( @@ -41,6 +71,7 @@ def set_up_postprocessing_pipeline(config: BaskervilleConfig): steps=[ MergeWithSensitiveData(config), AttackDetection(config), + Challenge(config), Save(config, json_cols=[]), ]), ] diff --git a/src/baskerville/models/pipeline_tasks/feedback_pipeline.py b/src/baskerville/models/pipeline_tasks/feedback_pipeline.py new file mode 100644 index 00000000..aa496c98 --- /dev/null +++ b/src/baskerville/models/pipeline_tasks/feedback_pipeline.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020, eQualit.ie inc. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from baskerville.db.dashboard_models import SubmittedFeedback +from baskerville.models.pipeline_tasks.tasks_base import Task +from baskerville.models.config import BaskervilleConfig +from baskerville.models.pipeline_tasks.tasks import GetDataKafka, SaveFeedback, \ + SendToKafka + + +def set_up_feedback_pipeline(config: BaskervilleConfig): + """ + Feedback Pipeline listens to a kafka topic and every time_bucket time, + it gathers all the feedback and saves it to the Feedback table. + It is considered a low needs pipeline, since feedback is not going to be + coming in constantly, but rarely. + In config, kafka-> data_topic should be configured for the data input + """ + tasks = [ + GetDataKafka( + config, + steps=[ + SaveFeedback( + config, + table_model=SubmittedFeedback, + not_common=( + 'id_context', + 'top_uuid_organization' + ), + ), + SendToKafka( + config, + ('uuid_organization', 'id_context', 'success'), + 'feedback', + cmd='feedback_center', + cc_to_client=True, + client_only=True + ) + ]), + ] + + main_task = Task(config, tasks) + main_task.name = 'Feedback Pipeline' + return main_task diff --git a/src/baskerville/models/pipeline_tasks/prediction_pipeline.py b/src/baskerville/models/pipeline_tasks/prediction_pipeline.py index 647a8a32..2459d8d9 100644 --- a/src/baskerville/models/pipeline_tasks/prediction_pipeline.py +++ b/src/baskerville/models/pipeline_tasks/prediction_pipeline.py @@ -19,7 +19,7 @@ def set_up_prediction_pipeline(config: BaskervilleConfig): Predict(config), SendToKafka( config=config, - columns=('id_client', 'id_request_sets', 'prediction', 'score'), + columns=('id_client', 'uuid_request_set', 'prediction', 'score'), topic=config.kafka.predictions_topic, cc_to_client=True, ), diff --git a/src/baskerville/models/pipeline_tasks/rawlog_pipeline.py b/src/baskerville/models/pipeline_tasks/rawlog_pipeline.py index 6c91401e..7e688ff9 100644 --- a/src/baskerville/models/pipeline_tasks/rawlog_pipeline.py +++ b/src/baskerville/models/pipeline_tasks/rawlog_pipeline.py @@ -9,7 +9,7 @@ from baskerville.models.config import BaskervilleConfig from baskerville.models.pipeline_tasks.tasks import GenerateFeatures, \ Save, \ - Predict, GetDataLog, AttackDetection, RefreshCache + Predict, GetDataLog, AttackDetection, RefreshCache, Challenge def set_up_isac_rawlog_pipeline(config: BaskervilleConfig): @@ -19,9 +19,10 @@ def set_up_isac_rawlog_pipeline(config: BaskervilleConfig): steps=[ GenerateFeatures(config), Predict(config), - Save(config), AttackDetection(config), - RefreshCache(config), + Challenge(config), + Save(config), + # RefreshCache(config), ]), ] diff --git a/src/baskerville/models/pipeline_tasks/service_provider.py b/src/baskerville/models/pipeline_tasks/service_provider.py index cfabdb27..e9e1b906 100644 --- a/src/baskerville/models/pipeline_tasks/service_provider.py +++ b/src/baskerville/models/pipeline_tasks/service_provider.py @@ -78,9 +78,23 @@ def refresh_model(self): self.load_model_from_db() def create_runtime(self): + from baskerville.db.dashboard_models import User, Organization + org = self.tools.session.query(Organization).filter_by( + uuid=self.config.user_details.organization_uuid + ).first() + if not org: + raise ValueError(f'No such organization.') + + user = self.tools.session.query(User).filter_by( + username=self.config.user_details.username).filter_by( + id_organization=org.id + ).first() + if not user: + raise ValueError(f'No such user.') self.runtime = self.tools.create_runtime( start=self.start_time, - conf=self.config.engine + conf=self.config.engine, + id_user=user.id ) self.logger.info(f'Created runtime {self.runtime.id}') @@ -106,7 +120,8 @@ def initialize_request_set_cache_service(self): expire_if_longer_than=self.config.engine.cache_expire_time, path=os.path.join(self.config.engine.storage_path, FOLDER_CACHE), - logger=self.logger + logger=self.logger, + use_storage=self.config.engine.use_storage_for_request_cache ) if self.config.engine.cache_load_past: self.request_set_cache = self.request_set_cache.load( diff --git a/src/baskerville/models/pipeline_tasks/setup_pipeline.py b/src/baskerville/models/pipeline_tasks/setup_pipeline.py new file mode 100644 index 00000000..27424f09 --- /dev/null +++ b/src/baskerville/models/pipeline_tasks/setup_pipeline.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020, eQualit.ie inc. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from baskerville.models.pipeline_tasks.tasks_base import Task +from baskerville.models.config import BaskervilleConfig +from baskerville.models.pipeline_tasks.tasks import GetDataKafka, \ + GenerateFeatures, \ + Save, CacheSensitiveData, SendToKafka, \ + GetPredictions, MergeWithSensitiveData, RefreshCache, AttackDetection, \ + GetDataLog, Predict, Challenge + + +def set_up_registration_pipeline(config: BaskervilleConfig): + task = [ + GetDataKafka( + config, + steps=[ + Register(config), + CacheSensitiveData(config), + SendToKafka( + config=config, + columns=('id_client', 'uuid_request_set', 'features'), + topic=config.kafka.features_topic, + ), + RefreshCache(config) + ]), + ] + + main_task = Task(config, task) + main_task.name = 'Preprocessing Pipeline' + return main_task + + +def set_up_user_creation_pipeline(config: BaskervilleConfig): + task = [ + GetDataKafka( + config, + steps=[ + Register(config), + CacheSensitiveData(config), + SendToKafka( + config=config, + columns=('id_client', 'uuid_request_set', 'features'), + topic=config.kafka.features_topic, + ), + RefreshCache(config) + ]), + ] + + main_task = Task(config, task) + main_task.name = 'Preprocessing Pipeline' + return main_task + + +def set_up_client_rawlog_pipeline(config: BaskervilleConfig): + """ + Reads from raw log and sends features to Kafka for prediction + Note: this is mostly set up for testing + """ + task = [ + GetDataLog( + config, + steps=[ + GenerateFeatures(config), + # CacheSensitiveData(config), + SendToKafka( + config=config, + columns=('id_client', 'uuid_request_set', 'features'), + topic=config.kafka.features_topic, + ), + Predict(config), + AttackDetection(config), + Challenge(config), + # Save(config, json_cols=[]), + # RefreshCache(config), + ]), + ] + + main_task = Task(config, task) + main_task.name = 'Preprocessing Pipeline' + return main_task + + +def set_up_postprocessing_pipeline(config: BaskervilleConfig): + tasks = [ + GetPredictions( + config, + steps=[ + MergeWithSensitiveData(config), + AttackDetection(config), + Challenge(config), + Save(config, json_cols=[]), + ]), + ] + + main_task = Task(config, tasks) + main_task.name = 'Postprocessing Pipeline' + return main_task diff --git a/src/baskerville/models/pipeline_tasks/tasks.py b/src/baskerville/models/pipeline_tasks/tasks.py index 2501a8b9..34050663 100644 --- a/src/baskerville/models/pipeline_tasks/tasks.py +++ b/src/baskerville/models/pipeline_tasks/tasks.py @@ -15,14 +15,16 @@ import traceback import pyspark +from baskerville.db.dashboard_models import FeedbackContext from pyspark.sql import functions as F, types as T from pyspark.sql.types import StringType, StructField, StructType, DoubleType from pyspark.streaming import StreamingContext from functools import reduce from pyspark.sql import DataFrame +from sqlalchemy.exc import SQLAlchemyError from baskerville.db import get_jdbc_url -from baskerville.db.models import RequestSet, Model +from baskerville.db.models import RequestSet, Model, Attack from baskerville.models.banjax_report_consumer import BanjaxReportConsumer from baskerville.models.ip_cache import IPCache from baskerville.models.metrics.registry import metrics_registry @@ -31,9 +33,10 @@ from baskerville.models.config import BaskervilleConfig from baskerville.spark.helpers import map_to_array, load_test, \ save_df_to_table, columns_to_dict, get_window, set_unknown_prediction, \ - send_to_kafka_by_partition_id + send_to_kafka_by_partition_id, df_has_rows, get_dtype_for_col from baskerville.spark.schemas import features_schema, \ - prediction_schema + prediction_schema, get_features_schema, get_data_schema, \ + get_feedback_context_schema from kafka import KafkaProducer from dateutil.tz import tzutc @@ -68,11 +71,11 @@ def __init__( self.data_parser = self.config.engine.data_config.parser self.kafka_params = { 'metadata.broker.list': self.config.kafka.bootstrap_servers, - 'auto.offset.reset': 'largest', + 'auto.offset.reset': self.config.kafka.auto_offset_reset, 'group.id': self.config.kafka.consume_group, 'auto.create.topics.enable': 'true' } - self.consume_topic = self.config.kafka.logs_topic + self.consume_topic = self.config.kafka.data_topic def initialize(self): super(GetDataKafka, self).initialize() @@ -151,28 +154,8 @@ class GetFeatures(GetDataKafka): def __init__(self, config: BaskervilleConfig, steps: list = ()): super().__init__(config, steps) self.consume_topic = self.config.kafka.features_topic - self.data_schema = self.get_data_schema() - self.features_schema = self.get_features_schema() - - def get_data_schema(self) -> T.StructType: - return T.StructType( - [T.StructField('key', T.StringType()), - T.StructField('message', T.StringType())] - ) - - def get_features_schema(self) -> T.StructType: - schema = T.StructType([ - T.StructField("id_client", T.StringType(), True), - T.StructField("id_request_sets", T.StringType(), False) - ]) - features = T.StructType() - for feature in self.config.engine.all_features.keys(): - features.add(T.StructField( - name=feature, - dataType=T.StringType(), - nullable=True)) - schema.add(T.StructField("features", features)) - return schema + self.data_schema = get_data_schema() + self.features_schema = get_features_schema(self.config.engine.all_features) def get_data(self): self.df = self.spark.createDataFrame( @@ -188,7 +171,7 @@ def get_data(self): self.df = self.df.where(F.col('message.id_client').isNotNull()) \ .withColumn('features', F.col('message.features')) \ .withColumn('id_client', F.col('message.id_client')) \ - .withColumn('id_request_sets', F.col('message.id_request_sets')) \ + .withColumn('uuid_request_set', F.col('message.uuid_request_set')) \ .drop('message', 'key').persist(self.config.spark.storage_level) @@ -288,11 +271,10 @@ def initialize(self): step.initialize() def create_runtime(self): - self.runtime = self.tools.create_runtime( - file_name=self.current_log_path, - conf=self.config.engine, - comment=f'batch runtime {self.batch_i} of {self.batch_n}', - ) + self.service_provider.create_runtime() + self.runtime.file_name = self.current_log_path + self.runtime.comment=f'batch runtime {self.batch_i} of {self.batch_n}' + self.db_tools.session.commit() self.logger.info('Created runtime {}'.format(self.runtime.id)) def get_data(self): @@ -809,13 +791,13 @@ def add_ids(self): self.df = self.df.withColumn( 'id_client', F.lit(self.config.engine.id_client) ).withColumn( - 'id_request_sets', F.monotonically_increasing_id() + 'uuid_request_set', F.monotonically_increasing_id() ).withColumn( - 'id_request_sets', + 'uuid_request_set', F.concat_ws( '_', F.col('id_client'), - F.col('id_request_sets'), + F.col('uuid_request_set'), F.col('start').cast('long').cast('string')) ) # todo: monotonically_increasing_id guarantees uniqueness within @@ -880,12 +862,12 @@ def __init__( self, config, steps=(), - table_name=RequestSet.__tablename__, + table_model=RequestSet, json_cols=('features',), mode='append' ): super().__init__(config, steps) - self.table_name = table_name + self.table_model = table_model self.json_cols = json_cols self.mode = mode @@ -895,7 +877,7 @@ def run(self): if self.df.count() > 0: save_df_to_table( self.df, - self.table_name, + self.table_model.__tablename__, self.config.database.__dict__, json_cols=self.json_cols, storage_level=self.config.spark.storage_level, @@ -910,25 +892,36 @@ class Save(SaveDfInPostgres): """ Saves dataframe in Postgres (current backend) """ + def __init__(self, config, + steps=(), + table_model=RequestSet, + json_cols=('features',), + mode='append', + not_common=( + 'prediction', + 'model_version', + 'label', + 'id_attribute', + 'updated_at') + ): + self.not_common = set(not_common) + super().__init__(config, steps, table_model, json_cols, mode) def prepare_to_save(self): - request_set_columns = RequestSet.columns[:] - not_common = { - 'prediction', 'model_version', 'label', 'id_attribute', - 'updated_at' - }.difference(self.df.columns) + table_columns = self.table_model.columns[:] + not_common = self.not_common.difference(self.df.columns) for c in not_common: - request_set_columns.remove(c) + table_columns.remove(c) - if len(self.df.columns) < len(request_set_columns): + if len(self.df.columns) < len(table_columns): # log and let it blow up; we need to know that we cannot save self.logger.error( 'The input df columns are different than ' 'the actual table columns' ) - self.df = self.df.select(request_set_columns) + self.df = self.df.select(table_columns) self.df = self.df.withColumn( 'created_at', F.current_timestamp() @@ -942,6 +935,108 @@ def run(self): return self.df +class SaveFeedback(SaveDfInPostgres): + def __init__(self, config, + steps=(), + table_model=RequestSet, + json_cols=('features',), + mode='append', + not_common=( + 'prediction', + 'model_version', + 'label', + 'id_attribute', + 'updated_at') + ): + self.not_common = set(not_common) + super().__init__(config, steps, table_model, json_cols, mode) + + def upsert_feedback_context(self): + new_ = False + success = False + try: + self.df = self.df.withColumn('id_fc', F.lit(None)) + uuid_organization, feedback_context = self.df.select( + 'uuid_organization', 'feedback_context' + ).collect()[0] + print(">>>>>>>> feedback_context", feedback_context) + if feedback_context: + fc = self.db_tools.session.query(FeedbackContext).filter_by( + uuid_organization=uuid_organization + ).filter_by( + start=feedback_context.start + ).filter_by(stop=feedback_context.stop).first() + if not fc: + fc = FeedbackContext() + new_ = True + # fc.uuid_org = feedback_context.uuid_org + fc.reason = feedback_context.reason + fc.reason_descr = feedback_context.reason_descr + fc.start = feedback_context.start + fc.stop = feedback_context.stop + fc.ip_count = feedback_context.ip_count + fc.notes = feedback_context.notes + fc.progress_report = feedback_context.progress_report + if new_: + self.db_tools.session.add(fc) + self.db_tools.session.commit() + self.df = self.df.withColumn('id_fc', F.lit(fc.id)) + success = True + else: + self.logger.info('No feedback context.') + success = True + except SQLAlchemyError as sqle: + traceback.print_exc() + self.db_tools.session.rollback() + success = False + self.logger.error(str(sqle)) + # todo: what should the handling be? + except Exception as e: + traceback.print_exc() + success = False + self.logger.error(str(e)) + # todo: what should the handling be? + return success + + def prepare_to_save(self): + try: + success = self.upsert_feedback_context() + self.df.show() + success = True + if success: + # explode submitted feedback first + # updated feedback will be inserted and identical uuid_request_set + # can be filtered out with created_at or max(id) + self.df = self.df.select( + 'uuid_organization', + 'id_context', + F.col('id_fc').alias('sumbitted_context_id'), + F.explode('feedback').alias('feedback') + ).cache() + self.df.show() + self.df = self.df.select( + F.col('uuid_organization').alias('top_uuid_organization'), + F.col('id_context').alias('client_id_context'), + F.col('sumbitted_context_id'), + *[F.col(f'feedback.{c}').alias(c) for c in self.table_model.columns] + ).cache() + self.df = self.df.withColumnRenamed('id_context', 'client_id_context') + self.df = self.df.drop('updated_at') + self.df = self.df.withColumn('id_context', F.col('sumbitted_context_id')).drop('sumbitted_context_id') + self.df.show() + Save.prepare_to_save(self) + self.df = SaveDfInPostgres.run(self) + self.df = self.df.groupBy('uuid_organization', 'id_context').count().toDF() + self.df = self.df.withColumn('success', F.lit(True)) + except: + self.df = self.df.withColumn('success', F.lit(False)) + + def run(self): + self.upsert_feedback_context() + self.prepare_to_save() + return self.df + + class RefreshCache(CacheTask): def run(self): self.service_provider.refresh_cache(self.df) @@ -981,7 +1076,7 @@ def run(self): ).option( 'ttl', self.ttl ).option( - 'key.column', 'id_request_sets' + 'key.column', 'uuid_request_set' ).save() self.df = super().run() return self.df @@ -1004,7 +1099,7 @@ def run(self): ).option( 'table', self.table_name ).option( - 'key.column', 'id_request_sets' + 'key.column', 'uuid_request_set' ).load().alias('redis_df') count = self.df.count() @@ -1013,8 +1108,8 @@ def run(self): self.df = self.df.alias('df') self.df = self.redis_df.join( - self.df, on=['id_client', 'id_request_sets'] - ).drop('df.id_client', 'df.id_request_sets') + self.df, on=['id_client', 'uuid_request_set'] + ).drop('df.id_client', 'df.uuid_request_set') if self.df and self.df.head(1): merge_count = self.df.count() @@ -1043,13 +1138,17 @@ def __init__( config: BaskervilleConfig, columns, topic, + cmd='prediction_center', cc_to_client=False, + client_only=True, steps: list = (), ): super().__init__(config, steps) self.columns = columns self.topic = topic + self.cmd = cmd self.cc_to_client = cc_to_client + self.client_only = client_only def run(self): self.logger.info(f'Sending to kafka topic \'{self.topic}\'...') @@ -1141,6 +1240,7 @@ def load_dataset(self, df, features): for feature in features: column = f'features.{feature}' feature_class = self.engine_conf.all_features[feature] + # fixme: bug: .alias(column) will give feature.feature_name dataset = dataset.withColumn(column, F.col(column).cast(feature_class.spark_type()).alias(column)) self.logger.debug(f'Loaded {dataset.count()} rows dataset.') @@ -1206,7 +1306,7 @@ def save(self): for f_name in self.feature_manager.active_feature_names: df = df.withColumn(f_name, F.col('features').getItem(f_name)) df.select( - 'id_request_sets', + 'uuid_request_set', 'prediction', 'score', 'stop', @@ -1214,7 +1314,7 @@ def save(self): ).write.format('io.tiledb.spark').option( 'uri', f'{get_default_data_path()}/tiledbstorage' ).option( - 'schema.dim.0.name', 'id_request_sets' + 'schema.dim.0.name', 'uuid_request_set' ).save() def run(self): @@ -1233,9 +1333,9 @@ def __init__(self, config, steps=()): super().__init__(config, steps) self.catalog = { 'table': {'namespace': 'default', 'name': 'request_sets'}, - 'rowkey': 'id_request_sets', + 'rowkey': 'uuid_request_set', 'columns': { - 'id_request_sets': {'cf': 'rowkey', 'col': 'id_request_sets', 'type': 'string'}, + 'uuid_request_set': {'cf': 'rowkey', 'col': 'uuid_request_set', 'type': 'string'}, 'prediction': {'cf': 'cf1', 'col': 'prediction', 'type': 'int'}, 'score': {'cf': 'cf1', 'col': 'score', 'type': 'double'}, 'stop': {'cf': 'cf1', 'col': 'stop', 'type': 'timestamp'}, @@ -1257,7 +1357,7 @@ def save(self): for f_name in self.feature_manager.active_feature_names: df = df.withColumn(f_name, F.col('features').getItem(f_name)) df.select( - 'id_request_sets', + 'uuid_request_set', 'prediction', 'score', 'stop', @@ -1281,7 +1381,7 @@ def save(self): for f_name in self.feature_manager.active_feature_names: df = df.withColumn(f_name, F.col('features').getItem(f_name)) df.select( - 'id_request_sets', + 'uuid_request_set', 'prediction', 'score', 'stop', @@ -1289,7 +1389,7 @@ def save(self): ).write.format('io.tiledb.spark').option( 'uri', f'{get_default_data_path()}/tiledbstorage' ).option( - 'schema.dim.0.name', 'id_request_sets' + 'schema.dim.0.name', 'uuid_request_set' ).save() def run(self): @@ -1306,17 +1406,14 @@ class AttackDetection(Task): def __init__(self, config, steps=()): super().__init__(config, steps) self.df_chunks = [] - self.white_list_ips = set(self.config.engine.white_list_ips) - self.df_white_list_hosts = None self.ip_cache = IPCache(config, self.logger) self.report_consumer = None self.banjax_thread = None self.register_metrics = config.engine.register_banjax_metrics - self.low_rate_attack_schema = T.StructType([T.StructField( - name='request_total', dataType=StringType(), nullable=True - )]) - self.producer = KafkaProducer( - bootstrap_servers=self.config.kafka.bootstrap_servers) + self.low_rate_attack_schema = None + self.time_filter = None + self.lra_condition = None + self.features_schema = get_features_schema(self.config.engine.all_features) self.origin_ips = OriginIPs( url=config.engine.url_origin_ips, logger=self.logger, @@ -1324,58 +1421,28 @@ def __init__(self, config, steps=()): ) def initialize(self): - global IP_ACC - # super(SaveStats, self).initialize() - if self.config.engine.white_list_hosts: - self.df_white_list_hosts = self.spark.createDataFrame( - [ - [host] for host in - set(self.config.engine.white_list_hosts) - ], ['target'])\ - .withColumn('white_list_host', F.lit(1)) - - from baskerville.spark.helpers import DictAccumulatorParam - IP_ACC = self.spark.sparkContext.accumulator(defaultdict(int), - DictAccumulatorParam( - defaultdict(int))) - - def send_to_kafka( - kafka_servers, topic, rows, cmd_name='challenge_host', - id_client=None - ): - """ - Creates a kafka producer and sends the rows one by one, - along with the specified command (challenge_[host, ip]) - :returns: False if something went wrong, true otherwise - """ - # global IP_ACC - try: - from kafka import KafkaProducer - producer = KafkaProducer( - bootstrap_servers=kafka_servers - ) - for row in rows: - from baskerville.spark.udfs import get_msg - message = get_msg(row, cmd_name) - producer.send(topic, get_msg(row, cmd_name)) - if id_client: - producer.send(f'{topic}.{id_client}', message) - # if cmd_name == 'challenge_ip': - # IP_ACC += {row: 1} - producer.flush() - except Exception: - import traceback - traceback.print_exc() - return False - return True - - self.udf_send_to_kafka = F.udf(send_to_kafka, T.BooleanType()) - + lr_attack_period = self.config.engine.low_rate_attack_period + lra_total_req = self.config.engine.low_rate_attack_total_request + # initialize these here to make sure spark session has been initialized + self.low_rate_attack_schema = T.StructType([T.StructField( + name='request_total', dataType=StringType(), nullable=True + )]) + self.time_filter = ( + F.abs(F.unix_timestamp(F.col('stop'))) - + F.abs(F.unix_timestamp(F.col('start'))) + ) + self.lra_condition = ( + ((F.col('features.request_total') > lr_attack_period[0]) & + (self.time_filter > lra_total_req[0])) | + ((F.col('features.request_total') > lr_attack_period[1]) & + (self.time_filter > lra_total_req[1])) + ) self.report_consumer = BanjaxReportConsumer(self.config, self.logger) if self.register_metrics: self.register_banjax_metrics() self.banjax_thread = threading.Thread(target=self.report_consumer.run) self.banjax_thread.start() + pass def finish_up(self): if self.banjax_thread: @@ -1476,43 +1543,25 @@ def get_attack_score(self): # ppp.persist(self.config.spark.storage_level) return df - def detect_low_rate_attack(self, df): + def detect_low_rate_attack(self): self.logger.info('Low rate attack detecting...') - lr_attack_period = self.config.engine.low_rate_attack_period - lra_total_req = self.config.engine.low_rate_attack_total_request - time_filter = ( - F.abs(F.unix_timestamp(df.stop)) - F.abs(F.unix_timestamp(df.start)) - ) # todo check features dtype and use from_json if necessary - df = df.withColumn('f', F.from_json('features', self.low_rate_attack_schema)) - df = df.withColumn( - 'f.request_total', - F.col('f.request_total').cast( + if get_dtype_for_col(self.df, 'features') == 'string': + self.df = self.df.withColumn( + 'features', + F.from_json('features', self.low_rate_attack_schema) + ) + self.df = self.df.withColumn( + 'features.request_total', + F.col('features.request_total').cast( T.DoubleType() - ).alias('f.request_total') + ).alias('features.request_total') + ).persist(self.config.spark.storage_level) + self.df = self.df.withColumn( + 'low_rate_attack', + F.when(self.lra_condition, 1.0).otherwise(0.0) ) - df_attackers = df.filter( - ((F.col('f.request_total') > lr_attack_period[0]) & - (time_filter > lra_total_req[0])) - | - ((F.col('f.request_total') > lr_attack_period[1]) & - (time_filter > lra_total_req[1])) - ).select( - 'ip', 'target', 'f.request_total', 'start' - ).withColumn('low_rate_attack', F.lit(1)) - - if df_attackers and df_attackers.head(1): - self.logger.info('Low rate attack -------------- ') - # fails with Null Pointer Exception - testing with head(1): - self.logger.info(df_attackers.show()) - df = df.join(df_attackers.select('ip', 'low_rate_attack'), on='ip', how='left') - df = df.fillna({'low_rate_attack': 0}) - else: - df = df.withColumn('low_rate_attack', F.lit(0)) - - return df - def apply_white_list_ips(self, ips): if not self.white_list_ips: return ips @@ -1545,52 +1594,147 @@ def detect_attack(self): (F.col('attack_score') > self.config.engine.attack_threshold) & (F.col('total') > self.config.engine.minimum_number_attackers), F.lit(1)).otherwise(F.lit(0))) - + # todo is this right? don't we need to join on ip too?? self.df = self.df.join( df_attack.select( ['target', 'attack_prediction'] ), on='target', how='left') - self.df = self.detect_low_rate_attack(self.df) - return df_attack + self.detect_low_rate_attack() + # return df_attack + return self.df - def send_challenge(self, df_attack): - df_ips = self.df.select('ip', 'target').where( - (F.col('attack_prediction') == 1) & (F.col('prediction') == 1) | - (F.col('low_rate_attack') == 1) - ).cache() - if self.config.engine.challenge == 'ip': - if not df_ips or not df_ips.head(1): - self.df = self.df.withColumn('challenged', F.lit(0)) - return + def updated_df_with_attacks(self, df_attack): + self.df = self.df.join(df_attack, on=[df_attack.uuid_request_set == self.df.uuid_request_set], how='left') - if self.df_white_list_hosts: - df_ips = df_ips.join(self.df_white_list_hosts, on='target', how='left').persist() - df_ips = df_ips.where(F.col('white_list_host').isNull()) + def run(self): + if get_dtype_for_col(self.df, 'features') == 'string': + # this can be true when running the raw log pipeline + self.df = self.df.withColumn( + "features", + F.from_json("features", self.features_schema) + ) + self.df = self.df.repartition('target').persist( + self.config.spark.storage_level + ) + self.classify_anomalies() + df_attack = self.detect_attack() + if not df_has_rows(df_attack): + self.updated_df_with_attacks(df_attack) + self.logger.info('No attacks detected...') + self.df = super().run() + return self.df + + +class Challenge(Task): + def __init__( + self, + config: BaskervilleConfig, steps=(), + attack_cols=('prediction', 'attack_prediction', 'low_rate_attack') + ): + super().__init__(config, steps) + self.attack_cols = attack_cols + self.white_list_ips = set(self.config.engine.white_list_ips) + self.df_white_list_hosts = None + self.attack_filter = None + self.producer = None + self.udf_send_to_kafka = None + + def initialize(self): + # global IP_ACC + # from baskerville.spark.helpers import DictAccumulatorParam + # IP_ACC = self.spark.sparkContext.accumulator(defaultdict(int), + # DictAccumulatorParam( + # defaultdict(int))) + self.attack_filter = self.get_attack_filter() + # self.producer = KafkaProducer( + # bootstrap_servers=self.config.kafka.bootstrap_servers) + if self.config.engine.white_list_hosts: + self.df_white_list_hosts = self.spark.createDataFrame( + [ + [host] for host in + set(self.config.engine.white_list_hosts) + ], ['target'])\ + .withColumn('white_list_host', F.lit(1)) - ips = [r['ip'] for r in df_ips.collect()] - ips = self.apply_white_list_ips(ips) - ips = self.apply_white_list_origin_ips(ips) - ips = self.ip_cache.update(ips) - num_records = len(ips) - if num_records > 0: - challenged_ips = self.spark.createDataFrame( - [[ip, 1] for ip in ips], ['ip', 'challenged'] + def send_to_kafka( + kafka_servers, topic, rows, cmd_name='challenge_host', + id_client=None + ): + """ + Creates a kafka producer and sends the rows one by one, + along with the specified command (challenge_[host, ip]) + :returns: False if something went wrong, true otherwise + """ + # global IP_ACC + try: + from kafka import KafkaProducer + producer = KafkaProducer( + bootstrap_servers=kafka_servers ) - self.df = self.df.join(challenged_ips, on='ip', how='left') - self.df = self.df.fillna({'challenged': 0}) - - self.logger.info( - f'Sending {num_records} IP challenge commands to ' - f'kafka topic \'{self.config.kafka.banjax_command_topic}\'...') - for ip in ips: - message = json.dumps( - {'name': 'challenge_ip', 'value': ip} - ).encode('utf-8') - self.producer.send(self.config.kafka.banjax_command_topic, message) - self.producer.flush() + for row in rows: + from baskerville.spark.udfs import get_msg + message = get_msg(row, cmd_name) + producer.send(topic, get_msg(row, cmd_name)) + if id_client: + producer.send(f'{topic}.{id_client}', message) + # if cmd_name == 'challenge_ip': + # IP_ACC += {row: 1} + producer.flush() + except Exception: + import traceback + traceback.print_exc() + return False + return True + + self.udf_send_to_kafka = F.udf(send_to_kafka, T.BooleanType()) + + def get_attack_filter(self): + filter_ = None + for f_ in [(F.col(a) == 1) for a in self.attack_cols]: + if filter_ is None: + filter_ = f_ else: - self.df = self.df.withColumn('challenged', F.lit(0)) + filter_ = filter_ & f_ + return filter_ + + def send_challenge(self): + df_ips = self.get_attack_df() + if self.config.engine.challenge == 'ip': + if not df_has_rows(df_ips): + self.logger.debug('No attacks to be challenged...') + return + if self.df_white_list_hosts: + df_ips = df_ips.join( + self.df_white_list_hosts, on='target', how='left' + ).persist() + df_ips = df_ips.where(F.col('white_list_host').isNull()) + if df_has_rows(df_ips): + ips = [r['ip'] for r in df_ips.collect()] + ips = self.apply_white_list_ips(ips) + ips = self.apply_white_list_origin_ips(ips) + ips = self.ip_cache.update(ips) + num_records = len(ips) + if num_records > 0: + # challenged_ips = self.spark.createDataFrame( + # [[ip, 1] for ip in ips], ['ip', 'challenged'] + # ) + self.df = self.df.withColumn( + 'challenged', + F.when(F.col('ip').isin(F.lit(ips)), 1).otherwise(0) + ) + # self.df = self.df.join(challenged_ips, on='ip', how='left') + # self.df = self.df.fillna({'challenged': 0}) + + self.logger.info( + f'Sending {num_records} IP challenge commands to ' + f'kafka topic \'{self.config.kafka.banjax_command_topic}\'...') + for ip in ips: + message = json.dumps( + {'name': 'challenge_ip', 'value': ip} + ).encode('utf-8') + # self.producer.send(self.config.kafka.banjax_command_topic, message) + # self.producer.flush() # # return @@ -1670,17 +1814,30 @@ def send_challenge(self, df_attack): # else: # self.logger.debug('No challenge flag is set, moving on...') + def get_attack_df(self): + return self.df.select('ip', 'target').where(self.attack_filter).cache() + + def filter_out_load_test(self): + if self.config.engine.load_test: + self.df = self.df.select( + "*" + ).where( + ~F.col('ip').contains('_load_test') + ).persist(self.config.spark.storage_level) + self.logger.debug( + 'Filtering out the load test duplications before challenging..' + ) + def run(self): - # self.df = self.df.withColumn("features", F.to_json("features")) - self.df = self.df.repartition('target').persist( - self.config.spark.storage_level - ) - self.classify_anomalies() - df_attack = self.detect_attack() - if df_attack and df_attack.head(1): - self.send_challenge(df_attack) + if df_has_rows(self.df): + self.df = self.df.withColumn('challenged', F.lit(0)) + self.filter_out_load_test() + self.df.select( + F.col('target').contains('_load_test') + ).show() + self.send_challenge() else: - self.logger.info('No attacks detected...') + self.logger.info('Nothing to be challenged...') self.df = super().run() return self.df diff --git a/src/baskerville/models/pipelines.py b/src/baskerville/models/pipelines.py index 5bc8df9e..33c34004 100644 --- a/src/baskerville/models/pipelines.py +++ b/src/baskerville/models/pipelines.py @@ -12,7 +12,10 @@ from baskerville.models.base_spark import SparkPipelineBase from pyspark.streaming import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +try: + from pyspark.streaming.kafka import KafkaUtils +except ImportError: + print('Cannot import KafkaUtils - check pyspark version') class ElasticsearchPipeline(SparkPipelineBase): @@ -296,7 +299,7 @@ def run(self): kafkaStream = KafkaUtils.createDirectStream( self.ssc, - [self.kafka_conf.logs_topic], + [self.kafka_conf.data_topic], kafkaParams=kafkaParams, # fromOffsets={TopicAndPartition( # self.kafka_conf.consume_topic, 0): 0} @@ -463,7 +466,7 @@ def run(self): kafkaStream = KafkaUtils.createDirectStream( self.ssc, - [self.kafka_conf.logs_topic], + [self.kafka_conf.data_topic], { # 'bootstrap.servers': self.kafka_conf.zookeeper, 'metadata.broker.list': self.kafka_conf.url, diff --git a/src/baskerville/models/request_set_cache.py b/src/baskerville/models/request_set_cache.py index b168e47d..25fbfb1d 100644 --- a/src/baskerville/models/request_set_cache.py +++ b/src/baskerville/models/request_set_cache.py @@ -31,7 +31,8 @@ def __init__( session_getter=get_spark_session, group_by_fields=('target', 'ip'), format_='parquet', - path='request_set_cache' + path='request_set_cache', + use_storage=False, ): self.__cache = None self.__persistent_cache = None @@ -54,17 +55,21 @@ def __init__( self._count = 0 self._last_updated = datetime.datetime.utcnow() self._changed = False - self.file_manager = FileManager(path, self.session_getter()) - - self.file_name = os.path.join( - path, f'{self.__class__.__name__}.{self.format_}') - self.temp_file_name = os.path.join( - path, f'{self.__class__.__name__}temp.{self.format_}') - - if self.file_manager.path_exists(self.file_name): - self.file_manager.delete_path(self.file_name) - if self.file_manager.path_exists(self.temp_file_name): - self.file_manager.delete_path(self.temp_file_name) + self._use_storage = use_storage + if self._use_storage: + self.file_manager = FileManager(path, self.session_getter()) + + self.file_names = [ + os.path.join(path, f'{self.__class__.__name__}_A.{self.format_}'), + os.path.join(path, f'{self.__class__.__name__}_B.{self.format_}') + ] + self.file_name_index = 0 + + for f in self.file_names: + if self.file_manager.path_exists(f): + self.file_manager.delete_path(f) + else: + self.storage_df = None @property def cache(self): @@ -76,7 +81,20 @@ def persistent_cache(self): @property def persistent_cache_file(self): - return self.file_name + return self.file_names[self.file_name_index] + + @property + def next_persistent_cache_file(self): + if self.file_name_index == 0: + return self.file_names[1] + else: + return self.file_names[0] + + def alternate_persistent_cache_file(self): + if self.file_name_index == 0: + self.file_name_index = 1 + else: + self.file_name_index = 0 def _get_load_q(self): return f'''(SELECT * @@ -242,27 +260,49 @@ def filter_by(self, df, columns=None): if not columns: columns = df.columns - if self.file_manager.path_exists(self.persistent_cache_file): - self.__cache = self.session_getter().read.format( - self.format_ - ).load(self.persistent_cache_file).join( - df, - on=columns, - how='inner' - ).drop( - 'a.ip' - ) #.persist(self.storage_level) + if self._use_storage: + if self.file_manager.path_exists(self.persistent_cache_file): + self.__cache = self.session_getter().read.format( + self.format_ + ).load(self.persistent_cache_file).join( + df, + on=columns, + how='inner' + ).drop( + 'a.ip' + ) #.persist(self.storage_level) + else: + if self.__cache: + self.__cache = self.__cache.join( + df, + on=columns, + how='inner' + ).drop( + 'a.ip' + )# .persist(self.storage_level) + else: + self.load_empty(self.schema) else: - if self.__cache: - self.__cache = self.__cache.join( + if self.storage_df: + self.__cache = self.storage_df.join( df, on=columns, how='inner' ).drop( 'a.ip' - )# .persist(self.storage_level) + ) #.persist(self.storage_level) else: - self.load_empty(self.schema) + if self.__cache: + self.__cache = self.__cache.join( + df, + on=columns, + how='inner' + ).drop( + 'a.ip' + )# .persist(self.storage_level) + else: + self.load_empty(self.schema) + # if self.__persistent_cache: # self.__cache = self.__persistent_cache.join( @@ -298,7 +338,7 @@ def update_self( 'prediction', 'r', 'score', 'to_update', 'id', 'id_runtime', 'features', 'start', 'stop', 'subset_count', 'num_requests', 'total_seconds', 'time_bucket', 'model_version', 'to_update', - 'label', 'id_attribute', 'id_request_sets', 'created_at', + 'label', 'id_attribute', 'uuid_request_set', 'created_at', 'dt', 'id_client' ] now = datetime.datetime.utcnow() @@ -311,14 +351,20 @@ def update_self( source_df = source_df.select(columns) # read the whole thing again - if self.file_manager.path_exists(self.file_name): - if self.__persistent_cache: - self.__persistent_cache.unpersist() - self.__persistent_cache = self.session_getter().read.format( - self.format_ - ).load( - self.file_name - )# .persist(self.storage_level) + if self._use_storage: + if self.file_manager.path_exists(self.persistent_cache_file): + if self.__persistent_cache: + self.__persistent_cache.unpersist() + self.__persistent_cache = self.session_getter().read.format( + self.format_ + ).load( + self.persistent_cache_file + )# .persist(self.storage_level) + else: + if self.storage_df: + if self.__persistent_cache: + self.__persistent_cache.unpersist() + self.__persistent_cache = self.storage_df # http://www.learnbymarketing.com/1100/pyspark-joins-by-example/ self.__persistent_cache = source_df.rdd.toDF(source_df.schema).join( @@ -364,17 +410,17 @@ def update_self( '*' ).where(F.col('updated_at') >= update_date) - # write back to parquet - different file/folder though - # because self.parquet_name is already in use - # rename temp to self.parquet_name - if self.file_manager.path_exists(self.temp_file_name): - self.file_manager.delete_path(self.temp_file_name) - - self.__persistent_cache.write.mode( - 'overwrite' - ).format( - self.format_ - ).save(self.temp_file_name) + if self._use_storage: + # write back to parquet - different file/folder though + self.__persistent_cache.write.mode( + 'overwrite' + ).format( + self.format_ + ).save(self.next_persistent_cache_file) + else: + spark = self.session_getter() + self.storage_df = spark.createDataFrame( + self.__persistent_cache.collect(), self.__persistent_cache.schema) # we don't need anything in memory anymore source_df.unpersist(blocking=True) @@ -382,11 +428,10 @@ def update_self( del source_df self.empty_all() - # rename temp to self.parquet_name - if self.file_manager.path_exists(self.file_name): - self.file_manager.delete_path(self.file_name) - - self.file_manager.rename_path(self.temp_file_name, self.file_name) + # if self.file_manager.path_exists(self.persistent_cache_file): + # self.file_manager.delete_path(self.persistent_cache_file) + if self._use_storage: + self.alternate_persistent_cache_file() def refresh(self, update_date, hosts, extra_filters=None): df = self._load( @@ -452,11 +497,12 @@ def empty(self): def empty_all(self): if self.__cache is not None: self.__cache.unpersist(blocking=True) + self.__cache = None + if self.__persistent_cache is not None: self.__persistent_cache.unpersist(blocking=True) - - self.__cache = None self.__persistent_cache = None + gc.collect() self.session_getter().sparkContext._jvm.System.gc() diff --git a/src/baskerville/simulation/real_timeish_feature_vector_simulation.py b/src/baskerville/simulation/real_timeish_feature_vector_simulation.py index a2adf317..b0c350b2 100644 --- a/src/baskerville/simulation/real_timeish_feature_vector_simulation.py +++ b/src/baskerville/simulation/real_timeish_feature_vector_simulation.py @@ -28,7 +28,7 @@ def simulation( spark=None, ): """ - Loads feature vectors with id_client and id_request_sets and publishes them one by + Loads feature vectors with id_client and uuid_request_set and publishes them one by one with some random delay to the defined topic. This is used :param str path: the path to feature vector samples :param str kafka_url: the url to kafka, defaults to '0.0.0.0:9092' @@ -38,7 +38,7 @@ def simulation( """ if topic_name: - active_columns = ['id_client', 'id_request_sets', 'features'] + active_columns = ['id_client', 'uuid_request_set', 'features'] if not spark: from baskerville.spark import get_spark_session @@ -62,7 +62,7 @@ def send_to_kafka(id_client, id_request_sets, features): json.dumps( { 'id_client': id_client, - 'id_request_sets': id_request_sets, + 'uuid_request_set': id_request_sets, 'features': features }).encode('utf-8') ) @@ -104,7 +104,7 @@ def send_to_kafka(id_client, id_request_sets, features): # 'id_client', # F.monotonically_increasing_id() # ).withColumn( -# 'id_request_sets', F.monotonically_increasing_id() +# 'uuid_request_set', F.monotonically_increasing_id() # ) -# data.select('id_client', 'id_request_sets', 'features').write.format('json').save( +# data.select('id_client', 'uuid_request_set', 'features').write.format('json').save( # '/path/to/baskerville/data/samples/sample_vectors') diff --git a/src/baskerville/spark/__init__.py b/src/baskerville/spark/__init__.py index bbf73bb2..4dd3b44e 100644 --- a/src/baskerville/spark/__init__.py +++ b/src/baskerville/spark/__init__.py @@ -176,6 +176,16 @@ def get_or_create_spark_session(spark_conf): .appName(spark_conf.app_name) \ .getOrCreate() + if spark_conf.s3_endpoint: + hadoop_config = spark._jsc.hadoopConfiguration() + hadoop_config.set('fs.s3n.impl', 'org.apache.hadoop.fs.s3native.NativeS3FileSystem') + hadoop_config.set('fs.s3a.impl', 'org.apache.hadoop.fs.s3a.S3AFileSystem') + hadoop_config.set('fs.s3a.aws.credentials.provider', 'org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider') + hadoop_config.set('com.amazonaws.services.s3.enableV4', 'true') + hadoop_config.set('fs.s3a.endpoint', spark_conf.s3_endpoint) + hadoop_config.set('fs.s3a.access.key', spark_conf.s3_access_key) + hadoop_config.set('fs.s3a.secret.key', spark_conf.s3_secret_key) + if spark_conf.log_level: spark.sparkContext.setLogLevel(spark_conf.log_level) return spark diff --git a/src/baskerville/spark/helpers.py b/src/baskerville/spark/helpers.py index d83b271a..1db8a23e 100644 --- a/src/baskerville/spark/helpers.py +++ b/src/baskerville/spark/helpers.py @@ -9,7 +9,7 @@ from baskerville.spark import get_spark_session from baskerville.util.enums import LabelEnum -from baskerville.util.helpers import TimeBucket +from baskerville.util.helpers import TimeBucket, get_logger from pyspark import AccumulatorParam from pyspark import StorageLevel from pyspark.sql import functions as F @@ -123,6 +123,28 @@ def save_df_to_table( ).mode(mode).save() +def load_df_from_table( + table_name_or_q, + db_config, + db_driver='org.postgresql.Driver', + where=None, + columns_to_keep=('*',) +): + spark = get_spark_session() + df = spark.read.format('jdbc').options( + url=db_config['db_url'], + driver=db_driver, + dbtable=table_name_or_q, + user=db_config['user'], + password=db_config['password'], + fetchsize=1000, + max_connections=200, + ).load() + if where: + df = df.where(where).select(*columns_to_keep) + return df + + def map_to_array(df, map_col, array_col, map_keys): """ Transforms map_col to array_col @@ -195,10 +217,13 @@ def load_test(df, load_test_num, storage_level): """ if load_test_num > 0: df = df.persist(storage_level) - + # print(f'-------- Initial df count {df.count()}') + initial_df = df for i in range(load_test_num - 1): - temp_df = df.withColumn( - 'client_ip', F.round(F.rand(42)).cast('string') + temp_df = initial_df.withColumn( + 'client_ip', F.concat( + F.round(F.rand(42)).cast('string'), F.lit('_load_test') + ) ) df = df.union(temp_df).persist(storage_level) @@ -262,7 +287,13 @@ def get_window(df, time_bucket: TimeBucket, storage_level: str): def send_to_kafka_by_partition_id( - df_to_send, bootstrap_servers, cmd_topic, cmd, id_client=None, udf_=None + df_to_send, + bootstrap_servers, + cmd_topic, + cmd, + id_client=None, + client_only=False, + udf_=None ): from baskerville.spark.udfs import udf_send_to_kafka df_to_send = df_to_send.withColumn('pid', F.spark_partition_id()).cache() @@ -282,7 +313,8 @@ def send_to_kafka_by_partition_id( F.lit(cmd_topic), F.col('rows'), F.lit(cmd), - F.lit('id_client') if id_client else F.lit(None) + F.lit('id_client') if id_client else F.lit(None), + F.lit(client_only) ) ) # False means something went wrong: @@ -290,3 +322,11 @@ def send_to_kafka_by_partition_id( F.col('sent_to_kafka') == False # noqa ).head(1)) return g_records + + +def df_has_rows(df): + return df and df.head(1) + + +def get_dtype_for_col(df, col): + return dict(df.dtypes).get(col) \ No newline at end of file diff --git a/src/baskerville/spark/schemas.py b/src/baskerville/spark/schemas.py index 733b3dcd..0da7de55 100644 --- a/src/baskerville/spark/schemas.py +++ b/src/baskerville/spark/schemas.py @@ -7,7 +7,6 @@ from pyspark.sql import types as T - cross_reference_schema = T.StructType([ T.StructField("label", T.IntegerType(), False), T.StructField("id_attribute", T.IntegerType(), False) @@ -15,13 +14,13 @@ features_schema = T.StructType([ T.StructField("id_client", T.StringType(), True), - T.StructField("id_request_sets", T.StringType(), False), + T.StructField("uuid_request_set", T.StringType(), False), T.StructField("features", T.StringType(), False) ]) prediction_schema = T.StructType([ T.StructField("id_client", T.StringType(), False), - T.StructField("id_request_sets", T.StringType(), False), + T.StructField("uuid_request_set", T.StringType(), False), T.StructField("prediction", T.FloatType(), False), T.StructField("score", T.FloatType(), False) ]) @@ -41,3 +40,73 @@ T.StructField("old_num_requests", T.IntegerType(), True), T.StructField("updated_at", T.TimestampType(), True) ]) + + +def get_features_schema(all_features: dict) -> T.StructType: + schema = T.StructType([ + T.StructField("id_client", T.StringType(), True), + T.StructField("uuid_request_set", T.StringType(), False) + ]) + features = T.StructType() + for feature in all_features.keys(): + features.add(T.StructField( + name=feature, + dataType=T.StringType(), + nullable=True)) + schema.add(T.StructField("features", features)) + return schema + + +def get_data_schema() -> T.StructType: + """ + Return the kafka data schema + """ + return T.StructType( + [T.StructField('key', T.StringType()), + T.StructField('message', T.StringType())] + ) + + +def get_feedback_context_schema() -> T.StructType: + return T.StructType( + [T.StructField('id', T.IntegerType()), + T.StructField('uuid_organization', T.StringType()), + T.StructField('reason', T.StringType()), + T.StructField('reason_descr', T.StringType()), + T.StructField('start', T.StringType()), + T.StructField('stop', T.StringType()), + T.StructField('ip_count', T.IntegerType()), + T.StructField('notes', T.StringType()), + T.StructField('progress_report', T.StringType()), + T.StructField('pending', T.BooleanType()), + ] + ) + + +def get_submitted_feedback_schema() -> T.ArrayType: + return T.ArrayType( + T.StructType([ + T.StructField('id', T.IntegerType()), + T.StructField('id_context', T.IntegerType()), + T.StructField('uuid_organization', T.StringType()), + T.StructField('uuid_request_set', T.StringType()), + T.StructField('prediction', T.IntegerType()), + T.StructField('score', T.FloatType()), + T.StructField('attack_prediction', T.FloatType()), + T.StructField('low_rate', T.BooleanType()), + feature_vectors_schema, + T.StructField('feedback', T.StringType()), + T.StructField('start', T.StringType()), + T.StructField('submitted_at', T.StringType()), + T.StructField('created_at', T.StringType()), + T.StructField('updated_at', T.StringType()) + ]) + ) + + +NAME_TO_SCHEMA = { + 'FeedbackSchema': { + 'feedback_context': get_feedback_context_schema(), + 'feedback': get_submitted_feedback_schema(), + } +} diff --git a/src/baskerville/spark/udfs.py b/src/baskerville/spark/udfs.py index d7f4fccc..9029acc0 100644 --- a/src/baskerville/spark/udfs.py +++ b/src/baskerville/spark/udfs.py @@ -14,7 +14,7 @@ from pyspark.ml.linalg import Vectors, VectorUDT from pyspark.sql import functions as F from pyspark.sql import types as T -from tzwhere import tzwhere +# from tzwhere import tzwhere import numpy as np @@ -252,12 +252,17 @@ def get_msg(row, cmd_name): return json.dumps( {'name': cmd_name, 'value': row} ).encode('utf-8') - elif cmd_name == 'prediction_center': + elif cmd_name == 'prediction_center' or 'feedback_center': return json.dumps(row.asDict()).encode('utf-8') def send_to_kafka( - kafka_servers, topic, rows, cmd_name='challenge_host', id_client=None + kafka_servers, + topic, + rows, + cmd_name='challenge_host', + id_client=None, + client_only=False, ): """ Creates a kafka producer and sends the rows one by one, @@ -267,11 +272,12 @@ def send_to_kafka( try: from kafka import KafkaProducer producer = KafkaProducer( - bootstrap_servers=kafka_servers + bootstrap_servers=kafka_servers, ) for row in rows: message = get_msg(row, cmd_name) - producer.send(topic, get_msg(row, cmd_name)) + if not client_only: + producer.send(topic, get_msg(row, cmd_name)) if id_client: producer.send(f'{topic}.{id_client}', message) producer.flush() diff --git a/src/baskerville/util/baskerville_tools.py b/src/baskerville/util/baskerville_tools.py index 122d9b3d..95d997b2 100644 --- a/src/baskerville/util/baskerville_tools.py +++ b/src/baskerville/util/baskerville_tools.py @@ -39,7 +39,8 @@ def create_runtime( file_name=None, processed=None, comment=None, - conf=None + conf=None, + id_user=None ): """ Create a record in runtimes table. @@ -54,9 +55,10 @@ def create_runtime( runtime.processed = processed runtime.comment = comment runtime.config = str(conf) - + runtime.id_user = id_user self.session.add(runtime) self.session.commit() + return runtime except Exception: self.session.rollback() diff --git a/src/baskerville/util/enums.py b/src/baskerville/util/enums.py index a0b3de85..8215e20e 100644 --- a/src/baskerville/util/enums.py +++ b/src/baskerville/util/enums.py @@ -50,6 +50,9 @@ class RunType(BaseStrEnum): preprocessing = 'preprocessing' postprocessing = 'postprocessing' predicting = 'predicting' + dashboard = 'dashboard' + client_rawlog = 'client_rawlog' + feedback = 'feedback' class Step(BaseStrEnum): @@ -133,3 +136,36 @@ class PartitionByEnum(BaseStrEnum): class ModelEnum(BaseStrEnum): AnomalyModelSklearn = "baskerville.models.anomaly_model_sklearn.AnomalyModelSklearn" AnomalyModel = "baskerville.models.anomaly_model.AnomalyModel" + + +class UserCategoryEnum(BaseStrEnum): + admin = 'Administrator' + guest = 'Guest' + user = 'User' + + +class FeedbackEnum(BaseStrEnum): + correct = 'correct' + incorrect = 'incorrect' + bot = 'bot' + not_bot = 'notbot' + none = '' + + +class FeedbackContextTypeEnum(BaseStrEnum): + attack = 'attack' + false_positive = 'false positive' + false_negative = 'false negative' + true_positive = 'true positive' + true_negative = 'true negative' + other = 'other' + + +FEEDBACK_CONTEXT_TO_DESCRIPTION = { + FeedbackContextTypeEnum.attack: 'Label request sets that were part of an attack', + FeedbackContextTypeEnum.false_positive: 'We marked something as bot, when it was not', + FeedbackContextTypeEnum.false_negative: 'We marked something as not bot, when it was', + FeedbackContextTypeEnum.true_positive: 'We did well (marked the bots correctly) and you want to tell us!', + FeedbackContextTypeEnum.true_negative: 'We did well (marked the normal traffic correctly) and you want to tell us!', + FeedbackContextTypeEnum.other: 'Anything else :)' +} \ No newline at end of file diff --git a/src/baskerville/util/file_manager.py b/src/baskerville/util/file_manager.py index 7a571fef..d289cda5 100644 --- a/src/baskerville/util/file_manager.py +++ b/src/baskerville/util/file_manager.py @@ -26,7 +26,7 @@ def __init__(self, path, spark_session=None): super().__init__() self.spark_session = spark_session - if path.startswith('hdfs://'): + if path.startswith('hdfs://') or path.startswith('s3a://'): if not self.spark_session: raise RuntimeError( 'You must pass a valid spark session if you use distributed storage') diff --git a/src/baskerville/util/helpers.py b/src/baskerville/util/helpers.py index 12067fa1..18045653 100644 --- a/src/baskerville/util/helpers.py +++ b/src/baskerville/util/helpers.py @@ -14,7 +14,7 @@ import yaml -from baskerville.util.enums import ModelEnum +from baskerville.util.enums import ModelEnum, BaseStrEnum FOLDER_MODELS = 'models' FOLDER_CACHE = 'cache' @@ -320,7 +320,7 @@ def to_dict(self, cols=()): :rtype: dict[str, T] """ if not cols: - if getattr(self, '__table__', None): + if getattr(self, '__table__') is not None: cols = self.__table__.columns basic_attrs = {c.name: getattr(self, c.name) for c in cols @@ -346,9 +346,9 @@ def to_dict(self, cols=()): if d is None: continue if isinstance(d, (list, tuple, set)): - extra_attrs[attr] = [each.as_dict() for each in d] + extra_attrs[attr] = [each.to_dict() for each in d] else: - extra_attrs[attr] = d.as_dict() + extra_attrs[attr] = d.to_dict() basic_attrs.update(extra_attrs) for k, v in basic_attrs.items(): @@ -356,6 +356,8 @@ def to_dict(self, cols=()): if hasattr(v, 'parent') and 'parent' not in v._remove: v._remove += ('parent') basic_attrs[k] = v.to_dict() + if isinstance(v, BaseStrEnum): + basic_attrs[k] = str(v) return basic_attrs diff --git a/src/baskerville/util/ksql_example.py b/src/baskerville/util/ksql_example.py index e69de29b..2ae75904 100644 --- a/src/baskerville/util/ksql_example.py +++ b/src/baskerville/util/ksql_example.py @@ -0,0 +1,22 @@ +import logging +from ksql import KSQLAPI +logging.basicConfig(level=logging.DEBUG) +client = KSQLAPI('http://0.0.0.0:8088') + +df = None +table_name = 'sensitive_data' +topic = 'predictions' +column_type = ['uuid_request_set bigint','ip varchar','target varchar', 'stop varchar'] +print(client.ksql('show tables')) +client.create_stream(table_name, column_type, topic) +print(client.query(f'select * from {table_name}', use_http2=True)) +print(client.ksql('show tables')) + +# client.create_stream_as(table_name='sensitive_data', +# select_columns=df.columns, +# src_table=src_table, +# kafka_topic='id_client.sensitive_data', +# value_format='json', +# conditions=conditions, +# partition_by='target', +# ) \ No newline at end of file diff --git a/src/baskerville/util/model_transfer.py b/src/baskerville/util/model_transfer.py index 1ffb0f52..399fde97 100644 --- a/src/baskerville/util/model_transfer.py +++ b/src/baskerville/util/model_transfer.py @@ -75,20 +75,23 @@ def model_transfer( for model in models_in: print(f'Getting model with id: {model.id}') model_out = Model() - model_out.scaler = model.scaler - model_out.classifier = model.classifier model_out.features = model.features model_out.algorithm = model.algorithm - model_out.analysis_notebook = model.analysis_notebook - model_out.created_at = model.created_at + model_out.scaler_type = model.scaler_type + model_out.parameters = model.parameters + model_out.recall = model.recall + model_out.precision = model.precision model_out.f1_score = model.f1_score + model_out.classifier = model.classifier + model_out.scaler = model.scaler + model_out.host_encoder = model.host_encoder model_out.n_training = model.n_training model_out.n_testing = model.n_testing + model_out.analysis_notebook = model.analysis_notebook model_out.notes = model.notes - model_out.parameters = model.parameters - model_out.precision = model.precision - model_out.recall = model.recall - model_out.request_sets = model.request_sets + model_out.threshold = model.threshold + model_out.request_sets = [] + model_out.created_at = model.created_at out_session.add(model_out) out_session.commit() except Exception: diff --git a/src/baskerville/util/origin_ips.py b/src/baskerville/util/origin_ips.py index 279de6b1..de5dea6a 100644 --- a/src/baskerville/util/origin_ips.py +++ b/src/baskerville/util/origin_ips.py @@ -20,6 +20,8 @@ def __init__(self, url, logger, refresh_period_in_seconds=300): self.refresh() def refresh(self): + if not self.url: + return if not self.last_timestamp or int(time.time() - self.last_timestamp) > self.refresh_period_in_seconds: self.last_timestamp = time.time() diff --git a/tests/unit/baskerville_tests/models_tests/test_base_spark.py b/tests/unit/baskerville_tests/models_tests/test_base_spark.py index ef98dea7..81cd5cc7 100644 --- a/tests/unit/baskerville_tests/models_tests/test_base_spark.py +++ b/tests/unit/baskerville_tests/models_tests/test_base_spark.py @@ -703,7 +703,7 @@ def test_save(self, mock_bytes, mock_instantiate_from_str): 'client_request_host': 'testhost', 'client_ip': '1', 'ip': '1', - 'id_request_sets': -1, + 'uuid_request_set': -1, 'id_attribute': '', 'id_runtime': -1, 'request_set_prediction': -1, @@ -732,7 +732,7 @@ def test_save(self, mock_bytes, mock_instantiate_from_str): 'client_request_host': 'other testhost', 'client_ip': '1', 'ip': '1', - 'id_request_sets': None, + 'uuid_request_set': None, 'id_attribute': '', 'id_runtime': -1, 'request_set_prediction': -1,