mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-03-28 01:12:54 +00:00
Improve tests: consolidate imports, use context managers, add subTest, add backward compat and coverage tests
Co-authored-by: seanthegeek <44679+seanthegeek@users.noreply.github.com>
This commit is contained in:
committed by
Sean Whalen
parent
8d23e23099
commit
fcbba3bf6b
212
tests.py
212
tests.py
@@ -3,12 +3,14 @@
|
||||
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from base64 import urlsafe_b64encode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
@@ -60,26 +62,25 @@ class Test(unittest.TestCase):
|
||||
# Example from Wikipedia Base64 article
|
||||
b64_str = "YW55IGNhcm5hbCBwbGVhcw"
|
||||
decoded_str = parsedmarc.utils.decode_base64(b64_str)
|
||||
assert decoded_str == b"any carnal pleas"
|
||||
self.assertEqual(decoded_str, b"any carnal pleas")
|
||||
|
||||
def testPSLDownload(self):
|
||||
"""Test Public Suffix List domain lookups"""
|
||||
subdomain = "foo.example.com"
|
||||
result = parsedmarc.utils.get_base_domain(subdomain)
|
||||
assert result == "example.com"
|
||||
self.assertEqual(result, "example.com")
|
||||
|
||||
# Test newer PSL entries
|
||||
subdomain = "e3191.c.akamaiedge.net"
|
||||
result = parsedmarc.utils.get_base_domain(subdomain)
|
||||
assert result == "c.akamaiedge.net"
|
||||
self.assertEqual(result, "c.akamaiedge.net")
|
||||
|
||||
def testExtractReportXMLComparator(self):
|
||||
"""Test XML comparator function"""
|
||||
xmlnice_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlnice = xmlnice_file.read()
|
||||
xmlnice_file.close()
|
||||
xmlchanged_file = open("samples/extract_report/changed-input.xml")
|
||||
xmlchanged = minify_xml(xmlchanged_file.read())
|
||||
xmlchanged_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlnice = f.read()
|
||||
with open("samples/extract_report/changed-input.xml") as f:
|
||||
xmlchanged = minify_xml(f.read())
|
||||
self.assertTrue(compare_xml(xmlnice, xmlnice))
|
||||
self.assertTrue(compare_xml(xmlchanged, xmlchanged))
|
||||
self.assertFalse(compare_xml(xmlnice, xmlchanged))
|
||||
@@ -94,21 +95,19 @@ class Test(unittest.TestCase):
|
||||
data = f.read()
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report(data)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
def testExtractReportXML(self):
|
||||
"""Test extract report function for XML input"""
|
||||
print()
|
||||
report_path = "samples/extract_report/nice-input.xml"
|
||||
print("Testing {0}: ".format(report_path), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(report_path)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
file = "samples/extract_report/nice-input.xml"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -126,9 +125,8 @@ class Test(unittest.TestCase):
|
||||
file = "samples/extract_report/nice-input.xml.gz"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -138,13 +136,11 @@ class Test(unittest.TestCase):
|
||||
file = "samples/extract_report/nice-input.xml.zip"
|
||||
print("Testing {0}: ".format(file), end="")
|
||||
xmlout = parsedmarc.extract_report_from_file_path(file)
|
||||
xmlin_file = open("samples/extract_report/nice-input.xml")
|
||||
xmlin = minify_xml(xmlin_file.read())
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/nice-input.xml") as f:
|
||||
xmlin = minify_xml(f.read())
|
||||
self.assertTrue(compare_xml(xmlout, xmlin))
|
||||
xmlin_file = open("samples/extract_report/changed-input.xml")
|
||||
xmlin = xmlin_file.read()
|
||||
xmlin_file.close()
|
||||
with open("samples/extract_report/changed-input.xml") as f:
|
||||
xmlin = f.read()
|
||||
self.assertFalse(compare_xml(xmlout, xmlin))
|
||||
print("Passed!")
|
||||
|
||||
@@ -178,11 +174,12 @@ class Test(unittest.TestCase):
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
result = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)
|
||||
assert result["report_type"] == "aggregate"
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
|
||||
with self.subTest(sample=sample_path):
|
||||
result = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)
|
||||
assert result["report_type"] == "aggregate"
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
|
||||
print("Passed!")
|
||||
|
||||
def testEmptySample(self):
|
||||
@@ -196,17 +193,18 @@ class Test(unittest.TestCase):
|
||||
sample_paths = glob("samples/failure/*.eml")
|
||||
for sample_path in sample_paths:
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with open(sample_path) as sample_file:
|
||||
sample_content = sample_file.read()
|
||||
email_result = parsedmarc.parse_report_email(
|
||||
sample_content, offline=OFFLINE_MODE
|
||||
with self.subTest(sample=sample_path):
|
||||
with open(sample_path) as sample_file:
|
||||
sample_content = sample_file.read()
|
||||
email_result = parsedmarc.parse_report_email(
|
||||
sample_content, offline=OFFLINE_MODE
|
||||
)
|
||||
assert email_result["report_type"] == "failure"
|
||||
result = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)
|
||||
assert email_result["report_type"] == "failure"
|
||||
result = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)
|
||||
assert result["report_type"] == "failure"
|
||||
parsedmarc.parsed_failure_reports_to_csv(result["report"])
|
||||
assert result["report_type"] == "failure"
|
||||
parsedmarc.parsed_failure_reports_to_csv(result["report"])
|
||||
print("Passed!")
|
||||
|
||||
def testFailureReportBackwardCompat(self):
|
||||
@@ -366,9 +364,12 @@ class Test(unittest.TestCase):
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
result = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE)
|
||||
assert result["report_type"] == "smtp_tls"
|
||||
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
|
||||
with self.subTest(sample=sample_path):
|
||||
result = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)
|
||||
assert result["report_type"] == "smtp_tls"
|
||||
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
|
||||
print("Passed!")
|
||||
|
||||
def testOpenSearchSigV4RequiresRegion(self):
|
||||
@@ -2386,7 +2387,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalBeginAfterEnd(self):
|
||||
"""begin > end should raise ValueError"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -2394,7 +2394,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalNaiveDatetime(self):
|
||||
"""Non-timezone-aware datetimes should raise ValueError"""
|
||||
from datetime import datetime
|
||||
begin = datetime(2024, 1, 1)
|
||||
end = datetime(2024, 1, 2)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -2402,7 +2401,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalDifferentTzinfo(self):
|
||||
"""Different tzinfo objects should raise ValueError"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
tz1 = timezone.utc
|
||||
tz2 = timezone(timedelta(hours=5))
|
||||
begin = datetime(2024, 1, 1, tzinfo=tz1)
|
||||
@@ -2412,7 +2410,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalNegativeCount(self):
|
||||
"""Negative total_count should raise ValueError"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -2420,7 +2417,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalZeroCount(self):
|
||||
"""Zero total_count should return empty list"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 2, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 0)
|
||||
@@ -2428,14 +2424,12 @@ watch = true
|
||||
|
||||
def testBucketIntervalSameBeginEnd(self):
|
||||
"""Same begin and end (zero interval) should return empty list"""
|
||||
from datetime import datetime, timezone
|
||||
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(dt, dt, 100)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def testBucketIntervalSingleDay(self):
|
||||
"""Single day interval should return one bucket with total count"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 1, 23, 59, 59, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
|
||||
@@ -2445,7 +2439,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalMultiDay(self):
|
||||
"""Multi-day interval should distribute counts proportionally"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 100)
|
||||
@@ -2458,7 +2451,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalRemainderDistribution(self):
|
||||
"""Odd count across equal days distributes remainder correctly"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 4, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 10)
|
||||
@@ -2468,7 +2460,6 @@ watch = true
|
||||
|
||||
def testBucketIntervalPartialDays(self):
|
||||
"""Partial days: 12h on day1, 24h on day2 => 1/3 vs 2/3 split"""
|
||||
from datetime import datetime, timezone
|
||||
begin = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
end = datetime(2024, 1, 3, 0, 0, 0, tzinfo=timezone.utc)
|
||||
result = parsedmarc._bucket_interval_by_day(begin, end, 90)
|
||||
@@ -2484,7 +2475,6 @@ watch = true
|
||||
|
||||
def testAppendParsedRecordNoNormalize(self):
|
||||
"""No normalization: record appended as-is with interval fields"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 10, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -2497,7 +2487,6 @@ watch = true
|
||||
|
||||
def testAppendParsedRecordNormalize(self):
|
||||
"""Normalization: record split into daily buckets"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 100, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -2511,7 +2500,6 @@ watch = true
|
||||
|
||||
def testAppendParsedRecordNormalizeZeroCount(self):
|
||||
"""Normalization with zero count: nothing appended"""
|
||||
from datetime import datetime, timezone
|
||||
records = []
|
||||
rec = {"count": 0, "source": {"ip_address": "1.2.3.4"}}
|
||||
begin = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -2878,7 +2866,6 @@ watch = true
|
||||
|
||||
def testParseSmtpTlsReportJsonValid(self):
|
||||
"""Valid SMTP TLS JSON report parses correctly"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Example Corp",
|
||||
"date-range": {
|
||||
@@ -2907,7 +2894,6 @@ watch = true
|
||||
|
||||
def testParseSmtpTlsReportJsonBytes(self):
|
||||
"""SMTP TLS report as bytes parses correctly"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
|
||||
@@ -2923,14 +2909,12 @@ watch = true
|
||||
|
||||
def testParseSmtpTlsReportJsonMissingField(self):
|
||||
"""Missing required field raises InvalidSMTPTLSReport"""
|
||||
import json
|
||||
report = json.dumps({"organization-name": "Org"})
|
||||
with self.assertRaises(parsedmarc.InvalidSMTPTLSReport):
|
||||
parsedmarc.parse_smtp_tls_report_json(report)
|
||||
|
||||
def testParseSmtpTlsReportJsonPoliciesNotList(self):
|
||||
"""Non-list policies raises InvalidSMTPTLSReport"""
|
||||
import json
|
||||
report = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01", "end-datetime": "2024-01-02"},
|
||||
@@ -3252,7 +3236,6 @@ watch = true
|
||||
|
||||
def testHumanTimestampToDatetime(self):
|
||||
"""human_timestamp_to_datetime parses timestamp string"""
|
||||
from datetime import datetime
|
||||
dt = parsedmarc.utils.human_timestamp_to_datetime("2024-01-01 00:00:00")
|
||||
self.assertIsInstance(dt, datetime)
|
||||
self.assertEqual(dt.year, 2024)
|
||||
@@ -3261,7 +3244,6 @@ watch = true
|
||||
|
||||
def testHumanTimestampToDatetimeUtc(self):
|
||||
"""human_timestamp_to_datetime with to_utc=True returns UTC"""
|
||||
from datetime import timezone
|
||||
dt = parsedmarc.utils.human_timestamp_to_datetime(
|
||||
"2024-01-01 12:00:00", to_utc=True
|
||||
)
|
||||
@@ -3329,9 +3311,6 @@ watch = true
|
||||
"""get_ip_address_info uses cache on second call"""
|
||||
from expiringdict import ExpiringDict
|
||||
cache = ExpiringDict(max_len=100, max_age_seconds=60)
|
||||
# offline=True means reverse_dns is None, so cache is not populated
|
||||
# Use offline=False with mock to test cache
|
||||
from unittest.mock import patch
|
||||
with patch("parsedmarc.utils.get_reverse_dns", return_value="dns.google"):
|
||||
info1 = parsedmarc.utils.get_ip_address_info(
|
||||
"8.8.8.8", offline=False, cache=cache,
|
||||
@@ -3419,7 +3398,6 @@ watch = true
|
||||
|
||||
def testWebhookClientSaveMethods(self):
|
||||
"""WebhookClient save methods call _send_to_webhook"""
|
||||
from unittest.mock import MagicMock
|
||||
from parsedmarc.webhook import WebhookClient
|
||||
client = WebhookClient("http://a", "http://f", "http://t")
|
||||
client.session = MagicMock()
|
||||
@@ -3564,7 +3542,6 @@ watch = true
|
||||
|
||||
def testSmtpTlsCsvRows(self):
|
||||
"""parsed_smtp_tls_reports_to_csv_rows produces correct rows"""
|
||||
import json
|
||||
report_json = json.dumps({
|
||||
"organization-name": "Org",
|
||||
"date-range": {"start-datetime": "2024-01-01T00:00:00Z", "end-datetime": "2024-01-02T00:00:00Z"},
|
||||
@@ -3641,6 +3618,103 @@ watch = true
|
||||
for r in report["records"]:
|
||||
self.assertTrue(r["normalized_timespan"])
|
||||
|
||||
# ===================================================================
|
||||
# Additional backward compatibility alias tests
|
||||
# ===================================================================
|
||||
|
||||
def testGelfBackwardCompatAlias(self):
|
||||
"""GelfClient forensic alias points to failure method"""
|
||||
from parsedmarc.gelf import GelfClient
|
||||
self.assertIs(
|
||||
GelfClient.save_forensic_report_to_gelf,
|
||||
GelfClient.save_failure_report_to_gelf,
|
||||
)
|
||||
|
||||
def testS3BackwardCompatAlias(self):
|
||||
"""S3Client forensic alias points to failure method"""
|
||||
from parsedmarc.s3 import S3Client
|
||||
self.assertIs(
|
||||
S3Client.save_forensic_report_to_s3,
|
||||
S3Client.save_failure_report_to_s3,
|
||||
)
|
||||
|
||||
def testKafkaBackwardCompatAlias(self):
|
||||
"""KafkaClient forensic alias points to failure method"""
|
||||
from parsedmarc.kafkaclient import KafkaClient
|
||||
self.assertIs(
|
||||
KafkaClient.save_forensic_reports_to_kafka,
|
||||
KafkaClient.save_failure_reports_to_kafka,
|
||||
)
|
||||
|
||||
# ===================================================================
|
||||
# Additional extract/parse tests
|
||||
# ===================================================================
|
||||
|
||||
def testExtractReportFromFilePathNotFound(self):
|
||||
"""extract_report_from_file_path raises ParserError for missing file"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report_from_file_path("nonexistent_file.xml")
|
||||
|
||||
def testExtractReportInvalidArchive(self):
|
||||
"""extract_report raises ParserError for unrecognized binary content"""
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report(b"\x00\x01\x02\x03\x04\x05\x06\x07")
|
||||
|
||||
def testParseAggregateReportFile(self):
|
||||
"""parse_aggregate_report_file parses bytes input directly"""
|
||||
print()
|
||||
sample_path = "samples/aggregate/dmarcbis-draft-sample.xml"
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with open(sample_path, "rb") as f:
|
||||
data = f.read()
|
||||
report = parsedmarc.parse_aggregate_report_file(
|
||||
data, offline=True, always_use_local_files=True,
|
||||
)
|
||||
self.assertEqual(report["report_metadata"]["org_name"], "Sample Reporter")
|
||||
self.assertEqual(report["policy_published"]["domain"], "example.com")
|
||||
print("Passed!")
|
||||
|
||||
def testParseInvalidAggregateSample(self):
|
||||
"""Test invalid aggregate samples are handled"""
|
||||
print()
|
||||
sample_paths = glob("samples/aggregate_invalid/*")
|
||||
for sample_path in sample_paths:
|
||||
if os.path.isdir(sample_path):
|
||||
continue
|
||||
print("Testing {0}: ".format(sample_path), end="")
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
|
||||
print("Passed!")
|
||||
|
||||
def testParseReportFileWithBytes(self):
|
||||
"""parse_report_file handles bytes input"""
|
||||
with open("samples/aggregate/dmarcbis-draft-sample.xml", "rb") as f:
|
||||
data = f.read()
|
||||
result = parsedmarc.parse_report_file(
|
||||
data, always_use_local_files=True, offline=True
|
||||
)
|
||||
self.assertEqual(result["report_type"], "aggregate")
|
||||
|
||||
def testFailureReportCsvRoundtrip(self):
|
||||
"""Failure report CSV generation works on sample reports"""
|
||||
print()
|
||||
sample_paths = glob("samples/forensic/*.eml")
|
||||
for sample_path in sample_paths:
|
||||
print("Testing CSV for {0}: ".format(sample_path), end="")
|
||||
with self.subTest(sample=sample_path):
|
||||
parsed_report = parsedmarc.parse_report_file(
|
||||
sample_path, offline=OFFLINE_MODE
|
||||
)["report"]
|
||||
csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report)
|
||||
self.assertIsNotNone(csv_output)
|
||||
self.assertIn(",", csv_output)
|
||||
rows = parsedmarc.parsed_failure_reports_to_csv_rows(parsed_report)
|
||||
self.assertTrue(len(rows) > 0)
|
||||
print("Passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user