#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, unicode_literals import json import os import sys import tempfile import unittest from datetime import datetime, timedelta, timezone from glob import glob from base64 import urlsafe_b64encode from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from types import SimpleNamespace from unittest.mock import MagicMock, patch from lxml import etree from googleapiclient.errors import HttpError from httplib2 import Response from imapclient.exceptions import IMAPClientError import parsedmarc import parsedmarc.cli from parsedmarc.mail.gmail import GmailConnection from parsedmarc.mail.gmail import _get_creds from parsedmarc.mail.graph import MSGraphConnection from parsedmarc.mail.graph import _generate_credential from parsedmarc.mail.graph import _get_cache_args from parsedmarc.mail.graph import _load_token from parsedmarc.mail.imap import IMAPConnection import parsedmarc.mail.gmail as gmail_module import parsedmarc.mail.graph as graph_module import parsedmarc.mail.imap as imap_module import parsedmarc.opensearch as opensearch_module import parsedmarc.utils # Detect if running in GitHub Actions to skip DNS lookups OFFLINE_MODE = os.environ.get("GITHUB_ACTIONS", "false").lower() == "true" def minify_xml(xml_string): parser = etree.XMLParser(remove_blank_text=True) tree = etree.fromstring(xml_string.encode("utf-8"), parser) return etree.tostring(tree, pretty_print=False).decode("utf-8") def compare_xml(xml1, xml2): parser = etree.XMLParser(remove_blank_text=True) tree1 = etree.fromstring(xml1.encode("utf-8"), parser) tree2 = etree.fromstring(xml2.encode("utf-8"), parser) return etree.tostring(tree1) == etree.tostring(tree2) class Test(unittest.TestCase): def testBase64Decoding(self): """Test base64 decoding""" # Example from Wikipedia Base64 article b64_str = "YW55IGNhcm5hbCBwbGVhcw" decoded_str = parsedmarc.utils.decode_base64(b64_str) self.assertEqual(decoded_str, b"any carnal pleas") def testPSLDownload(self): """Test Public Suffix List domain lookups""" subdomain = "foo.example.com" result = parsedmarc.utils.get_base_domain(subdomain) self.assertEqual(result, "example.com") # Test newer PSL entries subdomain = "e3191.c.akamaiedge.net" result = parsedmarc.utils.get_base_domain(subdomain) self.assertEqual(result, "c.akamaiedge.net") def testExtractReportXMLComparator(self): """Test XML comparator function""" with open("samples/extract_report/nice-input.xml") as f: xmlnice = f.read() with open("samples/extract_report/changed-input.xml") as f: xmlchanged = minify_xml(f.read()) self.assertTrue(compare_xml(xmlnice, xmlnice)) self.assertTrue(compare_xml(xmlchanged, xmlchanged)) self.assertFalse(compare_xml(xmlnice, xmlchanged)) self.assertFalse(compare_xml(xmlchanged, xmlnice)) print("Passed!") def testExtractReportBytes(self): """Test extract report function for bytes string input""" print() file = "samples/extract_report/nice-input.xml" with open(file, "rb") as f: data = f.read() print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report(data) with open("samples/extract_report/nice-input.xml") as f: xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") def testExtractReportXML(self): """Test extract report function for XML input""" print() report_path = "samples/extract_report/nice-input.xml" print("Testing {0}: ".format(report_path), end="") xmlout = parsedmarc.extract_report_from_file_path(report_path) xmlin_file = open("samples/extract_report/nice-input.xml") xmlin = xmlin_file.read() xmlin_file.close() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") def testExtractReportXMLFromPath(self): """Test extract report function for pathlib.Path input""" report_path = Path("samples/extract_report/nice-input.xml") xmlout = parsedmarc.extract_report_from_file_path(report_path) with open("samples/extract_report/nice-input.xml") as xmlin_file: xmlin = xmlin_file.read() self.assertTrue(compare_xml(xmlout, xmlin)) def testExtractReportGZip(self): """Test extract report function for gzip input""" print() file = "samples/extract_report/nice-input.xml.gz" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) with open("samples/extract_report/nice-input.xml") as f: xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") def testExtractReportZip(self): """Test extract report function for zip input""" print() file = "samples/extract_report/nice-input.xml.zip" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) with open("samples/extract_report/nice-input.xml") as f: xmlin = minify_xml(f.read()) self.assertTrue(compare_xml(xmlout, xmlin)) with open("samples/extract_report/changed-input.xml") as f: xmlin = f.read() self.assertFalse(compare_xml(xmlout, xmlin)) print("Passed!") def testParseReportFileAcceptsPathForXML(self): report_path = Path( "samples/aggregate/protection.outlook.com!example.com!1711756800!1711843200.xml" ) result = parsedmarc.parse_report_file( report_path, offline=True, ) self.assertEqual(result["report_type"], "aggregate") self.assertEqual(result["report"]["report_metadata"]["org_name"], "outlook.com") def testParseReportFileAcceptsPathForEmail(self): report_path = Path( "samples/aggregate/Report domain- borschow.com Submitter- google.com Report-ID- 949348866075514174.eml" ) result = parsedmarc.parse_report_file( report_path, offline=True, ) self.assertEqual(result["report_type"], "aggregate") self.assertEqual(result["report"]["report_metadata"]["org_name"], "google.com") def testAggregateSamples(self): """Test sample aggregate/rua DMARC reports""" print() sample_paths = glob("samples/aggregate/*") for sample_path in sample_paths: if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") with self.subTest(sample=sample_path): parsed_report = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=OFFLINE_MODE )["report"] parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) print("Passed!") def testEmptySample(self): """Test empty/unparasable report""" with self.assertRaises(parsedmarc.ParserError): parsedmarc.parse_report_file("samples/empty.xml", offline=OFFLINE_MODE) def testForensicSamples(self): """Test sample failure/ruf DMARC reports""" print() sample_paths = glob("samples/failure/*.eml") for sample_path in sample_paths: print("Testing {0}: ".format(sample_path), end="") with self.subTest(sample=sample_path): with open(sample_path) as sample_file: sample_content = sample_file.read() parsed_report = parsedmarc.parse_report_email( sample_content, offline=OFFLINE_MODE )["report"] parsed_report = parsedmarc.parse_report_file( sample_path, offline=OFFLINE_MODE )["report"] parsedmarc.parsed_failure_reports_to_csv(parsed_report) print("Passed!") def testFailureReportBackwardCompat(self): """Test that old forensic function aliases still work""" self.assertIs( parsedmarc.parse_forensic_report, parsedmarc.parse_failure_report, ) self.assertIs( parsedmarc.parsed_forensic_reports_to_csv, parsedmarc.parsed_failure_reports_to_csv, ) self.assertIs( parsedmarc.parsed_forensic_reports_to_csv_rows, parsedmarc.parsed_failure_reports_to_csv_rows, ) self.assertIs( parsedmarc.InvalidForensicReport, parsedmarc.InvalidFailureReport, ) def testDMARCbisDraftSample(self): """Test parsing the sample report from the DMARCbis aggregate draft""" print() sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" print("Testing {0}: ".format(sample_path), end="") result = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=True ) report = result["report"] # Verify report_type self.assertEqual(result["report_type"], "aggregate") # Verify xml_schema self.assertEqual(report["xml_schema"], "1.0") # Verify report_metadata metadata = report["report_metadata"] self.assertEqual(metadata["org_name"], "Sample Reporter") self.assertEqual(metadata["org_email"], "report_sender@example-reporter.com") self.assertEqual(metadata["org_extra_contact_info"], "...") self.assertEqual(metadata["report_id"], "3v98abbp8ya9n3va8yr8oa3ya") self.assertEqual( metadata["generator"], "Example DMARC Aggregate Reporter v1.2", ) # Verify DMARCbis policy_published fields pp = report["policy_published"] self.assertEqual(pp["domain"], "example.com") self.assertEqual(pp["p"], "quarantine") self.assertEqual(pp["sp"], "none") self.assertEqual(pp["np"], "none") self.assertEqual(pp["testing"], "n") self.assertEqual(pp["discovery_method"], "treewalk") # adkim/aspf default when not in XML self.assertEqual(pp["adkim"], "r") self.assertEqual(pp["aspf"], "r") # pct/fo are None on DMARCbis reports (not used) self.assertIsNone(pp["pct"]) self.assertIsNone(pp["fo"]) # Verify record self.assertEqual(len(report["records"]), 1) rec = report["records"][0] self.assertEqual(rec["source"]["ip_address"], "192.0.2.123") self.assertEqual(rec["count"], 123) self.assertEqual(rec["policy_evaluated"]["disposition"], "pass") self.assertEqual(rec["policy_evaluated"]["dkim"], "pass") self.assertEqual(rec["policy_evaluated"]["spf"], "fail") # Verify DKIM auth result with human_result self.assertEqual(len(rec["auth_results"]["dkim"]), 1) dkim = rec["auth_results"]["dkim"][0] self.assertEqual(dkim["domain"], "example.com") self.assertEqual(dkim["selector"], "abc123") self.assertEqual(dkim["result"], "pass") self.assertIsNone(dkim["human_result"]) # Verify SPF auth result with human_result self.assertEqual(len(rec["auth_results"]["spf"]), 1) spf = rec["auth_results"]["spf"][0] self.assertEqual(spf["domain"], "example.com") self.assertEqual(spf["result"], "fail") self.assertIsNone(spf["human_result"]) # Verify CSV output includes new fields csv = parsedmarc.parsed_aggregate_reports_to_csv(report) header = csv.split("\n")[0] self.assertIn("np", header.split(",")) self.assertIn("testing", header.split(",")) self.assertIn("discovery_method", header.split(",")) print("Passed!") def testDMARCbisFieldsWithRFC7489(self): """Test that RFC 7489 reports have None for DMARCbis-only fields""" print() sample_path = ( "samples/aggregate/example.net!example.com!1529366400!1529452799.xml" ) print("Testing {0}: ".format(sample_path), end="") result = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=True ) report = result["report"] pp = report["policy_published"] # RFC 7489 fields present self.assertEqual(pp["pct"], "100") self.assertEqual(pp["fo"], "0") # DMARCbis fields absent (None) self.assertIsNone(pp["np"]) self.assertIsNone(pp["testing"]) self.assertIsNone(pp["discovery_method"]) # generator absent (None) self.assertIsNone(report["report_metadata"]["generator"]) print("Passed!") def testDMARCbisWithExplicitFields(self): """Test DMARCbis report with explicit testing and discovery_method""" print() sample_path = ( "samples/aggregate/" "dmarcbis-example.net!example.com!1700000000!1700086399.xml" ) print("Testing {0}: ".format(sample_path), end="") result = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=True ) report = result["report"] pp = report["policy_published"] self.assertEqual(pp["np"], "reject") self.assertEqual(pp["testing"], "y") self.assertEqual(pp["discovery_method"], "treewalk") print("Passed!") def testSmtpTlsSamples(self): """Test sample SMTP TLS reports""" print() sample_paths = glob("samples/smtp_tls/*") for sample_path in sample_paths: if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") with self.subTest(sample=sample_path): parsed_report = parsedmarc.parse_report_file( sample_path, offline=OFFLINE_MODE )["report"] parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report) print("Passed!") # =================================================================== # New tests for _bucket_interval_by_day # =================================================================== def testBucketIntervalBeginAfterEnd(self): """begin > end should raise ValueError""" begin = datetime(2024, 1, 2, tzinfo=timezone.utc) end = datetime(2024, 1, 1, tzinfo=timezone.utc) with self.assertRaises(ValueError): parsedmarc._bucket_interval_by_day(begin, end, 100) def testBucketIntervalNaiveDatetime(self): """Non-timezone-aware datetimes should raise ValueError""" begin = datetime(2024, 1, 1) end = datetime(2024, 1, 2) with self.assertRaises(ValueError): parsedmarc._bucket_interval_by_day(begin, end, 100) def testBucketIntervalDifferentTzinfo(self): """Different tzinfo objects should raise ValueError""" tz1 = timezone.utc tz2 = timezone(timedelta(hours=5)) begin = datetime(2024, 1, 1, tzinfo=tz1) end = datetime(2024, 1, 2, tzinfo=tz2) with self.assertRaises(ValueError): parsedmarc._bucket_interval_by_day(begin, end, 100) def testBucketIntervalNegativeCount(self): """Negative total_count should raise ValueError""" begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) with self.assertRaises(ValueError): parsedmarc._bucket_interval_by_day(begin, end, -1) def testBucketIntervalZeroCount(self): """Zero total_count should return empty list""" begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 0) self.assertEqual(result, []) def testBucketIntervalSameBeginEnd(self): """Same begin and end (zero interval) should return empty list""" dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(dt, dt, 100) self.assertEqual(result, []) def testBucketIntervalSingleDay(self): """Single day interval should return one bucket with total count""" begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) self.assertEqual(len(result), 1) self.assertEqual(result[0]["count"], 100) self.assertEqual(result[0]["begin"], begin) def testBucketIntervalMultiDay(self): """Multi-day interval should distribute counts proportionally""" begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) self.assertEqual(len(result), 2) total = sum(b["count"] for b in result) self.assertEqual(total, 100) # Equal days => equal distribution self.assertEqual(result[0]["count"], 50) self.assertEqual(result[1]["count"], 50) def testBucketIntervalRemainderDistribution(self): """Odd count across equal days distributes remainder correctly""" begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 10) total = sum(b["count"] for b in result) self.assertEqual(total, 10) self.assertEqual(len(result), 3) def testBucketIntervalPartialDays(self): """Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split""" begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 90) total = sum(b["count"] for b in result) self.assertEqual(total, 90) # day1: 12h, day2: 24h => 1/3 vs 2/3 self.assertEqual(result[0]["count"], 30) self.assertEqual(result[1]["count"], 60) # =================================================================== # Tests for _append_parsed_record # =================================================================== def testAppendParsedRecordNoNormalize(self): """No normalization: record appended as-is with interval fields""" records = [] rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 2, 0, 0, 0, tzinfo=timezone.utc) parsedmarc._append_parsed_record(rec, records, begin, end, False) self.assertEqual(len(records), 1) self.assertFalse(records[0]["normalized_timespan"]) self.assertEqual(records[0]["interval_begin"], "2024-01-01 00:00:00") self.assertEqual(records[0]["interval_end"], "2024-01-02 00:00:00") def testAppendParsedRecordNormalize(self): """Normalization: record split into daily buckets""" records = [] rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) parsedmarc._append_parsed_record(rec, records, begin, end, True) self.assertEqual(len(records), 2) total = sum(r["count"] for r in records) self.assertEqual(total, 100) for r in records: self.assertTrue(r["normalized_timespan"]) def testAppendParsedRecordNormalizeZeroCount(self): """Normalization with zero count: nothing appended""" records = [] rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) parsedmarc._append_parsed_record(rec, records, begin, end, True) self.assertEqual(len(records), 0) # =================================================================== # Tests for _parse_report_record # =================================================================== def testParseReportRecordNoneSourceIP(self): """Record with None source_ip should raise ValueError""" record = { "row": { "source_ip": None, "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, } with self.assertRaises(ValueError): parsedmarc._parse_report_record(record, offline=True) def testParseReportRecordMissingDkimSpf(self): """Record with missing dkim/spf auth results defaults correctly""" record = { "row": { "source_ip": "192.0.2.1", "count": "5", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "fail", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {}, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["auth_results"]["dkim"], []) self.assertEqual(result["auth_results"]["spf"], []) def testParseReportRecordReasonHandling(self): """Reasons in policy_evaluated get normalized with comment default""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", "reason": {"type": "forwarded"}, }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) reasons = result["policy_evaluated"]["policy_override_reasons"] self.assertEqual(len(reasons), 1) self.assertEqual(reasons[0]["type"], "forwarded") self.assertIsNone(reasons[0]["comment"]) def testParseReportRecordReasonList(self): """Multiple reasons as a list are preserved""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", "reason": [ {"type": "forwarded", "comment": "relay"}, {"type": "local_policy"}, ], }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) reasons = result["policy_evaluated"]["policy_override_reasons"] self.assertEqual(len(reasons), 2) self.assertEqual(reasons[0]["comment"], "relay") self.assertIsNone(reasons[1]["comment"]) def testParseReportRecordIdentities(self): """'identities' key is mapped to 'identifiers'""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identities": { "header_from": "Example.COM", "envelope_from": "example.com", }, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) self.assertIn("identifiers", result) self.assertEqual(result["identifiers"]["header_from"], "example.com") def testParseReportRecordDkimDefaults(self): """DKIM result defaults: selector='none', result='none' when missing""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "fail", "spf": "fail", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { "dkim": {"domain": "example.com"}, "spf": [], }, } result = parsedmarc._parse_report_record(record, offline=True) dkim = result["auth_results"]["dkim"][0] self.assertEqual(dkim["selector"], "none") self.assertEqual(dkim["result"], "none") self.assertIsNone(dkim["human_result"]) def testParseReportRecordSpfDefaults(self): """SPF result defaults: scope='mfrom', result='none' when missing""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "fail", "spf": "fail", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { "dkim": [], "spf": {"domain": "example.com"}, }, } result = parsedmarc._parse_report_record(record, offline=True) spf = result["auth_results"]["spf"][0] self.assertEqual(spf["scope"], "mfrom") self.assertEqual(spf["result"], "none") self.assertIsNone(spf["human_result"]) def testParseReportRecordHumanResult(self): """human_result field is included when present""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { "dkim": [ { "domain": "example.com", "selector": "s1", "result": "pass", "human_result": "good key", } ], "spf": [ { "domain": "example.com", "scope": "mfrom", "result": "pass", "human_result": "sender valid", } ], }, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["auth_results"]["dkim"][0]["human_result"], "good key") self.assertEqual( result["auth_results"]["spf"][0]["human_result"], "sender valid" ) def testParseReportRecordEnvelopeFromFallback(self): """envelope_from falls back to last SPF domain when missing""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { "dkim": [], "spf": [ {"domain": "Bounce.Example.COM", "scope": "mfrom", "result": "pass"} ], }, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["identifiers"]["envelope_from"], "bounce.example.com") def testParseReportRecordEnvelopeFromNullFallback(self): """envelope_from None value falls back to SPF domain""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identifiers": { "header_from": "example.com", "envelope_from": None, }, "auth_results": { "dkim": [], "spf": [ {"domain": "SPF.Example.COM", "scope": "mfrom", "result": "pass"} ], }, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["identifiers"]["envelope_from"], "spf.example.com") def testParseReportRecordEnvelopeTo(self): """envelope_to is preserved and moved correctly""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "pass", }, }, "identifiers": { "header_from": "example.com", "envelope_from": "bounce@example.com", "envelope_to": "recipient@example.com", }, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["identifiers"]["envelope_to"], "recipient@example.com") def testParseReportRecordAlignment(self): """Alignment fields computed correctly from policy_evaluated""" record = { "row": { "source_ip": "192.0.2.1", "count": "1", "policy_evaluated": { "disposition": "none", "dkim": "pass", "spf": "fail", }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) self.assertTrue(result["alignment"]["dkim"]) self.assertFalse(result["alignment"]["spf"]) self.assertTrue(result["alignment"]["dmarc"]) # =================================================================== # Tests for _parse_smtp_tls_failure_details # =================================================================== def testParseSmtpTlsFailureDetailsMinimal(self): """Minimal failure details with just required fields""" details = { "result-type": "certificate-expired", "failed-session-count": 5, } result = parsedmarc._parse_smtp_tls_failure_details(details) self.assertEqual(result["result_type"], "certificate-expired") self.assertEqual(result["failed_session_count"], 5) self.assertNotIn("sending_mta_ip", result) def testParseSmtpTlsFailureDetailsAllOptional(self): """All optional fields included""" details = { "result-type": "starttls-not-supported", "failed-session-count": 3, "sending-mta-ip": "10.0.0.1", "receiving-ip": "10.0.0.2", "receiving-mx-hostname": "mx.example.com", "receiving-mx-helo": "mx.example.com", "additional-info-uri": "https://example.com/info", "failure-reason-code": "TLS_ERROR", } result = parsedmarc._parse_smtp_tls_failure_details(details) self.assertEqual(result["sending_mta_ip"], "10.0.0.1") self.assertEqual(result["receiving_ip"], "10.0.0.2") self.assertEqual(result["receiving_mx_hostname"], "mx.example.com") self.assertEqual(result["receiving_mx_helo"], "mx.example.com") self.assertEqual(result["additional_info_uri"], "https://example.com/info") self.assertEqual(result["failure_reason_code"], "TLS_ERROR") def testParseSmtpTlsFailureDetailsMissingRequired(self): """Missing required field raises InvalidSMTPTLSReport""" with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc._parse_smtp_tls_failure_details({"result-type": "err"}) # =================================================================== # Tests for _parse_smtp_tls_report_policy # =================================================================== def testParseSmtpTlsReportPolicyValid(self): """Valid STS policy parses correctly""" policy = { "policy": { "policy-type": "sts", "policy-domain": "example.com", "policy-string": ["version: STSv1", "mode: enforce"], "mx-host-pattern": ["*.example.com"], }, "summary": { "total-successful-session-count": 100, "total-failure-session-count": 2, }, } result = parsedmarc._parse_smtp_tls_report_policy(policy) self.assertEqual(result["policy_type"], "sts") self.assertEqual(result["policy_domain"], "example.com") self.assertEqual(result["policy_strings"], ["version: STSv1", "mode: enforce"]) self.assertEqual(result["mx_host_patterns"], ["*.example.com"]) self.assertEqual(result["successful_session_count"], 100) self.assertEqual(result["failed_session_count"], 2) def testParseSmtpTlsReportPolicyInvalidType(self): """Invalid policy type raises InvalidSMTPTLSReport""" policy = { "policy": { "policy-type": "invalid", "policy-domain": "example.com", }, "summary": { "total-successful-session-count": 0, "total-failure-session-count": 0, }, } with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc._parse_smtp_tls_report_policy(policy) def testParseSmtpTlsReportPolicyEmptyPolicyString(self): """Empty policy-string list is not included""" policy = { "policy": { "policy-type": "sts", "policy-domain": "example.com", "policy-string": [], "mx-host-pattern": [], }, "summary": { "total-successful-session-count": 50, "total-failure-session-count": 0, }, } result = parsedmarc._parse_smtp_tls_report_policy(policy) self.assertNotIn("policy_strings", result) self.assertNotIn("mx_host_patterns", result) def testParseSmtpTlsReportPolicyWithFailureDetails(self): """Policy with failure-details parses nested details""" policy = { "policy": { "policy-type": "sts", "policy-domain": "example.com", }, "summary": { "total-successful-session-count": 10, "total-failure-session-count": 1, }, "failure-details": [ { "result-type": "certificate-expired", "failed-session-count": 1, } ], } result = parsedmarc._parse_smtp_tls_report_policy(policy) self.assertEqual(len(result["failure_details"]), 1) self.assertEqual( result["failure_details"][0]["result_type"], "certificate-expired" ) def testParseSmtpTlsReportPolicyMissingField(self): """Missing required policy field raises InvalidSMTPTLSReport""" policy = {"policy": {"policy-type": "sts"}, "summary": {}} with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc._parse_smtp_tls_report_policy(policy) # =================================================================== # Tests for parse_smtp_tls_report_json # =================================================================== def testParseSmtpTlsReportJsonValid(self): """Valid SMTP TLS JSON report parses correctly""" report = json.dumps( { "organization-name": "Example Corp", "date-range": { "start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z", }, "contact-info": "admin@example.com", "report-id": "report-123", "policies": [ { "policy": { "policy-type": "sts", "policy-domain": "example.com", }, "summary": { "total-successful-session-count": 50, "total-failure-session-count": 0, }, } ], } ) result = parsedmarc.parse_smtp_tls_report_json(report) self.assertEqual(result["organization_name"], "Example Corp") self.assertEqual(result["report_id"], "report-123") self.assertEqual(len(result["policies"]), 1) def testParseSmtpTlsReportJsonBytes(self): """SMTP TLS report as bytes parses correctly""" report = json.dumps( { "organization-name": "Org", "date-range": { "start-datetime": "2024-01-01", "end-datetime": "2024-01-02", }, "contact-info": "a@b.com", "report-id": "r1", "policies": [ { "policy": {"policy-type": "tlsa", "policy-domain": "a.com"}, "summary": { "total-successful-session-count": 1, "total-failure-session-count": 0, }, } ], } ).encode("utf-8") result = parsedmarc.parse_smtp_tls_report_json(report) self.assertEqual(result["organization_name"], "Org") def testParseSmtpTlsReportJsonMissingField(self): """Missing required field raises InvalidSMTPTLSReport""" report = json.dumps({"organization-name": "Org"}) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) def testParseSmtpTlsReportJsonPoliciesNotList(self): """Non-list policies raises InvalidSMTPTLSReport""" report = json.dumps( { "organization-name": "Org", "date-range": { "start-datetime": "2024-01-01", "end-datetime": "2024-01-02", }, "contact-info": "a@b.com", "report-id": "r1", "policies": "not-a-list", } ) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) # =================================================================== # Tests for aggregate report parsing (validation warnings, etc.) # =================================================================== def testAggregateReportInvalidNpWarning(self): """Invalid np value is preserved but logs warning""" xml = """ 1.0 Test Org test@example.com test-np-invalid 17040672001704153599 example.com

none

banana maybe magic
192.0.2.1 1 none pass pass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) # Invalid values are still stored self.assertEqual(report["policy_published"]["np"], "banana") self.assertEqual(report["policy_published"]["testing"], "maybe") self.assertEqual(report["policy_published"]["discovery_method"], "magic") def testAggregateReportPassDisposition(self): """'pass' as valid disposition is preserved""" xml = """ TestOrg test@example.com test-pass 17040672001704153599 example.com

reject

192.0.2.1 1 pass pass pass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual( report["records"][0]["policy_evaluated"]["disposition"], "pass" ) def testAggregateReportMultipleRecords(self): """Reports with multiple records are all parsed""" xml = """ TestOrg test@example.com test-multi 17040672001704153599 example.com

none

192.0.2.1 10 nonepasspass example.com example.compass 192.0.2.2 5 quarantinefailfail example.com example.comfail
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual(len(report["records"]), 2) self.assertEqual(report["records"][0]["count"], 10) self.assertEqual(report["records"][1]["count"], 5) def testAggregateReportInvalidXmlRecovery(self): """Badly formed XML is recovered via lxml""" xml = 'Testt@e.comr117040672001704153599example.com

none

192.0.2.11nonepasspassexample.comexample.compass
' report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual(report["report_metadata"]["report_id"], "r1") def testAggregateReportCsvRowsContainDMARCbisFields(self): """CSV rows include np, testing, discovery_method columns""" result = parsedmarc.parse_report_file( "samples/aggregate/dmarcbis-draft-sample.xml", always_use_local_files=True, offline=True, ) report = result["report"] rows = parsedmarc.parsed_aggregate_reports_to_csv_rows(report) self.assertTrue(len(rows) > 0) row = rows[0] self.assertIn("np", row) self.assertIn("testing", row) self.assertIn("discovery_method", row) self.assertIn("source_ip_address", row) self.assertIn("dkim_domains", row) self.assertIn("spf_domains", row) def testAggregateReportSchemaVersion(self): """DMARCbis report with returns correct xml_schema""" xml = """ 1.0 TestOrg test@example.com test-version 17040672001704153599 example.com

none

192.0.2.1 1 nonepasspass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual(report["xml_schema"], "1.0") def testAggregateReportDraftSchema(self): """Report without defaults to 'draft' schema""" xml = """ TestOrg test@example.com test-draft 17040672001704153599 example.com

none

192.0.2.1 1 nonepasspass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual(report["xml_schema"], "draft") def testAggregateReportGeneratorField(self): """Generator field is correctly extracted""" xml = """ TestOrg test@example.com test-gen My Reporter v1.0 17040672001704153599 example.com

none

192.0.2.1 1 nonepasspass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertEqual(report["report_metadata"]["generator"], "My Reporter v1.0") def testAggregateReportReportErrors(self): """Report errors in metadata are captured""" xml = """ TestOrg test@example.com test-err Some error 17040672001704153599 example.com

none

192.0.2.1 1 nonepasspass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertIn("Some error", report["report_metadata"]["errors"]) def testAggregateReportPolicyDefaults(self): """Policy defaults: adkim/aspf='r', sp=p, pct/fo=None""" xml = """ TestOrg test@example.com test-defaults 17040672001704153599 example.com

reject

192.0.2.1 1 nonepasspass example.com example.compass
""" report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) pp = report["policy_published"] self.assertEqual(pp["adkim"], "r") self.assertEqual(pp["aspf"], "r") self.assertEqual(pp["sp"], "reject") # defaults to p self.assertIsNone(pp["pct"]) self.assertIsNone(pp["fo"]) self.assertIsNone(pp["np"]) self.assertIsNone(pp["testing"]) self.assertIsNone(pp["discovery_method"]) def testMagicXmlTagDetection(self): """XML without declaration (starting with '<') is extracted""" xml_no_decl = b"Ta@b.comr117040672001704153599example.com

none

192.0.2.11nonepasspassexample.comexample.compass
" self.assertTrue(xml_no_decl.startswith(parsedmarc.MAGIC_XML_TAG)) # Ensure it extracts as XML result = parsedmarc.extract_report(xml_no_decl) self.assertIn("", result) # =================================================================== # Tests for parsedmarc/utils.py # =================================================================== def testTimestampToDatetime(self): """timestamp_to_datetime converts UNIX timestamp to datetime""" from datetime import datetime dt = parsedmarc.utils.timestamp_to_datetime(0) self.assertIsInstance(dt, datetime) # Epoch 0 should be Jan 1 1970 in local time self.assertEqual(dt.year, 1970) def testTimestampToHuman(self): """timestamp_to_human returns formatted string""" result = parsedmarc.utils.timestamp_to_human(1704067200) self.assertRegex(result, r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}") def testHumanTimestampToDatetime(self): """human_timestamp_to_datetime parses timestamp string""" dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00") self.assertIsInstance(dt, datetime) self.assertEqual(dt.year, 2024) self.assertEqual(dt.month, 1) self.assertEqual(dt.day, 1) def testHumanTimestampToDatetimeUtc(self): """human_timestamp_to_datetime with to_utc=True returns UTC""" dt = parsedmarc.utils.human_timestamp_to_datetime( "2024-01-01 12:00:00", to_utc=True ) self.assertEqual(dt.tzinfo, timezone.utc) def testHumanTimestampToDatetimeParenthesisStripping(self): """Parenthesized content is stripped from timestamps""" dt = parsedmarc.utils.human_timestamp_to_datetime( "Mon, 01 Jan 2024 00:00:00 +0000 (UTC)" ) self.assertEqual(dt.year, 2024) def testHumanTimestampToDatetimeNegativeZero(self): """-0000 timezone is handled""" dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00 -0000") self.assertEqual(dt.year, 2024) def testHumanTimestampToUnixTimestamp(self): """human_timestamp_to_unix_timestamp converts to int""" ts = parsedmarc.utils.human_timestamp_to_unix_timestamp("2024-01-01 00:00:00") self.assertIsInstance(ts, int) def testHumanTimestampToUnixTimestampWithT(self): """T separator in timestamp is handled""" ts = parsedmarc.utils.human_timestamp_to_unix_timestamp("2024-01-01T00:00:00") self.assertIsInstance(ts, int) def testGetIpAddressCountry(self): """get_ip_address_country returns country code using bundled DBIP""" # 8.8.8.8 is a well-known Google DNS IP in US country = parsedmarc.utils.get_ip_address_country("8.8.8.8") self.assertEqual(country, "US") def testGetIpAddressCountryNotFound(self): """get_ip_address_country returns None for reserved IP""" country = parsedmarc.utils.get_ip_address_country("127.0.0.1") self.assertIsNone(country) def testGetServiceFromReverseDnsBaseDomainOffline(self): """get_service_from_reverse_dns_base_domain in offline mode""" result = parsedmarc.utils.get_service_from_reverse_dns_base_domain( "google.com", offline=True ) self.assertIn("Google", result["name"]) self.assertIsNotNone(result["type"]) def testGetServiceFromReverseDnsBaseDomainUnknown(self): """Unknown base domain returns domain as name and None as type""" result = parsedmarc.utils.get_service_from_reverse_dns_base_domain( "unknown-domain-xyz.example", offline=True ) self.assertEqual(result["name"], "unknown-domain-xyz.example") self.assertIsNone(result["type"]) def testGetIpAddressInfoOffline(self): """get_ip_address_info in offline mode returns country but no DNS""" info = parsedmarc.utils.get_ip_address_info("8.8.8.8", offline=True) self.assertEqual(info["ip_address"], "8.8.8.8") self.assertEqual(info["country"], "US") self.assertIsNone(info["reverse_dns"]) def testGetIpAddressInfoCache(self): """get_ip_address_info uses cache on second call""" from expiringdict import ExpiringDict cache = ExpiringDict(max_len=100, max_age_seconds=60) with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"): info1 = parsedmarc.utils.get_ip_address_info( "8.8.8.8", offline=False, cache=cache, always_use_local_files=True, ) self.assertIn("8.8.8.8", cache) info2 = parsedmarc.utils.get_ip_address_info( "8.8.8.8", offline=False, cache=cache ) self.assertEqual(info1["ip_address"], info2["ip_address"]) self.assertEqual(info2["reverse_dns"], "dns.google") def testParseEmailAddressWithDisplayName(self): """parse_email_address with display name""" result = parsedmarc.utils.parse_email_address(("John Doe", "john@example.com")) self.assertEqual(result["display_name"], "John Doe") self.assertEqual(result["address"], "john@example.com") self.assertEqual(result["local"], "john") self.assertEqual(result["domain"], "example.com") def testParseEmailAddressWithoutDisplayName(self): """parse_email_address with empty display name""" result = parsedmarc.utils.parse_email_address(("", "john@example.com")) self.assertIsNone(result["display_name"]) self.assertEqual(result["address"], "john@example.com") def testParseEmailAddressNoAt(self): """parse_email_address with no @ returns None local/domain""" result = parsedmarc.utils.parse_email_address(("", "localonly")) self.assertIsNone(result["local"]) self.assertIsNone(result["domain"]) def testGetFilenameSafeString(self): """get_filename_safe_string removes invalid chars""" result = parsedmarc.utils.get_filename_safe_string('file/name:with"bad*chars') self.assertNotIn("/", result) self.assertNotIn(":", result) self.assertNotIn('"', result) self.assertNotIn("*", result) def testGetFilenameSafeStringNone(self): """get_filename_safe_string with None returns 'None'""" result = parsedmarc.utils.get_filename_safe_string(None) self.assertEqual(result, "None") def testGetFilenameSafeStringLong(self): """get_filename_safe_string truncates to 100 chars""" result = parsedmarc.utils.get_filename_safe_string("a" * 200) self.assertEqual(len(result), 100) def testGetFilenameSafeStringTrailingDot(self): """get_filename_safe_string strips trailing dots""" result = parsedmarc.utils.get_filename_safe_string("filename...") self.assertFalse(result.endswith(".")) def testIsMboxNonMbox(self): """is_mbox returns False for non-mbox file""" result = parsedmarc.utils.is_mbox("samples/empty.xml") self.assertFalse(result) def testIsOutlookMsgNonMsg(self): """is_outlook_msg returns False for non-MSG content""" self.assertFalse(parsedmarc.utils.is_outlook_msg(b"not an outlook msg")) self.assertFalse(parsedmarc.utils.is_outlook_msg("string content")) def testIsOutlookMsgMagic(self): """is_outlook_msg returns True for correct magic bytes""" magic = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" + b"\x00" * 100 self.assertTrue(parsedmarc.utils.is_outlook_msg(magic)) # =================================================================== # Tests for output modules (mocked) # =================================================================== def testWebhookClientInit(self): """WebhookClient initializes with correct attributes""" from parsedmarc.webhook import WebhookClient client = WebhookClient( aggregate_url="http://agg.example.com", failure_url="http://fail.example.com", smtp_tls_url="http://tls.example.com", ) self.assertEqual(client.aggregate_url, "http://agg.example.com") self.assertEqual(client.failure_url, "http://fail.example.com") self.assertEqual(client.smtp_tls_url, "http://tls.example.com") self.assertEqual(client.timeout, 60) def testWebhookClientSaveMethods(self): """WebhookClient save methods call _send_to_webhook""" from parsedmarc.webhook import WebhookClient client = WebhookClient("http://a", "http://f", "http://t") client.session = MagicMock() client.save_aggregate_report_to_webhook('{"test": 1}') client.session.post.assert_called_with( "http://a", data='{"test": 1}', timeout=60 ) client.save_failure_report_to_webhook('{"fail": 1}') client.session.post.assert_called_with( "http://f", data='{"fail": 1}', timeout=60 ) client.save_smtp_tls_report_to_webhook('{"tls": 1}') client.session.post.assert_called_with( "http://t", data='{"tls": 1}', timeout=60 ) def testWebhookBackwardCompatAlias(self): """WebhookClient forensic alias points to failure method""" from parsedmarc.webhook import WebhookClient self.assertIs( WebhookClient.save_forensic_report_to_webhook, WebhookClient.save_failure_report_to_webhook, ) def testKafkaStripMetadata(self): """KafkaClient.strip_metadata extracts metadata to root""" from parsedmarc.kafkaclient import KafkaClient report = { "report_metadata": { "org_name": "TestOrg", "org_email": "test@example.com", "report_id": "r-123", "begin_date": "2024-01-01", "end_date": "2024-01-02", }, "records": [], } result = KafkaClient.strip_metadata(report) self.assertEqual(result["org_name"], "TestOrg") self.assertEqual(result["org_email"], "test@example.com") self.assertEqual(result["report_id"], "r-123") self.assertNotIn("report_metadata", result) def testKafkaGenerateDateRange(self): """KafkaClient.generate_date_range generates date range list""" from parsedmarc.kafkaclient import KafkaClient report = { "report_metadata": { "begin_date": "2024-01-01 00:00:00", "end_date": "2024-01-02 00:00:00", } } result = KafkaClient.generate_date_range(report) self.assertEqual(len(result), 2) self.assertIn("2024-01-01", result[0]) self.assertIn("2024-01-02", result[1]) def testSplunkHECClientInit(self): """HECClient initializes with correct URL and headers""" from parsedmarc.splunk import HECClient client = HECClient( url="https://splunk.example.com:8088", access_token="my-token", index="main", ) self.assertIn("/services/collector/event/1.0", client.url) self.assertEqual(client.access_token, "my-token") self.assertEqual(client.index, "main") self.assertEqual(client.source, "parsedmarc") self.assertIn("Splunk my-token", client.session.headers["Authorization"]) def testSplunkHECClientStripTokenPrefix(self): """HECClient strips 'Splunk ' prefix from token""" from parsedmarc.splunk import HECClient client = HECClient( url="https://splunk.example.com", access_token="Splunk my-token", index="main", ) self.assertEqual(client.access_token, "my-token") def testSplunkBackwardCompatAlias(self): """HECClient forensic alias points to failure method""" from parsedmarc.splunk import HECClient self.assertIs( HECClient.save_forensic_reports_to_splunk, HECClient.save_failure_reports_to_splunk, ) def testSyslogClientUdpInit(self): """SyslogClient creates UDP handler""" from parsedmarc.syslog import SyslogClient client = SyslogClient("localhost", 514, protocol="udp") self.assertEqual(client.server_name, "localhost") self.assertEqual(client.server_port, 514) self.assertEqual(client.protocol, "udp") def testSyslogClientInvalidProtocol(self): """SyslogClient with invalid protocol raises ValueError""" from parsedmarc.syslog import SyslogClient with self.assertRaises(ValueError): SyslogClient("localhost", 514, protocol="invalid") def testSyslogBackwardCompatAlias(self): """SyslogClient forensic alias points to failure method""" from parsedmarc.syslog import SyslogClient self.assertIs( SyslogClient.save_forensic_report_to_syslog, SyslogClient.save_failure_report_to_syslog, ) def testLogAnalyticsConfig(self): """LogAnalyticsConfig stores all fields""" from parsedmarc.loganalytics import LogAnalyticsConfig config = LogAnalyticsConfig( client_id="cid", client_secret="csec", tenant_id="tid", dce="https://dce.example.com", dcr_immutable_id="dcr-123", dcr_aggregate_stream="agg-stream", dcr_failure_stream="fail-stream", dcr_smtp_tls_stream="tls-stream", ) self.assertEqual(config.client_id, "cid") self.assertEqual(config.client_secret, "csec") self.assertEqual(config.tenant_id, "tid") self.assertEqual(config.dce, "https://dce.example.com") self.assertEqual(config.dcr_immutable_id, "dcr-123") self.assertEqual(config.dcr_aggregate_stream, "agg-stream") self.assertEqual(config.dcr_failure_stream, "fail-stream") self.assertEqual(config.dcr_smtp_tls_stream, "tls-stream") def testLogAnalyticsClientValidationError(self): """LogAnalyticsClient raises on missing required config""" from parsedmarc.loganalytics import LogAnalyticsClient, LogAnalyticsException with self.assertRaises(LogAnalyticsException): LogAnalyticsClient( client_id="", client_secret="csec", tenant_id="tid", dce="https://dce.example.com", dcr_immutable_id="dcr-123", dcr_aggregate_stream="agg", dcr_failure_stream="fail", dcr_smtp_tls_stream="tls", ) def testSmtpTlsCsvRows(self): """parsed_smtp_tls_reports_to_csv_rows produces correct rows""" report_json = json.dumps( { "organization-name": "Org", "date-range": { "start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z", }, "contact-info": "a@b.com", "report-id": "r1", "policies": [ { "policy": { "policy-type": "sts", "policy-domain": "example.com", "policy-string": ["v: STSv1"], "mx-host-pattern": ["*.example.com"], }, "summary": { "total-successful-session-count": 10, "total-failure-session-count": 1, }, "failure-details": [ {"result-type": "cert-expired", "failed-session-count": 1} ], } ], } ) parsed = parsedmarc.parse_smtp_tls_report_json(report_json) rows = parsedmarc.parsed_smtp_tls_reports_to_csv_rows(parsed) self.assertTrue(len(rows) >= 2) self.assertEqual(rows[0]["organization_name"], "Org") self.assertEqual(rows[0]["policy_domain"], "example.com") def testParsedAggregateReportsToCsvRowsList(self): """parsed_aggregate_reports_to_csv_rows handles list of reports""" result = parsedmarc.parse_report_file( "samples/aggregate/dmarcbis-draft-sample.xml", always_use_local_files=True, offline=True, ) report = result["report"] # Pass as a list rows = parsedmarc.parsed_aggregate_reports_to_csv_rows([report]) self.assertTrue(len(rows) > 0) # Verify non-str/int/bool values are cleaned for row in rows: for v in row.values(): self.assertIn(type(v), [str, int, bool]) def testExceptionHierarchy(self): """Exception class hierarchy is correct""" self.assertTrue(issubclass(parsedmarc.ParserError, RuntimeError)) self.assertTrue( issubclass(parsedmarc.InvalidDMARCReport, parsedmarc.ParserError) ) self.assertTrue( issubclass(parsedmarc.InvalidAggregateReport, parsedmarc.InvalidDMARCReport) ) self.assertTrue( issubclass(parsedmarc.InvalidFailureReport, parsedmarc.InvalidDMARCReport) ) self.assertTrue( issubclass(parsedmarc.InvalidSMTPTLSReport, parsedmarc.ParserError) ) self.assertIs(parsedmarc.InvalidForensicReport, parsedmarc.InvalidFailureReport) def testAggregateReportNormalization(self): """Reports spanning >24h get normalized per day""" xml = """ TestOrg test@example.com test-norm 17040672001704326400 example.com

