From d017dfcddf6079ca857e66a30216da2dd1f14a1a Mon Sep 17 00:00:00 2001 From: Sean Whalen Date: Tue, 2 Dec 2025 15:17:37 -0500 Subject: [PATCH] Enhance type hints and argument formatting across multiple files for improved clarity and consistency --- parsedmarc/__init__.py | 8 +++++--- parsedmarc/cli.py | 20 ++++++++++---------- parsedmarc/gelf.py | 11 ++++++++--- parsedmarc/kafkaclient.py | 33 +++++++++++++++++++++++++-------- parsedmarc/loganalytics.py | 19 +++++++++++++++++-- parsedmarc/s3.py | 26 ++++++++++++++++---------- parsedmarc/splunk.py | 26 ++++++++++++++++++++++---- parsedmarc/syslog.py | 22 ++++++++++++++++++---- 8 files changed, 121 insertions(+), 44 deletions(-) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index f14d47f..0b45659 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -623,7 +623,7 @@ def parse_aggregate_report_xml( xml: str, *, ip_db_path: Optional[bool] = None, - always_use_local_files: Optional [bool] = False, + always_use_local_files: Optional[bool] = False, reverse_dns_map_path: Optional[bool] = None, reverse_dns_map_url: Optional[bool] = None, offline: Optional[bool] = False, @@ -949,7 +949,9 @@ def parse_aggregate_report_file( ) -def parsed_aggregate_reports_to_csv_rows(reports: list[OrderedDict[str, Any]]) -> list[dict[str, Any]]: +def parsed_aggregate_reports_to_csv_rows( + reports: list[OrderedDict[str, Any]], +) -> list[dict[str, Any]]: """ Converts one or more parsed aggregate reports to list of dicts in flat CSV format @@ -1747,7 +1749,7 @@ def get_dmarc_reports_from_mailbox( test: Optional[bool] = False, ip_db_path: Optional[str] = None, always_use_local_files: Optional[str] = False, - reverse_dns_map_path:Optional[str] = None, + reverse_dns_map_path: Optional[str] = None, reverse_dns_map_url: Optional[str] = None, offline: Optional[bool] = False, nameservers: Optional[list[str]] = None, diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index ee4a96e..bbed299 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -1288,11 +1288,11 @@ def _main(): es_smtp_tls_index = "{0}{1}".format(prefix, es_smtp_tls_index) elastic.set_hosts( opts.elasticsearch_hosts, - opts.elasticsearch_ssl, - opts.elasticsearch_ssl_cert_path, - opts.elasticsearch_username, - opts.elasticsearch_password, - opts.elasticsearch_api_key, + use_ssl=opts.elasticsearch_ssl, + ssl_cert_path=opts.elasticsearch_ssl_cert_path, + username=opts.elasticsearch_username, + password=opts.elasticsearch_password, + api_key=opts.elasticsearch_api_key, timeout=opts.elasticsearch_timeout, ) elastic.migrate_indexes( @@ -1320,11 +1320,11 @@ def _main(): os_smtp_tls_index = "{0}{1}".format(prefix, os_smtp_tls_index) opensearch.set_hosts( opts.opensearch_hosts, - opts.opensearch_ssl, - opts.opensearch_ssl_cert_path, - opts.opensearch_username, - opts.opensearch_password, - opts.opensearch_api_key, + use_ssl=opts.opensearch_ssl, + ssl_cert_path=opts.opensearch_ssl_cert_path, + username=opts.opensearch_username, + password=opts.opensearch_password, + api_key=opts.opensearch_api_key, timeout=opts.opensearch_timeout, ) opensearch.migrate_indexes( diff --git a/parsedmarc/gelf.py b/parsedmarc/gelf.py index 9e5c9da..f7c811b 100644 --- a/parsedmarc/gelf.py +++ b/parsedmarc/gelf.py @@ -1,9 +1,14 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any + import logging import logging.handlers import json import threading +from collections import OrderedDict from parsedmarc import ( parsed_aggregate_reports_to_csv_rows, @@ -48,7 +53,7 @@ class GelfClient(object): ) self.logger.addHandler(self.handler) - def save_aggregate_report_to_gelf(self, aggregate_reports): + def save_aggregate_report_to_gelf(self, aggregate_reports: OrderedDict[str, Any]): rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports) for row in rows: log_context_data.parsedmarc = row @@ -56,12 +61,12 @@ class GelfClient(object): log_context_data.parsedmarc = None - def save_forensic_report_to_gelf(self, forensic_reports): + def save_forensic_report_to_gelf(self, forensic_reports: OrderedDict[str, Any]): rows = parsed_forensic_reports_to_csv_rows(forensic_reports) for row in rows: self.logger.info(json.dumps(row)) - def save_smtp_tls_report_to_gelf(self, smtp_tls_reports): + def save_smtp_tls_report_to_gelf(self, smtp_tls_reports: OrderedDict[str, Any]): rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports) for row in rows: self.logger.info(json.dumps(row)) diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 35d1c2d..31191d5 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -1,5 +1,10 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any, Optional +from ssl import SSLContext + import json from ssl import create_default_context @@ -18,7 +23,13 @@ class KafkaError(RuntimeError): class KafkaClient(object): def __init__( - self, kafka_hosts, ssl=False, username=None, password=None, ssl_context=None + self, + kafka_hosts: list[str], + *, + ssl: Optional[bool] = False, + username: Optional[str] = None, + password: Optional[str] = None, + ssl_context: Optional[SSLContext] = None, ): """ Initializes the Kafka client @@ -28,7 +39,7 @@ class KafkaClient(object): ssl (bool): Use a SSL/TLS connection username (str): An optional username password (str): An optional password - ssl_context: SSL context options + ssl_context (SSLContext): SSL context options Notes: ``use_ssl=True`` is implied when a username or password are @@ -55,7 +66,7 @@ class KafkaClient(object): raise KafkaError("No Kafka brokers available") @staticmethod - def strip_metadata(report): + def strip_metadata(report: OrderedDict[str, Any]): """ Duplicates org_name, org_email and report_id into JSON root and removes report_metadata key to bring it more inline @@ -69,7 +80,7 @@ class KafkaClient(object): return report @staticmethod - def generate_daterange(report): + def generate_date_range(report: OrderedDict[str, Any]): """ Creates a date_range timestamp with format YYYY-MM-DD-T-HH:MM:SS based on begin and end dates for easier parsing in Kibana. @@ -86,7 +97,9 @@ class KafkaClient(object): logger.debug("date_range is {}".format(date_range)) return date_range - def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic): + def save_aggregate_reports_to_kafka( + self, aggregate_reports: list[OrderedDict][str, Any], aggregate_topic: str + ): """ Saves aggregate DMARC reports to Kafka @@ -105,7 +118,7 @@ class KafkaClient(object): return for report in aggregate_reports: - report["date_range"] = self.generate_daterange(report) + report["date_range"] = self.generate_date_range(report) report = self.strip_metadata(report) for slice in report["records"]: @@ -129,7 +142,9 @@ class KafkaClient(object): except Exception as e: raise KafkaError("Kafka error: {0}".format(e.__str__())) - def save_forensic_reports_to_kafka(self, forensic_reports, forensic_topic): + def save_forensic_reports_to_kafka( + self, forensic_reports: OrderedDict[str, Any], forensic_topic: str + ): """ Saves forensic DMARC reports to Kafka, sends individual records (slices) since Kafka requires messages to be <= 1MB @@ -159,7 +174,9 @@ class KafkaClient(object): except Exception as e: raise KafkaError("Kafka error: {0}".format(e.__str__())) - def save_smtp_tls_reports_to_kafka(self, smtp_tls_reports, smtp_tls_topic): + def save_smtp_tls_reports_to_kafka( + self, smtp_tls_reports: list[OrderedDict[str, Any]], smtp_tls_topic: str + ): """ Saves SMTP TLS reports to Kafka, sends individual records (slices) since Kafka requires messages to be <= 1MB diff --git a/parsedmarc/loganalytics.py b/parsedmarc/loganalytics.py index 3192f4d..14bc3a9 100644 --- a/parsedmarc/loganalytics.py +++ b/parsedmarc/loganalytics.py @@ -1,4 +1,10 @@ # -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any +from collections import OrderedDict + from parsedmarc.log import logger from azure.core.exceptions import HttpResponseError from azure.identity import ClientSecretCredential @@ -102,7 +108,12 @@ class LogAnalyticsClient(object): "Invalid configuration. " + "One or more required settings are missing." ) - def publish_json(self, results, logs_client: LogsIngestionClient, dcr_stream: str): + def publish_json( + self, + results: OrderedDict[str, OrderedDict[str, Any]], + logs_client: LogsIngestionClient, + dcr_stream: str, + ): """ Background function to publish given DMARC report to specific Data Collection Rule. @@ -121,7 +132,11 @@ class LogAnalyticsClient(object): raise LogAnalyticsException("Upload failed: {error}".format(error=e)) def publish_results( - self, results, save_aggregate: bool, save_forensic: bool, save_smtp_tls: bool + self, + results: OrderedDict[str, OrderedDict[str, Any]], + save_aggregate: bool, + save_forensic: bool, + save_smtp_tls: bool, ): """ Function to publish DMARC and/or SMTP TLS reports to Log Analytics diff --git a/parsedmarc/s3.py b/parsedmarc/s3.py index a80331d..4164c67 100644 --- a/parsedmarc/s3.py +++ b/parsedmarc/s3.py @@ -1,8 +1,14 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any + import json import boto3 +from collections import OrderedDict + from parsedmarc.log import logger from parsedmarc.utils import human_timestamp_to_datetime @@ -12,12 +18,12 @@ class S3Client(object): def __init__( self, - bucket_name, - bucket_path, - region_name, - endpoint_url, - access_key_id, - secret_access_key, + bucket_name: str, + bucket_path: str, + region_name: str, + endpoint_url: str, + access_key_id: str, + secret_access_key: str, ): """ Initializes the S3Client @@ -49,16 +55,16 @@ class S3Client(object): ) self.bucket = self.s3.Bucket(self.bucket_name) - def save_aggregate_report_to_s3(self, report): + def save_aggregate_report_to_s3(self, report: OrderedDict[str, Any]): self.save_report_to_s3(report, "aggregate") - def save_forensic_report_to_s3(self, report): + def save_forensic_report_to_s3(self, report: OrderedDict[str, Any]): self.save_report_to_s3(report, "forensic") - def save_smtp_tls_report_to_s3(self, report): + def save_smtp_tls_report_to_s3(self, report: OrderedDict[str, Any]): self.save_report_to_s3(report, "smtp_tls") - def save_report_to_s3(self, report, report_type): + def save_report_to_s3(self, report: OrderedDict[str, Any], report_type: str): if report_type == "smtp_tls": report_date = report["begin_date"] report_id = report["report_id"] diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index 18307d5..925c502 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -1,3 +1,11 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any + +from collections import OrderedDict + from urllib.parse import urlparse import socket import json @@ -23,7 +31,13 @@ class HECClient(object): # http://docs.splunk.com/Documentation/Splunk/latest/RESTREF/RESTinput#services.2Fcollector def __init__( - self, url, access_token, index, source="parsedmarc", verify=True, timeout=60 + self, + url: str, + access_token: str, + index: str, + source: bool = "parsedmarc", + verify=True, + timeout=60, ): """ Initializes the HECClient @@ -55,7 +69,9 @@ class HECClient(object): "Authorization": "Splunk {0}".format(self.access_token), } - def save_aggregate_reports_to_splunk(self, aggregate_reports): + def save_aggregate_reports_to_splunk( + self, aggregate_reports: list[OrderedDict[str, Any]] + ): """ Saves aggregate DMARC reports to Splunk @@ -118,7 +134,9 @@ class HECClient(object): if response["code"] != 0: raise SplunkError(response["text"]) - def save_forensic_reports_to_splunk(self, forensic_reports): + def save_forensic_reports_to_splunk( + self, forensic_reports: list[OrderedDict[str, Any]] + ): """ Saves forensic DMARC reports to Splunk @@ -152,7 +170,7 @@ class HECClient(object): if response["code"] != 0: raise SplunkError(response["text"]) - def save_smtp_tls_reports_to_splunk(self, reports): + def save_smtp_tls_reports_to_splunk(self, reports: OrderedDict[str, Any]): """ Saves aggregate DMARC reports to Splunk diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index c656aa8..7502c0c 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -1,7 +1,15 @@ # -*- coding: utf-8 -*- + +from __future__ import annotations + import logging import logging.handlers + +from typing import Any + +from collections import OrderedDict + import json from parsedmarc import ( @@ -14,7 +22,7 @@ from parsedmarc import ( class SyslogClient(object): """A client for Syslog""" - def __init__(self, server_name, server_port): + def __init__(self, server_name: str, server_port: int): """ Initializes the SyslogClient Args: @@ -28,17 +36,23 @@ class SyslogClient(object): log_handler = logging.handlers.SysLogHandler(address=(server_name, server_port)) self.logger.addHandler(log_handler) - def save_aggregate_report_to_syslog(self, aggregate_reports): + def save_aggregate_report_to_syslog( + self, aggregate_reports: list[OrderedDict[str, Any]] + ): rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports) for row in rows: self.logger.info(json.dumps(row)) - def save_forensic_report_to_syslog(self, forensic_reports): + def save_forensic_report_to_syslog( + self, forensic_reports: list[OrderedDict[str, Any]] + ): rows = parsed_forensic_reports_to_csv_rows(forensic_reports) for row in rows: self.logger.info(json.dumps(row)) - def save_smtp_tls_report_to_syslog(self, smtp_tls_reports): + def save_smtp_tls_report_to_syslog( + self, smtp_tls_reports: list[OrderedDict[str, Any]] + ): rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports) for row in rows: self.logger.info(json.dumps(row))