From 2b10adaaf4ecc2e7f416df537ea77c93e26bbc8c Mon Sep 17 00:00:00 2001 From: Sean Whalen Date: Sat, 21 Mar 2026 16:06:41 -0400 Subject: [PATCH] Refactor tests to use assertions consistently and improve type hints --- tests.py | 54 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/tests.py b/tests.py index 02d76fb..1b70cb0 100755 --- a/tests.py +++ b/tests.py @@ -12,10 +12,11 @@ from base64 import urlsafe_b64encode from glob import glob from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import cast from types import SimpleNamespace from unittest.mock import MagicMock, patch -from lxml import etree +from lxml import etree # type: ignore[import-untyped] from googleapiclient.errors import HttpError from httplib2 import Response from imapclient.exceptions import IMAPClientError @@ -32,6 +33,7 @@ from parsedmarc.mail.imap import IMAPConnection import parsedmarc.mail.gmail as gmail_module import parsedmarc.mail.graph as graph_module import parsedmarc.mail.imap as imap_module +import parsedmarc.elastic import parsedmarc.opensearch as opensearch_module import parsedmarc.utils @@ -154,7 +156,7 @@ class Test(unittest.TestCase): report_path, offline=True, ) - self.assertEqual(result["report_type"], "aggregate") + assert result["report_type"] == "aggregate" self.assertEqual(result["report"]["report_metadata"]["org_name"], "outlook.com") def testParseReportFileAcceptsPathForEmail(self): @@ -165,7 +167,7 @@ class Test(unittest.TestCase): report_path, offline=True, ) - self.assertEqual(result["report_type"], "aggregate") + assert result["report_type"] == "aggregate" self.assertEqual(result["report"]["report_metadata"]["org_name"], "google.com") def testAggregateSamples(self): @@ -176,10 +178,11 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - parsed_report = parsedmarc.parse_report_file( + result = parsedmarc.parse_report_file( sample_path, always_use_local_files=True, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) + ) + assert result["report_type"] == "aggregate" + parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) print("Passed!") def testEmptySample(self): @@ -195,13 +198,15 @@ class Test(unittest.TestCase): print("Testing {0}: ".format(sample_path), end="") with open(sample_path) as sample_file: sample_content = sample_file.read() - parsed_report = parsedmarc.parse_report_email( + email_result = parsedmarc.parse_report_email( sample_content, offline=OFFLINE_MODE - )["report"] - parsed_report = parsedmarc.parse_report_file( + ) + assert email_result["report_type"] == "forensic" + result = parsedmarc.parse_report_file( sample_path, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_forensic_reports_to_csv(parsed_report) + ) + assert result["report_type"] == "forensic" + parsedmarc.parsed_forensic_reports_to_csv(result["report"]) print("Passed!") def testSmtpTlsSamples(self): @@ -212,10 +217,11 @@ class Test(unittest.TestCase): if os.path.isdir(sample_path): continue print("Testing {0}: ".format(sample_path), end="") - parsed_report = parsedmarc.parse_report_file( + result = parsedmarc.parse_report_file( sample_path, offline=OFFLINE_MODE - )["report"] - parsedmarc.parsed_smtp_tls_reports_to_csv(parsed_report) + ) + assert result["report_type"] == "smtp_tls" + parsedmarc.parsed_smtp_tls_reports_to_csv(result["report"]) print("Passed!") def testOpenSearchSigV4RequiresRegion(self): @@ -1289,7 +1295,7 @@ class TestMailboxWatchSince(unittest.TestCase): ) as mocked: with self.assertRaises(_BreakLoop): parsedmarc.watch_inbox( - mailbox_connection=mailbox_connection, + mailbox_connection=cast(parsedmarc.MailboxConnection, mailbox_connection), callback=callback, check_timeout=1, batch_size=10, @@ -1337,30 +1343,30 @@ since = 2d self.assertEqual(mock_watch_inbox.call_args.kwargs.get("since"), "2d") -class _DummyMailboxConnection: +class _DummyMailboxConnection(parsedmarc.MailboxConnection): def __init__(self): - self.fetch_calls = [] + self.fetch_calls: list[dict[str, object]] = [] - def create_folder(self, folder_name): + def create_folder(self, folder_name: str): return None - def fetch_messages(self, reports_folder, **kwargs): + def fetch_messages(self, reports_folder: str, **kwargs): self.fetch_calls.append({"reports_folder": reports_folder, **kwargs}) return [] - def fetch_message(self, message_id, **kwargs): + def fetch_message(self, message_id) -> str: return "" def delete_message(self, message_id): return None - def move_message(self, message_id, folder_name): + def move_message(self, message_id, folder_name: str): return None def keepalive(self): return None - def watch(self, check_callback, check_timeout): + def watch(self, check_callback, check_timeout, config_reloading=None): return None @@ -1559,7 +1565,7 @@ class TestMSGraphFolderFallback(unittest.TestCase): def testWellKnownFolderFallback(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "shared@example.com" - connection._client = _FakeGraphClient() + connection._client = _FakeGraphClient() # type: ignore[assignment] connection._request_with_retries = MagicMock( side_effect=lambda method_name, *args, **kwargs: getattr( connection._client, method_name @@ -1579,7 +1585,7 @@ class TestMSGraphFolderFallback(unittest.TestCase): def testUnknownFolderStillFails(self): connection = MSGraphConnection.__new__(MSGraphConnection) connection.mailbox_name = "shared@example.com" - connection._client = _FakeGraphClient() + connection._client = _FakeGraphClient() # type: ignore[assignment] connection._request_with_retries = MagicMock( side_effect=lambda method_name, *args, **kwargs: getattr( connection._client, method_name