From 25a071da7ebf046fa24a67f62701f94543da63ae Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 19:56:33 +0000 Subject: [PATCH] Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com> --- tests.py | 211 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 142 insertions(+), 69 deletions(-) diff --git a/tests.py b/tests.py index 4bb24bd..34a9729 100755 --- a/tests.py +++ b/tests.py @@ -12,6 +12,7 @@ import tempfile import unittest from base64 import urlsafe_b64encode from configparser import ConfigParser +from datetime import datetime, timedelta, timezone from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -63,26 +64,25 @@ class Test(unittest.TestCase): # Example from Wikipedia Base64 article b64_str = "YW55IGNhcm5hbCBwbGVhcw" decoded_str = parsedmarc.utils.decode_base64(b64_str) - assert decoded_str == b"any carnal pleas" + self.assertEqual(decoded_str, b"any carnal pleas") def testPSLDownload(self): + """Test Public Suffix List domain lookups""" subdomain = "foo.example.com" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "example.com" + self.assertEqual(result, "example.com") # Test newer PSL entries subdomain = "e3191.c.akamaiedge.net" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "c.akamaiedge.net" + self.assertEqual(result, "c.akamaiedge.net") def testExtractReportXMLComparator(self): """Test XML comparator function""" - xmlnice_file = open("samples/extract_report/nice-input.xml") - xmlnice = xmlnice_file.read() - xmlnice_file.close() - xmlchanged_file = open("samples/extract_report/changed-input.xml") - xmlchanged = minify_xml(xmlchanged_file.read()) - xmlchanged_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlnice = f.read() + with open("samples/extract_report/changed-input.xml") as f: + xmlchanged = minify_xml(f.read()) self.assertTrue(compare_xml(xmlnice, xmlnice)) self.assertTrue(compare_xml(xmlchanged, xmlchanged)) self.assertFalse(compare_xml(xmlnice, xmlchanged)) @@ -97,21 +97,19 @@ class Test(unittest.TestCase): data = f.read() print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report(data) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") def testExtractReportXML(self): """Test extract report function for XML input""" print() - report_path = "samples/extract_report/nice-input.xml" - print("Testing {0}: ".format(report_path), end="") - xmlout = parsedmarc.extract_report_from_file_path(report_path) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + file = "samples/extract_report/nice-input.xml" + print("Testing {0}: ".format(file), end="") + xmlout = parsedmarc.extract_report_from_file_path(file) + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -129,9 +127,8 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.gz" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -141,13 +138,11 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.zip" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = minify_xml(xmlin_file.read()) - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = minify_xml(f.read()) self.assertTrue(compare_xml(xmlout, xmlin)) - xmlin_file = open("samples/extract_report/changed-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/changed-input.xml") as f: + xmlin = f.read() self.assertFalse(compare_xml(xmlout, xmlin)) print("Passed!") @@ -181,11 +176,12 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - result = parsedmarc.parse_report_file( - sample_path, always_use_local_files=True, offline=OFFLINE_MODE - ) - assert result["report_type"] == "aggregate" - parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) + with self.subTest(sample=sample_path): + result = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + ) + assert result["report_type"] == "aggregate" + parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) print("Passed!") def testEmptySample(self): @@ -199,17 +195,18 @@ class Test(unittest.TestCase): sample_paths = glob("samples/failure/*.eml") for sample_path in sample_paths: print("Testing {0}: ".format(sample_path), end="") - with open(sample_path) as sample_file: - sample_content = sample_file.read() - email_result = parsedmarc.parse_report_email( - sample_content, offline=OFFLINE_MODE + with self.subTest(sample=sample_path): + with open(sample_path) as sample_file: + sample_content = sample_file.read() + email_result = parsedmarc.parse_report_email( + sample_content, offline=OFFLINE_MODE + ) + assert email_result["report_type"] == "failure" + result = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE ) - assert email_result["report_type"] == "failure" - result = parsedmarc.parse_report_file( - sample_path, offline=OFFLINE_MODE - ) - assert result["report_type"] == "failure" - parsedmarc.parsed_failure_reports_to_csv(result["report"]) + assert result["report_type"] == "failure" + parsedmarc.parsed_failure_reports_to_csv(result["report"]) print("Passed!") def testFailureReportBackwardCompat(self): @@ -369,9 +366,12 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - result = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE) - assert result["report_type"] == "smtp_tls" - parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"]) + with self.subTest(sample=sample_path): + result = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + ) + assert result["report_type"] == "smtp_tls" + parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"]) print("Passed!") def testOpenSearchSigV4RequiresRegion(self): @@ -3185,7 +3185,6 @@ class TestEnvVarConfig(unittest.TestCase): # ============================================================ def testBucketIntervalBeginAfterEnd(self): """begin > end should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 2, tzinfo=timezone.utc) end = datetime(2024, 1, 1, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -3193,7 +3192,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalNaiveDatetime(self): """Non-timezone-aware datetimes should raise ValueError""" - from datetime import datetime begin = datetime(2024, 1, 1) end = datetime(2024, 1, 2) with self.assertRaises(ValueError): @@ -3201,7 +3199,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalDifferentTzinfo(self): """Different tzinfo objects should raise ValueError""" - from datetime import datetime, timezone, timedelta tz1 = timezone.utc tz2 = timezone(timedelta(hours=5)) begin = datetime(2024, 1, 1, tzinfo=tz1) @@ -3211,7 +3208,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalNegativeCount(self): """Negative total_count should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -3219,7 +3215,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalZeroCount(self): """Zero total_count should return empty list""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 0) @@ -3227,14 +3222,12 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalSameBeginEnd(self): """Same begin and end (zero interval) should return empty list""" - from datetime import datetime, timezone dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(dt, dt, 100) self.assertEqual(result, []) def testBucketIntervalSingleDay(self): """Single day interval should return one bucket with total count""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -3244,7 +3237,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalMultiDay(self): """Multi-day interval should distribute counts proportionally""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -3257,7 +3249,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalRemainderDistribution(self): """Odd count across equal days distributes remainder correctly""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 10) @@ -3267,7 +3258,6 @@ class TestEnvVarConfig(unittest.TestCase): def testBucketIntervalPartialDays(self): """Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 90) @@ -3281,7 +3271,6 @@ class TestEnvVarConfig(unittest.TestCase): # ============================================================ def testAppendParsedRecordNoNormalize(self): """No normalization: record appended as-is with interval fields""" - from datetime import datetime, timezone records = [] rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -3294,7 +3283,6 @@ class TestEnvVarConfig(unittest.TestCase): def testAppendParsedRecordNormalize(self): """Normalization: record split into daily buckets""" - from datetime import datetime, timezone records = [] rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -3308,7 +3296,6 @@ class TestEnvVarConfig(unittest.TestCase): def testAppendParsedRecordNormalizeZeroCount(self): """Normalization with zero count: nothing appended""" - from datetime import datetime, timezone records = [] rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -3667,7 +3654,6 @@ class TestEnvVarConfig(unittest.TestCase): # ============================================================ def testParseSmtpTlsReportJsonValid(self): """Valid SMTP TLS JSON report parses correctly""" - import json report = json.dumps({ "organization-name": "Example Corp", "date-range": { @@ -3696,7 +3682,6 @@ class TestEnvVarConfig(unittest.TestCase): def testParseSmtpTlsReportJsonBytes(self): """SMTP TLS report as bytes parses correctly""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -3712,14 +3697,12 @@ class TestEnvVarConfig(unittest.TestCase): def testParseSmtpTlsReportJsonMissingField(self): """Missing required field raises InvalidSMTPTLSReport""" - import json report = json.dumps({"organization-name": "Org"}) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) def testParseSmtpTlsReportJsonPoliciesNotList(self): """Non-list policies raises InvalidSMTPTLSReport""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -4037,7 +4020,6 @@ class TestEnvVarConfig(unittest.TestCase): def testHumanTimestampToDatetime(self): """human_timestamp_to_datetime parses timestamp string""" - from datetime import datetime dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00") self.assertIsInstance(dt, datetime) self.assertEqual(dt.year, 2024) @@ -4046,7 +4028,6 @@ class TestEnvVarConfig(unittest.TestCase): def testHumanTimestampToDatetimeUtc(self): """human_timestamp_to_datetime with to_utc=True returns UTC""" - from datetime import timezone dt = parsedmarc.utils.human_timestamp_to_datetime( "2024-01-01 12:00:00", to_utc=True ) @@ -4114,9 +4095,6 @@ class TestEnvVarConfig(unittest.TestCase): """get_ip_address_info uses cache on second call""" from expiringdict import ExpiringDict cache = ExpiringDict(max_len=100, max_age_seconds=60) - # offline=True means reverse_dns is None, so cache is not populated - # Use offline=False with mock to test cache - from unittest.mock import patch with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"): info1 = parsedmarc.utils.get_ip_address_info( "8.8.8.8", offline=False, cache=cache, @@ -4202,7 +4180,6 @@ class TestEnvVarConfig(unittest.TestCase): def testWebhookClientSaveMethods(self): """WebhookClient save methods call _send_to_webhook""" - from unittest.mock import MagicMock from parsedmarc.webhook import WebhookClient client = WebhookClient("http://a", "http://f", "http://t") client.session = MagicMock() @@ -4347,7 +4324,6 @@ class TestEnvVarConfig(unittest.TestCase): def testSmtpTlsCsvRows(self): """parsed_smtp_tls_reports_to_csv_rows produces correct rows""" - import json report_json = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"}, @@ -4424,6 +4400,103 @@ class TestEnvVarConfig(unittest.TestCase): for r in report["records"]: self.assertTrue(r["normalized_timespan"]) + # =================================================================== + # Additional backward compatibility alias tests + # =================================================================== + + def testGelfBackwardCompatAlias(self): + """GelfClient forensic alias points to failure method""" + from parsedmarc.gelf import GelfClient + self.assertIs( + GelfClient.save_forensic_report_to_gelf, + GelfClient.save_failure_report_to_gelf, + ) + + def testS3BackwardCompatAlias(self): + """S3Client forensic alias points to failure method""" + from parsedmarc.s3 import S3Client + self.assertIs( + S3Client.save_forensic_report_to_s3, + S3Client.save_failure_report_to_s3, + ) + + def testKafkaBackwardCompatAlias(self): + """KafkaClient forensic alias points to failure method""" + from parsedmarc.kafkaclient import KafkaClient + self.assertIs( + KafkaClient.save_forensic_reports_to_kafka, + KafkaClient.save_failure_reports_to_kafka, + ) + + # =================================================================== + # Additional extract/parse tests + # =================================================================== + + def testExtractReportFromFilePathNotFound(self): + """extract_report_from_file_path raises ParserError for missing file""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report_from_file_path("nonexistent_file.xml") + + def testExtractReportInvalidArchive(self): + """extract_report raises ParserError for unrecognized binary content""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07") + + def testParseAggregateReportFile(self): + """parse_aggregate_report_file parses bytes input directly""" + print() + sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" + print("Testing {0}: ".format(sample_path), end="") + with open(sample_path, "rb") as f: + data = f.read() + report = parsedmarc.parse_aggregate_report_file( + data, offline=True, always_use_local_files=True, + ) + self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter") + self.assertEqual(report["policy_published"]["domain"], "example.com") + print("Passed!") + + def testParseInvalidAggregateSample(self): + """Test invalid aggregate samples are handled""" + print() + sample_paths = glob("samples/aggregate_invalid/*") + for sample_path in sample_paths: + if os.path.isdir(sample_path): + continue + print("Testing {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + )["report"] + parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) + print("Passed!") + + def testParseReportFileWithBytes(self): + """parse_report_file handles bytes input""" + with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f: + data = f.read() + result = parsedmarc.parse_report_file( + data, always_use_local_files=True, offline=True + ) + self.assertEqual(result["report_type"], "aggregate") + + def testFailureReportCsvRoundtrip(self): + """Failure report CSV generation works on sample reports""" + print() + sample_paths = glob("samples/forensic/*.eml") + for sample_path in sample_paths: + print("Testing CSV for {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + )["report"] + csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report) + self.assertIsNotNone(csv_output) + self.assertIn(",", csv_output) + rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report) + self.assertTrue(len(rows) > 0) + print("Passed!") + if __name__ == "__main__": unittest.main(verbosity=2)