diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..4c2e2b8 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +codecov: + require_ci_to_pass: true + +coverage: + status: + project: + default: + informational: true + patch: + default: + informational: false diff --git a/docs/source/usage.md b/docs/source/usage.md index b3c6bf7..426d98f 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -146,6 +146,9 @@ The full set of configuration options are: - `dns_timeout` - float: DNS timeout period - `debug` - bool: Print debugging messages - `silent` - bool: Only print errors (Default: `True`) + - `fail_on_output_error` - bool: Exit with a non-zero status code if + any configured output destination fails while saving/publishing + reports (Default: `False`) - `log_file` - str: Write log messages to a file at this path - `n_procs` - int: Number of process to run in parallel when parsing in CLI mode (Default: `1`) @@ -281,6 +284,10 @@ The full set of configuration options are: - `user` - str: Basic auth username - `password` - str: Basic auth password - `api_key` - str: API key + - `auth_type` - str: Authentication type: `basic` (default) or `awssigv4` (the key `authentication_type` is accepted as an alias for this option) + - `aws_region` - str: AWS region for SigV4 authentication + (required when `auth_type = awssigv4`) + - `aws_service` - str: AWS service for SigV4 signing (Default: `es`) - `ssl` - bool: Use an encrypted SSL/TLS connection (Default: `True`) - `timeout` - float: Timeout in seconds (Default: 60) @@ -511,6 +518,33 @@ PUT _cluster/settings Increasing this value increases resource usage. ::: +## Performance tuning + +For large mailbox imports or backfills, parsedmarc can consume a noticeable amount +of memory, especially when it runs on the same host as Elasticsearch or +OpenSearch. The following settings can reduce peak memory usage and make long +imports more predictable: + +- Reduce `mailbox.batch_size` to smaller values such as `100-500` instead of + processing a very large message set at once. Smaller batches trade throughput + for lower peak memory use and less sink pressure. +- Keep `n_procs` low for mailbox-heavy runs. In practice, `1-2` workers is often + a safer starting point for large backfills than aggressive parallelism. +- Use `mailbox.since` to process reports in smaller time windows such as `1d`, + `7d`, or another interval that fits the backlog. This makes it easier to catch + up incrementally instead of loading an entire mailbox history in one run. +- Set `strip_attachment_payloads = True` when forensic reports contain large + attachments and you do not need to retain the raw payloads in the parsed + output. +- Prefer running parsedmarc separately from Elasticsearch or OpenSearch, or + reserve enough RAM for both services if they must share a host. +- For very large imports, prefer incremental supervised runs, such as a + scheduler or systemd service, over infrequent massive backfills. + +These are operational tuning recommendations rather than hard requirements, but +they are often enough to avoid memory pressure and reduce failures during +high-volume mailbox processing. + ## Multi-tenant support Starting in `8.19.0`, ParseDMARC provides multi-tenant support by placing data into separate OpenSearch or Elasticsearch index prefixes. To set this up, create a YAML file that is formatted where each key is a tenant name, and the value is a list of domains related to that tenant, not including subdomains, like this: diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 0bbb447..afbefca 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -2229,6 +2229,7 @@ def watch_inbox( dns_timeout: float = 6.0, strip_attachment_payloads: bool = False, batch_size: int = 10, + since: Optional[Union[datetime, date, str]] = None, normalize_timespan_threshold_hours: float = 24, ): """ @@ -2255,6 +2256,7 @@ def watch_inbox( strip_attachment_payloads (bool): Replace attachment payloads in failure report samples with None batch_size (int): Number of messages to read and process before saving + since: Search for messages since certain time normalize_timespan_threshold_hours (float): Normalize timespans beyond this """ @@ -2274,6 +2276,7 @@ def watch_inbox( dns_timeout=dns_timeout, strip_attachment_payloads=strip_attachment_payloads, batch_size=batch_size, + since=since, create_folders=False, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours, ) diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index 420897f..77a8a3b 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -194,6 +194,13 @@ def _main(): return None def process_reports(reports_): + output_errors = [] + + def log_output_error(destination, error): + message = f"{destination} Error: {error}" + logger.error(message) + output_errors.append(message) + indent_value = 2 if opts.prettify_json else None output_str = "{0}\n".format( json.dumps(reports_, ensure_ascii=False, indent=indent_value) @@ -230,11 +237,9 @@ def _main(): except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format(error_.__str__())) + log_output_error("Elasticsearch", error_.__str__()) except Exception as error_: - logger.error( - "Elasticsearch exception error: {}".format(error_.__str__()) - ) + log_output_error("Elasticsearch exception", error_.__str__()) try: if opts.opensearch_hosts: @@ -252,11 +257,9 @@ def _main(): except opensearch.AlreadySaved as warning: logger.warning(warning.__str__()) except opensearch.OpenSearchError as error_: - logger.error("OpenSearch Error: {0}".format(error_.__str__())) + log_output_error("OpenSearch", error_.__str__()) except Exception as error_: - logger.error( - "OpenSearch exception error: {}".format(error_.__str__()) - ) + log_output_error("OpenSearch exception", error_.__str__()) try: if opts.kafka_hosts: @@ -264,25 +267,25 @@ def _main(): report, kafka_aggregate_topic ) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + log_output_error("Kafka", error_.__str__()) try: if opts.s3_bucket: s3_client.save_aggregate_report_to_s3(report) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + log_output_error("S3", error_.__str__()) try: if opts.syslog_server: syslog_client.save_aggregate_report_to_syslog(report) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + log_output_error("Syslog", error_.__str__()) try: if opts.gelf_host: gelf_client.save_aggregate_report_to_gelf(report) except Exception as error_: - logger.error("GELF Error: {0}".format(error_.__str__())) + log_output_error("GELF", error_.__str__()) try: if opts.webhook_aggregate_url: @@ -291,7 +294,7 @@ def _main(): json.dumps(report, ensure_ascii=False, indent=indent_value) ) except Exception as error_: - logger.error("Webhook Error: {0}".format(error_.__str__())) + log_output_error("Webhook", error_.__str__()) if opts.hec: try: @@ -299,7 +302,7 @@ def _main(): if len(aggregate_reports_) > 0: hec_client.save_aggregate_reports_to_splunk(aggregate_reports_) except splunk.SplunkError as e: - logger.error("Splunk HEC error: {0}".format(e.__str__())) + log_output_error("Splunk HEC", e.__str__()) if opts.save_failure: for report in reports_["failure_reports"]: @@ -319,9 +322,9 @@ def _main(): except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format(error_.__str__())) + log_output_error("Elasticsearch", error_.__str__()) except InvalidDMARCReport as error_: - logger.error(error_.__str__()) + log_output_error("Invalid DMARC report", error_.__str__()) try: shards = opts.opensearch_number_of_shards @@ -339,9 +342,9 @@ def _main(): except opensearch.AlreadySaved as warning: logger.warning(warning.__str__()) except opensearch.OpenSearchError as error_: - logger.error("OpenSearch Error: {0}".format(error_.__str__())) + log_output_error("OpenSearch", error_.__str__()) except InvalidDMARCReport as error_: - logger.error(error_.__str__()) + log_output_error("Invalid DMARC report", error_.__str__()) try: if opts.kafka_hosts: @@ -349,25 +352,25 @@ def _main(): report, kafka_failure_topic ) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + log_output_error("Kafka", error_.__str__()) try: if opts.s3_bucket: s3_client.save_failure_report_to_s3(report) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + log_output_error("S3", error_.__str__()) try: if opts.syslog_server: syslog_client.save_failure_report_to_syslog(report) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + log_output_error("Syslog", error_.__str__()) try: if opts.gelf_host: gelf_client.save_failure_report_to_gelf(report) except Exception as error_: - logger.error("GELF Error: {0}".format(error_.__str__())) + log_output_error("GELF", error_.__str__()) try: if opts.webhook_failure_url: @@ -376,7 +379,7 @@ def _main(): json.dumps(report, ensure_ascii=False, indent=indent_value) ) except Exception as error_: - logger.error("Webhook Error: {0}".format(error_.__str__())) + log_output_error("Webhook", error_.__str__()) if opts.hec: try: @@ -384,7 +387,7 @@ def _main(): if len(failure_reports_) > 0: hec_client.save_failure_reports_to_splunk(failure_reports_) except splunk.SplunkError as e: - logger.error("Splunk HEC error: {0}".format(e.__str__())) + log_output_error("Splunk HEC", e.__str__()) if opts.save_smtp_tls: for report in reports_["smtp_tls_reports"]: @@ -404,9 +407,9 @@ def _main(): except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format(error_.__str__())) + log_output_error("Elasticsearch", error_.__str__()) except InvalidDMARCReport as error_: - logger.error(error_.__str__()) + log_output_error("Invalid DMARC report", error_.__str__()) try: shards = opts.opensearch_number_of_shards @@ -424,9 +427,9 @@ def _main(): except opensearch.AlreadySaved as warning: logger.warning(warning.__str__()) except opensearch.OpenSearchError as error_: - logger.error("OpenSearch Error: {0}".format(error_.__str__())) + log_output_error("OpenSearch", error_.__str__()) except InvalidDMARCReport as error_: - logger.error(error_.__str__()) + log_output_error("Invalid DMARC report", error_.__str__()) try: if opts.kafka_hosts: @@ -434,25 +437,25 @@ def _main(): smtp_tls_reports, kafka_smtp_tls_topic ) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + log_output_error("Kafka", error_.__str__()) try: if opts.s3_bucket: s3_client.save_smtp_tls_report_to_s3(report) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + log_output_error("S3", error_.__str__()) try: if opts.syslog_server: syslog_client.save_smtp_tls_report_to_syslog(report) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + log_output_error("Syslog", error_.__str__()) try: if opts.gelf_host: gelf_client.save_smtp_tls_report_to_gelf(report) except Exception as error_: - logger.error("GELF Error: {0}".format(error_.__str__())) + log_output_error("GELF", error_.__str__()) try: if opts.webhook_smtp_tls_url: @@ -461,7 +464,7 @@ def _main(): json.dumps(report, ensure_ascii=False, indent=indent_value) ) except Exception as error_: - logger.error("Webhook Error: {0}".format(error_.__str__())) + log_output_error("Webhook", error_.__str__()) if opts.hec: try: @@ -469,7 +472,7 @@ def _main(): if len(smtp_tls_reports_) > 0: hec_client.save_smtp_tls_reports_to_splunk(smtp_tls_reports_) except splunk.SplunkError as e: - logger.error("Splunk HEC error: {0}".format(e.__str__())) + log_output_error("Splunk HEC", e.__str__()) if opts.la_dce: try: @@ -490,14 +493,16 @@ def _main(): opts.save_smtp_tls, ) except loganalytics.LogAnalyticsException as e: - logger.error("Log Analytics error: {0}".format(e.__str__())) + log_output_error("Log Analytics", e.__str__()) except Exception as e: - logger.error( - "Unknown error occurred" - + " during the publishing" - + " to Log Analytics: " - + e.__str__() + log_output_error("Log Analytics", f"Unknown publishing error: {e}") + + if opts.fail_on_output_error and output_errors: + raise ParserError( + "Output destination failures detected: {0}".format( + " | ".join(output_errors) ) + ) arg_parser = ArgumentParser(description="Parses DMARC reports") arg_parser.add_argument( @@ -671,6 +676,9 @@ def _main(): opensearch_username=None, opensearch_password=None, opensearch_api_key=None, + opensearch_auth_type="basic", + opensearch_aws_region=None, + opensearch_aws_service="es", kafka_hosts=None, kafka_username=None, kafka_password=None, @@ -736,6 +744,7 @@ def _main(): webhook_smtp_tls_url=None, webhook_timeout=60, normalize_timespan_threshold_hours=24.0, + fail_on_output_error=False, ) args = arg_parser.parse_args() @@ -824,6 +833,10 @@ def _main(): opts.silent = bool(general_config.getboolean("silent")) if "warnings" in general_config: opts.warnings = bool(general_config.getboolean("warnings")) + if "fail_on_output_error" in general_config: + opts.fail_on_output_error = bool( + general_config.getboolean("fail_on_output_error") + ) if "log_file" in general_config: opts.log_file = general_config["log_file"] if "n_procs" in general_config: @@ -1110,6 +1123,16 @@ def _main(): # Since 8.20 if "api_key" in opensearch_config: opts.opensearch_api_key = opensearch_config["api_key"] + if "auth_type" in opensearch_config: + opts.opensearch_auth_type = opensearch_config["auth_type"].strip().lower() + elif "authentication_type" in opensearch_config: + opts.opensearch_auth_type = ( + opensearch_config["authentication_type"].strip().lower() + ) + if "aws_region" in opensearch_config: + opts.opensearch_aws_region = opensearch_config["aws_region"].strip() + if "aws_service" in opensearch_config: + opts.opensearch_aws_service = opensearch_config["aws_service"].strip() if "splunk_hec" in config.sections(): hec_config = config["splunk_hec"] @@ -1462,6 +1485,9 @@ def _main(): password=opts.opensearch_password, api_key=opts.opensearch_api_key, timeout=opensearch_timeout_value, + auth_type=opts.opensearch_auth_type, + aws_region=opts.opensearch_aws_region, + aws_service=opts.opensearch_aws_service, ) opensearch.migrate_indexes( aggregate_indexes=[os_aggregate_index], @@ -1830,7 +1856,11 @@ def _main(): "smtp_tls_reports": smtp_tls_reports, } - process_reports(parsing_results) + try: + process_reports(parsing_results) + except ParserError as error: + logger.error(error.__str__()) + exit(1) if opts.smtp_host: try: @@ -1875,6 +1905,7 @@ def _main(): dns_timeout=opts.dns_timeout, strip_attachment_payloads=opts.strip_attachment_payloads, batch_size=mailbox_batch_size_value, + since=opts.mailbox_since, ip_db_path=opts.ip_db_path, always_use_local_files=opts.always_use_local_files, reverse_dns_map_path=opts.reverse_dns_map_path, @@ -1885,6 +1916,9 @@ def _main(): except FileExistsError as error: logger.error("{0}".format(error.__str__())) exit(1) + except ParserError as error: + logger.error(error.__str__()) + exit(1) if __name__ == "__main__": diff --git a/parsedmarc/mail/imap.py b/parsedmarc/mail/imap.py index 94279d3..3252807 100644 --- a/parsedmarc/mail/imap.py +++ b/parsedmarc/mail/imap.py @@ -55,10 +55,28 @@ class IMAPConnection(MailboxConnection): return cast(str, self._client.fetch_message(message_id, parse=False)) def delete_message(self, message_id: int): - self._client.delete_messages([message_id]) + try: + self._client.delete_messages([message_id]) + except IMAPClientError as error: + logger.warning( + "IMAP delete fallback for message %s due to server error: %s", + message_id, + error, + ) + self._client.add_flags([message_id], [r"\Deleted"], silent=True) + self._client.expunge() def move_message(self, message_id: int, folder_name: str): - self._client.move_messages([message_id], folder_name) + try: + self._client.move_messages([message_id], folder_name) + except IMAPClientError as error: + logger.warning( + "IMAP move fallback for message %s due to server error: %s", + message_id, + error, + ) + self._client.copy([message_id], folder_name) + self.delete_message(message_id) def keepalive(self): self._client.noop() diff --git a/parsedmarc/opensearch.py b/parsedmarc/opensearch.py index 24a38a6..5da977e 100644 --- a/parsedmarc/opensearch.py +++ b/parsedmarc/opensearch.py @@ -4,7 +4,9 @@ from __future__ import annotations from typing import Any, Optional, Union +import boto3 from opensearchpy import ( + AWSV4SignerAuth, Boolean, Date, Document, @@ -16,6 +18,7 @@ from opensearchpy import ( Nested, Object, Q, + RequestsHttpConnection, Search, Text, connections, @@ -306,6 +309,9 @@ def set_hosts( password: Optional[str] = None, api_key: Optional[str] = None, timeout: Optional[float] = 60.0, + auth_type: str = "basic", + aws_region: Optional[str] = None, + aws_service: str = "es", ): """ Sets the OpenSearch hosts to use @@ -318,6 +324,9 @@ def set_hosts( password (str): The password to use for authentication api_key (str): The Base64 encoded API key to use for authentication timeout (float): Timeout in seconds + auth_type (str): OpenSearch auth mode: basic (default) or awssigv4 + aws_region (str): AWS region for SigV4 auth (required for awssigv4) + aws_service (str): AWS service for SigV4 signing (default: es) """ if not isinstance(hosts, list): hosts = [hosts] @@ -329,10 +338,32 @@ def set_hosts( conn_params["ca_certs"] = ssl_cert_path else: conn_params["verify_certs"] = False - if username and password: - conn_params["http_auth"] = username + ":" + password - if api_key: - conn_params["api_key"] = api_key + normalized_auth_type = (auth_type or "basic").strip().lower() + if normalized_auth_type == "awssigv4": + if not aws_region: + raise OpenSearchError( + "OpenSearch AWS SigV4 auth requires 'aws_region' to be set" + ) + session = boto3.Session() + credentials = session.get_credentials() + if credentials is None: + raise OpenSearchError( + "Unable to load AWS credentials for OpenSearch SigV4 authentication" + ) + conn_params["http_auth"] = AWSV4SignerAuth( + credentials, aws_region, aws_service + ) + conn_params["connection_class"] = RequestsHttpConnection + elif normalized_auth_type == "basic": + if username and password: + conn_params["http_auth"] = username + ":" + password + if api_key: + conn_params["api_key"] = api_key + else: + raise OpenSearchError( + f"Unsupported OpenSearch auth_type '{auth_type}'. " + "Expected 'basic' or 'awssigv4'." + ) connections.create_connection(**conn_params) diff --git a/tests.py b/tests.py index 7d1bc2a..94303d1 100755 --- a/tests.py +++ b/tests.py @@ -14,6 +14,7 @@ from glob import glob from base64 import urlsafe_b64encode from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory +from types import SimpleNamespace from unittest.mock import MagicMock, patch from lxml import etree @@ -33,6 +34,7 @@ from parsedmarc.mail.imap import IMAPConnection import parsedmarc.mail.gmail as gmail_module import parsedmarc.mail.graph as graph_module import parsedmarc.mail.imap as imap_module +import parsedmarc.opensearch as opensearch_module import parsedmarc.utils # Detect if running in GitHub Actions to skip DNS lookups @@ -1793,7 +1795,257 @@ class Test(unittest.TestCase): self.assertTrue(len(rows) > 0) print("Passed!") + def testOpenSearchSigV4RequiresRegion(self): + with self.assertRaises(opensearch_module.OpenSearchError): + opensearch_module.set_hosts( + "https://example.org:9200", + auth_type="awssigv4", + ) + def testOpenSearchSigV4ConfiguresConnectionClass(self): + fake_credentials = object() + with patch.object(opensearch_module.boto3, "Session") as session_cls: + session_cls.return_value.get_credentials.return_value = fake_credentials + with patch.object( + opensearch_module, "AWSV4SignerAuth", return_value="auth" + ) as signer: + with patch.object( + opensearch_module.connections, "create_connection" + ) as create_connection: + opensearch_module.set_hosts( + "https://example.org:9200", + use_ssl=True, + auth_type="awssigv4", + aws_region="eu-west-1", + ) + signer.assert_called_once_with(fake_credentials, "eu-west-1", "es") + create_connection.assert_called_once() + self.assertEqual( + create_connection.call_args.kwargs.get("connection_class"), + opensearch_module.RequestsHttpConnection, + ) + self.assertEqual(create_connection.call_args.kwargs.get("http_auth"), "auth") + + def testOpenSearchSigV4RejectsUnknownAuthType(self): + with self.assertRaises(opensearch_module.OpenSearchError): + opensearch_module.set_hosts( + "https://example.org:9200", + auth_type="kerberos", + ) + + def testOpenSearchSigV4RequiresAwsCredentials(self): + with patch.object(opensearch_module.boto3, "Session") as session_cls: + session_cls.return_value.get_credentials.return_value = None + with self.assertRaises(opensearch_module.OpenSearchError): + opensearch_module.set_hosts( + "https://example.org:9200", + auth_type="awssigv4", + aws_region="eu-west-1", + ) + + @patch("parsedmarc.cli.opensearch.migrate_indexes") + @patch("parsedmarc.cli.opensearch.set_hosts") + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.IMAPConnection") + def testCliPassesOpenSearchSigV4Settings( + self, + mock_imap_connection, + mock_get_reports, + mock_set_hosts, + _mock_migrate_indexes, + ): + mock_imap_connection.return_value = object() + mock_get_reports.return_value = { + "aggregate_reports": [], + "forensic_reports": [], + "smtp_tls_reports": [], + } + + config = """[general] +save_aggregate = true +silent = true + +[imap] +host = imap.example.com +user = test-user +password = test-password + +[opensearch] +hosts = localhost +authentication_type = awssigv4 +aws_region = eu-west-1 +aws_service = aoss +""" + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as config_file: + config_file.write(config) + config_path = config_file.name + self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): + parsedmarc.cli._main() + + self.assertEqual(mock_set_hosts.call_args.kwargs.get("auth_type"), "awssigv4") + self.assertEqual(mock_set_hosts.call_args.kwargs.get("aws_region"), "eu-west-1") + self.assertEqual(mock_set_hosts.call_args.kwargs.get("aws_service"), "aoss") + + @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") + @patch("parsedmarc.cli.elastic.migrate_indexes") + @patch("parsedmarc.cli.elastic.set_hosts") + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.IMAPConnection") + def testFailOnOutputErrorExits( + self, + mock_imap_connection, + mock_get_reports, + _mock_set_hosts, + _mock_migrate_indexes, + mock_save_aggregate, + ): + """CLI should exit with code 1 when fail_on_output_error is enabled""" + mock_imap_connection.return_value = object() + mock_get_reports.return_value = { + "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], + "forensic_reports": [], + "smtp_tls_reports": [], + } + mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( + "simulated output failure" + ) + + config = """[general] +save_aggregate = true +fail_on_output_error = true +silent = true + +[imap] +host = imap.example.com +user = test-user +password = test-password + +[elasticsearch] +hosts = localhost +""" + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as config_file: + config_file.write(config) + config_path = config_file.name + self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): + with self.assertRaises(SystemExit) as ctx: + parsedmarc.cli._main() + + self.assertEqual(ctx.exception.code, 1) + mock_save_aggregate.assert_called_once() + + @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") + @patch("parsedmarc.cli.elastic.migrate_indexes") + @patch("parsedmarc.cli.elastic.set_hosts") + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.IMAPConnection") + def testOutputErrorDoesNotExitWhenDisabled( + self, + mock_imap_connection, + mock_get_reports, + _mock_set_hosts, + _mock_migrate_indexes, + mock_save_aggregate, + ): + mock_imap_connection.return_value = object() + mock_get_reports.return_value = { + "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], + "forensic_reports": [], + "smtp_tls_reports": [], + } + mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( + "simulated output failure" + ) + + config = """[general] +save_aggregate = true +fail_on_output_error = false +silent = true + +[imap] +host = imap.example.com +user = test-user +password = test-password + +[elasticsearch] +hosts = localhost +""" + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as config_file: + config_file.write(config) + config_path = config_file.name + self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): + parsedmarc.cli._main() + + mock_save_aggregate.assert_called_once() + + @patch("parsedmarc.cli.opensearch.save_forensic_report_to_opensearch") + @patch("parsedmarc.cli.opensearch.migrate_indexes") + @patch("parsedmarc.cli.opensearch.set_hosts") + @patch("parsedmarc.cli.elastic.save_forensic_report_to_elasticsearch") + @patch("parsedmarc.cli.elastic.save_aggregate_report_to_elasticsearch") + @patch("parsedmarc.cli.elastic.migrate_indexes") + @patch("parsedmarc.cli.elastic.set_hosts") + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.IMAPConnection") + def testFailOnOutputErrorExitsWithMultipleSinkErrors( + self, + mock_imap_connection, + mock_get_reports, + _mock_es_set_hosts, + _mock_es_migrate, + mock_save_aggregate, + _mock_save_forensic_elastic, + _mock_os_set_hosts, + _mock_os_migrate, + mock_save_forensic_opensearch, + ): + mock_imap_connection.return_value = object() + mock_get_reports.return_value = { + "aggregate_reports": [{"policy_published": {"domain": "example.com"}}], + "forensic_reports": [{"reported_domain": "example.com"}], + "smtp_tls_reports": [], + } + mock_save_aggregate.side_effect = parsedmarc.elastic.ElasticsearchError( + "aggregate sink failed" + ) + mock_save_forensic_opensearch.side_effect = parsedmarc.cli.opensearch.OpenSearchError( + "forensic sink failed" + ) + + config = """[general] +save_aggregate = true +save_forensic = true +fail_on_output_error = true +silent = true + +[imap] +host = imap.example.com +user = test-user +password = test-password + +[elasticsearch] +hosts = localhost + +[opensearch] +hosts = localhost +""" + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as config_file: + config_file.write(config) + config_path = config_file.name + self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", config_path]): + with self.assertRaises(SystemExit) as ctx: + parsedmarc.cli._main() + + self.assertEqual(ctx.exception.code, 1) + mock_save_aggregate.assert_called_once() + mock_save_forensic_opensearch.assert_called_once() class _FakeGraphResponse: def __init__(self, status_code, payload=None, text=""): self.status_code = status_code @@ -1803,7 +2055,6 @@ class _FakeGraphResponse: def json(self): return self._payload - class _BreakLoop(BaseException): pass @@ -2402,5 +2653,121 @@ scopes = https://www.googleapis.com/auth/gmail.modify ) +class TestImapFallbacks(unittest.TestCase): + def testDeleteSuccessDoesNotUseFallback(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + connection.delete_message(42) + connection._client.delete_messages.assert_called_once_with([42]) + connection._client.add_flags.assert_not_called() + connection._client.expunge.assert_not_called() + + def testDeleteFallbackUsesFlagsAndExpunge(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + connection._client.delete_messages.side_effect = IMAPClientError("uid expunge") + connection.delete_message(42) + connection._client.add_flags.assert_called_once_with( + [42], [r"\Deleted"], silent=True + ) + connection._client.expunge.assert_called_once_with() + + def testDeleteFallbackErrorPropagates(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + connection._client.delete_messages.side_effect = IMAPClientError("uid expunge") + connection._client.add_flags.side_effect = IMAPClientError("flag failed") + with self.assertRaises(IMAPClientError): + connection.delete_message(42) + + def testMoveSuccessDoesNotUseFallback(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + with patch.object(connection, "delete_message") as delete_mock: + connection.move_message(99, "Archive") + connection._client.move_messages.assert_called_once_with([99], "Archive") + connection._client.copy.assert_not_called() + delete_mock.assert_not_called() + + def testMoveFallbackCopiesThenDeletes(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + connection._client.move_messages.side_effect = IMAPClientError("move failed") + with patch.object(connection, "delete_message") as delete_mock: + connection.move_message(99, "Archive") + connection._client.copy.assert_called_once_with([99], "Archive") + delete_mock.assert_called_once_with(99) + + def testMoveFallbackCopyErrorPropagates(self): + connection = IMAPConnection.__new__(IMAPConnection) + connection._client = MagicMock() + connection._client.move_messages.side_effect = IMAPClientError("move failed") + connection._client.copy.side_effect = IMAPClientError("copy failed") + with patch.object(connection, "delete_message") as delete_mock: + with self.assertRaises(IMAPClientError): + connection.move_message(99, "Archive") + delete_mock.assert_not_called() + +class TestMailboxWatchSince(unittest.TestCase): + def testWatchInboxPassesSinceToMailboxFetch(self): + mailbox_connection = SimpleNamespace() + + def fake_watch(check_callback, check_timeout): + check_callback(mailbox_connection) + raise _BreakLoop() + + mailbox_connection.watch = fake_watch + callback = MagicMock() + with patch.object( + parsedmarc, "get_dmarc_reports_from_mailbox", return_value={} + ) as mocked: + with self.assertRaises(_BreakLoop): + parsedmarc.watch_inbox( + mailbox_connection=mailbox_connection, + callback=callback, + check_timeout=1, + batch_size=10, + since="1d", + ) + self.assertEqual(mocked.call_args.kwargs.get("since"), "1d") + + @patch("parsedmarc.cli.get_dmarc_reports_from_mailbox") + @patch("parsedmarc.cli.watch_inbox") + @patch("parsedmarc.cli.IMAPConnection") + def testCliPassesSinceToWatchInbox( + self, mock_imap_connection, mock_watch_inbox, mock_get_mailbox_reports + ): + mock_imap_connection.return_value = object() + mock_get_mailbox_reports.return_value = { + "aggregate_reports": [], + "forensic_reports": [], + "smtp_tls_reports": [], + } + mock_watch_inbox.side_effect = FileExistsError("stop-watch-loop") + + config_text = """[general] +silent = true + +[imap] +host = imap.example.com +user = user +password = pass + +[mailbox] +watch = true +since = 2d +""" + + with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg: + cfg.write(config_text) + cfg_path = cfg.name + self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path)) + + with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]): + with self.assertRaises(SystemExit) as system_exit: + parsedmarc.cli._main() + + self.assertEqual(system_exit.exception.code, 1) + self.assertEqual(mock_watch_inbox.call_args.kwargs.get("since"), "2d") if __name__ == "__main__": unittest.main(verbosity=2)