diff --git a/tests.py b/tests.py index 3f84b4b..886e1a0 100755 --- a/tests.py +++ b/tests.py @@ -17,7 +17,7 @@ from io import BytesIO from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import cast +from typing import BinaryIO, cast from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -2915,16 +2915,15 @@ class TestMaildirUidHandling(unittest.TestCase): os.makedirs(os.path.join(d, subdir)) original_stat = os.stat + call_count = [0] def stat_that_fails_once(path, *args, **kwargs): """Fail on the first call (UID check), pass through after.""" - stat_that_fails_once.calls += 1 - if stat_that_fails_once.calls == 1: + call_count[0] += 1 + if call_count[0] == 1: raise OSError("no stat") return original_stat(path, *args, **kwargs) - stat_that_fails_once.calls = 0 - with patch( "parsedmarc.mail.maildir.os.stat", side_effect=stat_that_fails_once ): @@ -4858,7 +4857,7 @@ class TestExtractReport(unittest.TestCase): def close(self): pass - result = parsedmarc.extract_report(NonSeekable(xml)) + result = parsedmarc.extract_report(cast(BinaryIO, NonSeekable(xml))) self.assertIn("", result) def testExtractReportInvalidContent(self): @@ -4883,7 +4882,7 @@ class TestExtractReport(unittest.TestCase): pass with self.assertRaises(parsedmarc.ParserError): - parsedmarc.extract_report(TextStream()) + parsedmarc.extract_report(cast(BinaryIO, TextStream())) class TestMalformedXmlRecovery(unittest.TestCase): @@ -4941,7 +4940,7 @@ class TestMalformedXmlRecovery(unittest.TestCase): example.compass """ - report = parsedmarc.parse_aggregate_report_xml(xml, offline=True) + report = parsedmarc.parse_aggregate_report_xml(xml.decode(), offline=True) self.assertEqual(report["report_metadata"]["report_id"], "test-bytes-input") def testExpatErrorRaises(self): @@ -5195,7 +5194,7 @@ Test body""" def testMissingVersion(self): """Missing version defaults to None""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Version:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Version:")] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( report_str, self._make_sample(), self._default_msg_date(), offline=True @@ -5205,7 +5204,7 @@ Test body""" def testMissingUserAgent(self): """Missing user_agent defaults to None""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("User-Agent:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("User-Agent:")] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( report_str, self._make_sample(), self._default_msg_date(), offline=True @@ -5259,7 +5258,7 @@ Test body""" """Missing reported_domain falls back to sample from domain""" report_str = self._make_feedback_report() lines = [ - l for l in report_str.split("\n") if not l.startswith("Reported-Domain:") + ln for ln in report_str.split("\n") if not ln.startswith("Reported-Domain:") ] report_str = "\n".join(lines) report = parsedmarc.parse_failure_report( @@ -5270,7 +5269,7 @@ Test body""" def testMissingArrivalDateWithMsgDate(self): """Missing arrival_date uses msg_date fallback""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")] report_str = "\n".join(lines) msg_date = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc) report = parsedmarc.parse_failure_report( @@ -5281,11 +5280,14 @@ Test body""" def testMissingArrivalDateNoMsgDateRaises(self): """Missing arrival_date with no msg_date raises""" report_str = self._make_feedback_report() - lines = [l for l in report_str.split("\n") if not l.startswith("Arrival-Date:")] + lines = [ln for ln in report_str.split("\n") if not ln.startswith("Arrival-Date:")] report_str = "\n".join(lines) with self.assertRaises(parsedmarc.InvalidFailureReport): parsedmarc.parse_failure_report( - report_str, self._make_sample(), None, offline=True + report_str, + self._make_sample(), + cast(datetime, None), # intentionally None to test error path + offline=True, ) @@ -5422,7 +5424,6 @@ Body""" def testEmailWithAttachments(self): """parse_email with strip_attachment_payloads removes payloads""" - import email as email_mod from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.mime.base import MIMEBase @@ -5551,7 +5552,6 @@ class TestGetDmarcReportsFromMbox(unittest.TestCase): def testMboxWithAggregateReport(self): """Mbox with aggregate report email is parsed""" - import email as email_mod from email.mime.multipart import MIMEMultipart from email.mime.application import MIMEApplication import gzip