mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-03-22 14:32:46 +00:00
Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests
Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com>
This commit is contained in:
201
tests.py
201
tests.py
@@ -3,9 +3,12 @@
|
||||
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from glob import glob
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from lxml import etree
|
||||
|
||||
@@ -35,26 +38,25 @@ class Test(unittest.TestCase):
|
||||
# Example from Wikipedia Base64 article
|
||||
b64_str = "YW55IGNhcm5hbCBwbGVhcw"
|
||||
decoded_str = parsedmarc.utils.decode_base64(b64_str)
|
||||
assert decoded_str == b"any carnal pleas"
|
||||
self.assertEqual(decoded_str, b"any carnal pleas")
|
||||
|
||||
def testPSLDownload(self):
|
||||
"""Test Public Suffix List domain lookups"""
|
||||
subdomain = "foo.example.com"
|
||||
result = parsedmarc.utils.get_base_domain(subdomain)
|
||||
assert result == "example.com"
|
||||
self.assertEqual(result, "example.com")
|
||||
|
||||
# Test newer PSL entries
|
||||
subdomain = "e3191.c.akamaiedge.net"
|
||||
result = parsedmarc.utils.get_base_domain(subdomain)
|
||||
assert result == "c.akamaiedge.net"
|
||||
self.assertEqual(result, "c.akamaiedge.net")
|
||||
|
||||
def testExtractReportXMLComparator(self):
|
||||
"""Test XML comparator function"""
|
||||
xmlnice_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlnice = xmlnice_file.read()
|
||||
xmlnice_file.close()
|
||||
xmlchanged_file = open("samples/extract_report/changed-input.xml")
|
||||
xmlchanged = minify_xml(xmlchanged_file.read())
|
||||
xmlchanged_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlnice = f.read()
|
||||
with open("samples/extract_report/changed-input.xml") as f:
|
||||
xmlchanged = minify_xml(f.read())
|
||||
self.assertTrue(compare_xml(xmlnice, xmlnice))
|
||||
self.assertTrue(compare_xml(xmlchanged, xmlchanged))
|
||||
self.assertFalse(compare_xml(xmlnice, xmlchanged))
|
||||
@@ -69,9 +71,8 @@ class Test(unittest.TestCase):
|
||||
data = f.read()
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report(data)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -81,9 +82,8 @@ class Test(unittest.TestCase):
|
||||
file = "samples/extract_report/nice-input.xml"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -93,9 +93,8 @@ class Test(unittest.TestCase):
|
||||
file = "samples/extract_report/nice-input.xml.gz"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -105,13 +104,11 @@ class Test(unittest.TestCase):
|
||||
file = "samples/extract_report/nice-input.xml.zip"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = minify_xml(xmlin_file.read())
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = minify_xml(f.read())
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
xmlin_file = open("samples/extract_report/changed-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/changed-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertFalse(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -123,10 +120,11 @@ class Test(unittest.TestCase):
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
|
||||
print("Passed!")
|
||||
|
||||
def testEmptySample(self):
|
||||
@@ -140,15 +138,16 @@ class Test(unittest.TestCase):
|
||||
sample_paths = glob("samples/forensic/*.eml")
|
||||
for sample_path in sample_paths:
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with open(sample_path) as sample_file:
|
||||
sample_content = sample_file.read()
|
||||
parsed_report = parsedmarc.parse_report_email(
|
||||
sample_content, offline=OFFLINE_MODE
|
||||
with self.subTest(sample=sample_path):
|
||||
with open(sample_path) as sample_file:
|
||||
sample_content = sample_file.read()
|
||||
parsed_report = parsedmarc.parse_report_email(
|
||||
sample_content, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_failure_reports_to_csv(parsed_report)
|
||||
parsedmarc.parsed_failure_reports_to_csv(parsed_report)
|
||||
print("Passed!")
|
||||
|
||||
def testFailureReportBackwardCompat(self):
|
||||
@@ -308,10 +307,11 @@ class Test(unittest.TestCase):
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report)
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report)
|
||||
print("Passed!")
|
||||
|
||||
# ===================================================================
|
||||
@@ -320,7 +320,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalBeginAfterEnd(self):
|
||||
"""begin > end should raise ValueError"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -328,7 +327,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalNaiveDatetime(self):
|
||||
"""Non-timezone-aware datetimes should raise ValueError"""
|
||||
from datetime import datetime
|
||||
begin = datetime(2024, 1, 1)
|
||||
end = datetime(2024, 1, 2)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -336,7 +334,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalDifferentTzinfo(self):
|
||||
"""Different tzinfo objects should raise ValueError"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
tz1 = timezone.utc
|
||||
tz2 = timezone(timedelta(hours=5))
|
||||
begin = datetime(2024, 1, 1, tzinfo=tz1)
|
||||
@@ -346,7 +343,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalNegativeCount(self):
|
||||
"""Negative total_count should raise ValueError"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -354,7 +350,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalZeroCount(self):
|
||||
"""Zero total_count should return empty list"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 0)
|
||||
@@ -362,14 +357,12 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalSameBeginEnd(self):
|
||||
"""Same begin and end (zero interval) should return empty list"""
|
||||
from datetime import datetime, timezone
|
||||
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(dt, dt, 100)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def testBucketIntervalSingleDay(self):
|
||||
"""Single day interval should return one bucket with total count"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
|
||||
@@ -379,7 +372,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalMultiDay(self):
|
||||
"""Multi-day interval should distribute counts proportionally"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
|
||||
@@ -392,7 +384,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalRemainderDistribution(self):
|
||||
"""Odd count across equal days distributes remainder correctly"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 10)
|
||||
@@ -402,7 +393,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testBucketIntervalPartialDays(self):
|
||||
"""Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 90)
|
||||
@@ -418,7 +408,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testAppendParsedRecordNoNormalize(self):
|
||||
"""No normalization: record appended as-is with interval fields"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -431,7 +420,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testAppendParsedRecordNormalize(self):
|
||||
"""Normalization: record split into daily buckets"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -445,7 +433,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testAppendParsedRecordNormalizeZeroCount(self):
|
||||
"""Normalization with zero count: nothing appended"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -812,7 +799,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testParseSmtpTlsReportJsonValid(self):
|
||||
"""Valid SMTP TLS JSON report parses correctly"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Example Corp",
|
||||
"date-range": {
|
||||
@@ -841,7 +827,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testParseSmtpTlsReportJsonBytes(self):
|
||||
"""SMTP TLS report as bytes parses correctly"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
|
||||
@@ -857,14 +842,12 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testParseSmtpTlsReportJsonMissingField(self):
|
||||
"""Missing required field raises InvalidSMTPTLSReport"""
|
||||
import json
|
||||
report = json.dumps({"organization-name": "Org"})
|
||||
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
|
||||
parsedmarc.parse_smtp_tls_report_json(report)
|
||||
|
||||
def testParseSmtpTlsReportJsonPoliciesNotList(self):
|
||||
"""Non-list policies raises InvalidSMTPTLSReport"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
|
||||
@@ -1186,7 +1169,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testHumanTimestampToDatetime(self):
|
||||
"""human_timestamp_to_datetime parses timestamp string"""
|
||||
from datetime import datetime
|
||||
dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00")
|
||||
self.assertIsInstance(dt, datetime)
|
||||
self.assertEqual(dt.year, 2024)
|
||||
@@ -1195,7 +1177,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testHumanTimestampToDatetimeUtc(self):
|
||||
"""human_timestamp_to_datetime with to_utc=True returns UTC"""
|
||||
from datetime import timezone
|
||||
dt = parsedmarc.utils.human_timestamp_to_datetime(
|
||||
"2024-01-01 12:00:00", to_utc=True
|
||||
)
|
||||
@@ -1263,9 +1244,6 @@ class Test(unittest.TestCase):
|
||||
"""get_ip_address_info uses cache on second call"""
|
||||
from expiringdict import ExpiringDict
|
||||
cache = ExpiringDict(max_len=100, max_age_seconds=60)
|
||||
# offline=True means reverse_dns is None, so cache is not populated
|
||||
# Use offline=False with mock to test cache
|
||||
from unittest.mock import patch
|
||||
with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"):
|
||||
info1 = parsedmarc.utils.get_ip_address_info(
|
||||
"8.8.8.8", offline=False, cache=cache,
|
||||
@@ -1353,7 +1331,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testWebhookClientSaveMethods(self):
|
||||
"""WebhookClient save methods call _send_to_webhook"""
|
||||
from unittest.mock import MagicMock
|
||||
from parsedmarc.webhook import WebhookClient
|
||||
client = WebhookClient("http://a", "http://f", "http://t")
|
||||
client.session = MagicMock()
|
||||
@@ -1498,7 +1475,6 @@ class Test(unittest.TestCase):
|
||||
|
||||
def testSmtpTlsCsvRows(self):
|
||||
"""parsed_smtp_tls_reports_to_csv_rows produces correct rows"""
|
||||
import json
|
||||
report_json = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"},
|
||||
@@ -1575,6 +1551,103 @@ class Test(unittest.TestCase):
|
||||
for r in report["records"]:
|
||||
self.assertTrue(r["normalized_timespan"])
|
||||
|
||||
# ===================================================================
|
||||
# Additional backward compatibility alias tests
|
||||
# ===================================================================
|
||||
|
||||
def testGelfBackwardCompatAlias(self):
|
||||
"""GelfClient forensic alias points to failure method"""
|
||||
from parsedmarc.gelf import GelfClient
|
||||
self.assertIs(
|
||||
GelfClient.save_forensic_report_to_gelf,
|
||||
GelfClient.save_failure_report_to_gelf,
|
||||
)
|
||||
|
||||
def testS3BackwardCompatAlias(self):
|
||||
"""S3Client forensic alias points to failure method"""
|
||||
from parsedmarc.s3 import S3Client
|
||||
self.assertIs(
|
||||
S3Client.save_forensic_report_to_s3,
|
||||
S3Client.save_failure_report_to_s3,
|
||||
)
|
||||
|
||||
def testKafkaBackwardCompatAlias(self):
|
||||
"""KafkaClient forensic alias points to failure method"""
|
||||
from parsedmarc.kafkaclient import KafkaClient
|
||||
self.assertIs(
|
||||
KafkaClient.save_forensic_reports_to_kafka,
|
||||
KafkaClient.save_failure_reports_to_kafka,
|
||||
)
|
||||
|
||||
# ===================================================================
|
||||
# Additional extract/parse tests
|
||||
# ===================================================================
|
||||
|
||||
def testExtractReportFromFilePathNotFound(self):
|
||||
"""extract_report_from_file_path raises ParserError for missing file"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report_from_file_path("nonexistent_file.xml")
|
||||
|
||||
def testExtractReportInvalidArchive(self):
|
||||
"""extract_report raises ParserError for unrecognized binary content"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07")
|
||||
|
||||
def testParseAggregateReportFile(self):
|
||||
"""parse_aggregate_report_file parses bytes input directly"""
|
||||
print()
|
||||
sample_path = "samples/aggregate/dmarcbis-draft-sample.xml"
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with open(sample_path, "rb") as f:
|
||||
data = f.read()
|
||||
report = parsedmarc.parse_aggregate_report_file(
|
||||
data, offline=True, always_use_local_files=True,
|
||||
)
|
||||
self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter")
|
||||
self.assertEqual(report["policy_published"]["domain"], "example.com")
|
||||
print("Passed!")
|
||||
|
||||
def testParseInvalidAggregateSample(self):
|
||||
"""Test invalid aggregate samples are handled"""
|
||||
print()
|
||||
sample_paths = glob("samples/aggregate_invalid/*")
|
||||
for sample_path in sample_paths:
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
|
||||
print("Passed!")
|
||||
|
||||
def testParseReportFileWithBytes(self):
|
||||
"""parse_report_file handles bytes input"""
|
||||
with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f:
|
||||
data = f.read()
|
||||
result = parsedmarc.parse_report_file(
|
||||
data, always_use_local_files=True, offline=True
|
||||
)
|
||||
self.assertEqual(result["report_type"], "aggregate")
|
||||
|
||||
def testFailureReportCsvRoundtrip(self):
|
||||
"""Failure report CSV generation works on sample reports"""
|
||||
print()
|
||||
sample_paths = glob("samples/forensic/*.eml")
|
||||
for sample_path in sample_paths:
|
||||
print("Testing CSV for {0}: ".format(sample_path), end="")
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report)
|
||||
self.assertIsNotNone(csv_output)
|
||||
self.assertIn(",", csv_output)
|
||||
rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report)
|
||||
self.assertTrue(len(rows) > 0)
|
||||
print("Passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user