From 3bdf79500658b0e4e59883aa1c0c74cb77ba6c96 Mon Sep 17 00:00:00 2001 From: Sean Whalen Date: Thu, 26 Mar 2026 02:17:42 -0400 Subject: [PATCH] Add tests for extract_report and parse_aggregate_report_xml functions --- tests.py | 854 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 854 insertions(+) diff --git a/tests.py b/tests.py index 14138d8..cda613d 100755 --- a/tests.py +++ b/tests.py @@ -13,6 +13,7 @@ import unittest from base64 import urlsafe_b64encode from configparser import ConfigParser from datetime import datetime, timedelta, timezone +from io import BytesIO from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -20,13 +21,18 @@ from typing import cast from types import SimpleNamespace from unittest.mock import MagicMock, patch +from expiringdict import ExpiringDict from lxml import etree # type: ignore[import-untyped] from googleapiclient.errors import HttpError from httplib2 import Response from imapclient.exceptions import IMAPClientError +import dns.exception +import requests + import parsedmarc import parsedmarc.cli +import parsedmarc.webhook from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport from parsedmarc.mail.gmail import GmailConnection from parsedmarc.mail.gmail import _get_creds @@ -3175,6 +3181,7 @@ class TestEnvVarConfig(unittest.TestCase): config.getboolean("general", "debug"), f"Expected falsy for {false_val!r}", ) + # ============================================================ # New tests for _bucket_interval_by_day # ============================================================ def testBucketIntervalBeginAfterEnd(self): @@ -4635,5 +4642,852 @@ class TestEnvVarConfig(unittest.TestCase): print("Passed!") +class TestExtractReport(unittest.TestCase): + """Tests for parsedmarc.extract_report()""" + + def testExtractReportFromBytes(self): + """extract_report handles raw XML bytes""" + xml = b'' + result = parsedmarc.extract_report(xml) + self.assertIn("", result) + + def testExtractReportFromBase64Xml(self): + """extract_report handles base64-encoded XML string""" + import base64 + + xml = b'' + b64 = base64.b64encode(xml).decode() + result = parsedmarc.extract_report(b64) + self.assertIn("", result) + + def testExtractReportFromGzip(self): + """extract_report handles gzip compressed content""" + import gzip + + xml = b'' + compressed = gzip.compress(xml) + result = parsedmarc.extract_report(compressed) + self.assertIn("", result) + + def testExtractReportFromZip(self): + """extract_report handles zip compressed content""" + import zipfile + + xml = b'' + buf = BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("report.xml", xml) + result = parsedmarc.extract_report(buf.getvalue()) + self.assertIn("", result) + + def testExtractReportFromBinaryIO(self): + """extract_report handles file-like BinaryIO objects""" + xml = b'' + bio = BytesIO(xml) + result = parsedmarc.extract_report(bio) + self.assertIn("", result) + + def testExtractReportFromNonSeekableStream(self): + """extract_report handles non-seekable streams""" + xml = b'' + + class NonSeekable: + def __init__(self, data): + self._data = data + self._pos = 0 + + def read(self, n=-1): + if n == -1: + result = self._data[self._pos :] + self._pos = len(self._data) + else: + result = self._data[self._pos : self._pos + n] + self._pos += n + return result + + def seekable(self): + return False + + def close(self): + pass + + result = parsedmarc.extract_report(NonSeekable(xml)) + self.assertIn("", result) + + def testExtractReportInvalidContent(self): + """extract_report raises ParserError for invalid content""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report(b"this is not a valid archive") + + def testExtractReportTextModeRaises(self): + """extract_report raises ParserError for text-mode streams""" + + class TextStream: + def read(self, n=-1): + return "text data" + + def seekable(self): + return True + + def seek(self, pos): + pass + + def close(self): + pass + + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report(TextStream()) + + +class TestMalformedXmlRecovery(unittest.TestCase): + """Tests for XML recovery in parse_aggregate_report_xml""" + + def testRecoversMalformedXml(self): + """Malformed XML triggers recovery path and still parses""" + # XML with a broken tag that xmltodict will reject but lxml can recover + malformed_xml = """ + + + example.com + dmarc@example.com + 12345 + 16800000001680086400 + + + example.com

