mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-05-02 10:05:25 +00:00
Add tests for extract_report and parse_aggregate_report_xml functions
This commit is contained in:
@@ -13,6 +13,7 @@ import unittest
|
||||
from base64 import urlsafe_b64encode
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from io import BytesIO
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
@@ -20,13 +21,18 @@ from typing import cast
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
from lxml import etree # type: ignore[import-untyped]
|
||||
from googleapiclient.errors import HttpError
|
||||
from httplib2 import Response
|
||||
from imapclient.exceptions import IMAPClientError
|
||||
|
||||
import dns.exception
|
||||
import requests
|
||||
|
||||
import parsedmarc
|
||||
import parsedmarc.cli
|
||||
import parsedmarc.webhook
|
||||
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
|
||||
from parsedmarc.mail.gmail import GmailConnection
|
||||
from parsedmarc.mail.gmail import _get_creds
|
||||
@@ -3175,6 +3181,7 @@ class TestEnvVarConfig(unittest.TestCase):
|
||||
config.getboolean("general", "debug"),
|
||||
f"Expected falsy for {false_val!r}",
|
||||
)
|
||||
|
||||
# ============================================================ # New tests for _bucket_interval_by_day
|
||||
# ============================================================
|
||||
def testBucketIntervalBeginAfterEnd(self):
|
||||
@@ -4635,5 +4642,852 @@ class TestEnvVarConfig(unittest.TestCase):
|
||||
print("Passed!")
|
||||
|
||||
|
||||
class TestExtractReport(unittest.TestCase):
|
||||
"""Tests for parsedmarc.extract_report()"""
|
||||
|
||||
def testExtractReportFromBytes(self):
|
||||
"""extract_report handles raw XML bytes"""
|
||||
xml = b'<?xml version="1.0"?><feedback><report_metadata></report_metadata></feedback>'
|
||||
result = parsedmarc.extract_report(xml)
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportFromBase64Xml(self):
|
||||
"""extract_report handles base64-encoded XML string"""
|
||||
import base64
|
||||
|
||||
xml = b'<?xml version="1.0"?><feedback></feedback>'
|
||||
b64 = base64.b64encode(xml).decode()
|
||||
result = parsedmarc.extract_report(b64)
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportFromGzip(self):
|
||||
"""extract_report handles gzip compressed content"""
|
||||
import gzip
|
||||
|
||||
xml = b'<?xml version="1.0"?><feedback></feedback>'
|
||||
compressed = gzip.compress(xml)
|
||||
result = parsedmarc.extract_report(compressed)
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportFromZip(self):
|
||||
"""extract_report handles zip compressed content"""
|
||||
import zipfile
|
||||
|
||||
xml = b'<?xml version="1.0"?><feedback></feedback>'
|
||||
buf = BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("report.xml", xml)
|
||||
result = parsedmarc.extract_report(buf.getvalue())
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportFromBinaryIO(self):
|
||||
"""extract_report handles file-like BinaryIO objects"""
|
||||
xml = b'<?xml version="1.0"?><feedback></feedback>'
|
||||
bio = BytesIO(xml)
|
||||
result = parsedmarc.extract_report(bio)
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportFromNonSeekableStream(self):
|
||||
"""extract_report handles non-seekable streams"""
|
||||
xml = b'<?xml version="1.0"?><feedback></feedback>'
|
||||
|
||||
class NonSeekable:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._pos = 0
|
||||
|
||||
def read(self, n=-1):
|
||||
if n == -1:
|
||||
result = self._data[self._pos :]
|
||||
self._pos = len(self._data)
|
||||
else:
|
||||
result = self._data[self._pos : self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
def seekable(self):
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
result = parsedmarc.extract_report(NonSeekable(xml))
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportInvalidContent(self):
|
||||
"""extract_report raises ParserError for invalid content"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report(b"this is not a valid archive")
|
||||
|
||||
def testExtractReportTextModeRaises(self):
|
||||
"""extract_report raises ParserError for text-mode streams"""
|
||||
|
||||
class TextStream:
|
||||
def read(self, n=-1):
|
||||
return "text data"
|
||||
|
||||
def seekable(self):
|
||||
return True
|
||||
|
||||
def seek(self, pos):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report(TextStream())
|
||||
|
||||
|
||||
class TestMalformedXmlRecovery(unittest.TestCase):
|
||||
"""Tests for XML recovery in parse_aggregate_report_xml"""
|
||||
|
||||
def testRecoversMalformedXml(self):
|
||||
"""Malformed XML triggers recovery path and still parses"""
|
||||
# XML with a broken tag that xmltodict will reject but lxml can recover
|
||||
malformed_xml = """<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<org_name>example.com</org_name>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>12345</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
</report_metadata>
|
||||
<policy_published>
|
||||
<domain>example.com</domain><p>none</p>
|
||||
</policy_published>
|
||||
<record>
|
||||
<row><source_ip>203.0.113.1</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
<broken_tag
|
||||
</feedback>"""
|
||||
# lxml recovery may succeed or fail depending on how broken the XML is
|
||||
# Either way, no unhandled exception should escape
|
||||
try:
|
||||
report = parsedmarc.parse_aggregate_report_xml(malformed_xml, offline=True)
|
||||
self.assertIn("report_metadata", report)
|
||||
except parsedmarc.InvalidAggregateReport:
|
||||
pass # Also acceptable
|
||||
|
||||
def testBytesXmlInput(self):
|
||||
"""XML bytes input is decoded"""
|
||||
xml = b"""<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<org_name>example.com</org_name>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>test-bytes-input</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
</report_metadata>
|
||||
<policy_published>
|
||||
<domain>example.com</domain><p>none</p>
|
||||
</policy_published>
|
||||
<record>
|
||||
<row><source_ip>203.0.113.1</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
|
||||
self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input")
|
||||
|
||||
def testExpatErrorRaises(self):
|
||||
"""Completely invalid XML raises InvalidAggregateReport"""
|
||||
with self.assertRaises(parsedmarc.InvalidAggregateReport):
|
||||
parsedmarc.parse_aggregate_report_xml("not xml at all {}", offline=True)
|
||||
|
||||
def testMissingOrgName(self):
|
||||
"""Missing org_name raises InvalidAggregateReport"""
|
||||
xml = """<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>missing-org</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
</report_metadata>
|
||||
<policy_published><domain>example.com</domain><p>none</p></policy_published>
|
||||
<record>
|
||||
<row><source_ip>1.2.3.4</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
with self.assertRaises(parsedmarc.InvalidAggregateReport):
|
||||
parsedmarc.parse_aggregate_report_xml(xml, offline=True)
|
||||
|
||||
|
||||
class TestPolicyPublishedEdgeCases(unittest.TestCase):
|
||||
"""Tests for edge cases in policy_published parsing"""
|
||||
|
||||
VALID_XML_TEMPLATE = """<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<org_name>example.com</org_name>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>test-{tag}</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
{extra_metadata}
|
||||
</report_metadata>
|
||||
<policy_published>
|
||||
<domain>example.com</domain><p>reject</p>
|
||||
{policy_extra}
|
||||
</policy_published>
|
||||
<record>
|
||||
<row><source_ip>203.0.113.1</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
|
||||
def _parse(self, tag="default", policy_extra="", extra_metadata=""):
|
||||
xml = self.VALID_XML_TEMPLATE.format(
|
||||
tag=tag, policy_extra=policy_extra, extra_metadata=extra_metadata
|
||||
)
|
||||
return parsedmarc.parse_aggregate_report_xml(xml, offline=True)
|
||||
|
||||
def testPolicyPublishedListHandled(self):
|
||||
"""policy_published as a list uses first element"""
|
||||
# The code checks `if type(policy_published) is list`
|
||||
# This is tested implicitly when xmltodict returns a list;
|
||||
# we test via the np field presence
|
||||
report = self._parse(tag="np", policy_extra="<np>quarantine</np>")
|
||||
self.assertEqual(report["policy_published"]["np"], "quarantine")
|
||||
|
||||
def testNpFieldValues(self):
|
||||
"""np field is parsed correctly"""
|
||||
for val in ["none", "quarantine", "reject"]:
|
||||
report = self._parse(tag=f"np-{val}", policy_extra=f"<np>{val}</np>")
|
||||
self.assertEqual(report["policy_published"]["np"], val)
|
||||
|
||||
def testTestingField(self):
|
||||
"""testing field is parsed correctly"""
|
||||
for val in ["y", "n"]:
|
||||
report = self._parse(
|
||||
tag=f"testing-{val}", policy_extra=f"<testing>{val}</testing>"
|
||||
)
|
||||
self.assertEqual(report["policy_published"]["testing"], val)
|
||||
|
||||
def testDiscoveryMethodField(self):
|
||||
"""discovery_method field is parsed correctly"""
|
||||
for val in ["psl", "treewalk"]:
|
||||
report = self._parse(
|
||||
tag=f"disc-{val}",
|
||||
policy_extra=f"<discovery_method>{val}</discovery_method>",
|
||||
)
|
||||
self.assertEqual(report["policy_published"]["discovery_method"], val)
|
||||
|
||||
def testGeneratorField(self):
|
||||
"""generator field in report_metadata is parsed"""
|
||||
report = self._parse(
|
||||
tag="gen", extra_metadata="<generator>TestGen/1.0</generator>"
|
||||
)
|
||||
self.assertEqual(report["report_metadata"]["generator"], "TestGen/1.0")
|
||||
|
||||
def testPctFieldNone(self):
|
||||
"""pct defaults to None when absent (DMARCbis)"""
|
||||
report = self._parse(tag="no-pct")
|
||||
self.assertIsNone(report["policy_published"]["pct"])
|
||||
|
||||
def testFoFieldNone(self):
|
||||
"""fo defaults to None when absent (DMARCbis)"""
|
||||
report = self._parse(tag="no-fo")
|
||||
self.assertIsNone(report["policy_published"]["fo"])
|
||||
|
||||
def testReportMetadataErrors(self):
|
||||
"""Report metadata errors are captured"""
|
||||
report = self._parse(
|
||||
tag="errors",
|
||||
extra_metadata="<error>DNS timeout</error>",
|
||||
)
|
||||
self.assertIn("DNS timeout", report["report_metadata"]["errors"])
|
||||
|
||||
def testReportMetadataErrorsList(self):
|
||||
"""Report metadata errors as list are captured"""
|
||||
report = self._parse(
|
||||
tag="errors-list",
|
||||
extra_metadata="<error>error1</error><error>error2</error>",
|
||||
)
|
||||
self.assertIn("error1", report["report_metadata"]["errors"])
|
||||
self.assertIn("error2", report["report_metadata"]["errors"])
|
||||
|
||||
def testRecordParseFailureSkipped(self):
|
||||
"""Bad records are skipped with a warning, not crashing"""
|
||||
xml = """<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<org_name>example.com</org_name>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>bad-records</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
</report_metadata>
|
||||
<policy_published><domain>example.com</domain><p>none</p></policy_published>
|
||||
<record>
|
||||
<row><source_ip>203.0.113.1</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
<record>
|
||||
<row><source_ip>bad-ip</source_ip><count>not-a-number</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
|
||||
# At least the valid record should be parsed
|
||||
self.assertTrue(len(report["records"]) >= 1)
|
||||
|
||||
|
||||
class TestParseReportFile(unittest.TestCase):
|
||||
"""Tests for parse_report_file with various input types"""
|
||||
|
||||
def testParseReportFileFromBytes(self):
|
||||
"""parse_report_file works with bytes input"""
|
||||
xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
|
||||
with open(xml_path, "rb") as f:
|
||||
content = f.read()
|
||||
result = parsedmarc.parse_report_file(content, offline=True)
|
||||
self.assertEqual(result["report_type"], "aggregate")
|
||||
|
||||
def testParseReportFileFromBinaryIO(self):
|
||||
"""parse_report_file works with BinaryIO input"""
|
||||
xml_path = "samples/aggregate/!example.com!1538204542!1538463818.xml"
|
||||
with open(xml_path, "rb") as f:
|
||||
result = parsedmarc.parse_report_file(f, offline=True)
|
||||
self.assertEqual(result["report_type"], "aggregate")
|
||||
|
||||
def testParseReportFileFromPathlib(self):
|
||||
"""parse_report_file works with pathlib.Path input"""
|
||||
xml_path = Path("samples/aggregate/!example.com!1538204542!1538463818.xml")
|
||||
result = parsedmarc.parse_report_file(xml_path, offline=True)
|
||||
self.assertEqual(result["report_type"], "aggregate")
|
||||
|
||||
def testParseReportFileSmtpTls(self):
|
||||
"""parse_report_file detects SMTP TLS reports"""
|
||||
result = parsedmarc.parse_report_file(
|
||||
"samples/smtp_tls/smtp_tls.json", offline=True
|
||||
)
|
||||
self.assertEqual(result["report_type"], "smtp_tls")
|
||||
|
||||
def testParseReportFileEmail(self):
|
||||
"""parse_report_file detects failure reports in email format"""
|
||||
eml_path = "samples/failure/dmarc_ruf_report_linkedin.eml"
|
||||
result = parsedmarc.parse_report_file(eml_path, offline=True)
|
||||
self.assertEqual(result["report_type"], "failure")
|
||||
|
||||
def testParseReportFileInvalid(self):
|
||||
"""parse_report_file raises ParserError for invalid content"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.parse_report_file(b"this is not a report", offline=True)
|
||||
|
||||
|
||||
class TestParseReportEmail(unittest.TestCase):
|
||||
"""Tests for parse_report_email edge cases"""
|
||||
|
||||
def testSmtpTlsEmailReport(self):
|
||||
"""parse_report_email handles SMTP TLS reports in email format"""
|
||||
eml_path = "samples/smtp_tls/google.com_smtp_tls_report.eml"
|
||||
with open(eml_path, "rb") as f:
|
||||
content = f.read()
|
||||
result = parsedmarc.parse_report_email(content, offline=True)
|
||||
self.assertEqual(result["report_type"], "smtp_tls")
|
||||
|
||||
def testInvalidEmailRaisesError(self):
|
||||
"""parse_report_email raises error for non-DMARC email"""
|
||||
email_str = """From: test@example.com
|
||||
Subject: Hello World
|
||||
Content-Type: text/plain
|
||||
|
||||
This is not a DMARC report."""
|
||||
with self.assertRaises(parsedmarc.InvalidDMARCReport):
|
||||
parsedmarc.parse_report_email(email_str, offline=True)
|
||||
|
||||
|
||||
class TestFailureReportParsing(unittest.TestCase):
|
||||
"""Tests for failure report field defaults and edge cases"""
|
||||
|
||||
def _make_feedback_report(self, **overrides):
|
||||
"""Create a minimal feedback report string"""
|
||||
fields = {
|
||||
"Feedback-Type": "auth-failure",
|
||||
"User-Agent": "test/1.0",
|
||||
"Version": "1",
|
||||
"Original-Mail-From": "sender@example.com",
|
||||
"Arrival-Date": "Thu, 1 Jan 2024 00:00:00 +0000",
|
||||
"Source-IP": "203.0.113.1",
|
||||
"Reported-Domain": "example.com",
|
||||
"Auth-Failure": "dmarc",
|
||||
}
|
||||
fields.update(overrides)
|
||||
return "\n".join(f"{k}: {v}" for k, v in fields.items())
|
||||
|
||||
def _make_sample(self):
|
||||
return """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test
|
||||
Date: Thu, 1 Jan 2024 00:00:00 +0000
|
||||
|
||||
Test body"""
|
||||
|
||||
def _default_msg_date(self):
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
def testMissingVersion(self):
|
||||
"""Missing version defaults to None"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Version:")]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertIsNone(report["version"])
|
||||
|
||||
def testMissingUserAgent(self):
|
||||
"""Missing user_agent defaults to None"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertIsNone(report["user_agent"])
|
||||
|
||||
def testMissingDeliveryResult(self):
|
||||
"""Missing delivery_result maps to 'other' when field absent"""
|
||||
report_str = self._make_feedback_report()
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
# When delivery_result is not in the parsed report, it's set to None,
|
||||
# but then the validation check maps None (not in delivery_results list) to "other"
|
||||
self.assertEqual(report["delivery_result"], "other")
|
||||
|
||||
def testDeliveryResultMapped(self):
|
||||
"""Known delivery_result values are mapped correctly"""
|
||||
for val in ["delivered", "spam", "policy", "reject"]:
|
||||
report_str = self._make_feedback_report(**{"Delivery-Result": val})
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertEqual(report["delivery_result"], val)
|
||||
|
||||
def testDeliveryResultUnknownMapsToOther(self):
|
||||
"""Unknown delivery_result maps to 'other'"""
|
||||
report_str = self._make_feedback_report(**{"Delivery-Result": "unknown-value"})
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertEqual(report["delivery_result"], "other")
|
||||
|
||||
def testIdentityAlignmentNone(self):
|
||||
"""identity_alignment='none' results in empty auth mechanisms"""
|
||||
report_str = self._make_feedback_report(**{"Identity-Alignment": "none"})
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertEqual(report["authentication_mechanisms"], [])
|
||||
|
||||
def testIdentityAlignmentMultiple(self):
|
||||
"""identity_alignment with multiple values is split"""
|
||||
report_str = self._make_feedback_report(**{"Identity-Alignment": "dkim,spf"})
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertEqual(report["authentication_mechanisms"], ["dkim", "spf"])
|
||||
|
||||
def testMissingReportedDomainFallback(self):
|
||||
"""Missing reported_domain falls back to sample from domain"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [
|
||||
l for l in report_str.split("\n") if not l.startswith("Reported-Domain:")
|
||||
]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
)
|
||||
self.assertEqual(report["reported_domain"], "example.com")
|
||||
|
||||
def testMissingArrivalDateWithMsgDate(self):
|
||||
"""Missing arrival_date uses msg_date fallback"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
|
||||
report_str = "\n".join(lines)
|
||||
msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), msg_date, offline=True
|
||||
)
|
||||
self.assertIn("2024-06-15", report["arrival_date"])
|
||||
|
||||
def testMissingArrivalDateNoMsgDateRaises(self):
|
||||
"""Missing arrival_date with no msg_date raises"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
|
||||
report_str = "\n".join(lines)
|
||||
with self.assertRaises(parsedmarc.InvalidFailureReport):
|
||||
parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), None, offline=True
|
||||
)
|
||||
|
||||
|
||||
class TestWebhookClient(unittest.TestCase):
|
||||
"""Tests for webhook client error handling and close"""
|
||||
|
||||
def testSaveMethodsHandleErrors(self):
|
||||
"""Each save method catches and logs errors from _send_to_webhook"""
|
||||
client = parsedmarc.webhook.WebhookClient(
|
||||
aggregate_url="http://invalid.test/agg",
|
||||
failure_url="http://invalid.test/fail",
|
||||
smtp_tls_url="http://invalid.test/tls",
|
||||
)
|
||||
# Mock _send_to_webhook to raise, testing the try/except in each save method
|
||||
with patch.object(
|
||||
client, "_send_to_webhook", side_effect=Exception("send failed")
|
||||
):
|
||||
# None should raise
|
||||
client.save_aggregate_report_to_webhook('{"test": true}')
|
||||
client.save_failure_report_to_webhook('{"test": true}')
|
||||
client.save_smtp_tls_report_to_webhook('{"test": true}')
|
||||
|
||||
def testSendToWebhookLogsPostErrors(self):
|
||||
"""_send_to_webhook catches and logs POST errors"""
|
||||
client = parsedmarc.webhook.WebhookClient(
|
||||
aggregate_url="http://invalid.test/agg",
|
||||
failure_url="http://invalid.test/fail",
|
||||
smtp_tls_url="http://invalid.test/tls",
|
||||
)
|
||||
with patch.object(
|
||||
client.session, "post", side_effect=Exception("connection refused")
|
||||
):
|
||||
# Should not raise
|
||||
client._send_to_webhook("http://invalid.test/agg", '{"test": true}')
|
||||
|
||||
def testClose(self):
|
||||
"""WebhookClient.close() closes session"""
|
||||
client = parsedmarc.webhook.WebhookClient(
|
||||
aggregate_url="http://invalid.test/agg",
|
||||
failure_url="http://invalid.test/fail",
|
||||
smtp_tls_url="http://invalid.test/tls",
|
||||
)
|
||||
mock_close = MagicMock()
|
||||
client.session.close = mock_close
|
||||
client.close()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
|
||||
class TestUtilsDnsCaching(unittest.TestCase):
|
||||
"""Tests for DNS query caching and reverse DNS error handling"""
|
||||
|
||||
def testQueryDnsUsesCacheHit(self):
|
||||
"""query_dns returns cached result without making DNS query"""
|
||||
cache = ExpiringDict(max_len=100, max_age_seconds=60)
|
||||
cache["example.com_A"] = ["1.2.3.4"]
|
||||
result = parsedmarc.utils.query_dns("example.com", "A", cache=cache)
|
||||
self.assertEqual(result, ["1.2.3.4"])
|
||||
|
||||
def testQueryDnsCachesResult(self):
|
||||
"""query_dns stores result in cache when cache is non-empty"""
|
||||
cache = ExpiringDict(max_len=100, max_age_seconds=60)
|
||||
# Pre-populate so ExpiringDict is truthy
|
||||
cache["seed_key"] = ["seed"]
|
||||
mock_record = MagicMock()
|
||||
mock_record.to_text.return_value = '"1.2.3.4"'
|
||||
mock_resolver = MagicMock()
|
||||
mock_resolver.resolve.return_value = [mock_record]
|
||||
with patch(
|
||||
"parsedmarc.utils.dns.resolver.Resolver", return_value=mock_resolver
|
||||
):
|
||||
result = parsedmarc.utils.query_dns(
|
||||
"test-cache.example.com", "A", cache=cache
|
||||
)
|
||||
self.assertEqual(result, ["1.2.3.4"])
|
||||
self.assertIn("test-cache.example.com_A", cache)
|
||||
|
||||
def testReverseDnsReturnsNoneOnFailure(self):
|
||||
"""get_reverse_dns returns None on DNS exceptions"""
|
||||
with patch(
|
||||
"parsedmarc.utils.query_dns",
|
||||
side_effect=dns.exception.DNSException("timeout"),
|
||||
):
|
||||
result = parsedmarc.utils.get_reverse_dns("203.0.113.1")
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestUtilsIpDbPaths(unittest.TestCase):
|
||||
"""Tests for IP database path validation"""
|
||||
|
||||
def testCustomPathFallsBack(self):
|
||||
"""Non-existent custom db path falls back to default"""
|
||||
result = parsedmarc.utils.get_ip_address_country(
|
||||
"1.1.1.1", db_path="/nonexistent/path.mmdb"
|
||||
)
|
||||
self.assertTrue(result is None or isinstance(result, str))
|
||||
|
||||
def testBundledDbWorks(self):
|
||||
"""Bundled IP database returns results"""
|
||||
result = parsedmarc.utils.get_ip_address_country("8.8.8.8")
|
||||
self.assertEqual(result, "US")
|
||||
|
||||
|
||||
class TestUtilsParseEmail(unittest.TestCase):
|
||||
"""Tests for parse_email edge cases"""
|
||||
|
||||
def testMinimalEmail(self):
|
||||
"""parse_email handles email with minimal headers"""
|
||||
email_str = """From: test@example.com
|
||||
Subject: Test
|
||||
|
||||
Body text"""
|
||||
result = parsedmarc.utils.parse_email(email_str)
|
||||
self.assertEqual(result["subject"], "Test")
|
||||
self.assertEqual(result["reply_to"], [])
|
||||
|
||||
def testEmailWithNoSubject(self):
|
||||
"""parse_email defaults subject to None when missing"""
|
||||
email_str = """From: test@example.com
|
||||
To: other@example.com
|
||||
|
||||
Body"""
|
||||
result = parsedmarc.utils.parse_email(email_str)
|
||||
self.assertIsNone(result["subject"])
|
||||
|
||||
def testEmailBytesInput(self):
|
||||
"""parse_email handles bytes input"""
|
||||
email_bytes = b"""From: test@example.com
|
||||
Subject: Bytes Test
|
||||
To: other@example.com
|
||||
|
||||
Body"""
|
||||
result = parsedmarc.utils.parse_email(email_bytes)
|
||||
self.assertEqual(result["subject"], "Bytes Test")
|
||||
|
||||
def testEmailWithAttachments(self):
|
||||
"""parse_email with strip_attachment_payloads removes payloads"""
|
||||
import email as email_mod
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.base import MIMEBase
|
||||
from email import encoders
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = "test@example.com"
|
||||
msg["To"] = "other@example.com"
|
||||
msg["Subject"] = "Attachment Test"
|
||||
msg.attach(MIMEText("Body text"))
|
||||
|
||||
attachment = MIMEBase("application", "octet-stream")
|
||||
attachment.set_payload(b"file content here")
|
||||
encoders.encode_base64(attachment)
|
||||
attachment.add_header("Content-Disposition", "attachment", filename="test.bin")
|
||||
msg.attach(attachment)
|
||||
|
||||
result = parsedmarc.utils.parse_email(
|
||||
msg.as_string(), strip_attachment_payloads=True
|
||||
)
|
||||
for att in result["attachments"]:
|
||||
self.assertNotIn("payload", att)
|
||||
|
||||
|
||||
class TestUtilsOutlookMsg(unittest.TestCase):
|
||||
"""Tests for Outlook MSG detection and conversion"""
|
||||
|
||||
def testIsOutlookMsg(self):
|
||||
"""is_outlook_msg detects MSG magic bytes"""
|
||||
msg_magic = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" + b"\x00" * 100
|
||||
self.assertTrue(parsedmarc.utils.is_outlook_msg(msg_magic))
|
||||
|
||||
def testIsNotOutlookMsg(self):
|
||||
"""is_outlook_msg rejects non-MSG content"""
|
||||
self.assertFalse(parsedmarc.utils.is_outlook_msg(b"not an msg file"))
|
||||
self.assertFalse(parsedmarc.utils.is_outlook_msg("string input"))
|
||||
|
||||
def testConvertOutlookMsgInvalidInput(self):
|
||||
"""convert_outlook_msg raises ValueError for non-MSG bytes"""
|
||||
with self.assertRaises(ValueError):
|
||||
parsedmarc.utils.convert_outlook_msg(b"not an msg file")
|
||||
|
||||
|
||||
class TestUtilsReverseDnsMap(unittest.TestCase):
|
||||
"""Tests for reverse DNS map loading"""
|
||||
|
||||
def testLoadReverseDnsMapOffline(self):
|
||||
"""load_reverse_dns_map in offline mode loads bundled map"""
|
||||
rdns_map = {}
|
||||
parsedmarc.utils.load_reverse_dns_map(rdns_map, offline=True)
|
||||
self.assertTrue(len(rdns_map) > 0)
|
||||
|
||||
def testLoadReverseDnsMapLocalOverride(self):
|
||||
"""load_reverse_dns_map uses local_file_path when provided"""
|
||||
with NamedTemporaryFile("w", suffix=".csv", delete=False) as f:
|
||||
f.write("base_reverse_dns,name,type\n")
|
||||
f.write("custom.example.com,Custom Service,hosting\n")
|
||||
path = f.name
|
||||
try:
|
||||
rdns_map = {}
|
||||
parsedmarc.utils.load_reverse_dns_map(
|
||||
rdns_map, offline=True, local_file_path=path
|
||||
)
|
||||
self.assertIn("custom.example.com", rdns_map)
|
||||
self.assertEqual(rdns_map["custom.example.com"]["name"], "Custom Service")
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
def testLoadReverseDnsMapNetworkFailureFallback(self):
|
||||
"""load_reverse_dns_map falls back to bundled on network error"""
|
||||
rdns_map = {}
|
||||
with patch(
|
||||
"parsedmarc.utils.requests.get",
|
||||
side_effect=requests.exceptions.ConnectionError("no network"),
|
||||
):
|
||||
parsedmarc.utils.load_reverse_dns_map(rdns_map)
|
||||
self.assertTrue(len(rdns_map) > 0)
|
||||
|
||||
|
||||
class TestSmtpTlsReportErrors(unittest.TestCase):
|
||||
"""Tests for SMTP TLS report error handling"""
|
||||
|
||||
def testMissingRequiredField(self):
|
||||
"""Missing required field raises InvalidSMTPTLSReport"""
|
||||
json_str = json.dumps({"policies": []})
|
||||
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
|
||||
parsedmarc.parse_smtp_tls_report_json(json_str)
|
||||
|
||||
def testInvalidJson(self):
|
||||
"""Invalid JSON raises InvalidSMTPTLSReport"""
|
||||
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
|
||||
parsedmarc.parse_smtp_tls_report_json("not json {{{")
|
||||
|
||||
|
||||
class TestBucketIntervalEdgeCases(unittest.TestCase):
|
||||
"""Tests for _bucket_interval_by_day edge cases"""
|
||||
|
||||
def testDayCursorAdjustment(self):
|
||||
"""When begin is before midnight due to tz, day_cursor adjusts back"""
|
||||
# Use a timezone where midnight calculation might cause day_cursor > begin
|
||||
import pytz
|
||||
|
||||
tz = pytz.FixedOffset(-600) # UTC-10
|
||||
begin = datetime(2024, 1, 1, 23, 30, 0, tzinfo=timezone.utc).astimezone(tz)
|
||||
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc).astimezone(tz)
|
||||
buckets = parsedmarc._bucket_interval_by_day(begin, end, 100)
|
||||
total = sum(b["count"] for b in buckets)
|
||||
self.assertEqual(total, 100)
|
||||
|
||||
|
||||
class TestGetDmarcReportsFromMbox(unittest.TestCase):
|
||||
"""Tests for mbox parsing"""
|
||||
|
||||
def testEmptyMbox(self):
|
||||
"""Empty mbox returns empty results"""
|
||||
with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
|
||||
f.write(b"")
|
||||
path = f.name
|
||||
try:
|
||||
results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
|
||||
self.assertEqual(results["aggregate_reports"], [])
|
||||
self.assertEqual(results["failure_reports"], [])
|
||||
self.assertEqual(results["smtp_tls_reports"], [])
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
def testMboxWithAggregateReport(self):
|
||||
"""Mbox with aggregate report email is parsed"""
|
||||
import email as email_mod
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.application import MIMEApplication
|
||||
import gzip
|
||||
|
||||
xml = b"""<?xml version="1.0"?>
|
||||
<feedback>
|
||||
<report_metadata>
|
||||
<org_name>example.com</org_name>
|
||||
<email>dmarc@example.com</email>
|
||||
<report_id>mbox-test-123</report_id>
|
||||
<date_range><begin>1680000000</begin><end>1680086400</end></date_range>
|
||||
</report_metadata>
|
||||
<policy_published><domain>example.com</domain><p>none</p></policy_published>
|
||||
<record>
|
||||
<row><source_ip>203.0.113.1</source_ip><count>1</count>
|
||||
<policy_evaluated><disposition>none</disposition><dkim>pass</dkim><spf>pass</spf></policy_evaluated>
|
||||
</row>
|
||||
<identifiers><header_from>example.com</header_from></identifiers>
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
compressed = gzip.compress(xml)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = "dmarc@example.com"
|
||||
msg["To"] = "postmaster@example.com"
|
||||
msg["Subject"] = "DMARC Aggregate Report"
|
||||
msg["Date"] = "Thu, 1 Jan 2024 00:00:00 +0000"
|
||||
att = MIMEApplication(compressed, "gzip")
|
||||
att.add_header("Content-Disposition", "attachment", filename="report.xml.gz")
|
||||
msg.attach(att)
|
||||
|
||||
with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
|
||||
# mbox format requires "From " line
|
||||
f.write("From dmarc@example.com Thu Jan 1 00:00:00 2024\n")
|
||||
f.write(msg.as_string())
|
||||
f.write("\n")
|
||||
path = f.name
|
||||
try:
|
||||
results = parsedmarc.get_dmarc_reports_from_mbox(path, offline=True)
|
||||
self.assertTrue(len(results["aggregate_reports"]) >= 1)
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
|
||||
class TestPslOverrides(unittest.TestCase):
|
||||
"""Tests for PSL override matching"""
|
||||
|
||||
def testOverrideMatch(self):
|
||||
"""PSL overrides are applied when domain ends with override"""
|
||||
# psl_overrides contains entries; test that get_base_domain
|
||||
# handles them without error
|
||||
result = parsedmarc.utils.get_base_domain("sub.example.com")
|
||||
self.assertEqual(result, "example.com")
|
||||
|
||||
|
||||
class TestIsMbox(unittest.TestCase):
|
||||
"""Tests for is_mbox utility"""
|
||||
|
||||
def testValidMbox(self):
|
||||
"""is_mbox returns True for valid mbox file"""
|
||||
with NamedTemporaryFile(suffix=".mbox", delete=False, mode="w") as f:
|
||||
f.write("From test@example.com Thu Jan 1 00:00:00 2024\n")
|
||||
f.write("Subject: Test\n\nBody\n\n")
|
||||
path = f.name
|
||||
try:
|
||||
self.assertTrue(parsedmarc.utils.is_mbox(path))
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
def testEmptyFileNotMbox(self):
|
||||
"""is_mbox returns False for empty file"""
|
||||
with NamedTemporaryFile(suffix=".mbox", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
self.assertFalse(parsedmarc.utils.is_mbox(path))
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
def testNonExistentNotMbox(self):
|
||||
"""is_mbox returns False for non-existent file"""
|
||||
self.assertFalse(parsedmarc.utils.is_mbox("/nonexistent/file.mbox"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user