diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 8ef96b6..0011a87 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -27,6 +27,7 @@ from typing import ( Dict, List, Optional, + Sequence, Union, cast, ) @@ -45,6 +46,13 @@ from parsedmarc.mail import ( MailboxConnection, MSGraphConnection, ) +from parsedmarc.types import ( + AggregateReport, + ForensicReport, + ParsedReport, + ParsingResults, + SMTPTLSReport, +) from parsedmarc.utils import ( convert_outlook_msg, get_base_domain, @@ -522,7 +530,7 @@ def _parse_smtp_tls_report_policy(policy: dict[str, Any]): raise InvalidSMTPTLSReport(str(e)) -def parse_smtp_tls_report_json(report: Union[str, bytes]): +def parse_smtp_tls_report_json(report: Union[str, bytes]) -> SMTPTLSReport: """Parses and validates an SMTP TLS report""" required_fields = [ "organization-name", @@ -547,7 +555,7 @@ def parse_smtp_tls_report_json(report: Union[str, bytes]): for policy in report_dict["policies"]: policies.append(_parse_smtp_tls_report_policy(policy)) - new_report: dict[str, Any] = { + new_report: SMTPTLSReport = { "organization_name": report_dict["organization-name"], "begin_date": report_dict["date-range"]["start-datetime"], "end_date": report_dict["date-range"]["end-datetime"], @@ -565,8 +573,8 @@ def parse_smtp_tls_report_json(report: Union[str, bytes]): def parsed_smtp_tls_reports_to_csv_rows( - reports: Union[dict[str, Any], list[dict[str, Any]]], -): + reports: Union[SMTPTLSReport, list[SMTPTLSReport]], +) -> list[dict[str, Any]]: """Converts one oor more parsed SMTP TLS reports into a list of single layer dict objects suitable for use in a CSV""" if isinstance(reports, dict): @@ -574,13 +582,13 @@ def parsed_smtp_tls_reports_to_csv_rows( rows = [] for report in reports: - common_fields = { + common_fields: dict[str, Any] = { "organization_name": report["organization_name"], "begin_date": report["begin_date"], "end_date": report["end_date"], "report_id": report["report_id"], } - record = common_fields.copy() + record: dict[str, Any] = common_fields.copy() for policy in report["policies"]: if "policy_strings" in policy: record["policy_strings"] = "|".join(policy["policy_strings"]) @@ -601,7 +609,9 @@ def parsed_smtp_tls_reports_to_csv_rows( return rows -def parsed_smtp_tls_reports_to_csv(reports: dict[str, Any]) -> str: +def parsed_smtp_tls_reports_to_csv( + reports: Union[SMTPTLSReport, list[SMTPTLSReport]], +) -> str: """ Converts one or more parsed SMTP TLS reports to flat CSV format, including headers @@ -658,7 +668,7 @@ def parse_aggregate_report_xml( timeout: float = 2.0, keep_alive: Optional[Callable] = None, normalize_timespan_threshold_hours: float = 24.0, -) -> dict[str, Any]: +) -> AggregateReport: """Parses a DMARC XML report string and returns a consistent dict Args: @@ -847,7 +857,7 @@ def parse_aggregate_report_xml( new_report["records"] = records - return new_report + return cast(AggregateReport, new_report) except expat.ExpatError as error: raise InvalidAggregateReport("Invalid XML: {0}".format(error.__str__())) @@ -966,7 +976,7 @@ def parse_aggregate_report_file( dns_timeout: float = 2.0, keep_alive: Optional[Callable] = None, normalize_timespan_threshold_hours: float = 24.0, -) -> dict[str, Any]: +) -> AggregateReport: """Parses a file at the given path, a file-like object. or bytes as an aggregate DMARC report @@ -1007,7 +1017,7 @@ def parse_aggregate_report_file( def parsed_aggregate_reports_to_csv_rows( - reports: Union[dict[str, Any], list[dict[str, Any]]], + reports: Union[AggregateReport, list[AggregateReport]], ) -> list[dict[str, Any]]: """ Converts one or more parsed aggregate reports to list of dicts in flat CSV @@ -1049,7 +1059,7 @@ def parsed_aggregate_reports_to_csv_rows( pct = report["policy_published"]["pct"] fo = report["policy_published"]["fo"] - report_dict = dict( + report_dict: dict[str, Any] = dict( xml_schema=xml_schema, org_name=org_name, org_email=org_email, @@ -1069,7 +1079,7 @@ def parsed_aggregate_reports_to_csv_rows( ) for record in report["records"]: - row = report_dict.copy() + row: dict[str, Any] = report_dict.copy() row["begin_date"] = record["interval_begin"] row["end_date"] = record["interval_end"] row["source_ip_address"] = record["source"]["ip_address"] @@ -1133,7 +1143,7 @@ def parsed_aggregate_reports_to_csv_rows( def parsed_aggregate_reports_to_csv( - reports: Union[dict[str, Any], list[dict[str, Any]]], + reports: Union[AggregateReport, list[AggregateReport]], ) -> str: """ Converts one or more parsed aggregate reports to flat CSV format, including @@ -1213,7 +1223,7 @@ def parse_forensic_report( nameservers: Optional[list[str]] = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, -) -> dict[str, Any]: +) -> ForensicReport: """ Converts a DMARC forensic report and sample to a dict @@ -1332,7 +1342,7 @@ def parse_forensic_report( parsed_report["sample"] = sample parsed_report["parsed_sample"] = parsed_sample - return parsed_report + return cast(ForensicReport, parsed_report) except KeyError as error: raise InvalidForensicReport("Missing value: {0}".format(error.__str__())) @@ -1342,8 +1352,8 @@ def parse_forensic_report( def parsed_forensic_reports_to_csv_rows( - reports: Union[dict[str, Any], list[dict[str, Any]]], -): + reports: Union[ForensicReport, list[ForensicReport]], +) -> list[dict[str, Any]]: """ Converts one or more parsed forensic reports to a list of dicts in flat CSV format @@ -1360,7 +1370,7 @@ def parsed_forensic_reports_to_csv_rows( rows = [] for report in reports: - row = report.copy() + row: dict[str, Any] = dict(report) row["source_ip_address"] = report["source"]["ip_address"] row["source_reverse_dns"] = report["source"]["reverse_dns"] row["source_base_domain"] = report["source"]["base_domain"] @@ -1368,7 +1378,7 @@ def parsed_forensic_reports_to_csv_rows( row["source_type"] = report["source"]["type"] row["source_country"] = report["source"]["country"] del row["source"] - row["subject"] = report["parsed_sample"]["subject"] + row["subject"] = report["parsed_sample"].get("subject") row["auth_failure"] = ",".join(report["auth_failure"]) authentication_mechanisms = report["authentication_mechanisms"] row["authentication_mechanisms"] = ",".join(authentication_mechanisms) @@ -1379,7 +1389,9 @@ def parsed_forensic_reports_to_csv_rows( return rows -def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str: +def parsed_forensic_reports_to_csv( + reports: Union[ForensicReport, list[ForensicReport]], +) -> str: """ Converts one or more parsed forensic reports to flat CSV format, including headers @@ -1423,9 +1435,9 @@ def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str: rows = parsed_forensic_reports_to_csv_rows(reports) for row in rows: - new_row = {} - for key in new_row.keys(): - new_row[key] = row[key] + new_row: dict[str, Any] = {} + for key in fields: + new_row[key] = row.get(key) csv_writer.writerow(new_row) return csv_file.getvalue() @@ -1444,7 +1456,7 @@ def parse_report_email( strip_attachment_payloads: bool = False, keep_alive: Optional[Callable] = None, normalize_timespan_threshold_hours: float = 24.0, -) -> dict[str, Any]: +) -> ParsedReport: """ Parses a DMARC report from an email @@ -1467,7 +1479,7 @@ def parse_report_email( * ``report_type``: ``aggregate`` or ``forensic`` * ``report``: The parsed report """ - result: Optional[dict[str, Any]] = None + result: Optional[ParsedReport] = None msg_date: datetime = datetime.now(timezone.utc) try: @@ -1653,7 +1665,7 @@ def parse_report_file( offline: bool = False, keep_alive: Optional[Callable] = None, normalize_timespan_threshold_hours: float = 24, -) -> dict[str, Any]: +) -> ParsedReport: """Parses a DMARC aggregate or forensic file at the given path, a file-like object. or bytes @@ -1688,7 +1700,7 @@ def parse_report_file( if content.startswith(MAGIC_ZIP) or content.startswith(MAGIC_GZIP): content = extract_report(content) - results: Optional[dict[str, Any]] = None + results: Optional[ParsedReport] = None try: report = parse_aggregate_report_file( @@ -1743,7 +1755,7 @@ def get_dmarc_reports_from_mbox( reverse_dns_map_url: Optional[str] = None, offline: bool = False, normalize_timespan_threshold_hours: float = 24.0, -) -> dict[str, list[dict[str, Any]]]: +) -> ParsingResults: """Parses a mailbox in mbox format containing e-mails with attached DMARC reports @@ -1765,9 +1777,9 @@ def get_dmarc_reports_from_mbox( dict: Lists of ``aggregate_reports``, ``forensic_reports``, and ``smtp_tls_reports`` """ - aggregate_reports = [] - forensic_reports = [] - smtp_tls_reports = [] + aggregate_reports: list[AggregateReport] = [] + forensic_reports: list[ForensicReport] = [] + smtp_tls_reports: list[SMTPTLSReport] = [] try: mbox = mailbox.mbox(input_) message_keys = mbox.keys() @@ -1833,12 +1845,12 @@ def get_dmarc_reports_from_mailbox( nameservers: Optional[list[str]] = None, dns_timeout: float = 6.0, strip_attachment_payloads: bool = False, - results: Optional[dict[str, Any]] = None, + results: Optional[ParsingResults] = None, batch_size: int = 10, since: Optional[Union[datetime, date, str]] = None, create_folders: bool = True, normalize_timespan_threshold_hours: float = 24, -) -> dict[str, list[dict[str, Any]]]: +) -> ParsingResults: """ Fetches and parses DMARC reports from a mailbox @@ -1878,9 +1890,9 @@ def get_dmarc_reports_from_mailbox( # current_time useful to fetch_messages later in the program current_time: Optional[Union[datetime, date, str]] = None - aggregate_reports = [] - forensic_reports = [] - smtp_tls_reports = [] + aggregate_reports: list[AggregateReport] = [] + forensic_reports: list[ForensicReport] = [] + smtp_tls_reports: list[SMTPTLSReport] = [] aggregate_report_msg_uids = [] forensic_report_msg_uids = [] smtp_tls_msg_uids = [] @@ -2218,7 +2230,14 @@ def watch_inbox( mailbox_connection.watch(check_callback=check_callback, check_timeout=check_timeout) -def append_json(filename, reports): +def append_json( + filename: str, + reports: Union[ + Sequence[AggregateReport], + Sequence[ForensicReport], + Sequence[SMTPTLSReport], + ], +) -> None: with open(filename, "a+", newline="\n", encoding="utf-8") as output: output_json = json.dumps(reports, ensure_ascii=False, indent=2) if output.seek(0, os.SEEK_END) != 0: @@ -2241,7 +2260,7 @@ def append_json(filename, reports): output.write(output_json) -def append_csv(filename, csv): +def append_csv(filename: str, csv: str) -> None: with open(filename, "a+", newline="\n", encoding="utf-8") as output: if output.seek(0, os.SEEK_END) != 0: # strip the headers from the CSV @@ -2254,7 +2273,7 @@ def append_csv(filename, csv): def save_output( - results: dict[str, Any], + results: ParsingResults, *, output_directory: str = "output", aggregate_json_filename: str = "aggregate.json", @@ -2268,7 +2287,7 @@ def save_output( Save report data in the given directory Args: - results (dict): Parsing results + results: Parsing results output_directory (str): The path to the directory to save in aggregate_json_filename (str): Filename for the aggregate JSON file forensic_json_filename (str): Filename for the forensic JSON file @@ -2325,7 +2344,11 @@ def save_output( sample = forensic_report["sample"] message_count = 0 parsed_sample = forensic_report["parsed_sample"] - subject = parsed_sample["filename_safe_subject"] + subject = ( + parsed_sample.get("filename_safe_subject") + or parsed_sample.get("subject") + or "sample" + ) filename = subject while filename in sample_filenames: @@ -2340,12 +2363,12 @@ def save_output( sample_file.write(sample) -def get_report_zip(results: dict[str, Any]) -> bytes: +def get_report_zip(results: ParsingResults) -> bytes: """ Creates a zip file of parsed report output Args: - results (dict): The parsed results + results: The parsed results Returns: bytes: zip file bytes @@ -2386,7 +2409,7 @@ def get_report_zip(results: dict[str, Any]) -> bytes: def email_results( - results: dict[str, Any], + results: ParsingResults, host: str, mail_from: str, mail_to: Optional[list[str]], diff --git a/parsedmarc/types.py b/parsedmarc/types.py new file mode 100644 index 0000000..54af485 --- /dev/null +++ b/parsedmarc/types.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union + +# NOTE: This module is intentionally Python 3.9 compatible. +# - No PEP 604 unions (A | B) +# - No typing.NotRequired / Required (3.11+) to avoid an extra dependency. +# For optional keys, use total=False TypedDicts. + + +ReportType = Literal["aggregate", "forensic", "smtp_tls"] + + +class AggregateReportMetadata(TypedDict): + org_name: str + org_email: str + org_extra_contact_info: Optional[str] + report_id: str + begin_date: str + end_date: str + timespan_requires_normalization: bool + original_timespan_seconds: int + errors: List[str] + + +class AggregatePolicyPublished(TypedDict): + domain: str + adkim: str + aspf: str + p: str + sp: str + pct: str + fo: str + + +class IPSourceInfo(TypedDict): + ip_address: str + country: Optional[str] + reverse_dns: Optional[str] + base_domain: Optional[str] + name: Optional[str] + type: Optional[str] + + +class AggregateAlignment(TypedDict): + spf: bool + dkim: bool + dmarc: bool + + +class AggregateIdentifiers(TypedDict): + header_from: str + envelope_from: Optional[str] + envelope_to: Optional[str] + + +class AggregatePolicyOverrideReason(TypedDict): + type: Optional[str] + comment: Optional[str] + + +class AggregateAuthResultDKIM(TypedDict): + domain: str + result: str + selector: str + + +class AggregateAuthResultSPF(TypedDict): + domain: str + result: str + scope: str + + +class AggregateAuthResults(TypedDict): + dkim: List[AggregateAuthResultDKIM] + spf: List[AggregateAuthResultSPF] + + +class AggregatePolicyEvaluated(TypedDict): + disposition: str + dkim: str + spf: str + policy_override_reasons: List[AggregatePolicyOverrideReason] + + +class AggregateRecord(TypedDict): + interval_begin: str + interval_end: str + source: IPSourceInfo + count: int + alignment: AggregateAlignment + policy_evaluated: AggregatePolicyEvaluated + disposition: str + identifiers: AggregateIdentifiers + auth_results: AggregateAuthResults + + +class AggregateReport(TypedDict): + xml_schema: str + report_metadata: AggregateReportMetadata + policy_published: AggregatePolicyPublished + records: List[AggregateRecord] + + +class EmailAddress(TypedDict): + display_name: Optional[str] + address: str + local: Optional[str] + domain: Optional[str] + + +class EmailAttachment(TypedDict, total=False): + filename: Optional[str] + mail_content_type: Optional[str] + sha256: Optional[str] + + +ParsedEmail = TypedDict( + "ParsedEmail", + { + # This is a lightly-specified version of mailsuite/mailparser JSON. + # It focuses on the fields parsedmarc uses in forensic handling. + "headers": Dict[str, Any], + "subject": Optional[str], + "filename_safe_subject": Optional[str], + "date": Optional[str], + "from": EmailAddress, + "to": List[EmailAddress], + "cc": List[EmailAddress], + "bcc": List[EmailAddress], + "attachments": List[EmailAttachment], + "body": Optional[str], + "has_defects": bool, + "defects": Any, + "defects_categories": Any, + }, + total=False, +) + + +class ForensicReport(TypedDict): + feedback_type: Optional[str] + user_agent: Optional[str] + version: Optional[str] + original_envelope_id: Optional[str] + original_mail_from: Optional[str] + original_rcpt_to: Optional[str] + arrival_date: str + arrival_date_utc: str + authentication_results: Optional[str] + delivery_result: Optional[str] + auth_failure: List[str] + authentication_mechanisms: List[str] + dkim_domain: Optional[str] + reported_domain: str + sample_headers_only: bool + source: IPSourceInfo + sample: str + parsed_sample: ParsedEmail + + +class SMTPTLSFailureDetails(TypedDict): + result_type: str + failed_session_count: int + + +class SMTPTLSFailureDetailsOptional(SMTPTLSFailureDetails, total=False): + sending_mta_ip: str + receiving_ip: str + receiving_mx_hostname: str + receiving_mx_helo: str + additional_info_uri: str + failure_reason_code: str + ip_address: str + + +class SMTPTLSPolicySummary(TypedDict): + policy_domain: str + policy_type: str + successful_session_count: int + failed_session_count: int + + +class SMTPTLSPolicy(SMTPTLSPolicySummary, total=False): + policy_strings: List[str] + mx_host_patterns: List[str] + failure_details: List[SMTPTLSFailureDetailsOptional] + + +class SMTPTLSReport(TypedDict): + organization_name: str + begin_date: str + end_date: str + contact_info: Union[str, List[str]] + report_id: str + policies: List[SMTPTLSPolicy] + + +class AggregateParsedReport(TypedDict): + report_type: Literal["aggregate"] + report: AggregateReport + + +class ForensicParsedReport(TypedDict): + report_type: Literal["forensic"] + report: ForensicReport + + +class SMTPTLSParsedReport(TypedDict): + report_type: Literal["smtp_tls"] + report: SMTPTLSReport + + +ParsedReport = Union[AggregateParsedReport, ForensicParsedReport, SMTPTLSParsedReport] + + +class ParsingResults(TypedDict): + aggregate_reports: List[AggregateReport] + forensic_reports: List[ForensicReport] + smtp_tls_reports: List[SMTPTLSReport]