Source code for api.correlation_engine

"""The correlation engine for Huntsman."""
import os
import yaml
from typing import List, Dict, Any, Optional, Set
from django.conf import settings
from .db.superdb_client import SuperDBClient


[docs] class CorrelationRule: """ Represents a single correlation rule. Attributes ---------- id : str The unique identifier of the rule. title : str The title of the rule. description : str A description of what the rule does. syntax : str The query syntax for the rule. author : str The author of the rule. source : str The source of the rule. tags : List[str] A list of tags associated with the rule. file_path : str The file path where the rule is defined. """ def __init__(self, rule_data: Dict[str, Any], file_path: str) -> None: """ Initialize a CorrelationRule object. Parameters ---------- rule_data : dict The dictionary containing the rule data. file_path : str The file path where the rule is defined. Raises ------ ValueError If a required field is missing in the rule data. """ required_fields = ['id', 'title', 'description', 'syntax'] if not all(field in rule_data for field in required_fields): raise ValueError(f"Rule file {file_path} is missing one or more required fields: {required_fields}") self.id: str = rule_data['id'] self.title: str = rule_data['title'] self.description: str = rule_data['description'] self.syntax: str = rule_data['syntax'].strip() self.author: str = rule_data.get('author', 'Unknown') self.source: str = rule_data.get('source', 'Unknown') self.tags: List[str] = rule_data.get('tags', []) self.file_path: str = file_path self.pool_name: str = self._extract_pool_name() def _extract_pool_name(self) -> str: """Extract the pool name from the rule's syntax.""" if " | " in self.syntax: pool_part = self.syntax.split(" | ")[0] return pool_part.replace("from ", "").replace("'", "").strip() return "default"
[docs] def execute_query(self, task_id: str, superdb_client: SuperDBClient) -> Dict[str, Any]: """ Execute the rule's query. Parameters ---------- task_id : str The ID of the task to correlate. superdb_client : SuperDBClient The SuperDB client to use for executing the query. Returns ------- dict A dictionary containing the results of the query execution. """ parts = self.syntax.split(" | ", 1) query = f"{parts[0]} | task_id == '{task_id}' | {parts[1]}" try: result = superdb_client.execute_query(query=query, pool=self.pool_name) if result is None: raise Exception("Query execution failed - likely pool not found or query syntax error") return { "rule_id": self.id, "rule_title": self.title, "tags": self.tags, "description": self.description, "query": query, "matches": len(result) if result else 0, "results": result, "status": "success" } except Exception as e: return { "rule_id": self.id, "rule_title": self.title, "tags": self.tags, "description": self.description, "query": query, "matches": 0, "results": [], "status": "error", "error": str(e) }
[docs] class CorrelationEngine: """ Main correlation engine for loading and executing rules. Attributes ---------- rules_directory_path : str The path to the directory containing the correlation rules. rules : List[CorrelationRule] A list of loaded correlation rules. superdb_client : SuperDBClient The SuperDB client used for executing queries. """ _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(CorrelationEngine, cls).__new__(cls) return cls._instance def __init__(self, rules_directory_path: Optional[str] = None) -> None: """ Initialize a CorrelationEngine object. Parameters ---------- rules_directory_path : str, optional The path to the directory containing the correlation rules. If not provided, it will be loaded from the Django settings. """ if not hasattr(self, 'initialized'): self.rules_directory_path: Optional[str] = rules_directory_path or getattr(settings, 'CORRELATION_RULES_PATH', None) self.rules: List[CorrelationRule] = [] self.superdb_client: SuperDBClient = SuperDBClient() self.load_rules() self.initialized = True
[docs] def load_rules(self) -> None: """ Load correlation rules from the specified directory. Raises ------ FileNotFoundError If the rules directory is not found or is not a directory. """ if not self.rules_directory_path or not os.path.isdir(self.rules_directory_path): raise FileNotFoundError(f"Rules directory not found or is not a directory: {self.rules_directory_path}") self.rules = [] loaded_ids: Set[str] = set() print(f"--- Loading correlation rules from: {self.rules_directory_path} ---") for root, _, files in os.walk(self.rules_directory_path): for file in files: if file.endswith((".yml", ".yaml")): file_path = os.path.join(root, file) try: with open(file_path, 'r') as f: rule_data_list = yaml.safe_load_all(f) for rule_data in rule_data_list: if not rule_data: continue rule_id = rule_data.get('id') if rule_id and rule_id in loaded_ids: print(f" [-] Skipping duplicate rule ID {rule_id} in {file_path}") continue rule = CorrelationRule(rule_data=rule_data, file_path=file_path) self.rules.append(rule) loaded_ids.add(rule.id) print(f" [+] Loaded rule: {rule.title}") except yaml.YAMLError as e: print(f" [!] ERROR parsing YAML file {file_path}: {e}") except ValueError as e: print(f" [!] ERROR validating rule data in {file_path}: {e}") except Exception as e: print(f" [!] ERROR loading rule file {file_path}: {e}") print(f"--- Successfully loaded {len(self.rules)} rules. ---")
[docs] def get_rules(self, rule_titles: Optional[List[str]] = None, tags_filter: Optional[List[str]] = None) -> List[CorrelationRule]: """ Get a filtered list of correlation rules. Parameters ---------- rule_titles : list of str, optional A list of rule titles to filter by. tags_filter : list of str, optional A list of tags to filter by. Returns ------- list of CorrelationRule A list of correlation rules that match the filter criteria. """ filtered_rules = self.rules if rule_titles: filtered_rules = [rule for rule in filtered_rules if rule.title in rule_titles] if tags_filter: tags_set = set(tags_filter) filtered_rules = [ rule for rule in filtered_rules if tags_set.intersection(rule.tags) ] return filtered_rules
[docs] def run_correlation_analysis(self, task_id: str, service_name: Optional[str] = None, rule_titles: Optional[List[str]] = None, tags_filter: Optional[List[str]] = None) -> Dict[str, Any]: """ Run correlation analysis for a specific task ID. Parameters ---------- task_id : str The ID of the task to analyze. service_name : str, optional The name of the service to run the rules against. rule_titles : list of str, optional A list of rule titles to run. If not provided, all rules will be run. tags_filter : list of str, optional A list of tags to filter the rules to be run. Returns ------- dict A dictionary containing the results of the correlation analysis. """ rules_to_run = self.get_rules(rule_titles, tags_filter) if service_name: rules_to_run = [ rule for rule in rules_to_run if rule.pool_name == service_name or rule.pool_name == "*" ] if not rules_to_run: return { "task_id": task_id, "total_rules": len(self.rules), "rules_executed": 0, "matches_found": 0, "results": [], "status": "no_rules_matched_filters" } results = [] total_matches = 0 for rule in rules_to_run: rule_result = rule.execute_query(task_id, self.superdb_client) results.append(rule_result) total_matches += rule_result.get('matches', 0) results.sort(key=lambda x: x['matches'], reverse=True) return { "task_id": task_id, "total_rules": len(self.rules), "rules_executed": len(rules_to_run), "matches_found": total_matches, "results": results, "status": "completed", "summary": self._generate_summary(results) }
def _generate_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """ Generate a summary of the correlation analysis results. Parameters ---------- results : list of dict A list of dictionaries, where each dictionary is the result of a single rule execution. Returns ------- dict A dictionary summarizing the results. """ summary = { "total_matches": 0, "rules_with_matches": 0, "failed_rules": 0 } for result in results: matches = result.get('matches', 0) status = result.get('status', 'unknown') summary["total_matches"] += matches if matches > 0: summary["rules_with_matches"] += 1 if status == "error": summary["failed_rules"] += 1 return summary
[docs] def create_correlation_engine_instance() -> CorrelationEngine: """ Create a new instance of the CorrelationEngine. Returns ------- CorrelationEngine A new instance of the CorrelationEngine. """ return CorrelationEngine()