Refactor tests to use type casting for report objects in parsedmarc

This commit is contained in:
Sean Whalen
2026-03-21 17:26:06 -04:00
parent 901bda384f
commit a4b741ce12
+30 -27
View File
@@ -27,6 +27,7 @@ from imapclient.exceptions import IMAPClientError
import parsedmarc
import parsedmarc.cli
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
from parsedmarc.mail.gmail import GmailConnection
from parsedmarc.mail.gmail import _get_creds
from parsedmarc.mail.graph import MSGraphConnection
@@ -155,7 +156,8 @@ class Test(unittest.TestCase):
offline=True,
)
assert result["report_type"] == "aggregate"
self.assertEqual(result["report"]["report_metadata"]["org_name"], "outlook.com")
report = cast(AggregateReport, result["report"])
self.assertEqual(report["report_metadata"]["org_name"], "outlook.com")
def testParseReportFileAcceptsPathForEmail(self):
report_path = Path(
@@ -166,7 +168,8 @@ class Test(unittest.TestCase):
offline=True,
)
assert result["report_type"] == "aggregate"
self.assertEqual(result["report"]["report_metadata"]["org_name"], "google.com")
report = cast(AggregateReport, result["report"])
self.assertEqual(report["report_metadata"]["org_name"], "google.com")
def testAggregateSamples(self):
"""Test sample aggregate/rua DMARC reports"""
@@ -181,7 +184,7 @@ class Test(unittest.TestCase):
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)
assert result["report_type"] == "aggregate"
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
parsedmarc.parsed_aggregate_reports_to_csv(cast(AggregateReport, result["report"]))
print("Passed!")
def testEmptySample(self):
@@ -206,7 +209,7 @@ class Test(unittest.TestCase):
sample_path, offline=OFFLINE_MODE
)
assert result["report_type"] == "failure"
parsedmarc.parsed_failure_reports_to_csv(result["report"])
parsedmarc.parsed_failure_reports_to_csv(cast(FailureReport, result["report"]))
print("Passed!")
def testFailureReportBackwardCompat(self):
@@ -238,7 +241,7 @@ class Test(unittest.TestCase):
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=True
)
report = result["report"]
report = cast(AggregateReport, result["report"])
# Verify report_type
self.assertEqual(result["report_type"], "aggregate")
@@ -323,7 +326,7 @@ class Test(unittest.TestCase):
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=True
)
report = result["report"]
report = cast(AggregateReport, result["report"])
pp = report["policy_published"]
# RFC 7489 fields present
@@ -350,7 +353,7 @@ class Test(unittest.TestCase):
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=True
)
report = result["report"]
report = cast(AggregateReport, result["report"])
pp = report["policy_published"]
self.assertEqual(pp["np"], "reject")
@@ -371,7 +374,7 @@ class Test(unittest.TestCase):
sample_path, offline=OFFLINE_MODE
)
assert result["report_type"] == "smtp_tls"
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
parsedmarc.parsed_smtp_tls_reports_to_csv(cast(SMTPTLSReport, result["report"]))
print("Passed!")
def testOpenSearchSigV4RequiresRegion(self):
@@ -3277,7 +3280,7 @@ class TestEnvVarConfig(unittest.TestCase):
end = datetime(2024, 1, 2, 0, 0, 0, tzinfo=timezone.utc)
parsedmarc._append_parsed_record(rec, records, begin, end, False)
self.assertEqual(len(records), 1)
self.assertFalse(records[0]["normalized_timespan"])
self.assertFalse(records[0]["normalized_timespan"]) # type: ignore[typeddict-item]
self.assertEqual(records[0]["interval_begin"], "2024-01-01 00:00:00")
self.assertEqual(records[0]["interval_end"], "2024-01-02 00:00:00")
@@ -3292,7 +3295,7 @@ class TestEnvVarConfig(unittest.TestCase):
total = sum(r["count"] for r in records)
self.assertEqual(total, 100)
for r in records:
self.assertTrue(r["normalized_timespan"])
self.assertTrue(r["normalized_timespan"]) # type: ignore[typeddict-item]
def testAppendParsedRecordNormalizeZeroCount(self):
"""Normalization with zero count: nothing appended"""
@@ -3838,7 +3841,7 @@ class TestEnvVarConfig(unittest.TestCase):
"samples/aggregate/dmarcbis-draft-sample.xml",
always_use_local_files=True, offline=True,
)
report = result["report"]
report = cast(AggregateReport, result["report"])
rows = parsedmarc.parsed_aggregate_reports_to_csv_rows(report)
self.assertTrue(len(rows) > 0)
row = rows[0]
@@ -4109,7 +4112,7 @@ class TestEnvVarConfig(unittest.TestCase):
def testParseEmailAddressWithDisplayName(self):
"""parse_email_address with display name"""
result = parsedmarc.utils.parse_email_address(("John Doe", "john@example.com"))
result = parsedmarc.utils.parse_email_address(("John Doe", "john@example.com")) # type: ignore[arg-type]
self.assertEqual(result["display_name"], "John Doe")
self.assertEqual(result["address"], "john@example.com")
self.assertEqual(result["local"], "john")
@@ -4117,13 +4120,13 @@ class TestEnvVarConfig(unittest.TestCase):
def testParseEmailAddressWithoutDisplayName(self):
"""parse_email_address with empty display name"""
result = parsedmarc.utils.parse_email_address(("", "john@example.com"))
result = parsedmarc.utils.parse_email_address(("", "john@example.com")) # type: ignore[arg-type]
self.assertIsNone(result["display_name"])
self.assertEqual(result["address"], "john@example.com")
def testParseEmailAddressNoAt(self):
"""parse_email_address with no @ returns None local/domain"""
result = parsedmarc.utils.parse_email_address(("", "localonly"))
result = parsedmarc.utils.parse_email_address(("", "localonly")) # type: ignore[arg-type]
self.assertIsNone(result["local"])
self.assertIsNone(result["domain"])
@@ -4137,7 +4140,7 @@ class TestEnvVarConfig(unittest.TestCase):
def testGetFilenameSafeStringNone(self):
"""get_filename_safe_string with None returns 'None'"""
result = parsedmarc.utils.get_filename_safe_string(None)
result = parsedmarc.utils.get_filename_safe_string(None) # type: ignore[arg-type]
self.assertEqual(result, "None")
def testGetFilenameSafeStringLong(self):
@@ -4196,7 +4199,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""WebhookClient forensic alias points to failure method"""
from parsedmarc.webhook import WebhookClient
self.assertIs(
WebhookClient.save_forensic_report_to_webhook,
WebhookClient.save_forensic_report_to_webhook, # type: ignore[attr-defined]
WebhookClient.save_failure_report_to_webhook,
)
@@ -4261,7 +4264,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""HECClient forensic alias points to failure method"""
from parsedmarc.splunk import HECClient
self.assertIs(
HECClient.save_forensic_reports_to_splunk,
HECClient.save_forensic_reports_to_splunk, # type: ignore[attr-defined]
HECClient.save_failure_reports_to_splunk,
)
@@ -4283,7 +4286,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""SyslogClient forensic alias points to failure method"""
from parsedmarc.syslog import SyslogClient
self.assertIs(
SyslogClient.save_forensic_report_to_syslog,
SyslogClient.save_forensic_report_to_syslog, # type: ignore[attr-defined]
SyslogClient.save_failure_report_to_syslog,
)
@@ -4350,7 +4353,7 @@ class TestEnvVarConfig(unittest.TestCase):
"samples/aggregate/dmarcbis-draft-sample.xml",
always_use_local_files=True, offline=True,
)
report = result["report"]
report = cast(AggregateReport, result["report"])
# Pass as a list
rows = parsedmarc.parsed_aggregate_reports_to_csv_rows([report])
self.assertTrue(len(rows) > 0)
@@ -4400,7 +4403,7 @@ class TestEnvVarConfig(unittest.TestCase):
total = sum(r["count"] for r in report["records"])
self.assertEqual(total, 90)
for r in report["records"]:
self.assertTrue(r["normalized_timespan"])
self.assertTrue(r["normalized_timespan"]) # type: ignore[typeddict-item]
# ===================================================================
# Additional backward compatibility alias tests
@@ -4410,7 +4413,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""GelfClient forensic alias points to failure method"""
from parsedmarc.gelf import GelfClient
self.assertIs(
GelfClient.save_forensic_report_to_gelf,
GelfClient.save_forensic_report_to_gelf, # type: ignore[attr-defined]
GelfClient.save_failure_report_to_gelf,
)
@@ -4418,7 +4421,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""S3Client forensic alias points to failure method"""
from parsedmarc.s3 import S3Client
self.assertIs(
S3Client.save_forensic_report_to_s3,
S3Client.save_forensic_report_to_s3, # type: ignore[attr-defined]
S3Client.save_failure_report_to_s3,
)
@@ -4426,7 +4429,7 @@ class TestEnvVarConfig(unittest.TestCase):
"""KafkaClient forensic alias points to failure method"""
from parsedmarc.kafkaclient import KafkaClient
self.assertIs(
KafkaClient.save_forensic_reports_to_kafka,
KafkaClient.save_forensic_reports_to_kafka, # type: ignore[attr-defined]
KafkaClient.save_failure_reports_to_kafka,
)
@@ -4467,9 +4470,9 @@ class TestEnvVarConfig(unittest.TestCase):
continue
print("Testing {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
parsed_report = cast(AggregateReport, parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
)["report"])
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
print("Passed!")
@@ -4489,9 +4492,9 @@ class TestEnvVarConfig(unittest.TestCase):
for sample_path in sample_paths:
print("Testing CSV for {0}: ".format(sample_path), end="")
with self.subTest(sample=sample_path):
parsed_report = parsedmarc.parse_report_file(
parsed_report = cast(FailureReport, parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
)["report"])
csv_output = parsedmarc.parsed_failure_reports_to_csv(parsed_report)
self.assertIsNotNone(csv_output)
self.assertIn(",", csv_output)