Add tests for extract_report and parse_aggregate_report_xml functions

This commit is contained in:
Sean Whalen
2026-03-26 02:17:42 -04:00
parent a55d9b3010
commit 3bdf795006
+854
View File
@@ -13,6 +13,7 @@ import unittest
from base64 import urlsafe_b64encode
from configparser import ConfigParser
from datetime import datetime, timedelta, timezone
from io import BytesIO
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
@@ -20,13 +21,18 @@ from typing import cast
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from expiringdict import ExpiringDict
from lxml import etree # type: ignore[import-untyped]
from googleapiclient.errors import HttpError
from httplib2 import Response
from imapclient.exceptions import IMAPClientError
import dns.exception
import requests
import parsedmarc
import parsedmarc.cli
import parsedmarc.webhook
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
from parsedmarc.mail.gmail import GmailConnection
from parsedmarc.mail.gmail import _get_creds
@@ -3175,6 +3181,7 @@ class TestEnvVarConfig(unittest.TestCase):
config.getboolean("general", "debug"),
f"Expected falsy for {false_val!r}",
)
# ============================================================ # New tests for _bucket_interval_by_day
# ============================================================
def testBucketIntervalBeginAfterEnd(self):
@@ -4635,5 +4642,852 @@ class TestEnvVarConfig(unittest.TestCase):
print("Passed!")
class TestExtractReport(unittest.TestCase):
"""Tests for parsedmarc.extract_report()"""
def testExtractReportFromBytes(self):
"""extract_report handles raw XML bytes"""
xml = b'<?xml version="1.0"?><feedback><report_metadata></report_metadata></feedback>'
result = parsedmarc.extract_report(xml)
self.assertIn("<feedback>", result)
def testExtractReportFromBase64Xml(self):
"""extract_report handles base64-encoded XML string"""
import base64
xml = b'<?xml version="1.0"?><feedback></feedback>'
b64 = base64.b64encode(xml).decode()
result = parsedmarc.extract_report(b64)
self.assertIn("<feedback>", result)
def testExtractReportFromGzip(self):
"""extract_report handles gzip compressed content"""
import gzip
xml = b'<?xml version="1.0"?><feedback></feedback>'
compressed = gzip.compress(xml)
result = parsedmarc.extract_report(compressed)
self.assertIn("<feedback>", result)
def testExtractReportFromZip(self):
"""extract_report handles zip compressed content"""
import zipfile
xml = b'<?xml version="1.0"?><feedback></feedback>'
buf = BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("report.xml", xml)
result = parsedmarc.extract_report(buf.getvalue())
self.assertIn("<feedback>", result)
def testExtractReportFromBinaryIO(self):
"""extract_report handles file-like BinaryIO objects"""
xml = b'<?xml version="1.0"?><feedback></feedback>'
bio = BytesIO(xml)
result = parsedmarc.extract_report(bio)
self.assertIn("<feedback>", result)
def testExtractReportFromNonSeekableStream(self):
"""extract_report handles non-seekable streams"""
xml = b'<?xml version="1.0"?><feedback></feedback>'
class NonSeekable:
def __init__(self, data):
self._data = data
self._pos = 0
def read(self, n=-1):
if n == -1:
result = self._data[self._pos :]
self._pos = len(self._data)
else:
result = self._data[self._pos : self._pos + n]
self._pos += n
return result
def seekable(self):
return False
def close(self):
pass
result = parsedmarc.extract_report(NonSeekable(xml))
self.assertIn("<feedback>", result)
def testExtractReportInvalidContent(self):
"""extract_report raises ParserError for invalid content"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report(b"this is not a valid archive")
def testExtractReportTextModeRaises(self):
"""extract_report raises ParserError for text-mode streams"""
class TextStream:
def read(self, n=-1):
return "text data"
def seekable(self):
return True
def seek(self, pos):
pass
def close(self):
pass
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report(TextStream())
class TestMalformedXmlRecovery(unittest.TestCase):
"""Tests for XML recovery in parse_aggregate_report_xml"""
def testRecoversMalformedXml(self):
"""Malformed XML triggers recovery path and still parses"""
# XML with a broken tag that xmltodict will reject but lxml can recover
malformed_xml = """<?xml version="1.0"?>
<feedback>
<report_metadata>
<org_name>example.com</org_name>
<email>dmarc@example.com</email>
<report_id>12345</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
</report_metadata>
<policy_published>
<domain>example.com</domain><p>none</p>
</policy_published>
<record>
<row><source_ip>203.0.113.1</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
<broken_tag
</feedback>"""
# lxml recovery may succeed or fail depending on how broken the XML is
# Either way, no unhandled exception should escape
try:
report = parsedmarc.parse_aggregate_report_xml(malformed_xml, offline=True)
self.assertIn("report_metadata", report)
except parsedmarc.InvalidAggregateReport:
pass # Also acceptable
def testBytesXmlInput(self):
"""XML bytes input is decoded"""
xml = b"""<?xml version="1.0"?>
<feedback>
<report_metadata>
<org_name>example.com</org_name>
<email>dmarc@example.com</email>
<report_id>test-bytes-input</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
</report_metadata>
<policy_published>
<domain>example.com</domain><p>none</p>
</policy_published>
<record>
<row><source_ip>203.0.113.1</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input")
def testExpatErrorRaises(self):
"""Completely invalid XML raises InvalidAggregateReport"""
with self.assertRaises(parsedmarc.InvalidAggregateReport):
parsedmarc.parse_aggregate_report_xml("not xml at all {}", offline=True)
def testMissingOrgName(self):
"""Missing org_name raises InvalidAggregateReport"""
xml = """<?xml version="1.0"?>
<feedback>
<report_metadata>
<email>dmarc@example.com</email>
<report_id>missing-org</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
</report_metadata>
<policy_published><domain>example.com</domain><p>none</p></policy_published>
<record>
<row><source_ip>1.2.3.4</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
with self.assertRaises(parsedmarc.InvalidAggregateReport):
parsedmarc.parse_aggregate_report_xml(xml, offline=True)
class TestPolicyPublishedEdgeCases(unittest.TestCase):
"""Tests for edge cases in policy_published parsing"""
VALID_XML_TEMPLATE = """<?xml version="1.0"?>
<feedback>
<report_metadata>
<org_name>example.com</org_name>
<email>dmarc@example.com</email>
<report_id>test-{tag}</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
{extra_metadata}
</report_metadata>
<policy_published>
<domain>example.com</domain><p>reject</p>
{policy_extra}
</policy_published>
<record>
<row><source_ip>203.0.113.1</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
def _parse(self, tag="default", policy_extra="", extra_metadata=""):
xml = self.VALID_XML_TEMPLATE.format(
tag=tag, policy_extra=policy_extra, extra_metadata=extra_metadata
)
return parsedmarc.parse_aggregate_report_xml(xml, offline=True)
def testPolicyPublishedListHandled(self):
"""policy_published as a list uses first element"""
# The code checks `if type(policy_published) is list`
# This is tested implicitly when xmltodict returns a list;
# we test via the np field presence
report = self._parse(tag="np", policy_extra="<np>quarantine</np>")
self.assertEqual(report["policy_published"]["np"], "quarantine")
def testNpFieldValues(self):
"""np field is parsed correctly"""
for val in ["none", "quarantine", "reject"]:
report = self._parse(tag=f"np-{val}", policy_extra=f"<np>{val}</np>")
self.assertEqual(report["policy_published"]["np"], val)
def testTestingField(self):
"""testing field is parsed correctly"""
for val in ["y", "n"]:
report = self._parse(
tag=f"testing-{val}", policy_extra=f"<testing>{val}</testing>"
)
self.assertEqual(report["policy_published"]["testing"], val)
def testDiscoveryMethodField(self):
"""discovery_method field is parsed correctly"""
for val in ["psl", "treewalk"]:
report = self._parse(
tag=f"disc-{val}",
policy_extra=f"<discovery_method>{val}</discovery_method>",
)
self.assertEqual(report["policy_published"]["discovery_method"], val)
def testGeneratorField(self):
"""generator field in report_metadata is parsed"""
report = self._parse(
tag="gen", extra_metadata="<generator>TestGen/1.0</generator>"
)
self.assertEqual(report["report_metadata"]["generator"], "TestGen/1.0")
def testPctFieldNone(self):
"""pct defaults to None when absent (DMARCbis)"""
report = self._parse(tag="no-pct")
self.assertIsNone(report["policy_published"]["pct"])
def testFoFieldNone(self):
"""fo defaults to None when absent (DMARCbis)"""
report = self._parse(tag="no-fo")
self.assertIsNone(report["policy_published"]["fo"])
def testReportMetadataErrors(self):
"""Report metadata errors are captured"""
report = self._parse(
tag="errors",
extra_metadata="<error>DNS timeout</error>",
)
self.assertIn("DNS timeout", report["report_metadata"]["errors"])
def testReportMetadataErrorsList(self):
"""Report metadata errors as list are captured"""
report = self._parse(
tag="errors-list",
extra_metadata="<error>error1</error><error>error2</error>",
)
self.assertIn("error1", report["report_metadata"]["errors"])
self.assertIn("error2", report["report_metadata"]["errors"])
def testRecordParseFailureSkipped(self):
"""Bad records are skipped with a warning, not crashing"""
xml = """<?xml version="1.0"?>
<feedback>
<report_metadata>
<org_name>example.com</org_name>
<email>dmarc@example.com</email>
<report_id>bad-records</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
</report_metadata>
<policy_published><domain>example.com</domain><p>none</p></policy_published>
<record>
<row><source_ip>203.0.113.1</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
<record>
<row><source_ip>bad-ip</source_ip><count>not-a-number</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
# At least the valid record should be parsed
self.assertTrue(len(report["records"]) >= 1)
class TestParseReportFile(unittest.TestCase):
"""Tests for parse_report_file with various input types"""
def testParseReportFileFromBytes(self):
"""parse_report_file works with bytes input"""
xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
with open(xml_path, "rb") as f:
content = f.read()
result = parsedmarc.parse_report_file(content, offline=True)
self.assertEqual(result["report_type"], "aggregate")
def testParseReportFileFromBinaryIO(self):
"""parse_report_file works with BinaryIO input"""
xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
with open(xml_path, "rb") as f:
result = parsedmarc.parse_report_file(f, offline=True)
self.assertEqual(result["report_type"], "aggregate")
def testParseReportFileFromPathlib(self):
"""parse_report_file works with pathlib.Path input"""
xml_path = Path("samples/aggregate/!example.com!1538204542!1538463818.xml")
result = parsedmarc.parse_report_file(xml_path, offline=True)
self.assertEqual(result["report_type"], "aggregate")
def testParseReportFileSmtpTls(self):
"""parse_report_file detects SMTP TLS reports"""
result = parsedmarc.parse_report_file(
"samples/smtp_tls/smtp_tls.json", offline=True
)
self.assertEqual(result["report_type"], "smtp_tls")
def testParseReportFileEmail(self):
"""parse_report_file detects failure reports in email format"""
eml_path = "samples/failure/dmarc_ruf_report_linkedin.eml"
result = parsedmarc.parse_report_file(eml_path, offline=True)
self.assertEqual(result["report_type"], "failure")
def testParseReportFileInvalid(self):
"""parse_report_file raises ParserError for invalid content"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.parse_report_file(b"this is not a report", offline=True)
class TestParseReportEmail(unittest.TestCase):
"""Tests for parse_report_email edge cases"""
def testSmtpTlsEmailReport(self):
"""parse_report_email handles SMTP TLS reports in email format"""
eml_path = "samples/smtp_tls/google.com_smtp_tls_report.eml"
with open(eml_path, "rb") as f:
content = f.read()
result = parsedmarc.parse_report_email(content, offline=True)
self.assertEqual(result["report_type"], "smtp_tls")
def testInvalidEmailRaisesError(self):
"""parse_report_email raises error for non-DMARC email"""
email_str = """From: test@example.com
Subject: Hello World
Content-Type: text/plain
This is not a DMARC report."""
with self.assertRaises(parsedmarc.InvalidDMARCReport):
parsedmarc.parse_report_email(email_str, offline=True)
class TestFailureReportParsing(unittest.TestCase):
"""Tests for failure report field defaults and edge cases"""
def _make_feedback_report(self, **overrides):
"""Create a minimal feedback report string"""
fields = {
"Feedback-Type": "auth-failure",
"User-Agent": "test/1.0",
"Version": "1",
"Original-Mail-From": "sender@example.com",
"Arrival-Date": "Thu, 1 Jan 2024 00:00:00 +0000",
"Source-IP": "203.0.113.1",
"Reported-Domain": "example.com",
"Auth-Failure": "dmarc",
}
fields.update(overrides)
return "\n".join(f"{k}: {v}" for k, v in fields.items())
def _make_sample(self):
return """From: sender@example.com
To: recipient@example.com
Subject: Test
Date: Thu, 1 Jan 2024 00:00:00 +0000
Test body"""
def _default_msg_date(self):
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
def testMissingVersion(self):
"""Missing version defaults to None"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Version:")]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertIsNone(report["version"])
def testMissingUserAgent(self):
"""Missing user_agent defaults to None"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertIsNone(report["user_agent"])
def testMissingDeliveryResult(self):
"""Missing delivery_result maps to 'other' when field absent"""
report_str = self._make_feedback_report()
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
# When delivery_result is not in the parsed report, it's set to None,
# but then the validation check maps None (not in delivery_results list) to "other"
self.assertEqual(report["delivery_result"], "other")
def testDeliveryResultMapped(self):
"""Known delivery_result values are mapped correctly"""
for val in ["delivered", "spam", "policy", "reject"]:
report_str = self._make_feedback_report(**{"Delivery-Result": val})
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertEqual(report["delivery_result"], val)
def testDeliveryResultUnknownMapsToOther(self):
"""Unknown delivery_result maps to 'other'"""
report_str = self._make_feedback_report(**{"Delivery-Result": "unknown-value"})
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertEqual(report["delivery_result"], "other")
def testIdentityAlignmentNone(self):
"""identity_alignment='none' results in empty auth mechanisms"""
report_str = self._make_feedback_report(**{"Identity-Alignment": "none"})
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertEqual(report["authentication_mechanisms"], [])
def testIdentityAlignmentMultiple(self):
"""identity_alignment with multiple values is split"""
report_str = self._make_feedback_report(**{"Identity-Alignment": "dkim,spf"})
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertEqual(report["authentication_mechanisms"], ["dkim", "spf"])
def testMissingReportedDomainFallback(self):
"""Missing reported_domain falls back to sample from domain"""
report_str = self._make_feedback_report()
lines = [
l for l in report_str.split("\n") if not l.startswith("Reported-Domain:")
]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
)
self.assertEqual(report["reported_domain"], "example.com")
def testMissingArrivalDateWithMsgDate(self):
"""Missing arrival_date uses msg_date fallback"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
report_str = "\n".join(lines)
msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), msg_date, offline=True
)
self.assertIn("2024-06-15", report["arrival_date"])
def testMissingArrivalDateNoMsgDateRaises(self):
"""Missing arrival_date with no msg_date raises"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
report_str = "\n".join(lines)
with self.assertRaises(parsedmarc.InvalidFailureReport):
parsedmarc.parse_failure_report(
report_str, self._make_sample(), None, offline=True
)
class TestWebhookClient(unittest.TestCase):
"""Tests for webhook client error handling and close"""
def testSaveMethodsHandleErrors(self):
"""Each save method catches and logs errors from _send_to_webhook"""
client = parsedmarc.webhook.WebhookClient(
aggregate_url="http://invalid.test/agg",
failure_url="http://invalid.test/fail",
smtp_tls_url="http://invalid.test/tls",
)
# Mock _send_to_webhook to raise, testing the try/except in each save method
with patch.object(
client, "_send_to_webhook", side_effect=Exception("send failed")
):
# None should raise
client.save_aggregate_report_to_webhook('{"test": true}')
client.save_failure_report_to_webhook('{"test": true}')
client.save_smtp_tls_report_to_webhook('{"test": true}')
def testSendToWebhookLogsPostErrors(self):
"""_send_to_webhook catches and logs POST errors"""
client = parsedmarc.webhook.WebhookClient(
aggregate_url="http://invalid.test/agg",
failure_url="http://invalid.test/fail",
smtp_tls_url="http://invalid.test/tls",
)
with patch.object(
client.session, "post", side_effect=Exception("connection refused")
):
# Should not raise
client._send_to_webhook("http://invalid.test/agg", '{"test": true}')
def testClose(self):
"""WebhookClient.close() closes session"""
client = parsedmarc.webhook.WebhookClient(
aggregate_url="http://invalid.test/agg",
failure_url="http://invalid.test/fail",
smtp_tls_url="http://invalid.test/tls",
)
mock_close = MagicMock()
client.session.close = mock_close
client.close()
mock_close.assert_called_once()
class TestUtilsDnsCaching(unittest.TestCase):
"""Tests for DNS query caching and reverse DNS error handling"""
def testQueryDnsUsesCacheHit(self):
"""query_dns returns cached result without making DNS query"""
cache = ExpiringDict(max_len=100, max_age_seconds=60)
cache["example.com_A"] = ["1.2.3.4"]
result = parsedmarc.utils.query_dns("example.com", "A", cache=cache)
self.assertEqual(result, ["1.2.3.4"])
def testQueryDnsCachesResult(self):
"""query_dns stores result in cache when cache is non-empty"""
cache = ExpiringDict(max_len=100, max_age_seconds=60)
# Pre-populate so ExpiringDict is truthy
cache["seed_key"] = ["seed"]
mock_record = MagicMock()
mock_record.to_text.return_value = '"1.2.3.4"'
mock_resolver = MagicMock()
mock_resolver.resolve.return_value = [mock_record]
with patch(
"parsedmarc.utils.dns.resolver.Resolver", return_value=mock_resolver
):
result = parsedmarc.utils.query_dns(
"test-cache.example.com", "A", cache=cache
)
self.assertEqual(result, ["1.2.3.4"])
self.assertIn("test-cache.example.com_A", cache)
def testReverseDnsReturnsNoneOnFailure(self):
"""get_reverse_dns returns None on DNS exceptions"""
with patch(
"parsedmarc.utils.query_dns",
side_effect=dns.exception.DNSException("timeout"),
):
result = parsedmarc.utils.get_reverse_dns("203.0.113.1")
self.assertIsNone(result)
class TestUtilsIpDbPaths(unittest.TestCase):
"""Tests for IP database path validation"""
def testCustomPathFallsBack(self):
"""Non-existent custom db path falls back to default"""
result = parsedmarc.utils.get_ip_address_country(
"1.1.1.1", db_path="/nonexistent/path.mmdb"
)
self.assertTrue(result is None or isinstance(result, str))
def testBundledDbWorks(self):
"""Bundled IP database returns results"""
result = parsedmarc.utils.get_ip_address_country("8.8.8.8")
self.assertEqual(result, "US")
class TestUtilsParseEmail(unittest.TestCase):
"""Tests for parse_email edge cases"""
def testMinimalEmail(self):
"""parse_email handles email with minimal headers"""
email_str = """From: test@example.com
Subject: Test
Body text"""
result = parsedmarc.utils.parse_email(email_str)
self.assertEqual(result["subject"], "Test")
self.assertEqual(result["reply_to"], [])
def testEmailWithNoSubject(self):
"""parse_email defaults subject to None when missing"""
email_str = """From: test@example.com
To: other@example.com
Body"""
result = parsedmarc.utils.parse_email(email_str)
self.assertIsNone(result["subject"])
def testEmailBytesInput(self):
"""parse_email handles bytes input"""
email_bytes = b"""From: test@example.com
Subject: Bytes Test
To: other@example.com
Body"""
result = parsedmarc.utils.parse_email(email_bytes)
self.assertEqual(result["subject"], "Bytes Test")
def testEmailWithAttachments(self):
"""parse_email with strip_attachment_payloads removes payloads"""
import email as email_mod
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.base import MIMEBase
from email import encoders
msg = MIMEMultipart()
msg["From"] = "test@example.com"
msg["To"] = "other@example.com"
msg["Subject"] = "Attachment Test"
msg.attach(MIMEText("Body text"))
attachment = MIMEBase("application", "octet-stream")
attachment.set_payload(b"file content here")
encoders.encode_base64(attachment)
attachment.add_header("Content-Disposition", "attachment", filename="test.bin")
msg.attach(attachment)
result = parsedmarc.utils.parse_email(
msg.as_string(), strip_attachment_payloads=True
)
for att in result["attachments"]:
self.assertNotIn("payload", att)
class TestUtilsOutlookMsg(unittest.TestCase):
"""Tests for Outlook MSG detection and conversion"""
def testIsOutlookMsg(self):
"""is_outlook_msg detects MSG magic bytes"""
msg_magic = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" + b"\x00" * 100
self.assertTrue(parsedmarc.utils.is_outlook_msg(msg_magic))
def testIsNotOutlookMsg(self):
"""is_outlook_msg rejects non-MSG content"""
self.assertFalse(parsedmarc.utils.is_outlook_msg(b"not an msg file"))
self.assertFalse(parsedmarc.utils.is_outlook_msg("string input"))
def testConvertOutlookMsgInvalidInput(self):
"""convert_outlook_msg raises ValueError for non-MSG bytes"""
with self.assertRaises(ValueError):
parsedmarc.utils.convert_outlook_msg(b"not an msg file")
class TestUtilsReverseDnsMap(unittest.TestCase):
"""Tests for reverse DNS map loading"""
def testLoadReverseDnsMapOffline(self):
"""load_reverse_dns_map in offline mode loads bundled map"""
rdns_map = {}
parsedmarc.utils.load_reverse_dns_map(rdns_map, offline=True)
self.assertTrue(len(rdns_map) > 0)
def testLoadReverseDnsMapLocalOverride(self):
"""load_reverse_dns_map uses local_file_path when provided"""
with NamedTemporaryFile("w", suffix=".csv", delete=False) as f:
f.write("base_reverse_dns,name,type\n")
f.write("custom.example.com,Custom Service,hosting\n")
path = f.name
try:
rdns_map = {}
parsedmarc.utils.load_reverse_dns_map(
rdns_map, offline=True, local_file_path=path
)
self.assertIn("custom.example.com", rdns_map)
self.assertEqual(rdns_map["custom.example.com"]["name"], "Custom Service")
finally:
os.remove(path)
def testLoadReverseDnsMapNetworkFailureFallback(self):
"""load_reverse_dns_map falls back to bundled on network error"""
rdns_map = {}
with patch(
"parsedmarc.utils.requests.get",
side_effect=requests.exceptions.ConnectionError("no network"),
):
parsedmarc.utils.load_reverse_dns_map(rdns_map)
self.assertTrue(len(rdns_map) > 0)
class TestSmtpTlsReportErrors(unittest.TestCase):
"""Tests for SMTP TLS report error handling"""
def testMissingRequiredField(self):
"""Missing required field raises InvalidSMTPTLSReport"""
json_str = json.dumps({"policies": []})
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
parsedmarc.parse_smtp_tls_report_json(json_str)
def testInvalidJson(self):
"""Invalid JSON raises InvalidSMTPTLSReport"""
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
parsedmarc.parse_smtp_tls_report_json("not json {{{")
class TestBucketIntervalEdgeCases(unittest.TestCase):
"""Tests for _bucket_interval_by_day edge cases"""
def testDayCursorAdjustment(self):
"""When begin is before midnight due to tz, day_cursor adjusts back"""
# Use a timezone where midnight calculation might cause day_cursor > begin
import pytz
tz = pytz.FixedOffset(-600) # UTC-10
begin = datetime(2024, 1, 1, 23, 30, 0, tzinfo=timezone.utc).astimezone(tz)
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc).astimezone(tz)
buckets = parsedmarc._bucket_interval_by_day(begin, end, 100)
total = sum(b["count"] for b in buckets)
self.assertEqual(total, 100)
class TestGetDmarcReportsFromMbox(unittest.TestCase):
"""Tests for mbox parsing"""
def testEmptyMbox(self):
"""Empty mbox returns empty results"""
with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
f.write(b"")
path = f.name
try:
results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
self.assertEqual(results["aggregate_reports"], [])
self.assertEqual(results["failure_reports"], [])
self.assertEqual(results["smtp_tls_reports"], [])
finally:
os.remove(path)
def testMboxWithAggregateReport(self):
"""Mbox with aggregate report email is parsed"""
import email as email_mod
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
import gzip
xml = b"""<?xml version="1.0"?>
<feedback>
<report_metadata>
<org_name>example.com</org_name>
<email>dmarc@example.com</email>
<report_id>mbox-test-123</report_id>
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
</report_metadata>
<policy_published><domain>example.com</domain><p>none</p></policy_published>
<record>
<row><source_ip>203.0.113.1</source_ip><count>1</count>
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
</row>
<identifiers><header_from>example.com</header_from></identifiers>
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
compressed = gzip.compress(xml)
msg = MIMEMultipart()
msg["From"] = "dmarc@example.com"
msg["To"] = "postmaster@example.com"
msg["Subject"] = "DMARC Aggregate Report"
msg["Date"] = "Thu, 1 Jan 2024 00:00:00 +0000"
att = MIMEApplication(compressed, "gzip")
att.add_header("Content-Disposition", "attachment", filename="report.xml.gz")
msg.attach(att)
with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
# mbox format requires "From " line
f.write("From dmarc@example.com Thu Jan 1 00:00:00 2024\n")
f.write(msg.as_string())
f.write("\n")
path = f.name
try:
results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
self.assertTrue(len(results["aggregate_reports"]) >= 1)
finally:
os.remove(path)
class TestPslOverrides(unittest.TestCase):
"""Tests for PSL override matching"""
def testOverrideMatch(self):
"""PSL overrides are applied when domain ends with override"""
# psl_overrides contains entries; test that get_base_domain
# handles them without error
result = parsedmarc.utils.get_base_domain("sub.example.com")
self.assertEqual(result, "example.com")
class TestIsMbox(unittest.TestCase):
"""Tests for is_mbox utility"""
def testValidMbox(self):
"""is_mbox returns True for valid mbox file"""
with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
f.write("From test@example.com Thu Jan 1 00:00:00 2024\n")
f.write("Subject: Test\n\nBody\n\n")
path = f.name
try:
self.assertTrue(parsedmarc.utils.is_mbox(path))
finally:
os.remove(path)
def testEmptyFileNotMbox(self):
"""is_mbox returns False for empty file"""
with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
path = f.name
try:
self.assertFalse(parsedmarc.utils.is_mbox(path))
finally:
os.remove(path)
def testNonExistentNotMbox(self):
"""is_mbox returns False for non-existent file"""
self.assertFalse(parsedmarc.utils.is_mbox("/nonexistent/file.mbox"))
if __name__ == "__main__":
unittest.main(verbosity=2)