Enhance type hints for improved clarity and consistency in __init__.py, elastic.py, and opensearch.py

This commit is contained in:
Sean Whalen
2025-12-02 14:14:06 -05:00
parent ba57368ac3
commit 5fae99aacc
3 changed files with 167 additions and 156 deletions
+157 -146
View File
@@ -4,7 +4,7 @@
from __future__ import annotations
from typing import Dict, List, Any, Union, IO, Callable
from typing import Dict, List, Any, Union, Optional, IO, Callable
import binascii
import email
@@ -220,8 +220,8 @@ def _bucket_interval_by_day(
def _append_parsed_record(
parsed_record: Dict[str, Any],
records: List[Dict[str, Any]],
parsed_record: OrderedDict[str, Any],
records: OrderedDict[str, Any],
begin_dt: datetime,
end_dt: datetime,
normalize: bool,
@@ -264,15 +264,16 @@ def _append_parsed_record(
def _parse_report_record(
record: dict,
ip_db_path: str = None,
record: OrderedDict,
*,
ip_db_path: Optional[str] = None,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: bool = False,
nameservers: list[str] = None,
dns_timeout: float = 2.0,
):
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
) -> OrderedDict[str, Any]:
"""
Converts a record from a DMARC aggregate report into a more consistent
format
@@ -426,7 +427,7 @@ def _parse_report_record(
return new_record
def _parse_smtp_tls_failure_details(failure_details: dict):
def _parse_smtp_tls_failure_details(failure_details: dict[str, Any]):
try:
new_failure_details = OrderedDict(
result_type=failure_details["result-type"],
@@ -462,7 +463,7 @@ def _parse_smtp_tls_failure_details(failure_details: dict):
raise InvalidSMTPTLSReport(str(e))
def _parse_smtp_tls_report_policy(policy: dict):
def _parse_smtp_tls_report_policy(policy: dict[str, Any]):
policy_types = ["tlsa", "sts", "no-policy-found"]
try:
policy_domain = policy["policy"]["policy-domain"]
@@ -499,7 +500,7 @@ def _parse_smtp_tls_report_policy(policy: dict):
raise InvalidSMTPTLSReport(str(e))
def parse_smtp_tls_report_json(report: dict):
def parse_smtp_tls_report_json(report: dict[str, Any]):
"""Parses and validates an SMTP TLS report"""
required_fields = [
"organization-name",
@@ -538,7 +539,7 @@ def parse_smtp_tls_report_json(report: dict):
raise InvalidSMTPTLSReport(str(e))
def parsed_smtp_tls_reports_to_csv_rows(reports: dict):
def parsed_smtp_tls_reports_to_csv_rows(reports: OrderedDict[str, Any]):
"""Converts one oor more parsed SMTP TLS reports into a list of single
layer OrderedDict objects suitable for use in a CSV"""
if type(reports) is OrderedDict:
@@ -573,7 +574,7 @@ def parsed_smtp_tls_reports_to_csv_rows(reports: dict):
return rows
def parsed_smtp_tls_reports_to_csv(reports: dict):
def parsed_smtp_tls_reports_to_csv(reports: OrderedDict[str, Any]) -> str:
"""
Converts one or more parsed SMTP TLS reports to flat CSV format, including
headers
@@ -620,16 +621,17 @@ def parsed_smtp_tls_reports_to_csv(reports: dict):
def parse_aggregate_report_xml(
xml: str,
ip_db_path: bool = None,
always_use_local_files: bool = False,
reverse_dns_map_path: bool = None,
reverse_dns_map_url: bool = None,
offline: bool = False,
nameservers: bool = None,
timeout: float = 2.0,
keep_alive: callable = None,
normalize_timespan_threshold_hours: float = 24.0,
):
*,
ip_db_path: Optional[bool] = None,
always_use_local_files: Optional [bool] = False,
reverse_dns_map_path: Optional[bool] = None,
reverse_dns_map_url: Optional[bool] = None,
offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None,
timeout: Optional[float] = 2.0,
keep_alive: Optional[callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
) -> OrderedDict[str, Any]:
"""Parses a DMARC XML report string and returns a consistent OrderedDict
Args:
@@ -832,7 +834,7 @@ def parse_aggregate_report_xml(
raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__()))
def extract_report(content: Union[bytes, str, IO[Any]]):
def extract_report(content: Union[bytes, str, IO[Any]]) -> str:
"""
Extracts text from a zip or gzip file, as a base64-encoded string,
file-like object, or bytes.
@@ -886,7 +888,7 @@ def extract_report(content: Union[bytes, str, IO[Any]]):
return report
def extract_report_from_file_path(file_path):
def extract_report_from_file_path(file_path: str):
"""Extracts report from a file at the given file_path"""
try:
with open(file_path, "rb") as report_file:
@@ -897,21 +899,22 @@ def extract_report_from_file_path(file_path):
def parse_aggregate_report_file(
_input: Union[str, bytes, IO[Any]],
offline: bool = False,
always_use_local_files: bool = None,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
ip_db_path: str = None,
nameservers: list[str] = None,
dns_timeout: float = 2.0,
keep_alive: Callable = None,
normalize_timespan_threshold_hours: float = 24.0,
):
*,
offline: Optional[bool] = False,
always_use_local_files: Optional[bool] = None,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
) -> OrderedDict[str, any]:
"""Parses a file at the given path, a file-like object. or bytes as an
aggregate DMARC report
Args:
_input: A path to a file, a file like object, or bytes
_input (str | bytes | IO): A path to a file, a file like object, or bytes
offline (bool): Do not query online for geolocation or DNS
always_use_local_files (bool): Do not download files
reverse_dns_map_path (str): Path to a reverse DNS map file
@@ -946,7 +949,7 @@ def parse_aggregate_report_file(
)
def parsed_aggregate_reports_to_csv_rows(reports: list[dict]):
def parsed_aggregate_reports_to_csv_rows(reports: list[OrderedDict[str, Any]]) -> list[dict[str, Any]]:
"""
Converts one or more parsed aggregate reports to list of dicts in flat CSV
format
@@ -1070,7 +1073,7 @@ def parsed_aggregate_reports_to_csv_rows(reports: list[dict]):
return rows
def parsed_aggregate_reports_to_csv(reports: list[OrderedDict]):
def parsed_aggregate_reports_to_csv(reports: list[OrderedDict[str, Any]]) -> str:
"""
Converts one or more parsed aggregate reports to flat CSV format, including
headers
@@ -1140,15 +1143,16 @@ def parse_forensic_report(
feedback_report: str,
sample: str,
msg_date: datetime,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
*,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: str = None,
offline: bool = False,
ip_db_path: str = None,
nameservers: list[str] = None,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
):
offline: Optional[bool] = False,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
) -> OrderedDict[str, Any]:
"""
Converts a DMARC forensic report and sample to a ``OrderedDict``
@@ -1276,7 +1280,7 @@ def parse_forensic_report(
raise InvalidForensicReport("Unexpected error: {0}".format(error.__str__()))
def parsed_forensic_reports_to_csv_rows(reports: list[OrderedDict]):
def parsed_forensic_reports_to_csv_rows(reports: list[OrderedDict[str, Any]]):
"""
Converts one or more parsed forensic reports to a list of dicts in flat CSV
format
@@ -1312,7 +1316,7 @@ def parsed_forensic_reports_to_csv_rows(reports: list[OrderedDict]):
return rows
def parsed_forensic_reports_to_csv(reports: list[dict]):
def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str:
"""
Converts one or more parsed forensic reports to flat CSV format, including
headers
@@ -1366,17 +1370,18 @@ def parsed_forensic_reports_to_csv(reports: list[dict]):
def parse_report_email(
input_: Union[bytes, str],
offline: bool = False,
ip_db_path: str = None,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
*,
offline: Optional[bool] = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
nameservers: list[str] = None,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
keep_alive: callable = None,
normalize_timespan_threshold_hours: float = 24.0,
):
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
keep_alive: Optional[callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
) -> OrderedDict[str, Any]:
"""
Parses a DMARC report from an email
@@ -1563,22 +1568,23 @@ def parse_report_email(
def parse_report_file(
input_: Union[bytes, str, IO[Any]],
nameservers: list[str] = None,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
ip_db_path: str = None,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
offline: bool = False,
keep_alive: Callable = None,
normalize_timespan_threshold_hours: float = 24,
):
*,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24,
) -> OrderedDict[str, Any]:
"""Parses a DMARC aggregate or forensic file at the given path, a
file-like object. or bytes
Args:
input_: A path to a file, a file like object, or bytes
input_ (str | bytes | IO): A path to a file, a file like object, or bytes
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
dns_timeout (float): Sets the DNS timeout in seconds
@@ -1645,21 +1651,22 @@ def parse_report_file(
def get_dmarc_reports_from_mbox(
input_: str,
nameservers: list[str] = None,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
ip_db_path: str = None,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
offline: bool = False,
normalize_timespan_threshold_hours: float = 24.0,
):
*,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
) -> OrderedDict[str, OrderedDict[str, Any]]:
"""Parses a mailbox in mbox format containing e-mails with attached
DMARC reports
Args:
input_: A path to a mbox file
input_ (str): A path to a mbox file
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
dns_timeout (float): Sets the DNS timeout in seconds
@@ -1673,7 +1680,7 @@ def get_dmarc_reports_from_mbox(
normalize_timespan_threshold_hours (float): Normalize timespans beyond this
Returns:
OrderedDict: Lists of ``aggregate_reports`` and ``forensic_reports``
OrderedDict: Lists of ``aggregate_reports``, ``forensic_reports``, and ``smtp_tls_reports``
"""
aggregate_reports = []
@@ -1733,31 +1740,32 @@ def get_dmarc_reports_from_mbox(
def get_dmarc_reports_from_mailbox(
connection: MailboxConnection,
reports_folder: str = "INBOX",
archive_folder: str = "Archive",
delete: bool = False,
test: bool = False,
ip_db_path: str = None,
always_use_local_files: str = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
offline: bool = False,
nameservers: list[str] = None,
dns_timeout: float = 6.0,
strip_attachment_payloads: bool = False,
results: dict = None,
batch_size: int = 10,
since: datetime = None,
create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24,
):
*,
reports_folder: Optional[str] = "INBOX",
archive_folder: Optional[str] = "Archive",
delete: Optional[bool] = False,
test: Optional[bool] = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[str] = False,
reverse_dns_map_path:Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 6.0,
strip_attachment_payloads: Optional[bool] = False,
results: Optional[OrderedDict[str, any]] = None,
batch_size: Optional[int] = 10,
since: Optional[datetime] = None,
create_folders: Optional[bool] = True,
normalize_timespan_threshold_hours: Optional[float] = 24,
) -> OrderedDict[str, OrderedDict[str, Any]]:
"""
Fetches and parses DMARC reports from a mailbox
Args:
connection: A Mailbox connection object
reports_folder: The folder where reports can be found
archive_folder: The folder to move processed mail to
reports_folder (str): The folder where reports can be found
archive_folder (str): The folder to move processed mail to
delete (bool): Delete messages after processing them
test (bool): Do not move or delete messages after processing them
ip_db_path (str): Path to a MMDB file from MaxMind or DBIP
@@ -1779,7 +1787,7 @@ def get_dmarc_reports_from_mailbox(
normalize_timespan_threshold_hours (float): Normalize timespans beyond this
Returns:
OrderedDict: Lists of ``aggregate_reports`` and ``forensic_reports``
OrderedDict: Lists of ``aggregate_reports``, ``forensic_reports``, and ``smtp_tls_reports``
"""
if delete and test:
raise ValueError("delete and test options are mutually exclusive")
@@ -2055,21 +2063,22 @@ def get_dmarc_reports_from_mailbox(
def watch_inbox(
mailbox_connection: MailboxConnection,
callback: Callable,
reports_folder: str = "INBOX",
archive_folder: str = "Archive",
delete: bool = False,
test: bool = False,
check_timeout: int = 30,
ip_db_path: str = None,
always_use_local_files: bool = False,
reverse_dns_map_path: str = None,
reverse_dns_map_url: str = None,
offline: bool = False,
nameservers: list[str] = None,
dns_timeout: float = 6.0,
strip_attachment_payloads: bool = False,
batch_size: int = None,
normalize_timespan_threshold_hours: float = 24,
*,
reports_folder: Optional[str] = "INBOX",
archive_folder: Optional[str] = "Archive",
delete: Optional[bool] = False,
test: Optional[bool] = False,
check_timeout: Optional[int] = 30,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 6.0,
strip_attachment_payloads: Optional[bool] = False,
batch_size: Optional[int] = None,
normalize_timespan_threshold_hours: Optional[float] = 24,
):
"""
Watches the mailbox for new messages and
@@ -2078,8 +2087,8 @@ def watch_inbox(
Args:
mailbox_connection: The mailbox connection object
callback: The callback function to receive the parsing results
reports_folder: The IMAP folder where reports can be found
archive_folder: The folder to move processed mail to
reports_folder (str): The IMAP folder where reports can be found
archive_folder (str): The folder to move processed mail to
delete (bool): Delete messages after processing them
test (bool): Do not move or delete messages after processing them
check_timeout (int): Number of seconds to wait for a IMAP IDLE response
@@ -2159,14 +2168,15 @@ def append_csv(filename, csv):
def save_output(
results: OrderedDict,
output_directory: str = "output",
aggregate_json_filename: str = "aggregate.json",
forensic_json_filename: str = "forensic.json",
smtp_tls_json_filename: str = "smtp_tls.json",
aggregate_csv_filename: str = "aggregate.csv",
forensic_csv_filename: str = "forensic.csv",
smtp_tls_csv_filename: str = "smtp_tls.csv",
results: OrderedDict[str, Any],
*,
output_directory: Optional[str] = "output",
aggregate_json_filename: Optional[str] = "aggregate.json",
forensic_json_filename: Optional[str] = "forensic.json",
smtp_tls_json_filename: Optional[str] = "smtp_tls.json",
aggregate_csv_filename: Optional[str] = "aggregate.csv",
forensic_csv_filename: Optional[str] = "forensic.csv",
smtp_tls_csv_filename: Optional[str] = "smtp_tls.csv",
):
"""
Save report data in the given directory
@@ -2244,7 +2254,7 @@ def save_output(
sample_file.write(sample)
def get_report_zip(results: OrderedDict):
def get_report_zip(results: OrderedDict[str, Any]) -> bytes:
"""
Creates a zip file of parsed report output
@@ -2290,27 +2300,28 @@ def get_report_zip(results: OrderedDict):
def email_results(
results,
host,
mail_from,
mail_to,
mail_cc=None,
mail_bcc=None,
port=0,
require_encryption=False,
verify=True,
username=None,
password=None,
subject=None,
attachment_filename=None,
message=None,
results: OrderedDict,
*,
host: str,
mail_from: str,
mail_to: str,
mail_cc: list = None,
mail_bcc: list = None,
port: int = 0,
require_encryption: bool = False,
verify: bool = True,
username: str = None,
password: str = None,
subject: str = None,
attachment_filename: str = None,
message: str = None,
):
"""
Emails parsing results as a zip file
Args:
results (OrderedDict): Parsing results
host: Mail server hostname or IP address
host (str): Mail server hostname or IP address
mail_from: The value of the message from header
mail_to (list): A list of addresses to mail to
mail_cc (list): A list of addresses to CC
+5 -5
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Optional, Union
from typing import Optional, Union, Any
from collections import OrderedDict
@@ -304,7 +304,7 @@ def set_hosts(
connections.create_connection(**conn_params)
def create_indexes(names: list[str], settings: Optional[dict[str, any]] = None):
def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
"""
Create Elasticsearch indexes
@@ -377,7 +377,7 @@ def migrate_indexes(
def save_aggregate_report_to_elasticsearch(
aggregate_report: OrderedDict[str, any],
aggregate_report: OrderedDict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
@@ -539,7 +539,7 @@ def save_aggregate_report_to_elasticsearch(
def save_forensic_report_to_elasticsearch(
forensic_report: OrderedDict[str, any],
forensic_report: OrderedDict[str, Any],
index_suffix: Optional[any] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
@@ -706,7 +706,7 @@ def save_forensic_report_to_elasticsearch(
def save_smtp_tls_report_to_elasticsearch(
report: OrderedDict[str, any],
report: OrderedDict[str, Any],
index_suffix: str = None,
index_prefix: str = None,
monthly_indexes: Optional[bool] = False,
+5 -5
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Optional, Union
from typing import Optional, Union, Any
from collections import OrderedDict
@@ -304,7 +304,7 @@ def set_hosts(
connections.create_connection(**conn_params)
def create_indexes(names: list[str], settings: Optional[dict[str, any]] = None):
def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
"""
Create OpenSearch indexes
@@ -377,7 +377,7 @@ def migrate_indexes(
def save_aggregate_report_to_elasticsearch(
aggregate_report: OrderedDict[str, any],
aggregate_report: OrderedDict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
@@ -539,7 +539,7 @@ def save_aggregate_report_to_elasticsearch(
def save_forensic_report_to_elasticsearch(
forensic_report: OrderedDict[str, any],
forensic_report: OrderedDict[str, Any],
index_suffix: Optional[any] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
@@ -706,7 +706,7 @@ def save_forensic_report_to_elasticsearch(
def save_smtp_tls_report_to_elasticsearch(
report: OrderedDict[str, any],
report: OrderedDict[str, Any],
index_suffix: str = None,
index_prefix: str = None,
monthly_indexes: Optional[bool] = False,