mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-04-20 12:29:28 +00:00
Refactor type casting in tests for improved clarity and consistency
This commit is contained in:
32
tests.py
32
tests.py
@@ -17,7 +17,7 @@ from io import BytesIO
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from typing import cast
|
||||
from typing import BinaryIO, cast
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -2915,16 +2915,15 @@ class TestMaildirUidHandling(unittest.TestCase):
|
||||
os.makedirs(os.path.join(d, subdir))
|
||||
|
||||
original_stat = os.stat
|
||||
call_count = [0]
|
||||
|
||||
def stat_that_fails_once(path, *args, **kwargs):
|
||||
"""Fail on the first call (UID check), pass through after."""
|
||||
stat_that_fails_once.calls += 1
|
||||
if stat_that_fails_once.calls == 1:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise OSError("no stat")
|
||||
return original_stat(path, *args, **kwargs)
|
||||
|
||||
stat_that_fails_once.calls = 0
|
||||
|
||||
with patch(
|
||||
"parsedmarc.mail.maildir.os.stat", side_effect=stat_that_fails_once
|
||||
):
|
||||
@@ -4858,7 +4857,7 @@ class TestExtractReport(unittest.TestCase):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
result = parsedmarc.extract_report(NonSeekable(xml))
|
||||
result = parsedmarc.extract_report(cast(BinaryIO, NonSeekable(xml)))
|
||||
self.assertIn("<feedback>", result)
|
||||
|
||||
def testExtractReportInvalidContent(self):
|
||||
@@ -4883,7 +4882,7 @@ class TestExtractReport(unittest.TestCase):
|
||||
pass
|
||||
|
||||
with self.assertRaises(parsedmarc.ParserError):
|
||||
parsedmarc.extract_report(TextStream())
|
||||
parsedmarc.extract_report(cast(BinaryIO, TextStream()))
|
||||
|
||||
|
||||
class TestMalformedXmlRecovery(unittest.TestCase):
|
||||
@@ -4941,7 +4940,7 @@ class TestMalformedXmlRecovery(unittest.TestCase):
|
||||
<auth_results><spf><domain>example.com</domain><result>pass</result></spf></auth_results>
|
||||
</record>
|
||||
</feedback>"""
|
||||
report = parsedmarc.parse_aggregate_report_xml(xml, offline=True)
|
||||
report = parsedmarc.parse_aggregate_report_xml(xml.decode(), offline=True)
|
||||
self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input")
|
||||
|
||||
def testExpatErrorRaises(self):
|
||||
@@ -5195,7 +5194,7 @@ Test body"""
|
||||
def testMissingVersion(self):
|
||||
"""Missing version defaults to None"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Version:")]
|
||||
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Version:")]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
@@ -5205,7 +5204,7 @@ Test body"""
|
||||
def testMissingUserAgent(self):
|
||||
"""Missing user_agent defaults to None"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")]
|
||||
lines = [ln for ln in report_str.split("\n") if not ln.startswith("User-Agent:")]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), self._default_msg_date(), offline=True
|
||||
@@ -5259,7 +5258,7 @@ Test body"""
|
||||
"""Missing reported_domain falls back to sample from domain"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [
|
||||
l for l in report_str.split("\n") if not l.startswith("Reported-Domain:")
|
||||
ln for ln in report_str.split("\n") if not ln.startswith("Reported-Domain:")
|
||||
]
|
||||
report_str = "\n".join(lines)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
@@ -5270,7 +5269,7 @@ Test body"""
|
||||
def testMissingArrivalDateWithMsgDate(self):
|
||||
"""Missing arrival_date uses msg_date fallback"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
|
||||
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")]
|
||||
report_str = "\n".join(lines)
|
||||
msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
report = parsedmarc.parse_failure_report(
|
||||
@@ -5281,11 +5280,14 @@ Test body"""
|
||||
def testMissingArrivalDateNoMsgDateRaises(self):
|
||||
"""Missing arrival_date with no msg_date raises"""
|
||||
report_str = self._make_feedback_report()
|
||||
lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")]
|
||||
lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")]
|
||||
report_str = "\n".join(lines)
|
||||
with self.assertRaises(parsedmarc.InvalidFailureReport):
|
||||
parsedmarc.parse_failure_report(
|
||||
report_str, self._make_sample(), None, offline=True
|
||||
report_str,
|
||||
self._make_sample(),
|
||||
cast(datetime, None), # intentionally None to test error path
|
||||
offline=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -5422,7 +5424,6 @@ Body"""
|
||||
|
||||
def testEmailWithAttachments(self):
|
||||
"""parse_email with strip_attachment_payloads removes payloads"""
|
||||
import email as email_mod
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -5551,7 +5552,6 @@ class TestGetDmarcReportsFromMbox(unittest.TestCase):
|
||||
|
||||
def testMboxWithAggregateReport(self):
|
||||
"""Mbox with aggregate report email is parsed"""
|
||||
import email as email_mod
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.application import MIMEApplication
|
||||
import gzip
|
||||
|
||||
Reference in New Issue
Block a user