Enhance type hints and argument formatting across multiple files for improved clarity and consistency

This commit is contained in:
Sean Whalen
2025-12-02 15:17:37 -05:00
parent 5fae99aacc
commit d017dfcddf
8 changed files with 121 additions and 44 deletions
+5 -3
View File
@@ -623,7 +623,7 @@ def parse_aggregate_report_xml(
xml: str,
*,
ip_db_path: Optional[bool] = None,
always_use_local_files: Optional [bool] = False,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[bool] = None,
reverse_dns_map_url: Optional[bool] = None,
offline: Optional[bool] = False,
@@ -949,7 +949,9 @@ def parse_aggregate_report_file(
)
def parsed_aggregate_reports_to_csv_rows(reports: list[OrderedDict[str, Any]]) -> list[dict[str, Any]]:
def parsed_aggregate_reports_to_csv_rows(
reports: list[OrderedDict[str, Any]],
) -> list[dict[str, Any]]:
"""
Converts one or more parsed aggregate reports to list of dicts in flat CSV
format
@@ -1747,7 +1749,7 @@ def get_dmarc_reports_from_mailbox(
test: Optional[bool] = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[str] = False,
reverse_dns_map_path:Optional[str] = None,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None,
+10 -10
View File
@@ -1288,11 +1288,11 @@ def _main():
es_smtp_tls_index = "{0}{1}".format(prefix, es_smtp_tls_index)
elastic.set_hosts(
opts.elasticsearch_hosts,
opts.elasticsearch_ssl,
opts.elasticsearch_ssl_cert_path,
opts.elasticsearch_username,
opts.elasticsearch_password,
opts.elasticsearch_api_key,
use_ssl=opts.elasticsearch_ssl,
ssl_cert_path=opts.elasticsearch_ssl_cert_path,
username=opts.elasticsearch_username,
password=opts.elasticsearch_password,
api_key=opts.elasticsearch_api_key,
timeout=opts.elasticsearch_timeout,
)
elastic.migrate_indexes(
@@ -1320,11 +1320,11 @@ def _main():
os_smtp_tls_index = "{0}{1}".format(prefix, os_smtp_tls_index)
opensearch.set_hosts(
opts.opensearch_hosts,
opts.opensearch_ssl,
opts.opensearch_ssl_cert_path,
opts.opensearch_username,
opts.opensearch_password,
opts.opensearch_api_key,
use_ssl=opts.opensearch_ssl,
ssl_cert_path=opts.opensearch_ssl_cert_path,
username=opts.opensearch_username,
password=opts.opensearch_password,
api_key=opts.opensearch_api_key,
timeout=opts.opensearch_timeout,
)
opensearch.migrate_indexes(
+8 -3
View File
@@ -1,9 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
import logging
import logging.handlers
import json
import threading
from collections import OrderedDict
from parsedmarc import (
parsed_aggregate_reports_to_csv_rows,
@@ -48,7 +53,7 @@ class GelfClient(object):
)
self.logger.addHandler(self.handler)
def save_aggregate_report_to_gelf(self, aggregate_reports):
def save_aggregate_report_to_gelf(self, aggregate_reports: OrderedDict[str, Any]):
rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports)
for row in rows:
log_context_data.parsedmarc = row
@@ -56,12 +61,12 @@ class GelfClient(object):
log_context_data.parsedmarc = None
def save_forensic_report_to_gelf(self, forensic_reports):
def save_forensic_report_to_gelf(self, forensic_reports: OrderedDict[str, Any]):
rows = parsed_forensic_reports_to_csv_rows(forensic_reports)
for row in rows:
self.logger.info(json.dumps(row))
def save_smtp_tls_report_to_gelf(self, smtp_tls_reports):
def save_smtp_tls_report_to_gelf(self, smtp_tls_reports: OrderedDict[str, Any]):
rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports)
for row in rows:
self.logger.info(json.dumps(row))
+25 -8
View File
@@ -1,5 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any, Optional
from ssl import SSLContext
import json
from ssl import create_default_context
@@ -18,7 +23,13 @@ class KafkaError(RuntimeError):
class KafkaClient(object):
def __init__(
self, kafka_hosts, ssl=False, username=None, password=None, ssl_context=None
self,
kafka_hosts: list[str],
*,
ssl: Optional[bool] = False,
username: Optional[str] = None,
password: Optional[str] = None,
ssl_context: Optional[SSLContext] = None,
):
"""
Initializes the Kafka client
@@ -28,7 +39,7 @@ class KafkaClient(object):
ssl (bool): Use a SSL/TLS connection
username (str): An optional username
password (str): An optional password
ssl_context: SSL context options
ssl_context (SSLContext): SSL context options
Notes:
``use_ssl=True`` is implied when a username or password are
@@ -55,7 +66,7 @@ class KafkaClient(object):
raise KafkaError("No Kafka brokers available")
@staticmethod
def strip_metadata(report):
def strip_metadata(report: OrderedDict[str, Any]):
"""
Duplicates org_name, org_email and report_id into JSON root
and removes report_metadata key to bring it more inline
@@ -69,7 +80,7 @@ class KafkaClient(object):
return report
@staticmethod
def generate_daterange(report):
def generate_date_range(report: OrderedDict[str, Any]):
"""
Creates a date_range timestamp with format YYYY-MM-DD-T-HH:MM:SS
based on begin and end dates for easier parsing in Kibana.
@@ -86,7 +97,9 @@ class KafkaClient(object):
logger.debug("date_range is {}".format(date_range))
return date_range
def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic):
def save_aggregate_reports_to_kafka(
self, aggregate_reports: list[OrderedDict][str, Any], aggregate_topic: str
):
"""
Saves aggregate DMARC reports to Kafka
@@ -105,7 +118,7 @@ class KafkaClient(object):
return
for report in aggregate_reports:
report["date_range"] = self.generate_daterange(report)
report["date_range"] = self.generate_date_range(report)
report = self.strip_metadata(report)
for slice in report["records"]:
@@ -129,7 +142,9 @@ class KafkaClient(object):
except Exception as e:
raise KafkaError("Kafka error: {0}".format(e.__str__()))
def save_forensic_reports_to_kafka(self, forensic_reports, forensic_topic):
def save_forensic_reports_to_kafka(
self, forensic_reports: OrderedDict[str, Any], forensic_topic: str
):
"""
Saves forensic DMARC reports to Kafka, sends individual
records (slices) since Kafka requires messages to be <= 1MB
@@ -159,7 +174,9 @@ class KafkaClient(object):
except Exception as e:
raise KafkaError("Kafka error: {0}".format(e.__str__()))
def save_smtp_tls_reports_to_kafka(self, smtp_tls_reports, smtp_tls_topic):
def save_smtp_tls_reports_to_kafka(
self, smtp_tls_reports: list[OrderedDict[str, Any]], smtp_tls_topic: str
):
"""
Saves SMTP TLS reports to Kafka, sends individual
records (slices) since Kafka requires messages to be <= 1MB
+17 -2
View File
@@ -1,4 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
from collections import OrderedDict
from parsedmarc.log import logger
from azure.core.exceptions import HttpResponseError
from azure.identity import ClientSecretCredential
@@ -102,7 +108,12 @@ class LogAnalyticsClient(object):
"Invalid configuration. " + "One or more required settings are missing."
)
def publish_json(self, results, logs_client: LogsIngestionClient, dcr_stream: str):
def publish_json(
self,
results: OrderedDict[str, OrderedDict[str, Any]],
logs_client: LogsIngestionClient,
dcr_stream: str,
):
"""
Background function to publish given
DMARC report to specific Data Collection Rule.
@@ -121,7 +132,11 @@ class LogAnalyticsClient(object):
raise LogAnalyticsException("Upload failed: {error}".format(error=e))
def publish_results(
self, results, save_aggregate: bool, save_forensic: bool, save_smtp_tls: bool
self,
results: OrderedDict[str, OrderedDict[str, Any]],
save_aggregate: bool,
save_forensic: bool,
save_smtp_tls: bool,
):
"""
Function to publish DMARC and/or SMTP TLS reports to Log Analytics
+16 -10
View File
@@ -1,8 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
import json
import boto3
from collections import OrderedDict
from parsedmarc.log import logger
from parsedmarc.utils import human_timestamp_to_datetime
@@ -12,12 +18,12 @@ class S3Client(object):
def __init__(
self,
bucket_name,
bucket_path,
region_name,
endpoint_url,
access_key_id,
secret_access_key,
bucket_name: str,
bucket_path: str,
region_name: str,
endpoint_url: str,
access_key_id: str,
secret_access_key: str,
):
"""
Initializes the S3Client
@@ -49,16 +55,16 @@ class S3Client(object):
)
self.bucket = self.s3.Bucket(self.bucket_name)
def save_aggregate_report_to_s3(self, report):
def save_aggregate_report_to_s3(self, report: OrderedDict[str, Any]):
self.save_report_to_s3(report, "aggregate")
def save_forensic_report_to_s3(self, report):
def save_forensic_report_to_s3(self, report: OrderedDict[str, Any]):
self.save_report_to_s3(report, "forensic")
def save_smtp_tls_report_to_s3(self, report):
def save_smtp_tls_report_to_s3(self, report: OrderedDict[str, Any]):
self.save_report_to_s3(report, "smtp_tls")
def save_report_to_s3(self, report, report_type):
def save_report_to_s3(self, report: OrderedDict[str, Any], report_type: str):
if report_type == "smtp_tls":
report_date = report["begin_date"]
report_id = report["report_id"]
+22 -4
View File
@@ -1,3 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
from collections import OrderedDict
from urllib.parse import urlparse
import socket
import json
@@ -23,7 +31,13 @@ class HECClient(object):
# http://docs.splunk.com/Documentation/Splunk/latest/RESTREF/RESTinput#services.2Fcollector
def __init__(
self, url, access_token, index, source="parsedmarc", verify=True, timeout=60
self,
url: str,
access_token: str,
index: str,
source: bool = "parsedmarc",
verify=True,
timeout=60,
):
"""
Initializes the HECClient
@@ -55,7 +69,9 @@ class HECClient(object):
"Authorization": "Splunk {0}".format(self.access_token),
}
def save_aggregate_reports_to_splunk(self, aggregate_reports):
def save_aggregate_reports_to_splunk(
self, aggregate_reports: list[OrderedDict[str, Any]]
):
"""
Saves aggregate DMARC reports to Splunk
@@ -118,7 +134,9 @@ class HECClient(object):
if response["code"] != 0:
raise SplunkError(response["text"])
def save_forensic_reports_to_splunk(self, forensic_reports):
def save_forensic_reports_to_splunk(
self, forensic_reports: list[OrderedDict[str, Any]]
):
"""
Saves forensic DMARC reports to Splunk
@@ -152,7 +170,7 @@ class HECClient(object):
if response["code"] != 0:
raise SplunkError(response["text"])
def save_smtp_tls_reports_to_splunk(self, reports):
def save_smtp_tls_reports_to_splunk(self, reports: OrderedDict[str, Any]):
"""
Saves aggregate DMARC reports to Splunk
+18 -4
View File
@@ -1,7 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
import logging.handlers
from typing import Any
from collections import OrderedDict
import json
from parsedmarc import (
@@ -14,7 +22,7 @@ from parsedmarc import (
class SyslogClient(object):
"""A client for Syslog"""
def __init__(self, server_name, server_port):
def __init__(self, server_name: str, server_port: int):
"""
Initializes the SyslogClient
Args:
@@ -28,17 +36,23 @@ class SyslogClient(object):
log_handler = logging.handlers.SysLogHandler(address=(server_name, server_port))
self.logger.addHandler(log_handler)
def save_aggregate_report_to_syslog(self, aggregate_reports):
def save_aggregate_report_to_syslog(
self, aggregate_reports: list[OrderedDict[str, Any]]
):
rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports)
for row in rows:
self.logger.info(json.dumps(row))
def save_forensic_report_to_syslog(self, forensic_reports):
def save_forensic_report_to_syslog(
self, forensic_reports: list[OrderedDict[str, Any]]
):
rows = parsed_forensic_reports_to_csv_rows(forensic_reports)
for row in rows:
self.logger.info(json.dumps(row))
def save_smtp_tls_report_to_syslog(self, smtp_tls_reports):
def save_smtp_tls_report_to_syslog(
self, smtp_tls_reports: list[OrderedDict[str, Any]]
):
rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports)
for row in rows:
self.logger.info(json.dumps(row))