Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests

Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-02-21 19:56:33 +00:00
parent 68ccb0eb35
commit 01c2e623bb

201
tests.py
View File

@@ -3,9 +3,12 @@
from __future__ import absolute_import, print_function, unicode_literals
import json
import os
import unittest
from datetime import datetime, timedelta, timezone
from glob import glob
from unittest.mock import MagicMock, patch
from lxml import etree
@@ -35,26 +38,25 @@ class Test(unittest.TestCase):
# Example from Wikipedia Base64 article
b64_str = "YW55IGNhcm5hbCBwbGVhcw"
decoded_str = parsedmarc.utils.decode_base64(b64_str)
assert decoded_str == b"any carnal pleas"
self.assertEqual(decoded_str, b"any carnal pleas")
def testPSLDownload(self):
"""Test Public Suffix List domain lookups"""
subdomain = "foo.example.com"
result = parsedmarc.utils.get_base_domain(subdomain)
assert result == "example.com"
self.assertEqual(result, "example.com")
# Test newer PSL entries
subdomain = "e3191.c.akamaiedge.net"
result = parsedmarc.utils.get_base_domain(subdomain)
assert result == "c.akamaiedge.net"
self.assertEqual(result, "c.akamaiedge.net")
def testExtractReportXMLComparator(self):
"""Test XML comparator function"""
xmlnice_file = open("samples/extract_report/nice-input.xml")
xmlnice = xmlnice_file.read()
xmlnice_file.close()
xmlchanged_file = open("samples/extract_report/changed-input.xml")
xmlchanged = minify_xml(xmlchanged_file.read())
xmlchanged_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlnice = f.read()
with open("samples/extract_report/changed-input.xml") as f:
xmlchanged = minify_xml(f.read())
self.assertTrue(compare_xml(xmlnice, xmlnice))
self.assertTrue(compare_xml(xmlchanged, xmlchanged))
self.assertFalse(compare_xml(xmlnice, xmlchanged))
@@ -69,9 +71,8 @@ class Test(unittest.TestCase):
data = f.read()
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report(data)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -81,9 +82,8 @@ class Test(unittest.TestCase):
file = "samples/extract_report/nice-input.xml"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -93,9 +93,8 @@ class Test(unittest.TestCase):
file = "samples/extract_report/nice-input.xml.gz"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -105,13 +104,11 @@ class Test(unittest.TestCase):
file = "samples/extract_report/nice-input.xml.zip"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = minify_xml(xmlin_file.read())
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = minify_xml(f.read())
self.assertTrue(compare_xml(xmlout, xmlin))
xmlin_file = open("samples/extract_report/changed-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/changed-input.xml") as f:
xmlin = f.read()
self.assertFalse(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -123,10 +120,11 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
parsed_report = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
print("Passed!")
def testEmptySample(self):
@@ -140,15 +138,16 @@ class Test(unittest.TestCase):
sample_paths = glob("samples/forensic/*.eml")
for sample_path in sample_paths:
print("Testing {0}: ".format(sample_path), end="")
with open(sample_path) as sample_file:
sample_content = sample_file.read()
parsed_report = parsedmarc.parse_report_email(
sample_content, offline=OFFLINE_MODE
with self.subTest(sample=sample_path):
with open(sample_path) as sample_file:
sample_content = sample_file.read()
parsed_report = parsedmarc.parse_report_email(
sample_content, offline=OFFLINE_MODE
)["report"]
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_failure_reports_to_csv(parsed_report)
parsedmarc.parsed_failure_reports_to_csv(parsed_report)
print("Passed!")
def testFailureReportBackwardCompat(self):
@@ -308,10 +307,11 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report)
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report)
print("Passed!")
# ===================================================================
@@ -320,7 +320,6 @@ class Test(unittest.TestCase):
def testBucketIntervalBeginAfterEnd(self):
"""begin > end should raise ValueError"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 2, tzinfo=timezone.utc)
end = datetime(2024, 1, 1, tzinfo=timezone.utc)
with self.assertRaises(ValueError):
@@ -328,7 +327,6 @@ class Test(unittest.TestCase):
def testBucketIntervalNaiveDatetime(self):
"""Non-timezone-aware datetimes should raise ValueError"""
from datetime import datetime
begin = datetime(2024, 1, 1)
end = datetime(2024, 1, 2)
with self.assertRaises(ValueError):
@@ -336,7 +334,6 @@ class Test(unittest.TestCase):
def testBucketIntervalDifferentTzinfo(self):
"""Different tzinfo objects should raise ValueError"""
from datetime import datetime, timezone, timedelta
tz1 = timezone.utc
tz2 = timezone(timedelta(hours=5))
begin = datetime(2024, 1, 1, tzinfo=tz1)
@@ -346,7 +343,6 @@ class Test(unittest.TestCase):
def testBucketIntervalNegativeCount(self):
"""Negative total_count should raise ValueError"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
with self.assertRaises(ValueError):
@@ -354,7 +350,6 @@ class Test(unittest.TestCase):
def testBucketIntervalZeroCount(self):
"""Zero total_count should return empty list"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 0)
@@ -362,14 +357,12 @@ class Test(unittest.TestCase):
def testBucketIntervalSameBeginEnd(self):
"""Same begin and end (zero interval) should return empty list"""
from datetime import datetime, timezone
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(dt, dt, 100)
self.assertEqual(result, [])
def testBucketIntervalSingleDay(self):
"""Single day interval should return one bucket with total count"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
@@ -379,7 +372,6 @@ class Test(unittest.TestCase):
def testBucketIntervalMultiDay(self):
"""Multi-day interval should distribute counts proportionally"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
@@ -392,7 +384,6 @@ class Test(unittest.TestCase):
def testBucketIntervalRemainderDistribution(self):
"""Odd count across equal days distributes remainder correctly"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 10)
@@ -402,7 +393,6 @@ class Test(unittest.TestCase):
def testBucketIntervalPartialDays(self):
"""Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 90)
@@ -418,7 +408,6 @@ class Test(unittest.TestCase):
def testAppendParsedRecordNoNormalize(self):
"""No normalization: record appended as-is with interval fields"""
from datetime import datetime, timezone
records = []
rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -431,7 +420,6 @@ class Test(unittest.TestCase):
def testAppendParsedRecordNormalize(self):
"""Normalization: record split into daily buckets"""
from datetime import datetime, timezone
records = []
rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -445,7 +433,6 @@ class Test(unittest.TestCase):
def testAppendParsedRecordNormalizeZeroCount(self):
"""Normalization with zero count: nothing appended"""
from datetime import datetime, timezone
records = []
rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -812,7 +799,6 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonValid(self):
"""Valid SMTP TLS JSON report parses correctly"""
import json
report = json.dumps({
"organization-name": "Example Corp",
"date-range": {
@@ -841,7 +827,6 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonBytes(self):
"""SMTP TLS report as bytes parses correctly"""
import json
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
@@ -857,14 +842,12 @@ class Test(unittest.TestCase):
def testParseSmtpTlsReportJsonMissingField(self):
"""Missing required field raises InvalidSMTPTLSReport"""
import json
report = json.dumps({"organization-name": "Org"})
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
parsedmarc.parse_smtp_tls_report_json(report)
def testParseSmtpTlsReportJsonPoliciesNotList(self):
"""Non-list policies raises InvalidSMTPTLSReport"""
import json
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
@@ -1186,7 +1169,6 @@ class Test(unittest.TestCase):
def testHumanTimestampToDatetime(self):
"""human_timestamp_to_datetime parses timestamp string"""
from datetime import datetime
dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00")
self.assertIsInstance(dt, datetime)
self.assertEqual(dt.year, 2024)
@@ -1195,7 +1177,6 @@ class Test(unittest.TestCase):
def testHumanTimestampToDatetimeUtc(self):
"""human_timestamp_to_datetime with to_utc=True returns UTC"""
from datetime import timezone
dt = parsedmarc.utils.human_timestamp_to_datetime(
"2024-01-01 12:00:00", to_utc=True
)
@@ -1263,9 +1244,6 @@ class Test(unittest.TestCase):
"""get_ip_address_info uses cache on second call"""
from expiringdict import ExpiringDict
cache = ExpiringDict(max_len=100, max_age_seconds=60)
# offline=True means reverse_dns is None, so cache is not populated
# Use offline=False with mock to test cache
from unittest.mock import patch
with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"):
info1 = parsedmarc.utils.get_ip_address_info(
"8.8.8.8", offline=False, cache=cache,
@@ -1353,7 +1331,6 @@ class Test(unittest.TestCase):
def testWebhookClientSaveMethods(self):
"""WebhookClient save methods call _send_to_webhook"""
from unittest.mock import MagicMock
from parsedmarc.webhook import WebhookClient
client = WebhookClient("http://a", "http://f", "http://t")
client.session = MagicMock()
@@ -1498,7 +1475,6 @@ class Test(unittest.TestCase):
def testSmtpTlsCsvRows(self):
"""parsed_smtp_tls_reports_to_csv_rows produces correct rows"""
import json
report_json = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"},
@@ -1575,6 +1551,103 @@ class Test(unittest.TestCase):
for r in report["records"]:
self.assertTrue(r["normalized_timespan"])
# ===================================================================
# Additional backward compatibility alias tests
# ===================================================================
def testGelfBackwardCompatAlias(self):
"""GelfClient forensic alias points to failure method"""
from parsedmarc.gelf import GelfClient
self.assertIs(
GelfClient.save_forensic_report_to_gelf,
GelfClient.save_failure_report_to_gelf,
)
def testS3BackwardCompatAlias(self):
"""S3Client forensic alias points to failure method"""
from parsedmarc.s3 import S3Client
self.assertIs(
S3Client.save_forensic_report_to_s3,
S3Client.save_failure_report_to_s3,
)
def testKafkaBackwardCompatAlias(self):
"""KafkaClient forensic alias points to failure method"""
from parsedmarc.kafkaclient import KafkaClient
self.assertIs(
KafkaClient.save_forensic_reports_to_kafka,
KafkaClient.save_failure_reports_to_kafka,
)
# ===================================================================
# Additional extract/parse tests
# ===================================================================
def testExtractReportFromFilePathNotFound(self):
"""extract_report_from_file_path raises ParserError for missing file"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report_from_file_path("nonexistent_file.xml")
def testExtractReportInvalidArchive(self):
"""extract_report raises ParserError for unrecognized binary content"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07")
def testParseAggregateReportFile(self):
"""parse_aggregate_report_file parses bytes input directly"""
print()
sample_path = "samples/aggregate/dmarcbis-draft-sample.xml"
print("Testing {0}: ".format(sample_path), end="")
with open(sample_path, "rb") as f:
data = f.read()
report = parsedmarc.parse_aggregate_report_file(
data, offline=True, always_use_local_files=True,
)
self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter")
self.assertEqual(report["policy_published"]["domain"], "example.com")
print("Passed!")
def testParseInvalidAggregateSample(self):
"""Test invalid aggregate samples are handled"""
print()
sample_paths = glob("samples/aggregate_invalid/*")
for sample_path in sample_paths:
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
print("Passed!")
def testParseReportFileWithBytes(self):
"""parse_report_file handles bytes input"""
with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f:
data = f.read()
result = parsedmarc.parse_report_file(
data, always_use_local_files=True, offline=True
)
self.assertEqual(result["report_type"], "aggregate")
def testFailureReportCsvRoundtrip(self):
"""Failure report CSV generation works on sample reports"""
print()
sample_paths = glob("samples/forensic/*.eml")
for sample_path in sample_paths:
print("Testing CSV for {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report)
self.assertIsNotNone(csv_output)
self.assertIn(",", csv_output)
rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report)
self.assertTrue(len(rows) > 0)
print("Passed!")
if __name__ == "__main__":
unittest.main(verbosity=2)