diff --git a/tests.py b/tests.py
index 14138d8..cda613d 100755
--- a/tests.py
+++ b/tests.py
@@ -13,6 +13,7 @@ import unittest
from base64 import urlsafe_b64encode
from configparser import ConfigParser
from datetime import datetime, timedelta, timezone
+from io import BytesIO
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
@@ -20,13 +21,18 @@ from typing import cast
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
+from expiringdict import ExpiringDict
from lxml import etree # type: ignore[import-untyped]
from googleapiclient.errors import HttpError
from httplib2 import Response
from imapclient.exceptions import IMAPClientError
+import dns.exception
+import requests
+
import parsedmarc
import parsedmarc.cli
+import parsedmarc.webhook
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
from parsedmarc.mail.gmail import GmailConnection
from parsedmarc.mail.gmail import _get_creds
@@ -3175,6 +3181,7 @@ class TestEnvVarConfig(unittest.TestCase):
config.getboolean("general", "debug"),
f"Expected falsy for {false_val!r}",
)
+
# ============================================================ # New tests for _bucket_interval_by_day
# ============================================================
def testBucketIntervalBeginAfterEnd(self):
@@ -4635,5 +4642,852 @@ class TestEnvVarConfig(unittest.TestCase):
print("Passed!")
+class TestExtractReport(unittest.TestCase):
+ """Tests for parsedmarc.extract_report()"""
+
+ def testExtractReportFromBytes(self):
+ """extract_report handles raw XML bytes"""
+ xml = b''
+ result = parsedmarc.extract_report(xml)
+ self.assertIn("", result)
+
+ def testExtractReportFromBase64Xml(self):
+ """extract_report handles base64-encoded XML string"""
+ import base64
+
+ xml = b''
+ b64 = base64.b64encode(xml).decode()
+ result = parsedmarc.extract_report(b64)
+ self.assertIn("", result)
+
+ def testExtractReportFromGzip(self):
+ """extract_report handles gzip compressed content"""
+ import gzip
+
+ xml = b''
+ compressed = gzip.compress(xml)
+ result = parsedmarc.extract_report(compressed)
+ self.assertIn("", result)
+
+ def testExtractReportFromZip(self):
+ """extract_report handles zip compressed content"""
+ import zipfile
+
+ xml = b''
+ buf = BytesIO()
+ with zipfile.ZipFile(buf, "w") as zf:
+ zf.writestr("report.xml", xml)
+ result = parsedmarc.extract_report(buf.getvalue())
+ self.assertIn("", result)
+
+ def testExtractReportFromBinaryIO(self):
+ """extract_report handles file-like BinaryIO objects"""
+ xml = b''
+ bio = BytesIO(xml)
+ result = parsedmarc.extract_report(bio)
+ self.assertIn("", result)
+
+ def testExtractReportFromNonSeekableStream(self):
+ """extract_report handles non-seekable streams"""
+ xml = b''
+
+ class NonSeekable:
+ def __init__(self, data):
+ self._data = data
+ self._pos = 0
+
+ def read(self, n=-1):
+ if n == -1:
+ result = self._data[self._pos :]
+ self._pos = len(self._data)
+ else:
+ result = self._data[self._pos : self._pos + n]
+ self._pos += n
+ return result
+
+ def seekable(self):
+ return False
+
+ def close(self):
+ pass
+
+ result = parsedmarc.extract_report(NonSeekable(xml))
+ self.assertIn("", result)
+
+ def testExtractReportInvalidContent(self):
+ """extract_report raises ParserError for invalid content"""
+ with self.assertRaises(parsedmarc.ParserError):
+ parsedmarc.extract_report(b"this is not a valid archive")
+
+ def testExtractReportTextModeRaises(self):
+ """extract_report raises ParserError for text-mode streams"""
+
+ class TextStream:
+ def read(self, n=-1):
+ return "text data"
+
+ def seekable(self):
+ return True
+
+ def seek(self, pos):
+ pass
+
+ def close(self):
+ pass
+
+ with self.assertRaises(parsedmarc.ParserError):
+ parsedmarc.extract_report(TextStream())
+
+
+class TestMalformedXmlRecovery(unittest.TestCase):
+ """Tests for XML recovery in parse_aggregate_report_xml"""
+
+ def testRecoversMalformedXml(self):
+ """Malformed XML triggers recovery path and still parses"""
+ # XML with a broken tag that xmltodict will reject but lxml can recover
+ malformed_xml = """
+
+
+ example.com
+ dmarc@example.com
+ 12345
+ 16800000001680086400
+
+
+ example.comnone
+
+
+ 203.0.113.11
+ nonepasspass
+
+ example.com
+ example.compass
+
+ """
+ # lxml recovery may succeed or fail depending on how broken the XML is
+ # Either way, no unhandled exception should escape
+ try:
+ report = parsedmarc.parse_aggregate_report_xml(malformed_xml, offline=True)
+ self.assertIn("report_metadata", report)
+ except parsedmarc.InvalidAggregateReport:
+ pass # Also acceptable
+
+ def testBytesXmlInput(self):
+ """XML bytes input is decoded"""
+ xml = b"""
+
+
+ example.com
+ dmarc@example.com
+ test-bytes-input
+ 16800000001680086400
+
+
+ example.comnone
+
+
+ 203.0.113.11
+ nonepasspass
+
+ example.com
+ example.compass
+
+"""
+ report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
+ self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input")
+
+ def testExpatErrorRaises(self):
+ """Completely invalid XML raises InvalidAggregateReport"""
+ with self.assertRaises(parsedmarc.InvalidAggregateReport):
+ parsedmarc.parse_aggregate_report_xml("not xml at all {}", offline=True)
+
+ def testMissingOrgName(self):
+ """Missing org_name raises InvalidAggregateReport"""
+ xml = """
+
+
+ dmarc@example.com
+ missing-org
+ 16800000001680086400
+
+ example.comnone
+
+ 1.2.3.41
+ nonepasspass
+
+ example.com
+ example.compass
+
+"""
+ with self.assertRaises(parsedmarc.InvalidAggregateReport):
+ parsedmarc.parse_aggregate_report_xml(xml, offline=True)
+
+
+class TestPolicyPublishedEdgeCases(unittest.TestCase):
+ """Tests for edge cases in policy_published parsing"""
+
+ VALID_XML_TEMPLATE = """
+
+
+ example.com
+ dmarc@example.com
+ test-{tag}
+ 16800000001680086400
+ {extra_metadata}
+
+
+ example.comreject
+ {policy_extra}
+
+
+ 203.0.113.11
+ nonepasspass
+
+ example.com
+ example.compass
+
+"""
+
+ def _parse(self, tag="default", policy_extra="", extra_metadata=""):
+ xml = self.VALID_XML_TEMPLATE.format(
+ tag=tag, policy_extra=policy_extra, extra_metadata=extra_metadata
+ )
+ return parsedmarc.parse_aggregate_report_xml(xml, offline=True)
+
+ def testPolicyPublishedListHandled(self):
+ """policy_published as a list uses first element"""
+ # The code checks `if type(policy_published) is list`
+ # This is tested implicitly when xmltodict returns a list;
+ # we test via the np field presence
+ report = self._parse(tag="np", policy_extra="quarantine")
+ self.assertEqual(report["policy_published"]["np"], "quarantine")
+
+ def testNpFieldValues(self):
+ """np field is parsed correctly"""
+ for val in ["none", "quarantine", "reject"]:
+ report = self._parse(tag=f"np-{val}", policy_extra=f"{val}")
+ self.assertEqual(report["policy_published"]["np"], val)
+
+ def testTestingField(self):
+ """testing field is parsed correctly"""
+ for val in ["y", "n"]:
+ report = self._parse(
+ tag=f"testing-{val}", policy_extra=f"{val}"
+ )
+ self.assertEqual(report["policy_published"]["testing"], val)
+
+ def testDiscoveryMethodField(self):
+ """discovery_method field is parsed correctly"""
+ for val in ["psl", "treewalk"]:
+ report = self._parse(
+ tag=f"disc-{val}",
+ policy_extra=f"{val}",
+ )
+ self.assertEqual(report["policy_published"]["discovery_method"], val)
+
+ def testGeneratorField(self):
+ """generator field in report_metadata is parsed"""
+ report = self._parse(
+ tag="gen", extra_metadata="TestGen/1.0"
+ )
+ self.assertEqual(report["report_metadata"]["generator"], "TestGen/1.0")
+
+ def testPctFieldNone(self):
+ """pct defaults to None when absent (DMARCbis)"""
+ report = self._parse(tag="no-pct")
+ self.assertIsNone(report["policy_published"]["pct"])
+
+ def testFoFieldNone(self):
+ """fo defaults to None when absent (DMARCbis)"""
+ report = self._parse(tag="no-fo")
+ self.assertIsNone(report["policy_published"]["fo"])
+
+ def testReportMetadataErrors(self):
+ """Report metadata errors are captured"""
+ report = self._parse(
+ tag="errors",
+ extra_metadata="DNS timeout",
+ )
+ self.assertIn("DNS timeout", report["report_metadata"]["errors"])
+
+ def testReportMetadataErrorsList(self):
+ """Report metadata errors as list are captured"""
+ report = self._parse(
+ tag="errors-list",
+ extra_metadata="error1error2",
+ )
+ self.assertIn("error1", report["report_metadata"]["errors"])
+ self.assertIn("error2", report["report_metadata"]["errors"])
+
+ def testRecordParseFailureSkipped(self):
+ """Bad records are skipped with a warning, not crashing"""
+ xml = """
+
+
+ example.com
+ dmarc@example.com
+ bad-records
+ 16800000001680086400
+
+ example.comnone
+
+ 203.0.113.11
+ nonepasspass
+
+ example.com
+ example.compass
+
+
+ bad-ipnot-a-number
+ nonepasspass
+
+ example.com
+ example.compass
+
+"""
+ report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
+ # At least the valid record should be parsed
+ self.assertTrue(len(report["records"]) >= 1)
+
+
+class TestParseReportFile(unittest.TestCase):
+ """Tests for parse_report_file with various input types"""
+
+ def testParseReportFileFromBytes(self):
+ """parse_report_file works with bytes input"""
+ xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
+ with open(xml_path, "rb") as f:
+ content = f.read()
+ result = parsedmarc.parse_report_file(content, offline=True)
+ self.assertEqual(result["report_type"], "aggregate")
+
+ def testParseReportFileFromBinaryIO(self):
+ """parse_report_file works with BinaryIO input"""
+ xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
+ with open(xml_path, "rb") as f:
+ result = parsedmarc.parse_report_file(f, offline=True)
+ self.assertEqual(result["report_type"], "aggregate")
+
+ def testParseReportFileFromPathlib(self):
+ """parse_report_file works with pathlib.Path input"""
+ xml_path = Path("samples/aggregate/!example.com!1538204542!1538463818.xml")
+ result = parsedmarc.parse_report_file(xml_path, offline=True)
+ self.assertEqual(result["report_type"], "aggregate")
+
+ def testParseReportFileSmtpTls(self):
+ """parse_report_file detects SMTP TLS reports"""
+ result = parsedmarc.parse_report_file(
+ "samples/smtp_tls/smtp_tls.json", offline=True
+ )
+ self.assertEqual(result["report_type"], "smtp_tls")
+
+ def testParseReportFileEmail(self):
+ """parse_report_file detects failure reports in email format"""
+ eml_path = "samples/failure/dmarc_ruf_report_linkedin.eml"
+ result = parsedmarc.parse_report_file(eml_path, offline=True)
+ self.assertEqual(result["report_type"], "failure")
+
+ def testParseReportFileInvalid(self):
+ """parse_report_file raises ParserError for invalid content"""
+ with self.assertRaises(parsedmarc.ParserError):
+ parsedmarc.parse_report_file(b"this is not a report", offline=True)
+
+
+class TestParseReportEmail(unittest.TestCase):
+ """Tests for parse_report_email edge cases"""
+
+ def testSmtpTlsEmailReport(self):
+ """parse_report_email handles SMTP TLS reports in email format"""
+ eml_path = "samples/smtp_tls/google.com_smtp_tls_report.eml"
+ with open(eml_path, "rb") as f:
+ content = f.read()
+ result = parsedmarc.parse_report_email(content, offline=True)
+ self.assertEqual(result["report_type"], "smtp_tls")
+
+ def testInvalidEmailRaisesError(self):
+ """parse_report_email raises error for non-DMARC email"""
+ email_str = """From: test@example.com
+Subject: Hello World
+Content-Type: text/plain
+
+This is not a DMARC report."""
+ with self.assertRaises(parsedmarc.InvalidDMARCReport):
+ parsedmarc.parse_report_email(email_str, offline=True)
+
+
+class TestFailureReportParsing(unittest.TestCase):
+ """Tests for failure report field defaults and edge cases"""
+
+ def _make_feedback_report(self, **overrides):
+ """Create a minimal feedback report string"""
+ fields = {
+ "Feedback-Type": "auth-failure",
+ "User-Agent": "test/1.0",
+ "Version": "1",
+ "Original-Mail-From": "sender@example.com",
+ "Arrival-Date": "Thu, 1 Jan 2024 00:00:00 +0000",
+ "Source-IP": "203.0.113.1",
+ "Reported-Domain": "example.com",
+ "Auth-Failure": "dmarc",
+ }
+ fields.update(overrides)
+ return "\n".join(f"{k}: {v}" for k, v in fields.items())
+
+ def _make_sample(self):
+ return """From: sender@example.com
+To: recipient@example.com
+Subject: Test
+Date: Thu, 1 Jan 2024 00:00:00 +0000
+
+Test body"""
+
+ def _default_msg_date(self):
+ return datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ def testMissingVersion(self):
+ """Missing version defaults to None"""
+ report_str = self._make_feedback_report()
+ lines = [l for l in report_str.split("\n") if not l.startswith("Version:")]
+ report_str = "\n".join(lines)
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertIsNone(report["version"])
+
+ def testMissingUserAgent(self):
+ """Missing user_agent defaults to None"""
+ report_str = self._make_feedback_report()
+ lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")]
+ report_str = "\n".join(lines)
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertIsNone(report["user_agent"])
+
+ def testMissingDeliveryResult(self):
+ """Missing delivery_result maps to 'other' when field absent"""
+ report_str = self._make_feedback_report()
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ # When delivery_result is not in the parsed report, it's set to None,
+ # but then the validation check maps None (not in delivery_results list) to "other"
+ self.assertEqual(report["delivery_result"], "other")
+
+ def testDeliveryResultMapped(self):
+ """Known delivery_result values are mapped correctly"""
+ for val in ["delivered", "spam", "policy", "reject"]:
+ report_str = self._make_feedback_report(**{"Delivery-Result": val})
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertEqual(report["delivery_result"], val)
+
+ def testDeliveryResultUnknownMapsToOther(self):
+ """Unknown delivery_result maps to 'other'"""
+ report_str = self._make_feedback_report(**{"Delivery-Result": "unknown-value"})
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertEqual(report["delivery_result"], "other")
+
+ def testIdentityAlignmentNone(self):
+ """identity_alignment='none' results in empty auth mechanisms"""
+ report_str = self._make_feedback_report(**{"Identity-Alignment": "none"})
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertEqual(report["authentication_mechanisms"], [])
+
+ def testIdentityAlignmentMultiple(self):
+ """identity_alignment with multiple values is split"""
+ report_str = self._make_feedback_report(**{"Identity-Alignment": "dkim,spf"})
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertEqual(report["authentication_mechanisms"], ["dkim", "spf"])
+
+ def testMissingReportedDomainFallback(self):
+ """Missing reported_domain falls back to sample from domain"""
+ report_str = self._make_feedback_report()
+ lines = [
+ l for l in report_str.split("\n") if not l.startswith("Reported-Domain:")
+ ]
+ report_str = "\n".join(lines)
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), self._default_msg_date(), offline=True
+ )
+ self.assertEqual(report["reported_domain"], "example.com")
+
+ def testMissingArrivalDateWithMsgDate(self):
+ """Missing arrival_date uses msg_date fallback"""
+ report_str = self._make_feedback_report()
+ lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
+ report_str = "\n".join(lines)
+ msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
+ report = parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), msg_date, offline=True
+ )
+ self.assertIn("2024-06-15", report["arrival_date"])
+
+ def testMissingArrivalDateNoMsgDateRaises(self):
+ """Missing arrival_date with no msg_date raises"""
+ report_str = self._make_feedback_report()
+ lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
+ report_str = "\n".join(lines)
+ with self.assertRaises(parsedmarc.InvalidFailureReport):
+ parsedmarc.parse_failure_report(
+ report_str, self._make_sample(), None, offline=True
+ )
+
+
+class TestWebhookClient(unittest.TestCase):
+ """Tests for webhook client error handling and close"""
+
+ def testSaveMethodsHandleErrors(self):
+ """Each save method catches and logs errors from _send_to_webhook"""
+ client = parsedmarc.webhook.WebhookClient(
+ aggregate_url="http://invalid.test/agg",
+ failure_url="http://invalid.test/fail",
+ smtp_tls_url="http://invalid.test/tls",
+ )
+ # Mock _send_to_webhook to raise, testing the try/except in each save method
+ with patch.object(
+ client, "_send_to_webhook", side_effect=Exception("send failed")
+ ):
+ # None should raise
+ client.save_aggregate_report_to_webhook('{"test": true}')
+ client.save_failure_report_to_webhook('{"test": true}')
+ client.save_smtp_tls_report_to_webhook('{"test": true}')
+
+ def testSendToWebhookLogsPostErrors(self):
+ """_send_to_webhook catches and logs POST errors"""
+ client = parsedmarc.webhook.WebhookClient(
+ aggregate_url="http://invalid.test/agg",
+ failure_url="http://invalid.test/fail",
+ smtp_tls_url="http://invalid.test/tls",
+ )
+ with patch.object(
+ client.session, "post", side_effect=Exception("connection refused")
+ ):
+ # Should not raise
+ client._send_to_webhook("http://invalid.test/agg", '{"test": true}')
+
+ def testClose(self):
+ """WebhookClient.close() closes session"""
+ client = parsedmarc.webhook.WebhookClient(
+ aggregate_url="http://invalid.test/agg",
+ failure_url="http://invalid.test/fail",
+ smtp_tls_url="http://invalid.test/tls",
+ )
+ mock_close = MagicMock()
+ client.session.close = mock_close
+ client.close()
+ mock_close.assert_called_once()
+
+
+class TestUtilsDnsCaching(unittest.TestCase):
+ """Tests for DNS query caching and reverse DNS error handling"""
+
+ def testQueryDnsUsesCacheHit(self):
+ """query_dns returns cached result without making DNS query"""
+ cache = ExpiringDict(max_len=100, max_age_seconds=60)
+ cache["example.com_A"] = ["1.2.3.4"]
+ result = parsedmarc.utils.query_dns("example.com", "A", cache=cache)
+ self.assertEqual(result, ["1.2.3.4"])
+
+ def testQueryDnsCachesResult(self):
+ """query_dns stores result in cache when cache is non-empty"""
+ cache = ExpiringDict(max_len=100, max_age_seconds=60)
+ # Pre-populate so ExpiringDict is truthy
+ cache["seed_key"] = ["seed"]
+ mock_record = MagicMock()
+ mock_record.to_text.return_value = '"1.2.3.4"'
+ mock_resolver = MagicMock()
+ mock_resolver.resolve.return_value = [mock_record]
+ with patch(
+ "parsedmarc.utils.dns.resolver.Resolver", return_value=mock_resolver
+ ):
+ result = parsedmarc.utils.query_dns(
+ "test-cache.example.com", "A", cache=cache
+ )
+ self.assertEqual(result, ["1.2.3.4"])
+ self.assertIn("test-cache.example.com_A", cache)
+
+ def testReverseDnsReturnsNoneOnFailure(self):
+ """get_reverse_dns returns None on DNS exceptions"""
+ with patch(
+ "parsedmarc.utils.query_dns",
+ side_effect=dns.exception.DNSException("timeout"),
+ ):
+ result = parsedmarc.utils.get_reverse_dns("203.0.113.1")
+ self.assertIsNone(result)
+
+
+class TestUtilsIpDbPaths(unittest.TestCase):
+ """Tests for IP database path validation"""
+
+ def testCustomPathFallsBack(self):
+ """Non-existent custom db path falls back to default"""
+ result = parsedmarc.utils.get_ip_address_country(
+ "1.1.1.1", db_path="/nonexistent/path.mmdb"
+ )
+ self.assertTrue(result is None or isinstance(result, str))
+
+ def testBundledDbWorks(self):
+ """Bundled IP database returns results"""
+ result = parsedmarc.utils.get_ip_address_country("8.8.8.8")
+ self.assertEqual(result, "US")
+
+
+class TestUtilsParseEmail(unittest.TestCase):
+ """Tests for parse_email edge cases"""
+
+ def testMinimalEmail(self):
+ """parse_email handles email with minimal headers"""
+ email_str = """From: test@example.com
+Subject: Test
+
+Body text"""
+ result = parsedmarc.utils.parse_email(email_str)
+ self.assertEqual(result["subject"], "Test")
+ self.assertEqual(result["reply_to"], [])
+
+ def testEmailWithNoSubject(self):
+ """parse_email defaults subject to None when missing"""
+ email_str = """From: test@example.com
+To: other@example.com
+
+Body"""
+ result = parsedmarc.utils.parse_email(email_str)
+ self.assertIsNone(result["subject"])
+
+ def testEmailBytesInput(self):
+ """parse_email handles bytes input"""
+ email_bytes = b"""From: test@example.com
+Subject: Bytes Test
+To: other@example.com
+
+Body"""
+ result = parsedmarc.utils.parse_email(email_bytes)
+ self.assertEqual(result["subject"], "Bytes Test")
+
+ def testEmailWithAttachments(self):
+ """parse_email with strip_attachment_payloads removes payloads"""
+ import email as email_mod
+ from email.mime.multipart import MIMEMultipart
+ from email.mime.text import MIMEText
+ from email.mime.base import MIMEBase
+ from email import encoders
+
+ msg = MIMEMultipart()
+ msg["From"] = "test@example.com"
+ msg["To"] = "other@example.com"
+ msg["Subject"] = "Attachment Test"
+ msg.attach(MIMEText("Body text"))
+
+ attachment = MIMEBase("application", "octet-stream")
+ attachment.set_payload(b"file content here")
+ encoders.encode_base64(attachment)
+ attachment.add_header("Content-Disposition", "attachment", filename="test.bin")
+ msg.attach(attachment)
+
+ result = parsedmarc.utils.parse_email(
+ msg.as_string(), strip_attachment_payloads=True
+ )
+ for att in result["attachments"]:
+ self.assertNotIn("payload", att)
+
+
+class TestUtilsOutlookMsg(unittest.TestCase):
+ """Tests for Outlook MSG detection and conversion"""
+
+ def testIsOutlookMsg(self):
+ """is_outlook_msg detects MSG magic bytes"""
+ msg_magic = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" + b"\x00" * 100
+ self.assertTrue(parsedmarc.utils.is_outlook_msg(msg_magic))
+
+ def testIsNotOutlookMsg(self):
+ """is_outlook_msg rejects non-MSG content"""
+ self.assertFalse(parsedmarc.utils.is_outlook_msg(b"not an msg file"))
+ self.assertFalse(parsedmarc.utils.is_outlook_msg("string input"))
+
+ def testConvertOutlookMsgInvalidInput(self):
+ """convert_outlook_msg raises ValueError for non-MSG bytes"""
+ with self.assertRaises(ValueError):
+ parsedmarc.utils.convert_outlook_msg(b"not an msg file")
+
+
+class TestUtilsReverseDnsMap(unittest.TestCase):
+ """Tests for reverse DNS map loading"""
+
+ def testLoadReverseDnsMapOffline(self):
+ """load_reverse_dns_map in offline mode loads bundled map"""
+ rdns_map = {}
+ parsedmarc.utils.load_reverse_dns_map(rdns_map, offline=True)
+ self.assertTrue(len(rdns_map) > 0)
+
+ def testLoadReverseDnsMapLocalOverride(self):
+ """load_reverse_dns_map uses local_file_path when provided"""
+ with NamedTemporaryFile("w", suffix=".csv", delete=False) as f:
+ f.write("base_reverse_dns,name,type\n")
+ f.write("custom.example.com,Custom Service,hosting\n")
+ path = f.name
+ try:
+ rdns_map = {}
+ parsedmarc.utils.load_reverse_dns_map(
+ rdns_map, offline=True, local_file_path=path
+ )
+ self.assertIn("custom.example.com", rdns_map)
+ self.assertEqual(rdns_map["custom.example.com"]["name"], "Custom Service")
+ finally:
+ os.remove(path)
+
+ def testLoadReverseDnsMapNetworkFailureFallback(self):
+ """load_reverse_dns_map falls back to bundled on network error"""
+ rdns_map = {}
+ with patch(
+ "parsedmarc.utils.requests.get",
+ side_effect=requests.exceptions.ConnectionError("no network"),
+ ):
+ parsedmarc.utils.load_reverse_dns_map(rdns_map)
+ self.assertTrue(len(rdns_map) > 0)
+
+
+class TestSmtpTlsReportErrors(unittest.TestCase):
+ """Tests for SMTP TLS report error handling"""
+
+ def testMissingRequiredField(self):
+ """Missing required field raises InvalidSMTPTLSReport"""
+ json_str = json.dumps({"policies": []})
+ with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
+ parsedmarc.parse_smtp_tls_report_json(json_str)
+
+ def testInvalidJson(self):
+ """Invalid JSON raises InvalidSMTPTLSReport"""
+ with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
+ parsedmarc.parse_smtp_tls_report_json("not json {{{")
+
+
+class TestBucketIntervalEdgeCases(unittest.TestCase):
+ """Tests for _bucket_interval_by_day edge cases"""
+
+ def testDayCursorAdjustment(self):
+ """When begin is before midnight due to tz, day_cursor adjusts back"""
+ # Use a timezone where midnight calculation might cause day_cursor > begin
+ import pytz
+
+ tz = pytz.FixedOffset(-600) # UTC-10
+ begin = datetime(2024, 1, 1, 23, 30, 0, tzinfo=timezone.utc).astimezone(tz)
+ end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc).astimezone(tz)
+ buckets = parsedmarc._bucket_interval_by_day(begin, end, 100)
+ total = sum(b["count"] for b in buckets)
+ self.assertEqual(total, 100)
+
+
+class TestGetDmarcReportsFromMbox(unittest.TestCase):
+ """Tests for mbox parsing"""
+
+ def testEmptyMbox(self):
+ """Empty mbox returns empty results"""
+ with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
+ f.write(b"")
+ path = f.name
+ try:
+ results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
+ self.assertEqual(results["aggregate_reports"], [])
+ self.assertEqual(results["failure_reports"], [])
+ self.assertEqual(results["smtp_tls_reports"], [])
+ finally:
+ os.remove(path)
+
+ def testMboxWithAggregateReport(self):
+ """Mbox with aggregate report email is parsed"""
+ import email as email_mod
+ from email.mime.multipart import MIMEMultipart
+ from email.mime.application import MIMEApplication
+ import gzip
+
+ xml = b"""
+
+
+ example.com
+ dmarc@example.com
+ mbox-test-123
+ 16800000001680086400
+
+ example.comnone
+
+ 203.0.113.11
+ nonepasspass
+
+ example.com
+ example.compass
+
+"""
+ compressed = gzip.compress(xml)
+
+ msg = MIMEMultipart()
+ msg["From"] = "dmarc@example.com"
+ msg["To"] = "postmaster@example.com"
+ msg["Subject"] = "DMARC Aggregate Report"
+ msg["Date"] = "Thu, 1 Jan 2024 00:00:00 +0000"
+ att = MIMEApplication(compressed, "gzip")
+ att.add_header("Content-Disposition", "attachment", filename="report.xml.gz")
+ msg.attach(att)
+
+ with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
+ # mbox format requires "From " line
+ f.write("From dmarc@example.com Thu Jan 1 00:00:00 2024\n")
+ f.write(msg.as_string())
+ f.write("\n")
+ path = f.name
+ try:
+ results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
+ self.assertTrue(len(results["aggregate_reports"]) >= 1)
+ finally:
+ os.remove(path)
+
+
+class TestPslOverrides(unittest.TestCase):
+ """Tests for PSL override matching"""
+
+ def testOverrideMatch(self):
+ """PSL overrides are applied when domain ends with override"""
+ # psl_overrides contains entries; test that get_base_domain
+ # handles them without error
+ result = parsedmarc.utils.get_base_domain("sub.example.com")
+ self.assertEqual(result, "example.com")
+
+
+class TestIsMbox(unittest.TestCase):
+ """Tests for is_mbox utility"""
+
+ def testValidMbox(self):
+ """is_mbox returns True for valid mbox file"""
+ with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
+ f.write("From test@example.com Thu Jan 1 00:00:00 2024\n")
+ f.write("Subject: Test\n\nBody\n\n")
+ path = f.name
+ try:
+ self.assertTrue(parsedmarc.utils.is_mbox(path))
+ finally:
+ os.remove(path)
+
+ def testEmptyFileNotMbox(self):
+ """is_mbox returns False for empty file"""
+ with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
+ path = f.name
+ try:
+ self.assertFalse(parsedmarc.utils.is_mbox(path))
+ finally:
+ os.remove(path)
+
+ def testNonExistentNotMbox(self):
+ """is_mbox returns False for non-existent file"""
+ self.assertFalse(parsedmarc.utils.is_mbox("/nonexistent/file.mbox"))
+
+
if __name__ == "__main__":
unittest.main(verbosity=2)