Source code for parsedmarc.splunk

# -*- coding: utf-8 -*-

from __future__ import annotations

import json
import socket
from typing import Any, Union
from urllib.parse import urlparse

import requests
import urllib3

from parsedmarc.constants import USER_AGENT
from parsedmarc.log import logger
from parsedmarc.utils import human_timestamp_to_unix_timestamp

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


[docs] class SplunkError(RuntimeError): """Raised when a Splunk API error occurs"""
[docs] class HECClient(object): """A client for a Splunk HTTP Events Collector (HEC)""" # http://docs.splunk.com/Documentation/Splunk/latest/Data/AboutHEC # http://docs.splunk.com/Documentation/Splunk/latest/RESTREF/RESTinput#services.2Fcollector def __init__( self, url: str, access_token: str, index: str, source: str = "parsedmarc", verify=True, timeout=60, ): """ Initializes the HECClient Args: url (str): The URL of the HEC access_token (str): The HEC access token index (str): The name of the index source (str): The source name verify (bool): Verify SSL certificates timeout (float): Number of seconds to wait for the server to send data before giving up """ parsed_url = urlparse(url) self.url = "{0}://{1}/services/collector/event/1.0".format( parsed_url.scheme, parsed_url.netloc ) self.access_token = access_token.lstrip("Splunk ") self.index = index self.host = socket.getfqdn() self.source = source self.session = requests.Session() self.timeout = timeout self.session.verify = verify self._common_data: dict[str, Union[str, int, float, dict]] = dict( host=self.host, source=self.source, index=self.index ) self.session.headers = { "User-Agent": USER_AGENT, "Authorization": "Splunk {0}".format(self.access_token), }
[docs] def save_aggregate_reports_to_splunk( self, aggregate_reports: Union[list[dict[str, Any]], dict[str, Any]], ): """ Saves aggregate DMARC reports to Splunk Args: aggregate_reports: A list of aggregate report dictionaries to save in Splunk """ logger.debug("Saving aggregate reports to Splunk") if isinstance(aggregate_reports, dict): aggregate_reports = [aggregate_reports] if len(aggregate_reports) < 1: return data = self._common_data.copy() json_str = "" for report in aggregate_reports: for record in report["records"]: new_report: dict[str, Union[str, int, float, dict]] = dict() for metadata in report["report_metadata"]: new_report[metadata] = report["report_metadata"][metadata] new_report["interval_begin"] = record["interval_begin"] new_report["interval_end"] = record["interval_end"] new_report["normalized_timespan"] = record["normalized_timespan"] new_report["published_policy"] = report["policy_published"] new_report["source_ip_address"] = record["source"]["ip_address"] new_report["source_country"] = record["source"]["country"] new_report["source_reverse_dns"] = record["source"]["reverse_dns"] new_report["source_base_domain"] = record["source"]["base_domain"] new_report["source_type"] = record["source"]["type"] new_report["source_name"] = record["source"]["name"] new_report["message_count"] = record["count"] new_report["disposition"] = record["policy_evaluated"]["disposition"] new_report["spf_aligned"] = record["alignment"]["spf"] new_report["dkim_aligned"] = record["alignment"]["dkim"] new_report["passed_dmarc"] = record["alignment"]["dmarc"] new_report["header_from"] = record["identifiers"]["header_from"] new_report["envelope_from"] = record["identifiers"]["envelope_from"] if "dkim" in record["auth_results"]: new_report["dkim_results"] = record["auth_results"]["dkim"] if "spf" in record["auth_results"]: new_report["spf_results"] = record["auth_results"]["spf"] data["sourcetype"] = "dmarc:aggregate" timestamp = human_timestamp_to_unix_timestamp( new_report["interval_begin"] ) data["time"] = timestamp data["event"] = new_report.copy() json_str += "{0}\n".format(json.dumps(data)) if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: response = self.session.post(self.url, data=json_str, timeout=self.timeout) response = response.json() except Exception as e: raise SplunkError(e.__str__()) if response["code"] != 0: raise SplunkError(response["text"])
[docs] def save_forensic_reports_to_splunk( self, forensic_reports: Union[list[dict[str, Any]], dict[str, Any]], ): """ Saves forensic DMARC reports to Splunk Args: forensic_reports (list): A list of forensic report dictionaries to save in Splunk """ logger.debug("Saving forensic reports to Splunk") if isinstance(forensic_reports, dict): forensic_reports = [forensic_reports] if len(forensic_reports) < 1: return json_str = "" for report in forensic_reports: data = self._common_data.copy() data["sourcetype"] = "dmarc:forensic" timestamp = human_timestamp_to_unix_timestamp(report["arrival_date_utc"]) data["time"] = timestamp data["event"] = report.copy() json_str += "{0}\n".format(json.dumps(data)) if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: response = self.session.post(self.url, data=json_str, timeout=self.timeout) response = response.json() except Exception as e: raise SplunkError(e.__str__()) if response["code"] != 0: raise SplunkError(response["text"])
[docs] def save_smtp_tls_reports_to_splunk( self, reports: Union[list[dict[str, Any]], dict[str, Any]] ): """ Saves aggregate DMARC reports to Splunk Args: reports: A list of SMTP TLS report dictionaries to save in Splunk """ logger.debug("Saving SMTP TLS reports to Splunk") if isinstance(reports, dict): reports = [reports] if len(reports) < 1: return data = self._common_data.copy() json_str = "" for report in reports: data["sourcetype"] = "smtp:tls" timestamp = human_timestamp_to_unix_timestamp(report["begin_date"]) data["time"] = timestamp data["event"] = report.copy() json_str += "{0}\n".format(json.dumps(data)) if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: response = self.session.post(self.url, data=json_str, timeout=self.timeout) response = response.json() except Exception as e: raise SplunkError(e.__str__()) if response["code"] != 0: raise SplunkError(response["text"])