none

+
+ + 203.0.113.11 + nonepasspass + + example.com + example.compass + + """ + # lxml recovery may succeed or fail depending on how broken the XML is + # Either way, no unhandled exception should escape + try: + report = parsedmarc.parse_aggregate_report_xml(malformed_xml, offline=True) + self.assertIn("report_metadata", report) + except parsedmarc.InvalidAggregateReport: + pass # Also acceptable + + def testBytesXmlInput(self): + """XML bytes input is decoded""" + xml = b""" + + + example.com + dmarc@example.com + test-bytes-input + 16800000001680086400 + + + example.com

none

+
+ + 203.0.113.11 + nonepasspass + + example.com + example.compass + +
""" + report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) + self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input") + + def testExpatErrorRaises(self): + """Completely invalid XML raises InvalidAggregateReport""" + with self.assertRaises(parsedmarc.InvalidAggregateReport): + parsedmarc.parse_aggregate_report_xml("not xml at all {}", offline=True) + + def testMissingOrgName(self): + """Missing org_name raises InvalidAggregateReport""" + xml = """ + + + dmarc@example.com + missing-org + 16800000001680086400 + + example.com

none

+ + 1.2.3.41 + nonepasspass + + example.com + example.compass + +
""" + with self.assertRaises(parsedmarc.InvalidAggregateReport): + parsedmarc.parse_aggregate_report_xml(xml, offline=True) + + +class TestPolicyPublishedEdgeCases(unittest.TestCase): + """Tests for edge cases in policy_published parsing""" + + VALID_XML_TEMPLATE = """ + + + example.com + dmarc@example.com + test-{tag} + 16800000001680086400 + {extra_metadata} + + + example.com

reject

+ {policy_extra} +
+ + 203.0.113.11 + nonepasspass + + example.com + example.compass + +
""" + + def _parse(self, tag="default", policy_extra="", extra_metadata=""): + xml = self.VALID_XML_TEMPLATE.format( + tag=tag, policy_extra=policy_extra, extra_metadata=extra_metadata + ) + return parsedmarc.parse_aggregate_report_xml(xml, offline=True) + + def testPolicyPublishedListHandled(self): + """policy_published as a list uses first element""" + # The code checks `if type(policy_published) is list` + # This is tested implicitly when xmltodict returns a list; + # we test via the np field presence + report = self._parse(tag="np", policy_extra="quarantine") + self.assertEqual(report["policy_published"]["np"], "quarantine") + + def testNpFieldValues(self): + """np field is parsed correctly""" + for val in ["none", "quarantine", "reject"]: + report = self._parse(tag=f"np-{val}", policy_extra=f"{val}") + self.assertEqual(report["policy_published"]["np"], val) + + def testTestingField(self): + """testing field is parsed correctly""" + for val in ["y", "n"]: + report = self._parse( + tag=f"testing-{val}", policy_extra=f"{val}" + ) + self.assertEqual(report["policy_published"]["testing"], val) + + def testDiscoveryMethodField(self): + """discovery_method field is parsed correctly""" + for val in ["psl", "treewalk"]: + report = self._parse( + tag=f"disc-{val}", + policy_extra=f"{val}", + ) + self.assertEqual(report["policy_published"]["discovery_method"], val) + + def testGeneratorField(self): + """generator field in report_metadata is parsed""" + report = self._parse( + tag="gen", extra_metadata="TestGen/1.0" + ) + self.assertEqual(report["report_metadata"]["generator"], "TestGen/1.0") + + def testPctFieldNone(self): + """pct defaults to None when absent (DMARCbis)""" + report = self._parse(tag="no-pct") + self.assertIsNone(report["policy_published"]["pct"]) + + def testFoFieldNone(self): + """fo defaults to None when absent (DMARCbis)""" + report = self._parse(tag="no-fo") + self.assertIsNone(report["policy_published"]["fo"]) + + def testReportMetadataErrors(self): + """Report metadata errors are captured""" + report = self._parse( + tag="errors", + extra_metadata="DNS timeout", + ) + self.assertIn("DNS timeout", report["report_metadata"]["errors"]) + + def testReportMetadataErrorsList(self): + """Report metadata errors as list are captured""" + report = self._parse( + tag="errors-list", + extra_metadata="error1error2", + ) + self.assertIn("error1", report["report_metadata"]["errors"]) + self.assertIn("error2", report["report_metadata"]["errors"]) + + def testRecordParseFailureSkipped(self): + """Bad records are skipped with a warning, not crashing""" + xml = """ + + + example.com + dmarc@example.com + bad-records + 16800000001680086400 + + example.com

