Refactor type casting in tests for improved clarity and consistency

This commit is contained in:
Sean Whalen
2026-03-26 02:22:37 -04:00
parent 78f03ffded
commit cc664540aa

View File

@@ -17,7 +17,7 @@ from io import BytesIO
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import cast
from typing import BinaryIO, cast
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@@ -2915,16 +2915,15 @@ class TestMaildirUidHandling(unittest.TestCase):
os.makedirs(os.path.join(d, subdir))
original_stat = os.stat
call_count = [0]
def stat_that_fails_once(path, *args, **kwargs):
"""Fail on the first call (UID check), pass through after."""
stat_that_fails_once.calls += 1
if stat_that_fails_once.calls == 1:
call_count[0] += 1
if call_count[0] == 1:
raise OSError("no stat")
return original_stat(path, *args, **kwargs)
stat_that_fails_once.calls = 0
with patch(
"parsedmarc.mail.maildir.os.stat", side_effect=stat_that_fails_once
):
@@ -4858,7 +4857,7 @@ class TestExtractReport(unittest.TestCase):
def close(self):
pass
result = parsedmarc.extract_report(NonSeekable(xml))
result = parsedmarc.extract_report(cast(BinaryIO, NonSeekable(xml)))
self.assertIn("<feedback>", result)
def testExtractReportInvalidContent(self):
@@ -4883,7 +4882,7 @@ class TestExtractReport(unittest.TestCase):
pass
with self.assertRaises(parsedmarc.ParserError):
parsedmarc.extract_report(TextStream())
parsedmarc.extract_report(cast(BinaryIO, TextStream()))
class TestMalformedXmlRecovery(unittest.TestCase):
@@ -4941,7 +4940,7 @@ class TestMalformedXmlRecovery(unittest.TestCase):
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
</record>
</feedback>"""
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
report = parsedmarc.parse_aggregate_report_xml(xml.decode(), offline=True)
self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input")
def testExpatErrorRaises(self):
@@ -5195,7 +5194,7 @@ Test body"""
def testMissingVersion(self):
"""Missing version defaults to None"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Version:")]
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Version:")]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
@@ -5205,7 +5204,7 @@ Test body"""
def testMissingUserAgent(self):
"""Missing user_agent defaults to None"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")]
lines = [ln for ln in report_str.split("\n") if not ln.startswith("User-Agent:")]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
report_str, self._make_sample(), self._default_msg_date(), offline=True
@@ -5259,7 +5258,7 @@ Test body"""
"""Missing reported_domain falls back to sample from domain"""
report_str = self._make_feedback_report()
lines = [
l for l in report_str.split("\n") if not l.startswith("Reported-Domain:")
ln for ln in report_str.split("\n") if not ln.startswith("Reported-Domain:")
]
report_str = "\n".join(lines)
report = parsedmarc.parse_failure_report(
@@ -5270,7 +5269,7 @@ Test body"""
def testMissingArrivalDateWithMsgDate(self):
"""Missing arrival_date uses msg_date fallback"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")]
report_str = "\n".join(lines)
msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
report = parsedmarc.parse_failure_report(
@@ -5281,11 +5280,14 @@ Test body"""
def testMissingArrivalDateNoMsgDateRaises(self):
"""Missing arrival_date with no msg_date raises"""
report_str = self._make_feedback_report()
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")]
report_str = "\n".join(lines)
with self.assertRaises(parsedmarc.InvalidFailureReport):
parsedmarc.parse_failure_report(
report_str, self._make_sample(), None, offline=True
report_str,
self._make_sample(),
cast(datetime, None), # intentionally None to test error path
offline=True,
)
@@ -5422,7 +5424,6 @@ Body"""
def testEmailWithAttachments(self):
"""parse_email with strip_attachment_payloads removes payloads"""
import email as email_mod
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.base import MIMEBase
@@ -5551,7 +5552,6 @@ class TestGetDmarcReportsFromMbox(unittest.TestCase):
def testMboxWithAggregateReport(self):
"""Mbox with aggregate report email is parsed"""
import email as email_mod
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
import gzip