From 01c2e623bbcb88ceff7f761f181c05eec3146b47 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 19:56:33 +0000 Subject: [PATCH] Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com> --- tests.py | 201 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 137 insertions(+), 64 deletions(-) diff --git a/tests.py b/tests.py index 6d629c7..58304ac 100755 --- a/tests.py +++ b/tests.py @@ -3,9 +3,12 @@ from __future__ import absolute_import, print_function, unicode_literals +import json import os import unittest +from datetime import datetime, timedelta, timezone from glob import glob +from unittest.mock import MagicMock, patch from lxml import etree @@ -35,26 +38,25 @@ class Test(unittest.TestCase): # Example from Wikipedia Base64 article b64_str = "YW55IGNhcm5hbCBwbGVhcw" decoded_str = parsedmarc.utils.decode_base64(b64_str) - assert decoded_str == b"any carnal pleas" + self.assertEqual(decoded_str, b"any carnal pleas") def testPSLDownload(self): + """Test Public Suffix List domain lookups""" subdomain = "foo.example.com" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "example.com" + self.assertEqual(result, "example.com") # Test newer PSL entries subdomain = "e3191.c.akamaiedge.net" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "c.akamaiedge.net" + self.assertEqual(result, "c.akamaiedge.net") def testExtractReportXMLComparator(self): """Test XML comparator function""" - xmlnice_file = open("samples/extract_report/nice-input.xml") - xmlnice = xmlnice_file.read() - xmlnice_file.close() - xmlchanged_file = open("samples/extract_report/changed-input.xml") - xmlchanged = minify_xml(xmlchanged_file.read()) - xmlchanged_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlnice = f.read() + with open("samples/extract_report/changed-input.xml") as f: + xmlchanged = minify_xml(f.read()) self.assertTrue(compare_xml(xmlnice, xmlnice)) self.assertTrue(compare_xml(xmlchanged, xmlchanged)) self.assertFalse(compare_xml(xmlnice, xmlchanged)) @@ -69,9 +71,8 @@ class Test(unittest.TestCase): data = f.read() print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report(data) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -81,9 +82,8 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -93,9 +93,8 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.gz" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -105,13 +104,11 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.zip" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = minify_xml(xmlin_file.read()) - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = minify_xml(f.read()) self.assertTrue(compare_xml(xmlout, xmlin)) - xmlin_file = open("samples/extract_report/changed-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/changed-input.xml") as f: + xmlin = f.read() self.assertFalse(compare_xml(xmlout, xmlin)) print("Passed!") @@ -123,10 +120,11 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - parsed_report = parsedmarc.parse_report_file( - sample_path, always_use_local_files=True, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + )["report"] + parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) print("Passed!") def testEmptySample(self): @@ -140,15 +138,16 @@ class Test(unittest.TestCase): sample_paths = glob("samples/forensic/*.eml") for sample_path in sample_paths: print("Testing {0}: ".format(sample_path), end="") - with open(sample_path) as sample_file: - sample_content = sample_file.read() - parsed_report = parsedmarc.parse_report_email( - sample_content, offline=OFFLINE_MODE + with self.subTest(sample=sample_path): + with open(sample_path) as sample_file: + sample_content = sample_file.read() + parsed_report = parsedmarc.parse_report_email( + sample_content, offline=OFFLINE_MODE + )["report"] + parsed_report = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE )["report"] - parsed_report = parsedmarc.parse_report_file( - sample_path, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_failure_reports_to_csv(parsed_report) + parsedmarc.parsed_failure_reports_to_csv(parsed_report) print("Passed!") def testFailureReportBackwardCompat(self): @@ -308,10 +307,11 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - parsed_report = parsedmarc.parse_report_file( - sample_path, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report) + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + )["report"] + parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report) print("Passed!") # =================================================================== @@ -320,7 +320,6 @@ class Test(unittest.TestCase): def testBucketIntervalBeginAfterEnd(self): """begin > end should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 2, tzinfo=timezone.utc) end = datetime(2024, 1, 1, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -328,7 +327,6 @@ class Test(unittest.TestCase): def testBucketIntervalNaiveDatetime(self): """Non-timezone-aware datetimes should raise ValueError""" - from datetime import datetime begin = datetime(2024, 1, 1) end = datetime(2024, 1, 2) with self.assertRaises(ValueError): @@ -336,7 +334,6 @@ class Test(unittest.TestCase): def testBucketIntervalDifferentTzinfo(self): """Different tzinfo objects should raise ValueError""" - from datetime import datetime, timezone, timedelta tz1 = timezone.utc tz2 = timezone(timedelta(hours=5)) begin = datetime(2024, 1, 1, tzinfo=tz1) @@ -346,7 +343,6 @@ class Test(unittest.TestCase): def testBucketIntervalNegativeCount(self): """Negative total_count should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -354,7 +350,6 @@ class Test(unittest.TestCase): def testBucketIntervalZeroCount(self): """Zero total_count should return empty list""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 0) @@ -362,14 +357,12 @@ class Test(unittest.TestCase): def testBucketIntervalSameBeginEnd(self): """Same begin and end (zero interval) should return empty list""" - from datetime import datetime, timezone dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(dt, dt, 100) self.assertEqual(result, []) def testBucketIntervalSingleDay(self): """Single day interval should return one bucket with total count""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -379,7 +372,6 @@ class Test(unittest.TestCase): def testBucketIntervalMultiDay(self): """Multi-day interval should distribute counts proportionally""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -392,7 +384,6 @@ class Test(unittest.TestCase): def testBucketIntervalRemainderDistribution(self): """Odd count across equal days distributes remainder correctly""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 10) @@ -402,7 +393,6 @@ class Test(unittest.TestCase): def testBucketIntervalPartialDays(self): """Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 90) @@ -418,7 +408,6 @@ class Test(unittest.TestCase): def testAppendParsedRecordNoNormalize(self): """No normalization: record appended as-is with interval fields""" - from datetime import datetime, timezone records = [] rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -431,7 +420,6 @@ class Test(unittest.TestCase): def testAppendParsedRecordNormalize(self): """Normalization: record split into daily buckets""" - from datetime import datetime, timezone records = [] rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -445,7 +433,6 @@ class Test(unittest.TestCase): def testAppendParsedRecordNormalizeZeroCount(self): """Normalization with zero count: nothing appended""" - from datetime import datetime, timezone records = [] rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -812,7 +799,6 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonValid(self): """Valid SMTP TLS JSON report parses correctly""" - import json report = json.dumps({ "organization-name": "Example Corp", "date-range": { @@ -841,7 +827,6 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonBytes(self): """SMTP TLS report as bytes parses correctly""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -857,14 +842,12 @@ class Test(unittest.TestCase): def testParseSmtpTlsReportJsonMissingField(self): """Missing required field raises InvalidSMTPTLSReport""" - import json report = json.dumps({"organization-name": "Org"}) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) def testParseSmtpTlsReportJsonPoliciesNotList(self): """Non-list policies raises InvalidSMTPTLSReport""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -1186,7 +1169,6 @@ class Test(unittest.TestCase): def testHumanTimestampToDatetime(self): """human_timestamp_to_datetime parses timestamp string""" - from datetime import datetime dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00") self.assertIsInstance(dt, datetime) self.assertEqual(dt.year, 2024) @@ -1195,7 +1177,6 @@ class Test(unittest.TestCase): def testHumanTimestampToDatetimeUtc(self): """human_timestamp_to_datetime with to_utc=True returns UTC""" - from datetime import timezone dt = parsedmarc.utils.human_timestamp_to_datetime( "2024-01-01 12:00:00", to_utc=True ) @@ -1263,9 +1244,6 @@ class Test(unittest.TestCase): """get_ip_address_info uses cache on second call""" from expiringdict import ExpiringDict cache = ExpiringDict(max_len=100, max_age_seconds=60) - # offline=True means reverse_dns is None, so cache is not populated - # Use offline=False with mock to test cache - from unittest.mock import patch with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"): info1 = parsedmarc.utils.get_ip_address_info( "8.8.8.8", offline=False, cache=cache, @@ -1353,7 +1331,6 @@ class Test(unittest.TestCase): def testWebhookClientSaveMethods(self): """WebhookClient save methods call _send_to_webhook""" - from unittest.mock import MagicMock from parsedmarc.webhook import WebhookClient client = WebhookClient("http://a", "http://f", "http://t") client.session = MagicMock() @@ -1498,7 +1475,6 @@ class Test(unittest.TestCase): def testSmtpTlsCsvRows(self): """parsed_smtp_tls_reports_to_csv_rows produces correct rows""" - import json report_json = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"}, @@ -1575,6 +1551,103 @@ class Test(unittest.TestCase): for r in report["records"]: self.assertTrue(r["normalized_timespan"]) + # =================================================================== + # Additional backward compatibility alias tests + # =================================================================== + + def testGelfBackwardCompatAlias(self): + """GelfClient forensic alias points to failure method""" + from parsedmarc.gelf import GelfClient + self.assertIs( + GelfClient.save_forensic_report_to_gelf, + GelfClient.save_failure_report_to_gelf, + ) + + def testS3BackwardCompatAlias(self): + """S3Client forensic alias points to failure method""" + from parsedmarc.s3 import S3Client + self.assertIs( + S3Client.save_forensic_report_to_s3, + S3Client.save_failure_report_to_s3, + ) + + def testKafkaBackwardCompatAlias(self): + """KafkaClient forensic alias points to failure method""" + from parsedmarc.kafkaclient import KafkaClient + self.assertIs( + KafkaClient.save_forensic_reports_to_kafka, + KafkaClient.save_failure_reports_to_kafka, + ) + + # =================================================================== + # Additional extract/parse tests + # =================================================================== + + def testExtractReportFromFilePathNotFound(self): + """extract_report_from_file_path raises ParserError for missing file""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report_from_file_path("nonexistent_file.xml") + + def testExtractReportInvalidArchive(self): + """extract_report raises ParserError for unrecognized binary content""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07") + + def testParseAggregateReportFile(self): + """parse_aggregate_report_file parses bytes input directly""" + print() + sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" + print("Testing {0}: ".format(sample_path), end="") + with open(sample_path, "rb") as f: + data = f.read() + report = parsedmarc.parse_aggregate_report_file( + data, offline=True, always_use_local_files=True, + ) + self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter") + self.assertEqual(report["policy_published"]["domain"], "example.com") + print("Passed!") + + def testParseInvalidAggregateSample(self): + """Test invalid aggregate samples are handled""" + print() + sample_paths = glob("samples/aggregate_invalid/*") + for sample_path in sample_paths: + if os.path.isdir(sample_path): + continue + print("Testing {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + )["report"] + parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) + print("Passed!") + + def testParseReportFileWithBytes(self): + """parse_report_file handles bytes input""" + with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f: + data = f.read() + result = parsedmarc.parse_report_file( + data, always_use_local_files=True, offline=True + ) + self.assertEqual(result["report_type"], "aggregate") + + def testFailureReportCsvRoundtrip(self): + """Failure report CSV generation works on sample reports""" + print() + sample_paths = glob("samples/forensic/*.eml") + for sample_path in sample_paths: + print("Testing CSV for {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + )["report"] + csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report) + self.assertIsNotNone(csv_output) + self.assertIn(",", csv_output) + rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report) + self.assertTrue(len(rows) > 0) + print("Passed!") + if __name__ == "__main__": unittest.main(verbosity=2)