Refactor tests to use assertions consistently and improve type hints

This commit is contained in:
Sean Whalen
2026-03-21 16:06:41 -04:00
parent 49edcb98ec
commit 2b10adaaf4

View File

@@ -12,10 +12,11 @@ from base64 import urlsafe_b64encode
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import cast
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from lxml import etree
from lxml import etree # type: ignore[import-untyped]
from googleapiclient.errors import HttpError
from httplib2 import Response
from imapclient.exceptions import IMAPClientError
@@ -32,6 +33,7 @@ from parsedmarc.mail.imap import IMAPConnection
import parsedmarc.mail.gmail as gmail_module
import parsedmarc.mail.graph as graph_module
import parsedmarc.mail.imap as imap_module
import parsedmarc.elastic
import parsedmarc.opensearch as opensearch_module
import parsedmarc.utils
@@ -154,7 +156,7 @@ class Test(unittest.TestCase):
report_path,
offline=True,
)
self.assertEqual(result["report_type"], "aggregate")
assert result["report_type"] == "aggregate"
self.assertEqual(result["report"]["report_metadata"]["org_name"], "outlook.com")
def testParseReportFileAcceptsPathForEmail(self):
@@ -165,7 +167,7 @@ class Test(unittest.TestCase):
report_path,
offline=True,
)
self.assertEqual(result["report_type"], "aggregate")
assert result["report_type"] == "aggregate"
self.assertEqual(result["report"]["report_metadata"]["org_name"], "google.com")
def testAggregateSamples(self):
@@ -176,10 +178,11 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
parsed_report = parsedmarc.parse_report_file(
result = parsedmarc.parse_report_file(
sample_path, always_use_local_files=True, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_aggregate_reports_to_csv(parsed_report)
)
assert result["report_type"] == "aggregate"
parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
print("Passed!")
def testEmptySample(self):
@@ -195,13 +198,15 @@ class Test(unittest.TestCase):
print("Testing {0}: ".format(sample_path), end="")
with open(sample_path) as sample_file:
sample_content = sample_file.read()
parsed_report = parsedmarc.parse_report_email(
email_result = parsedmarc.parse_report_email(
sample_content, offline=OFFLINE_MODE
)["report"]
parsed_report = parsedmarc.parse_report_file(
)
assert email_result["report_type"] == "forensic"
result = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_forensic_reports_to_csv(parsed_report)
)
assert result["report_type"] == "forensic"
parsedmarc.parsed_forensic_reports_to_csv(result["report"])
print("Passed!")
def testSmtpTlsSamples(self):
@@ -212,10 +217,11 @@ class Test(unittest.TestCase):
if os.path.isdir(sample_path):
continue
print("Testing {0}: ".format(sample_path), end="")
parsed_report = parsedmarc.parse_report_file(
result = parsedmarc.parse_report_file(
sample_path, offline=OFFLINE_MODE
)["report"]
parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report)
)
assert result["report_type"] == "smtp_tls"
parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"])
print("Passed!")
def testOpenSearchSigV4RequiresRegion(self):
@@ -1289,7 +1295,7 @@ class TestMailboxWatchSince(unittest.TestCase):
) as mocked:
with self.assertRaises(_BreakLoop):
parsedmarc.watch_inbox(
mailbox_connection=mailbox_connection,
mailbox_connection=cast(parsedmarc.MailboxConnection, mailbox_connection),
callback=callback,
check_timeout=1,
batch_size=10,
@@ -1337,30 +1343,30 @@ since = 2d
self.assertEqual(mock_watch_inbox.call_args.kwargs.get("since"), "2d")
class _DummyMailboxConnection:
class _DummyMailboxConnection(parsedmarc.MailboxConnection):
def __init__(self):
self.fetch_calls = []
self.fetch_calls: list[dict[str, object]] = []
def create_folder(self, folder_name):
def create_folder(self, folder_name: str):
return None
def fetch_messages(self, reports_folder, **kwargs):
def fetch_messages(self, reports_folder: str, **kwargs):
self.fetch_calls.append({"reports_folder": reports_folder, **kwargs})
return []
def fetch_message(self, message_id, **kwargs):
def fetch_message(self, message_id) -> str:
return ""
def delete_message(self, message_id):
return None
def move_message(self, message_id, folder_name):
def move_message(self, message_id, folder_name: str):
return None
def keepalive(self):
return None
def watch(self, check_callback, check_timeout):
def watch(self, check_callback, check_timeout, config_reloading=None):
return None
@@ -1559,7 +1565,7 @@ class TestMSGraphFolderFallback(unittest.TestCase):
def testWellKnownFolderFallback(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "shared@example.com"
connection._client = _FakeGraphClient()
connection._client = _FakeGraphClient() # type: ignore[assignment]
connection._request_with_retries = MagicMock(
side_effect=lambda method_name, *args, **kwargs: getattr(
connection._client, method_name
@@ -1579,7 +1585,7 @@ class TestMSGraphFolderFallback(unittest.TestCase):
def testUnknownFolderStillFails(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "shared@example.com"
connection._client = _FakeGraphClient()
connection._client = _FakeGraphClient() # type: ignore[assignment]
connection._request_with_retries = MagicMock(
side_effect=lambda method_name, *args, **kwargs: getattr(
connection._client, method_name