mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-03-20 21:45:59 +00:00
Compare commits
7 Commits
copilot/su
...
copilot/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32e85399bf | ||
|
|
c3876bb9d9 | ||
|
|
a05eb0807a | ||
|
|
6fceb3e2ce | ||
|
|
510d5d05a9 | ||
|
|
3445438684 | ||
|
|
17defb75b0 |
@@ -2249,11 +2249,14 @@ def watch_inbox(
|
||||
)
|
||||
callback(res)
|
||||
|
||||
mailbox_connection.watch(
|
||||
check_callback=check_callback,
|
||||
check_timeout=check_timeout,
|
||||
should_reload=should_reload,
|
||||
)
|
||||
watch_kwargs: dict = {
|
||||
"check_callback": check_callback,
|
||||
"check_timeout": check_timeout,
|
||||
}
|
||||
if should_reload is not None:
|
||||
watch_kwargs["should_reload"] = should_reload
|
||||
|
||||
mailbox_connection.watch(**watch_kwargs)
|
||||
|
||||
|
||||
def append_json(
|
||||
|
||||
@@ -201,8 +201,21 @@ def _parse_config_file(config_file, opts):
|
||||
"normalize_timespan_threshold_hours"
|
||||
)
|
||||
if "index_prefix_domain_map" in general_config:
|
||||
with open(general_config["index_prefix_domain_map"]) as f:
|
||||
index_prefix_domain_map = yaml.safe_load(f)
|
||||
map_path = general_config["index_prefix_domain_map"]
|
||||
try:
|
||||
with open(map_path) as f:
|
||||
index_prefix_domain_map = yaml.safe_load(f)
|
||||
except OSError as exc:
|
||||
raise ConfigurationError(
|
||||
"Failed to read index_prefix_domain_map file '{0}': {1}".format(
|
||||
map_path, exc
|
||||
)
|
||||
) from exc
|
||||
except yaml.YAMLError as exc:
|
||||
raise ConfigurationError(
|
||||
"Failed to parse YAML in index_prefix_domain_map "
|
||||
"file '{0}': {1}".format(map_path, exc)
|
||||
) from exc
|
||||
if "offline" in general_config:
|
||||
opts.offline = bool(general_config.getboolean("offline"))
|
||||
if "strip_attachment_payloads" in general_config:
|
||||
@@ -242,7 +255,7 @@ def _parse_config_file(config_file, opts):
|
||||
except Exception as ns_error:
|
||||
raise ConfigurationError(
|
||||
"DNS pre-flight check failed: {}".format(ns_error)
|
||||
)
|
||||
) from ns_error
|
||||
if not dummy_hostname:
|
||||
raise ConfigurationError(
|
||||
"DNS pre-flight check failed: no PTR record for {} from {}".format(
|
||||
@@ -259,8 +272,6 @@ def _parse_config_file(config_file, opts):
|
||||
opts.debug = bool(general_config.getboolean("debug"))
|
||||
if "verbose" in general_config:
|
||||
opts.verbose = bool(general_config.getboolean("verbose"))
|
||||
if "silent" in general_config:
|
||||
opts.silent = bool(general_config.getboolean("silent"))
|
||||
if "warnings" in general_config:
|
||||
opts.warnings = bool(general_config.getboolean("warnings"))
|
||||
if "fail_on_output_error" in general_config:
|
||||
@@ -588,9 +599,9 @@ def _parse_config_file(config_file, opts):
|
||||
"index setting missing from the splunk_hec config section"
|
||||
)
|
||||
if "skip_certificate_verification" in hec_config:
|
||||
opts.hec_skip_certificate_verification = hec_config[
|
||||
"skip_certificate_verification"
|
||||
]
|
||||
opts.hec_skip_certificate_verification = bool(
|
||||
hec_config.getboolean("skip_certificate_verification", fallback=False)
|
||||
)
|
||||
|
||||
if "kafka" in config.sections():
|
||||
kafka_config = config["kafka"]
|
||||
@@ -620,14 +631,14 @@ def _parse_config_file(config_file, opts):
|
||||
if "forensic_topic" in kafka_config:
|
||||
opts.kafka_forensic_topic = kafka_config["forensic_topic"]
|
||||
else:
|
||||
logger.critical(
|
||||
raise ConfigurationError(
|
||||
"forensic_topic setting missing from the kafka config section"
|
||||
)
|
||||
if "smtp_tls_topic" in kafka_config:
|
||||
opts.kafka_smtp_tls_topic = kafka_config["smtp_tls_topic"]
|
||||
else:
|
||||
logger.critical(
|
||||
"forensic_topic setting missing from the splunk_hec config section"
|
||||
raise ConfigurationError(
|
||||
"smtp_tls_topic setting missing from the kafka config section"
|
||||
)
|
||||
|
||||
if "smtp" in config.sections():
|
||||
@@ -894,31 +905,39 @@ def _init_output_clients(opts):
|
||||
)
|
||||
|
||||
if opts.s3_bucket:
|
||||
clients["s3_client"] = s3.S3Client(
|
||||
bucket_name=opts.s3_bucket,
|
||||
bucket_path=opts.s3_path,
|
||||
region_name=opts.s3_region_name,
|
||||
endpoint_url=opts.s3_endpoint_url,
|
||||
access_key_id=opts.s3_access_key_id,
|
||||
secret_access_key=opts.s3_secret_access_key,
|
||||
)
|
||||
try:
|
||||
clients["s3_client"] = s3.S3Client(
|
||||
bucket_name=opts.s3_bucket,
|
||||
bucket_path=opts.s3_path,
|
||||
region_name=opts.s3_region_name,
|
||||
endpoint_url=opts.s3_endpoint_url,
|
||||
access_key_id=opts.s3_access_key_id,
|
||||
secret_access_key=opts.s3_secret_access_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to initialize S3 client; skipping", exc_info=True)
|
||||
|
||||
if opts.syslog_server:
|
||||
clients["syslog_client"] = syslog.SyslogClient(
|
||||
server_name=opts.syslog_server,
|
||||
server_port=int(opts.syslog_port),
|
||||
protocol=opts.syslog_protocol or "udp",
|
||||
cafile_path=opts.syslog_cafile_path,
|
||||
certfile_path=opts.syslog_certfile_path,
|
||||
keyfile_path=opts.syslog_keyfile_path,
|
||||
timeout=opts.syslog_timeout if opts.syslog_timeout is not None else 5.0,
|
||||
retry_attempts=opts.syslog_retry_attempts
|
||||
if opts.syslog_retry_attempts is not None
|
||||
else 3,
|
||||
retry_delay=opts.syslog_retry_delay
|
||||
if opts.syslog_retry_delay is not None
|
||||
else 5,
|
||||
)
|
||||
try:
|
||||
clients["syslog_client"] = syslog.SyslogClient(
|
||||
server_name=opts.syslog_server,
|
||||
server_port=int(opts.syslog_port),
|
||||
protocol=opts.syslog_protocol or "udp",
|
||||
cafile_path=opts.syslog_cafile_path,
|
||||
certfile_path=opts.syslog_certfile_path,
|
||||
keyfile_path=opts.syslog_keyfile_path,
|
||||
timeout=opts.syslog_timeout if opts.syslog_timeout is not None else 5.0,
|
||||
retry_attempts=opts.syslog_retry_attempts
|
||||
if opts.syslog_retry_attempts is not None
|
||||
else 3,
|
||||
retry_delay=opts.syslog_retry_delay
|
||||
if opts.syslog_retry_delay is not None
|
||||
else 5,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize syslog client; skipping", exc_info=True
|
||||
)
|
||||
|
||||
if opts.hec:
|
||||
if opts.hec_token is None or opts.hec_index is None:
|
||||
@@ -928,46 +947,81 @@ def _init_output_clients(opts):
|
||||
verify = True
|
||||
if opts.hec_skip_certificate_verification:
|
||||
verify = False
|
||||
clients["hec_client"] = splunk.HECClient(
|
||||
opts.hec, opts.hec_token, opts.hec_index, verify=verify
|
||||
)
|
||||
try:
|
||||
clients["hec_client"] = splunk.HECClient(
|
||||
opts.hec, opts.hec_token, opts.hec_index, verify=verify
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize Splunk HEC client; skipping", exc_info=True
|
||||
)
|
||||
|
||||
if opts.kafka_hosts:
|
||||
ssl_context = None
|
||||
if opts.kafka_skip_certificate_verification:
|
||||
logger.debug("Skipping Kafka certificate verification")
|
||||
if opts.kafka_ssl:
|
||||
ssl_context = create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = CERT_NONE
|
||||
clients["kafka_client"] = kafkaclient.KafkaClient(
|
||||
opts.kafka_hosts,
|
||||
username=opts.kafka_username,
|
||||
password=opts.kafka_password,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
if opts.kafka_skip_certificate_verification:
|
||||
logger.debug("Skipping Kafka certificate verification")
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = CERT_NONE
|
||||
try:
|
||||
clients["kafka_client"] = kafkaclient.KafkaClient(
|
||||
opts.kafka_hosts,
|
||||
username=opts.kafka_username,
|
||||
password=opts.kafka_password,
|
||||
ssl=opts.kafka_ssl,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to initialize Kafka client; skipping", exc_info=True)
|
||||
|
||||
if opts.gelf_host:
|
||||
clients["gelf_client"] = gelf.GelfClient(
|
||||
host=opts.gelf_host,
|
||||
port=int(opts.gelf_port),
|
||||
mode=opts.gelf_mode,
|
||||
)
|
||||
try:
|
||||
clients["gelf_client"] = gelf.GelfClient(
|
||||
host=opts.gelf_host,
|
||||
port=int(opts.gelf_port),
|
||||
mode=opts.gelf_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to initialize GELF client; skipping", exc_info=True)
|
||||
|
||||
if (
|
||||
opts.webhook_aggregate_url
|
||||
or opts.webhook_forensic_url
|
||||
or opts.webhook_smtp_tls_url
|
||||
):
|
||||
clients["webhook_client"] = webhook.WebhookClient(
|
||||
aggregate_url=opts.webhook_aggregate_url,
|
||||
forensic_url=opts.webhook_forensic_url,
|
||||
smtp_tls_url=opts.webhook_smtp_tls_url,
|
||||
timeout=opts.webhook_timeout,
|
||||
)
|
||||
try:
|
||||
clients["webhook_client"] = webhook.WebhookClient(
|
||||
aggregate_url=opts.webhook_aggregate_url,
|
||||
forensic_url=opts.webhook_forensic_url,
|
||||
smtp_tls_url=opts.webhook_smtp_tls_url,
|
||||
timeout=opts.webhook_timeout,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize webhook client; skipping", exc_info=True
|
||||
)
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
def _close_output_clients(clients):
|
||||
"""Close output clients that hold persistent connections.
|
||||
|
||||
Clients that do not expose a ``close`` method are silently skipped.
|
||||
Errors during closing are logged as warnings and do not propagate.
|
||||
|
||||
Args:
|
||||
clients: dict of client instances returned by :func:`_init_output_clients`.
|
||||
"""
|
||||
for name, client in clients.items():
|
||||
if hasattr(client, "close"):
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.warning("Error closing %s", name, exc_info=True)
|
||||
|
||||
|
||||
def _main():
|
||||
"""Called when the module is executed"""
|
||||
|
||||
@@ -1561,7 +1615,11 @@ def _main():
|
||||
normalize_timespan_threshold_hours=24.0,
|
||||
fail_on_output_error=False,
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Snapshot opts as set from CLI args / hardcoded defaults, before any config
|
||||
# file is applied. On each SIGHUP reload we restore this baseline first so
|
||||
# that sections removed from the config file actually take effect.
|
||||
opts_from_cli = Namespace(**vars(opts))
|
||||
|
||||
index_prefix_domain_map = None
|
||||
|
||||
@@ -1937,7 +1995,6 @@ def _main():
|
||||
logger.info("Watching for email - Quit with ctrl-c")
|
||||
|
||||
while True:
|
||||
_reload_requested = False
|
||||
try:
|
||||
watch_inbox(
|
||||
mailbox_connection=mailbox_connection,
|
||||
@@ -1970,12 +2027,26 @@ def _main():
|
||||
if not _reload_requested:
|
||||
break
|
||||
|
||||
# Reload configuration
|
||||
# Reload configuration — clear the flag first so that any new
|
||||
# SIGHUP arriving while we reload will be captured for the next
|
||||
# iteration rather than being silently dropped.
|
||||
_reload_requested = False
|
||||
logger.info("Reloading configuration...")
|
||||
old_opts_snapshot = Namespace(**vars(opts))
|
||||
try:
|
||||
index_prefix_domain_map = _parse_config_file(args.config_file, opts)
|
||||
clients = _init_output_clients(opts)
|
||||
# Build a fresh opts starting from CLI-only defaults so that
|
||||
# sections removed from the config file actually take effect.
|
||||
new_opts = Namespace(**vars(opts_from_cli))
|
||||
new_index_prefix_domain_map = _parse_config_file(
|
||||
args.config_file, new_opts
|
||||
)
|
||||
new_clients = _init_output_clients(new_opts)
|
||||
|
||||
# All steps succeeded — commit the changes atomically.
|
||||
_close_output_clients(clients)
|
||||
clients = new_clients
|
||||
index_prefix_domain_map = new_index_prefix_domain_map
|
||||
for k, v in vars(new_opts).items():
|
||||
setattr(opts, k, v)
|
||||
|
||||
# Update watch parameters from reloaded config
|
||||
mailbox_batch_size_value = (
|
||||
@@ -2003,14 +2074,36 @@ def _main():
|
||||
if opts.debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Refresh FileHandler if log_file changed
|
||||
old_log_file = getattr(opts, "active_log_file", None)
|
||||
new_log_file = opts.log_file
|
||||
if old_log_file != new_log_file:
|
||||
# Remove old FileHandlers
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, logging.FileHandler):
|
||||
h.close()
|
||||
logger.removeHandler(h)
|
||||
# Add new FileHandler if configured
|
||||
if new_log_file:
|
||||
try:
|
||||
fh = logging.FileHandler(new_log_file, "a")
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s"
|
||||
" - [%(filename)s:%(lineno)d] - %(message)s"
|
||||
)
|
||||
fh.setFormatter(file_formatter)
|
||||
logger.addHandler(fh)
|
||||
except Exception as log_error:
|
||||
logger.warning(
|
||||
"Unable to write to log file: {}".format(log_error)
|
||||
)
|
||||
opts.active_log_file = new_log_file
|
||||
|
||||
logger.info("Configuration reloaded successfully")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Config reload failed, continuing with previous config"
|
||||
)
|
||||
# Restore old opts
|
||||
for k, v in vars(old_opts_snapshot).items():
|
||||
setattr(opts, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -69,3 +69,8 @@ class GelfClient(object):
|
||||
for row in rows:
|
||||
log_context_data.parsedmarc = row
|
||||
self.logger.info("parsedmarc smtptls report")
|
||||
|
||||
def close(self):
|
||||
"""Remove and close the GELF handler, releasing its connection."""
|
||||
self.logger.removeHandler(self.handler)
|
||||
self.handler.close()
|
||||
|
||||
@@ -62,6 +62,10 @@ class KafkaClient(object):
|
||||
except NoBrokersAvailable:
|
||||
raise KafkaError("No Kafka brokers available")
|
||||
|
||||
def close(self):
|
||||
"""Close the Kafka producer, releasing background threads and sockets."""
|
||||
self.producer.close()
|
||||
|
||||
@staticmethod
|
||||
def strip_metadata(report: dict[str, Any]):
|
||||
"""
|
||||
|
||||
@@ -178,10 +178,12 @@ class GmailConnection(MailboxConnection):
|
||||
def watch(self, check_callback, check_timeout, should_reload=None):
|
||||
"""Checks the mailbox for new messages every n seconds"""
|
||||
while True:
|
||||
sleep(check_timeout)
|
||||
check_callback(self)
|
||||
if should_reload and should_reload():
|
||||
return
|
||||
sleep(check_timeout)
|
||||
if should_reload and should_reload():
|
||||
return
|
||||
check_callback(self)
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def _find_label_id_for_label(self, label_name: str) -> str:
|
||||
|
||||
@@ -281,10 +281,10 @@ class MSGraphConnection(MailboxConnection):
|
||||
def watch(self, check_callback, check_timeout, should_reload=None):
|
||||
"""Checks the mailbox for new messages every n seconds"""
|
||||
while True:
|
||||
sleep(check_timeout)
|
||||
check_callback(self)
|
||||
if should_reload and should_reload():
|
||||
return
|
||||
sleep(check_timeout)
|
||||
check_callback(self)
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def _find_folder_id_from_folder_path(self, folder_name: str) -> str:
|
||||
|
||||
@@ -57,7 +57,7 @@ class SyslogClient(object):
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Create the appropriate syslog handler based on protocol
|
||||
log_handler = self._create_syslog_handler(
|
||||
self.log_handler = self._create_syslog_handler(
|
||||
server_name,
|
||||
server_port,
|
||||
self.protocol,
|
||||
@@ -69,7 +69,7 @@ class SyslogClient(object):
|
||||
retry_delay,
|
||||
)
|
||||
|
||||
self.logger.addHandler(log_handler)
|
||||
self.logger.addHandler(self.log_handler)
|
||||
|
||||
def _create_syslog_handler(
|
||||
self,
|
||||
@@ -179,3 +179,8 @@ class SyslogClient(object):
|
||||
rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports)
|
||||
for row in rows:
|
||||
self.logger.info(json.dumps(row))
|
||||
|
||||
def close(self):
|
||||
"""Remove and close the syslog handler, releasing its socket."""
|
||||
self.logger.removeHandler(self.log_handler)
|
||||
self.log_handler.close()
|
||||
|
||||
@@ -63,3 +63,7 @@ class WebhookClient(object):
|
||||
self.session.post(webhook_url, data=payload, timeout=self.timeout)
|
||||
except Exception as error_:
|
||||
logger.error("Webhook Error: {0}".format(error_.__str__()))
|
||||
|
||||
def close(self):
|
||||
"""Close the underlying HTTP session."""
|
||||
self.session.close()
|
||||
|
||||
304
tests.py
304
tests.py
@@ -4,6 +4,7 @@
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -1910,5 +1911,308 @@ certificate_path = /tmp/msgraph-cert.pem
|
||||
mock_get_mailbox_reports.assert_not_called()
|
||||
|
||||
|
||||
class TestSighupReload(unittest.TestCase):
|
||||
"""Tests for SIGHUP-driven configuration reload in watch mode."""
|
||||
|
||||
_BASE_CONFIG = """[general]
|
||||
silent = true
|
||||
|
||||
[imap]
|
||||
host = imap.example.com
|
||||
user = user
|
||||
password = pass
|
||||
|
||||
[mailbox]
|
||||
watch = true
|
||||
"""
|
||||
|
||||
@unittest.skipUnless(hasattr(signal, "SIGHUP"), "SIGHUP not available on this platform")
|
||||
@patch("parsedmarc.cli._init_output_clients")
|
||||
@patch("parsedmarc.cli._parse_config_file")
|
||||
@patch("parsedmarc.cli.get_dmarc_reports_from_mailbox")
|
||||
@patch("parsedmarc.cli.watch_inbox")
|
||||
@patch("parsedmarc.cli.IMAPConnection")
|
||||
def testSighupTriggersReloadAndWatchRestarts(
|
||||
self,
|
||||
mock_imap,
|
||||
mock_watch,
|
||||
mock_get_reports,
|
||||
mock_parse_config,
|
||||
mock_init_clients,
|
||||
):
|
||||
"""SIGHUP causes watch to return, config is re-parsed, and watch restarts."""
|
||||
import signal as signal_module
|
||||
|
||||
mock_imap.return_value = object()
|
||||
mock_get_reports.return_value = {
|
||||
"aggregate_reports": [],
|
||||
"forensic_reports": [],
|
||||
"smtp_tls_reports": [],
|
||||
}
|
||||
|
||||
def parse_side_effect(config_file, opts):
|
||||
opts.imap_host = "imap.example.com"
|
||||
opts.imap_user = "user"
|
||||
opts.imap_password = "pass"
|
||||
opts.mailbox_watch = True
|
||||
return None
|
||||
|
||||
mock_parse_config.side_effect = parse_side_effect
|
||||
mock_init_clients.return_value = {}
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def watch_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# Simulate SIGHUP arriving while watch is running
|
||||
if hasattr(signal_module, "SIGHUP"):
|
||||
import os
|
||||
|
||||
os.kill(os.getpid(), signal_module.SIGHUP)
|
||||
return # Normal return — reload loop will continue
|
||||
else:
|
||||
raise FileExistsError("stop-watch-loop")
|
||||
|
||||
mock_watch.side_effect = watch_side_effect
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg:
|
||||
cfg.write(self._BASE_CONFIG)
|
||||
cfg_path = cfg.name
|
||||
self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path))
|
||||
|
||||
with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]):
|
||||
with self.assertRaises(SystemExit) as cm:
|
||||
parsedmarc.cli._main()
|
||||
|
||||
# Exited with code 1 (from FileExistsError handler)
|
||||
self.assertEqual(cm.exception.code, 1)
|
||||
# watch_inbox was called twice: initial run + after reload
|
||||
self.assertEqual(mock_watch.call_count, 2)
|
||||
# _parse_config_file called for initial load + reload
|
||||
self.assertGreaterEqual(mock_parse_config.call_count, 2)
|
||||
|
||||
@unittest.skipUnless(hasattr(signal, "SIGHUP"), "SIGHUP not available on this platform")
|
||||
@patch("parsedmarc.cli._init_output_clients")
|
||||
@patch("parsedmarc.cli._parse_config_file")
|
||||
@patch("parsedmarc.cli.get_dmarc_reports_from_mailbox")
|
||||
@patch("parsedmarc.cli.watch_inbox")
|
||||
@patch("parsedmarc.cli.IMAPConnection")
|
||||
def testInvalidConfigOnReloadKeepsPreviousState(
|
||||
self,
|
||||
mock_imap,
|
||||
mock_watch,
|
||||
mock_get_reports,
|
||||
mock_parse_config,
|
||||
mock_init_clients,
|
||||
):
|
||||
"""A failing reload leaves opts and clients unchanged."""
|
||||
import signal as signal_module
|
||||
|
||||
mock_imap.return_value = object()
|
||||
mock_get_reports.return_value = {
|
||||
"aggregate_reports": [],
|
||||
"forensic_reports": [],
|
||||
"smtp_tls_reports": [],
|
||||
}
|
||||
|
||||
# Initial parse sets required opts; reload parse raises
|
||||
initial_map = {"prefix_": ["example.com"]}
|
||||
call_count = [0]
|
||||
|
||||
def parse_side_effect(config_file, opts):
|
||||
call_count[0] += 1
|
||||
opts.imap_host = "imap.example.com"
|
||||
opts.imap_user = "user"
|
||||
opts.imap_password = "pass"
|
||||
opts.mailbox_watch = True
|
||||
if call_count[0] == 1:
|
||||
return initial_map
|
||||
raise RuntimeError("bad config")
|
||||
|
||||
mock_parse_config.side_effect = parse_side_effect
|
||||
|
||||
initial_clients = {"s3_client": MagicMock()}
|
||||
mock_init_clients.return_value = initial_clients
|
||||
|
||||
watch_calls = [0]
|
||||
|
||||
def watch_side_effect(*args, **kwargs):
|
||||
watch_calls[0] += 1
|
||||
if watch_calls[0] == 1:
|
||||
if hasattr(signal_module, "SIGHUP"):
|
||||
import os
|
||||
|
||||
os.kill(os.getpid(), signal_module.SIGHUP)
|
||||
return
|
||||
else:
|
||||
raise FileExistsError("stop")
|
||||
|
||||
mock_watch.side_effect = watch_side_effect
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg:
|
||||
cfg.write(self._BASE_CONFIG)
|
||||
cfg_path = cfg.name
|
||||
self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path))
|
||||
|
||||
with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]):
|
||||
with self.assertRaises(SystemExit) as cm:
|
||||
parsedmarc.cli._main()
|
||||
|
||||
self.assertEqual(cm.exception.code, 1)
|
||||
# watch was still called twice (reload loop continued after failed reload)
|
||||
self.assertEqual(mock_watch.call_count, 2)
|
||||
# The failed reload must not have closed the original clients
|
||||
initial_clients["s3_client"].close.assert_not_called()
|
||||
|
||||
@unittest.skipUnless(hasattr(signal, "SIGHUP"), "SIGHUP not available on this platform")
|
||||
@patch("parsedmarc.cli._init_output_clients")
|
||||
@patch("parsedmarc.cli._parse_config_file")
|
||||
@patch("parsedmarc.cli.get_dmarc_reports_from_mailbox")
|
||||
@patch("parsedmarc.cli.watch_inbox")
|
||||
@patch("parsedmarc.cli.IMAPConnection")
|
||||
def testReloadClosesOldClients(
|
||||
self,
|
||||
mock_imap,
|
||||
mock_watch,
|
||||
mock_get_reports,
|
||||
mock_parse_config,
|
||||
mock_init_clients,
|
||||
):
|
||||
"""Successful reload closes the old output clients before replacing them."""
|
||||
import signal as signal_module
|
||||
|
||||
mock_imap.return_value = object()
|
||||
mock_get_reports.return_value = {
|
||||
"aggregate_reports": [],
|
||||
"forensic_reports": [],
|
||||
"smtp_tls_reports": [],
|
||||
}
|
||||
|
||||
def parse_side_effect(config_file, opts):
|
||||
opts.imap_host = "imap.example.com"
|
||||
opts.imap_user = "user"
|
||||
opts.imap_password = "pass"
|
||||
opts.mailbox_watch = True
|
||||
return None
|
||||
|
||||
mock_parse_config.side_effect = parse_side_effect
|
||||
|
||||
old_client = MagicMock()
|
||||
new_client = MagicMock()
|
||||
init_call = [0]
|
||||
|
||||
def init_side_effect(opts):
|
||||
init_call[0] += 1
|
||||
if init_call[0] == 1:
|
||||
return {"kafka_client": old_client}
|
||||
return {"kafka_client": new_client}
|
||||
|
||||
mock_init_clients.side_effect = init_side_effect
|
||||
|
||||
watch_calls = [0]
|
||||
|
||||
def watch_side_effect(*args, **kwargs):
|
||||
watch_calls[0] += 1
|
||||
if watch_calls[0] == 1:
|
||||
if hasattr(signal_module, "SIGHUP"):
|
||||
import os
|
||||
|
||||
os.kill(os.getpid(), signal_module.SIGHUP)
|
||||
return
|
||||
else:
|
||||
raise FileExistsError("stop")
|
||||
|
||||
mock_watch.side_effect = watch_side_effect
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg:
|
||||
cfg.write(self._BASE_CONFIG)
|
||||
cfg_path = cfg.name
|
||||
self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path))
|
||||
|
||||
with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]):
|
||||
with self.assertRaises(SystemExit):
|
||||
parsedmarc.cli._main()
|
||||
|
||||
# Old client must have been closed when reload succeeded
|
||||
old_client.close.assert_called_once()
|
||||
|
||||
@unittest.skipUnless(hasattr(signal, "SIGHUP"), "SIGHUP not available on this platform")
|
||||
@patch("parsedmarc.cli._init_output_clients")
|
||||
@patch("parsedmarc.cli.get_dmarc_reports_from_mailbox")
|
||||
@patch("parsedmarc.cli.watch_inbox")
|
||||
@patch("parsedmarc.cli.IMAPConnection")
|
||||
def testRemovedConfigSectionTakesEffectOnReload(
|
||||
self,
|
||||
mock_imap,
|
||||
mock_watch,
|
||||
mock_get_reports,
|
||||
mock_init_clients,
|
||||
):
|
||||
"""Removing a config section on reload resets that option to its default."""
|
||||
import signal as signal_module
|
||||
|
||||
mock_imap.return_value = object()
|
||||
mock_get_reports.return_value = {
|
||||
"aggregate_reports": [],
|
||||
"forensic_reports": [],
|
||||
"smtp_tls_reports": [],
|
||||
}
|
||||
mock_init_clients.return_value = {}
|
||||
|
||||
# First config sets kafka_hosts (with required topics); second removes it.
|
||||
config_v1 = (
|
||||
self._BASE_CONFIG
|
||||
+ "\n[kafka]\nhosts = kafka.example.com:9092\n"
|
||||
+ "aggregate_topic = dmarc_agg\n"
|
||||
+ "forensic_topic = dmarc_forensic\n"
|
||||
+ "smtp_tls_topic = smtp_tls\n"
|
||||
)
|
||||
config_v2 = self._BASE_CONFIG # no [kafka] section
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".ini", delete=False) as cfg:
|
||||
cfg.write(config_v1)
|
||||
cfg_path = cfg.name
|
||||
self.addCleanup(lambda: os.path.exists(cfg_path) and os.remove(cfg_path))
|
||||
|
||||
watch_calls = [0]
|
||||
|
||||
def watch_side_effect(*args, **kwargs):
|
||||
watch_calls[0] += 1
|
||||
if watch_calls[0] == 1:
|
||||
# Rewrite config to remove kafka before triggering reload
|
||||
with open(cfg_path, "w") as f:
|
||||
f.write(config_v2)
|
||||
if hasattr(signal_module, "SIGHUP"):
|
||||
import os
|
||||
|
||||
os.kill(os.getpid(), signal_module.SIGHUP)
|
||||
return
|
||||
else:
|
||||
raise FileExistsError("stop")
|
||||
|
||||
mock_watch.side_effect = watch_side_effect
|
||||
|
||||
# Capture opts used on each _init_output_clients call
|
||||
init_opts_captures = []
|
||||
|
||||
def init_side_effect(opts):
|
||||
from argparse import Namespace as NS
|
||||
|
||||
init_opts_captures.append(NS(**vars(opts)))
|
||||
return {}
|
||||
|
||||
mock_init_clients.side_effect = init_side_effect
|
||||
|
||||
with patch.object(sys, "argv", ["parsedmarc", "-c", cfg_path]):
|
||||
with self.assertRaises(SystemExit):
|
||||
parsedmarc.cli._main()
|
||||
|
||||
# First init: kafka_hosts should be set from v1 config
|
||||
self.assertIsNotNone(init_opts_captures[0].kafka_hosts)
|
||||
# Second init (after reload with v2 config): kafka_hosts should be None
|
||||
self.assertIsNone(init_opts_captures[1].kafka_hosts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user