diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index 6f3dbcc..05154f7 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -90,10 +90,15 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs): client_secret=kwargs["client_secret"], ) elif auth_method == AuthMethod.Certificate.name: + cert_path = kwargs.get("certificate_path") + if not cert_path: + raise ValueError( + "certificate_path is required when auth_method is 'Certificate'" + ) credential = CertificateCredential( client_id=kwargs["client_id"], tenant_id=kwargs["tenant_id"], - certificate_path=kwargs["certificate_path"], + certificate_path=cert_path, password=kwargs.get("certificate_password"), ) else: @@ -117,9 +122,9 @@ class MSGraphConnection(MailboxConnection): mailbox: str, graph_url: str, client_id: str, - client_secret: str, - username: str, - password: str, + client_secret: Optional[str], + username: Optional[str], + password: Optional[str], tenant_id: str, token_file: str, allow_unencrypted_storage: bool, @@ -146,7 +151,7 @@ class MSGraphConnection(MailboxConnection): if not isinstance(credential, (ClientSecretCredential, CertificateCredential)): scopes = ["Mail.ReadWrite"] # Detect if mailbox is shared - if mailbox and username != mailbox: + if mailbox and username and username != mailbox: scopes = ["Mail.ReadWrite.Shared"] auth_record = credential.authenticate(scopes=scopes) _cache_auth_record(auth_record, token_path) diff --git a/tests.py b/tests.py index 0b9d3cf..88826d9 100755 --- a/tests.py +++ b/tests.py @@ -840,6 +840,24 @@ class TestGraphConnection(unittest.TestCase): password="secret-pass", ) + def testGenerateCredentialCertificateRequiresPath(self): + with self.assertRaisesRegex( + ValueError, + "certificate_path is required when auth_method is 'Certificate'", + ): + _generate_credential( + graph_module.AuthMethod.Certificate.name, + Path("/tmp/token"), + client_id="cid", + client_secret=None, + certificate_path=None, + certificate_password="secret-pass", + username=None, + password=None, + tenant_id="tenant", + allow_unencrypted_storage=False, + ) + def testInitUsesSharedMailboxScopes(self): class FakeCredential: def __init__(self): @@ -872,6 +890,34 @@ class TestGraphConnection(unittest.TestCase): graph_client.call_args.kwargs.get("scopes"), ["Mail.ReadWrite.Shared"] ) + def testInitWithoutUsernameUsesDefaultMailReadWriteScope(self): + class FakeCredential: + def __init__(self): + self.authenticate = MagicMock(return_value="auth-record") + + fake_credential = FakeCredential() + with patch.object( + graph_module, "_generate_credential", return_value=fake_credential + ): + with patch.object(graph_module, "_cache_auth_record") as cache_auth: + with patch.object(graph_module, "GraphClient") as graph_client: + MSGraphConnection( + auth_method=graph_module.AuthMethod.DeviceCode.name, + mailbox="owner@example.com", + graph_url="https://graph.microsoft.com", + client_id="cid", + client_secret="secret", + username=None, + password=None, + tenant_id="tenant", + token_file="/tmp/token-file", + allow_unencrypted_storage=True, + ) + fake_credential.authenticate.assert_called_once_with(scopes=["Mail.ReadWrite"]) + cache_auth.assert_called_once() + graph_client.assert_called_once() + self.assertEqual(graph_client.call_args.kwargs.get("scopes"), ["Mail.ReadWrite"]) + def testInitCertificateAuthSkipsInteractiveAuthenticate(self): class DummyCertificateCredential: pass @@ -1311,6 +1357,39 @@ certificate_password = cert-pass mock_graph_connection.call_args.kwargs.get("certificate_password"), "cert-pass", ) + + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.MSGraphConnection") + @patch("parsedmarc.cli.logger") + def testCliRequiresMsGraphCertificatePath( + self, mock_logger, mock_graph_connection, mock_get_mailbox_reports + ): + config_text = """[general] +silent = true + +[msgraph] +auth_method = Certificate +client_id = client-id +tenant_id = tenant-id +mailbox = shared@example.com +""" + + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg: + cfg.write(config_text) + cfg_path = cfg.name + self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]): + with self.assertRaises(SystemExit) as system_exit: + parsedmarc.cli._main() + + self.assertEqual(system_exit.exception.code, -1) + mock_logger.critical.assert_called_once_with( + "certificate_path setting missing from the msgraph config section" + ) + mock_graph_connection.assert_not_called() + mock_get_mailbox_reports.assert_not_called() + class _FakeGraphClient: def get(self, url, params=None): if "/mailFolders/inbox?$select=id,displayName" in url: