Fix ruff formatting errors, duplicate import, and test mock key names

Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-03-09 21:40:28 +00:00
parent b4b90e763d
commit fce8e2247b
6 changed files with 291 additions and 161 deletions

View File

@@ -825,18 +825,14 @@ def parse_aggregate_report_xml(
if policy_published["np"] is not None:
np_ = policy_published["np"]
if np_ not in ("none", "quarantine", "reject"):
logger.warning(
"Invalid np value: {0}".format(np_)
)
logger.warning("Invalid np value: {0}".format(np_))
new_policy_published["np"] = np_
testing = None
if "testing" in policy_published:
if policy_published["testing"] is not None:
testing = policy_published["testing"]
if testing not in ("n", "y"):
logger.warning(
"Invalid testing value: {0}".format(testing)
)
logger.warning("Invalid testing value: {0}".format(testing))
new_policy_published["testing"] = testing
discovery_method = None
if "discovery_method" in policy_published:
@@ -844,9 +840,7 @@ def parse_aggregate_report_xml(
discovery_method = policy_published["discovery_method"]
if discovery_method not in ("psl", "treewalk"):
logger.warning(
"Invalid discovery_method value: {0}".format(
discovery_method
)
"Invalid discovery_method value: {0}".format(discovery_method)
)
new_policy_published["discovery_method"] = discovery_method
new_report["policy_published"] = new_policy_published
@@ -1107,9 +1101,7 @@ def parsed_aggregate_reports_to_csv_rows(
fo = report["policy_published"]["fo"]
np_ = report["policy_published"].get("np", None)
testing = report["policy_published"].get("testing", None)
discovery_method = report["policy_published"].get(
"discovery_method", None
)
discovery_method = report["policy_published"].get("discovery_method", None)
report_dict: dict[str, Any] = dict(
xml_schema=xml_schema,
@@ -2377,9 +2369,7 @@ def save_output(
parsed_aggregate_reports_to_csv(aggregate_reports),
)
append_json(
os.path.join(output_directory, failure_json_filename), failure_reports
)
append_json(os.path.join(output_directory, failure_json_filename), failure_reports)
append_csv(
os.path.join(output_directory, failure_csv_filename),

View File

@@ -1325,9 +1325,7 @@ def _main():
opts.la_dcr_aggregate_stream = log_analytics_config.get(
"dcr_aggregate_stream"
)
opts.la_dcr_failure_stream = log_analytics_config.get(
"dcr_failure_stream"
)
opts.la_dcr_failure_stream = log_analytics_config.get("dcr_failure_stream")
if opts.la_dcr_failure_stream is None:
opts.la_dcr_failure_stream = log_analytics_config.get(
"dcr_forensic_stream"

View File

@@ -105,23 +105,33 @@ class _AggregateReportDoc(Document):
self.policy_overrides.append(_PolicyOverride(type=type_, comment=comment)) # pyright: ignore[reportCallIssue]
def add_dkim_result(
self, domain: str, selector: str, result: _DKIMResult,
self,
domain: str,
selector: str,
result: _DKIMResult,
human_result: str = None,
):
self.dkim_results.append(
_DKIMResult(
domain=domain, selector=selector, result=result,
domain=domain,
selector=selector,
result=result,
human_result=human_result,
)
) # pyright: ignore[reportCallIssue]
def add_spf_result(
self, domain: str, scope: str, result: _SPFResult,
self,
domain: str,
scope: str,
result: _SPFResult,
human_result: str = None,
):
self.spf_results.append(
_SPFResult(
domain=domain, scope=scope, result=result,
domain=domain,
scope=scope,
result=result,
human_result=human_result,
)
) # pyright: ignore[reportCallIssue]
@@ -480,9 +490,7 @@ def save_aggregate_report_to_elasticsearch(
fo=aggregate_report["policy_published"]["fo"],
np=aggregate_report["policy_published"].get("np"),
testing=aggregate_report["policy_published"].get("testing"),
discovery_method=aggregate_report["policy_published"].get(
"discovery_method"
),
discovery_method=aggregate_report["policy_published"].get("discovery_method"),
)
for record in aggregate_report["records"]:
@@ -612,15 +620,12 @@ def save_failure_report_to_elasticsearch(
arrival_date_epoch_milliseconds = int(arrival_date.timestamp() * 1000)
if index_suffix is not None:
search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(
index_suffix
)
search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(index_suffix)
else:
search_index = "dmarc_failure*,dmarc_forensic*"
if index_prefix is not None:
search_index = ",".join(
"{0}{1}".format(index_prefix, part)
for part in search_index.split(",")
"{0}{1}".format(index_prefix, part) for part in search_index.split(",")
)
search = Search(index=search_index)
q = Q(dict(match=dict(arrival_date=arrival_date_epoch_milliseconds))) # pyright: ignore[reportArgumentType]

View File

@@ -105,23 +105,33 @@ class _AggregateReportDoc(Document):
self.policy_overrides.append(_PolicyOverride(type=type_, comment=comment))
def add_dkim_result(
self, domain: str, selector: str, result: _DKIMResult,
self,
domain: str,
selector: str,
result: _DKIMResult,
human_result: str = None,
):
self.dkim_results.append(
_DKIMResult(
domain=domain, selector=selector, result=result,
domain=domain,
selector=selector,
result=result,
human_result=human_result,
)
)
def add_spf_result(
self, domain: str, scope: str, result: _SPFResult,
self,
domain: str,
scope: str,
result: _SPFResult,
human_result: str = None,
):
self.spf_results.append(
_SPFResult(
domain=domain, scope=scope, result=result,
domain=domain,
scope=scope,
result=result,
human_result=human_result,
)
)
@@ -480,9 +490,7 @@ def save_aggregate_report_to_opensearch(
fo=aggregate_report["policy_published"]["fo"],
np=aggregate_report["policy_published"].get("np"),
testing=aggregate_report["policy_published"].get("testing"),
discovery_method=aggregate_report["policy_published"].get(
"discovery_method"
),
discovery_method=aggregate_report["policy_published"].get("discovery_method"),
)
for record in aggregate_report["records"]:
@@ -612,15 +620,12 @@ def save_failure_report_to_opensearch(
arrival_date_epoch_milliseconds = int(arrival_date.timestamp() * 1000)
if index_suffix is not None:
search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(
index_suffix
)
search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(index_suffix)
else:
search_index = "dmarc_failure*,dmarc_forensic*"
if index_prefix is not None:
search_index = ",".join(
"{0}{1}".format(index_prefix, part)
for part in search_index.split(",")
"{0}{1}".format(index_prefix, part) for part in search_index.split(",")
)
search = Search(index=search_index)
q = Q(dict(match=dict(arrival_date=arrival_date_epoch_milliseconds)))
@@ -665,9 +670,7 @@ def save_failure_report_to_opensearch(
"A failure sample to {0} from {1} "
"with a subject of {2} and arrival date of {3} "
"already exists in "
"OpenSearch".format(
to_, from_, subject, failure_report["arrival_date_utc"]
)
"OpenSearch".format(to_, from_, subject, failure_report["arrival_date_utc"])
)
parsed_sample = failure_report["parsed_sample"]

View File

@@ -66,4 +66,6 @@ class WebhookClient(object):
# Backward-compatible aliases
WebhookClient.save_forensic_report_to_webhook = WebhookClient.save_failure_report_to_webhook
WebhookClient.save_forensic_report_to_webhook = (
WebhookClient.save_failure_report_to_webhook
)

362
tests.py
View File

@@ -12,7 +12,6 @@ import unittest
from datetime import datetime, timedelta, timezone
from glob import glob
from base64 import urlsafe_b64encode
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest.mock import MagicMock, patch
@@ -193,9 +192,7 @@ class Test(unittest.TestCase):
def testDMARCbisDraftSample(self):
"""Test parsing the sample report from the DMARCbis aggregate draft"""
print()
sample_path = (
"samples/aggregate/dmarcbis-draft-sample.xml"
)
sample_path = "samples/aggregate/dmarcbis-draft-sample.xml"
print("Testing {0}: ".format(sample_path), end="")
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=True
@@ -211,15 +208,9 @@ class Test(unittest.TestCase):
# Verify report_metadata
metadata = report["report_metadata"]
self.assertEqual(metadata["org_name"], "Sample Reporter")
self.assertEqual(
metadata["org_email"], "report_sender@example-reporter.com"
)
self.assertEqual(
metadata["org_extra_contact_info"], "..."
)
self.assertEqual(
metadata["report_id"], "3v98abbp8ya9n3va8yr8oa3ya"
)
self.assertEqual(metadata["org_email"], "report_sender@example-reporter.com")
self.assertEqual(metadata["org_extra_contact_info"], "...")
self.assertEqual(metadata["report_id"], "3v98abbp8ya9n3va8yr8oa3ya")
self.assertEqual(
metadata["generator"],
"Example DMARC Aggregate Reporter v1.2",
@@ -245,9 +236,7 @@ class Test(unittest.TestCase):
rec = report["records"][0]
self.assertEqual(rec["source"]["ip_address"], "192.0.2.123")
self.assertEqual(rec["count"], 123)
self.assertEqual(
rec["policy_evaluated"]["disposition"], "pass"
)
self.assertEqual(rec["policy_evaluated"]["disposition"], "pass")
self.assertEqual(rec["policy_evaluated"]["dkim"], "pass")
self.assertEqual(rec["policy_evaluated"]["spf"], "fail")
@@ -278,8 +267,7 @@ class Test(unittest.TestCase):
"""Test that RFC 7489 reports have None for DMARCbis-only fields"""
print()
sample_path = (
"samples/aggregate/"
"example.net!example.com!1529366400!1529452799.xml"
"samples/aggregate/example.net!example.com!1529366400!1529452799.xml"
)
print("Testing {0}: ".format(sample_path), end="")
result = parsedmarc.parse_report_file(
@@ -471,7 +459,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": None,
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {"dkim": [], "spf": []},
@@ -485,7 +477,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "5",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "fail"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "fail",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {},
@@ -547,9 +543,16 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identities": {
"header_from": "Example.COM",
"envelope_from": "example.com",
},
"identities": {"header_from": "Example.COM", "envelope_from": "example.com"},
"auth_results": {"dkim": [], "spf": []},
}
result = parsedmarc._parse_report_record(record, offline=True)
@@ -562,7 +565,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "fail", "spf": "fail"},
"policy_evaluated": {
"disposition": "none",
"dkim": "fail",
"spf": "fail",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {
@@ -582,7 +589,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "fail", "spf": "fail"},
"policy_evaluated": {
"disposition": "none",
"dkim": "fail",
"spf": "fail",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {
@@ -602,19 +613,37 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {
"dkim": [{"domain": "example.com", "selector": "s1",
"result": "pass", "human_result": "good key"}],
"spf": [{"domain": "example.com", "scope": "mfrom",
"result": "pass", "human_result": "sender valid"}],
"dkim": [
{
"domain": "example.com",
"selector": "s1",
"result": "pass",
"human_result": "good key",
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": "sender valid",
}
],
},
}
result = parsedmarc._parse_report_record(record, offline=True)
self.assertEqual(result["auth_results"]["dkim"][0]["human_result"], "good key")
self.assertEqual(result["auth_results"]["spf"][0]["human_result"], "sender valid")
self.assertEqual(
result["auth_results"]["spf"][0]["human_result"], "sender valid"
)
def testParseReportRecordEnvelopeFromFallback(self):
"""envelope_from falls back to last SPF domain when missing"""
@@ -622,12 +651,18 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {
"dkim": [],
"spf": [{"domain": "Bounce.Example.COM", "scope": "mfrom", "result": "pass"}],
"spf": [
{"domain": "Bounce.Example.COM", "scope": "mfrom", "result": "pass"}
],
},
}
result = parsedmarc._parse_report_record(record, offline=True)
@@ -639,7 +674,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identifiers": {
"header_from": "example.com",
@@ -647,7 +686,9 @@ class Test(unittest.TestCase):
},
"auth_results": {
"dkim": [],
"spf": [{"domain": "SPF.Example.COM", "scope": "mfrom", "result": "pass"}],
"spf": [
{"domain": "SPF.Example.COM", "scope": "mfrom", "result": "pass"}
],
},
}
result = parsedmarc._parse_report_record(record, offline=True)
@@ -659,7 +700,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
},
},
"identifiers": {
"header_from": "example.com",
@@ -677,7 +722,11 @@ class Test(unittest.TestCase):
"row": {
"source_ip": "192.0.2.1",
"count": "1",
"policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "fail"},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "fail",
},
},
"identifiers": {"header_from": "example.com"},
"auth_results": {"dkim": [], "spf": []},
@@ -806,7 +855,9 @@ class Test(unittest.TestCase):
}
result = parsedmarc._parse_smtp_tls_report_policy(policy)
self.assertEqual(len(result["failure_details"]), 1)
self.assertEqual(result["failure_details"][0]["result_type"], "certificate-expired")
self.assertEqual(
result["failure_details"][0]["result_type"], "certificate-expired"
)
def testParseSmtpTlsReportPolicyMissingField(self):
"""Missing required policy field raises InvalidSMTPTLSReport"""
@@ -820,27 +871,29 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonValid(self):
"""Valid SMTP TLS JSON report parses correctly"""
report = json.dumps({
"organization-name": "Example Corp",
"date-range": {
"start-datetime": "2024-01-01T00:00:00Z",
"end-datetime": "2024-01-02T00:00:00Z",
},
"contact-info": "admin@example.com",
"report-id": "report-123",
"policies": [
{
"policy": {
"policy-type": "sts",
"policy-domain": "example.com",
},
"summary": {
"total-successful-session-count": 50,
"total-failure-session-count": 0,
},
}
],
})
report = json.dumps(
{
"organization-name": "Example Corp",
"date-range": {
"start-datetime": "2024-01-01T00:00:00Z",
"end-datetime": "2024-01-02T00:00:00Z",
},
"contact-info": "admin@example.com",
"report-id": "report-123",
"policies": [
{
"policy": {
"policy-type": "sts",
"policy-domain": "example.com",
},
"summary": {
"total-successful-session-count": 50,
"total-failure-session-count": 0,
},
}
],
}
)
result = parsedmarc.parse_smtp_tls_report_json(report)
self.assertEqual(result["organization_name"], "Example Corp")
self.assertEqual(result["report_id"], "report-123")
@@ -848,16 +901,26 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonBytes(self):
"""SMTP TLS report as bytes parses correctly"""
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": [{
"policy": {"policy-type": "tlsa", "policy-domain": "a.com"},
"summary": {"total-successful-session-count": 1, "total-failure-session-count": 0},
}],
}).encode("utf-8")
report = json.dumps(
{
"organization-name": "Org",
"date-range": {
"start-datetime": "2024-01-01",
"end-datetime": "2024-01-02",
},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": [
{
"policy": {"policy-type": "tlsa", "policy-domain": "a.com"},
"summary": {
"total-successful-session-count": 1,
"total-failure-session-count": 0,
},
}
],
}
).encode("utf-8")
result = parsedmarc.parse_smtp_tls_report_json(report)
self.assertEqual(result["organization_name"], "Org")
@@ -869,13 +932,18 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonPoliciesNotList(self):
"""Non-list policies raises InvalidSMTPTLSReport"""
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": "not-a-list",
})
report = json.dumps(
{
"organization-name": "Org",
"date-range": {
"start-datetime": "2024-01-01",
"end-datetime": "2024-01-02",
},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": "not-a-list",
}
)
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
parsedmarc.parse_smtp_tls_report_json(report)
@@ -954,7 +1022,9 @@ class Test(unittest.TestCase):
</record>
</feedback>"""
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
self.assertEqual(report["records"][0]["policy_evaluated"]["disposition"], "pass")
self.assertEqual(
report["records"][0]["policy_evaluated"]["disposition"], "pass"
)
def testAggregateReportMultipleRecords(self):
"""Reports with multiple records are all parsed"""
@@ -1004,7 +1074,8 @@ class Test(unittest.TestCase):
"""CSV rows include np, testing, discovery_method columns"""
result = parsedmarc.parse_report_file(
"samples/aggregate/dmarcbis-draft-sample.xml",
always_use_local_files=True, offline=True,
always_use_local_files=True,
offline=True,
)
report = result["report"]
rows = parsedmarc.parsed_aggregate_reports_to_csv_rows(report)
@@ -1178,6 +1249,7 @@ class Test(unittest.TestCase):
def testTimestampToDatetime(self):
"""timestamp_to_datetime converts UNIX timestamp to datetime"""
from datetime import datetime
dt = parsedmarc.utils.timestamp_to_datetime(0)
self.assertIsInstance(dt, datetime)
# Epoch 0 should be Jan 1 1970 in local time
@@ -1212,9 +1284,7 @@ class Test(unittest.TestCase):
def testHumanTimestampToDatetimeNegativeZero(self):
"""-0000 timezone is handled"""
dt = parsedmarc.utils.human_timestamp_to_datetime(
"2024-01-01 00:00:00 -0000"
)
dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00 -0000")
self.assertEqual(dt.year, 2024)
def testHumanTimestampToUnixTimestamp(self):
@@ -1264,14 +1334,19 @@ class Test(unittest.TestCase):
def testGetIpAddressInfoCache(self):
"""get_ip_address_info uses cache on second call"""
from expiringdict import ExpiringDict
cache = ExpiringDict(max_len=100, max_age_seconds=60)
with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"):
info1 = parsedmarc.utils.get_ip_address_info(
"8.8.8.8", offline=False, cache=cache,
"8.8.8.8",
offline=False,
cache=cache,
always_use_local_files=True,
)
self.assertIn("8.8.8.8", cache)
info2 = parsedmarc.utils.get_ip_address_info("8.8.8.8", offline=False, cache=cache)
info2 = parsedmarc.utils.get_ip_address_info(
"8.8.8.8", offline=False, cache=cache
)
self.assertEqual(info1["ip_address"], info2["ip_address"])
self.assertEqual(info2["reverse_dns"], "dns.google")
@@ -1340,6 +1415,7 @@ class Test(unittest.TestCase):
def testWebhookClientInit(self):
"""WebhookClient initializes with correct attributes"""
from parsedmarc.webhook import WebhookClient
client = WebhookClient(
aggregate_url="http://agg.example.com",
failure_url="http://fail.example.com",
@@ -1353,18 +1429,26 @@ class Test(unittest.TestCase):
def testWebhookClientSaveMethods(self):
"""WebhookClient save methods call _send_to_webhook"""
from parsedmarc.webhook import WebhookClient
client = WebhookClient("http://a", "http://f", "http://t")
client.session = MagicMock()
client.save_aggregate_report_to_webhook('{"test": 1}')
client.session.post.assert_called_with("http://a", data='{"test": 1}', timeout=60)
client.session.post.assert_called_with(
"http://a", data='{"test": 1}', timeout=60
)
client.save_failure_report_to_webhook('{"fail": 1}')
client.session.post.assert_called_with("http://f", data='{"fail": 1}', timeout=60)
client.session.post.assert_called_with(
"http://f", data='{"fail": 1}', timeout=60
)
client.save_smtp_tls_report_to_webhook('{"tls": 1}')
client.session.post.assert_called_with("http://t", data='{"tls": 1}', timeout=60)
client.session.post.assert_called_with(
"http://t", data='{"tls": 1}', timeout=60
)
def testWebhookBackwardCompatAlias(self):
"""WebhookClient forensic alias points to failure method"""
from parsedmarc.webhook import WebhookClient
self.assertIs(
WebhookClient.save_forensic_report_to_webhook,
WebhookClient.save_failure_report_to_webhook,
@@ -1373,6 +1457,7 @@ class Test(unittest.TestCase):
def testKafkaStripMetadata(self):
"""KafkaClient.strip_metadata extracts metadata to root"""
from parsedmarc.kafkaclient import KafkaClient
report = {
"report_metadata": {
"org_name": "TestOrg",
@@ -1392,6 +1477,7 @@ class Test(unittest.TestCase):
def testKafkaGenerateDateRange(self):
"""KafkaClient.generate_date_range generates date range list"""
from parsedmarc.kafkaclient import KafkaClient
report = {
"report_metadata": {
"begin_date": "2024-01-01 00:00:00",
@@ -1406,6 +1492,7 @@ class Test(unittest.TestCase):
def testSplunkHECClientInit(self):
"""HECClient initializes with correct URL and headers"""
from parsedmarc.splunk import HECClient
client = HECClient(
url="https://splunk.example.com:8088",
access_token="my-token",
@@ -1420,6 +1507,7 @@ class Test(unittest.TestCase):
def testSplunkHECClientStripTokenPrefix(self):
"""HECClient strips 'Splunk ' prefix from token"""
from parsedmarc.splunk import HECClient
client = HECClient(
url="https://splunk.example.com",
access_token="Splunk my-token",
@@ -1430,6 +1518,7 @@ class Test(unittest.TestCase):
def testSplunkBackwardCompatAlias(self):
"""HECClient forensic alias points to failure method"""
from parsedmarc.splunk import HECClient
self.assertIs(
HECClient.save_forensic_reports_to_splunk,
HECClient.save_failure_reports_to_splunk,
@@ -1438,6 +1527,7 @@ class Test(unittest.TestCase):
def testSyslogClientUdpInit(self):
"""SyslogClient creates UDP handler"""
from parsedmarc.syslog import SyslogClient
client = SyslogClient("localhost", 514, protocol="udp")
self.assertEqual(client.server_name, "localhost")
self.assertEqual(client.server_port, 514)
@@ -1446,12 +1536,14 @@ class Test(unittest.TestCase):
def testSyslogClientInvalidProtocol(self):
"""SyslogClient with invalid protocol raises ValueError"""
from parsedmarc.syslog import SyslogClient
with self.assertRaises(ValueError):
SyslogClient("localhost", 514, protocol="invalid")
def testSyslogBackwardCompatAlias(self):
"""SyslogClient forensic alias points to failure method"""
from parsedmarc.syslog import SyslogClient
self.assertIs(
SyslogClient.save_forensic_report_to_syslog,
SyslogClient.save_failure_report_to_syslog,
@@ -1460,6 +1552,7 @@ class Test(unittest.TestCase):
def testLogAnalyticsConfig(self):
"""LogAnalyticsConfig stores all fields"""
from parsedmarc.loganalytics import LogAnalyticsConfig
config = LogAnalyticsConfig(
client_id="cid",
client_secret="csec",
@@ -1482,6 +1575,7 @@ class Test(unittest.TestCase):
def testLogAnalyticsClientValidationError(self):
"""LogAnalyticsClient raises on missing required config"""
from parsedmarc.loganalytics import LogAnalyticsClient, LogAnalyticsException
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(
client_id="",
@@ -1496,18 +1590,34 @@ class Test(unittest.TestCase):
def testSmtpTlsCsvRows(self):
"""parsed_smtp_tls_reports_to_csv_rows produces correct rows"""
report_json = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": [{
"policy": {"policy-type": "sts", "policy-domain": "example.com",
"policy-string": ["v: STSv1"], "mx-host-pattern": ["*.example.com"]},
"summary": {"total-successful-session-count": 10, "total-failure-session-count": 1},
"failure-details": [{"result-type": "cert-expired", "failed-session-count": 1}],
}],
})
report_json = json.dumps(
{
"organization-name": "Org",
"date-range": {
"start-datetime": "2024-01-01T00:00:00Z",
"end-datetime": "2024-01-02T00:00:00Z",
},
"contact-info": "a@b.com",
"report-id": "r1",
"policies": [
{
"policy": {
"policy-type": "sts",
"policy-domain": "example.com",
"policy-string": ["v: STSv1"],
"mx-host-pattern": ["*.example.com"],
},
"summary": {
"total-successful-session-count": 10,
"total-failure-session-count": 1,
},
"failure-details": [
{"result-type": "cert-expired", "failed-session-count": 1}
],
}
],
}
)
parsed = parsedmarc.parse_smtp_tls_report_json(report_json)
rows = parsedmarc.parsed_smtp_tls_reports_to_csv_rows(parsed)
self.assertTrue(len(rows) >= 2)
@@ -1518,7 +1628,8 @@ class Test(unittest.TestCase):
"""parsed_aggregate_reports_to_csv_rows handles list of reports"""
result = parsedmarc.parse_report_file(
"samples/aggregate/dmarcbis-draft-sample.xml",
always_use_local_files=True, offline=True,
always_use_local_files=True,
offline=True,
)
report = result["report"]
# Pass as a list
@@ -1532,10 +1643,18 @@ class Test(unittest.TestCase):
def testExceptionHierarchy(self):
"""Exception class hierarchy is correct"""
self.assertTrue(issubclass(parsedmarc.ParserError, RuntimeError))
self.assertTrue(issubclass(parsedmarc.InvalidDMARCReport, parsedmarc.ParserError))
self.assertTrue(issubclass(parsedmarc.InvalidAggregateReport, parsedmarc.InvalidDMARCReport))
self.assertTrue(issubclass(parsedmarc.InvalidFailureReport, parsedmarc.InvalidDMARCReport))
self.assertTrue(issubclass(parsedmarc.InvalidSMTPTLSReport, parsedmarc.ParserError))
self.assertTrue(
issubclass(parsedmarc.InvalidDMARCReport, parsedmarc.ParserError)
)
self.assertTrue(
issubclass(parsedmarc.InvalidAggregateReport, parsedmarc.InvalidDMARCReport)
)
self.assertTrue(
issubclass(parsedmarc.InvalidFailureReport, parsedmarc.InvalidDMARCReport)
)
self.assertTrue(
issubclass(parsedmarc.InvalidSMTPTLSReport, parsedmarc.ParserError)
)
self.assertIs(parsedmarc.InvalidForensicReport, parsedmarc.InvalidFailureReport)
def testAggregateReportNormalization(self):
@@ -1579,6 +1698,7 @@ class Test(unittest.TestCase):
def testGelfBackwardCompatAlias(self):
"""GelfClient forensic alias points to failure method"""
from parsedmarc.gelf import GelfClient
self.assertIs(
GelfClient.save_forensic_report_to_gelf,
GelfClient.save_failure_report_to_gelf,
@@ -1587,6 +1707,7 @@ class Test(unittest.TestCase):
def testS3BackwardCompatAlias(self):
"""S3Client forensic alias points to failure method"""
from parsedmarc.s3 import S3Client
self.assertIs(
S3Client.save_forensic_report_to_s3,
S3Client.save_failure_report_to_s3,
@@ -1595,6 +1716,7 @@ class Test(unittest.TestCase):
def testKafkaBackwardCompatAlias(self):
"""KafkaClient forensic alias points to failure method"""
from parsedmarc.kafkaclient import KafkaClient
self.assertIs(
KafkaClient.save_forensic_reports_to_kafka,
KafkaClient.save_failure_reports_to_kafka,
@@ -1622,7 +1744,9 @@ class Test(unittest.TestCase):
with open(sample_path, "rb") as f:
data = f.read()
report = parsedmarc.parse_aggregate_report_file(
data, offline=True, always_use_local_files=True,
data,
offline=True,
always_use_local_files=True,
)
self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter")
self.assertEqual(report["policy_published"]["domain"], "example.com")
@@ -1755,9 +1879,7 @@ class TestGmailConnection(unittest.TestCase):
"from_authorized_user_file",
return_value=creds,
):
returned = _get_creds(
token_path, "credentials.json", ["scope"], 8080
)
returned = _get_creds(token_path, "credentials.json", ["scope"], 8080)
finally:
os.remove(token_path)
self.assertEqual(returned, creds)
@@ -1811,9 +1933,7 @@ class TestGmailConnection(unittest.TestCase):
"from_authorized_user_file",
return_value=expired_creds,
):
returned = _get_creds(
token_path, "credentials.json", ["scope"], 8080
)
returned = _get_creds(token_path, "credentials.json", ["scope"], 8080)
finally:
os.remove(token_path)
@@ -1890,7 +2010,9 @@ class TestGraphConnection(unittest.TestCase):
def testFetchMessagesPassesSinceAndBatchSize(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._find_folder_id_from_folder_path = MagicMock(return_value="folder-id")
connection._find_folder_id_from_folder_path = MagicMock(
return_value="folder-id"
)
connection._get_all_messages = MagicMock(return_value=[{"id": "1"}])
self.assertEqual(
connection.fetch_messages("Inbox", since="2026-03-01", batch_size=5), ["1"]
@@ -1951,7 +2073,9 @@ class TestGraphConnection(unittest.TestCase):
def testGenerateCredentialDeviceCode(self):
fake_credential = object()
with patch.object(graph_module, "_get_cache_args", return_value={"cached": True}):
with patch.object(
graph_module, "_get_cache_args", return_value={"cached": True}
):
with patch.object(
graph_module,
"DeviceCodeCredential",
@@ -2098,8 +2222,12 @@ class TestImapConnection(unittest.TestCase):
with self.assertRaises(_BreakLoop):
connection.watch(callback, check_timeout=1)
callback.assert_called_once_with(connection)
class TestGmailAuthModes(unittest.TestCase):
@patch("parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file")
@patch(
"parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file"
)
def testGetCredsServiceAccountWithoutSubject(self, mock_from_service_account_file):
service_creds = MagicMock()
service_creds.with_subject.return_value = MagicMock()
@@ -2121,7 +2249,9 @@ class TestGmailAuthModes(unittest.TestCase):
)
service_creds.with_subject.assert_not_called()
@patch("parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file")
@patch(
"parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file"
)
def testGetCredsServiceAccountWithSubject(self, mock_from_service_account_file):
base_creds = MagicMock()
delegated_creds = MagicMock()
@@ -2212,7 +2342,7 @@ class TestGmailAuthModes(unittest.TestCase):
mock_gmail_connection.return_value = MagicMock()
mock_get_mailbox_reports.return_value = {
"aggregate_reports": [],
"forensic_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
}
config = """[general]
@@ -2246,7 +2376,7 @@ scopes = https://www.googleapis.com/auth/gmail.modify
mock_gmail_connection.return_value = MagicMock()
mock_get_reports.return_value = {
"aggregate_reports": [],
"forensic_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
}
config = """[general]
@@ -2270,5 +2400,7 @@ scopes = https://www.googleapis.com/auth/gmail.modify
mock_gmail_connection.call_args.kwargs.get("service_account_user"),
"delegated@example.com",
)
if __name__ == "__main__":
unittest.main(verbosity=2)