From fcbba3bf6b3575fe8983b5563abb8ff731fac4d2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 19:56:33 +0000 Subject: [PATCH] Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com> --- tests.py | 212 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 143 insertions(+), 69 deletions(-) diff --git a/tests.py b/tests.py index 9378985..fd45831 100755 --- a/tests.py +++ b/tests.py @@ -3,12 +3,14 @@ from __future__ import absolute_import, print_function, unicode_literals +import json import os import signal import sys import tempfile import unittest from base64 import urlsafe_b64encode +from datetime import datetime, timedelta, timezone from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -60,26 +62,25 @@ class Test(unittest.TestCase): # Example from Wikipedia Base64 article b64_str = "YW55IGNhcm5hbCBwbGVhcw" decoded_str = parsedmarc.utils.decode_base64(b64_str) - assert decoded_str == b"any carnal pleas" + self.assertEqual(decoded_str, b"any carnal pleas") def testPSLDownload(self): + """Test Public Suffix List domain lookups""" subdomain = "foo.example.com" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "example.com" + self.assertEqual(result, "example.com") # Test newer PSL entries subdomain = "e3191.c.akamaiedge.net" result = parsedmarc.utils.get_base_domain(subdomain) - assert result == "c.akamaiedge.net" + self.assertEqual(result, "c.akamaiedge.net") def testExtractReportXMLComparator(self): """Test XML comparator function""" - xmlnice_file = open("samples/extract_report/nice-input.xml") - xmlnice = xmlnice_file.read() - xmlnice_file.close() - xmlchanged_file = open("samples/extract_report/changed-input.xml") - xmlchanged = minify_xml(xmlchanged_file.read()) - xmlchanged_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlnice = f.read() + with open("samples/extract_report/changed-input.xml") as f: + xmlchanged = minify_xml(f.read()) self.assertTrue(compare_xml(xmlnice, xmlnice)) self.assertTrue(compare_xml(xmlchanged, xmlchanged)) self.assertFalse(compare_xml(xmlnice, xmlchanged)) @@ -94,21 +95,19 @@ class Test(unittest.TestCase): data = f.read() print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report(data) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") def testExtractReportXML(self): """Test extract report function for XML input""" print() - report_path = "samples/extract_report/nice-input.xml" - print("Testing {0}: ".format(report_path), end="") - xmlout = parsedmarc.extract_report_from_file_path(report_path) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + file = "samples/extract_report/nice-input.xml" + print("Testing {0}: ".format(file), end="") + xmlout = parsedmarc.extract_report_from_file_path(file) + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -126,9 +125,8 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.gz" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = f.read() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") @@ -138,13 +136,11 @@ class Test(unittest.TestCase): file = "samples/extract_report/nice-input.xml.zip" print("Testing {0}: ".format(file), end="") xmlout = parsedmarc.extract_report_from_file_path(file) - xmlin_file = open("samples/extract_report/nice-input.xml") - xmlin = minify_xml(xmlin_file.read()) - xmlin_file.close() + with open("samples/extract_report/nice-input.xml") as f: + xmlin = minify_xml(f.read()) self.assertTrue(compare_xml(xmlout, xmlin)) - xmlin_file = open("samples/extract_report/changed-input.xml") - xmlin = xmlin_file.read() - xmlin_file.close() + with open("samples/extract_report/changed-input.xml") as f: + xmlin = f.read() self.assertFalse(compare_xml(xmlout, xmlin)) print("Passed!") @@ -178,11 +174,12 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - result = parsedmarc.parse_report_file( - sample_path, always_use_local_files=True, offline=OFFLINE_MODE - ) - assert result["report_type"] == "aggregate" - parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) + with self.subTest(sample=sample_path): + result = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + ) + assert result["report_type"] == "aggregate" + parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) print("Passed!") def testEmptySample(self): @@ -196,17 +193,18 @@ class Test(unittest.TestCase): sample_paths = glob("samples/failure/*.eml") for sample_path in sample_paths: print("Testing {0}: ".format(sample_path), end="") - with open(sample_path) as sample_file: - sample_content = sample_file.read() - email_result = parsedmarc.parse_report_email( - sample_content, offline=OFFLINE_MODE + with self.subTest(sample=sample_path): + with open(sample_path) as sample_file: + sample_content = sample_file.read() + email_result = parsedmarc.parse_report_email( + sample_content, offline=OFFLINE_MODE + ) + assert email_result["report_type"] == "failure" + result = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE ) - assert email_result["report_type"] == "failure" - result = parsedmarc.parse_report_file( - sample_path, offline=OFFLINE_MODE - ) - assert result["report_type"] == "failure" - parsedmarc.parsed_failure_reports_to_csv(result["report"]) + assert result["report_type"] == "failure" + parsedmarc.parsed_failure_reports_to_csv(result["report"]) print("Passed!") def testFailureReportBackwardCompat(self): @@ -366,9 +364,12 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - result = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE) - assert result["report_type"] == "smtp_tls" - parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"]) + with self.subTest(sample=sample_path): + result = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + ) + assert result["report_type"] == "smtp_tls" + parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"]) print("Passed!") def testOpenSearchSigV4RequiresRegion(self): @@ -2386,7 +2387,6 @@ watch = true def testBucketIntervalBeginAfterEnd(self): """begin > end should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 2, tzinfo=timezone.utc) end = datetime(2024, 1, 1, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -2394,7 +2394,6 @@ watch = true def testBucketIntervalNaiveDatetime(self): """Non-timezone-aware datetimes should raise ValueError""" - from datetime import datetime begin = datetime(2024, 1, 1) end = datetime(2024, 1, 2) with self.assertRaises(ValueError): @@ -2402,7 +2401,6 @@ watch = true def testBucketIntervalDifferentTzinfo(self): """Different tzinfo objects should raise ValueError""" - from datetime import datetime, timezone, timedelta tz1 = timezone.utc tz2 = timezone(timedelta(hours=5)) begin = datetime(2024, 1, 1, tzinfo=tz1) @@ -2412,7 +2410,6 @@ watch = true def testBucketIntervalNegativeCount(self): """Negative total_count should raise ValueError""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) with self.assertRaises(ValueError): @@ -2420,7 +2417,6 @@ watch = true def testBucketIntervalZeroCount(self): """Zero total_count should return empty list""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, tzinfo=timezone.utc) end = datetime(2024, 1, 2, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 0) @@ -2428,14 +2424,12 @@ watch = true def testBucketIntervalSameBeginEnd(self): """Same begin and end (zero interval) should return empty list""" - from datetime import datetime, timezone dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(dt, dt, 100) self.assertEqual(result, []) def testBucketIntervalSingleDay(self): """Single day interval should return one bucket with total count""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -2445,7 +2439,6 @@ watch = true def testBucketIntervalMultiDay(self): """Multi-day interval should distribute counts proportionally""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 100) @@ -2458,7 +2451,6 @@ watch = true def testBucketIntervalRemainderDistribution(self): """Odd count across equal days distributes remainder correctly""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 10) @@ -2468,7 +2460,6 @@ watch = true def testBucketIntervalPartialDays(self): """Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split""" - from datetime import datetime, timezone begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc) result = parsedmarc._bucket_interval_by_day(begin, end, 90) @@ -2484,7 +2475,6 @@ watch = true def testAppendParsedRecordNoNormalize(self): """No normalization: record appended as-is with interval fields""" - from datetime import datetime, timezone records = [] rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -2497,7 +2487,6 @@ watch = true def testAppendParsedRecordNormalize(self): """Normalization: record split into daily buckets""" - from datetime import datetime, timezone records = [] rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -2511,7 +2500,6 @@ watch = true def testAppendParsedRecordNormalizeZeroCount(self): """Normalization with zero count: nothing appended""" - from datetime import datetime, timezone records = [] rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}} begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -2878,7 +2866,6 @@ watch = true def testParseSmtpTlsReportJsonValid(self): """Valid SMTP TLS JSON report parses correctly""" - import json report = json.dumps({ "organization-name": "Example Corp", "date-range": { @@ -2907,7 +2894,6 @@ watch = true def testParseSmtpTlsReportJsonBytes(self): """SMTP TLS report as bytes parses correctly""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -2923,14 +2909,12 @@ watch = true def testParseSmtpTlsReportJsonMissingField(self): """Missing required field raises InvalidSMTPTLSReport""" - import json report = json.dumps({"organization-name": "Org"}) with self.assertRaises(parsedmarc.InvalidSMTPTLSReport): parsedmarc.parse_smtp_tls_report_json(report) def testParseSmtpTlsReportJsonPoliciesNotList(self): """Non-list policies raises InvalidSMTPTLSReport""" - import json report = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"}, @@ -3252,7 +3236,6 @@ watch = true def testHumanTimestampToDatetime(self): """human_timestamp_to_datetime parses timestamp string""" - from datetime import datetime dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00") self.assertIsInstance(dt, datetime) self.assertEqual(dt.year, 2024) @@ -3261,7 +3244,6 @@ watch = true def testHumanTimestampToDatetimeUtc(self): """human_timestamp_to_datetime with to_utc=True returns UTC""" - from datetime import timezone dt = parsedmarc.utils.human_timestamp_to_datetime( "2024-01-01 12:00:00", to_utc=True ) @@ -3329,9 +3311,6 @@ watch = true """get_ip_address_info uses cache on second call""" from expiringdict import ExpiringDict cache = ExpiringDict(max_len=100, max_age_seconds=60) - # offline=True means reverse_dns is None, so cache is not populated - # Use offline=False with mock to test cache - from unittest.mock import patch with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"): info1 = parsedmarc.utils.get_ip_address_info( "8.8.8.8", offline=False, cache=cache, @@ -3419,7 +3398,6 @@ watch = true def testWebhookClientSaveMethods(self): """WebhookClient save methods call _send_to_webhook""" - from unittest.mock import MagicMock from parsedmarc.webhook import WebhookClient client = WebhookClient("http://a", "http://f", "http://t") client.session = MagicMock() @@ -3564,7 +3542,6 @@ watch = true def testSmtpTlsCsvRows(self): """parsed_smtp_tls_reports_to_csv_rows produces correct rows""" - import json report_json = json.dumps({ "organization-name": "Org", "date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"}, @@ -3641,6 +3618,103 @@ watch = true for r in report["records"]: self.assertTrue(r["normalized_timespan"]) + # =================================================================== + # Additional backward compatibility alias tests + # =================================================================== + + def testGelfBackwardCompatAlias(self): + """GelfClient forensic alias points to failure method""" + from parsedmarc.gelf import GelfClient + self.assertIs( + GelfClient.save_forensic_report_to_gelf, + GelfClient.save_failure_report_to_gelf, + ) + + def testS3BackwardCompatAlias(self): + """S3Client forensic alias points to failure method""" + from parsedmarc.s3 import S3Client + self.assertIs( + S3Client.save_forensic_report_to_s3, + S3Client.save_failure_report_to_s3, + ) + + def testKafkaBackwardCompatAlias(self): + """KafkaClient forensic alias points to failure method""" + from parsedmarc.kafkaclient import KafkaClient + self.assertIs( + KafkaClient.save_forensic_reports_to_kafka, + KafkaClient.save_failure_reports_to_kafka, + ) + + # =================================================================== + # Additional extract/parse tests + # =================================================================== + + def testExtractReportFromFilePathNotFound(self): + """extract_report_from_file_path raises ParserError for missing file""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report_from_file_path("nonexistent_file.xml") + + def testExtractReportInvalidArchive(self): + """extract_report raises ParserError for unrecognized binary content""" + with self.assertRaises(parsedmarc.ParserError): + parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07") + + def testParseAggregateReportFile(self): + """parse_aggregate_report_file parses bytes input directly""" + print() + sample_path = "samples/aggregate/dmarcbis-draft-sample.xml" + print("Testing {0}: ".format(sample_path), end="") + with open(sample_path, "rb") as f: + data = f.read() + report = parsedmarc.parse_aggregate_report_file( + data, offline=True, always_use_local_files=True, + ) + self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter") + self.assertEqual(report["policy_published"]["domain"], "example.com") + print("Passed!") + + def testParseInvalidAggregateSample(self): + """Test invalid aggregate samples are handled""" + print() + sample_paths = glob("samples/aggregate_invalid/*") + for sample_path in sample_paths: + if os.path.isdir(sample_path): + continue + print("Testing {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, always_use_local_files=True, offline=OFFLINE_MODE + )["report"] + parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) + print("Passed!") + + def testParseReportFileWithBytes(self): + """parse_report_file handles bytes input""" + with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f: + data = f.read() + result = parsedmarc.parse_report_file( + data, always_use_local_files=True, offline=True + ) + self.assertEqual(result["report_type"], "aggregate") + + def testFailureReportCsvRoundtrip(self): + """Failure report CSV generation works on sample reports""" + print() + sample_paths = glob("samples/forensic/*.eml") + for sample_path in sample_paths: + print("Testing CSV for {0}: ".format(sample_path), end="") + with self.subTest(sample=sample_path): + parsed_report = parsedmarc.parse_report_file( + sample_path, offline=OFFLINE_MODE + )["report"] + csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report) + self.assertIsNotNone(csv_output) + self.assertIn(",", csv_output) + rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report) + self.assertTrue(len(rows) > 0) + print("Passed!") + if __name__ == "__main__": unittest.main(verbosity=2)