From fce8e2247b0d9cb4c85139b4e43bb5d2bcac7b37 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:40:28 +0000 Subject: [PATCH] Fix ruff formatting errors, duplicate import, and test mock key names Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com> --- parsedmarc/__init__.py | 20 +-- parsedmarc/cli.py | 4 +- parsedmarc/elastic.py | 29 ++-- parsedmarc/opensearch.py | 33 ++-- parsedmarc/webhook.py | 4 +- tests.py | 362 ++++++++++++++++++++++++++------------- 6 files changed, 291 insertions(+), 161 deletions(-) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index df6078a..0bbb447 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -825,18 +825,14 @@ def parse_aggregate_report_xml( if policy_published["np"] is not None: np_ = policy_published["np"] if np_ not in ("none", "quarantine", "reject"): - logger.warning( - "Invalid np value: {0}".format(np_) - ) + logger.warning("Invalid np value: {0}".format(np_)) new_policy_published["np"] = np_ testing = None if "testing" in policy_published: if policy_published["testing"] is not None: testing = policy_published["testing"] if testing not in ("n", "y"): - logger.warning( - "Invalid testing value: {0}".format(testing) - ) + logger.warning("Invalid testing value: {0}".format(testing)) new_policy_published["testing"] = testing discovery_method = None if "discovery_method" in policy_published: @@ -844,9 +840,7 @@ def parse_aggregate_report_xml( discovery_method = policy_published["discovery_method"] if discovery_method not in ("psl", "treewalk"): logger.warning( - "Invalid discovery_method value: {0}".format( - discovery_method - ) + "Invalid discovery_method value: {0}".format(discovery_method) ) new_policy_published["discovery_method"] = discovery_method new_report["policy_published"] = new_policy_published @@ -1107,9 +1101,7 @@ def parsed_aggregate_reports_to_csv_rows( fo = report["policy_published"]["fo"] np_ = report["policy_published"].get("np", None) testing = report["policy_published"].get("testing", None) - discovery_method = report["policy_published"].get( - "discovery_method", None - ) + discovery_method = report["policy_published"].get("discovery_method", None) report_dict: dict[str, Any] = dict( xml_schema=xml_schema, @@ -2377,9 +2369,7 @@ def save_output( parsed_aggregate_reports_to_csv(aggregate_reports), ) - append_json( - os.path.join(output_directory, failure_json_filename), failure_reports - ) + append_json(os.path.join(output_directory, failure_json_filename), failure_reports) append_csv( os.path.join(output_directory, failure_csv_filename), diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index 6250892..420897f 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -1325,9 +1325,7 @@ def _main(): opts.la_dcr_aggregate_stream = log_analytics_config.get( "dcr_aggregate_stream" ) - opts.la_dcr_failure_stream = log_analytics_config.get( - "dcr_failure_stream" - ) + opts.la_dcr_failure_stream = log_analytics_config.get("dcr_failure_stream") if opts.la_dcr_failure_stream is None: opts.la_dcr_failure_stream = log_analytics_config.get( "dcr_forensic_stream" diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index cc1f162..ec41e7d 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -105,23 +105,33 @@ class _AggregateReportDoc(Document): self.policy_overrides.append(_PolicyOverride(type=type_, comment=comment)) # pyright: ignore[reportCallIssue] def add_dkim_result( - self, domain: str, selector: str, result: _DKIMResult, + self, + domain: str, + selector: str, + result: _DKIMResult, human_result: str = None, ): self.dkim_results.append( _DKIMResult( - domain=domain, selector=selector, result=result, + domain=domain, + selector=selector, + result=result, human_result=human_result, ) ) # pyright: ignore[reportCallIssue] def add_spf_result( - self, domain: str, scope: str, result: _SPFResult, + self, + domain: str, + scope: str, + result: _SPFResult, human_result: str = None, ): self.spf_results.append( _SPFResult( - domain=domain, scope=scope, result=result, + domain=domain, + scope=scope, + result=result, human_result=human_result, ) ) # pyright: ignore[reportCallIssue] @@ -480,9 +490,7 @@ def save_aggregate_report_to_elasticsearch( fo=aggregate_report["policy_published"]["fo"], np=aggregate_report["policy_published"].get("np"), testing=aggregate_report["policy_published"].get("testing"), - discovery_method=aggregate_report["policy_published"].get( - "discovery_method" - ), + discovery_method=aggregate_report["policy_published"].get("discovery_method"), ) for record in aggregate_report["records"]: @@ -612,15 +620,12 @@ def save_failure_report_to_elasticsearch( arrival_date_epoch_milliseconds = int(arrival_date.timestamp() * 1000) if index_suffix is not None: - search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format( - index_suffix - ) + search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(index_suffix) else: search_index = "dmarc_failure*,dmarc_forensic*" if index_prefix is not None: search_index = ",".join( - "{0}{1}".format(index_prefix, part) - for part in search_index.split(",") + "{0}{1}".format(index_prefix, part) for part in search_index.split(",") ) search = Search(index=search_index) q = Q(dict(match=dict(arrival_date=arrival_date_epoch_milliseconds))) # pyright: ignore[reportArgumentType] diff --git a/parsedmarc/opensearch.py b/parsedmarc/opensearch.py index 9d77555..24a38a6 100644 --- a/parsedmarc/opensearch.py +++ b/parsedmarc/opensearch.py @@ -105,23 +105,33 @@ class _AggregateReportDoc(Document): self.policy_overrides.append(_PolicyOverride(type=type_, comment=comment)) def add_dkim_result( - self, domain: str, selector: str, result: _DKIMResult, + self, + domain: str, + selector: str, + result: _DKIMResult, human_result: str = None, ): self.dkim_results.append( _DKIMResult( - domain=domain, selector=selector, result=result, + domain=domain, + selector=selector, + result=result, human_result=human_result, ) ) def add_spf_result( - self, domain: str, scope: str, result: _SPFResult, + self, + domain: str, + scope: str, + result: _SPFResult, human_result: str = None, ): self.spf_results.append( _SPFResult( - domain=domain, scope=scope, result=result, + domain=domain, + scope=scope, + result=result, human_result=human_result, ) ) @@ -480,9 +490,7 @@ def save_aggregate_report_to_opensearch( fo=aggregate_report["policy_published"]["fo"], np=aggregate_report["policy_published"].get("np"), testing=aggregate_report["policy_published"].get("testing"), - discovery_method=aggregate_report["policy_published"].get( - "discovery_method" - ), + discovery_method=aggregate_report["policy_published"].get("discovery_method"), ) for record in aggregate_report["records"]: @@ -612,15 +620,12 @@ def save_failure_report_to_opensearch( arrival_date_epoch_milliseconds = int(arrival_date.timestamp() * 1000) if index_suffix is not None: - search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format( - index_suffix - ) + search_index = "dmarc_failure_{0}*,dmarc_forensic_{0}*".format(index_suffix) else: search_index = "dmarc_failure*,dmarc_forensic*" if index_prefix is not None: search_index = ",".join( - "{0}{1}".format(index_prefix, part) - for part in search_index.split(",") + "{0}{1}".format(index_prefix, part) for part in search_index.split(",") ) search = Search(index=search_index) q = Q(dict(match=dict(arrival_date=arrival_date_epoch_milliseconds))) @@ -665,9 +670,7 @@ def save_failure_report_to_opensearch( "A failure sample to {0} from {1} " "with a subject of {2} and arrival date of {3} " "already exists in " - "OpenSearch".format( - to_, from_, subject, failure_report["arrival_date_utc"] - ) + "OpenSearch".format(to_, from_, subject, failure_report["arrival_date_utc"]) ) parsed_sample = failure_report["parsed_sample"] diff --git a/parsedmarc/webhook.py b/parsedmarc/webhook.py index 685de63..61c1f35 100644 --- a/parsedmarc/webhook.py +++ b/parsedmarc/webhook.py @@ -66,4 +66,6 @@ class WebhookClient(object): # Backward-compatible aliases -WebhookClient.save_forensic_report_to_webhook = WebhookClient.save_failure_report_to_webhook +WebhookClient.save_forensic_report_to_webhook = ( + WebhookClient.save_failure_report_to_webhook +) diff --git a/tests.py b/tests.py index ee7d0a1..7d1bc2a 100755 --- a/tests.py +++ b/tests.py @@ -12,7 +12,6 @@ import unittest from datetime import datetime, timedelta, timezone from glob import glob from base64 import urlsafe_b64encode -from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from unittest.mock import MagicMock, patch @@ -193,9 +192,7 @@ class Test(unittest.TestCase): def testDMARCbisDraftSample(self): """Test parsing the sample report from the DMARCbis aggregate draft""" print() - sample_path = ( - "samples/aggregate/dmarcbis-draft-sample.xml" - ) + sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" print("Testing {0}: ".format(sample_path), end="") result = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=True @@ -211,15 +208,9 @@ class Test(unittest.TestCase): # Verify report_metadata metadata = report["report_metadata"] self.assertEqual(metadata["org_name"], "Sample Reporter") - self.assertEqual( - metadata["org_email"], "report_sender@example-reporter.com" - ) - self.assertEqual( - metadata["org_extra_contact_info"], "..." - ) - self.assertEqual( - metadata["report_id"], "3v98abbp8ya9n3va8yr8oa3ya" - ) + self.assertEqual(metadata["org_email"], "report_sender@example-reporter.com") + self.assertEqual(metadata["org_extra_contact_info"], "...") + self.assertEqual(metadata["report_id"], "3v98abbp8ya9n3va8yr8oa3ya") self.assertEqual( metadata["generator"], "Example DMARC Aggregate Reporter v1.2", @@ -245,9 +236,7 @@ class Test(unittest.TestCase): rec = report["records"][0] self.assertEqual(rec["source"]["ip_address"], "192.0.2.123") self.assertEqual(rec["count"], 123) - self.assertEqual( - rec["policy_evaluated"]["disposition"], "pass" - ) + self.assertEqual(rec["policy_evaluated"]["disposition"], "pass") self.assertEqual(rec["policy_evaluated"]["dkim"], "pass") self.assertEqual(rec["policy_evaluated"]["spf"], "fail") @@ -278,8 +267,7 @@ class Test(unittest.TestCase): """Test that RFC 7489 reports have None for DMARCbis-only fields""" print() sample_path = ( - "samples/aggregate/" - "example.net!example.com!1529366400!1529452799.xml" + "samples/aggregate/example.net!example.com!1529366400!1529452799.xml" ) print("Testing {0}: ".format(sample_path), end="") result = parsedmarc.parse_report_file( @@ -471,7 +459,11 @@ class Test(unittest.TestCase): "row": { "source_ip": None, "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, @@ -485,7 +477,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "5", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "fail"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "fail", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {}, @@ -547,9 +543,16 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, + }, + "identities": { + "header_from": "Example.COM", + "envelope_from": "example.com", }, - "identities": {"header_from": "Example.COM", "envelope_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, } result = parsedmarc._parse_report_record(record, offline=True) @@ -562,7 +565,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "fail", "spf": "fail"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "fail", + "spf": "fail", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { @@ -582,7 +589,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "fail", "spf": "fail"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "fail", + "spf": "fail", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { @@ -602,19 +613,37 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { - "dkim": [{"domain": "example.com", "selector": "s1", - "result": "pass", "human_result": "good key"}], - "spf": [{"domain": "example.com", "scope": "mfrom", - "result": "pass", "human_result": "sender valid"}], + "dkim": [ + { + "domain": "example.com", + "selector": "s1", + "result": "pass", + "human_result": "good key", + } + ], + "spf": [ + { + "domain": "example.com", + "scope": "mfrom", + "result": "pass", + "human_result": "sender valid", + } + ], }, } result = parsedmarc._parse_report_record(record, offline=True) self.assertEqual(result["auth_results"]["dkim"][0]["human_result"], "good key") - self.assertEqual(result["auth_results"]["spf"][0]["human_result"], "sender valid") + self.assertEqual( + result["auth_results"]["spf"][0]["human_result"], "sender valid" + ) def testParseReportRecordEnvelopeFromFallback(self): """envelope_from falls back to last SPF domain when missing""" @@ -622,12 +651,18 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": { "dkim": [], - "spf": [{"domain": "Bounce.Example.COM", "scope": "mfrom", "result": "pass"}], + "spf": [ + {"domain": "Bounce.Example.COM", "scope": "mfrom", "result": "pass"} + ], }, } result = parsedmarc._parse_report_record(record, offline=True) @@ -639,7 +674,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, }, "identifiers": { "header_from": "example.com", @@ -647,7 +686,9 @@ class Test(unittest.TestCase): }, "auth_results": { "dkim": [], - "spf": [{"domain": "SPF.Example.COM", "scope": "mfrom", "result": "pass"}], + "spf": [ + {"domain": "SPF.Example.COM", "scope": "mfrom", "result": "pass"} + ], }, } result = parsedmarc._parse_report_record(record, offline=True) @@ -659,7 +700,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "pass"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "pass", + }, }, "identifiers": { "header_from": "example.com", @@ -677,7 +722,11 @@ class Test(unittest.TestCase): "row": { "source_ip": "192.0.2.1", "count": "1", - "policy_evaluated": {"disposition": "none", "dkim": "pass", "spf": "fail"}, + "policy_evaluated": { + "disposition": "none", + "dkim": "pass", + "spf": "fail", + }, }, "identifiers": {"header_from": "example.com"}, "auth_results": {"dkim": [], "spf": []}, @@ -806,7 +855,9 @@ class Test(unittest.TestCase): } result = parsedmarc._parse_smtp_tls_report_policy(policy) self.assertEqual(len(result["failure_details"]), 1) - self.assertEqual(result["failure_details"][0]["result_type"], "certificate-expired") + self.assertEqual( + result["failure_details"][0]["result_type"], "certificate-expired" + ) def testParseSmtpTlsReportPolicyMissingField(self): """Missing required policy field raises InvalidSMTPTLSReport""" @@ -820,27 +871,29 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonValid(self): """Valid SMTP TLS JSON report parses correctly""" - report = json.dumps({ - "organization-name": "Example Corp", - "date-range": { - "start-datetime": "2024-01-01T00:00:00Z", - "end-datetime": "2024-01-02T00:00:00Z", - }, - "contact-info": "admin@example.com", - "report-id": "report-123", - "policies": [ - { - "policy": { - "policy-type": "sts", - "policy-domain": "example.com", - }, - "summary": { - "total-successful-session-count": 50, - "total-failure-session-count": 0, - }, - } - ], - }) + report = json.dumps( + { + "organization-name": "Example Corp", + "date-range": { + "start-datetime": "2024-01-01T00:00:00Z", + "end-datetime": "2024-01-02T00:00:00Z", + }, + "contact-info": "admin@example.com", + "report-id": "report-123", + "policies": [ + { + "policy": { + "policy-type": "sts", + "policy-domain": "example.com", + }, + "summary": { + "total-successful-session-count": 50, + "total-failure-session-count": 0, + }, + } + ], + } + ) result = parsedmarc.parse_smtp_tls_report_json(report) self.assertEqual(result["organization_name"], "Example Corp") self.assertEqual(result["report_id"], "report-123") @@ -848,16 +901,26 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonBytes(self): """SMTP TLS report as bytes parses correctly""" - report = json.dumps({ - "organization-name": "Org", - "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, - "contact-info": "a@b.com", - "report-id": "r1", - "policies": [{ - "policy": {"policy-type": "tlsa", "policy-domain": "a.com"}, - "summary": {"total-successful-session-count": 1, "total-failure-session-count": 0}, - }], - }).encode("utf-8") + report = json.dumps( + { + "organization-name": "Org", + "date-range": { + "start-datetime": "2024-01-01", + "end-datetime": "2024-01-02", + }, + "contact-info": "a@b.com", + "report-id": "r1", + "policies": [ + { + "policy": {"policy-type": "tlsa", "policy-domain": "a.com"}, + "summary": { + "total-successful-session-count": 1, + "total-failure-session-count": 0, + }, + } + ], + } + ).encode("utf-8") result = parsedmarc.parse_smtp_tls_report_json(report) self.assertEqual(result["organization_name"], "Org") @@ -869,13 +932,18 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonPoliciesNotList(self): """Non-list policies raises InvalidSMTPTLSReport""" - report = json.dumps({ - "organization-name": "Org", - "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, - "contact-info": "a@b.com", - "report-id": "r1", - "policies": "not-a-list", - }) + report = json.dumps( + { + "organization-name": "Org", + "date-range": { + "start-datetime": "2024-01-01", + "end-datetime": "2024-01-02", + }, + "contact-info": "a@b.com", + "report-id": "r1", + "policies": "not-a-list", + } + ) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) @@ -954,7 +1022,9 @@ class Test(unittest.TestCase): """ report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) - self.assertEqual(report["records"][0]["policy_evaluated"]["disposition"], "pass") + self.assertEqual( + report["records"][0]["policy_evaluated"]["disposition"], "pass" + ) def testAggregateReportMultipleRecords(self): """Reports with multiple records are all parsed""" @@ -1004,7 +1074,8 @@ class Test(unittest.TestCase): """CSV rows include np, testing, discovery_method columns""" result = parsedmarc.parse_report_file( "samples/aggregate/dmarcbis-draft-sample.xml", - always_use_local_files=True, offline=True, + always_use_local_files=True, + offline=True, ) report = result["report"] rows = parsedmarc.parsed_aggregate_reports_to_csv_rows(report) @@ -1178,6 +1249,7 @@ class Test(unittest.TestCase): def testTimestampToDatetime(self): """timestamp_to_datetime converts UNIX timestamp to datetime""" from datetime import datetime + dt = parsedmarc.utils.timestamp_to_datetime(0) self.assertIsInstance(dt, datetime) # Epoch 0 should be Jan 1 1970 in local time @@ -1212,9 +1284,7 @@ class Test(unittest.TestCase): def testHumanTimestampToDatetimeNegativeZero(self): """-0000 timezone is handled""" - dt = parsedmarc.utils.human_timestamp_to_datetime( - "2024-01-01 00:00:00 -0000" - ) + dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00 -0000") self.assertEqual(dt.year, 2024) def testHumanTimestampToUnixTimestamp(self): @@ -1264,14 +1334,19 @@ class Test(unittest.TestCase): def testGetIpAddressInfoCache(self): """get_ip_address_info uses cache on second call""" from expiringdict import ExpiringDict + cache = ExpiringDict(max_len=100, max_age_seconds=60) with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"): info1 = parsedmarc.utils.get_ip_address_info( - "8.8.8.8", offline=False, cache=cache, + "8.8.8.8", + offline=False, + cache=cache, always_use_local_files=True, ) self.assertIn("8.8.8.8", cache) - info2 = parsedmarc.utils.get_ip_address_info("8.8.8.8", offline=False, cache=cache) + info2 = parsedmarc.utils.get_ip_address_info( + "8.8.8.8", offline=False, cache=cache + ) self.assertEqual(info1["ip_address"], info2["ip_address"]) self.assertEqual(info2["reverse_dns"], "dns.google") @@ -1340,6 +1415,7 @@ class Test(unittest.TestCase): def testWebhookClientInit(self): """WebhookClient initializes with correct attributes""" from parsedmarc.webhook import WebhookClient + client = WebhookClient( aggregate_url="http://agg.example.com", failure_url="http://fail.example.com", @@ -1353,18 +1429,26 @@ class Test(unittest.TestCase): def testWebhookClientSaveMethods(self): """WebhookClient save methods call _send_to_webhook""" from parsedmarc.webhook import WebhookClient + client = WebhookClient("http://a", "http://f", "http://t") client.session = MagicMock() client.save_aggregate_report_to_webhook('{"test": 1}') - client.session.post.assert_called_with("http://a", data='{"test": 1}', timeout=60) + client.session.post.assert_called_with( + "http://a", data='{"test": 1}', timeout=60 + ) client.save_failure_report_to_webhook('{"fail": 1}') - client.session.post.assert_called_with("http://f", data='{"fail": 1}', timeout=60) + client.session.post.assert_called_with( + "http://f", data='{"fail": 1}', timeout=60 + ) client.save_smtp_tls_report_to_webhook('{"tls": 1}') - client.session.post.assert_called_with("http://t", data='{"tls": 1}', timeout=60) + client.session.post.assert_called_with( + "http://t", data='{"tls": 1}', timeout=60 + ) def testWebhookBackwardCompatAlias(self): """WebhookClient forensic alias points to failure method""" from parsedmarc.webhook import WebhookClient + self.assertIs( WebhookClient.save_forensic_report_to_webhook, WebhookClient.save_failure_report_to_webhook, @@ -1373,6 +1457,7 @@ class Test(unittest.TestCase): def testKafkaStripMetadata(self): """KafkaClient.strip_metadata extracts metadata to root""" from parsedmarc.kafkaclient import KafkaClient + report = { "report_metadata": { "org_name": "TestOrg", @@ -1392,6 +1477,7 @@ class Test(unittest.TestCase): def testKafkaGenerateDateRange(self): """KafkaClient.generate_date_range generates date range list""" from parsedmarc.kafkaclient import KafkaClient + report = { "report_metadata": { "begin_date": "2024-01-01 00:00:00", @@ -1406,6 +1492,7 @@ class Test(unittest.TestCase): def testSplunkHECClientInit(self): """HECClient initializes with correct URL and headers""" from parsedmarc.splunk import HECClient + client = HECClient( url="https://splunk.example.com:8088", access_token="my-token", @@ -1420,6 +1507,7 @@ class Test(unittest.TestCase): def testSplunkHECClientStripTokenPrefix(self): """HECClient strips 'Splunk ' prefix from token""" from parsedmarc.splunk import HECClient + client = HECClient( url="https://splunk.example.com", access_token="Splunk my-token", @@ -1430,6 +1518,7 @@ class Test(unittest.TestCase): def testSplunkBackwardCompatAlias(self): """HECClient forensic alias points to failure method""" from parsedmarc.splunk import HECClient + self.assertIs( HECClient.save_forensic_reports_to_splunk, HECClient.save_failure_reports_to_splunk, @@ -1438,6 +1527,7 @@ class Test(unittest.TestCase): def testSyslogClientUdpInit(self): """SyslogClient creates UDP handler""" from parsedmarc.syslog import SyslogClient + client = SyslogClient("localhost", 514, protocol="udp") self.assertEqual(client.server_name, "localhost") self.assertEqual(client.server_port, 514) @@ -1446,12 +1536,14 @@ class Test(unittest.TestCase): def testSyslogClientInvalidProtocol(self): """SyslogClient with invalid protocol raises ValueError""" from parsedmarc.syslog import SyslogClient + with self.assertRaises(ValueError): SyslogClient("localhost", 514, protocol="invalid") def testSyslogBackwardCompatAlias(self): """SyslogClient forensic alias points to failure method""" from parsedmarc.syslog import SyslogClient + self.assertIs( SyslogClient.save_forensic_report_to_syslog, SyslogClient.save_failure_report_to_syslog, @@ -1460,6 +1552,7 @@ class Test(unittest.TestCase): def testLogAnalyticsConfig(self): """LogAnalyticsConfig stores all fields""" from parsedmarc.loganalytics import LogAnalyticsConfig + config = LogAnalyticsConfig( client_id="cid", client_secret="csec", @@ -1482,6 +1575,7 @@ class Test(unittest.TestCase): def testLogAnalyticsClientValidationError(self): """LogAnalyticsClient raises on missing required config""" from parsedmarc.loganalytics import LogAnalyticsClient, LogAnalyticsException + with self.assertRaises(LogAnalyticsException): LogAnalyticsClient( client_id="", @@ -1496,18 +1590,34 @@ class Test(unittest.TestCase): def testSmtpTlsCsvRows(self): """parsed_smtp_tls_reports_to_csv_rows produces correct rows""" - report_json = json.dumps({ - "organization-name": "Org", - "date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"}, - "contact-info": "a@b.com", - "report-id": "r1", - "policies": [{ - "policy": {"policy-type": "sts", "policy-domain": "example.com", - "policy-string": ["v: STSv1"], "mx-host-pattern": ["*.example.com"]}, - "summary": {"total-successful-session-count": 10, "total-failure-session-count": 1}, - "failure-details": [{"result-type": "cert-expired", "failed-session-count": 1}], - }], - }) + report_json = json.dumps( + { + "organization-name": "Org", + "date-range": { + "start-datetime": "2024-01-01T00:00:00Z", + "end-datetime": "2024-01-02T00:00:00Z", + }, + "contact-info": "a@b.com", + "report-id": "r1", + "policies": [ + { + "policy": { + "policy-type": "sts", + "policy-domain": "example.com", + "policy-string": ["v: STSv1"], + "mx-host-pattern": ["*.example.com"], + }, + "summary": { + "total-successful-session-count": 10, + "total-failure-session-count": 1, + }, + "failure-details": [ + {"result-type": "cert-expired", "failed-session-count": 1} + ], + } + ], + } + ) parsed = parsedmarc.parse_smtp_tls_report_json(report_json) rows = parsedmarc.parsed_smtp_tls_reports_to_csv_rows(parsed) self.assertTrue(len(rows) >= 2) @@ -1518,7 +1628,8 @@ class Test(unittest.TestCase): """parsed_aggregate_reports_to_csv_rows handles list of reports""" result = parsedmarc.parse_report_file( "samples/aggregate/dmarcbis-draft-sample.xml", - always_use_local_files=True, offline=True, + always_use_local_files=True, + offline=True, ) report = result["report"] # Pass as a list @@ -1532,10 +1643,18 @@ class Test(unittest.TestCase): def testExceptionHierarchy(self): """Exception class hierarchy is correct""" self.assertTrue(issubclass(parsedmarc.ParserError, RuntimeError)) - self.assertTrue(issubclass(parsedmarc.InvalidDMARCReport, parsedmarc.ParserError)) - self.assertTrue(issubclass(parsedmarc.InvalidAggregateReport, parsedmarc.InvalidDMARCReport)) - self.assertTrue(issubclass(parsedmarc.InvalidFailureReport, parsedmarc.InvalidDMARCReport)) - self.assertTrue(issubclass(parsedmarc.InvalidSMTPTLSReport, parsedmarc.ParserError)) + self.assertTrue( + issubclass(parsedmarc.InvalidDMARCReport, parsedmarc.ParserError) + ) + self.assertTrue( + issubclass(parsedmarc.InvalidAggregateReport, parsedmarc.InvalidDMARCReport) + ) + self.assertTrue( + issubclass(parsedmarc.InvalidFailureReport, parsedmarc.InvalidDMARCReport) + ) + self.assertTrue( + issubclass(parsedmarc.InvalidSMTPTLSReport, parsedmarc.ParserError) + ) self.assertIs(parsedmarc.InvalidForensicReport, parsedmarc.InvalidFailureReport) def testAggregateReportNormalization(self): @@ -1579,6 +1698,7 @@ class Test(unittest.TestCase): def testGelfBackwardCompatAlias(self): """GelfClient forensic alias points to failure method""" from parsedmarc.gelf import GelfClient + self.assertIs( GelfClient.save_forensic_report_to_gelf, GelfClient.save_failure_report_to_gelf, @@ -1587,6 +1707,7 @@ class Test(unittest.TestCase): def testS3BackwardCompatAlias(self): """S3Client forensic alias points to failure method""" from parsedmarc.s3 import S3Client + self.assertIs( S3Client.save_forensic_report_to_s3, S3Client.save_failure_report_to_s3, @@ -1595,6 +1716,7 @@ class Test(unittest.TestCase): def testKafkaBackwardCompatAlias(self): """KafkaClient forensic alias points to failure method""" from parsedmarc.kafkaclient import KafkaClient + self.assertIs( KafkaClient.save_forensic_reports_to_kafka, KafkaClient.save_failure_reports_to_kafka, @@ -1622,7 +1744,9 @@ class Test(unittest.TestCase): with open(sample_path, "rb") as f: data = f.read() report = parsedmarc.parse_aggregate_report_file( - data, offline=True, always_use_local_files=True, + data, + offline=True, + always_use_local_files=True, ) self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter") self.assertEqual(report["policy_published"]["domain"], "example.com") @@ -1755,9 +1879,7 @@ class TestGmailConnection(unittest.TestCase): "from_authorized_user_file", return_value=creds, ): - returned = _get_creds( - token_path, "credentials.json", ["scope"], 8080 - ) + returned = _get_creds(token_path, "credentials.json", ["scope"], 8080) finally: os.remove(token_path) self.assertEqual(returned, creds) @@ -1811,9 +1933,7 @@ class TestGmailConnection(unittest.TestCase): "from_authorized_user_file", return_value=expired_creds, ): - returned = _get_creds( - token_path, "credentials.json", ["scope"], 8080 - ) + returned = _get_creds(token_path, "credentials.json", ["scope"], 8080) finally: os.remove(token_path) @@ -1890,7 +2010,9 @@ class TestGraphConnection(unittest.TestCase): def testFetchMessagesPassesSinceAndBatchSize(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "mailbox@example.com" - connection._find_folder_id_from_folder_path = MagicMock(return_value="folder-id") + connection._find_folder_id_from_folder_path = MagicMock( + return_value="folder-id" + ) connection._get_all_messages = MagicMock(return_value=[{"id": "1"}]) self.assertEqual( connection.fetch_messages("Inbox", since="2026-03-01", batch_size=5), ["1"] @@ -1951,7 +2073,9 @@ class TestGraphConnection(unittest.TestCase): def testGenerateCredentialDeviceCode(self): fake_credential = object() - with patch.object(graph_module, "_get_cache_args", return_value={"cached": True}): + with patch.object( + graph_module, "_get_cache_args", return_value={"cached": True} + ): with patch.object( graph_module, "DeviceCodeCredential", @@ -2098,8 +2222,12 @@ class TestImapConnection(unittest.TestCase): with self.assertRaises(_BreakLoop): connection.watch(callback, check_timeout=1) callback.assert_called_once_with(connection) + + class TestGmailAuthModes(unittest.TestCase): - @patch("parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file") + @patch( + "parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file" + ) def testGetCredsServiceAccountWithoutSubject(self, mock_from_service_account_file): service_creds = MagicMock() service_creds.with_subject.return_value = MagicMock() @@ -2121,7 +2249,9 @@ class TestGmailAuthModes(unittest.TestCase): ) service_creds.with_subject.assert_not_called() - @patch("parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file") + @patch( + "parsedmarc.mail.gmail.service_account.Credentials.from_service_account_file" + ) def testGetCredsServiceAccountWithSubject(self, mock_from_service_account_file): base_creds = MagicMock() delegated_creds = MagicMock() @@ -2212,7 +2342,7 @@ class TestGmailAuthModes(unittest.TestCase): mock_gmail_connection.return_value = MagicMock() mock_get_mailbox_reports.return_value = { "aggregate_reports": [], - "forensic_reports": [], + "failure_reports": [], "smtp_tls_reports": [], } config = """[general] @@ -2246,7 +2376,7 @@ scopes = https://www.googleapis.com/auth/gmail.modify mock_gmail_connection.return_value = MagicMock() mock_get_reports.return_value = { "aggregate_reports": [], - "forensic_reports": [], + "failure_reports": [], "smtp_tls_reports": [], } config = """[general] @@ -2270,5 +2400,7 @@ scopes = https://www.googleapis.com/auth/gmail.modify mock_gmail_connection.call_args.kwargs.get("service_account_user"), "delegated@example.com", ) + + if __name__ == "__main__": unittest.main(verbosity=2)