none

+ + 203.0.113.11 + nonepasspass + + example.com + example.compass + + + bad-ipnot-a-number + nonepasspass + + example.com + example.compass + +
""" + report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) + # At least the valid record should be parsed + self.assertTrue(len(report["records"]) >= 1) + + +class TestParseReportFile(unittest.TestCase): + """Tests for parse_report_file with various input types""" + + def testParseReportFileFromBytes(self): + """parse_report_file works with bytes input""" + xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml" + with open(xml_path, "rb") as f: + content = f.read() + result = parsedmarc.parse_report_file(content, offline=True) + self.assertEqual(result["report_type"], "aggregate") + + def testParseReportFileFromBinaryIO(self): + """parse_report_file works with BinaryIO input""" + xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml" + with open(xml_path, "rb") as f: + result = parsedmarc.parse_report_file(f, offline=True) + self.assertEqual(result["report_type"], "aggregate") + + def testParseReportFileFromPathlib(self): + """parse_report_file works with pathlib.Path input""" + xml_path = Path("samples/aggregate/!example.com!1538204542!1538463818.xml") + result = parsedmarc.parse_report_file(xml_path, offline=True) + self.assertEqual(result["report_type"], "aggregate") + + def testParseReportFileSmtpTls(self): + """parse_report_file detects SMTP TLS reports""" + result = parsedmarc.parse_report_file( + "samples/smtp_tls/smtp_tls.json", offline=True + ) + self.assertEqual(result["report_type"], "smtp_tls") + + def testParseReportFileEmail(self): + """parse_report_file detects failure reports in email format""" + eml_path = "samples/failure/dmarc_ruf_report_linkedin.eml" + result = parsedmarc.parse_report_file(eml_path, offline=True) + self.assertEqual(result["report_type"], "failure") + + def testParseReportFileInvalid(self): + """parse_report_file raises ParserError for invalid content""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.parse_report_file(b"this is not a report", offline=True) + + +class TestParseReportEmail(unittest.TestCase): + """Tests for parse_report_email edge cases""" + + def testSmtpTlsEmailReport(self): + """parse_report_email handles SMTP TLS reports in email format""" + eml_path = "samples/smtp_tls/google.com_smtp_tls_report.eml" + with open(eml_path, "rb") as f: + content = f.read() + result = parsedmarc.parse_report_email(content, offline=True) + self.assertEqual(result["report_type"], "smtp_tls") + + def testInvalidEmailRaisesError(self): + """parse_report_email raises error for non-DMARC email""" + email_str = """From: test@example.com +Subject: Hello World +Content-Type: text/plain + +This is not a DMARC report.""" + with self.assertRaises(parsedmarc.InvalidDMARCReport): + parsedmarc.parse_report_email(email_str, offline=True) + + +class TestFailureReportParsing(unittest.TestCase): + """Tests for failure report field defaults and edge cases""" + + def _make_feedback_report(self, **overrides): + """Create a minimal feedback report string""" + fields = { + "Feedback-Type": "auth-failure", + "User-Agent": "test/1.0", + "Version": "1", + "Original-Mail-From": "sender@example.com", + "Arrival-Date": "Thu, 1 Jan 2024 00:00:00 +0000", + "Source-IP": "203.0.113.1", + "Reported-Domain": "example.com", + "Auth-Failure": "dmarc", + } + fields.update(overrides) + return "\n".join(f"{k}: {v}" for k, v in fields.items()) + + def _make_sample(self): + return """From: sender@example.com +To: recipient@example.com +Subject: Test +Date: Thu, 1 Jan 2024 00:00:00 +0000 + +Test body""" + + def _default_msg_date(self): + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + def testMissingVersion(self): + """Missing version defaults to None""" + report_str = self._make_feedback_report() + lines = [l for l in report_str.split("\n") if not l.startswith("Version:")] + report_str = "\n".join(lines) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertIsNone(report["version"]) + + def testMissingUserAgent(self): + """Missing user_agent defaults to None""" + report_str = self._make_feedback_report() + lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")] + report_str = "\n".join(lines) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertIsNone(report["user_agent"]) + + def testMissingDeliveryResult(self): + """Missing delivery_result maps to 'other' when field absent""" + report_str = self._make_feedback_report() + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + # When delivery_result is not in the parsed report, it's set to None, + # but then the validation check maps None (not in delivery_results list) to "other" + self.assertEqual(report["delivery_result"], "other") + + def testDeliveryResultMapped(self): + """Known delivery_result values are mapped correctly""" + for val in ["delivered", "spam", "policy", "reject"]: + report_str = self._make_feedback_report(**{"Delivery-Result": val}) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertEqual(report["delivery_result"], val) + + def testDeliveryResultUnknownMapsToOther(self): + """Unknown delivery_result maps to 'other'""" + report_str = self._make_feedback_report(**{"Delivery-Result": "unknown-value"}) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertEqual(report["delivery_result"], "other") + + def testIdentityAlignmentNone(self): + """identity_alignment='none' results in empty auth mechanisms""" + report_str = self._make_feedback_report(**{"Identity-Alignment": "none"}) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertEqual(report["authentication_mechanisms"], []) + + def testIdentityAlignmentMultiple(self): + """identity_alignment with multiple values is split""" + report_str = self._make_feedback_report(**{"Identity-Alignment": "dkim,spf"}) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertEqual(report["authentication_mechanisms"], ["dkim", "spf"]) + + def testMissingReportedDomainFallback(self): + """Missing reported_domain falls back to sample from domain""" + report_str = self._make_feedback_report() + lines = [ + l for l in report_str.split("\n") if not l.startswith("Reported-Domain:") + ] + report_str = "\n".join(lines) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), self._default_msg_date(), offline=True + ) + self.assertEqual(report["reported_domain"], "example.com") + + def testMissingArrivalDateWithMsgDate(self): + """Missing arrival_date uses msg_date fallback""" + report_str = self._make_feedback_report() + lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + report_str = "\n".join(lines) + msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc) + report = parsedmarc.parse_failure_report( + report_str, self._make_sample(), msg_date, offline=True + ) + self.assertIn("2024-06-15", report["arrival_date"]) + + def testMissingArrivalDateNoMsgDateRaises(self): + """Missing arrival_date with no msg_date raises""" + report_str = self._make_feedback_report() + lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + report_str = "\n".join(lines) + with self.assertRaises(parsedmarc.InvalidFailureReport): + parsedmarc.parse_failure_report( + report_str, self._make_sample(), None, offline=True + ) + + +class TestWebhookClient(unittest.TestCase): + """Tests for webhook client error handling and close""" + + def testSaveMethodsHandleErrors(self): + """Each save method catches and logs errors from _send_to_webhook""" + client = parsedmarc.webhook.WebhookClient( + aggregate_url="http://invalid.test/agg", + failure_url="http://invalid.test/fail", + smtp_tls_url="http://invalid.test/tls", + ) + # Mock _send_to_webhook to raise, testing the try/except in each save method + with patch.object( + client, "_send_to_webhook", side_effect=Exception("send failed") + ): + # None should raise + client.save_aggregate_report_to_webhook('{"test": true}') + client.save_failure_report_to_webhook('{"test": true}') + client.save_smtp_tls_report_to_webhook('{"test": true}') + + def testSendToWebhookLogsPostErrors(self): + """_send_to_webhook catches and logs POST errors""" + client = parsedmarc.webhook.WebhookClient( + aggregate_url="http://invalid.test/agg", + failure_url="http://invalid.test/fail", + smtp_tls_url="http://invalid.test/tls", + ) + with patch.object( + client.session, "post", side_effect=Exception("connection refused") + ): + # Should not raise + client._send_to_webhook("http://invalid.test/agg", '{"test": true}') + + def testClose(self): + """WebhookClient.close() closes session""" + client = parsedmarc.webhook.WebhookClient( + aggregate_url="http://invalid.test/agg", + failure_url="http://invalid.test/fail", + smtp_tls_url="http://invalid.test/tls", + ) + mock_close = MagicMock() + client.session.close = mock_close + client.close() + mock_close.assert_called_once() + + +class TestUtilsDnsCaching(unittest.TestCase): + """Tests for DNS query caching and reverse DNS error handling""" + + def testQueryDnsUsesCacheHit(self): + """query_dns returns cached result without making DNS query""" + cache = ExpiringDict(max_len=100, max_age_seconds=60) + cache["example.com_A"] = ["1.2.3.4"] + result = parsedmarc.utils.query_dns("example.com", "A", cache=cache) + self.assertEqual(result, ["1.2.3.4"]) + + def testQueryDnsCachesResult(self): + """query_dns stores result in cache when cache is non-empty""" + cache = ExpiringDict(max_len=100, max_age_seconds=60) + # Pre-populate so ExpiringDict is truthy + cache["seed_key"] = ["seed"] + mock_record = MagicMock() + mock_record.to_text.return_value = '"1.2.3.4"' + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = [mock_record] + with patch( + "parsedmarc.utils.dns.resolver.Resolver", return_value=mock_resolver + ): + result = parsedmarc.utils.query_dns( + "test-cache.example.com", "A", cache=cache + ) + self.assertEqual(result, ["1.2.3.4"]) + self.assertIn("test-cache.example.com_A", cache) + + def testReverseDnsReturnsNoneOnFailure(self): + """get_reverse_dns returns None on DNS exceptions""" + with patch( + "parsedmarc.utils.query_dns", + side_effect=dns.exception.DNSException("timeout"), + ): + result = parsedmarc.utils.get_reverse_dns("203.0.113.1") + self.assertIsNone(result) + + +class TestUtilsIpDbPaths(unittest.TestCase): + """Tests for IP database path validation""" + + def testCustomPathFallsBack(self): + """Non-existent custom db path falls back to default""" + result = parsedmarc.utils.get_ip_address_country( + "1.1.1.1", db_path="/nonexistent/path.mmdb" + ) + self.assertTrue(result is None or isinstance(result, str)) + + def testBundledDbWorks(self): + """Bundled IP database returns results""" + result = parsedmarc.utils.get_ip_address_country("8.8.8.8") + self.assertEqual(result, "US") + + +class TestUtilsParseEmail(unittest.TestCase): + """Tests for parse_email edge cases""" + + def testMinimalEmail(self): + """parse_email handles email with minimal headers""" + email_str = """From: test@example.com +Subject: Test + +Body text""" + result = parsedmarc.utils.parse_email(email_str) + self.assertEqual(result["subject"], "Test") + self.assertEqual(result["reply_to"], []) + + def testEmailWithNoSubject(self): + """parse_email defaults subject to None when missing""" + email_str = """From: test@example.com +To: other@example.com + +Body""" + result = parsedmarc.utils.parse_email(email_str) + self.assertIsNone(result["subject"]) + + def testEmailBytesInput(self): + """parse_email handles bytes input""" + email_bytes = b"""From: test@example.com +Subject: Bytes Test +To: other@example.com + +Body""" + result = parsedmarc.utils.parse_email(email_bytes) + self.assertEqual(result["subject"], "Bytes Test") + + def testEmailWithAttachments(self): + """parse_email with strip_attachment_payloads removes payloads""" + import email as email_mod + from email.mime.multipart import MIMEMultipart + from email.mime.text import MIMEText + from email.mime.base import MIMEBase + from email import encoders + + msg = MIMEMultipart() + msg["From"] = "test@example.com" + msg["To"] = "other@example.com" + msg["Subject"] = "Attachment Test" + msg.attach(MIMEText("Body text")) + + attachment = MIMEBase("application", "octet-stream") + attachment.set_payload(b"file content here") + encoders.encode_base64(attachment) + attachment.add_header("Content-Disposition", "attachment", filename="test.bin") + msg.attach(attachment) + + result = parsedmarc.utils.parse_email( + msg.as_string(), strip_attachment_payloads=True + ) + for att in result["attachments"]: + self.assertNotIn("payload", att) + + +class TestUtilsOutlookMsg(unittest.TestCase): + """Tests for Outlook MSG detection and conversion""" + + def testIsOutlookMsg(self): + """is_outlook_msg detects MSG magic bytes""" + msg_magic = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" + b"\x00" * 100 + self.assertTrue(parsedmarc.utils.is_outlook_msg(msg_magic)) + + def testIsNotOutlookMsg(self): + """is_outlook_msg rejects non-MSG content""" + self.assertFalse(parsedmarc.utils.is_outlook_msg(b"not an msg file")) + self.assertFalse(parsedmarc.utils.is_outlook_msg("string input")) + + def testConvertOutlookMsgInvalidInput(self): + """convert_outlook_msg raises ValueError for non-MSG bytes""" + with self.assertRaises(ValueError): + parsedmarc.utils.convert_outlook_msg(b"not an msg file") + + +class TestUtilsReverseDnsMap(unittest.TestCase): + """Tests for reverse DNS map loading""" + + def testLoadReverseDnsMapOffline(self): + """load_reverse_dns_map in offline mode loads bundled map""" + rdns_map = {} + parsedmarc.utils.load_reverse_dns_map(rdns_map, offline=True) + self.assertTrue(len(rdns_map) > 0) + + def testLoadReverseDnsMapLocalOverride(self): + """load_reverse_dns_map uses local_file_path when provided""" + with NamedTemporaryFile("w", suffix=".csv", delete=False) as f: + f.write("base_reverse_dns,name,type\n") + f.write("custom.example.com,Custom Service,hosting\n") + path = f.name + try: + rdns_map = {} + parsedmarc.utils.load_reverse_dns_map( + rdns_map, offline=True, local_file_path=path + ) + self.assertIn("custom.example.com", rdns_map) + self.assertEqual(rdns_map["custom.example.com"]["name"], "Custom Service") + finally: + os.remove(path) + + def testLoadReverseDnsMapNetworkFailureFallback(self): + """load_reverse_dns_map falls back to bundled on network error""" + rdns_map = {} + with patch( + "parsedmarc.utils.requests.get", + side_effect=requests.exceptions.ConnectionError("no network"), + ): + parsedmarc.utils.load_reverse_dns_map(rdns_map) + self.assertTrue(len(rdns_map) > 0) + + +class TestSmtpTlsReportErrors(unittest.TestCase): + """Tests for SMTP TLS report error handling""" + + def testMissingRequiredField(self): + """Missing required field raises InvalidSMTPTLSReport""" + json_str = json.dumps({"policies": []}) + with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): + parsedmarc.parse_smtp_tls_report_json(json_str) + + def testInvalidJson(self): + """Invalid JSON raises InvalidSMTPTLSReport""" + with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): + parsedmarc.parse_smtp_tls_report_json("not json {{{") + + +class TestBucketIntervalEdgeCases(unittest.TestCase): + """Tests for _bucket_interval_by_day edge cases""" + + def testDayCursorAdjustment(self): + """When begin is before midnight due to tz, day_cursor adjusts back""" + # Use a timezone where midnight calculation might cause day_cursor > begin + import pytz + + tz = pytz.FixedOffset(-600) # UTC-10 + begin = datetime(2024, 1, 1, 23, 30, 0, tzinfo=timezone.utc).astimezone(tz) + end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc).astimezone(tz) + buckets = parsedmarc._bucket_interval_by_day(begin, end, 100) + total = sum(b["count"] for b in buckets) + self.assertEqual(total, 100) + + +class TestGetDmarcReportsFromMbox(unittest.TestCase): + """Tests for mbox parsing""" + + def testEmptyMbox(self): + """Empty mbox returns empty results""" + with NamedTemporaryFile(suffix=".mbox", delete=False) as f: + f.write(b"") + path = f.name + try: + results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True) + self.assertEqual(results["aggregate_reports"], []) + self.assertEqual(results["failure_reports"], []) + self.assertEqual(results["smtp_tls_reports"], []) + finally: + os.remove(path) + + def testMboxWithAggregateReport(self): + """Mbox with aggregate report email is parsed""" + import email as email_mod + from email.mime.multipart import MIMEMultipart + from email.mime.application import MIMEApplication + import gzip + + xml = b""" + + + example.com + dmarc@example.com + mbox-test-123 + 16800000001680086400 + + example.com

