Increase unit test coverage for Gmail/Graph/IMAP connectors (#664)

* Increase coverage for Gmail, Graph, and IMAP mail connectors

* Make testLoadTokenMissing use guaranteed-missing temp path

* Expand coverage for Gmail token refresh and Graph pagination error paths
This commit is contained in:
Kili
2026-03-09 16:54:43 +01:00
committed by GitHub
parent adb0d31382
commit d49ce6a13f
+447
View File
@@ -5,11 +5,29 @@ from __future__ import absolute_import, print_function, unicode_literals
import os
import unittest
from base64 import urlsafe_b64encode
from glob import glob
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest.mock import MagicMock
from unittest.mock import patch
from lxml import etree
from googleapiclient.errors import HttpError
from httplib2 import Response
from imapclient.exceptions import IMAPClientError
import parsedmarc
from parsedmarc.mail.gmail import GmailConnection
from parsedmarc.mail.gmail import _get_creds
from parsedmarc.mail.graph import MSGraphConnection
from parsedmarc.mail.graph import _generate_credential
from parsedmarc.mail.graph import _get_cache_args
from parsedmarc.mail.graph import _load_token
from parsedmarc.mail.imap import IMAPConnection
import parsedmarc.mail.gmail as gmail_module
import parsedmarc.mail.graph as graph_module
import parsedmarc.mail.imap as imap_module
import parsedmarc.utils
# Detect if running in GitHub Actions to skip DNS lookups
@@ -166,5 +184,434 @@ class Test(unittest.TestCase):
print("Passed!")
class _FakeGraphResponse:
def __init__(self, status_code, payload=None, text=""):
self.status_code = status_code
self._payload = payload or {}
self.text = text
def json(self):
return self._payload
class _BreakLoop(BaseException):
pass
class TestGmailConnection(unittest.TestCase):
def _build_connection(self, *, paginate=True):
connection = GmailConnection.__new__(GmailConnection)
connection.include_spam_trash = False
connection.reports_label_id = "REPORTS"
connection.paginate_messages = paginate
connection.service = MagicMock()
return connection
def testFindLabelId(self):
connection = self._build_connection()
labels_api = connection.service.users.return_value.labels.return_value
labels_api.list.return_value.execute.return_value = {
"labels": [
{"id": "INBOX", "name": "INBOX"},
{"id": "REPORTS", "name": "Reports"},
]
}
self.assertEqual(connection._find_label_id_for_label("Reports"), "REPORTS")
self.assertEqual(connection._find_label_id_for_label("MISSING"), "")
def testFetchMessagesWithPagination(self):
connection = self._build_connection(paginate=True)
messages_api = connection.service.users.return_value.messages.return_value
def list_side_effect(**kwargs):
response = MagicMock()
if kwargs.get("pageToken") is None:
response.execute.return_value = {
"messages": [{"id": "a"}, {"id": "b"}],
"nextPageToken": "n1",
}
else:
response.execute.return_value = {"messages": [{"id": "c"}]}
return response
messages_api.list.side_effect = list_side_effect
connection._find_label_id_for_label = MagicMock(return_value="REPORTS")
self.assertEqual(connection.fetch_messages("Reports"), ["a", "b", "c"])
def testFetchMessageDecoding(self):
connection = self._build_connection()
messages_api = connection.service.users.return_value.messages.return_value
raw = urlsafe_b64encode(b"Subject: test\n\nbody").decode()
messages_api.get.return_value.execute.return_value = {"raw": raw}
content = connection.fetch_message("m1")
self.assertIn("Subject: test", content)
def testMoveAndDeleteMessage(self):
connection = self._build_connection()
connection._find_label_id_for_label = MagicMock(return_value="ARCHIVE")
messages_api = connection.service.users.return_value.messages.return_value
messages_api.modify.return_value.execute.return_value = {}
connection.move_message("m1", "Archive")
messages_api.modify.assert_called_once()
connection.delete_message("m1")
messages_api.delete.assert_called_once_with(userId="me", id="m1")
def testGetCredsFromTokenFile(self):
creds = MagicMock()
creds.valid = True
with NamedTemporaryFile("w", delete=False) as token_file:
token_file.write("{}")
token_path = token_file.name
try:
with patch.object(
gmail_module.Credentials,
"from_authorized_user_file",
return_value=creds,
):
returned = _get_creds(
token_path, "credentials.json", ["scope"], 8080
)
finally:
os.remove(token_path)
self.assertEqual(returned, creds)
def testGetCredsWithOauthFlow(self):
expired_creds = MagicMock()
expired_creds.valid = False
expired_creds.expired = False
expired_creds.refresh_token = None
new_creds = MagicMock()
new_creds.valid = True
new_creds.to_json.return_value = '{"token":"x"}'
flow = MagicMock()
flow.run_local_server.return_value = new_creds
with NamedTemporaryFile("w", delete=False) as token_file:
token_file.write("{}")
token_path = token_file.name
try:
with patch.object(
gmail_module.Credentials,
"from_authorized_user_file",
return_value=expired_creds,
):
with patch.object(
gmail_module.InstalledAppFlow,
"from_client_secrets_file",
return_value=flow,
):
returned = _get_creds(
token_path, "credentials.json", ["scope"], 8080
)
finally:
os.remove(token_path)
self.assertEqual(returned, new_creds)
flow.run_local_server.assert_called_once()
def testGetCredsRefreshesExpiredToken(self):
expired_creds = MagicMock()
expired_creds.valid = False
expired_creds.expired = True
expired_creds.refresh_token = "rt"
expired_creds.to_json.return_value = '{"token":"refreshed"}'
with NamedTemporaryFile("w", delete=False) as token_file:
token_file.write("{}")
token_path = token_file.name
try:
with patch.object(
gmail_module.Credentials,
"from_authorized_user_file",
return_value=expired_creds,
):
returned = _get_creds(
token_path, "credentials.json", ["scope"], 8080
)
finally:
os.remove(token_path)
self.assertEqual(returned, expired_creds)
expired_creds.refresh.assert_called_once()
def testCreateFolderConflictIgnored(self):
connection = self._build_connection()
labels_api = connection.service.users.return_value.labels.return_value
conflict = HttpError(Response({"status": "409"}), b"conflict")
labels_api.create.return_value.execute.side_effect = conflict
connection.create_folder("Existing")
class TestGraphConnection(unittest.TestCase):
def testLoadTokenMissing(self):
with TemporaryDirectory() as temp_dir:
missing_path = Path(temp_dir) / "missing-token-file"
self.assertIsNone(_load_token(missing_path))
def testLoadTokenExisting(self):
with NamedTemporaryFile("w", delete=False) as token_file:
token_file.write("serialized-auth-record")
token_path = token_file.name
try:
self.assertEqual(_load_token(Path(token_path)), "serialized-auth-record")
finally:
os.remove(token_path)
def testGetAllMessagesPagination(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
first_response = _FakeGraphResponse(
200, {"value": [{"id": "1"}], "@odata.nextLink": "next-url"}
)
second_response = _FakeGraphResponse(200, {"value": [{"id": "2"}]})
connection._client = MagicMock()
connection._client.get.side_effect = [first_response, second_response]
messages = connection._get_all_messages("/url", batch_size=0, since=None)
self.assertEqual([msg["id"] for msg in messages], ["1", "2"])
def testGetAllMessagesInitialRequestFailure(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection._client = MagicMock()
connection._client.get.return_value = _FakeGraphResponse(500, text="boom")
with self.assertRaises(RuntimeError):
connection._get_all_messages("/url", batch_size=0, since=None)
def testGetAllMessagesNextPageFailure(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
first_response = _FakeGraphResponse(
200, {"value": [{"id": "1"}], "@odata.nextLink": "next-url"}
)
second_response = _FakeGraphResponse(500, text="page-fail")
connection._client = MagicMock()
connection._client.get.side_effect = [first_response, second_response]
with self.assertRaises(RuntimeError):
connection._get_all_messages("/url", batch_size=0, since=None)
def testGetAllMessagesHonorsBatchSizeLimit(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
first_response = _FakeGraphResponse(
200,
{
"value": [{"id": "1"}, {"id": "2"}],
"@odata.nextLink": "next-url",
},
)
connection._client = MagicMock()
connection._client.get.return_value = first_response
messages = connection._get_all_messages("/url", batch_size=2, since=None)
self.assertEqual([msg["id"] for msg in messages], ["1", "2"])
connection._client.get.assert_called_once()
def testFetchMessagesPassesSinceAndBatchSize(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._find_folder_id_from_folder_path = MagicMock(return_value="folder-id")
connection._get_all_messages = MagicMock(return_value=[{"id": "1"}])
self.assertEqual(
connection.fetch_messages("Inbox", since="2026-03-01", batch_size=5), ["1"]
)
connection._get_all_messages.assert_called_once_with(
"/users/mailbox@example.com/mailFolders/folder-id/messages",
5,
"2026-03-01",
)
def testFetchMessageMarksRead(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._client = MagicMock()
connection._client.get.return_value = _FakeGraphResponse(
200, text="email-content"
)
connection.mark_message_read = MagicMock()
content = connection.fetch_message("123", mark_read=True)
self.assertEqual(content, "email-content")
connection.mark_message_read.assert_called_once_with("123")
def testFindFolderIdNotFound(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._client = MagicMock()
connection._client.get.return_value = _FakeGraphResponse(200, {"value": []})
with self.assertRaises(RuntimeError):
connection._find_folder_id_with_parent("Missing", None)
def testGetCacheArgsWithAuthRecord(self):
with NamedTemporaryFile("w", delete=False) as token_file:
token_file.write("serialized")
token_path = Path(token_file.name)
try:
with patch.object(
graph_module.AuthenticationRecord,
"deserialize",
return_value="auth_record",
):
args = _get_cache_args(token_path, allow_unencrypted_storage=False)
self.assertIn("authentication_record", args)
finally:
os.remove(token_path)
def testGenerateCredentialInvalid(self):
with self.assertRaises(RuntimeError):
_generate_credential(
"Nope",
Path("/tmp/token"),
client_id="x",
client_secret="y",
username="u",
password="p",
tenant_id="t",
allow_unencrypted_storage=False,
)
def testGenerateCredentialDeviceCode(self):
fake_credential = object()
with patch.object(graph_module, "_get_cache_args", return_value={"cached": True}):
with patch.object(
graph_module,
"DeviceCodeCredential",
return_value=fake_credential,
) as mocked:
result = _generate_credential(
graph_module.AuthMethod.DeviceCode.name,
Path("/tmp/token"),
client_id="cid",
client_secret="secret",
username="user",
password="pass",
tenant_id="tenant",
allow_unencrypted_storage=True,
)
self.assertIs(result, fake_credential)
mocked.assert_called_once()
def testGenerateCredentialClientSecret(self):
fake_credential = object()
with patch.object(
graph_module, "ClientSecretCredential", return_value=fake_credential
) as mocked:
result = _generate_credential(
graph_module.AuthMethod.ClientSecret.name,
Path("/tmp/token"),
client_id="cid",
client_secret="secret",
username="user",
password="pass",
tenant_id="tenant",
allow_unencrypted_storage=False,
)
self.assertIs(result, fake_credential)
mocked.assert_called_once_with(
client_id="cid", tenant_id="tenant", client_secret="secret"
)
def testInitUsesSharedMailboxScopes(self):
class FakeCredential:
def __init__(self):
self.authenticate = MagicMock(return_value="auth-record")
fake_credential = FakeCredential()
with patch.object(
graph_module, "_generate_credential", return_value=fake_credential
):
with patch.object(graph_module, "_cache_auth_record") as cache_auth:
with patch.object(graph_module, "GraphClient") as graph_client:
MSGraphConnection(
auth_method=graph_module.AuthMethod.DeviceCode.name,
mailbox="shared@example.com",
graph_url="https://graph.microsoft.com",
client_id="cid",
client_secret="secret",
username="owner@example.com",
password="pass",
tenant_id="tenant",
token_file="/tmp/token-file",
allow_unencrypted_storage=True,
)
fake_credential.authenticate.assert_called_once_with(
scopes=["Mail.ReadWrite.Shared"]
)
cache_auth.assert_called_once()
graph_client.assert_called_once()
self.assertEqual(
graph_client.call_args.kwargs.get("scopes"), ["Mail.ReadWrite.Shared"]
)
def testCreateFolderAndMoveErrors(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._client = MagicMock()
connection._client.post.return_value = _FakeGraphResponse(500, {"error": "x"})
connection._find_folder_id_from_folder_path = MagicMock(return_value="dest")
with self.assertRaises(RuntimeWarning):
connection.move_message("m1", "Archive")
connection._client.post.return_value = _FakeGraphResponse(409, {})
connection.create_folder("Archive")
def testMarkReadDeleteFailures(self):
connection = MSGraphConnection.__new__(MSGraphConnection)
connection.mailbox_name = "mailbox@example.com"
connection._client = MagicMock()
connection._client.patch.return_value = _FakeGraphResponse(500, {"error": "x"})
with self.assertRaises(RuntimeWarning):
connection.mark_message_read("m1")
connection._client.delete.return_value = _FakeGraphResponse(500, {"error": "x"})
with self.assertRaises(RuntimeWarning):
connection.delete_message("m1")
class TestImapConnection(unittest.TestCase):
def testDelegatesToImapClient(self):
with patch.object(imap_module, "IMAPClient") as mocked_client_cls:
mocked_client = MagicMock()
mocked_client_cls.return_value = mocked_client
connection = IMAPConnection(
"imap.example.com", user="user", password="pass"
)
connection.create_folder("Archive")
mocked_client.create_folder.assert_called_once_with("Archive")
mocked_client.search.return_value = [1, 2]
self.assertEqual(connection.fetch_messages("INBOX"), [1, 2])
mocked_client.select_folder.assert_called_with("INBOX")
connection.fetch_messages("INBOX", since="2026-03-01")
mocked_client.search.assert_called_with("SINCE 2026-03-01")
mocked_client.fetch_message.return_value = "raw-message"
self.assertEqual(connection.fetch_message(1), "raw-message")
connection.delete_message(7)
mocked_client.delete_messages.assert_called_once_with([7])
connection.move_message(8, "Archive")
mocked_client.move_messages.assert_called_once_with([8], "Archive")
connection.keepalive()
mocked_client.noop.assert_called_once()
def testWatchReconnectPath(self):
with patch.object(imap_module, "IMAPClient") as mocked_client_cls:
base_client = MagicMock()
base_client.host = "imap.example.com"
base_client.port = 993
base_client.ssl = True
mocked_client_cls.return_value = base_client
connection = IMAPConnection(
"imap.example.com", user="user", password="pass"
)
calls = {"count": 0}
def fake_imap_constructor(*args, **kwargs):
idle_callback = kwargs.get("idle_callback")
if calls["count"] == 0:
calls["count"] += 1
raise IMAPClientError("timeout")
if idle_callback is not None:
idle_callback(base_client)
raise _BreakLoop()
callback = MagicMock()
with patch.object(imap_module, "sleep", return_value=None):
with patch.object(
imap_module, "IMAPClient", side_effect=fake_imap_constructor
):
with self.assertRaises(_BreakLoop):
connection.watch(callback, check_timeout=1)
callback.assert_called_once_with(connection)
if __name__ == "__main__":
unittest.main(verbosity=2)