diff --git a/stix_taxii/CHANGELOG.md b/stix_taxii/CHANGELOG.md index 4c9d9ea0..d172811b 100644 --- a/stix_taxii/CHANGELOG.md +++ b/stix_taxii/CHANGELOG.md @@ -1,3 +1,9 @@ +# 3.2.0 +## Added +- Added support for IOC Retraction for STIX/TAXII version 1.x and 2.x. +- Added support for Batch Size configuration parameter for 2.x. +- Added support for fetching dynamic configuration parameters based on the version selected. To use the dynamic field population feature update your CE version to 6.0.0. + # 3.1.0 ## Added - Added support for pulling ipv4 and ipv6 and Domain for STIX/TAXII version 2.x. @@ -23,7 +29,7 @@ ## Added - Added log statements for better understanding. -# 2.0.1-beta +# 2.0.1 ## Fixed - Fixed an issue related to the SSL verification. diff --git a/stix_taxii/main.py b/stix_taxii/main.py index e486c37f..78982941 100644 --- a/stix_taxii/main.py +++ b/stix_taxii/main.py @@ -34,13 +34,14 @@ import pytz from cabby import create_client, exceptions from datetime import datetime, timedelta -from typing import Dict, List +from typing import Any, Callable, Dict, List, Optional from stix.core import STIXPackage from cybox.objects.file_object import File from cybox.objects.uri_object import URI from cybox.objects.domain_name_object import DomainName from urllib.parse import urlparse import re +import time import requests import tempfile import traceback @@ -60,6 +61,7 @@ get_configuration_parameters, STIXTAXIIException, add_ce_user_agent, + ensure_utc_aware, ) from .utils.constants import ( CONFIDENCE_TO_REPUTATION_MAPPINGS, @@ -68,13 +70,33 @@ MODULE_NAME, PLATFORM_NAME, PLUGIN_VERSION, - LIMIT, - BUNDLE_LIMIT, STIX_VERSION_1, STIX_VERSION_20, STIX_VERSION_21, SERVICE_TYPE, DATE_CONVERSION_STRING, + DISCOVERY_URL_V1, + DISCOVERY_URL_V2, + USERNAME_CONFIG, + PASSWORD_CONFIG, + COLLECTION_NAMES_CONFIG, + PAGINATION_METHOD_CONFIG_V2, + INITIAL_RANGE_CONFIG, + LOOK_BACK_CONFIG, + TYPE_V1, + TYPE_V2, + SEVERITY_V1, + SEVERITY_V2, + REPUTATION_CONFIG, + BATCH_SIZE_CONFIG_V20, + BATCH_SIZE_CONFIG_V21, + RETRACTION_INTERVAL_CONFIG, + RETRACTION, + IN_EXECUTION_MAX_RETRIES, + IN_EXECUTION_SLEEP_TIME, + VALIDITY_DISPLAY_FORMAT, + PROXY_ERROR_RESOLUTION, + CONNECTION_ERROR_RESOLUTION, ) @@ -124,6 +146,50 @@ def _get_plugin_info(self): ) return (PLATFORM_NAME, PLUGIN_VERSION) + def get_dynamic_fields(self): + """Get the dynamic fields based on STIX/TAXII version. + + Returns: + list: List of dynamic configuration fields. + """ + version = self.configuration.get("version", None) + + # Version-specific fields + if version == STIX_VERSION_1: + discovery_url = DISCOVERY_URL_V1 + type_config = TYPE_V1 + severity_config = SEVERITY_V1 + else: + discovery_url = DISCOVERY_URL_V2 + type_config = TYPE_V2 + severity_config = SEVERITY_V2 + + # Build configuration fields list + fields = [ + discovery_url, + USERNAME_CONFIG, + PASSWORD_CONFIG, + COLLECTION_NAMES_CONFIG, + INITIAL_RANGE_CONFIG, + LOOK_BACK_CONFIG, + type_config, + severity_config, + REPUTATION_CONFIG, + ] + + # Batch Size and Pagination Method - only applicable for version 2.x + if version in [STIX_VERSION_20, STIX_VERSION_21]: + fields.append(PAGINATION_METHOD_CONFIG_V2) + + if version == STIX_VERSION_20: + fields.append(BATCH_SIZE_CONFIG_V20) + if version == STIX_VERSION_21: + fields.append(BATCH_SIZE_CONFIG_V21) + + fields.append(RETRACTION_INTERVAL_CONFIG) + + return fields + def _filter_collections(self, all_collections, selected_collections): """Create or filter collection names. Args: @@ -144,8 +210,14 @@ def _filter_collections(self, all_collections, selected_collections): ) if missing_collections: self.logger.error( - f"{self.log_prefix}: Following collections could not be " - f"found - {', '.join(missing_collections)}." + message=( + f"{self.log_prefix}: Following collections could not be " + f"found - {', '.join(missing_collections)}." + ), + resolution=( + "Ensure the collection names are correct and " + "available on the server." + ) ) return list( set(selected_collections).intersection(set(all_collections)) @@ -193,175 +265,375 @@ def _extract_fields_from_indicator(self, indicator, observable): ) return self._ids[observable.idref] - def _extract_from_indicator(self, package_indicators): + def _is_indicator_expired_1x(self, indicator): + """Check if a STIX 1.x indicator is expired based on valid_time_positions. + + Args: + indicator: STIX 1.x indicator object. + + Returns: + tuple: (is_expired: bool, validity_times_str: str) + - is_expired: True if the indicator is expired, False otherwise. + - validity_times_str: Formatted string of validity time windows. + """ + # Check if valid_time_positions exists and is not empty + if not getattr(indicator, "valid_time_positions", None): + # No validity window means it's valid indefinitely + return False, "" + + current_time = pytz.utc.localize(datetime.now()) + + # Phase 1: Gather all validity windows + # and determine expiration status + validity_pairs = [] # List of (start_dt, end_dt, is_open_ended) + all_windows_expired = True + + for window in indicator.valid_time_positions: + if not window: + continue + + # Extract start_time + start_time_obj = getattr(window, "start_time", None) + start_dt = None + if start_time_obj and getattr(start_time_obj, "value", None): + start_dt = start_time_obj.value + if isinstance(start_dt, datetime) and start_dt.tzinfo is None: + start_dt = pytz.utc.localize(start_dt) + + # Extract end_time + end_time_obj = getattr(window, "end_time", None) + end_dt = None + is_open_ended = False + + if not end_time_obj or not getattr(end_time_obj, "value", None): + # Missing end_time implies validity "forever" + is_open_ended = True + all_windows_expired = False + else: + end_dt = end_time_obj.value + if isinstance(end_dt, datetime): + if end_dt.tzinfo is None: + end_dt = pytz.utc.localize(end_dt) + # Check if this window is still valid + if end_dt > current_time: + all_windows_expired = False + else: + # Non-datetime end_dt - treat as valid (don't expire) + all_windows_expired = False + + validity_pairs.append((start_dt, end_dt, is_open_ended)) + + # Phase 2: Build the formatted validity times string + validity_windows = [] + for start_dt, end_dt, is_open_ended in validity_pairs: + # Format start_dt + if isinstance(start_dt, datetime): + start_str = start_dt.strftime(VALIDITY_DISPLAY_FORMAT) + else: + start_str = "N/A" + + # Format end_dt + if is_open_ended: + end_str = "N/A" + elif isinstance(end_dt, datetime): + end_str = end_dt.strftime(VALIDITY_DISPLAY_FORMAT) + else: + end_str = str(end_dt) if end_dt else "N/A" + + validity_windows.append(f"Valid From: {start_str}, Valid Until: {end_str}") + + validity_times_str = ", ".join(validity_windows) if validity_windows else "" + + return all_windows_expired, validity_times_str + + def _extract_from_indicator_1x( + self, package_indicators, is_retraction: bool = False + ): """Extract ioc from indicators. Args: package_indicators (list): List of indicators. + is_retraction (bool): If True, return set of values instead of + Indicator objects. Returns: - list: List of indicators. - bool: True if all indicators are skipped. + tuple: (indicators, skipped_count) + - indicators: List of Indicator objects or set of values. + - skipped_count: Count of skipped indicators/observables. """ - indicators = [] - is_skipped_final = False + indicators = set() if is_retraction else [] + skipped_indicators = {} for indicator in package_indicators: - for observable in indicator.observables: - data = self._extract_fields_from_indicator( - indicator, observable + # Skip expired indicators (valid_time_positions check) + try: + expired, validity_times_str = self._is_indicator_expired_1x(indicator) + except Exception: + expired = False + validity_times_str = "" + if expired: + skipped_indicators["Expired"] = ( + skipped_indicators.get("Expired", 0) + 1 ) - if not observable.object_: - is_skipped_final = True - continue - properties = observable.object_.properties - if not properties: - is_skipped_final = True - continue + continue + for observable in indicator.observables: try: + data = self._extract_fields_from_indicator( + indicator, observable + ) + if not observable.object_: + skipped_indicators["Missing object"] = ( + skipped_indicators.get("Missing object", 0) + 1 + ) + continue + properties = observable.object_.properties + if not properties: + skipped_indicators["Missing properties"] = ( + skipped_indicators.get("Missing properties", 0) + 1 + ) + continue + # Build base comments from description + base_comment = str( + observable.description + or indicator.description + or "" + ) + # Append validity times if available + if validity_times_str: + if base_comment: + full_comment = f"{base_comment}, {validity_times_str}" + else: + full_comment = validity_times_str + else: + full_comment = base_comment + if ( type(properties) is File and properties.hashes and properties.hashes.md5 ): - indicators.append( - Indicator( - value=str(properties.hashes.md5), - type=IndicatorType.MD5, - **data, - comments=str( - observable.description - or indicator.description - or "" - ), + if is_retraction: + indicators.add(str(properties.hashes.md5)) + else: + indicators.append( + Indicator( + value=str(properties.hashes.md5), + type=IndicatorType.MD5, + **data, + comments=full_comment, + ) ) - ) elif ( type(properties) is File and properties.hashes and properties.hashes.sha256 ): - indicators.append( - Indicator( - value=str(properties.hashes.sha256), - type=IndicatorType.SHA256, - **data, - comments=str( - observable.description - or indicator.description - or "" - ), + if is_retraction: + indicators.add(str(properties.hashes.sha256)) + else: + indicators.append( + Indicator( + value=str(properties.hashes.sha256), + type=IndicatorType.SHA256, + **data, + comments=full_comment, + ) ) - ) elif type(properties) is URI and properties.value: - indicators.append( - Indicator( - value=str(properties.value), - type=IndicatorType.URL, - **data, - comments=str( - observable.description - or indicator.description - or "" - ), + if is_retraction: + indicators.add(str(properties.value)) + else: + indicators.append( + Indicator( + value=str(properties.value), + type=IndicatorType.URL, + **data, + comments=full_comment, + ) ) - ) elif type(properties) is DomainName and properties.value: - indicators.append( - Indicator( - value=str(properties.value), - type=getattr( - IndicatorType, "DOMAIN", IndicatorType.URL - ), - **data, - comments=str( - observable.description - or indicator.description - or "" - ), + if is_retraction: + indicators.add(str(properties.value)) + else: + indicators.append( + Indicator( + value=str(properties.value), + type=getattr( + IndicatorType, "DOMAIN", IndicatorType.URL + ), + **data, + comments=full_comment, + ) ) - ) else: - is_skipped_final = True - except Exception: - is_skipped_final = True - return indicators, is_skipped_final + prop_type = type(properties).__name__ + reason = f"Unsupported properties type '{prop_type}'" + skipped_indicators[reason] = ( + skipped_indicators.get(reason, 0) + 1 + ) + except Exception as e: + skipped_indicators["Exception"] = ( + skipped_indicators.get("Exception", 0) + 1 + ) + self.logger.error( + message=( + f"{self.log_prefix}: Skipping indicator. " + f"Exception: {e}" + ), + details=str(traceback.format_exc()), + ) + skipped_count = sum(skipped_indicators.values()) + if skipped_indicators: + skip_reasons_stats = ", ".join( + f"{reason}: {count}" for reason, count in skipped_indicators.items() + ) + self.logger.debug( + message=( + f"{self.log_prefix}:" + "Some indicators were skipped due to expiration, " + "missing fields or exceptions. Skip Stats: " + f"{skip_reasons_stats}" + ) + ) + return indicators, skipped_count - def _extract_from_observables(self, observables): + def _extract_from_observables_1x( + self, observables, is_retraction: bool = False + ): """Extract iocs from observables. Args: observables (list): List of observables. + is_retraction (bool): If True, return set of values instead of + Indicator objects. Returns: - list: List of indicators. - bool: True if all indicators are skipped. + tuple: (indicators, skipped_count) + - indicators: List of Indicator objects or set of values. + - skipped_count: Count of skipped observables. """ - indicators = [] - is_skipped = False + indicators = set() if is_retraction else [] + skipped_observables = {} for observable in observables: - if not observable.object_: - is_skipped = True - continue - properties = observable.object_.properties - if not properties: - is_skipped = True - continue try: + if not observable.object_: + skipped_observables["Missing object"] = ( + skipped_observables.get("Missing object", 0) + 1 + ) + continue + properties = observable.object_.properties + if not properties: + skipped_observables["Missing properties"] = ( + skipped_observables.get("Missing properties", 0) + 1 + ) + continue if ( type(properties) is File and properties.hashes and properties.hashes.md5 ): - indicators.append( - Indicator( - value=str(properties.hashes.md5), - type=IndicatorType.MD5, - **self._ids.get(observable.id_, {}), - comments=str(observable.description or ""), + if is_retraction: + indicators.add(str(properties.hashes.md5)) + else: + indicators.append( + Indicator( + value=str(properties.hashes.md5), + type=IndicatorType.MD5, + **self._ids.get(observable.id_, {}), + comments=str(observable.description or ""), + ) ) - ) elif ( type(properties) is File and properties.hashes and properties.hashes.sha256 ): - indicators.append( - Indicator( - value=str(properties.hashes.sha256), - type=IndicatorType.SHA256, - **self._ids.get(observable.id_, {}), - comments=str(observable.description or ""), + if is_retraction: + indicators.add(str(properties.hashes.sha256)) + else: + indicators.append( + Indicator( + value=str(properties.hashes.sha256), + type=IndicatorType.SHA256, + **self._ids.get(observable.id_, {}), + comments=str(observable.description or ""), + ) ) - ) elif type(properties) is URI and properties.value: - indicators.append( - Indicator( - value=str(properties.value), - type=IndicatorType.URL, - comments=str(observable.description or ""), + if is_retraction: + indicators.add(str(properties.value)) + else: + indicators.append( + Indicator( + value=str(properties.value), + type=IndicatorType.URL, + comments=str(observable.description or ""), + ) ) - ) elif type(properties) is DomainName and properties.value: - indicators.append( - Indicator( - value=str(properties.value), - type=getattr( - IndicatorType, "DOMAIN", IndicatorType.URL - ), - comments=str(observable.description or ""), + if is_retraction: + indicators.add(str(properties.value)) + else: + indicators.append( + Indicator( + value=str(properties.value), + type=getattr( + IndicatorType, "DOMAIN", IndicatorType.URL + ), + comments=str(observable.description or ""), + ) ) - ) else: - is_skipped = True - except Exception: - is_skipped = True - return indicators, is_skipped + prop_type = type(properties).__name__ + reason = f"Unsupported properties type '{prop_type}'" + skipped_observables[reason] = ( + skipped_observables.get(reason, 0) + 1 + ) + except Exception as e: + skipped_observables["Exception"] = ( + skipped_observables.get("Exception", 0) + 1 + ) + self.logger.error( + message=( + f"{self.log_prefix}: Skipping observable. " + f"Exception: {e}" + ), + details=str(traceback.format_exc()), + ) + skipped_count = sum(skipped_observables.values()) + if skipped_observables: + skip_reasons_stats = ", ".join( + f"{reason}: {count}" for reason, count in skipped_observables.items() + ) + self.logger.debug( + message=( + f"{self.log_prefix}:" + "Some observables were skipped due to " + "missing fields or exceptions. Skip Stats: " + f"{skip_reasons_stats}" + ) + ) + return indicators, skipped_count + + def _extract_indicators_1x(self, package, is_retraction: bool = False): + """Extract iocs from a STIX package. - def _extract_indicators(self, package): - """Extract iocs from a STIX package.""" + Args: + package: STIX package object. + is_retraction (bool): If True, return set of values instead of + Indicator objects. + + Returns: + tuple: (indicators, skipped_count) + - indicators: List of Indicator objects or set of values. + - skipped_count: Count of skipped indicators/observables. + """ if package.indicators: - return self._extract_from_indicator(package.indicators) + return self._extract_from_indicator_1x( + package.indicators, is_retraction + ) elif package.observables: - return self._extract_from_observables(package.observables) + return self._extract_from_observables_1x( + package.observables, is_retraction + ) else: - return [], True + return set() if is_retraction else [], 0 def _build_client(self, configuration): """Build client for TAXII. @@ -372,18 +644,12 @@ def _build_client(self, configuration): client: Client object. """ ( - _, discovery_url, username, password, - _, - _, - _, - _, - _, - _, - _, - ) = get_configuration_parameters(configuration) + ) = get_configuration_parameters( + configuration, keys=["discovery_url", "username", "password"] + ) parsed_url = urlparse(discovery_url) discovery_url = parsed_url.path if len(parsed_url.netloc.split(":")) > 1: @@ -446,10 +712,9 @@ def convert_string_to_datetime(self, collections_dict): """ try: if collections_dict and isinstance(collections_dict, dict): - for ( - collection_name, - str_datetime_value, - ) in collections_dict.items(): + for collection_name, str_datetime_value in ( + collections_dict.items() + ): if isinstance(str_datetime_value, str): collections_dict[collection_name] = str_to_datetime( string=str_datetime_value, @@ -477,10 +742,9 @@ def convert_datetime_to_string(self, collections_dict): """ try: if collections_dict and isinstance(collections_dict, dict): - for ( - collection_name, - datetime_value, - ) in collections_dict.items(): + for collection_name, datetime_value in ( + collections_dict.items() + ): if isinstance(datetime_value, datetime): collections_dict[collection_name] = ( datetime_value.strftime(DATE_CONVERSION_STRING) @@ -502,8 +766,8 @@ def handle_and_raise( err: Exception, err_msg: str, details_msg: str = "", - exc_type: Exception = STIXTAXIIException, if_raise: bool = True, + resolution: str = "", ): """Handle and raise an exception. @@ -515,37 +779,149 @@ def handle_and_raise( STIXTAXIIException. if_raise (bool, optional): Whether to raise the exception. Defaults to True. + resolution (str, optional): Resolution message for the error. + Defaults to empty string. """ self.logger.error( message=f"{self.log_prefix}: {err_msg} Error: {err}", details=details_msg, + resolution=resolution if resolution else None, ) if if_raise: - raise exc_type(err_msg) + raise STIXTAXIIException(err_msg) + + def _format_type_breakdown(self, type_counts, type_to_pull): + """Format type counts into a readable breakdown string. + + Args: + type_counts (dict): Dictionary mapping type values to counts. + type_to_pull (list): List of indicator types to pull. + Empty list means include all types. + + Returns: + str: Formatted breakdown string like "X SHA256, Y MD5, Z Domain(s)" + """ + type_labels = { + "sha256": "SHA256", + "md5": "MD5", + "url": "URL", + "ipv4": "IPv4", + "ipv6": "IPv6", + "domain": "Domain" + } + + # Determine which types to include in breakdown + types_to_show = type_to_pull if type_to_pull else list(type_labels.keys()) + + type_breakdown_parts = [] + for t in types_to_show: + count = type_counts.get(t, 0) + if count > 0: + label = type_labels.get(t, t.title()) + type_breakdown_parts.append(f"{count} {label}") + + if not type_breakdown_parts: + return "0" + + if len(type_breakdown_parts) == 1: + return type_breakdown_parts[0] + + return ", ".join(type_breakdown_parts[:-1]) + f" and {type_breakdown_parts[-1]}" + + def _filter_indicators_by_config( + self, indicators, type_to_pull, severity, reputation, type_counts + ): + """Filter indicators based on configuration parameters. + + Args: + indicators (list): List of Indicator objects. + type_to_pull (list): List of indicator types to pull. + Empty list means include all types. + severity (list): List of severity values to include. + Empty list means include all severities. + reputation (int): Minimum reputation value. + type_counts (dict): Dictionary to accumulate type counts into. + Updated in place. + + Returns: + tuple: (filtered_indicators, skipped_count) + - filtered_indicators: List of indicators that passed filters. + - skipped_count: Number of indicators filtered out. + """ + + def matches_type(indicator_type): + """Check if indicator type matches the configured types.""" + # Empty list means include all types + if not type_to_pull: + return True + return ( + (indicator_type is IndicatorType.SHA256 and "sha256" in type_to_pull) + or (indicator_type is IndicatorType.MD5 and "md5" in type_to_pull) + or (indicator_type is IndicatorType.URL and "url" in type_to_pull) + or ( + indicator_type is getattr(IndicatorType, "IPV4", IndicatorType.URL) + and "ipv4" in type_to_pull + ) + or ( + indicator_type is getattr(IndicatorType, "IPV6", IndicatorType.URL) + and "ipv6" in type_to_pull + ) + or ( + indicator_type is getattr(IndicatorType, "DOMAIN", IndicatorType.URL) + and "domain" in type_to_pull + ) + ) - def pull_1x(self, configuration, start_time): + def matches_severity(indicator_severity): + """Check if indicator severity matches the configured severities.""" + # Empty list means include all severities + if not severity: + return True + return indicator_severity.value in severity + + filtered_list = [] + for ind in indicators: + if ( + matches_severity(ind.severity) + and ind.reputation >= int(reputation) + and matches_type(ind.type) + ): + filtered_list.append(ind) + # Count by type (lowercase to match config keys) + type_key = ind.type.value.lower() + if type_key in type_counts: + type_counts[type_key] += 1 + + skipped_count = len(indicators) - len(filtered_list) + return filtered_list, skipped_count + + def pull_1x(self, configuration, start_time, is_retraction: bool = False): """Pull implementation for version 1.x. Args: configuration (dict): Configuration dictionary. start_time (datetime): Start time. + is_retraction (bool): If True, yield sets of values instead of + Indicator objects. Also filters out expired indicators. - Returns: - ValidationResult: Validation result. + Yields: + tuple: (indicators_batch, sub_checkpoint_dict) for each block. + indicators_batch is list of Indicators or set of values. """ ( - _, - _, - _, - _, collection_names, - _, - _, delay_config, - _, - _, - _, - ) = get_configuration_parameters(configuration) + type_to_pull, + severity, + reputation, + ) = get_configuration_parameters( + configuration, + keys=["collection_names", "delay", "type_to_pull", "severity", "reputation"] + ) + if delay_config and isinstance(delay_config, int): + delay_config = int(delay_config) + else: + delay_config = 0 self._ids = {} try: @@ -557,6 +933,7 @@ def pull_1x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=PROXY_ERROR_RESOLUTION, ) except requests.exceptions.ConnectionError as err: err_msg = ( @@ -567,6 +944,7 @@ def pull_1x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=CONNECTION_ERROR_RESOLUTION, ) except requests.exceptions.RequestException as err: err_msg = "Request Exception occurred." @@ -592,15 +970,20 @@ def pull_1x(self, configuration, start_time): f"{self.log_prefix}: Following collections will be" f" fetched - {', '.join(filtered_collections)}." ) - indicators = [] delay_time = int(delay_config) - - start_time = pytz.utc.localize( + start_time = ensure_utc_aware( start_time - timedelta(minutes=delay_time) ) - for collection in filtered_collections: + total_indicators = 0 + total_skipped = 0 + total_type_counts = { + "sha256": 0, "md5": 0, "url": 0, + "ipv4": 0, "ipv6": 0, "domain": 0 + } + + for collection_idx, collection in enumerate(filtered_collections): self.logger.debug( f"{self.log_prefix}: Parsing collection - " f"'{collection}'. Start time: {start_time}." @@ -618,6 +1001,7 @@ def pull_1x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=PROXY_ERROR_RESOLUTION, ) except requests.exceptions.ConnectionError as err: err_msg = ( @@ -628,6 +1012,7 @@ def pull_1x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=CONNECTION_ERROR_RESOLUTION, ) except requests.exceptions.RequestException as err: err_msg = "Request Exception occurred." @@ -649,35 +1034,71 @@ def pull_1x(self, configuration, start_time): block_id = 1 collection_indicator_count = 0 + collection_skip_count = 0 for block in content_blocks: try: temp = tempfile.TemporaryFile() temp.write(block.content) temp.seek(0) stix_package = STIXPackage.from_xml(temp) - extracted, is_skipped = self._extract_indicators( - stix_package - ) - indicators += extracted - collection_indicator_count += len(extracted) - total_log = ( - f"Total {collection_indicator_count} " - "indicator(s) pulled till now." + extracted, skipped_count = self._extract_indicators_1x( + stix_package, is_retraction ) - if is_skipped is True: + + if is_retraction: + # For retraction: yield set of values, no filtering + collection_indicator_count += len(extracted) + total_indicators += len(extracted) + collection_skip_count += skipped_count + total_skipped += skipped_count + self.logger.info( - f"{self.log_prefix}: Pulled {len(extracted)} " - f"indicator(s) from Block-{block_id}, some " - "indicator(s) might have been discarded." - f" {total_log}" + f"{self.log_prefix}: Extracted {len(extracted)} " + f"valid indicator(s) from Block-{block_id} " + f"for retraction check. Skipped {skipped_count} " + "indicator(s)." ) + temp.close() + + if extracted: + yield extracted, None else: + # Apply filtering for normal pull + filtered_batch, filter_skipped = ( + self._filter_indicators_by_config( + extracted, type_to_pull, severity, reputation, + total_type_counts + ) + ) + + collection_indicator_count += len(filtered_batch) + total_indicators += len(filtered_batch) + collection_skip_count += skipped_count + filter_skipped + total_skipped += skipped_count + filter_skipped + + total_log = ( + f"Total {collection_indicator_count} " + "indicator(s) pulled till now." + ) self.logger.info( - f"{self.log_prefix}: Pulled {len(extracted)} " - f"indicator(s) from Block-{block_id}." + f"{self.log_prefix}: Pulled {len(filtered_batch)} " + f"indicator(s) from Block-{block_id}. Skipped " + f"{skipped_count} indicator(s), filtered " + f"{filter_skipped} indicator(s)." f" {total_log}" ) - temp.close() + temp.close() + + # Build sub_checkpoint + sub_checkpoint = { + "collection": collection, + "collection_idx": collection_idx, + "block_id": block_id, + } + + if filtered_batch: + yield filtered_batch, sub_checkpoint + except Exception as e: err_msg = ( "Error occurred while extracting indicator(s)" @@ -689,163 +1110,262 @@ def pull_1x(self, configuration, start_time): details_msg=str(traceback.format_exc()), if_raise=False, ) - block_id += 1 - continue block_id += 1 + self.logger.info( - f"{self.log_prefix}: Completed pulling of" + f"{self.log_prefix}: Completed pulling of " f"indicator(s) from collection - '{collection}'." f" Total {collection_indicator_count}" " indicator(s) pulled." + f" {collection_skip_count} indicator(s) skipped/filtered." ) + type_breakdown_str = self._format_type_breakdown(total_type_counts, type_to_pull) + self.logger.info( f"{self.log_prefix}: Completed pulling of" " indicator(s) from collection(s) - " - f"{', '.join(filtered_collections)}." - f" Total {len(indicators)} indicator(s) pulled." + f"{', '.join(filtered_collections)}. " + f"Total {total_indicators} indicator(s) pulled, " + f"{total_skipped} skipped/filtered. " + f"Pull Stats: {type_breakdown_str} indicator(s) were fetched." ) - return indicators - def _extract_observables_2x(self, pattern: str, data: dict): + def _extract_observables_2x( + self, + pattern: str, + data: dict, + is_retraction: bool = False, + ): """Extract observables from a pattern. Args: pattern (str): The pattern to extract observables from. data (dict): The data to extract observables from. + is_retraction (bool): If True, return set of values instead of + Indicator objects. Returns: - list: List of observables. + tuple: (observables, skipped_count, sha256_count, md5_count, + skip_reason) + - observables: List of observables (or set of values if + is_retraction). + - skipped_count: Count of skipped indicators for this pattern. + - skip_reason: Reason for skipping (None if not skipped). """ sha256_count = 0 md5_count = 0 - observables = [] - is_skipped = False + observables = set() if is_retraction else [] + match_count = 0 + exception_count = 0 for kind in OBSERVABLE_REGEXES: matches = re.findall(kind["regex"], pattern, re.IGNORECASE) - if len(matches) == 0: - is_skipped = is_skipped or False - else: - is_skipped = is_skipped or True + match_count += len(matches) for match in matches: try: if ( kind["type"] == IndicatorType.SHA256 or kind["type"] == IndicatorType.MD5 ): - observables.append( - Indicator( - value=match.replace("'", ""), - type=kind["type"], - **data, + value = match.replace("'", "") + if is_retraction: + observables.add(value) + else: + observables.append( + Indicator( + value=value, + type=kind["type"], + **data, + ) ) - ) if kind["type"] == IndicatorType.SHA256: sha256_count += 1 elif kind["type"] == IndicatorType.MD5: md5_count += 1 else: if "ipv4" in pattern or "ipv6" in pattern: - observables.append( - Indicator( - value=match.replace("'", ""), - type=kind["type"], - **data, - ) - ) + value = match.replace("'", "") + else: + value = match[1].replace("'", "") + if is_retraction: + observables.add(value) else: observables.append( Indicator( - value=match[1].replace("'", ""), + value=value, type=kind["type"], **data, ) ) - except Exception: - is_skipped = True - return observables, not (is_skipped), sha256_count, md5_count + except Exception as e: + exception_count += 1 + self.logger.debug( + message=( + f"{self.log_prefix}: Skipping observable in " + f"indicator. Exception: {e}" + ), + details=str(traceback.format_exc()), + ) + skipped_count = 0 + skip_reason = None + if len(observables) == 0: + skipped_count = 1 + if exception_count > 0: + skip_reason = "Exception" + elif match_count == 0: + skip_reason = "No supported observables found" + else: + skip_reason = "Unsupported pattern" + return observables, skipped_count, sha256_count, md5_count, skip_reason - def _extract_indicators_2x(self, objects): + def _extract_indicators_2x(self, objects, is_retraction: bool = False): """Extract indicators from a list of objects. Args: objects (list): List of objects. + is_retraction (bool): If True, return set of values instead of + Indicator objects. Returns: - list: List of indicators. + tuple: (indicators, skipped_count, sha256_count, md5_count) + indicators is a list of Indicator objects or set of values + if is_retraction is True. """ - indicators = [] - is_skipped_final = False + indicators = set() if is_retraction else [] + skipped_indicators = {} modified_time = None total_sha256_count = 0 total_md5_count = 0 + current_time = datetime.now() + for o in objects: - if o.get("type").lower() != "indicator": - is_skipped_final = True - continue - created_time = str_to_datetime(o.get("created")) - modified_time = str_to_datetime(o.get("modified")) - data = { - "comments": o.get("description") or o.get("pattern") or "", - "reputation": int(o.get("confidence", 50) / 10), - "firstSeen": created_time, - "lastSeen": modified_time, - } - sha256 = 0 - md5 = 0 - extracted, is_skipped, sha256, md5 = ( - self._extract_observables_2x(o.get("pattern", ""), data) - ) - total_sha256_count += sha256 - total_md5_count += md5 - if is_skipped: - is_skipped_final = True - indicators += extracted - return ( - indicators, - is_skipped_final, - total_sha256_count, - total_md5_count - ) + indicator_type = o.get("type", "") + try: + if indicator_type.lower() != "indicator": + reason = f"Unsupported type '{indicator_type}'" + skipped_indicators[reason] = ( + skipped_indicators.get(reason, 0) + 1 + ) + continue - def update_storage( - self, - configuration, - bundle, - last_added_date, - storage, - collection, - execution_details, - bundle_id, - ): - """Update storage with new indicators. + # Skip revoked indicators + if o.get("revoked", False): + skipped_indicators["Revoked"] = ( + skipped_indicators.get("Revoked", 0) + 1 + ) + continue - Args: - configuration (dict): Configuration dictionary. - bundle (dict): Bundle dictionary. - last_added_date (datetime): Last added date. - storage (dict): Storage dictionary. - collection (str): Collection name. - execution_details (dict): Execution details. - bundle_id (str): Bundle ID. + # Skip expired indicators (valid_until < now) + # format 2025-12-25T09:35:07.502007Z + valid_until_str = o.get("valid_until") + if valid_until_str: + valid_until = str_to_datetime( + valid_until_str, + date_format=DATE_CONVERSION_STRING, + replace_dot=False, + return_now_on_error=False, + ) + if valid_until and valid_until < current_time: + skipped_indicators["Expired"] = ( + skipped_indicators.get("Expired", 0) + 1 + ) + continue + + created_time = str_to_datetime(o.get("created")) + modified_time = str_to_datetime(o.get("modified")) + + # Build base comments and append validity times + base_comment = o.get("description") or o.get("pattern") or "" + valid_from_str = o.get("valid_from") + + # Build validity times string + validity_parts = [] + if valid_from_str: + validity_parts.append(f"Valid From: {valid_from_str}") + if valid_until_str: + validity_parts.append(f"Valid Until: {valid_until_str}") + + if validity_parts: + validity_times_str = ", ".join(validity_parts) + if base_comment: + full_comment = f"{base_comment}, {validity_times_str}" + else: + full_comment = validity_times_str + else: + full_comment = base_comment + + data = { + "comments": full_comment, + "reputation": int(o.get("confidence", 50) / 10), + "firstSeen": created_time, + "lastSeen": modified_time, + } + sha256 = 0 + md5 = 0 + extracted, observables_skipped, sha256, md5, skip_reason = ( + self._extract_observables_2x( + o.get("pattern", ""), + data, + is_retraction, + ) + ) + total_sha256_count += sha256 + total_md5_count += md5 + if observables_skipped and skip_reason: + skipped_indicators[skip_reason] = ( + skipped_indicators.get(skip_reason, 0) + 1 + ) - Returns: - dict: Storage dictionary. - """ - ( - version, - _, - _, - _, - _, - pagination_method, - _, - _, - _, - _, - _, - ) = get_configuration_parameters(configuration) + if is_retraction: + indicators.update(extracted) + else: + indicators += extracted + except Exception as e: + skipped_indicators["Exception"] = ( + skipped_indicators.get("Exception", 0) + 1 + ) + self.logger.error( + message=( + f"{self.log_prefix}: Skipping indicator. " + f"Exception: {e}" + ), + details=str(traceback.format_exc()), + ) + + skipped_count = sum(skipped_indicators.values()) + if skipped_indicators: + skip_reasons_stats = ", ".join( + f"{reason}: {count}" for reason, count in skipped_indicators.items() + ) + self.logger.debug( + message=( + f"{self.log_prefix}: " + "Some indicators were skipped due to expiration, " + "missing fields or exceptions. Skip Stats: " + f"{skip_reasons_stats}" + ) + ) + return ( + indicators, + skipped_count, + total_sha256_count, + total_md5_count + ) + def update_storage( + self, + bundle, + last_added_date, + storage, + collection, + execution_details, + bundle_id, + start_offset, + batch_size, + version, + pagination_method, + ): + """Update storage with new pagination details.""" next_value_21 = bundle.get("next") objects = bundle.get("objects", []) @@ -863,14 +1383,12 @@ def update_storage( } elif ( version == STIX_VERSION_20 - and len(objects) >= LIMIT + and len(objects) >= batch_size ): try: - next_value_20 = int( - storage.get("in_execution", {}) - .get(collection, {}) - .get("next", 0) - ) + (LIMIT * bundle_id) + next_value_20 = int(start_offset) + ( + batch_size * bundle_id + ) except Exception: next_value_20 = 0 storage["in_execution"] = { @@ -881,94 +1399,109 @@ def update_storage( } else: storage["in_execution"] = {} - execution_details[collection] = pytz.utc.localize( + execution_details[collection] = ensure_utc_aware( datetime.now() ) else: + resume_added_date = last_added_date + if not resume_added_date: + fallback_date = execution_details.get(collection) + if isinstance(fallback_date, datetime): + resume_added_date = fallback_date.strftime( + DATE_CONVERSION_STRING + ) + elif isinstance(fallback_date, str): + resume_added_date = fallback_date + if ( version == STIX_VERSION_21 and bundle.get("more") - and last_added_date + and resume_added_date ): storage["in_execution"] = { collection: { "next": next_value_21, - "last_added_date": last_added_date, + "last_added_date": resume_added_date, } } elif ( version == STIX_VERSION_20 - and len(objects) >= LIMIT - and last_added_date + and len(objects) >= batch_size + and resume_added_date ): try: - next_value_20 = int( - storage.get("in_execution", {}) - .get(collection, {}) - .get("next", 0) - ) + (LIMIT * bundle_id) + next_value_20 = int(start_offset) + ( + batch_size * bundle_id + ) except Exception: next_value_20 = 0 storage["in_execution"] = { collection: { "next": next_value_20, - "last_added_date": last_added_date, + "last_added_date": resume_added_date, } } else: storage["in_execution"] = {} - execution_details[collection] = pytz.utc.localize( + execution_details[collection] = ensure_utc_aware( datetime.now() ) else: storage["in_execution"] = {} - execution_details[collection] = pytz.utc.localize(datetime.now()) - - return + execution_details[collection] = ensure_utc_aware(datetime.now()) def paginate( self, configuration, pages, collection, - storage, - execution_details, - indicators, + storage=None, + execution_details=None, + start_offset: int = 0, + is_retraction: bool = False, ): - """Paginate through the collection. + """Paginate through the collection and yield batches. Args: configuration (dict): Configuration dictionary. - pages (list): List of pages. + pages (generator): Generator of (bundle, last_added_date) tuples. collection (str): Collection name. - storage (dict): Storage dictionary. - execution_details (dict): Execution details. - indicators (list): List of indicators. - - Returns: - list: List of indicators. + storage (dict): Storage dictionary for resume details. + execution_details (dict): Per-collection execution timestamps. + start_offset (int): Starting offset for TAXII 2.0 pagination. + is_retraction (bool): If True, yield sets of values instead of + Indicator objects. + + Yields: + tuple: (extracted_indicators, skipped_count, sub_checkpoint_dict) + for each bundle. extracted_indicators is a list of Indicator + objects or set of values if is_retraction is True. """ + ( + version, + batch_size, + pagination_method, + ) = get_configuration_parameters( + configuration, keys=["version", "batch_size", "pagination_method"] + ) + bundle_id = 1 collection_indicator_count = 0 collection_skip_count = 0 total_sha256_count = 0 total_md5_count = 0 - sha256_count = 0 - md5_count = 0 + for bundle, last_added_date in pages: objects = bundle.get("objects", []) - extracted, is_skipped, sha256_count, md5_count = ( - self._extract_indicators_2x(objects) + extracted, skipped_count, sha256_count, md5_count = ( + self._extract_indicators_2x(objects, is_retraction) ) total_sha256_count += sha256_count total_md5_count += md5_count - indicators += extracted + extracted_count = len(extracted) - skip_count = len(objects) - extracted_count collection_indicator_count += extracted_count - self.total_indicators += extracted_count - collection_skip_count += skip_count - self.total_skipped += skip_count + collection_skip_count += skipped_count incremental_hash_msg = "" if sha256_count > 0 or md5_count > 0: @@ -979,56 +1512,66 @@ def paginate( f"Total {collection_indicator_count} indicators" f" pulled from '{collection}' collection till now." ) - if is_skipped is True: - self.logger.info( + if skipped_count > 0: + self.logger.debug( f"{self.log_prefix}: Pulled {extracted_count} " f"{incremental_hash_msg}" f"indicator(s) from '{collection}' collection " - f"Bundle-{bundle_id}, {skip_count} indicator(s) " + f"Bundle-{bundle_id}, {skipped_count} indicator(s) " "have been discarded." f" {total_log}" ) else: - self.logger.info( + self.logger.debug( f"{self.log_prefix}: Pulled {extracted_count} " f"{incremental_hash_msg}" f"indicator(s) from '{collection}' collection " f"Bundle-{bundle_id}. {total_log}" ) - self.total_bundle_count += 1 + # Build sub_checkpoint for resumption + next_value = bundle.get("next") + sub_checkpoint = { + "collection": collection, + "bundle_id": bundle_id, + "last_added_date": last_added_date, + } - if self.total_bundle_count == BUNDLE_LIMIT: - self.logger.info( - f"{self.log_prefix}: Bundle limit of {BUNDLE_LIMIT}" - f" is reached while executing '{collection}' collection." - "The execution will be continued in the next cycle." - f" Total {collection_indicator_count} indicators " - f"pulled and {collection_skip_count} indicators" - f" skipped for '{collection}' collection." - ) - self.logger.debug( - f"{self.log_prefix}: Updating the collection" - " execution details with the next page details" - ) + # For TAXII 2.1, use next token; for 2.0, calculate offset + if version == STIX_VERSION_21: + sub_checkpoint["next"] = next_value + sub_checkpoint["has_more"] = bundle.get("more", False) + else: + # TAXII 2.0: calculate offset for next page + # Offset = start_offset + (batch_size * bundles_processed) + sub_checkpoint["next"] = start_offset + (batch_size * bundle_id) + sub_checkpoint["has_more"] = len(objects) >= batch_size + + if ( + not is_retraction + and storage is not None + and execution_details is not None + ): self.update_storage( - configuration=configuration, bundle=bundle, last_added_date=last_added_date, storage=storage, collection=collection, execution_details=execution_details, bundle_id=bundle_id, + start_offset=start_offset, + batch_size=batch_size, + version=version, + pagination_method=pagination_method, ) storage["collections"] = self.convert_datetime_to_string( - execution_details + execution_details.copy() ) - self.logger.debug( - f"{self.log_prefix}: Updated the collection execution" - " details successfully. Collection execution" - f" details: {storage}." + sub_checkpoint["in_execution"] = storage.get( + "in_execution", {} ) - return indicators + + yield extracted, skipped_count, sub_checkpoint bundle_id += 1 hash_msg = "" @@ -1038,17 +1581,12 @@ def paginate( ) self.logger.info( f"{self.log_prefix}: Completed pulling of" - f" indicator(s) from collection(s) - " + f" indicator(s) from collection - " f"'{collection}'." f" Total {collection_indicator_count} {hash_msg}" "indicator(s) pulled" f" and {collection_skip_count} skipped." ) - self.logger.debug( - f"{self.log_prefix}: Successfully pulled " - f"{self.total_bundle_count} bundle(s)." - ) - return def get_page(self, func, configuration, start_time, next=None, start=0): """Get a page of indicators. @@ -1063,9 +1601,13 @@ def get_page(self, func, configuration, start_time, next=None, start=0): Returns: list: List of indicators. """ - version, _, _, _, _, _, _, _, _, _, _, = get_configuration_parameters( - configuration + ( + version, + batch_size, + ) = get_configuration_parameters( + configuration, keys=["version", "batch_size"] ) + headers = add_ce_user_agent( plugin_name=self.plugin_name, plugin_version=self.plugin_version ) @@ -1073,7 +1615,7 @@ def get_page(self, func, configuration, start_time, next=None, start=0): pages = as_pages21( func, plugin=self, - per_request=LIMIT, + per_request=batch_size, added_after=start_time, next=next, with_header=True, @@ -1083,7 +1625,7 @@ def get_page(self, func, configuration, start_time, next=None, start=0): pages = as_pages20( func, plugin=self, - per_request=LIMIT, + per_request=batch_size, added_after=start_time, start=start, with_header=True, @@ -1092,15 +1634,18 @@ def get_page(self, func, configuration, start_time, next=None, start=0): return pages - def pull_2x(self, configuration, start_time): + def pull_2x(self, configuration, start_time, is_retraction: bool = False): """Pull implementation for version 2.x. Args: configuration (dict): Configuration dictionary. start_time (datetime): Start time. + is_retraction (bool): If True, yield sets of values for retraction. - Returns: - list: List of indicators. + Yields: + tuple: (indicators_batch, sub_checkpoint_dict) for each bundle. + indicators_batch is a list of Indicator objects or set of + values if is_retraction is True. """ ( version, @@ -1109,31 +1654,44 @@ def pull_2x(self, configuration, start_time): password, collection_names, pagination_method, - _, delay, - _, - _, - _, - ) = get_configuration_parameters(configuration) - indicators = [] - collection_execution_details = {} - new_collection_details = {} + type_to_pull, + severity, + reputation, + ) = get_configuration_parameters( + configuration, + keys=[ + "version", + "discovery_url", + "username", + "password", + "collection_names", + "pagination_method", + "delay", + "type_to_pull", + "severity", + "reputation", + ], + ) + if delay and isinstance(delay, int): + delay = int(delay) + else: + delay = 0 + collection_name_object = {} + delay_time = int(delay) storage = {} - self.total_indicators = 0 - self.total_skipped = 0 - self.total_bundle_count = 0 + collection_execution_details = {} + new_collection_details = {} - if self.storage is not None: + if not is_retraction and self.storage is not None: storage = self.storage if storage.get("collections", {}): collection_execution_details = self.convert_string_to_datetime( - storage.get("collections", {}) + storage.get("collections", {}).copy() ) - else: - storage = {} - delay_time = int(delay) + # Initialize API root based on version if version == STIX_VERSION_21: apiroot = ApiRoot21( discovery_url, @@ -1150,6 +1708,8 @@ def pull_2x(self, configuration, start_time): verify=self.ssl_validation, proxies=self.proxy, ) + + # Build collection mapping all_collections = [] for c in apiroot.collections: all_collections.append(c.title) @@ -1163,128 +1723,235 @@ def pull_2x(self, configuration, start_time): f" be fetched - {', '.join(filtered_collections)}." ) - self.logger.debug( - f"{self.log_prefix}: Collection execution details - {storage}." - ) + if not is_retraction: + self.logger.debug( + f"{self.log_prefix}: Collection execution details - {storage}." + ) + + total_indicators = 0 + total_skipped = 0 + total_type_counts = { + "sha256": 0, "md5": 0, "url": 0, + "ipv4": 0, "ipv6": 0, "domain": 0 + } + + def _process_pages(pages, collection, execution_details, start_offset): + nonlocal total_indicators, total_skipped + + for extracted, skipped_count, bundle_checkpoint in self.paginate( + configuration, + pages, + collection, + storage=None if is_retraction else storage, + execution_details=( + None if is_retraction else execution_details + ), + start_offset=start_offset, + is_retraction=is_retraction, + ): + if is_retraction: + yield extracted, None + continue + + total_skipped += skipped_count + if isinstance(extracted, list) and extracted: + filtered_batch, filter_skipped = ( + self._filter_indicators_by_config( + extracted, type_to_pull, severity, reputation, + total_type_counts + ) + ) + total_indicators += len(filtered_batch) + total_skipped += filter_skipped - if storage.get("in_execution", {}): + self.logger.info( + f"{self.log_prefix}: Pulled total " + f"{len(filtered_batch)} indicator(s) from " + f"'{collection}' collection's Bundle-" + f"{bundle_checkpoint.get('bundle_id')}, " + f"{filter_skipped} indicator(s) filtered, and " + f"{skipped_count} indicator(s) skipped." + ) + + if filtered_batch: + yield filtered_batch, bundle_checkpoint + + if not is_retraction and storage.get("in_execution", {}): for collection, next_page_details in storage.get( - "in_execution" + "in_execution", {} ).items(): if collection not in filtered_collections: break collection_object = collection_name_object[collection] + next_start_time = collection_execution_details.get( + collection, start_time + ) + next_val = None + start_offset = 0 + if pagination_method == "next": - next_start_time = collection_execution_details[collection] - next_value = next_page_details.get("next") - next_val = next_value - start_val = next_value + next_val = next_page_details.get("next") + if version == STIX_VERSION_20: + try: + start_offset = int(next_val) if next_val else 0 + except Exception: + start_offset = 0 else: - next_value = next_page_details.get("last_added_date") - next_start_time = str_to_datetime(next_value) + last_added_date = next_page_details.get("last_added_date") + if last_added_date: + next_start_time = str_to_datetime( + string=last_added_date, + date_format=DATE_CONVERSION_STRING, + replace_dot=False, + ) + else: + next_start_time = collection_execution_details.get( + collection, start_time + ) next_val = None - start_val = 0 + start_offset = 0 - try: - next_start_time = next_start_time - timedelta( - minutes=delay_time - ) - next_start_time = pytz.utc.localize(next_start_time) - except Exception: - pass - self.logger.debug( - f"{self.log_prefix}: Executing the collection " - f"'{collection}' with start time {next_start_time}." + next_start_time = next_start_time - timedelta( + minutes=delay_time ) - try: - pages = self.get_page( - func=collection_object.get_objects, - configuration=configuration, - start_time=next_start_time, - next=next_val, - start=start_val, - ) + next_start_time = ensure_utc_aware(next_start_time) + + resume_msg = "" + if pagination_method == "next": + if version == STIX_VERSION_21 and next_val: + resume_msg = ", resuming with next token" + elif version == STIX_VERSION_20 and start_offset: + resume_msg = f", resuming from offset={start_offset}" - fetched_indicators = self.paginate( - configuration, - pages, - collection, - storage, - collection_execution_details, - indicators, + for attempt in range(IN_EXECUTION_MAX_RETRIES): + self.logger.info( + f"{self.log_prefix}: Parsing collection - " + f"'{collection}'. Start time: {next_start_time} (UTC)" + f"{resume_msg}. Attempt {attempt + 1} of {IN_EXECUTION_MAX_RETRIES}." ) - if fetched_indicators is not None: - self.logger.debug( - f"{self.log_prefix}: Successfully pulled" - f" {self.total_indicators} indicator(s)" - f" and {self.total_skipped} " - "indicator(s) were skipped." + + try: + pages = self.get_page( + func=collection_object.get_objects, + configuration=configuration, + start_time=next_start_time, + next=next_val, + start=start_offset, ) - return indicators, self.total_skipped - except requests.exceptions.ProxyError as err: - err_msg = "Invalid proxy configuration." - self.handle_and_raise( - err=err, - err_msg=err_msg, - details_msg=str(traceback.format_exc()), - ) - except requests.exceptions.ConnectionError as err: - err_msg = ( - "Connection Error occurred. Check the " - "Discovery URL/API Root URL provided." - ) - self.handle_and_raise( - err=err, - err_msg=err_msg, - details_msg=str(traceback.format_exc()), - ) - except requests.exceptions.RequestException as err: - if not ( - "416" in str(err) - or "request range not satisfiable" in str(err).lower() - ): - err_msg = "Request Exception occurred." + + yield from _process_pages( + pages, + collection, + collection_execution_details, + start_offset, + ) + + storage["in_execution"] = {} + collection_execution_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) + break + except KeyError: + storage["in_execution"] = {} + collection_execution_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) + break + except requests.exceptions.ProxyError as err: + err_msg = f"Invalid proxy configuration. Retry attempt: {attempt}." self.handle_and_raise( err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + if_raise=False, + resolution=PROXY_ERROR_RESOLUTION, ) - self.logger.info( - f"{self.log_prefix}: Received status code 416," - f" exiting the pulling of '{collection}' collection. " - f"Response: {str(err)}." - ) - except Exception as err: - if not ( - "416" in str(err) - or "request range not satisfiable" in str(err).lower() - ): + except requests.exceptions.ConnectionError as err: err_msg = ( - "Exception occurred while fetching the" - " objects of collection." + "Connection Error occurred. Check the " + "Discovery URL/API Root URL provided. " + f"Retry attempt: {attempt}." ) self.handle_and_raise( err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + if_raise=False, + resolution=CONNECTION_ERROR_RESOLUTION, ) - self.logger.info( - f"{self.log_prefix}: Received status code 416, " - f"exiting the pulling of '{collection}' collection. " - f"Response: {str(err)}." - ) + except requests.exceptions.RequestException as err: + if ( + "416" in str(err) + or "request range not satisfiable" in str(err).lower() + ): + storage["in_execution"] = {} + collection_execution_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) + self.logger.info( + f"{self.log_prefix}: Received status code 416, " + f"exiting the pulling of '{collection}' " + f"collection. Response: {str(err)}." + ) + break - collection_execution_details[collection] = datetime.now() - storage["in_execution"] = {} + err_msg = ( + "Exception occurred while fetching the " + "objects of collection. " + f"Retry attempt: {attempt}." + ) + self.handle_and_raise( + err=err, + err_msg=err_msg, + details_msg=str(traceback.format_exc()), + if_raise=False, + ) + except Exception as err: + if ( + "416" in str(err) + or "request range not satisfiable" in str(err).lower() + ): + storage["in_execution"] = {} + collection_execution_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) + self.logger.info( + f"{self.log_prefix}: Received status code 416, " + f"exiting the pulling of '{collection}' " + f"collection. Response: {str(err)}." + ) + break - for collection in apiroot.collections: - collection_name = collection.title - if collection_name not in filtered_collections: - continue + err_msg = ( + "Exception occurred while fetching the " + "objects of collection. " + f"Retry attempt: {attempt}." + ) + self.handle_and_raise( + err=err, + err_msg=err_msg, + details_msg=str(traceback.format_exc()), + if_raise=False, + ) + + if attempt >= (IN_EXECUTION_MAX_RETRIES - 1): + storage["in_execution"] = {} + collection_execution_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) + self.logger.info( + f"{self.log_prefix}: Exhausted retries while " + f"resuming '{collection}'. Skipping this " + f"collection till current time." + ) + break + + time.sleep(IN_EXECUTION_SLEEP_TIME) - new_collection_details[collection_name] = pytz.utc.localize( - collection_execution_details.get(collection_name, start_time) + for collection in filtered_collections: + new_collection_details[collection] = ensure_utc_aware( + collection_execution_details.get(collection, start_time) ) sorted_collection = sorted( @@ -1293,45 +1960,43 @@ def pull_2x(self, configuration, start_time): for collection in sorted_collection: collection_object = collection_name_object[collection] - start_time = new_collection_details[collection] - timedelta( - minutes=delay_time + collection_start_time = new_collection_details[ + collection + ] - timedelta(minutes=delay_time) + + self.logger.info( + f"{self.log_prefix}: Parsing collection - " + f"'{collection}'. Start time: {collection_start_time} (UTC)" ) + try: - self.logger.debug( - f"{self.log_prefix}: Parsing collection - " - f"'{collection}'. Start time: {start_time} (UTC)" - ) pages = self.get_page( func=collection_object.get_objects, configuration=configuration, - start_time=start_time, + start_time=collection_start_time, ) - fetched_indicators = self.paginate( - configuration, + yield from _process_pages( pages, collection, - storage, new_collection_details, - indicators, + start_offset=0, ) - if fetched_indicators is not None: - self.logger.debug( - f"{self.log_prefix}: Successfully pulled" - f" {self.total_indicators} indicator(s)" - f" and {self.total_skipped} indicator(s) were skipped." - ) - return indicators, self.total_skipped - storage["in_execution"] = {} - new_collection_details[collection] = pytz.utc.localize( - datetime.now() - ) + if not is_retraction: + storage["in_execution"] = {} + new_collection_details[collection] = ensure_utc_aware( + datetime.now() + ) except KeyError: - # if there is no data in a collection - storage["in_execution"] = {} - new_collection_details[collection] = pytz.utc.localize( - datetime.now() + if not is_retraction: + storage["in_execution"] = {} + new_collection_details[collection] = ensure_utc_aware( + datetime.now() + ) + self.logger.info( + f"{self.log_prefix}: No data in collection " + f"'{collection}', continuing." ) except requests.exceptions.ProxyError as err: err_msg = "Invalid proxy configuration." @@ -1339,6 +2004,7 @@ def pull_2x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=PROXY_ERROR_RESOLUTION, ) except requests.exceptions.ConnectionError as err: err_msg = ( @@ -1349,21 +2015,23 @@ def pull_2x(self, configuration, start_time): err=err, err_msg=err_msg, details_msg=str(traceback.format_exc()), + resolution=CONNECTION_ERROR_RESOLUTION, ) except requests.exceptions.RequestException as err: if ( "416" in str(err) or "request range not satisfiable" in str(err).lower() ): - storage["in_execution"] = {} + if not is_retraction: + storage["in_execution"] = {} + new_collection_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) self.logger.info( f"{self.log_prefix}: Received status code 416, " f"exiting the pulling of '{collection}' " f"collection. Response: {str(err)}." ) - new_collection_details[collection] = pytz.utc.localize( - datetime.now() - ) else: err_msg = ( "Exception occurred while fetching the " @@ -1379,15 +2047,16 @@ def pull_2x(self, configuration, start_time): "416" in str(err) or "request range not satisfiable" in str(err).lower() ): - storage["in_execution"] = {} + if not is_retraction: + storage["in_execution"] = {} + new_collection_details[collection] = ( + ensure_utc_aware(datetime.now()) + ) self.logger.info( f"{self.log_prefix}: Received status code 416, " f"exiting the pulling of '{collection}' " f"collection. Response: {str(err)}." ) - new_collection_details[collection] = pytz.utc.localize( - datetime.now() - ) else: err_msg = ( "Exception occurred while fetching the " @@ -1399,19 +2068,25 @@ def pull_2x(self, configuration, start_time): details_msg=str(traceback.format_exc()), ) - storage["collections"] = self.convert_datetime_to_string( - new_collection_details - ) - self.logger.debug( - f"{self.log_prefix}: Storage value after" - f" completion of the pull cycle: {storage}." - ) - self.logger.debug( - f"{self.log_prefix}: Successfully pulled {self.total_indicators}" - f" indicator(s) and {self.total_skipped} " - "indicator(s) were skipped." + if not is_retraction: + storage["collections"] = self.convert_datetime_to_string( + new_collection_details.copy() + ) + self.logger.debug( + f"{self.log_prefix}: Storage value after" + f" completion of the pull cycle: {storage['collections']}." + ) + + type_breakdown_str = self._format_type_breakdown(total_type_counts, type_to_pull) + + self.logger.info( + f"{self.log_prefix}: Completed pulling of " + f"indicator(s) from all collection(s) - " + f"{', '.join(filtered_collections)}. " + f"Total {total_indicators} indicator(s) " + f"pulled, {total_skipped} skipped due to filters. " + f"Pull Stats: {type_breakdown_str} indicator(s) were fetched." ) - return indicators, self.total_skipped def _pull(self, configuration, last_run_at): """Pull implementation. @@ -1420,24 +2095,21 @@ def _pull(self, configuration, last_run_at): configuration (dict): Configuration dictionary. last_run_at (datetime): Last run time. - Returns: - list: List of indicators. + Yields: + tuple: (indicators_batch, sub_checkpoint_dict) for each batch. """ ( version, discovery_url, - _, - _, - _, - _, initial_range, - _, type_to_pull, severity, reputation, - ) = get_configuration_parameters(configuration) + ) = get_configuration_parameters( + configuration, + keys=["version", "discovery_url", "days", "type_to_pull", "severity", "reputation"] + ) - skipped = 0 if not last_run_at: start_time = datetime.now() - timedelta( days=int(initial_range) @@ -1455,8 +2127,7 @@ def _pull(self, configuration, last_run_at): f"{self.log_prefix}: Starting the pull execution for " f"Discovery URL: " f"{discovery_url}," - f" Version: {version} and" - f" start time: {start_time}." + f" Version: {version}." ) self.logger.debug( @@ -1467,78 +2138,18 @@ def _pull(self, configuration, last_run_at): ) if version == STIX_VERSION_1: - indicators = self.pull_1x(configuration, start_time) + yield from self.pull_1x(configuration, start_time) else: - indicators, skipped = self.pull_2x(configuration, start_time) - - filtered_list = list( - filter( - lambda x: x.severity.value in severity - and x.reputation >= int(reputation) - and ( - ( - x.type is IndicatorType.SHA256 - and "sha256" in type_to_pull - ) - or ( - x.type is IndicatorType.MD5 - and "md5" in type_to_pull - ) - or ( - x.type is IndicatorType.URL - and "url" in type_to_pull - ) - or ( - x.type - is getattr(IndicatorType, "IPV4", IndicatorType.URL) - and "ipv4" in type_to_pull - ) - or ( - x.type - is getattr(IndicatorType, "IPV6", IndicatorType.URL) - and "ipv6" in type_to_pull - ) - or ( - x.type - is getattr(IndicatorType, "DOMAIN", IndicatorType.URL) - and "domain" in type_to_pull - ) - ), - indicators, - ) - ) - skipped_filtered = len(indicators) - len(filtered_list) - - log_msg_without_skip = ( - f"{self.log_prefix}: Pulled {len(filtered_list)}" - " indicator(s) successfully." - ) - - log_msg_with_skip = ( - f"{self.log_prefix}: Pulled {len(filtered_list)}" - " indicator(s) successfully, " - f"skipped {skipped_filtered} indicator(s) " - "due to filter(s) in configuration" - ) - if version == STIX_VERSION_21 or version == STIX_VERSION_20: - if skipped > 0 or skipped_filtered > 0: - self.logger.info( - log_msg_with_skip + ", " - f"and {skipped} indicator(s) were skipped " - "due to invalid or unsupported type." - ) - else: - self.logger.info(log_msg_without_skip) - else: - if skipped_filtered > 0: - self.logger.info(log_msg_with_skip + ".") - else: - self.logger.info(log_msg_without_skip) - - return filtered_list + yield from self.pull_2x(configuration, start_time) def pull(self): - """Pull indicators from TAXII server.""" + """Pull indicators from TAXII server. + + Returns: + Generator or List: If sub_checkpoint is available, returns the + generator directly. Otherwise, consumes the generator and + returns a list of all indicators. + """ try: return self._pull(self.configuration, self.last_run_at) except STIXTAXIIException as err: @@ -1551,6 +2162,120 @@ def pull(self): details_msg=str(traceback.format_exc()), ) + def get_modified_indicators(self, source_indicators: List[List[Indicator]]): + """Get all modified indicators status for retraction. + + This method identifies indicators that should be retracted because they + either no longer exist on the TAXII server, have been revoked, or have + expired (valid_until/valid_time_positions < current time). + + Applicable for both STIX/TAXII 1.x and 2.x. + + Args: + source_indicators (List[List[Indicator]]): Batches of source indicators + currently stored in Cloud Exchange. + + Yields: + tuple: (list_of_ioc_values_to_retract, is_done_flag) + """ + if RETRACTION not in self.log_prefix: + self.log_prefix = self.log_prefix + f" [{RETRACTION}]" + + # Get configuration parameters + ( + version, + retraction_interval, + ) = get_configuration_parameters( + self.configuration, + keys=["version", "retraction_interval"] + ) + + if not (retraction_interval and isinstance(retraction_interval, int)): + log_msg = ( + "Retraction Interval is not configured for " + f'"{self.config_name}". Skipping retraction.' + ) + self.logger.info(f"{self.log_prefix}: {log_msg}") + yield [], True + return + + retraction_interval = int(retraction_interval) + self.logger.info( + f"{self.log_prefix}: Pulling modified indicators from " + f"{PLATFORM_NAME} for retraction. Looking back " + f"{retraction_interval} days. Version: {version}." + ) + + # Calculate start time for retraction + start_time = datetime.now() - timedelta(days=retraction_interval) + + try: + # Fetch currently valid indicators from TAXII server + # using is_retraction=True to get set of values + fetched_iocs = set() + + if version == STIX_VERSION_1: + # Use pull_1x for version 1.x + for ioc_values_set, _ in self.pull_1x( + self.configuration, start_time, is_retraction=True + ): + if isinstance(ioc_values_set, set): + fetched_iocs.update(ioc_values_set) + else: + # Use pull_2x for version 2.x + for ioc_values_set, _ in self.pull_2x( + self.configuration, start_time, is_retraction=True + ): + if isinstance(ioc_values_set, set): + fetched_iocs.update(ioc_values_set) + + self.logger.info( + f"{self.log_prefix}: Fetched {len(fetched_iocs)} valid " + f"indicator(s) from {PLATFORM_NAME}." + ) + + # Compare source indicators with fetched indicators + for source_ioc_list in source_indicators: + try: + total_iocs = len(source_ioc_list) + + # Find indicators NOT in fetched set = should be retracted + iocs_to_retract = [ + ioc.value for ioc in source_ioc_list + if ioc and ioc.value not in fetched_iocs + ] + + self.logger.info( + f"{self.log_prefix}: {len(iocs_to_retract)} indicator(s) " + f"will be marked as retracted out of {total_iocs} " + f"total indicator(s) from {PLATFORM_NAME}." + ) + yield iocs_to_retract, False + + except Exception as err: + err_msg = ( + f"Error while processing source indicators for " + f"retraction from {PLATFORM_NAME}." + ) + self.logger.error( + message=f"{self.log_prefix}: {err_msg} Error: {err}", + details=traceback.format_exc(), + ) + raise STIXTAXIIException(err_msg) + + except STIXTAXIIException: + raise + except Exception as err: + err_msg = ( + f"Error occurred while fetching modified indicators " + f"from {PLATFORM_NAME}." + ) + self.logger.error( + message=f"{self.log_prefix}: {err_msg} Error: {err}", + details=traceback.format_exc(), + ) + raise STIXTAXIIException(err_msg) + def _validate_collections(self, configuration): """Validate collections. Args: @@ -1565,14 +2290,12 @@ def _validate_collections(self, configuration): username, password, collection_names, - _, - _, - _, - _, - _, - _, - ) = get_configuration_parameters(configuration) + ) = get_configuration_parameters( + configuration, + keys=["version", "discovery_url", "username", "password", "collection_names"] + ) + # Build clients if version == STIX_VERSION_1: client = self._build_client(configuration) all_collections = self._get_collections(client) @@ -1584,7 +2307,6 @@ def _validate_collections(self, configuration): verify=self.ssl_validation, proxies=self.proxy, ) - all_collections = [c.title for c in apiroot.collections] elif version == STIX_VERSION_21: apiroot = ApiRoot21( discovery_url, @@ -1593,7 +2315,15 @@ def _validate_collections(self, configuration): verify=self.ssl_validation, proxies=self.proxy, ) - all_collections = [c.title for c in apiroot.collections] + + # Gather collections + collection_name_object = {} + if version in [STIX_VERSION_20, STIX_VERSION_21]: + collection_objects = list(apiroot.collections) + all_collections = [c.title for c in collection_objects] + collection_name_object = { + c.title: c for c in collection_objects + } collections = [c.strip() for c in collection_names.split(",")] collections = list(filter(lambda x: len(x) > 0, collections)) if collections and set(collections) - set(all_collections): @@ -1604,6 +2334,30 @@ def _validate_collections(self, configuration): f"{', '.join(set(collections) - set(all_collections))}" ), ) + collections_to_validate = collections or all_collections + + # Validate collections + if collections_to_validate: + validation_collection = collections_to_validate[0] + start_time = ensure_utc_aware(datetime.now()) + if version == STIX_VERSION_1: + content_blocks = client.poll( + collection_name=validation_collection, + begin_date=start_time, + ) + next(iter(content_blocks), None) + else: + collection_object = collection_name_object.get( + validation_collection + ) + if collection_object: + pages = self.get_page( + func=collection_object.get_objects, + configuration=configuration, + start_time=start_time, + ) + next(pages, None) + return ValidationResult( success=True, message="Validated successfully." ) @@ -1614,6 +2368,7 @@ def _validate_collections(self, configuration): err_msg=err_msg, details_msg=str(traceback.format_exc()), if_raise=False, + resolution=PROXY_ERROR_RESOLUTION, ) return ValidationResult(success=False, message=err_msg) except requests.exceptions.ConnectionError as err: @@ -1626,6 +2381,7 @@ def _validate_collections(self, configuration): err_msg=err_msg, details_msg=str(traceback.format_exc()), if_raise=False, + resolution=CONNECTION_ERROR_RESOLUTION, ) return ValidationResult(success=False, message=err_msg) except requests.exceptions.RequestException as ex: @@ -1642,7 +2398,11 @@ def _validate_collections(self, configuration): ) except exceptions.UnsuccessfulStatusError as ex: self.logger.error( - f"{self.log_prefix}: {str(ex)}", details=traceback.format_exc() + message=( + f"{self.log_prefix}: Error occurred while " + f"validating credentials. Error: {str(ex)}" + ), + details=traceback.format_exc() ) if ex.status == "UNAUTHORIZED": return ValidationResult( @@ -1678,7 +2438,7 @@ def _validate_collections(self, configuration): ) return ValidationResult( success=False, - message=err_msg + ". Check all of the parameters.", + message=err_msg + " Check all of the parameters.", ) def validate(self, configuration: Dict) -> ValidationResult: @@ -1702,198 +2462,162 @@ def validate(self, configuration: Dict) -> ValidationResult: type_to_pull, severity, reputation, + batch_size, + retraction_interval, ) = get_configuration_parameters(configuration, is_validation=True) # Discovery URL - if not discovery_url: - err_msg = ( - "Discovery URL/API Root URL is a " - "required configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult( - success=False, - message=err_msg, - ) - elif not isinstance(discovery_url, str): - err_msg = ( - "Invalid Discovery URL/API Root URL Provided " - "in configuration parameters." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult( - success=False, - message=err_msg, - ) + if validation_failure := self._validate_configuration_parameters( + field_name="Discovery URL/API Root URL", + field_value=discovery_url, + field_type=str, + ): + return validation_failure # Username - if not isinstance(username, str): - err_msg = "Invalid Username Provided in configuration parameters." - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult( - success=False, - message=err_msg, - ) + if validation_failure := self._validate_configuration_parameters( + field_name="Username", + field_value=username, + field_type=str, + is_required=False, + ): + return validation_failure # Password - if not isinstance(password, str): - err_msg = "Invalid Password Provided in configuration parameters." - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult( - success=False, - message=err_msg, - ) + if validation_failure := self._validate_configuration_parameters( + field_name="Password", + field_value=password, + field_type=str, + is_required=False, + ): + return validation_failure # STIX/TAXII Version - if not version: - err_msg = ( - "STIX/TAXII Version is a required configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif not isinstance(version, str) or version not in [ - STIX_VERSION_1, - STIX_VERSION_20, - STIX_VERSION_21, - ]: - err_msg = ( - "Invalid value for STIX/TAXII Version provided." - " Available values are '1', '2.0', or '2.1'." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif version == STIX_VERSION_1 and ( - "ipv4" in type_to_pull or "ipv6" in type_to_pull + if validation_failure := self._validate_configuration_parameters( + field_name="STIX/TAXII Version", + field_value=version, + field_type=str, + allowed_values=[ + STIX_VERSION_1, + STIX_VERSION_20, + STIX_VERSION_21, + ], ): - err_msg = ( - "IPv4/IPv6 is not supported in the plugin for" - " STIX Version 1.x." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure # Collection Names - if not isinstance(collection_names, str): - err_msg = ( - "Invalid Collection Names provided in" - " configuration parameters." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + if validation_failure := self._validate_configuration_parameters( + field_name="Collection Names", + field_value=collection_names, + field_type=str, + is_required=False, + ): + return validation_failure # Type of Threat data to pull - if not type_to_pull: - err_msg = ( - "Type of Threat data to pull is a required" - " configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif not isinstance(type_to_pull, list) or not all( - item in ["sha256", "md5", "url", "ipv4", "ipv6", "domain"] - for item in type_to_pull + valid_types = ( + [c["value"] for c in TYPE_V1["choices"]] + if version == STIX_VERSION_1 + else [c["value"] for c in TYPE_V2["choices"]] + ) + if validation_failure := self._validate_configuration_parameters( + field_name="Type of Threat data to pull", + field_value=type_to_pull, + field_type=list, + allowed_values=valid_types, + is_required=False ): - err_msg = ( - "Invalid value for Type of Threat data to pull" - " provided in configuration parameters. " - "Available values are 'sha256', 'md5'," - " 'url', 'ipv4', 'ipv6', 'domain'." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - - # Pagination Method - if not pagination_method: - err_msg = ( - "Pagination Method is a required configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif not isinstance( - pagination_method, str - ) or pagination_method not in ["next", "last_added_date"]: - err_msg = ( - "Invalid value for Pagination Method provided. Available" - " values are 'next' or 'last_added_date'." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure # Reputation - if not reputation: - err_msg = "Reputation is a required configuration parameter." - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif ( - not isinstance(reputation, int) - or int(reputation) < 1 - or int(reputation) > 10 + if validation_failure := self._validate_configuration_parameters( + field_name="Reputation", + field_value=reputation, + field_type=int, + min_value=1, + max_value=10, ): - err_msg = ( - "Invalid value for Reputation provided. " - "Must be an integer in range 1 - 10." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure # Initial Range - if initial_range != 0 and not initial_range: - err_msg = ( - "Initial Range (in days) is a required " - "configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif ( - not isinstance(initial_range, int) - or int(initial_range) < 1 - or int(initial_range) > 365 + if validation_failure := self._validate_configuration_parameters( + field_name="Initial Range (in days)", + field_value=initial_range, + field_type=int, + min_value=1, + max_value=365, ): - err_msg = ( - "Invalid value for Initial Range (in days) provided" - " in configuration parameters. " - "Must be an integer in range 1 - 365." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure + + # Retraction logic + if validation_failure := self._validate_configuration_parameters( + field_name="Retraction Interval (in days)", + field_value=retraction_interval, + field_type=int, + min_value=1, + max_value=365, + is_required=False, + ): + return validation_failure # Delay - if ( - not isinstance(delay_config, int) - or int(delay_config) < 0 - or int(delay_config) > 1440 + if validation_failure := self._validate_configuration_parameters( + field_name="Look Back", + field_value=delay_config, + field_type=int, + min_value=1, + max_value=1440, + is_required=False, ): - err_msg = ( - "Invalid value for Look Back provided" - " in configuration parameters. " - "Must be an integer in range 0 - 1440." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure - # severity - if not severity: - err_msg = ( - "Severity is a required configuration parameter." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) - elif not isinstance(severity, list) or not ( - all( - sev in ["unknown", "low", "medium", "high", "critical"] - for sev in severity - ) + # Severity + valid_severity = ( + [c["value"] for c in SEVERITY_V1["choices"]] + if version == STIX_VERSION_1 + else [c["value"] for c in SEVERITY_V2["choices"]] + ) + if validation_failure := self._validate_configuration_parameters( + field_name="Severity", + field_value=severity, + field_type=list, + allowed_values=valid_severity, + is_required=False ): - err_msg = ( - "Invalid value for Severity provided in " - "configuration parameters. " - "Available values are 'unknown', " - "'low', 'medium', 'high' and 'critical'." - ) - self.logger.error(f"{self.log_prefix}: {err_msg}") - return ValidationResult(success=False, message=err_msg) + return validation_failure + + # Pagination method - only for version 2.x + if version in [STIX_VERSION_20, STIX_VERSION_21]: + if validation_failure := self._validate_configuration_parameters( + field_name="Pagination Method", + field_value=pagination_method, + field_type=str, + allowed_values=["next", "last_added_date"], + ): + return validation_failure + + # Batch Size - only for version 2.x + # Range for 2.0 is 2 to 1000 + if version == STIX_VERSION_20: + if validation_failure := self._validate_configuration_parameters( + field_name="Batch Size", + field_value=batch_size, + field_type=int, + min_value=2, + max_value=1000, + ): + return validation_failure + # Range for 2.1 is 1 to 1000 + if version == STIX_VERSION_21: + if validation_failure := self._validate_configuration_parameters( + field_name="Batch Size", + field_value=batch_size, + field_type=int, + min_value=1, + max_value=1000, + ): + return validation_failure - # validating the configuration parameters # Validate collections validate_collections = self._validate_collections(configuration) if validate_collections.success is False: @@ -1907,6 +2631,91 @@ def validate(self, configuration: Dict) -> ValidationResult: success=True, message="Validated successfully." ) + def _validate_configuration_parameters( + self, + field_name: str, + field_value: Any, + field_type: type, + allowed_values: Optional[List] = None, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + custom_validation_func: Optional[Callable] = None, + is_required: bool = True, + validation_err_msg: str = "", + ) -> Optional[ValidationResult]: + """Validate a configuration field value.""" + if field_type is str and isinstance(field_value, str): + field_value = field_value.strip() + + if is_required and ( + ( + not isinstance(field_value, int) + and not field_value + ) + or ( + isinstance(field_value, int) + and field_value is None + ) + ): + err_msg = ( + f"{field_name} is a required configuration parameter." + ) + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + + if field_value and not isinstance(field_value, field_type) or ( + custom_validation_func + and not custom_validation_func(field_value) + ): + err_msg = ( + "Invalid value provided for the configuration" + f" parameter '{field_name}', expecting {field_type} type value." + ) + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + + if allowed_values: + allowed_values_list = ( + list(allowed_values.values()) + if isinstance(allowed_values, dict) + else list(allowed_values) + ) + allowed_values_str = ", ".join( + [str(value) for value in allowed_values_list] + ) + err_msg = ( + f"Invalid value for {field_name} provided in configuration " + f"parameters. Available values are {allowed_values_str}." + ) + if field_type is str and field_value not in allowed_values_list: + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + if field_type is list and any( + value not in allowed_values_list for value in field_value + ): + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + + if isinstance(field_value, int): + if min_value is not None and field_value < min_value: + err_msg = ( + f"Invalid value for {field_name} provided in " + "configuration parameters. Must be an integer greater " + f"than or equal to {min_value}." + ) + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + if max_value is not None and field_value > max_value: + err_msg = ( + f"Invalid value for {field_name} provided in " + "configuration parameters. Must be an integer less " + f"than or equal to {max_value}." + ) + self.logger.error(f"{self.log_prefix}: {validation_err_msg}{err_msg}") + return ValidationResult(success=False, message=err_msg) + + return None + def get_actions(self) -> List[ActionWithoutParams]: """Get available actions.""" return [] diff --git a/stix_taxii/manifest.json b/stix_taxii/manifest.json index d519ee55..4406ab6e 100644 --- a/stix_taxii/manifest.json +++ b/stix_taxii/manifest.json @@ -1,174 +1,39 @@ { - "name": "STIX/TAXII", - "id": "stix_taxii", - "version": "3.1.0", - "module": "CTE", - "description": "This plugin is used to fetch the indicators of type Domain, IPv4, IPv6, URL and Hash (MD5 and SHA256) from the TAXII feeds and extracts observables from them. This plugin does not support sharing of indicators to TAXII feeds.", - "patch_supported": false, - "push_supported": false, - "configuration": [ - { - "label": "STIX/TAXII Version", - "key": "version", - "type": "choice", - "choices": [ - { - "key": "1.1", - "value": "1" - }, - { - "key": "2.0", - "value": "2.0" - }, - { - "key": "2.1", - "value": "2.1" - } - ], - "mandatory": true, - "description": "STIX/TAXII Version.", - "default": "1" - }, - { - "label": "Discovery URL/API Root URL", - "key": "discovery_url", - "type": "text", - "default": "", - "mandatory": true, - "description": "Discovery/Feed URL of TAXII server for version 1.x and API Root URL for version 2.x. Contact your STIX/TAXII support to get the appropriate URL." - }, - { - "label": "Username", - "key": "username", - "type": "text", - "mandatory": false, - "description": "Username required for authentication if any." - }, - { - "label": "Password", - "key": "password", - "type": "password", - "mandatory": false, - "description": "Password required for authentication if any." - }, - { - "label": "Collection Names", - "key": "collection_names", - "type": "text", - "default": "", - "mandatory": false, - "description": "Comma separated collection names from which data needs to be fetched. Leave empty to fetch data from all of the collections." - }, - { - "label": "Pagination Method", - "key": "pagination_method", - "description": "Pagination Method to use while pulling the indicators. Contact your STIX/TAXII support to choose the appropriate option.", - "type": "choice", - "choices": [ - { - "key": "Next", - "value": "next" - }, - { - "key": "X-TAXII-Date-Added-Last", - "value": "last_added_date" - } - ], - "mandatory": true, - "default": "next" - }, - { - "label": "Initial Range (in days)", - "key": "days", - "type": "number", - "mandatory": true, - "default": 7, - "description": "Number of days to pull the data for the initial run." - }, - { - "label": "Look Back (in minutes)", - "key": "delay", - "type": "number", - "mandatory": false, - "description": "Number of minutes to backdate the start time for pulling the data. Valid value is anything between 0 to 1440." - }, - { - "label": "Type of Threat data to pull", - "key": "type", - "type": "multichoice", - "choices": [ - { - "key": "SHA-256", - "value": "sha256" - }, - { - "key": "MD5", - "value": "md5" - }, - { - "key": "URL", - "value": "url" - }, - { - "key": "IPv4", - "value": "ipv4" - }, - { - "key": "IPv6", - "value": "ipv6" - }, - { - "key": "Domain", - "value": "domain" - } - ], - "default": ["sha256", "md5", "url", "ipv4", "ipv6", "domain"], - "mandatory": true, - "description": "Type of Threat data to pull. Note: IPv4/IPv6 is supported for STIX/TAXII version 2.x." - }, - { - "label": "Severity", - "key": "severity", - "type": "multichoice", - "choices": [ - { - "key": "Unknown", - "value": "unknown" - }, - { - "key": "Low", - "value": "low" - }, - { - "key": "Medium", - "value": "medium" - }, - { - "key": "High", - "value": "high" - }, - { - "key": "Critical", - "value": "critical" + "name": "STIX/TAXII", + "id": "stix_taxii", + "version": "3.2.0", + "module": "CTE", + "description": "This plugin is used to fetch IOCs of type Domain, URL and Hash (MD5 and SHA256) for version 1.1 and IOCs of type Domain, URL, IPv4, IPv6, and Hash (MD5 and SHA256) for version 2.1/2.2 from the TAXII feeds and extracts observables from them. This plugin supports retraction of IOCs pulled from TAXII feeds. This plugin does not support sharing of indicators to TAXII feeds.", + "patch_supported": false, + "push_supported": false, + "fetch_retraction_info": true, + "minimum_version": "6.0.0", + "configuration": [ + { + "label": "STIX/TAXII Version", + "key": "version", + "type": "choice", + "choices": [ + { + "key": "1.1", + "value": "1" + }, + { + "key": "2.0", + "value": "2.0" + }, + { + "key": "2.1", + "value": "2.1" + } + ], + "mandatory": true, + "description": "STIX/TAXII Version.", + "default": "1", + "has_api_call": true, + "payload_fields": [ + "version" + ] } - ], - "default": [ - "critical", - "high", - "medium", - "low", - "unknown" - ], - "mandatory": false, - "description": "Only indicators with matching severity will be fetched. For STIX/TAXII version 2.x, Unknown should be selected because for all the indicators fetched from these versions would have Unknown severity." - }, - { - "label": "Reputation", - "key": "reputation", - "type": "number", - "mandatory": true, - "default": 5, - "description": "Only indicators with reputation equal to or greater than this will be saved." - } - ] + ] } \ No newline at end of file diff --git a/stix_taxii/utils/constants.py b/stix_taxii/utils/constants.py index 2807e9af..97871c34 100644 --- a/stix_taxii/utils/constants.py +++ b/stix_taxii/utils/constants.py @@ -35,6 +35,9 @@ from netskope.integrations.cte.models.indicator import SeverityType # user agent format +MODULE_NAME = "CTE" +PLUGIN_VERSION = "3.2.0" +PLATFORM_NAME = "STIX/TAXII" USER_AGENT_FORMAT = "{}-{}-{}-v{}" USER_AGENT_KEY = "User-Agent" DEFAULT_USER_AGENT = "netskope-ce" @@ -42,6 +45,14 @@ STIX_VERSION_20 = "2.0" STIX_VERSION_21 = "2.1" SERVICE_TYPE = "COLLECTION_MANAGEMENT" +RETRACTION = "Retraction" +IN_EXECUTION_MAX_RETRIES = 3 +IN_EXECUTION_SLEEP_TIME = 60 # Sleep time in seconds +DATE_CONVERSION_STRING = "%Y-%m-%dT%H:%M:%S.%fZ" +DATE_FORMAT_STRING = "%Y-%m-%dT%H:%M:%S%fZ" + +# Display format for validity times in comments +VALIDITY_DISPLAY_FORMAT = "%Y-%m-%d %H:%M:%S UTC" CONFIDENCE_TO_REPUTATION_MAPPINGS = { "High": 10, @@ -100,14 +111,219 @@ ), }, ] -DATE_CONVERSION_STRING = "%Y-%m-%dT%H:%M:%S.%fZ" -DATE_FORMAT_STRING = "%Y-%m-%dT%H:%M:%S%fZ" -# page size -LIMIT = 1000 -# page -BUNDLE_LIMIT = 100 +# Error resolution messages +PROXY_ERROR_RESOLUTION = ( + "Ensure the proxy server address, port, username, and " + "password are correctly configured and the proxy server is " + "accessible from the network." +) -MODULE_NAME = "CTE" -PLUGIN_VERSION = "3.1.0" -PLATFORM_NAME = "STIX/TAXII" +CONNECTION_ERROR_RESOLUTION = ( + "Ensure the Discovery URL/API Root URL is correct, " + "accessible from your network, and the TAXII server is " + "running and reachable." +) + +# Configuration field constants +DISCOVERY_URL_V1 = { + "label": "Discovery URL", + "key": "discovery_url", + "type": "text", + "default": "", + "mandatory": True, + "description": ( + "Discovery/Feed URL of TAXII server for version 1.x. " + "Contact your STIX/TAXII support to get the appropriate URL." + ) +} + +DISCOVERY_URL_V2 = { + "label": "API Root URL", + "key": "discovery_url", + "type": "text", + "default": "", + "mandatory": True, + "description": ( + "API Root URL of TAXII server for version 2.x. " + "Contact your STIX/TAXII support to get the appropriate URL." + ) +} + +USERNAME_CONFIG = { + "label": "Username", + "key": "username", + "type": "text", + "mandatory": False, + "description": "Username required for authentication if any." +} + +PASSWORD_CONFIG = { + "label": "Password", + "key": "password", + "type": "password", + "mandatory": False, + "description": "Password required for authentication if any." +} + +COLLECTION_NAMES_CONFIG = { + "label": "Collection Names", + "key": "collection_names", + "type": "text", + "default": "", + "mandatory": False, + "description": ( + "Comma separated collection names from which data needs to be " + "fetched. Leave empty to fetch data from all of the collections." + ) +} + +PAGINATION_METHOD_CONFIG_V2 = { + "label": "Pagination Method", + "key": "pagination_method", + "description": ( + "Pagination Method to use while pulling the indicators. " + "Contact your STIX/TAXII support to choose the appropriate option." + ), + "type": "choice", + "choices": [ + {"key": "Next", "value": "next"}, + {"key": "X-TAXII-Date-Added-Last", "value": "last_added_date"} + ], + "mandatory": True, + "default": "next" +} + +INITIAL_RANGE_CONFIG = { + "label": "Initial Range (in days)", + "key": "days", + "type": "number", + "mandatory": True, + "default": 7, + "description": "Number of days to pull the data for the initial run. Must be an integer in range 1 to 365." +} + +LOOK_BACK_CONFIG = { + "label": "Look Back (in minutes)", + "key": "delay", + "type": "number", + "mandatory": False, + "description": ( + "Number of minutes to backdate the start time for pulling the data. " + "Valid value is anything between 1 to 1440." + ) +} + +TYPE_V1 = { + "label": "Type of Threat data to pull", + "key": "type", + "type": "multichoice", + "choices": [ + {"key": "SHA-256", "value": "sha256"}, + {"key": "MD5", "value": "md5"}, + {"key": "URL", "value": "url"}, + {"key": "Domain", "value": "domain"} + ], + "default": ["sha256", "md5", "url", "domain"], + "mandatory": False, + "description": "Type of Threat data to pull. Keep empty to fetch indicators of all types." +} + +TYPE_V2 = { + "label": "Type of Threat data to pull", + "key": "type", + "type": "multichoice", + "choices": [ + {"key": "SHA-256", "value": "sha256"}, + {"key": "MD5", "value": "md5"}, + {"key": "URL", "value": "url"}, + {"key": "IPv4", "value": "ipv4"}, + {"key": "IPv6", "value": "ipv6"}, + {"key": "Domain", "value": "domain"} + ], + "default": ["sha256", "md5", "url", "ipv4", "ipv6", "domain"], + "mandatory": False, + "description": ( + "Type of Threat data to pull. " + "IPv4/IPv6 is supported for STIX/TAXII version 2.x. " + "Keep empty to fetch indicators of all types." + ) +} + +SEVERITY_V1 = { + "label": "Severity", + "key": "severity", + "type": "multichoice", + "choices": [ + {"key": "Unknown", "value": "unknown"}, + {"key": "Low", "value": "low"}, + {"key": "Medium", "value": "medium"}, + {"key": "High", "value": "high"}, + {"key": "Critical", "value": "critical"} + ], + "default": ["critical", "high", "medium", "low", "unknown"], + "mandatory": False, + "description": ( + "Only indicators with matching severity will be fetched. " + "Keep empty to fetch indicators of all severity." + ) +} + +SEVERITY_V2 = { + "label": "Severity", + "key": "severity", + "type": "multichoice", + "choices": [ + {"key": "Unknown", "value": "unknown"}, + ], + "default": ["unknown"], + "mandatory": False, + "description": ( + "Only indicators with matching severity will be fetched. " + "For STIX/TAXII version 2.x, Unknown should be selected because " + "for all the indicators fetched from these versions would have " + "Unknown severity. " + "Keep empty to fetch indicators of all severity." + ) +} + +REPUTATION_CONFIG = { + "label": "Reputation", + "key": "reputation", + "type": "number", + "mandatory": True, + "default": 5, + "description": ( + "Only indicators with reputation equal to or greater than " + "this will be stored. Must be an integer in range 1 to 10." + ) +} + +BATCH_SIZE_CONFIG_V20 = { + "label": "Batch Size", + "key": "batch_size", + "type": "number", + "mandatory": True, + "default": 1000, + "description": "Number of indicators to fetch per bundle. Must be an integer in range 2 to 1000." +} + +BATCH_SIZE_CONFIG_V21 = { + "label": "Batch Size", + "key": "batch_size", + "type": "number", + "mandatory": True, + "default": 1000, + "description": "Number of indicators to fetch per bundle. Must be an integer in range 1 to 1000." +} + +RETRACTION_INTERVAL_CONFIG = { + "label": "Retraction Interval (in days)", + "key": "retraction_interval", + "type": "number", + "mandatory": False, + "description": ( + "Number of days to look back for retraction checks. " + "Leave empty to disable retraction. Must be an integer in range 1 to 365." + ) +} diff --git a/stix_taxii/utils/helper.py b/stix_taxii/utils/helper.py index b53b8eee..40fea9f7 100644 --- a/stix_taxii/utils/helper.py +++ b/stix_taxii/utils/helper.py @@ -34,6 +34,7 @@ from datetime import datetime from typing import Dict, Any, Union from netskope.common.utils import add_user_agent +import pytz from .constants import ( STIX_VERSION_1, USER_AGENT_FORMAT, @@ -53,16 +54,27 @@ class STIXTAXIIException(Exception): def get_configuration_parameters( - configuration: Dict[str, Any], is_validation: bool = False + configuration: Dict[str, Any], + is_validation: bool = False, + keys: list = [] ): """ Get configuration parameters. Args: configuration (Dict[str, Any]): Configuration dictionary. + is_validation (bool): Whether this is for validation. + keys (list, optional): List of specific keys to return. If provided, + returns tuple of only those values in the order specified. Returns: + tuple: Tuple of configuration parameters. + If keys is specified, returns only those values. + Available keys: + version, discovery_url, username, password, collection_names, + pagination_method, days, delay, type_to_pull, severity, + reputation, batch_size, retraction_interval """ version = configuration.get( "version", STIX_VERSION_1 if not is_validation else "" @@ -75,7 +87,7 @@ def get_configuration_parameters( "pagination_method", "next" if not is_validation else "" ).strip() days = configuration.get("days", 7 if not is_validation else None) - delay = configuration.get("delay", 0) or 0 + delay = configuration.get("delay", 0 if not is_validation else None) type_to_pull = configuration.get( "type", ( @@ -88,6 +100,32 @@ def get_configuration_parameters( reputation = configuration.get( "reputation", 5 if not is_validation else None ) + batch_size = configuration.get( + "batch_size", 1000 if not is_validation else None + ) + retraction_interval = configuration.get( + "retraction_interval", + 0 if not is_validation else None + ) + + all_params = { + "version": version, + "discovery_url": discovery_url, + "username": username, + "password": password, + "collection_names": collection_names, + "pagination_method": pagination_method, + "days": days, + "delay": delay, + "type_to_pull": type_to_pull, + "severity": severity, + "reputation": reputation, + "batch_size": batch_size, + "retraction_interval": retraction_interval, + } + + if keys: + return tuple(all_params[key] for key in keys) return ( version, @@ -101,13 +139,33 @@ def get_configuration_parameters( type_to_pull, severity, reputation, + batch_size, + retraction_interval, ) +def ensure_utc_aware(dt) -> datetime: + """Ensure datetime is UTC-aware (timezone-aware, converted to UTC). + + Args: + dt (datetime): A datetime object, either naive or aware. + + Returns: + datetime: A timezone-aware datetime object in UTC. + """ + if dt is None: + return pytz.utc.localize(datetime.now()) + # If dt is timezone-naive, assume UTC. If it is already timezone-aware, + # normalize it to UTC for consistent comparisons and serialization. + if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: + return pytz.utc.localize(dt) + # If dt is timezone aware, ensure it is converted to UTC. + return dt.astimezone(pytz.utc) def str_to_datetime( string: str, date_format: str = DATE_FORMAT_STRING, replace_dot: bool = True, + return_now_on_error: bool = True, ) -> datetime: """Convert ISO formatted string to datetime object. @@ -122,8 +180,7 @@ def str_to_datetime( string.replace(".", "") if replace_dot else string, date_format ) except ValueError: - return datetime.now() - + return datetime.now() if return_now_on_error else None def add_ce_user_agent( headers: Union[Dict, None] = None,