none

+ + 203.0.113.11 + nonepasspass + + example.com + example.compass + +
""" + compressed = gzip.compress(xml) + + msg = MIMEMultipart() + msg["From"] = "dmarc@example.com" + msg["To"] = "postmaster@example.com" + msg["Subject"] = "DMARC Aggregate Report" + msg["Date"] = "Thu, 1 Jan 2024 00:00:00 +0000" + att = MIMEApplication(compressed, "gzip") + att.add_header("Content-Disposition", "attachment", filename="report.xml.gz") + msg.attach(att) + + with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f: + # mbox format requires "From " line + f.write("From dmarc@example.com Thu Jan 1 00:00:00 2024\n") + f.write(msg.as_string()) + f.write("\n") + path = f.name + try: + results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True) + self.assertTrue(len(results["aggregate_reports"]) >= 1) + finally: + os.remove(path) + + +class TestPslOverrides(unittest.TestCase): + """Tests for PSL override matching""" + + def testOverrideMatch(self): + """PSL overrides are applied when domain ends with override""" + # psl_overrides contains entries; test that get_base_domain + # handles them without error + result = parsedmarc.utils.get_base_domain("sub.example.com") + self.assertEqual(result, "example.com") + + +class TestIsMbox(unittest.TestCase): + """Tests for is_mbox utility""" + + def testValidMbox(self): + """is_mbox returns True for valid mbox file""" + with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f: + f.write("From test@example.com Thu Jan 1 00:00:00 2024\n") + f.write("Subject: Test\n\nBody\n\n") + path = f.name + try: + self.assertTrue(parsedmarc.utils.is_mbox(path)) + finally: + os.remove(path) + + def testEmptyFileNotMbox(self): + """is_mbox returns False for empty file""" + with NamedTemporaryFile(suffix=".mbox", delete=False) as f: + path = f.name + try: + self.assertFalse(parsedmarc.utils.is_mbox(path)) + finally: + os.remove(path) + + def testNonExistentNotMbox(self): + """is_mbox returns False for non-existent file""" + self.assertFalse(parsedmarc.utils.is_mbox("/nonexistent/file.mbox")) + + if __name__ == "__main__": unittest.main(verbosity=2)