Source code for api.views

"""Views for the Huntsman API."""
from django.db import transaction
from django.db import connection
from django.db.utils import OperationalError
from django.views.generic import TemplateView
from django.http import HttpRequest, JsonResponse
from django.contrib.auth.decorators import login_required
from django.utils.decorators import method_decorator
from rest_framework import generics, status
from rest_framework.views import APIView
from rest_framework.response import Response
from huntsman.celery import app as celery_app
from typing import Dict, Any, List
import functools
import json
import logging
from django.conf import settings

from .config import (
    load_api_recipes, load_internal_services_recipes, 
    load_scraping_recipes, load_ioc_patterns, 
    load_predefined_queries
)
from .models import AnalysisTask
from .serializers import (
    AnalysisTaskSerializer, BulkTaskSubmissionSerializer,
    TaskSubmissionSerializer, SuperDBQuerySerializer,
    BulkTaskStatusRequestSerializer, CorrelationEngineSerializer,
    CreatePoolSerializer, LoadDataToBranchSerializer,
    STIXReportSerializer, BulkSTIXReportSerializer,
    AIAnalysisSerializer
)
from .tasks import (
    run_analysis_task, run_superdb_query_task, run_create_pool_task, 
    run_load_data_to_branch_task, run_stix_report_creation_task,
    run_bulk_stix_report_creation_task, run_ai_analysis_task
)
from .correlation_engine import create_correlation_engine_instance

logger = logging.getLogger(__name__)

