Add type annotations for SMTP TLS and forensic report structures

This commit is contained in:
Sean Whalen
2025-12-25 16:39:33 -05:00
parent 7b842740f5
commit abf4bdba13
2 changed files with 288 additions and 45 deletions

View File

@@ -27,6 +27,7 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Union,
cast,
)
@@ -45,6 +46,13 @@ from parsedmarc.mail import (
MailboxConnection,
MSGraphConnection,
)
from parsedmarc.types import (
AggregateReport,
ForensicReport,
ParsedReport,
ParsingResults,
SMTPTLSReport,
)
from parsedmarc.utils import (
convert_outlook_msg,
get_base_domain,
@@ -522,7 +530,7 @@ def _parse_smtp_tls_report_policy(policy: dict[str, Any]):
raise InvalidSMTPTLSReport(str(e))
def parse_smtp_tls_report_json(report: Union[str, bytes]):
def parse_smtp_tls_report_json(report: Union[str, bytes]) -> SMTPTLSReport:
"""Parses and validates an SMTP TLS report"""
required_fields = [
"organization-name",
@@ -547,7 +555,7 @@ def parse_smtp_tls_report_json(report: Union[str, bytes]):
for policy in report_dict["policies"]:
policies.append(_parse_smtp_tls_report_policy(policy))
new_report: dict[str, Any] = {
new_report: SMTPTLSReport = {
"organization_name": report_dict["organization-name"],
"begin_date": report_dict["date-range"]["start-datetime"],
"end_date": report_dict["date-range"]["end-datetime"],
@@ -565,8 +573,8 @@ def parse_smtp_tls_report_json(report: Union[str, bytes]):
def parsed_smtp_tls_reports_to_csv_rows(
reports: Union[dict[str, Any], list[dict[str, Any]]],
):
reports: Union[SMTPTLSReport, list[SMTPTLSReport]],
) -> list[dict[str, Any]]:
"""Converts one oor more parsed SMTP TLS reports into a list of single
layer dict objects suitable for use in a CSV"""
if isinstance(reports, dict):
@@ -574,13 +582,13 @@ def parsed_smtp_tls_reports_to_csv_rows(
rows = []
for report in reports:
common_fields = {
common_fields: dict[str, Any] = {
"organization_name": report["organization_name"],
"begin_date": report["begin_date"],
"end_date": report["end_date"],
"report_id": report["report_id"],
}
record = common_fields.copy()
record: dict[str, Any] = common_fields.copy()
for policy in report["policies"]:
if "policy_strings" in policy:
record["policy_strings"] = "|".join(policy["policy_strings"])
@@ -601,7 +609,9 @@ def parsed_smtp_tls_reports_to_csv_rows(
return rows
def parsed_smtp_tls_reports_to_csv(reports: dict[str, Any]) -> str:
def parsed_smtp_tls_reports_to_csv(
reports: Union[SMTPTLSReport, list[SMTPTLSReport]],
) -> str:
"""
Converts one or more parsed SMTP TLS reports to flat CSV format, including
headers
@@ -658,7 +668,7 @@ def parse_aggregate_report_xml(
timeout: float = 2.0,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
) -> AggregateReport:
"""Parses a DMARC XML report string and returns a consistent dict
Args:
@@ -847,7 +857,7 @@ def parse_aggregate_report_xml(
new_report["records"] = records
return new_report
return cast(AggregateReport, new_report)
except expat.ExpatError as error:
raise InvalidAggregateReport("Invalid XML: {0}".format(error.__str__()))
@@ -966,7 +976,7 @@ def parse_aggregate_report_file(
dns_timeout: float = 2.0,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
) -> AggregateReport:
"""Parses a file at the given path, a file-like object. or bytes as an
aggregate DMARC report
@@ -1007,7 +1017,7 @@ def parse_aggregate_report_file(
def parsed_aggregate_reports_to_csv_rows(
reports: Union[dict[str, Any], list[dict[str, Any]]],
reports: Union[AggregateReport, list[AggregateReport]],
) -> list[dict[str, Any]]:
"""
Converts one or more parsed aggregate reports to list of dicts in flat CSV
@@ -1049,7 +1059,7 @@ def parsed_aggregate_reports_to_csv_rows(
pct = report["policy_published"]["pct"]
fo = report["policy_published"]["fo"]
report_dict = dict(
report_dict: dict[str, Any] = dict(
xml_schema=xml_schema,
org_name=org_name,
org_email=org_email,
@@ -1069,7 +1079,7 @@ def parsed_aggregate_reports_to_csv_rows(
)
for record in report["records"]:
row = report_dict.copy()
row: dict[str, Any] = report_dict.copy()
row["begin_date"] = record["interval_begin"]
row["end_date"] = record["interval_end"]
row["source_ip_address"] = record["source"]["ip_address"]
@@ -1133,7 +1143,7 @@ def parsed_aggregate_reports_to_csv_rows(
def parsed_aggregate_reports_to_csv(
reports: Union[dict[str, Any], list[dict[str, Any]]],
reports: Union[AggregateReport, list[AggregateReport]],
) -> str:
"""
Converts one or more parsed aggregate reports to flat CSV format, including
@@ -1213,7 +1223,7 @@ def parse_forensic_report(
nameservers: Optional[list[str]] = None,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
) -> dict[str, Any]:
) -> ForensicReport:
"""
Converts a DMARC forensic report and sample to a dict
@@ -1332,7 +1342,7 @@ def parse_forensic_report(
parsed_report["sample"] = sample
parsed_report["parsed_sample"] = parsed_sample
return parsed_report
return cast(ForensicReport, parsed_report)
except KeyError as error:
raise InvalidForensicReport("Missing value: {0}".format(error.__str__()))
@@ -1342,8 +1352,8 @@ def parse_forensic_report(
def parsed_forensic_reports_to_csv_rows(
reports: Union[dict[str, Any], list[dict[str, Any]]],
):
reports: Union[ForensicReport, list[ForensicReport]],
) -> list[dict[str, Any]]:
"""
Converts one or more parsed forensic reports to a list of dicts in flat CSV
format
@@ -1360,7 +1370,7 @@ def parsed_forensic_reports_to_csv_rows(
rows = []
for report in reports:
row = report.copy()
row: dict[str, Any] = dict(report)
row["source_ip_address"] = report["source"]["ip_address"]
row["source_reverse_dns"] = report["source"]["reverse_dns"]
row["source_base_domain"] = report["source"]["base_domain"]
@@ -1368,7 +1378,7 @@ def parsed_forensic_reports_to_csv_rows(
row["source_type"] = report["source"]["type"]
row["source_country"] = report["source"]["country"]
del row["source"]
row["subject"] = report["parsed_sample"]["subject"]
row["subject"] = report["parsed_sample"].get("subject")
row["auth_failure"] = ",".join(report["auth_failure"])
authentication_mechanisms = report["authentication_mechanisms"]
row["authentication_mechanisms"] = ",".join(authentication_mechanisms)
@@ -1379,7 +1389,9 @@ def parsed_forensic_reports_to_csv_rows(
return rows
def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str:
def parsed_forensic_reports_to_csv(
reports: Union[ForensicReport, list[ForensicReport]],
) -> str:
"""
Converts one or more parsed forensic reports to flat CSV format, including
headers
@@ -1423,9 +1435,9 @@ def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str:
rows = parsed_forensic_reports_to_csv_rows(reports)
for row in rows:
new_row = {}
for key in new_row.keys():
new_row[key] = row[key]
new_row: dict[str, Any] = {}
for key in fields:
new_row[key] = row.get(key)
csv_writer.writerow(new_row)
return csv_file.getvalue()
@@ -1444,7 +1456,7 @@ def parse_report_email(
strip_attachment_payloads: bool = False,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
) -> ParsedReport:
"""
Parses a DMARC report from an email
@@ -1467,7 +1479,7 @@ def parse_report_email(
* ``report_type``: ``aggregate`` or ``forensic``
* ``report``: The parsed report
"""
result: Optional[dict[str, Any]] = None
result: Optional[ParsedReport] = None
msg_date: datetime = datetime.now(timezone.utc)
try:
@@ -1653,7 +1665,7 @@ def parse_report_file(
offline: bool = False,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24,
) -> dict[str, Any]:
) -> ParsedReport:
"""Parses a DMARC aggregate or forensic file at the given path, a
file-like object. or bytes
@@ -1688,7 +1700,7 @@ def parse_report_file(
if content.startswith(MAGIC_ZIP) or content.startswith(MAGIC_GZIP):
content = extract_report(content)
results: Optional[dict[str, Any]] = None
results: Optional[ParsedReport] = None
try:
report = parse_aggregate_report_file(
@@ -1743,7 +1755,7 @@ def get_dmarc_reports_from_mbox(
reverse_dns_map_url: Optional[str] = None,
offline: bool = False,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, list[dict[str, Any]]]:
) -> ParsingResults:
"""Parses a mailbox in mbox format containing e-mails with attached
DMARC reports
@@ -1765,9 +1777,9 @@ def get_dmarc_reports_from_mbox(
dict: Lists of ``aggregate_reports``, ``forensic_reports``, and ``smtp_tls_reports``
"""
aggregate_reports = []
forensic_reports = []
smtp_tls_reports = []
aggregate_reports: list[AggregateReport] = []
forensic_reports: list[ForensicReport] = []
smtp_tls_reports: list[SMTPTLSReport] = []
try:
mbox = mailbox.mbox(input_)
message_keys = mbox.keys()
@@ -1833,12 +1845,12 @@ def get_dmarc_reports_from_mailbox(
nameservers: Optional[list[str]] = None,
dns_timeout: float = 6.0,
strip_attachment_payloads: bool = False,
results: Optional[dict[str, Any]] = None,
results: Optional[ParsingResults] = None,
batch_size: int = 10,
since: Optional[Union[datetime, date, str]] = None,
create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24,
) -> dict[str, list[dict[str, Any]]]:
) -> ParsingResults:
"""
Fetches and parses DMARC reports from a mailbox
@@ -1878,9 +1890,9 @@ def get_dmarc_reports_from_mailbox(
# current_time useful to fetch_messages later in the program
current_time: Optional[Union[datetime, date, str]] = None
aggregate_reports = []
forensic_reports = []
smtp_tls_reports = []
aggregate_reports: list[AggregateReport] = []
forensic_reports: list[ForensicReport] = []
smtp_tls_reports: list[SMTPTLSReport] = []
aggregate_report_msg_uids = []
forensic_report_msg_uids = []
smtp_tls_msg_uids = []
@@ -2218,7 +2230,14 @@ def watch_inbox(
mailbox_connection.watch(check_callback=check_callback, check_timeout=check_timeout)
def append_json(filename, reports):
def append_json(
filename: str,
reports: Union[
Sequence[AggregateReport],
Sequence[ForensicReport],
Sequence[SMTPTLSReport],
],
) -> None:
with open(filename, "a+", newline="\n", encoding="utf-8") as output:
output_json = json.dumps(reports, ensure_ascii=False, indent=2)
if output.seek(0, os.SEEK_END) != 0:
@@ -2241,7 +2260,7 @@ def append_json(filename, reports):
output.write(output_json)
def append_csv(filename, csv):
def append_csv(filename: str, csv: str) -> None:
with open(filename, "a+", newline="\n", encoding="utf-8") as output:
if output.seek(0, os.SEEK_END) != 0:
# strip the headers from the CSV
@@ -2254,7 +2273,7 @@ def append_csv(filename, csv):
def save_output(
results: dict[str, Any],
results: ParsingResults,
*,
output_directory: str = "output",
aggregate_json_filename: str = "aggregate.json",
@@ -2268,7 +2287,7 @@ def save_output(
Save report data in the given directory
Args:
results (dict): Parsing results
results: Parsing results
output_directory (str): The path to the directory to save in
aggregate_json_filename (str): Filename for the aggregate JSON file
forensic_json_filename (str): Filename for the forensic JSON file
@@ -2325,7 +2344,11 @@ def save_output(
sample = forensic_report["sample"]
message_count = 0
parsed_sample = forensic_report["parsed_sample"]
subject = parsed_sample["filename_safe_subject"]
subject = (
parsed_sample.get("filename_safe_subject")
or parsed_sample.get("subject")
or "sample"
)
filename = subject
while filename in sample_filenames:
@@ -2340,12 +2363,12 @@ def save_output(
sample_file.write(sample)
def get_report_zip(results: dict[str, Any]) -> bytes:
def get_report_zip(results: ParsingResults) -> bytes:
"""
Creates a zip file of parsed report output
Args:
results (dict): The parsed results
results: The parsed results
Returns:
bytes: zip file bytes
@@ -2386,7 +2409,7 @@ def get_report_zip(results: dict[str, Any]) -> bytes:
def email_results(
results: dict[str, Any],
results: ParsingResults,
host: str,
mail_from: str,
mail_to: Optional[list[str]],

220
parsedmarc/types.py Normal file
View File

@@ -0,0 +1,220 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
# NOTE: This module is intentionally Python 3.9 compatible.
# - No PEP 604 unions (A | B)
# - No typing.NotRequired / Required (3.11+) to avoid an extra dependency.
# For optional keys, use total=False TypedDicts.
ReportType = Literal["aggregate", "forensic", "smtp_tls"]
class AggregateReportMetadata(TypedDict):
org_name: str
org_email: str
org_extra_contact_info: Optional[str]
report_id: str
begin_date: str
end_date: str
timespan_requires_normalization: bool
original_timespan_seconds: int
errors: List[str]
class AggregatePolicyPublished(TypedDict):
domain: str
adkim: str
aspf: str
p: str
sp: str
pct: str
fo: str
class IPSourceInfo(TypedDict):
ip_address: str
country: Optional[str]
reverse_dns: Optional[str]
base_domain: Optional[str]
name: Optional[str]
type: Optional[str]
class AggregateAlignment(TypedDict):
spf: bool
dkim: bool
dmarc: bool
class AggregateIdentifiers(TypedDict):
header_from: str
envelope_from: Optional[str]
envelope_to: Optional[str]
class AggregatePolicyOverrideReason(TypedDict):
type: Optional[str]
comment: Optional[str]
class AggregateAuthResultDKIM(TypedDict):
domain: str
result: str
selector: str
class AggregateAuthResultSPF(TypedDict):
domain: str
result: str
scope: str
class AggregateAuthResults(TypedDict):
dkim: List[AggregateAuthResultDKIM]
spf: List[AggregateAuthResultSPF]
class AggregatePolicyEvaluated(TypedDict):
disposition: str
dkim: str
spf: str
policy_override_reasons: List[AggregatePolicyOverrideReason]
class AggregateRecord(TypedDict):
interval_begin: str
interval_end: str
source: IPSourceInfo
count: int
alignment: AggregateAlignment
policy_evaluated: AggregatePolicyEvaluated
disposition: str
identifiers: AggregateIdentifiers
auth_results: AggregateAuthResults
class AggregateReport(TypedDict):
xml_schema: str
report_metadata: AggregateReportMetadata
policy_published: AggregatePolicyPublished
records: List[AggregateRecord]
class EmailAddress(TypedDict):
display_name: Optional[str]
address: str
local: Optional[str]
domain: Optional[str]
class EmailAttachment(TypedDict, total=False):
filename: Optional[str]
mail_content_type: Optional[str]
sha256: Optional[str]
ParsedEmail = TypedDict(
"ParsedEmail",
{
# This is a lightly-specified version of mailsuite/mailparser JSON.
# It focuses on the fields parsedmarc uses in forensic handling.
"headers": Dict[str, Any],
"subject": Optional[str],
"filename_safe_subject": Optional[str],
"date": Optional[str],
"from": EmailAddress,
"to": List[EmailAddress],
"cc": List[EmailAddress],
"bcc": List[EmailAddress],
"attachments": List[EmailAttachment],
"body": Optional[str],
"has_defects": bool,
"defects": Any,
"defects_categories": Any,
},
total=False,
)
class ForensicReport(TypedDict):
feedback_type: Optional[str]
user_agent: Optional[str]
version: Optional[str]
original_envelope_id: Optional[str]
original_mail_from: Optional[str]
original_rcpt_to: Optional[str]
arrival_date: str
arrival_date_utc: str
authentication_results: Optional[str]
delivery_result: Optional[str]
auth_failure: List[str]
authentication_mechanisms: List[str]
dkim_domain: Optional[str]
reported_domain: str
sample_headers_only: bool
source: IPSourceInfo
sample: str
parsed_sample: ParsedEmail
class SMTPTLSFailureDetails(TypedDict):
result_type: str
failed_session_count: int
class SMTPTLSFailureDetailsOptional(SMTPTLSFailureDetails, total=False):
sending_mta_ip: str
receiving_ip: str
receiving_mx_hostname: str
receiving_mx_helo: str
additional_info_uri: str
failure_reason_code: str
ip_address: str
class SMTPTLSPolicySummary(TypedDict):
policy_domain: str
policy_type: str
successful_session_count: int
failed_session_count: int
class SMTPTLSPolicy(SMTPTLSPolicySummary, total=False):
policy_strings: List[str]
mx_host_patterns: List[str]
failure_details: List[SMTPTLSFailureDetailsOptional]
class SMTPTLSReport(TypedDict):
organization_name: str
begin_date: str
end_date: str
contact_info: Union[str, List[str]]
report_id: str
policies: List[SMTPTLSPolicy]
class AggregateParsedReport(TypedDict):
report_type: Literal["aggregate"]
report: AggregateReport
class ForensicParsedReport(TypedDict):
report_type: Literal["forensic"]
report: ForensicReport
class SMTPTLSParsedReport(TypedDict):
report_type: Literal["smtp_tls"]
report: SMTPTLSReport
ParsedReport = Union[AggregateParsedReport, ForensicParsedReport, SMTPTLSParsedReport]
class ParsingResults(TypedDict):
aggregate_reports: List[AggregateReport]
forensic_reports: List[ForensicReport]
smtp_tls_reports: List[SMTPTLSReport]