From 2e34eff62a34ddd3bb8de8c6603f77818e0d92f0 Mon Sep 17 00:00:00 2001 From: Sean Whalen Date: Thu, 26 Mar 2026 02:22:37 -0400 Subject: [PATCH] Refactor type casting in tests for improved clarity and consistency --- tests.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests.py b/tests.py index cda613d..6fff648 100755 --- a/tests.py +++ b/tests.py @@ -17,7 +17,7 @@ from io import BytesIO from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import cast +from typing import BinaryIO, cast from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -2915,16 +2915,15 @@ class TestMaildirUidHandling(unittest.TestCase): os.makedirs(os.path.join(d, subdir)) original_stat = os.stat + call_count = [0] def stat_that_fails_once(path, *args, **kwargs): """Fail on the first call (UID check), pass through after.""" - stat_that_fails_once.calls += 1 - if stat_that_fails_once.calls == 1: + call_count[0] += 1 + if call_count[0] == 1: raise OSError("no stat") return original_stat(path, *args, **kwargs) - stat_that_fails_once.calls = 0 - with patch( "parsedmarc.mail.maildir.os.stat", side_effect=stat_that_fails_once ): @@ -4711,7 +4710,7 @@ class TestExtractReport(unittest.TestCase): def close(self): pass - result = parsedmarc.extract_report(NonSeekable(xml)) + result = parsedmarc.extract_report(cast(BinaryIO, NonSeekable(xml))) self.assertIn("", result) def testExtractReportInvalidContent(self): @@ -4736,7 +4735,7 @@ class TestExtractReport(unittest.TestCase): pass with self.assertRaises(parsedmarc.ParserError): - parsedmarc.extract_report(TextStream()) + parsedmarc.extract_report(cast(BinaryIO, TextStream())) class TestMalformedXmlRecovery(unittest.TestCase): @@ -4794,7 +4793,7 @@ class TestMalformedXmlRecovery(unittest.TestCase): example.compass """ - report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) + report = parsedmarc.parse_aggregate_report_xml(xml.decode(), offline=True) self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input") def testExpatErrorRaises(self): @@ -5048,7 +5047,7 @@ Test body""" def testMissingVersion(self): """Missing version defaults to None""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Version:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Version:")] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( report_str, self._make_sample(), self._default_msg_date(), offline=True @@ -5058,7 +5057,7 @@ Test body""" def testMissingUserAgent(self): """Missing user_agent defaults to None""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("User-Agent:")] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( report_str, self._make_sample(), self._default_msg_date(), offline=True @@ -5112,7 +5111,7 @@ Test body""" """Missing reported_domain falls back to sample from domain""" report_str = self._make_feedback_report() lines = [ - l for l in report_str.split("\n") if not l.startswith("Reported-Domain:") + ln for ln in report_str.split("\n") if not ln.startswith("Reported-Domain:") ] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( @@ -5123,7 +5122,7 @@ Test body""" def testMissingArrivalDateWithMsgDate(self): """Missing arrival_date uses msg_date fallback""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")] report_str = "\n".join(lines) msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc) report = parsedmarc.parse_failure_report( @@ -5134,11 +5133,14 @@ Test body""" def testMissingArrivalDateNoMsgDateRaises(self): """Missing arrival_date with no msg_date raises""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")] report_str = "\n".join(lines) with self.assertRaises(parsedmarc.InvalidFailureReport): parsedmarc.parse_failure_report( - report_str, self._make_sample(), None, offline=True + report_str, + self._make_sample(), + cast(datetime, None), # intentionally None to test error path + offline=True, ) @@ -5275,7 +5277,6 @@ Body""" def testEmailWithAttachments(self): """parse_email with strip_attachment_payloads removes payloads""" - import email as email_mod from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.mime.base import MIMEBase @@ -5404,7 +5405,6 @@ class TestGetDmarcReportsFromMbox(unittest.TestCase): def testMboxWithAggregateReport(self): """Mbox with aggregate report email is parsed""" - import email as email_mod from email.mime.multipart import MIMEMultipart from email.mime.application import MIMEApplication import gzip