Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests

Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-02-21 19:56:33 +00:00
committed by Sean Whalen
parent 99a5962d93
commit 25a071da7e

211
tests.py
View File

@@ -12,6 +12,7 @@ import tempfile
import unittest
from base64 import urlsafe_b64encode
from configparser import ConfigParser
from datetime import datetime, timedelta, timezone
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
@@ -63,26 +64,25 @@ class Test(unittest.TestCase):
# Example from Wikipedia Base64 article
b64_str = "YW55IGNhcm5hbCBwbGVhcw"
decoded_str = parsedmarc.utils.decode_base64(b64_str)
assert decoded_str == b"any carnal pleas"
self.assertEqual(decoded_str, b"any carnal pleas")
def testPSLDownload(self):
"""Test Public Suffix List domain lookups"""
subdomain = "foo.example.com"
result = parsedmarc.utils.get_base_domain(subdomain)
assert result == "example.com"
self.assertEqual(result, "example.com")
# Test newer PSL entries
subdomain = "e3191.c.akamaiedge.net"
result = parsedmarc.utils.get_base_domain(subdomain)
assert result == "c.akamaiedge.net"
self.assertEqual(result, "c.akamaiedge.net")
def testExtractReportXMLComparator(self):
"""Test XML comparator function"""
xmlnice_file = open("samples/extract_report/nice-input.xml")
xmlnice = xmlnice_file.read()
xmlnice_file.close()
xmlchanged_file = open("samples/extract_report/changed-input.xml")
xmlchanged = minify_xml(xmlchanged_file.read())
xmlchanged_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlnice = f.read()
with open("samples/extract_report/changed-input.xml") as f:
xmlchanged = minify_xml(f.read())
self.assertTrue(compare_xml(xmlnice, xmlnice))
self.assertTrue(compare_xml(xmlchanged, xmlchanged))
self.assertFalse(compare_xml(xmlnice, xmlchanged))
@@ -97,21 +97,19 @@ class Test(unittest.TestCase):
data = f.read()
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report(data)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
def testExtractReportXML(self):
"""Test extract report function for XML input"""
print()
report_path = "samples/extract_report/nice-input.xml"
print("Testing {0}: ".format(report_path), end="")
xmlout = parsedmarc.extract_report_from_file_path(report_path)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
file = "samples/extract_report/nice-input.xml"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -129,9 +127,8 @@ class Test(unittest.TestCase):
file = "samples/extract_report/nice-input.xml.gz"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = f.read()
self.assertTrue(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -141,13 +138,11 @@ class Test(unittest.TestCase):
file = "samples/extract_report/nice-input.xml.zip"
print("Testing {0}: ".format(file), end="")
xmlout = parsedmarc.extract_report_from_file_path(file)
xmlin_file = open("samples/extract_report/nice-input.xml")
xmlin = minify_xml(xmlin_file.read())
xmlin_file.close()
with open("samples/extract_report/nice-input.xml") as f:
xmlin = minify_xml(f.read())
self.assertTrue(compare_xml(xmlout, xmlin))
xmlin_file = open("samples/extract_report/changed-input.xml")
xmlin = xmlin_file.read()
xmlin_file.close()
with open("samples/extract_report/changed-input.xml") as f:
xmlin = f.read()
self.assertFalse(compare_xml(xmlout, xmlin))
print("Passed!")
@@ -181,11 +176,12 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)
assert result["report_type"] == "aggregate"
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
with self.subTest(sample=sample_path):
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)
assert result["report_type"] == "aggregate"
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
print("Passed!")
def testEmptySample(self):
@@ -199,17 +195,18 @@ class Test(unittest.TestCase):
sample_paths = glob("samples/failure/*.eml")
for sample_path in sample_paths:
print("Testing {0}: ".format(sample_path), end="")
with open(sample_path) as sample_file:
sample_content = sample_file.read()
email_result = parsedmarc.parse_report_email(
sample_content, offline=OFFLINE_MODE
with self.subTest(sample=sample_path):
with open(sample_path) as sample_file:
sample_content = sample_file.read()
email_result = parsedmarc.parse_report_email(
sample_content, offline=OFFLINE_MODE
)
assert email_result["report_type"] == "failure"
result = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)
assert email_result["report_type"] == "failure"
result = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)
assert result["report_type"] == "failure"
parsedmarc.parsed_failure_reports_to_csv(result["report"])
assert result["report_type"] == "failure"
parsedmarc.parsed_failure_reports_to_csv(result["report"])
print("Passed!")
def testFailureReportBackwardCompat(self):
@@ -369,9 +366,12 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
result = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE)
assert result["report_type"] == "smtp_tls"
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
with self.subTest(sample=sample_path):
result = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)
assert result["report_type"] == "smtp_tls"
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
print("Passed!")
def testOpenSearchSigV4RequiresRegion(self):
@@ -3185,7 +3185,6 @@ class TestEnvVarConfig(unittest.TestCase):
# ============================================================
def testBucketIntervalBeginAfterEnd(self):
"""begin > end should raise ValueError"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 2, tzinfo=timezone.utc)
end = datetime(2024, 1, 1, tzinfo=timezone.utc)
with self.assertRaises(ValueError):
@@ -3193,7 +3192,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalNaiveDatetime(self):
"""Non-timezone-aware datetimes should raise ValueError"""
from datetime import datetime
begin = datetime(2024, 1, 1)
end = datetime(2024, 1, 2)
with self.assertRaises(ValueError):
@@ -3201,7 +3199,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalDifferentTzinfo(self):
"""Different tzinfo objects should raise ValueError"""
from datetime import datetime, timezone, timedelta
tz1 = timezone.utc
tz2 = timezone(timedelta(hours=5))
begin = datetime(2024, 1, 1, tzinfo=tz1)
@@ -3211,7 +3208,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalNegativeCount(self):
"""Negative total_count should raise ValueError"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
with self.assertRaises(ValueError):
@@ -3219,7 +3215,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalZeroCount(self):
"""Zero total_count should return empty list"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 0)
@@ -3227,14 +3222,12 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalSameBeginEnd(self):
"""Same begin and end (zero interval) should return empty list"""
from datetime import datetime, timezone
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(dt, dt, 100)
self.assertEqual(result, [])
def testBucketIntervalSingleDay(self):
"""Single day interval should return one bucket with total count"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
@@ -3244,7 +3237,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalMultiDay(self):
"""Multi-day interval should distribute counts proportionally"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
@@ -3257,7 +3249,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalRemainderDistribution(self):
"""Odd count across equal days distributes remainder correctly"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 10)
@@ -3267,7 +3258,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testBucketIntervalPartialDays(self):
"""Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split"""
from datetime import datetime, timezone
begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
result = parsedmarc._bucket_interval_by_day(begin, end, 90)
@@ -3281,7 +3271,6 @@ class TestEnvVarConfig(unittest.TestCase):
# ============================================================
def testAppendParsedRecordNoNormalize(self):
"""No normalization: record appended as-is with interval fields"""
from datetime import datetime, timezone
records = []
rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -3294,7 +3283,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testAppendParsedRecordNormalize(self):
"""Normalization: record split into daily buckets"""
from datetime import datetime, timezone
records = []
rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -3308,7 +3296,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testAppendParsedRecordNormalizeZeroCount(self):
"""Normalization with zero count: nothing appended"""
from datetime import datetime, timezone
records = []
rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}}
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -3667,7 +3654,6 @@ class TestEnvVarConfig(unittest.TestCase):
# ============================================================
def testParseSmtpTlsReportJsonValid(self):
"""Valid SMTP TLS JSON report parses correctly"""
import json
report = json.dumps({
"organization-name": "Example Corp",
"date-range": {
@@ -3696,7 +3682,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testParseSmtpTlsReportJsonBytes(self):
"""SMTP TLS report as bytes parses correctly"""
import json
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
@@ -3712,14 +3697,12 @@ class TestEnvVarConfig(unittest.TestCase):
def testParseSmtpTlsReportJsonMissingField(self):
"""Missing required field raises InvalidSMTPTLSReport"""
import json
report = json.dumps({"organization-name": "Org"})
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
parsedmarc.parse_smtp_tls_report_json(report)
def testParseSmtpTlsReportJsonPoliciesNotList(self):
"""Non-list policies raises InvalidSMTPTLSReport"""
import json
report = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
@@ -4037,7 +4020,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testHumanTimestampToDatetime(self):
"""human_timestamp_to_datetime parses timestamp string"""
from datetime import datetime
dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00")
self.assertIsInstance(dt, datetime)
self.assertEqual(dt.year, 2024)
@@ -4046,7 +4028,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testHumanTimestampToDatetimeUtc(self):
"""human_timestamp_to_datetime with to_utc=True returns UTC"""
from datetime import timezone
dt = parsedmarc.utils.human_timestamp_to_datetime(
"2024-01-01 12:00:00", to_utc=True
)
@@ -4114,9 +4095,6 @@ class TestEnvVarConfig(unittest.TestCase):
"""get_ip_address_info uses cache on second call"""
from expiringdict import ExpiringDict
cache = ExpiringDict(max_len=100, max_age_seconds=60)
# offline=True means reverse_dns is None, so cache is not populated
# Use offline=False with mock to test cache
from unittest.mock import patch
with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"):
info1 = parsedmarc.utils.get_ip_address_info(
"8.8.8.8", offline=False, cache=cache,
@@ -4202,7 +4180,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testWebhookClientSaveMethods(self):
"""WebhookClient save methods call _send_to_webhook"""
from unittest.mock import MagicMock
from parsedmarc.webhook import WebhookClient
client = WebhookClient("http://a", "http://f", "http://t")
client.session = MagicMock()
@@ -4347,7 +4324,6 @@ class TestEnvVarConfig(unittest.TestCase):
def testSmtpTlsCsvRows(self):
"""parsed_smtp_tls_reports_to_csv_rows produces correct rows"""
import json
report_json = json.dumps({
"organization-name": "Org",
"date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"},
@@ -4424,6 +4400,103 @@ class TestEnvVarConfig(unittest.TestCase):
for r in report["records"]:
self.assertTrue(r["normalized_timespan"])
# ===================================================================
# Additional backward compatibility alias tests
# ===================================================================
def testGelfBackwardCompatAlias(self):
"""GelfClient forensic alias points to failure method"""
from parsedmarc.gelf import GelfClient
self.assertIs(
GelfClient.save_forensic_report_to_gelf,
GelfClient.save_failure_report_to_gelf,
)
def testS3BackwardCompatAlias(self):
"""S3Client forensic alias points to failure method"""
from parsedmarc.s3 import S3Client
self.assertIs(
S3Client.save_forensic_report_to_s3,
S3Client.save_failure_report_to_s3,
)
def testKafkaBackwardCompatAlias(self):
"""KafkaClient forensic alias points to failure method"""
from parsedmarc.kafkaclient import KafkaClient
self.assertIs(
KafkaClient.save_forensic_reports_to_kafka,
KafkaClient.save_failure_reports_to_kafka,
)
# ===================================================================
# Additional extract/parse tests
# ===================================================================
def testExtractReportFromFilePathNotFound(self):
"""extract_report_from_file_path raises ParserError for missing file"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report_from_file_path("nonexistent_file.xml")
def testExtractReportInvalidArchive(self):
"""extract_report raises ParserError for unrecognized binary content"""
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07")
def testParseAggregateReportFile(self):
"""parse_aggregate_report_file parses bytes input directly"""
print()
sample_path = "samples/aggregate/dmarcbis-draft-sample.xml"
print("Testing {0}: ".format(sample_path), end="")
with open(sample_path, "rb") as f:
data = f.read()
report = parsedmarc.parse_aggregate_report_file(
data, offline=True, always_use_local_files=True,
)
self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter")
self.assertEqual(report["policy_published"]["domain"], "example.com")
print("Passed!")
def testParseInvalidAggregateSample(self):
"""Test invalid aggregate samples are handled"""
print()
sample_paths = glob("samples/aggregate_invalid/*")
for sample_path in sample_paths:
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
print("Passed!")
def testParseReportFileWithBytes(self):
"""parse_report_file handles bytes input"""
with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f:
data = f.read()
result = parsedmarc.parse_report_file(
data, always_use_local_files=True, offline=True
)
self.assertEqual(result["report_type"], "aggregate")
def testFailureReportCsvRoundtrip(self):
"""Failure report CSV generation works on sample reports"""
print()
sample_paths = glob("samples/forensic/*.eml")
for sample_path in sample_paths:
print("Testing CSV for {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report)
self.assertIsNotNone(csv_output)
self.assertIn(",", csv_output)
rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report)
self.assertTrue(len(rows) > 0)
print("Passed!")
if __name__ == "__main__":
unittest.main(verbosity=2)