[docs] @method_decorator(login_required, name='dispatch') class DashboardView(TemplateView): """A view to render the main dashboard.""" template_name = "api/dashboard.html"
[docs] def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: """ Get the context data for rendering the dashboard. This method loads IOC patterns and adds them to the context. Parameters ---------- **kwargs : Any Arbitrary keyword arguments. Returns ------- dict The context data for the template. """ context = super().get_context_data(**kwargs) context['ioc_patterns'] = json.dumps(load_ioc_patterns()) context['max_ioc_submission_limit'] = settings.MAX_IOC_SUBMISSION_LIMIT return context
[docs] class AnalysisTriggerView(APIView): """An API view to trigger a single analysis task."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to trigger an analysis task. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = TaskSubmissionSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name=data['service_name'], identifier=data['identifier'], identifier_type=data['identifier_type'], status=AnalysisTask.Status.PENDING ) transaction.on_commit( functools.partial(run_analysis_task.delay, task_db_id=str(task_record.id)) ) except Exception as e: return Response({"error": "Failed to create task in database.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( {"message": "Analysis task has been queued.", "task_id": task_record.id}, status=status.HTTP_202_ACCEPTED )
[docs] class BulkAnalysisTriggerView(APIView): """An API view to trigger multiple analysis tasks in bulk."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to trigger bulk analysis tasks. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = BulkTaskSubmissionSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) tasks_data = serializer.validated_data['tasks'] task_ids = [] try: with transaction.atomic(): for task_data in tasks_data: task_record = AnalysisTask.objects.create( service_name=task_data['service_name'], identifier=task_data['identifier'], identifier_type=task_data['identifier_type'], status=AnalysisTask.Status.PENDING ) transaction.on_commit( functools.partial(run_analysis_task.delay, task_db_id=str(task_record.id)) ) task_ids.append(str(task_record.id)) except Exception as e: logger.error(f"Bulk task creation failed: {str(e)}") return Response({"error": "Failed to create bulk tasks in database.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( {"message": f"{len(task_ids)} analysis tasks have been queued.", "task_ids": task_ids}, status=status.HTTP_202_ACCEPTED )
[docs] class TaskStatusView(generics.RetrieveAPIView): """An API view to retrieve the status of a single analysis task.""" queryset = AnalysisTask.objects.all() serializer_class = AnalysisTaskSerializer lookup_field = 'id'
[docs] class TaskListView(generics.ListAPIView): """An API view to list all analysis tasks, with optional filtering by status.""" queryset = AnalysisTask.objects.all().order_by('-created_at') serializer_class = AnalysisTaskSerializer
[docs] def get_queryset(self) -> Any: """ Get the queryset for listing tasks. This method filters the queryset based on the 'status' query parameter. Returns ------- QuerySet The filtered queryset of AnalysisTask objects. """ queryset = super().get_queryset() status_filter = self.request.query_params.get('status') if status_filter: queryset = queryset.filter(status__iexact=status_filter) return queryset
[docs] class SuperDBQueryView(APIView): """An API view to execute a query on SuperDB."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to execute a SuperDB query. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = SuperDBQuerySerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data if data.get('query'): query_string = data['query'] else: from_pool = data['from_pool'] commands = data.get('commands', []) query_parts = [f"from {from_pool}"] query_parts.extend(commands) query_string = " | ".join(query_parts) try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='superdb_query', identifier=query_string, identifier_type='etl_query', status=AnalysisTask.Status.PENDING ) task_record.save() transaction.on_commit( functools.partial(run_superdb_query_task.delay, task_db_id=str(task_record.id)) ) except Exception as e: return Response({"error": "Failed to create query task in database.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( { "message": "Database query task has been queued.", "task_id": task_record.id, "query": query_string, }, status=status.HTTP_202_ACCEPTED )
[docs] class PredefinedQueriesListView(APIView): """An API view to list all predefined SuperDB queries."""
[docs] def get(self, request: HttpRequest, format: Any = None) -> Response: """ Handle GET requests to list predefined queries. Parameters ---------- request : HttpRequest The HTTP request object. format : Any, optional The format of the response. Returns ------- Response A DRF Response object containing the predefined queries. """ queries = load_predefined_queries() return Response(queries, status=status.HTTP_200_OK)
[docs] class CreatePoolView(APIView): """An API view to create a new pool in SuperDB."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to create a SuperDB pool. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = CreatePoolSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data pool_name = data['name'] layout_order = data.get('layout_order', 'asc') layout_keys = data.get('layout_keys', [['ts']]) thresh = data.get('thresh') try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='superdb_create_pool', identifier=pool_name, identifier_type='pool_creation', status=AnalysisTask.Status.PENDING ) task_record.save() transaction.on_commit( functools.partial( run_create_pool_task.delay, task_db_id=str(task_record.id), name=pool_name, layout_order=layout_order, layout_keys=layout_keys, thresh=thresh ) ) except Exception as e: return Response({"error": "Failed to create pool task in database.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( { "message": f"SuperDB pool creation task for '{pool_name}' has been queued.", "task_id": task_record.id, "pool_name": pool_name }, status=status.HTTP_202_ACCEPTED )
[docs] class LoadDataToBranchView(APIView): """An API view to load data into a branch of a SuperDB pool."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to load data into a SuperDB branch. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = LoadDataToBranchSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data pool_id_or_name = data['pool_id_or_name'] branch_name = data['branch_name'] load_data = data['data'] csv_delim = data.get('csv_delim', ',') try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='superdb_load_data', identifier=f"Load data to {branch_name} in {pool_id_or_name}", identifier_type='data_loading', status=AnalysisTask.Status.PENDING ) task_record.save() transaction.on_commit( functools.partial( run_load_data_to_branch_task.delay, task_db_id=str(task_record.id), pool_id_or_name=pool_id_or_name, branch_name=branch_name, data=load_data, csv_delim=csv_delim ) ) except Exception as e: return Response({"error": "Failed to create load data task in. database.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( { "message": f"Data loading task for branch '{branch_name}' in pool '{pool_id_or_name}' has been queued.", "task_id": task_record.id, "pool_id_or_name": pool_id_or_name, "branch_name": branch_name }, status=status.HTTP_202_ACCEPTED )
[docs] class ServiceListView(APIView): """An API view to list all available services and their supported types."""
[docs] def get(self, request: HttpRequest, format: Any = None) -> Response: """ Handle GET requests to list available services. Parameters ---------- request : HttpRequest The HTTP request object. format : Any, optional The format of the response. Returns ------- Response A DRF Response object containing the list of services. """ api_recipes = load_api_recipes() internal_recipes = load_internal_services_recipes() scraping_recipes = load_scraping_recipes() combined_recipes = {**api_recipes, **internal_recipes, **scraping_recipes} service_list = [] for name, recipe in combined_recipes.items(): if not recipe.get('enabled', True): continue if "endpoints" not in recipe or not isinstance(recipe["endpoints"], dict): continue supported_types = list(recipe["endpoints"].keys()) service_list.append({ "name": name, "label": recipe.get("label", name.replace("_", " ").title()), "supported_types": supported_types }) return Response(service_list)
[docs] class HealthCheckView(APIView): """An API view to perform health checks on the system components."""
[docs] def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle GET requests to perform health checks. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object with the health check results. """ checks = { "database": self.check_database(), "celery_redis": self.check_celery(), } if any(check["status"] == "error" for check in checks.values()): return Response(checks, status=status.HTTP_503_SERVICE_UNAVAILABLE) return Response(checks, status=status.HTTP_200_OK)
[docs] def check_database(self) -> Dict[str, str]: """ Check the status of the database connection. Returns ------- dict A dictionary with the status and a message. """ try: connection.ensure_connection() return {"status": "ok", "message": "Database connection successful."} except OperationalError as e: return {"status": "error", "message": f"Database connection failed: {str(e)}"}
[docs] def check_celery(self) -> Dict[str, str]: """ Check the status of the Celery workers. Returns ------- dict A dictionary with the status and a message. """ try: ping = celery_app.control.ping(timeout=3) if ping: worker_count = len(ping) return {"status": "ok", "message": f"{worker_count} Celery worker(s) responded."} else: return {"status": "error", "message": "No Celery workers responded to ping."} except Exception as e: return {"status": "error", "message": f"Celery check failed: {str(e)}"}
[docs] class BulkTaskStatusView(APIView): """An API view to retrieve the status of multiple tasks in bulk."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to get the status of bulk tasks. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object with the status of the requested tasks. """ logger.info(f"Received bulk task status request with data: {request.data}") serializer = BulkTaskStatusRequestSerializer(data=request.data) if not serializer.is_valid(): logger.error(f"Bulk task status request validation failed: {serializer.errors}") return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) task_ids = serializer.validated_data['task_ids'] tasks = AnalysisTask.objects.filter(id__in=task_ids) logger.info(f"Found {tasks.count()} tasks for IDs: {task_ids}") results_serializer = AnalysisTaskSerializer(tasks, many=True) return Response(results_serializer.data, status=status.HTTP_200_OK)
[docs] class CorrelationEngineView(APIView): """An API view to interact with the Correlation Engine."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to run the correlation engine. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object with the correlation analysis results. """ serializer = CorrelationEngineSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data task_id = data['task_id'] try: task = AnalysisTask.objects.get(id=task_id) except AnalysisTask.DoesNotExist: return Response( {"error": f"Task with ID {task_id} not found"}, status=status.HTTP_404_NOT_FOUND ) if task.status not in [AnalysisTask.Status.SUCCESS, AnalysisTask.Status.FAILURE]: return Response( { "warning": f"Task {task_id} is in '{task.status}' status. Results may be incomplete.", "task_status": task.status }, status=status.HTTP_202_ACCEPTED ) try: engine = create_correlation_engine_instance() analysis_result = engine.run_correlation_analysis( task_id=str(task_id), service_name=task.service_name, rule_titles=data.get('rules'), tags_filter=data.get('tags_filter') ) analysis_result['task_metadata'] = { 'task_id': str(task.id), 'service_name': task.service_name, 'identifier': task.identifier, 'identifier_type': task.identifier_type, 'task_status': task.status, 'created_at': task.created_at, 'completed_at': task.completed_at } return Response(analysis_result, status=status.HTTP_200_OK) except FileNotFoundError as e: return Response( {"error": "Rules configuration file not found", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) except ValueError as e: return Response( {"error": "Invalid rules configuration", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) except Exception as e: return Response( {"error": "Correlation analysis failed", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR )
[docs] def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle GET requests to list correlation rules and tags. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object with the list of rules and tags. """ try: engine = create_correlation_engine_instance() rules_info = [] for rule in engine.rules: rules_info.append({ "title": rule.title, "description": rule.description, "tags": rule.tags, "syntax_template": rule.syntax }) return Response({ "total_rules": len(rules_info), "rules": rules_info, "available_tags": list(set(tag for rule in engine.rules for tag in rule.tags)) }, status=status.HTTP_200_OK) except Exception as e: return Response( {"error": "Failed to load rules information", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR )
[docs] class STIXReportView(APIView): """An API view to create a STIX report."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to create a STIX report. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = STIXReportSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data source_task_id = data['task_id'] try: source_task = AnalysisTask.objects.get(id=source_task_id) except AnalysisTask.DoesNotExist: return Response({"error": f"Source task {source_task_id} not found"}, status=status.HTTP_404_NOT_FOUND) try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='stix_report_creation', identifier=f"STIX Report for {source_task.identifier}", identifier_type='report_generation', status=AnalysisTask.Status.PENDING ) transaction.on_commit( functools.partial( run_stix_report_creation_task.delay, task_db_id=str(task_record.id), source_task_id=str(source_task_id) ) ) except Exception as e: return Response({"error": "Failed to create STIX report task.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( {"message": "STIX report creation task has been queued.", "task_id": task_record.id}, status=status.HTTP_202_ACCEPTED )
[docs] class BulkSTIXReportView(APIView): """An API view to create a bulk STIX report from multiple tasks."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to create a bulk STIX report. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = BulkSTIXReportSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) report_mappings = serializer.validated_data['reports'] try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='stix_bulk_report_creation', identifier=f"Bulk report job with {len(report_mappings)} reports", identifier_type='bulk_report', status=AnalysisTask.Status.PENDING ) transaction.on_commit( functools.partial( run_bulk_stix_report_creation_task.delay, task_db_id=str(task_record.id), report_mappings=report_mappings ) ) except Exception as e: return Response({"error": "Failed to create bulk STIX report task.", "details": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response( {"message": "Bulk STIX report creation task has been queued.", "task_id": task_record.id}, status=status.HTTP_202_ACCEPTED )
[docs] class AIAnalysisView(APIView): """An API view to perform AI analysis on a given dataset."""
[docs] def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response: """ Handle POST requests to perform AI analysis. Parameters ---------- request : HttpRequest The HTTP request object. *args : Any Variable length argument list. **kwargs : Any Arbitrary keyword arguments. Returns ------- Response A DRF Response object. """ serializer = AIAnalysisSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) data = serializer.validated_data['data'] prompt = serializer.validated_data['prompt'] system_prompt = serializer.validated_data.get('system_prompt', 'You are a world class CTI Analyst. Give brief summary and provide actionable insights.') try: with transaction.atomic(): task_record = AnalysisTask.objects.create( service_name='ai_analysis', identifier=f"AI Analysis - {len(data)} records", identifier_type='ai_insight', status=AnalysisTask.Status.PENDING ) transaction.on_commit( functools.partial( run_ai_analysis_task.delay, task_db_id=str(task_record.id), data=data, prompt=prompt, system_prompt=system_prompt ) ) return Response( {"message": "AI Analysis task queued.", "task_id": task_record.id}, status=status.HTTP_202_ACCEPTED ) except Exception as e: return Response({"error": f"Failed to queue AI task: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)