none

192.0.2.1 90 nonepasspass example.com example.compass
""" # Span is 259200 seconds (3 days), exceeds default 24h threshold report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) self.assertTrue(report["report_metadata"]["timespan_requires_normalization"]) # Records should be split across days self.assertTrue(len(report["records"]) > 1) total = sum(r["count"] for r in report["records"]) self.assertEqual(total, 90) for r in report["records"]: self.assertTrue(r["normalized_timespan"]) # =================================================================== # Additional backward compatibility alias tests # =================================================================== def testGelfBackwardCompatAlias(self): """GelfClient forensic alias points to failure method""" from parsedmarc.gelf import GelfClient self.assertIs( GelfClient.save_forensic_report_to_gelf, GelfClient.save_failure_report_to_gelf, ) def testS3BackwardCompatAlias(self): """S3Client forensic alias points to failure method""" from parsedmarc.s3 import S3Client self.assertIs( S3Client.save_forensic_report_to_s3, S3Client.save_failure_report_to_s3, ) def testKafkaBackwardCompatAlias(self): """KafkaClient forensic alias points to failure method""" from parsedmarc.kafkaclient import KafkaClient self.assertIs( KafkaClient.save_forensic_reports_to_kafka, KafkaClient.save_failure_reports_to_kafka, ) # =================================================================== # Additional extract/parse tests # =================================================================== def testExtractReportFromFilePathNotFound(self): """extract_report_from_file_path raises ParserError for missing file""" with self.assertRaises(parsedmarc.ParserError): parsedmarc.extract_report_from_file_path("nonexistent_file.xml") def testExtractReportInvalidArchive(self): """extract_report raises ParserError for unrecognized binary content""" with self.assertRaises(parsedmarc.ParserError): parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07") def testParseAggregateReportFile(self): """parse_aggregate_report_file parses bytes input directly""" print() sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" print("Testing {0}: ".format(sample_path), end="") with open(sample_path, "rb") as f: data = f.read() report = parsedmarc.parse_aggregate_report_file( data, offline=True, always_use_local_files=True, ) self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter") self.assertEqual(report["policy_published"]["domain"], "example.com") print("Passed!") def testParseInvalidAggregateSample(self): """Test invalid aggregate samples are handled""" print() sample_paths = glob("samples/aggregate_invalid/*") for sample_path in sample_paths: if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") with self.subTest(sample=sample_path): parsed_report = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=OFFLINE_MODE )["report"] parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) print("Passed!") def testParseReportFileWithBytes(self): """parse_report_file handles bytes input""" with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f: data = f.read() result = parsedmarc.parse_report_file( data, always_use_local_files=True, offline=True ) self.assertEqual(result["report_type"], "aggregate") def testFailureReportCsvRoundtrip(self): """Failure report CSV generation works on sample reports""" print() sample_paths = glob("samples/failure/*.eml") for sample_path in sample_paths: print("Testing CSV for {0}: ".format(sample_path), end="") with self.subTest(sample=sample_path): parsed_report = parsedmarc.parse_report_file( sample_path, offline=OFFLINE_MODE )["report"] csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report) self.assertIsNotNone(csv_output) self.assertIn(",", csv_output) rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report) self.assertTrue(len(rows) > 0) print("Passed!") def testOpenSearchSigV4RequiresRegion(self): with self.assertRaises(opensearch_module.OpenSearchError): opensearch_module.set_hosts( "https://example.org:9200", auth_type="awssigv4", ) def testOpenSearchSigV4ConfiguresConnectionClass(self): fake_credentials = object() with patch.object(opensearch_module.boto3, "Session") as session_cls: session_cls.return_value.get_credentials.return_value = fake_credentials with patch.object( opensearch_module, "AWSV4SignerAuth", return_value="auth" ) as signer: with patch.object( opensearch_module.connections, "create_connection" ) as create_connection: opensearch_module.set_hosts( "https://example.org:9200", use_ssl=True, auth_type="awssigv4", aws_region="eu-west-1", ) signer.assert_called_once_with(fake_credentials, "eu-west-1", "es") create_connection.assert_called_once() self.assertEqual( create_connection.call_args.kwargs.get("connection_class"), opensearch_module.RequestsHttpConnection, ) self.assertEqual(create_connection.call_args.kwargs.get("http_auth"), "auth") def testOpenSearchSigV4RejectsUnknownAuthType(self): with self.assertRaises(opensearch_module.OpenSearchError): opensearch_module.set_hosts( "https://example.org:9200", auth_type="kerberos", ) def testOpenSearchSigV4RequiresAwsCredentials(self): with patch.object(opensearch_module.boto3, "Session") as session_cls: session_cls.return_value.get_credentials.return_value = None with self.assertRaises(opensearch_module.OpenSearchError): opensearch_module.set_hosts( "https://example.org:9200", auth_type="awssigv4", aws_region="eu-west-1", ) @patch("parsedmarc.cli.opensearch.migrate_indexes") @patch("parsedmarc.cli.opensearch.set_hosts") @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.IMAPConnection") def testCliPassesOpenSearchSigV4Settings( self, mock_imap_connection, mock_get_reports, mock_set_hosts, _mock_migrate_indexes, ): mock_imap_connection.return_value = object() mock_get_reports.return_value = { "aggregate_reports": [], "failure_reports": [], "smtp_tls_reports": [], } config = """[general] save_aggregate = true silent = true [imap] host = imap.example.com user = test-user password = test-password [opensearch] hosts = localhost authentication_type = awssigv4 aws_region = eu-west-1 aws_service = aoss """ with tempfile.NamedTemporaryFile( "w", suffix=".ini", delete=False ) as config_file: config_file.write(config) config_path = config_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): parsedmarc.cli._main() self.assertEqual(mock_set_hosts.call_args.kwargs.get("auth_type"), "awssigv4") self.assertEqual(mock_set_hosts.call_args.kwargs.get("aws_region"), "eu-west-1") self.assertEqual(mock_set_hosts.call_args.kwargs.get("aws_service"), "aoss") @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") @patch("parsedmarc.cli.elastic.migrate_indexes") @patch("parsedmarc.cli.elastic.set_hosts") @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.IMAPConnection") def testFailOnOutputErrorExits( self, mock_imap_connection, mock_get_reports, _mock_set_hosts, _mock_migrate_indexes, mock_save_aggregate, ): """CLI should exit with code 1 when fail_on_output_error is enabled""" mock_imap_connection.return_value = object() mock_get_reports.return_value = { "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], "failure_reports": [], "smtp_tls_reports": [], } mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( "simulated output failure" ) config = """[general] save_aggregate = true fail_on_output_error = true silent = true [imap] host = imap.example.com user = test-user password = test-password [elasticsearch] hosts = localhost """ with tempfile.NamedTemporaryFile( "w", suffix=".ini", delete=False ) as config_file: config_file.write(config) config_path = config_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): with self.assertRaises(SystemExit) as ctx: parsedmarc.cli._main() self.assertEqual(ctx.exception.code, 1) mock_save_aggregate.assert_called_once() @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") @patch("parsedmarc.cli.elastic.migrate_indexes") @patch("parsedmarc.cli.elastic.set_hosts") @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.IMAPConnection") def testOutputErrorDoesNotExitWhenDisabled( self, mock_imap_connection, mock_get_reports, _mock_set_hosts, _mock_migrate_indexes, mock_save_aggregate, ): mock_imap_connection.return_value = object() mock_get_reports.return_value = { "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], "failure_reports": [], "smtp_tls_reports": [], } mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( "simulated output failure" ) config = """[general] save_aggregate = true fail_on_output_error = false silent = true [imap] host = imap.example.com user = test-user password = test-password [elasticsearch] hosts = localhost """ with tempfile.NamedTemporaryFile( "w", suffix=".ini", delete=False ) as config_file: config_file.write(config) config_path = config_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): parsedmarc.cli._main() mock_save_aggregate.assert_called_once() @patch("parsedmarc.cli.opensearch.save_failure_report_to_opensearch") @patch("parsedmarc.cli.opensearch.migrate_indexes") @patch("parsedmarc.cli.opensearch.set_hosts") @patch("parsedmarc.cli.elastic.save_failure_report_to_elasticsearch") @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") @patch("parsedmarc.cli.elastic.migrate_indexes") @patch("parsedmarc.cli.elastic.set_hosts") @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.IMAPConnection") def testFailOnOutputErrorExitsWithMultipleSinkErrors( self, mock_imap_connection, mock_get_reports, _mock_es_set_hosts, _mock_es_migrate, mock_save_aggregate, _mock_save_failure_elastic, _mock_os_set_hosts, _mock_os_migrate, mock_save_failure_opensearch, ): mock_imap_connection.return_value = object() mock_get_reports.return_value = { "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], "failure_reports": [{"reported_domain": "example.com"}], "smtp_tls_reports": [], } mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( "aggregate sink failed" ) mock_save_failure_opensearch.side_effect = ( parsedmarc.cli.opensearch.OpenSearchError("failure sink failed") ) config = """[general] save_aggregate = true save_forensic = true fail_on_output_error = true silent = true [imap] host = imap.example.com user = test-user password = test-password [elasticsearch] hosts = localhost [opensearch] hosts = localhost """ with tempfile.NamedTemporaryFile( "w", suffix=".ini", delete=False ) as config_file: config_file.write(config) config_path = config_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): with self.assertRaises(SystemExit) as ctx: parsedmarc.cli._main() self.assertEqual(ctx.exception.code, 1) mock_save_aggregate.assert_called_once() mock_save_failure_opensearch.assert_called_once() class _FakeGraphResponse: def __init__(self, status_code, payload=None, text=""): self.status_code = status_code self._payload = payload or {} self.text = text def json(self): return self._payload class _BreakLoop(BaseException): pass class TestGmailConnection(unittest.TestCase): def _build_connection(self, *, paginate=True): connection = GmailConnection.__new__(GmailConnection) connection.include_spam_trash = False connection.reports_label_id = "REPORTS" connection.paginate_messages = paginate connection.service = MagicMock() return connection def testFindLabelId(self): connection = self._build_connection() labels_api = connection.service.users.return_value.labels.return_value labels_api.list.return_value.execute.return_value = { "labels": [ {"id": "INBOX", "name": "INBOX"}, {"id": "REPORTS", "name": "Reports"}, ] } self.assertEqual(connection._find_label_id_for_label("Reports"), "REPORTS") self.assertEqual(connection._find_label_id_for_label("MISSING"), "") def testFetchMessagesWithPagination(self): connection = self._build_connection(paginate=True) messages_api = connection.service.users.return_value.messages.return_value def list_side_effect(**kwargs): response = MagicMock() if kwargs.get("pageToken") is None: response.execute.return_value = { "messages": [{"id": "a"}, {"id": "b"}], "nextPageToken": "n1", } else: response.execute.return_value = {"messages": [{"id": "c"}]} return response messages_api.list.side_effect = list_side_effect connection._find_label_id_for_label = MagicMock(return_value="REPORTS") self.assertEqual(connection.fetch_messages("Reports"), ["a", "b", "c"]) def testFetchMessageDecoding(self): connection = self._build_connection() messages_api = connection.service.users.return_value.messages.return_value raw = urlsafe_b64encode(b"Subject: test\n\nbody").decode() messages_api.get.return_value.execute.return_value = {"raw": raw} content = connection.fetch_message("m1") self.assertIn("Subject: test", content) def testMoveAndDeleteMessage(self): connection = self._build_connection() connection._find_label_id_for_label = MagicMock(return_value="ARCHIVE") messages_api = connection.service.users.return_value.messages.return_value messages_api.modify.return_value.execute.return_value = {} connection.move_message("m1", "Archive") messages_api.modify.assert_called_once() connection.delete_message("m1") messages_api.delete.assert_called_once_with(userId="me", id="m1") messages_api.delete.return_value.execute.assert_called_once() def testGetCredsFromTokenFile(self): creds = MagicMock() creds.valid = True with NamedTemporaryFile("w", delete=False) as token_file: token_file.write("{}") token_path = token_file.name try: with patch.object( gmail_module.Credentials, "from_authorized_user_file", return_value=creds, ): returned = _get_creds(token_path, "credentials.json", ["scope"], 8080) finally: os.remove(token_path) self.assertEqual(returned, creds) def testGetCredsWithOauthFlow(self): expired_creds = MagicMock() expired_creds.valid = False expired_creds.expired = False expired_creds.refresh_token = None new_creds = MagicMock() new_creds.valid = True new_creds.to_json.return_value = '{"token":"x"}' flow = MagicMock() flow.run_local_server.return_value = new_creds with NamedTemporaryFile("w", delete=False) as token_file: token_file.write("{}") token_path = token_file.name try: with patch.object( gmail_module.Credentials, "from_authorized_user_file", return_value=expired_creds, ): with patch.object( gmail_module.InstalledAppFlow, "from_client_secrets_file", return_value=flow, ): returned = _get_creds( token_path, "credentials.json", ["scope"], 8080 ) finally: os.remove(token_path) self.assertEqual(returned, new_creds) flow.run_local_server.assert_called_once() def testGetCredsRefreshesExpiredToken(self): expired_creds = MagicMock() expired_creds.valid = False expired_creds.expired = True expired_creds.refresh_token = "rt" expired_creds.to_json.return_value = '{"token":"refreshed"}' with NamedTemporaryFile("w", delete=False) as token_file: token_file.write("{}") token_path = token_file.name try: with patch.object( gmail_module.Credentials, "from_authorized_user_file", return_value=expired_creds, ): returned = _get_creds(token_path, "credentials.json", ["scope"], 8080) finally: os.remove(token_path) self.assertEqual(returned, expired_creds) expired_creds.refresh.assert_called_once() def testCreateFolderConflictIgnored(self): connection = self._build_connection() labels_api = connection.service.users.return_value.labels.return_value conflict = HttpError(Response({"status": "409"}), b"conflict") labels_api.create.return_value.execute.side_effect = conflict connection.create_folder("Existing") class TestGraphConnection(unittest.TestCase): def testLoadTokenMissing(self): with TemporaryDirectory() as temp_dir: missing_path = Path(temp_dir) / "missing-token-file" self.assertIsNone(_load_token(missing_path)) def testLoadTokenExisting(self): with NamedTemporaryFile("w", delete=False) as token_file: token_file.write("serialized-auth-record") token_path = token_file.name try: self.assertEqual(_load_token(Path(token_path)), "serialized-auth-record") finally: os.remove(token_path) def testGetAllMessagesPagination(self): connection = MSGraphConnection.__new__(MSGraphConnection) first_response = _FakeGraphResponse( 200, {"value": [{"id": "1"}], "@odata.nextLink": "next-url"} ) second_response = _FakeGraphResponse(200, {"value": [{"id": "2"}]}) connection._client = MagicMock() connection._client.get.side_effect = [first_response, second_response] messages = connection._get_all_messages("/url", batch_size=0, since=None) self.assertEqual([msg["id"] for msg in messages], ["1", "2"]) def testGetAllMessagesInitialRequestFailure(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection._client = MagicMock() connection._client.get.return_value = _FakeGraphResponse(500, text="boom") with self.assertRaises(RuntimeError): connection._get_all_messages("/url", batch_size=0, since=None) def testGetAllMessagesRetriesTransientRequestErrors(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection._client = MagicMock() connection._client.get.side_effect = [ graph_module.RequestException("connection reset"), _FakeGraphResponse(200, {"value": [{"id": "1"}]}), ] with patch.object(graph_module, "sleep") as mocked_sleep: messages = connection._get_all_messages("/url", batch_size=0, since=None) self.assertEqual([msg["id"] for msg in messages], ["1"]) mocked_sleep.assert_called_once_with(graph_module.GRAPH_REQUEST_RETRY_DELAY_SECONDS) def testGetAllMessagesRaisesAfterRetryExhaustion(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection._client = MagicMock() connection._client.get.side_effect = graph_module.RequestException( "connection reset" ) with patch.object(graph_module, "sleep") as mocked_sleep: with self.assertRaises(graph_module.RequestException): connection._get_all_messages("/url", batch_size=0, since=None) self.assertEqual( mocked_sleep.call_count, graph_module.GRAPH_REQUEST_RETRY_ATTEMPTS - 1 ) def testGetAllMessagesNextPageFailure(self): connection = MSGraphConnection.__new__(MSGraphConnection) first_response = _FakeGraphResponse( 200, {"value": [{"id": "1"}], "@odata.nextLink": "next-url"} ) second_response = _FakeGraphResponse(500, text="page-fail") connection._client = MagicMock() connection._client.get.side_effect = [first_response, second_response] with self.assertRaises(RuntimeError): connection._get_all_messages("/url", batch_size=0, since=None) def testGetAllMessagesHonorsBatchSizeLimit(self): connection = MSGraphConnection.__new__(MSGraphConnection) first_response = _FakeGraphResponse( 200, { "value": [{"id": "1"}, {"id": "2"}], "@odata.nextLink": "next-url", }, ) connection._client = MagicMock() connection._client.get.return_value = first_response messages = connection._get_all_messages("/url", batch_size=2, since=None) self.assertEqual([msg["id"] for msg in messages], ["1", "2"]) connection._client.get.assert_called_once() def testFetchMessagesPassesSinceAndBatchSize(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" connection._find_folder_id_from_folder_path = MagicMock( return_value="folder-id" ) connection._get_all_messages = MagicMock(return_value=[{"id": "1"}]) self.assertEqual( connection.fetch_messages("Inbox", since="2026-03-01", batch_size=5), ["1"] ) connection._get_all_messages.assert_called_once_with( "/users/mailbox@example.com/mailFolders/folder-id/messages", 5, "2026-03-01", ) def testFetchMessageMarksRead(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" connection._client = MagicMock() connection._client.get.return_value = _FakeGraphResponse( 200, text="email-content" ) connection.mark_message_read = MagicMock() content = connection.fetch_message("123", mark_read=True) self.assertEqual(content, "email-content") connection.mark_message_read.assert_called_once_with("123") def testFindFolderIdNotFound(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" connection._client = MagicMock() connection._client.get.return_value = _FakeGraphResponse(200, {"value": []}) with self.assertRaises(RuntimeError): connection._find_folder_id_with_parent("Missing", None) def testGetCacheArgsWithAuthRecord(self): with NamedTemporaryFile("w", delete=False) as token_file: token_file.write("serialized") token_path = Path(token_file.name) try: with patch.object( graph_module.AuthenticationRecord, "deserialize", return_value="auth_record", ): args = _get_cache_args(token_path, allow_unencrypted_storage=False) self.assertIn("authentication_record", args) finally: os.remove(token_path) def testGenerateCredentialInvalid(self): with self.assertRaises(RuntimeError): _generate_credential( "Nope", Path("/tmp/token"), client_id="x", client_secret="y", username="u", password="p", tenant_id="t", allow_unencrypted_storage=False, ) def testGenerateCredentialDeviceCode(self): fake_credential = object() with patch.object( graph_module, "_get_cache_args", return_value={"cached": True} ): with patch.object( graph_module, "DeviceCodeCredential", return_value=fake_credential, ) as mocked: result = _generate_credential( graph_module.AuthMethod.DeviceCode.name, Path("/tmp/token"), client_id="cid", client_secret="secret", username="user", password="pass", tenant_id="tenant", allow_unencrypted_storage=True, ) self.assertIs(result, fake_credential) mocked.assert_called_once() def testGenerateCredentialClientSecret(self): fake_credential = object() with patch.object( graph_module, "ClientSecretCredential", return_value=fake_credential ) as mocked: result = _generate_credential( graph_module.AuthMethod.ClientSecret.name, Path("/tmp/token"), client_id="cid", client_secret="secret", username="user", password="pass", tenant_id="tenant", allow_unencrypted_storage=False, ) self.assertIs(result, fake_credential) mocked.assert_called_once_with( client_id="cid", tenant_id="tenant", client_secret="secret" ) def testGenerateCredentialCertificate(self): fake_credential = object() with patch.object( graph_module, "CertificateCredential", return_value=fake_credential ) as mocked: result = _generate_credential( graph_module.AuthMethod.Certificate.name, Path("/tmp/token"), client_id="cid", client_secret="secret", certificate_path="/tmp/cert.pem", certificate_password="secret-pass", username="user", password="pass", tenant_id="tenant", allow_unencrypted_storage=False, ) self.assertIs(result, fake_credential) mocked.assert_called_once_with( client_id="cid", tenant_id="tenant", certificate_path="/tmp/cert.pem", password="secret-pass", ) def testGenerateCredentialCertificateRequiresPath(self): with self.assertRaisesRegex( ValueError, "certificate_path is required when auth_method is 'Certificate'", ): _generate_credential( graph_module.AuthMethod.Certificate.name, Path("/tmp/token"), client_id="cid", client_secret=None, certificate_path=None, certificate_password="secret-pass", username=None, password=None, tenant_id="tenant", allow_unencrypted_storage=False, ) def testInitUsesSharedMailboxScopes(self): class FakeCredential: def __init__(self): self.authenticate = MagicMock(return_value="auth-record") fake_credential = FakeCredential() with patch.object( graph_module, "_generate_credential", return_value=fake_credential ): with patch.object(graph_module, "_cache_auth_record") as cache_auth: with patch.object(graph_module, "GraphClient") as graph_client: MSGraphConnection( auth_method=graph_module.AuthMethod.DeviceCode.name, mailbox="shared@example.com", graph_url="https://graph.microsoft.com", client_id="cid", client_secret="secret", username="owner@example.com", password="pass", tenant_id="tenant", token_file="/tmp/token-file", allow_unencrypted_storage=True, ) fake_credential.authenticate.assert_called_once_with( scopes=["Mail.ReadWrite.Shared"] ) cache_auth.assert_called_once() graph_client.assert_called_once() self.assertEqual( graph_client.call_args.kwargs.get("scopes"), ["Mail.ReadWrite.Shared"] ) def testInitWithoutUsernameUsesDefaultMailReadWriteScope(self): class FakeCredential: def __init__(self): self.authenticate = MagicMock(return_value="auth-record") fake_credential = FakeCredential() with patch.object( graph_module, "_generate_credential", return_value=fake_credential ): with patch.object(graph_module, "_cache_auth_record") as cache_auth: with patch.object(graph_module, "GraphClient") as graph_client: MSGraphConnection( auth_method=graph_module.AuthMethod.DeviceCode.name, mailbox="owner@example.com", graph_url="https://graph.microsoft.com", client_id="cid", client_secret="secret", username=None, password=None, tenant_id="tenant", token_file="/tmp/token-file", allow_unencrypted_storage=True, ) fake_credential.authenticate.assert_called_once_with(scopes=["Mail.ReadWrite"]) cache_auth.assert_called_once() graph_client.assert_called_once() self.assertEqual(graph_client.call_args.kwargs.get("scopes"), ["Mail.ReadWrite"]) def testInitCertificateAuthSkipsInteractiveAuthenticate(self): class DummyCertificateCredential: pass fake_credential = DummyCertificateCredential() with patch.object(graph_module, "CertificateCredential", DummyCertificateCredential): with patch.object( graph_module, "_generate_credential", return_value=fake_credential ): with patch.object(graph_module, "_cache_auth_record") as cache_auth: with patch.object(graph_module, "GraphClient") as graph_client: MSGraphConnection( auth_method=graph_module.AuthMethod.Certificate.name, mailbox="shared@example.com", graph_url="https://graph.microsoft.com", client_id="cid", client_secret=None, certificate_path="/tmp/cert.pem", certificate_password="secret-pass", username=None, password=None, tenant_id="tenant", token_file="/tmp/token-file", allow_unencrypted_storage=False, ) cache_auth.assert_not_called() graph_client.assert_called_once() self.assertNotIn("scopes", graph_client.call_args.kwargs) def testCreateFolderAndMoveErrors(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" connection._client = MagicMock() connection._client.post.return_value = _FakeGraphResponse(500, {"error": "x"}) connection._find_folder_id_from_folder_path = MagicMock(return_value="dest") with self.assertRaises(RuntimeWarning): connection.move_message("m1", "Archive") connection._client.post.return_value = _FakeGraphResponse(409, {}) connection.create_folder("Archive") def testMarkReadDeleteFailures(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" connection._client = MagicMock() connection._client.patch.return_value = _FakeGraphResponse(500, {"error": "x"}) with self.assertRaises(RuntimeWarning): connection.mark_message_read("m1") connection._client.delete.return_value = _FakeGraphResponse(500, {"error": "x"}) with self.assertRaises(RuntimeWarning): connection.delete_message("m1") class TestImapConnection(unittest.TestCase): def testDelegatesToImapClient(self): with patch.object(imap_module, "IMAPClient") as mocked_client_cls: mocked_client = MagicMock() mocked_client_cls.return_value = mocked_client connection = IMAPConnection( "imap.example.com", user="user", password="pass" ) connection.create_folder("Archive") mocked_client.create_folder.assert_called_once_with("Archive") mocked_client.search.return_value = [1, 2] self.assertEqual(connection.fetch_messages("INBOX"), [1, 2]) mocked_client.select_folder.assert_called_with("INBOX") connection.fetch_messages("INBOX", since="2026-03-01") mocked_client.search.assert_called_with("SINCE 2026-03-01") mocked_client.fetch_message.return_value = "raw-message" self.assertEqual(connection.fetch_message(1), "raw-message") connection.delete_message(7) mocked_client.delete_messages.assert_called_once_with([7]) connection.move_message(8, "Archive") mocked_client.move_messages.assert_called_once_with([8], "Archive") connection.keepalive() mocked_client.noop.assert_called_once() def testWatchReconnectPath(self): with patch.object(imap_module, "IMAPClient") as mocked_client_cls: base_client = MagicMock() base_client.host = "imap.example.com" base_client.port = 993 base_client.ssl = True mocked_client_cls.return_value = base_client connection = IMAPConnection( "imap.example.com", user="user", password="pass" ) calls = {"count": 0} def fake_imap_constructor(*args, **kwargs): idle_callback = kwargs.get("idle_callback") if calls["count"] == 0: calls["count"] += 1 raise IMAPClientError("timeout") if idle_callback is not None: idle_callback(base_client) raise _BreakLoop() callback = MagicMock() with patch.object(imap_module, "sleep", return_value=None): with patch.object( imap_module, "IMAPClient", side_effect=fake_imap_constructor ): with self.assertRaises(_BreakLoop): connection.watch(callback, check_timeout=1) callback.assert_called_once_with(connection) class TestGmailAuthModes(unittest.TestCase): @patch( "parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file" ) def testGetCredsServiceAccountWithoutSubject(self, mock_from_service_account_file): service_creds = MagicMock() service_creds.with_subject.return_value = MagicMock() mock_from_service_account_file.return_value = service_creds creds = gmail_module._get_creds( token_file=".token", credentials_file="service-account.json", scopes=["https://www.googleapis.com/auth/gmail.readonly"], oauth2_port=8080, auth_mode="service_account", service_account_user=None, ) self.assertIs(creds, service_creds) mock_from_service_account_file.assert_called_once_with( "service-account.json", scopes=["https://www.googleapis.com/auth/gmail.readonly"], ) service_creds.with_subject.assert_not_called() @patch( "parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file" ) def testGetCredsServiceAccountWithSubject(self, mock_from_service_account_file): base_creds = MagicMock() delegated_creds = MagicMock() base_creds.with_subject.return_value = delegated_creds mock_from_service_account_file.return_value = base_creds creds = gmail_module._get_creds( token_file=".token", credentials_file="service-account.json", scopes=["https://www.googleapis.com/auth/gmail.modify"], oauth2_port=8080, auth_mode="service_account", service_account_user="dmarc@example.com", ) self.assertIs(creds, delegated_creds) base_creds.with_subject.assert_called_once_with("dmarc@example.com") def testGetCredsRejectsUnsupportedAuthMode(self): with self.assertRaises(ValueError): gmail_module._get_creds( token_file=".token", credentials_file="client-secret.json", scopes=["https://www.googleapis.com/auth/gmail.modify"], oauth2_port=8080, auth_mode="unsupported", ) @patch("parsedmarc.mail.gmail.Path.exists", return_value=True) @patch("parsedmarc.mail.gmail.Credentials.from_authorized_user_file") def testGetCredsInstalledAppStillUsesTokenFile( self, mock_from_authorized_user_file, _mock_exists ): token_creds = MagicMock() token_creds.valid = True mock_from_authorized_user_file.return_value = token_creds creds = gmail_module._get_creds( token_file=".token", credentials_file="client-secret.json", scopes=["https://www.googleapis.com/auth/gmail.modify"], oauth2_port=8080, auth_mode="installed_app", ) self.assertIs(creds, token_creds) mock_from_authorized_user_file.assert_called_once_with( ".token", ["https://www.googleapis.com/auth/gmail.modify"], ) @patch("parsedmarc.mail.gmail.GmailConnection._find_label_id_for_label") @patch("parsedmarc.mail.gmail.build") @patch("parsedmarc.mail.gmail._get_creds") def testGmailConnectionPassesAuthModeAndDelegatedUser( self, mock_get_creds, mock_build, mock_find_label ): mock_get_creds.return_value = MagicMock() mock_build.return_value = MagicMock() mock_find_label.return_value = "INBOX" gmail_module.GmailConnection( token_file=".token", credentials_file="service-account.json", scopes=["https://www.googleapis.com/auth/gmail.modify"], include_spam_trash=False, reports_folder="INBOX", oauth2_port=8080, paginate_messages=True, auth_mode="service_account", service_account_user="dmarc@example.com", ) mock_get_creds.assert_called_once_with( ".token", "service-account.json", ["https://www.googleapis.com/auth/gmail.modify"], 8080, auth_mode="service_account", service_account_user="dmarc@example.com", ) @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.GmailConnection") def testCliPassesGmailServiceAccountAuthSettings( self, mock_gmail_connection, mock_get_mailbox_reports ): mock_gmail_connection.return_value = MagicMock() mock_get_mailbox_reports.return_value = { "aggregate_reports": [], "failure_reports": [], "smtp_tls_reports": [], } config = """[general] silent = true [gmail_api] credentials_file = /tmp/service-account.json auth_mode = service_account service_account_user = dmarc@example.com scopes = https://www.googleapis.com/auth/gmail.modify """ with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg_file: cfg_file.write(config) config_path = cfg_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): parsedmarc.cli._main() self.assertEqual( mock_gmail_connection.call_args.kwargs.get("auth_mode"), "service_account" ) self.assertEqual( mock_gmail_connection.call_args.kwargs.get("service_account_user"), "dmarc@example.com", ) @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.GmailConnection") def testCliAcceptsDelegatedUserAlias(self, mock_gmail_connection, mock_get_reports): mock_gmail_connection.return_value = MagicMock() mock_get_reports.return_value = { "aggregate_reports": [], "failure_reports": [], "smtp_tls_reports": [], } config = """[general] silent = true [gmail_api] credentials_file = /tmp/service-account.json auth_mode = service_account delegated_user = delegated@example.com scopes = https://www.googleapis.com/auth/gmail.modify """ with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg_file: cfg_file.write(config) config_path = cfg_file.name self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): parsedmarc.cli._main() self.assertEqual( mock_gmail_connection.call_args.kwargs.get("service_account_user"), "delegated@example.com", ) class TestImapFallbacks(unittest.TestCase): def testDeleteSuccessDoesNotUseFallback(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() connection.delete_message(42) connection._client.delete_messages.assert_called_once_with([42]) connection._client.add_flags.assert_not_called() connection._client.expunge.assert_not_called() def testDeleteFallbackUsesFlagsAndExpunge(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() connection._client.delete_messages.side_effect = IMAPClientError("uid expunge") connection.delete_message(42) connection._client.add_flags.assert_called_once_with( [42], [r"\Deleted"], silent=True ) connection._client.expunge.assert_called_once_with() def testDeleteFallbackErrorPropagates(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() connection._client.delete_messages.side_effect = IMAPClientError("uid expunge") connection._client.add_flags.side_effect = IMAPClientError("flag failed") with self.assertRaises(IMAPClientError): connection.delete_message(42) def testMoveSuccessDoesNotUseFallback(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() with patch.object(connection, "delete_message") as delete_mock: connection.move_message(99, "Archive") connection._client.move_messages.assert_called_once_with([99], "Archive") connection._client.copy.assert_not_called() delete_mock.assert_not_called() def testMoveFallbackCopiesThenDeletes(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() connection._client.move_messages.side_effect = IMAPClientError("move failed") with patch.object(connection, "delete_message") as delete_mock: connection.move_message(99, "Archive") connection._client.copy.assert_called_once_with([99], "Archive") delete_mock.assert_called_once_with(99) def testMoveFallbackCopyErrorPropagates(self): connection = IMAPConnection.__new__(IMAPConnection) connection._client = MagicMock() connection._client.move_messages.side_effect = IMAPClientError("move failed") connection._client.copy.side_effect = IMAPClientError("copy failed") with patch.object(connection, "delete_message") as delete_mock: with self.assertRaises(IMAPClientError): connection.move_message(99, "Archive") delete_mock.assert_not_called() class TestMailboxWatchSince(unittest.TestCase): def testWatchInboxPassesSinceToMailboxFetch(self): mailbox_connection = SimpleNamespace() def fake_watch(check_callback, check_timeout): check_callback(mailbox_connection) raise _BreakLoop() mailbox_connection.watch = fake_watch callback = MagicMock() with patch.object( parsedmarc, "get_dmarc_reports_from_mailbox", return_value={} ) as mocked: with self.assertRaises(_BreakLoop): parsedmarc.watch_inbox( mailbox_connection=mailbox_connection, callback=callback, check_timeout=1, batch_size=10, since="1d", ) self.assertEqual(mocked.call_args.kwargs.get("since"), "1d") @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.watch_inbox") @patch("parsedmarc.cli.IMAPConnection") def testCliPassesSinceToWatchInbox( self, mock_imap_connection, mock_watch_inbox, mock_get_mailbox_reports ): mock_imap_connection.return_value = object() mock_get_mailbox_reports.return_value = { "aggregate_reports": [], "failure_reports": [], "smtp_tls_reports": [], } mock_watch_inbox.side_effect = FileExistsError("stop-watch-loop") config_text = """[general] silent = true [imap] host = imap.example.com user = user password = pass [mailbox] watch = true since = 2d """ with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg: cfg.write(config_text) cfg_path = cfg.name self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]): with self.assertRaises(SystemExit) as system_exit: parsedmarc.cli._main() self.assertEqual(system_exit.exception.code, 1) self.assertEqual(mock_watch_inbox.call_args.kwargs.get("since"), "2d") class _DummyMailboxConnection: def __init__(self): self.fetch_calls = [] def create_folder(self, folder_name): return None def fetch_messages(self, reports_folder, **kwargs): self.fetch_calls.append({"reports_folder": reports_folder, **kwargs}) return [] def fetch_message(self, message_id, **kwargs): return "" def delete_message(self, message_id): return None def move_message(self, message_id, folder_name): return None def keepalive(self): return None def watch(self, check_callback, check_timeout): return None class TestMailboxPerformance(unittest.TestCase): def testBatchModeAvoidsExtraFullFetch(self): connection = _DummyMailboxConnection() parsedmarc.get_dmarc_reports_from_mailbox( connection=connection, reports_folder="INBOX", test=True, batch_size=10, create_folders=False, ) self.assertEqual(len(connection.fetch_calls), 1) @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.MSGraphConnection") def testCliPassesMsGraphCertificateAuthSettings( self, mock_graph_connection, mock_get_mailbox_reports ): mock_graph_connection.return_value = object() mock_get_mailbox_reports.return_value = { "aggregate_reports": [], "forensic_reports": [], "smtp_tls_reports": [], } config_text = """[general] silent = true [msgraph] auth_method = Certificate client_id = client-id tenant_id = tenant-id mailbox = shared@example.com certificate_path = /tmp/msgraph-cert.pem certificate_password = cert-pass """ with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg: cfg.write(config_text) cfg_path = cfg.name self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]): parsedmarc.cli._main() self.assertEqual( mock_graph_connection.call_args.kwargs.get("auth_method"), "Certificate" ) self.assertEqual( mock_graph_connection.call_args.kwargs.get("certificate_path"), "/tmp/msgraph-cert.pem", ) self.assertEqual( mock_graph_connection.call_args.kwargs.get("certificate_password"), "cert-pass", ) @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") @patch("parsedmarc.cli.MSGraphConnection") @patch("parsedmarc.cli.logger") def testCliRequiresMsGraphCertificatePath( self, mock_logger, mock_graph_connection, mock_get_mailbox_reports ): config_text = """[general] silent = true [msgraph] auth_method = Certificate client_id = client-id tenant_id = tenant-id mailbox = shared@example.com """ with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg: cfg.write(config_text) cfg_path = cfg.name self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path)) with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]): with self.assertRaises(SystemExit) as system_exit: parsedmarc.cli._main() self.assertEqual(system_exit.exception.code, -1) mock_logger.critical.assert_called_once_with( "certificate_path setting missing from the msgraph config section" ) mock_graph_connection.assert_not_called() mock_get_mailbox_reports.assert_not_called() class _FakeGraphClient: def get(self, url, params=None): if "/mailFolders/inbox?$select=id,displayName" in url: return _FakeGraphResponse(200, {"id": "inbox-id", "displayName": "Inbox"}) if "/mailFolders?$filter=displayName eq 'Inbox'" in url: return _FakeGraphResponse( 404, { "error": { "code": "ErrorItemNotFound", "message": "Default folder Root not found.", } }, ) if "/mailFolders?$filter=displayName eq 'Custom'" in url: return _FakeGraphResponse( 404, { "error": { "code": "ErrorItemNotFound", "message": "Default folder Root not found.", } }, ) return _FakeGraphResponse(404, {"error": {"code": "NotFound"}}) class TestMSGraphFolderFallback(unittest.TestCase): def testWellKnownFolderFallback(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "shared@example.com" connection._client = _FakeGraphClient() connection._request_with_retries = MagicMock( side_effect=lambda method_name, *args, **kwargs: getattr( connection._client, method_name )( *args, **kwargs ) ) folder_id = connection._find_folder_id_with_parent("Inbox", None) self.assertEqual(folder_id, "inbox-id") connection._request_with_retries.assert_any_call( "get", "/users/shared@example.com/mailFolders?$filter=displayName eq 'Inbox'" ) connection._request_with_retries.assert_any_call( "get", "/users/shared@example.com/mailFolders/inbox?$select=id,displayName" ) def testUnknownFolderStillFails(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "shared@example.com" connection._client = _FakeGraphClient() connection._request_with_retries = MagicMock( side_effect=lambda method_name, *args, **kwargs: getattr( connection._client, method_name )( *args, **kwargs ) ) with self.assertRaises(RuntimeWarning): connection._find_folder_id_from_folder_path("Custom") def testSingleSegmentPathAvoidsExtraWellKnownLookupWhenListingSucceeds(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "shared@example.com" connection._find_folder_id_with_parent = MagicMock(return_value="custom-id") connection._get_well_known_folder_id = MagicMock(return_value="inbox-id") folder_id = connection._find_folder_id_from_folder_path("Inbox") self.assertEqual(folder_id, "custom-id") connection._find_folder_id_with_parent.assert_called_once_with("Inbox", None) connection._get_well_known_folder_id.assert_not_called() if __name__ == "__main__": unittest.main(verbosity=2)