Expand honest test coverage from 59% to 83%; fix two latent bugs (#775)

* Expand honest test coverage from 59% to 83%; fix two latent bugs

271 new tests across the output modules, ES/OS clients, CLI config
parsing, and the top-level parsing surface. Coverage measured against
shipped code only (see [tool.coverage.run] source = ["parsedmarc"]
omit = ["*/parsedmarc/resources/maps/*.py"] in pyproject.toml).

Per-module results:

  s3.py             38% → 100%   (also fixes SMTP-TLS-to-S3 bug below)
  gelf.py           40% → 100%
  syslog.py         46% → 100%
  kafkaclient.py    34% → 100%
  splunk.py         24% → 100%
  loganalytics.py   56% → 100%
  webhook.py        78% → 100%   (also removes redundant try/except)
  elastic.py        36% →  99%
  opensearch.py     40% →  99%
  cli.py            52% →  69%
  __init__.py       74% →  76%   (also fixes append_json bug below)
  utils.py          84% (unchanged in this PR)
  TOTAL             59% →  83%

The remaining 17% is honest. The biggest unreached blocks are
_main() in cli.py and the watch-mode mailbox iteration in __init__.py,
both of which would require either standing up live subsystems (real
Elasticsearch, real IMAP) or mocking deep enough that the test would
verify the mock rather than the code. The PR-A AGENTS.md guidance —
"if 90% requires faking it, ship 85% honestly" — applies here.

Bugs fixed while writing tests:

1. parsedmarc/s3.py — SMTP-TLS-to-S3 was completely broken.
   save_report_to_s3 unconditionally read report["report_metadata"]
   when building S3 object metadata, but RFC 8460 §4.3 SMTP TLS
   reports are flat (no report_metadata sub-object). The CLI's
   surrounding try/except silently swallowed the KeyError, so every
   SMTP-TLS report quietly failed to upload. Also fixes a related
   issue: parse_smtp_tls_report_json stores begin_date as the raw
   ISO-8601 string from the report (per the SMTPTLSReport TypedDict
   and RFC 8460 §4.3), but the S3 code path assumed a datetime
   with .year / .month / .day attributes. Both fixed; the broken
   metadata-extraction branch now uses the flat-report fields, and
   the date branch normalizes via human_timestamp_to_datetime.

2. parsedmarc/__init__.py — append_json corrupted JSON output files
   on the second write. The original implementation opened files in
   "a+" mode, then seek()ed backwards to overwrite the trailing "]"
   with ",\n" before appending more elements. Python's docs are
   explicit (https://docs.python.org/3/library/functions.html#open):
   on POSIX, writes in "a"/"a+" mode always go to EOF regardless of
   seek() position. The result was that the second call produced
   [...]\n],\n[...] -style corrupted output instead of a single
   merged array. Replaced with a read-merge-write pattern: load the
   existing array (if any), append the new elements, rewrite the
   whole file. The CSV cousin append_csv was not affected — it
   doesn't seek backwards.

3. parsedmarc/webhook.py — removed redundant try/except blocks in
   save_aggregate_report_to_webhook / save_failure_report_to_webhook
   / save_smtp_tls_report_to_webhook. _send_to_webhook already
   catches every Exception itself, so the outer except blocks were
   unreachable dead code (covered nothing, defended against nothing,
   and inflated the source-line count without testing value).

Testing approach: mocks at SDK boundaries (boto3 resource, kafka
producer, requests session, opensearch/elasticsearch Document/Search,
azure LogsIngestionClient). Tests verify the parsedmarc-side
transformation logic — document/event construction, index/topic
naming, dedup queries, error wrapping — rather than asserting on
mock invocations as a proxy for behaviour. Where a branch is
defensive against a caller that doesn't exist in the codebase, the
test is omitted (commented in code rather than hidden behind a
pragma).

547 tests total (was 276), all passing. ruff check + format clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Document the two bug fixes from this PR in the 10.0.0 changelog

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Document testing standards in AGENTS.md

Adds a "Testing standards" section covering the principles applied in
PR-A (split) and PR-B (coverage expansion):

- Coverage measures shipped code only — don't reintroduce tests/* to
  the scope, don't expand omit, don't use # pragma: no cover.
- Honest tests assert on observable behaviour, not "the mock was called".
  Mock at SDK boundaries; parse the payload that gets sent.
- "If 90% requires faking it, ship 85% honestly" — coverage is a tool,
  not a goal. PR-B's deliberate stops at cli.py 69% and __init__.py 76%
  are the documented precedent for when to halt.
- Verify bug claims against the relevant RFC, internal types, installed
  SDK source, or upstream docs before changing code. Cite the source in
  the commit message and test docstring (RFC 8460 §4.3 and the Python
  open() docs for #775's two bug fixes are the pattern to follow).
- Bugs found while writing tests are fixed in the same PR; the test
  doubles as the regression guard.
- File layout (tests/test_<module>.py) is non-negotiable; module-level
  test loggers need fresh-handler setup so test ordering doesn't break
  assertLogs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Cover the corrupt-file fallback in append_json

Codecov flagged 2 missing patch-coverage lines on PR #775: the
except (json.JSONDecodeError, OSError) branch in append_json, which
falls back to overwriting when the existing file isn't a parseable
JSON array. Two new tests in tests/test_init.py:TestAppendJson
exercise both paths:

- test_corrupt_existing_file_is_overwritten_cleanly: existing file
  contains invalid JSON; append_json overwrites with the new array.
- test_existing_file_with_non_list_root_is_overwritten: existing
  file parses as {"foo": ...} (dict, not list); the isinstance guard
  rejects it and we overwrite cleanly.

Patch coverage now 100% on the bug fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Sean Whalen
2026-05-20 20:35:22 -04:00
committed by GitHub
parent 5b08627eaa
commit b7b8383fa4
16 changed files with 4734 additions and 205 deletions
+42
View File
@@ -116,6 +116,48 @@ IP address info cached for 4 hours, seen aggregate report IDs cached for 1 hour
- Token file writes must create parent directories before opening for write.
- Store natively numeric values as numbers, not pre-formatted strings. Example: ASN is stored as `int 15169`, not `"AS15169"`; Elasticsearch / OpenSearch mappings for such fields use `Integer()` so consumers can do range queries and numeric sorts. Display layers format with a prefix at render time.
## Testing standards
These rules govern *every* test added to `tests/`. They exist because the project has been burned by tests that looked like coverage but caught nothing, and by bug claims that turned out to be wrong about the spec. Both failure modes erode trust faster than missing coverage does.
### Coverage measures shipped code only
`[tool.coverage.run]` in `pyproject.toml` sets `source = ["parsedmarc"]` and omits `*/parsedmarc/resources/maps/*.py` (maintainer scripts that ship out of the wheel). Counting the test files in the denominator inflates the headline by ~8 percentage points without telling anyone anything useful — pytest discovers test files and runs them, so they're trivially "covered". The number that matters is "what fraction of the installed library does the test suite actually exercise". Don't reintroduce `tests/*` to the coverage scope, don't expand the `omit` list to hide gaps, don't add `# pragma: no cover` to dodge ugly branches. If a branch is genuinely unreachable, delete it; if it's reachable but hard to test, write the test.
### Honest tests assert on observable behaviour
A test that mocks every dependency and asserts that the mocks were invoked is testing the mocks, not the code. The benchmark for a good test is: *would this test fail if the code under test were silently wrong?* If the answer is no — if the test would pass regardless of whether the function does what its docstring claims — it isn't a test, it's coverage-padding.
Concrete patterns:
- **Mock at SDK boundaries, not at internal helpers.** Patch `boto3.resource`, `kafka.KafkaProducer`, `requests.Session.post`, `elasticsearch_dsl.Document.save`, `azure.monitor.ingestion.LogsIngestionClient` — the seams where the project's code stops and an external system begins. Don't patch our own functions just to make a test "easier"; that hides bugs in the function instead of testing it.
- **Assert on what gets sent, not that something was sent.** For an output module, parse the body that was passed to the mocked transport (`json.loads(call.kwargs["data"])`, `kafka.send.call_args.args[1]`, `bucket.put_object.call_args.kwargs["Key"]`) and verify the *fields and values a dashboard or downstream consumer would actually filter on*. A test that only checks `mock.assert_called_once()` would pass even if the payload were `{}`.
- **No trivial passthrough tests.** A test that calls a getter and asserts it returns the value just set isn't testing the code; it's testing Python's attribute machinery.
- **No `# pragma: no cover`.** If a branch is unreachable, the right fix is to delete the branch, not to hide it.
### "If 90% requires faking it, ship 85% honestly"
Coverage targets are a tool, not a goal. The value of coverage is what would actually catch regressions; chasing a percentage by writing low-signal tests degrades the suite. When the next available coverage point would cost test integrity — typically the deep orchestration paths in `_main()` and the watch-mode mailbox iteration, both of which need either a live ES/IMAP cluster or mocks so deep they verify the mock rather than the code — stop, and call out the modules where you stopped in the PR description. PR-B (#775) explicitly halted `cli.py` at 69% and `__init__.py` at 76% for this reason; the floor for the rest of the suite is 99100%.
### Verify bug claims against authoritative sources before fixing
If a test surfaces something that looks like a bug, cite the spec before changing code. Intuition isn't enough; "this code looks wrong" has been wrong often enough in this codebase that the project requires verification. In order of authority:
1. **The relevant RFC** for protocol or report-format questions (RFC 9989 for DMARC policy, RFC 9990 for aggregate reports, RFC 9991 for failure reports, RFC 8460 for SMTP TLS reports, RFC 6591 for legacy ARF).
2. **The internal type contract** (`parsedmarc/types.py` TypedDicts) for project-internal data shapes.
3. **The installed SDK source in the venv** for third-party API questions where the docs are inaccessible — `find venv -name '*.py' -path '*<package>*'` and grep, rather than asking a subagent to synthesize an answer.
4. **The official upstream documentation** (Python docs, vendor docs) for language- or platform-level behaviour. The `append_json` bug fix in #775 cited the explicit "writes in `a`/`a+` mode always go to EOF regardless of seek" line from <https://docs.python.org/3/library/functions.html#open>.
Cite the source in the commit message and the test docstring. A reviewer should be able to look at the test and confirm both *what* changed and *why the prior behaviour was wrong*. Two examples worth pattern-matching are #775's SMTP-TLS-to-S3 fix (RFC 8460 §4.3 cited) and the `append_json` fix (Python docs quoted).
### Bugs found while writing tests are fixed in the same PR
When a test for the documented behaviour fails because the code is wrong, the right move is to fix the code, not to lock in the broken behaviour. Don't write `self.assertRaises(KeyError)` to make a passing test out of a known bug, and don't skip the test with a "TODO: file separately". If the fix is small and clearly correct against the cited authority above, it belongs in the same PR as the test that found it — the test then doubles as the regression guard. List each fix in `CHANGELOG.md` under the in-progress version's **Bug fixes** section (introducing the heading if it's not there yet).
### File layout is non-negotiable
Tests live under `tests/` as `tests/test_<module>.py`, one per top-level `parsedmarc/*` module. The split is documented in [Code Style](#code-style) above. New tests go in the file whose module they exercise — don't create cross-module kitchen-sink test files, and don't reintroduce a monolithic `tests.py`. Module-level test logger handlers should be reset in `setUp` / a `_fresh_logger()` helper (see `tests/test_gelf.py` and `tests/test_syslog.py`) so that test ordering doesn't cause stale handlers from a prior test to accumulate on the module's logger and break `assertLogs` capture.
## Local dev secrets
If a config file is listed in `.gitignore`, treat its contents as secret. Do not paste its literal values into any tracked file — READMEs, docs, code comments, commit messages, PR descriptions, sample/test fixtures. Reference the variable name (e.g. `$SOME_PASSWORD`) or show a placeholder (`...`) instead, and tell the reader to pick their own values. This is both a real-leak hedge and a way to keep secret scanners (GitHub secret scanning, push protection, third-party scanners) from firing false positives on the repo. Defer to `.gitignore` as the source of truth on what's secret — the rule applies to any gitignored config file the project ever adds, not just the ones present today (currently `.env` and `parsedmarc*.ini`).
+6
View File
@@ -27,6 +27,12 @@ Several elements that became `langAttrString` in RFC 9990 (`extra_contact_info`,
Backwards compatibility to RFC 7489 is maintained.
### Bug fixes
- **`save_smtp_tls_report_to_s3` was completely broken.** `parsedmarc/s3.py:save_report_to_s3` unconditionally read `report["report_metadata"]` when assembling S3 object metadata, but SMTP TLS reports are flat per RFC 8460 §4.3 — they have no `report_metadata` sub-object — and `parse_smtp_tls_report_json` correctly stores `begin_date` as the raw ISO-8601 string from the report. The S3 path branch also assumed `begin_date` was a `datetime` and did `.year` / `.month` / `.day` on it. The CLI's surrounding `try/except` silently swallowed the resulting `KeyError`, so every SMTP-TLS report quietly failed to upload to S3 in production. Both issues are fixed: SMTP-TLS metadata is now built from the flat report fields directly, and the date is normalized via `human_timestamp_to_datetime`.
- **`append_json` corrupted JSON output files on the second write.** The original implementation opened files in `"a+"` mode, then `seek()`ed backwards to overwrite the trailing `]` with `,\n` before appending more elements. [Python's docs are explicit](https://docs.python.org/3/library/functions.html#open): on POSIX, writes in `"a"`/`"a+"` mode always go to EOF regardless of seek position. The result was that every second call onto an existing file produced `[...]\n],\n[...]`-style corrupted output instead of a single merged JSON array. Anyone running parsedmarc in watch mode with JSON output enabled had `aggregate.json` / `failure.json` / `smtp_tls.json` quietly turning into invalid JSON after the first overlap. Replaced with a read-merge-write pattern: load the existing array (if any), append the new elements, rewrite the whole file. `append_csv` was not affected — it doesn't seek backwards.
- **Removed redundant try/except in `parsedmarc/webhook.py`.** `save_aggregate_report_to_webhook` / `save_failure_report_to_webhook` / `save_smtp_tls_report_to_webhook` each wrapped `self._send_to_webhook(...)` in a try/except, but `_send_to_webhook` already catches every `Exception` itself, so the outer except blocks were unreachable dead code.
### Breaking changes
#### Forensic reports have been renamed to failure reports
+32 -19
View File
@@ -2500,26 +2500,39 @@ def append_json(
Sequence[SMTPTLSReport],
],
) -> None:
with open(filename, "a+", newline="\n", encoding="utf-8") as output:
output_json = json.dumps(reports, ensure_ascii=False, indent=2)
if output.seek(0, os.SEEK_END) != 0:
if len(reports) == 0:
# not appending anything, don't do any dance to append it
# correctly
return
output.seek(output.tell() - 1)
last_char = output.read(1)
if last_char == "]":
# remove the trailing "\n]", leading "[\n", and replace with
# ",\n"
output.seek(output.tell() - 2)
output.write(",\n")
output_json = output_json[2:]
else:
output.seek(0)
output.truncate()
"""Append ``reports`` to a JSON array on disk, creating the file
if needed.
output.write(output_json)
Reads the existing array (if the file exists and parses cleanly),
merges the new reports onto the end, and rewrites the file as a
single valid JSON array. An earlier version of this used an
``open(..., "a+")`` + ``seek()`` + overwrite pattern, but Python's
documentation is explicit that on POSIX, ``a`` / ``a+`` writes
*always* go to EOF regardless of seek position — so the second
call onto an existing file produced ``[...],\\n[...]``-style
corrupted output. Read-merge-write is the only way to get a valid
JSON array out of repeated appends.
"""
if len(reports) == 0:
# Don't create an empty-array file for an empty input; if a
# file already exists, leave it alone.
return
existing: list = []
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
try:
with open(filename, "r", encoding="utf-8") as f:
loaded = json.loads(f.read())
if isinstance(loaded, list):
existing = loaded
except (json.JSONDecodeError, OSError):
# Corrupted or unreadable: overwrite cleanly rather than
# silently fail to record.
existing = []
merged = existing + list(reports)
with open(filename, "w", newline="\n", encoding="utf-8") as output:
json.dump(merged, output, ensure_ascii=False, indent=2)
def append_csv(filename: str, csv: str) -> None:
+17 -4
View File
@@ -64,13 +64,24 @@ class S3Client(object):
def save_report_to_s3(self, report: dict[str, Any], report_type: str):
if report_type == "smtp_tls":
report_date = report["begin_date"]
# SMTP TLS reports (RFC 8460) are flat — they have no
# `report_metadata` sub-object — and parse_smtp_tls_report_json
# stores begin_date as the ISO string from the report JSON
# (per SMTPTLSReport's TypedDict).
report_date = human_timestamp_to_datetime(report["begin_date"])
report_id = report["report_id"]
metadata_source = {
"org_name": report.get("organization_name"),
"report_id": report.get("report_id"),
"begin_date": str(report.get("begin_date")),
"end_date": str(report.get("end_date")),
}
else:
report_date = human_timestamp_to_datetime(
report["report_metadata"]["begin_date"]
)
report_id = report["report_metadata"]["report_id"]
metadata_source = report["report_metadata"]
path_template = "{0}/{1}/year={2}/month={3:02d}/day={4:02d}/{5}.json"
object_path = path_template.format(
self.bucket_path,
@@ -87,11 +98,13 @@ class S3Client(object):
)
object_metadata = {
k: v
for k, v in report["report_metadata"].items()
if k in self.metadata_keys
for k, v in metadata_source.items()
if k in self.metadata_keys and v is not None
}
self.bucket.put_object(
Body=json.dumps(report), Key=object_path, Metadata=object_metadata
Body=json.dumps(report, default=str),
Key=object_path,
Metadata=object_metadata,
)
def close(self):
+8 -12
View File
@@ -39,26 +39,22 @@ class WebhookClient(object):
}
def save_failure_report_to_webhook(self, report: str):
try:
self._send_to_webhook(self.failure_url, report)
except Exception as error_:
logger.error("Webhook Error: {0}".format(error_.__str__()))
self._send_to_webhook(self.failure_url, report)
def save_smtp_tls_report_to_webhook(self, report: str):
try:
self._send_to_webhook(self.smtp_tls_url, report)
except Exception as error_:
logger.error("Webhook Error: {0}".format(error_.__str__()))
self._send_to_webhook(self.smtp_tls_url, report)
def save_aggregate_report_to_webhook(self, report: str):
try:
self._send_to_webhook(self.aggregate_url, report)
except Exception as error_:
logger.error("Webhook Error: {0}".format(error_.__str__()))
self._send_to_webhook(self.aggregate_url, report)
def _send_to_webhook(
self, webhook_url: str, payload: Union[bytes, str, dict[str, Any]]
):
# All HTTP / network errors are swallowed and logged: a failing
# webhook should never abort the surrounding parse-and-output
# batch. The outer save_* methods previously wrapped this in a
# redundant try/except — removed because _send_to_webhook
# already catches every Exception itself.
try:
self.session.post(webhook_url, data=payload, timeout=self.timeout)
except Exception as error_:
+853
View File
@@ -1805,5 +1805,858 @@ class TestExpandPath(unittest.TestCase):
self.assertEqual(_expand_path("relative/path"), "relative/path")
# ---------------------------------------------------------------------------
# _parse_config: per-section INI → opts mapping
#
# Each section of the INI is consumed by a different branch of
# _parse_config. The tests below build a minimal config for one
# section at a time and verify every documented key lands on the right
# opts attribute. A rename, typo, or dropped backwards-compat alias
# would be caught here.
# ---------------------------------------------------------------------------
class _StrToListTests(unittest.TestCase):
def test_str_to_list_strips_leading_whitespace_per_element(self):
from parsedmarc.cli import _str_to_list
self.assertEqual(_str_to_list("a, b ,c"), ["a", "b ", "c"])
def test_str_to_list_single_value(self):
from parsedmarc.cli import _str_to_list
self.assertEqual(_str_to_list("solo"), ["solo"])
def _opts():
"""A fresh Namespace with no attributes — _parse_config sets fields
via attribute assignment on whatever it's given."""
from argparse import Namespace
return Namespace()
def _config_with(section: str, settings: dict) -> "ConfigParser":
"""Build a ConfigParser holding exactly one section."""
from configparser import ConfigParser
cp = ConfigParser()
cp.add_section(section)
for k, v in settings.items():
cp.set(section, k, str(v))
return cp
class TestParseConfigGeneral(unittest.TestCase):
"""The [general] section sets dozens of flags. Hit a representative
subset: filenames, save-toggles, DNS settings, output dir."""
def test_general_filenames_and_output(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"general",
{
"silent": "false",
"output": "/tmp/dmarc-out",
"aggregate_json_filename": "agg.json",
"failure_json_filename": "fail.json",
"smtp_tls_json_filename": "tls.json",
"aggregate_csv_filename": "agg.csv",
"failure_csv_filename": "fail.csv",
"smtp_tls_csv_filename": "tls.csv",
"save_aggregate": "true",
"save_failure": "true",
"save_smtp_tls": "true",
"debug": "false",
"verbose": "false",
"warnings": "false",
"fail_on_output_error": "false",
"offline": "true",
"strip_attachment_payloads": "true",
"n_procs": "4",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.output, "/tmp/dmarc-out")
self.assertEqual(opts.aggregate_json_filename, "agg.json")
self.assertEqual(opts.failure_json_filename, "fail.json")
self.assertEqual(opts.smtp_tls_csv_filename, "tls.csv")
self.assertTrue(opts.save_aggregate)
self.assertTrue(opts.save_failure)
self.assertTrue(opts.save_smtp_tls)
self.assertTrue(opts.offline)
self.assertTrue(opts.strip_attachment_payloads)
self.assertEqual(opts.n_procs, 4)
self.assertFalse(opts.silent)
self.assertFalse(opts.debug)
def test_general_save_forensic_alias_sets_save_failure(self):
"""Backwards compat: save_forensic in INI sets opts.save_failure."""
from parsedmarc.cli import _parse_config
cp = _config_with("general", {"save_forensic": "true"})
opts = _opts()
_parse_config(cp, opts)
self.assertTrue(opts.save_failure)
def test_general_forensic_filename_aliases_set_failure(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"general",
{
"forensic_json_filename": "fa.json",
"forensic_csv_filename": "fa.csv",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.failure_json_filename, "fa.json")
self.assertEqual(opts.failure_csv_filename, "fa.csv")
def test_general_dns_settings_with_defaults(self):
from parsedmarc.cli import _parse_config
# dns_timeout/dns_retries are typed via getfloat/getint which
# return non-None values for any valid input.
cp = _config_with(
"general",
{
"dns_timeout": "5.0",
"dns_retries": "2",
"dns_test_address": "1.1.1.1",
"nameservers": "1.1.1.1, 8.8.8.8",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.dns_timeout, 5.0)
self.assertEqual(opts.dns_retries, 2)
self.assertEqual(opts.nameservers, ["1.1.1.1", "8.8.8.8"])
def test_general_normalize_timespan_threshold(self):
from parsedmarc.cli import _parse_config
cp = _config_with("general", {"normalize_timespan_threshold_hours": "48"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.normalize_timespan_threshold_hours, 48.0)
class TestParseConfigElasticsearch(unittest.TestCase):
def test_elasticsearch_basic(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"elasticsearch",
{
"hosts": "es1:9200, es2:9200",
"timeout": "30.0",
"number_of_shards": "3",
"number_of_replicas": "1",
"index_suffix": "tenant_a",
"index_prefix": "cust_",
"monthly_indexes": "true",
"ssl": "true",
"cert_path": "/etc/ca.pem",
"skip_certificate_verification": "true",
"user": "alice",
"password": "secret",
"api_key": "base64key",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.elasticsearch_hosts, ["es1:9200", "es2:9200"])
self.assertEqual(opts.elasticsearch_timeout, 30.0)
self.assertEqual(opts.elasticsearch_number_of_shards, 3)
self.assertEqual(opts.elasticsearch_number_of_replicas, 1)
self.assertEqual(opts.elasticsearch_index_suffix, "tenant_a")
self.assertEqual(opts.elasticsearch_index_prefix, "cust_")
self.assertTrue(opts.elasticsearch_monthly_indexes)
self.assertTrue(opts.elasticsearch_ssl)
self.assertEqual(opts.elasticsearch_ssl_cert_path, "/etc/ca.pem")
self.assertTrue(opts.elasticsearch_skip_certificate_verification)
self.assertEqual(opts.elasticsearch_username, "alice")
self.assertEqual(opts.elasticsearch_password, "secret")
self.assertEqual(opts.elasticsearch_api_key, "base64key")
def test_elasticsearch_apikey_camelcase_alias_pre_8_20(self):
"""`apiKey` (camelCase) is the legacy 8.20-and-earlier name."""
from parsedmarc.cli import _parse_config
cp = _config_with("elasticsearch", {"hosts": "es:9200", "apiKey": "legacy"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.elasticsearch_api_key, "legacy")
def test_elasticsearch_missing_hosts_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("elasticsearch", {"timeout": "30"})
with self.assertRaises(ConfigurationError) as ctx:
_parse_config(cp, _opts())
self.assertIn("hosts", str(ctx.exception))
class TestParseConfigOpenSearch(unittest.TestCase):
def test_opensearch_basic(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"opensearch",
{
"hosts": "os1:9200",
"timeout": "45.0",
"number_of_shards": "2",
"number_of_replicas": "0",
"index_suffix": "x",
"index_prefix": "y_",
"monthly_indexes": "true",
"ssl": "true",
"cert_path": "/etc/ca.pem",
"skip_certificate_verification": "true",
"user": "u",
"password": "p",
"api_key": "k",
"auth_type": "BASIC",
"aws_region": "us-east-1",
"aws_service": "es",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.opensearch_hosts, ["os1:9200"])
self.assertEqual(opts.opensearch_timeout, 45.0)
self.assertEqual(opts.opensearch_number_of_shards, 2)
self.assertEqual(opts.opensearch_number_of_replicas, 0)
self.assertEqual(opts.opensearch_index_suffix, "x")
self.assertEqual(opts.opensearch_index_prefix, "y_")
self.assertTrue(opts.opensearch_monthly_indexes)
self.assertTrue(opts.opensearch_ssl)
self.assertEqual(opts.opensearch_ssl_cert_path, "/etc/ca.pem")
self.assertTrue(opts.opensearch_skip_certificate_verification)
self.assertEqual(opts.opensearch_username, "u")
self.assertEqual(opts.opensearch_password, "p")
self.assertEqual(opts.opensearch_api_key, "k")
# auth_type is lowercased/stripped.
self.assertEqual(opts.opensearch_auth_type, "basic")
self.assertEqual(opts.opensearch_aws_region, "us-east-1")
self.assertEqual(opts.opensearch_aws_service, "es")
def test_opensearch_authentication_type_legacy_alias(self):
"""`authentication_type` is the legacy spelling of `auth_type`."""
from parsedmarc.cli import _parse_config
cp = _config_with(
"opensearch",
{"hosts": "os:9200", "authentication_type": "AWSSigV4"},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.opensearch_auth_type, "awssigv4")
def test_opensearch_apikey_camelcase_alias(self):
from parsedmarc.cli import _parse_config
cp = _config_with("opensearch", {"hosts": "os:9200", "apiKey": "legacy"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.opensearch_api_key, "legacy")
def test_opensearch_missing_hosts_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("opensearch", {"timeout": "30"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigSplunkHec(unittest.TestCase):
def test_splunk_hec_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"splunk_hec",
{
"url": "https://splunk:8088",
"token": "abc-token",
"index": "dmarc",
"skip_certificate_verification": "true",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.hec, "https://splunk:8088")
self.assertEqual(opts.hec_token, "abc-token")
self.assertEqual(opts.hec_index, "dmarc")
self.assertTrue(opts.hec_skip_certificate_verification)
def test_splunk_hec_missing_url_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("splunk_hec", {"token": "t", "index": "i"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_splunk_hec_missing_token_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("splunk_hec", {"url": "https://splunk:8088", "index": "i"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_splunk_hec_missing_index_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("splunk_hec", {"url": "https://splunk:8088", "token": "t"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigKafka(unittest.TestCase):
def test_kafka_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"kafka",
{
"hosts": "kafka1:9092, kafka2:9092",
"user": "u",
"password": "p",
"ssl": "true",
"skip_certificate_verification": "true",
"aggregate_topic": "dmarc-aggregate",
"failure_topic": "dmarc-failure",
"smtp_tls_topic": "smtp-tls",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.kafka_hosts, ["kafka1:9092", "kafka2:9092"])
self.assertEqual(opts.kafka_username, "u")
self.assertEqual(opts.kafka_password, "p")
self.assertTrue(opts.kafka_ssl)
self.assertTrue(opts.kafka_skip_certificate_verification)
self.assertEqual(opts.kafka_aggregate_topic, "dmarc-aggregate")
self.assertEqual(opts.kafka_failure_topic, "dmarc-failure")
self.assertEqual(opts.kafka_smtp_tls_topic, "smtp-tls")
def test_kafka_forensic_topic_alias_sets_failure_topic(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"kafka",
{
"hosts": "k:9092",
"aggregate_topic": "agg",
"forensic_topic": "old-fail",
"smtp_tls_topic": "tls",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.kafka_failure_topic, "old-fail")
def test_kafka_missing_hosts_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with(
"kafka",
{
"aggregate_topic": "a",
"failure_topic": "f",
"smtp_tls_topic": "t",
},
)
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_kafka_missing_aggregate_topic_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with(
"kafka",
{"hosts": "k:9092", "failure_topic": "f", "smtp_tls_topic": "t"},
)
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_kafka_missing_failure_topic_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with(
"kafka",
{"hosts": "k:9092", "aggregate_topic": "a", "smtp_tls_topic": "t"},
)
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_kafka_missing_smtp_tls_topic_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with(
"kafka",
{"hosts": "k:9092", "aggregate_topic": "a", "failure_topic": "f"},
)
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigSmtp(unittest.TestCase):
def test_smtp_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"smtp",
{
"host": "smtp.example.com",
"port": "587",
"ssl": "true",
"skip_certificate_verification": "true",
"user": "u",
"password": "p",
"from": "dmarc@example.com",
"to": "admin@example.com, alert@example.com",
"subject": "DMARC Report",
"attachment": "/tmp/dmarc.zip",
"message": "See attached",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.smtp_host, "smtp.example.com")
self.assertEqual(opts.smtp_port, 587)
self.assertTrue(opts.smtp_ssl)
self.assertTrue(opts.smtp_skip_certificate_verification)
self.assertEqual(opts.smtp_user, "u")
self.assertEqual(opts.smtp_password, "p")
self.assertEqual(opts.smtp_from, "dmarc@example.com")
self.assertEqual(opts.smtp_to, ["admin@example.com", "alert@example.com"])
self.assertEqual(opts.smtp_subject, "DMARC Report")
self.assertEqual(opts.smtp_attachment, "/tmp/dmarc.zip")
self.assertEqual(opts.smtp_message, "See attached")
def test_smtp_missing_host_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("smtp", {"user": "u", "password": "p"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_smtp_missing_user_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("smtp", {"host": "smtp.example.com", "password": "p"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_smtp_missing_password_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("smtp", {"host": "smtp.example.com", "user": "u"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigS3(unittest.TestCase):
def test_s3_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"s3",
{
"bucket": "my-bucket",
"path": "/dmarc/",
"region_name": "us-east-1",
"endpoint_url": "https://s3.example.com",
"access_key_id": "AKIA-x",
"secret_access_key": "secret",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.s3_bucket, "my-bucket")
# Leading and trailing slashes are stripped.
self.assertEqual(opts.s3_path, "dmarc")
self.assertEqual(opts.s3_region_name, "us-east-1")
self.assertEqual(opts.s3_endpoint_url, "https://s3.example.com")
self.assertEqual(opts.s3_access_key_id, "AKIA-x")
self.assertEqual(opts.s3_secret_access_key, "secret")
def test_s3_default_path_is_empty(self):
from parsedmarc.cli import _parse_config
cp = _config_with("s3", {"bucket": "b"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.s3_path, "")
def test_s3_missing_bucket_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("s3", {"path": "x"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigSyslog(unittest.TestCase):
def test_syslog_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"syslog",
{
"server": "syslog.example.com",
"port": "6514",
"protocol": "tls",
"cafile_path": "/etc/ca.pem",
"certfile_path": "/etc/c.pem",
"keyfile_path": "/etc/k.pem",
"timeout": "10.0",
"retry_attempts": "5",
"retry_delay": "2",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.syslog_server, "syslog.example.com")
self.assertEqual(opts.syslog_port, "6514")
self.assertEqual(opts.syslog_protocol, "tls")
self.assertEqual(opts.syslog_cafile_path, "/etc/ca.pem")
self.assertEqual(opts.syslog_certfile_path, "/etc/c.pem")
self.assertEqual(opts.syslog_keyfile_path, "/etc/k.pem")
self.assertEqual(opts.syslog_timeout, 10.0)
self.assertEqual(opts.syslog_retry_attempts, 5)
self.assertEqual(opts.syslog_retry_delay, 2)
def test_syslog_defaults(self):
from parsedmarc.cli import _parse_config
cp = _config_with("syslog", {"server": "s"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.syslog_port, 514)
self.assertEqual(opts.syslog_protocol, "udp")
self.assertEqual(opts.syslog_timeout, 5.0)
self.assertEqual(opts.syslog_retry_attempts, 3)
self.assertEqual(opts.syslog_retry_delay, 5)
def test_syslog_missing_server_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("syslog", {"port": "514"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigGmailApi(unittest.TestCase):
def test_gmail_api_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"gmail_api",
{
"credentials_file": "/etc/gmail-creds.json",
"token_file": "/var/lib/parsedmarc/gmail.token",
"include_spam_trash": "true",
"paginate_messages": "false",
"scopes": "https://www.googleapis.com/auth/gmail.readonly",
"oauth2_port": "8888",
"auth_mode": "device_code",
"service_account_user": "user@example.com",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.gmail_api_credentials_file, "/etc/gmail-creds.json")
self.assertEqual(opts.gmail_api_token_file, "/var/lib/parsedmarc/gmail.token")
self.assertTrue(opts.gmail_api_include_spam_trash)
self.assertFalse(opts.gmail_api_paginate_messages)
self.assertEqual(
opts.gmail_api_scopes,
["https://www.googleapis.com/auth/gmail.readonly"],
)
self.assertEqual(opts.gmail_api_oauth2_port, 8888)
self.assertEqual(opts.gmail_api_auth_mode, "device_code")
self.assertEqual(opts.gmail_api_service_account_user, "user@example.com")
def test_gmail_api_delegated_user_alias(self):
"""`delegated_user` is the legacy spelling of `service_account_user`."""
from parsedmarc.cli import _parse_config
cp = _config_with(
"gmail_api",
{
"credentials_file": "/c",
"delegated_user": "legacy@example.com",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.gmail_api_service_account_user, "legacy@example.com")
def test_gmail_api_default_scope(self):
from parsedmarc.cli import _parse_config
cp = _config_with("gmail_api", {"credentials_file": "/c"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(
opts.gmail_api_scopes,
["https://www.googleapis.com/auth/gmail.modify"],
)
class TestParseConfigLogAnalytics(unittest.TestCase):
def test_log_analytics_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"log_analytics",
{
"client_id": "cid",
"client_secret": "csec",
"tenant_id": "tid",
"dce": "https://dce.example.com",
"dcr_immutable_id": "dcr-1",
"dcr_aggregate_stream": "Custom-Aggregate_CL",
"dcr_failure_stream": "Custom-Failure_CL",
"dcr_smtp_tls_stream": "Custom-SMTPTLS_CL",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.la_client_id, "cid")
self.assertEqual(opts.la_client_secret, "csec")
self.assertEqual(opts.la_tenant_id, "tid")
self.assertEqual(opts.la_dce, "https://dce.example.com")
self.assertEqual(opts.la_dcr_immutable_id, "dcr-1")
self.assertEqual(opts.la_dcr_aggregate_stream, "Custom-Aggregate_CL")
self.assertEqual(opts.la_dcr_failure_stream, "Custom-Failure_CL")
self.assertEqual(opts.la_dcr_smtp_tls_stream, "Custom-SMTPTLS_CL")
def test_log_analytics_forensic_stream_alias(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"log_analytics",
{
"client_id": "c",
"dcr_forensic_stream": "Old-Forensic_CL",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.la_dcr_failure_stream, "Old-Forensic_CL")
class TestParseConfigGelf(unittest.TestCase):
def test_gelf_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"gelf", {"host": "graylog.example.com", "port": "12201", "mode": "tls"}
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.gelf_host, "graylog.example.com")
self.assertEqual(opts.gelf_port, "12201")
self.assertEqual(opts.gelf_mode, "tls")
def test_gelf_missing_host_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("gelf", {"port": "12201", "mode": "udp"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_gelf_missing_port_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("gelf", {"host": "g", "mode": "udp"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
def test_gelf_missing_mode_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("gelf", {"host": "g", "port": "12201"})
with self.assertRaises(ConfigurationError):
_parse_config(cp, _opts())
class TestParseConfigWebhook(unittest.TestCase):
def test_webhook_complete(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"webhook",
{
"aggregate_url": "https://hooks.example.com/agg",
"failure_url": "https://hooks.example.com/fail",
"smtp_tls_url": "https://hooks.example.com/tls",
"timeout": "30",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.webhook_aggregate_url, "https://hooks.example.com/agg")
self.assertEqual(opts.webhook_failure_url, "https://hooks.example.com/fail")
self.assertEqual(opts.webhook_smtp_tls_url, "https://hooks.example.com/tls")
self.assertEqual(opts.webhook_timeout, 30)
def test_webhook_forensic_url_alias_sets_failure_url(self):
from parsedmarc.cli import _parse_config
cp = _config_with("webhook", {"forensic_url": "https://old.example.com/fail"})
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.webhook_failure_url, "https://old.example.com/fail")
class TestConfigureLogging(unittest.TestCase):
"""_configure_logging is called in every child process for parallel
parsing — if it stops attaching a handler, log output goes dark in
multiprocessing mode."""
def setUp(self):
from parsedmarc.log import logger as plog
self._saved_handlers = list(plog.handlers)
self._saved_level = plog.level
def tearDown(self):
from parsedmarc.log import logger as plog
plog.handlers[:] = self._saved_handlers
plog.setLevel(self._saved_level)
def test_sets_log_level(self):
import logging as _logging
from parsedmarc.cli import _configure_logging
from parsedmarc.log import logger as plog
_configure_logging(_logging.DEBUG)
self.assertEqual(plog.level, _logging.DEBUG)
def test_adds_stream_handler_when_none_present(self):
import logging as _logging
from parsedmarc.cli import _configure_logging
from parsedmarc.log import logger as plog
# Clear any existing StreamHandler so we know addHandler runs.
plog.handlers[:] = [
h for h in plog.handlers if type(h) is not _logging.StreamHandler
]
_configure_logging(_logging.INFO)
self.assertTrue(any(type(h) is _logging.StreamHandler for h in plog.handlers))
def test_does_not_duplicate_stream_handler(self):
import logging as _logging
from parsedmarc.cli import _configure_logging
from parsedmarc.log import logger as plog
# Start with a single StreamHandler attached.
plog.handlers[:] = [_logging.StreamHandler()]
before = len(plog.handlers)
_configure_logging(_logging.INFO)
after = len(plog.handlers)
self.assertEqual(before, after)
def test_adds_file_handler_when_log_file_given(self):
import logging as _logging
import tempfile
from parsedmarc.cli import _configure_logging
from parsedmarc.log import logger as plog
with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as tf:
path = tf.name
try:
_configure_logging(_logging.INFO, log_file=path)
self.assertTrue(
any(isinstance(h, _logging.FileHandler) for h in plog.handlers)
)
finally:
for h in list(plog.handlers):
if isinstance(h, _logging.FileHandler):
plog.removeHandler(h)
h.close()
os.remove(path)
def test_unwritable_log_file_logs_warning_does_not_raise(self):
"""If the log file can't be opened, we warn and continue. A
regression that raised would crash the whole parse pipeline."""
import logging as _logging
from parsedmarc.cli import _configure_logging
with self.assertLogs("parsedmarc.log", level="WARNING") as cm:
_configure_logging(_logging.INFO, log_file="/proc/nonexistent/x.log")
self.assertTrue(any("Unable to write to log file" in m for m in cm.output))
class TestCliParse(unittest.TestCase):
"""cli_parse is the multiprocessing worker — it shells out to
parse_report_file, then sends the result (or error) back over a
pipe. Both branches matter: a regression would silently drop
results in parallel mode."""
def test_cli_parse_sends_results_on_success(self):
from multiprocessing import Pipe
from unittest.mock import patch
from parsedmarc.cli import cli_parse
parent_conn, child_conn = Pipe()
with patch("parsedmarc.cli.parse_report_file") as mock_parse:
mock_parse.return_value = {"report_type": "aggregate", "report": {}}
cli_parse(
"/path/to/report.xml",
False,
None,
2.0,
0,
None,
True,
True,
None,
None,
24.0,
child_conn,
)
sent = parent_conn.recv()
self.assertEqual(sent[0], {"report_type": "aggregate", "report": {}})
self.assertEqual(sent[1], "/path/to/report.xml")
def test_cli_parse_sends_error_on_parser_error(self):
from multiprocessing import Pipe
from unittest.mock import patch
from parsedmarc.cli import cli_parse
from parsedmarc import ParserError
parent_conn, child_conn = Pipe()
with patch("parsedmarc.cli.parse_report_file") as mock_parse:
err = ParserError("bad report")
mock_parse.side_effect = err
cli_parse(
"/bad.xml",
False,
None,
2.0,
0,
None,
True,
True,
None,
None,
24.0,
child_conn,
)
sent = parent_conn.recv()
self.assertIsInstance(sent[0], ParserError)
self.assertEqual(sent[1], "/bad.xml")
if __name__ == "__main__":
unittest.main(verbosity=2)
+842
View File
@@ -0,0 +1,842 @@
"""Tests for parsedmarc.elastic
Mocks at the elasticsearch-dsl SDK boundary (connections.create_connection,
Index, Search, Document.save) so the tests verify the parsedmarc-side
transformation logic — document construction, index naming, deduplication
queries, error wrapping — without needing a running Elasticsearch cluster.
"""
import unittest
from unittest.mock import MagicMock, call, patch
import parsedmarc.elastic as elastic_module
from parsedmarc import InvalidFailureReport
from parsedmarc.elastic import (
AlreadySaved,
ElasticsearchError,
create_indexes,
migrate_indexes,
save_aggregate_report_to_elasticsearch,
save_failure_report_to_elasticsearch,
save_smtp_tls_report_to_elasticsearch,
set_hosts,
)
# ---------------------------------------------------------------------------
# Sample report fixtures
# ---------------------------------------------------------------------------
def _aggregate_report(**overrides):
base = {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
"org_name": "TestOrg",
"org_email": "dmarc@example.com",
"org_extra_contact_info": None,
"report_id": "agg-1",
"begin_date": "2024-01-15 00:00:00",
"end_date": "2024-01-16 00:00:00",
"timespan_requires_normalization": False,
"original_timespan_seconds": 86400,
"errors": [],
"generator": "TestGen/1.0",
},
"policy_published": {
"domain": "example.com",
"adkim": "r",
"aspf": "r",
"p": "none",
"sp": "none",
"pct": None,
"fo": None,
"np": "reject",
"testing": "n",
"discovery_method": "treewalk",
},
"records": [
{
"interval_begin": "2024-01-15 00:00:00",
"interval_end": "2024-01-16 00:00:00",
"normalized_timespan": False,
"source": {
"ip_address": "192.0.2.1",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"count": 4,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [
{"type": "local_policy", "comment": "approved"}
],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": "rcpt@example.com",
},
"auth_results": {
"dkim": [
{
"domain": "example.com",
"selector": "s",
"result": "pass",
"human_result": None,
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": None,
}
],
},
}
],
}
base.update(overrides)
return base
def _failure_report(**overrides):
base = {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "x@example.com",
"original_rcpt_to": None,
"arrival_date": "Thu, 1 Jan 2024 00:00:00 +0000",
"arrival_date_utc": "2024-01-01 00:00:00",
"authentication_results": None,
"delivery_result": "other",
"auth_failure": ["dmarc"],
"authentication_mechanisms": [],
"dkim_domain": None,
"reported_domain": "example.com",
"sample_headers_only": True,
"source": {
"ip_address": "192.0.2.5",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"sample": "raw",
"parsed_sample": {
"headers": {
# mailparser emits headers as [[display_name, address]]
# lists; an empty display becomes [["", address]].
"From": [["Sender Name", "sender@example.com"]],
"To": [["", "rcpt@example.com"]],
"Subject": "Test",
},
"subject": "Test",
"filename_safe_subject": "Test",
"body": "body",
"date": "Thu, 1 Jan 2024 00:00:00 +0000",
"to": [{"display_name": None, "address": "rcpt@example.com"}],
"reply_to": [],
"cc": [],
"bcc": [],
"attachments": [],
},
}
base.update(overrides)
return base
def _smtp_tls_report(**overrides):
base = {
"organization_name": "TestOrg",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"contact_info": "tls@example.com",
"report_id": "tls-1",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"successful_session_count": 100,
"failed_session_count": 1,
"policy_strings": ["version: STSv1"],
"mx_host_patterns": ["*.example.com"],
"failure_details": [
{
"result_type": "certificate-expired",
"failed_session_count": 1,
"receiving_mx_hostname": "mx.example.com",
"sending_mta_ip": "10.0.0.1",
}
],
}
],
}
base.update(overrides)
return base
def _empty_search():
"""A Search() mock whose .execute() returns an empty hit list."""
search = MagicMock()
search.execute.return_value = []
return search
def _populated_search():
"""A Search() mock whose .execute() returns a non-empty hit list."""
search = MagicMock()
search.execute.return_value = [MagicMock()]
return search
# ---------------------------------------------------------------------------
# set_hosts: connection-parameter assembly
# ---------------------------------------------------------------------------
class TestSetHosts(unittest.TestCase):
"""Verify the conn_params dict handed to elasticsearch-dsl
matches each documented option. Each branch corresponds to a
real-world deployment shape (TLS, basic auth, API key, custom CA)."""
def test_single_host_string_normalized_to_list(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("https://es:9200")
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["hosts"], ["https://es:9200"])
def test_host_list_preserved(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts(["es1:9200", "es2:9200"])
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["hosts"], ["es1:9200", "es2:9200"])
def test_timeout_default_60s(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200")
self.assertEqual(mock_conn.call_args.kwargs["timeout"], 60.0)
def test_timeout_custom(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", timeout=30.0)
self.assertEqual(mock_conn.call_args.kwargs["timeout"], 30.0)
def test_use_ssl_enables_verify_by_default(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True)
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["use_ssl"], True)
self.assertEqual(kwargs["verify_certs"], True)
self.assertNotIn("ca_certs", kwargs)
def test_use_ssl_with_custom_ca(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True, ssl_cert_path="/etc/ca.pem")
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["ca_certs"], "/etc/ca.pem")
def test_skip_certificate_verification_sets_verify_false(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True, skip_certificate_verification=True)
self.assertEqual(mock_conn.call_args.kwargs["verify_certs"], False)
def test_username_password_sets_http_auth(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", username="u", password="p")
self.assertEqual(mock_conn.call_args.kwargs["http_auth"], ("u", "p"))
def test_username_without_password_not_set(self):
"""Half-configured auth is suspicious enough not to send."""
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", username="u")
self.assertNotIn("http_auth", mock_conn.call_args.kwargs)
def test_api_key_set(self):
with patch("parsedmarc.elastic.connections.create_connection") as mock_conn:
set_hosts("es:9200", api_key="base64key==")
self.assertEqual(mock_conn.call_args.kwargs["api_key"], "base64key==")
# ---------------------------------------------------------------------------
# create_indexes
# ---------------------------------------------------------------------------
class TestCreateIndexes(unittest.TestCase):
def test_creates_missing_index_with_default_settings(self):
with patch("parsedmarc.elastic.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = False
create_indexes(["dmarc_aggregate-2024-01-15"])
mock_index.settings.assert_called_once_with(
number_of_shards=1, number_of_replicas=0
)
mock_index.create.assert_called_once()
def test_creates_with_custom_settings(self):
with patch("parsedmarc.elastic.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = False
create_indexes(
["idx"], settings={"number_of_shards": 3, "refresh_interval": "5s"}
)
mock_index.settings.assert_called_once_with(
number_of_shards=3, refresh_interval="5s"
)
def test_skips_existing_index(self):
with patch("parsedmarc.elastic.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = True
create_indexes(["idx"])
mock_index.create.assert_not_called()
def test_wraps_sdk_error(self):
with patch("parsedmarc.elastic.Index") as mock_index_cls:
mock_index_cls.return_value.exists.side_effect = RuntimeError(
"cluster down"
)
with self.assertRaises(ElasticsearchError) as ctx:
create_indexes(["idx"])
self.assertIn("cluster down", str(ctx.exception))
# ---------------------------------------------------------------------------
# migrate_indexes
# ---------------------------------------------------------------------------
class TestMigrateIndexes(unittest.TestCase):
"""The legacy `published_policy.fo` field was mapped as `long` in
older indexes. migrate_indexes detects that and rebuilds the index
with the text/keyword shape. The branch is gnarly; a regression
would silently leave old data un-migrated."""
def test_no_indexes_is_noop(self):
migrate_indexes() # Should not raise
def test_skips_non_existent_index(self):
with patch("parsedmarc.elastic.Index") as mock_index_cls:
mock_index_cls.return_value.exists.return_value = False
migrate_indexes(aggregate_indexes=["missing"])
# exists() returned False — no field_mapping fetch.
mock_index_cls.return_value.get_field_mapping.assert_not_called()
def test_skips_when_doc_mapping_absent(self):
"""An index that has 'fo' but not under the 'doc' type
(e.g., empty index with default mapping) is left alone."""
with patch("parsedmarc.elastic.Index") as mock_index_cls:
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {"some_key": {"mappings": {}}}
with patch("parsedmarc.elastic.reindex") as mock_reindex:
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2023-01-01"])
mock_reindex.assert_not_called()
def test_migrates_when_fo_is_long(self):
"""The actual migration path: when fo is mapped as 'long',
a v2 index is created with the corrected mapping, data is
reindexed, and the old index is deleted."""
with (
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch("parsedmarc.elastic.reindex") as mock_reindex,
patch("parsedmarc.elastic.connections.get_connection") as mock_get_conn,
):
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {
"dmarc_aggregate-2023-01-01": {
"mappings": {
"doc": {
"published_policy.fo": {"mapping": {"fo": {"type": "long"}}}
}
}
}
}
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2023-01-01"])
# reindex called from old → new (v2) index.
mock_reindex.assert_called_once()
# connections.get_connection consulted to get the ES client.
mock_get_conn.assert_called_once()
def test_skips_when_fo_already_text(self):
with (
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch("parsedmarc.elastic.reindex") as mock_reindex,
):
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {
"dmarc_aggregate-2024-01-01": {
"mappings": {
"doc": {
"published_policy.fo": {"mapping": {"fo": {"type": "text"}}}
}
}
}
}
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2024-01-01"])
mock_reindex.assert_not_called()
# ---------------------------------------------------------------------------
# save_aggregate_report_to_elasticsearch
# ---------------------------------------------------------------------------
class TestSaveAggregateReport(unittest.TestCase):
"""The aggregate-report save fans out across multiple SDK calls:
Search (for dedup), Index.create (for the daily/monthly index),
Document.save. Each test patches the boundary it needs and
leaves the rest alone."""
def _patches(self, search_factory=_empty_search):
return [
patch("parsedmarc.elastic.Search", return_value=search_factory()),
patch(
"parsedmarc.elastic.Index",
return_value=MagicMock(exists=MagicMock(return_value=True)),
),
patch.object(elastic_module._AggregateReportDoc, "save"),
]
def test_save_emits_one_document_per_record(self):
report = _aggregate_report()
report["records"].append(report["records"][0].copy())
patches = self._patches()
with patches[0], patches[1], patches[2] as mock_save:
save_aggregate_report_to_elasticsearch(report)
# Two records → two saves.
self.assertEqual(mock_save.call_count, 2)
def test_already_saved_raises_when_search_returns_hit(self):
"""The dedup query is the only thing preventing
double-indexing on re-run. A regression would silently
re-save reports, inflating Kibana counts."""
with (
patch("parsedmarc.elastic.Search", return_value=_populated_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._AggregateReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_aggregate_report_to_elasticsearch(_aggregate_report())
mock_save.assert_not_called()
def test_search_exception_wraps_to_elasticsearch_error(self):
bad_search = MagicMock()
bad_search.execute.side_effect = RuntimeError("network")
with (
patch("parsedmarc.elastic.Search", return_value=bad_search),
patch("parsedmarc.elastic.Index"),
):
with self.assertRaises(ElasticsearchError) as ctx:
save_aggregate_report_to_elasticsearch(_aggregate_report())
self.assertIn("network", str(ctx.exception))
def test_save_exception_wraps_to_elasticsearch_error(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(
elastic_module._AggregateReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(ElasticsearchError) as ctx:
save_aggregate_report_to_elasticsearch(_aggregate_report())
self.assertIn("disk", str(ctx.exception))
def test_index_name_uses_daily_format_by_default(self):
"""Index naming: dmarc_aggregate-YYYY-MM-DD by default."""
index_calls = []
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_elasticsearch(_aggregate_report())
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_aggregate-2024-01-15", index_calls)
def test_index_name_uses_monthly_format_when_flag_set(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_elasticsearch(
_aggregate_report(), monthly_indexes=True
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_aggregate-2024-01", index_calls)
def test_index_name_honours_suffix_and_prefix(self):
"""Prefix/suffix support multi-tenant setups where one ES
cluster serves several DMARC owners."""
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_elasticsearch(
_aggregate_report(),
index_suffix="tenant_a",
index_prefix="customer1_",
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("customer1_dmarc_aggregate_tenant_a-2024-01-15", index_calls)
def test_dedup_search_pattern_uses_suffix_wildcard(self):
"""Existing-report search uses '*' so it matches both
daily and monthly index buckets."""
with (
patch("parsedmarc.elastic.Search") as mock_search_cls,
patch(
"parsedmarc.elastic.Index",
return_value=MagicMock(exists=MagicMock(return_value=True)),
),
patch.object(elastic_module._AggregateReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_aggregate_report_to_elasticsearch(
_aggregate_report(), index_suffix="tenant_a", index_prefix="cust_"
)
# Search index pattern wraps prefix+name+suffix with trailing wildcard.
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("cust_dmarc_aggregate_tenant_a*", search_index)
# ---------------------------------------------------------------------------
# save_failure_report_to_elasticsearch
# ---------------------------------------------------------------------------
class TestSaveFailureReport(unittest.TestCase):
def test_save_emits_one_document(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_elasticsearch(_failure_report())
mock_save.assert_called_once()
def test_already_saved_raises_on_dedup_hit(self):
"""Failure-report dedup uses arrival_date + From/To/Subject
from the parsed sample. A hit means we've already indexed
this exact failure sample."""
with (
patch("parsedmarc.elastic.Search", return_value=_populated_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_failure_report_to_elasticsearch(_failure_report())
mock_save.assert_not_called()
def test_save_exception_wraps_to_elasticsearch_error(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(
elastic_module._FailureReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(ElasticsearchError) as ctx:
save_failure_report_to_elasticsearch(_failure_report())
self.assertIn("disk", str(ctx.exception))
def test_keyerror_wraps_to_invalid_failure_report(self):
"""A malformed failure report (missing a required field) is
surfaced as InvalidFailureReport so the caller can route it
differently from infra errors."""
report = _failure_report()
del report["feedback_type"]
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save"),
):
with self.assertRaises(InvalidFailureReport):
save_failure_report_to_elasticsearch(report)
def test_index_dedup_pattern_searches_both_old_and_new_names(self):
"""The split-PR rename forensic→failure left existing data
in dmarc_forensic*; the dedup search must check both names
so re-runs don't double-index."""
with (
patch("parsedmarc.elastic.Search") as mock_search_cls,
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_failure_report_to_elasticsearch(_failure_report())
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("dmarc_failure*", search_index)
self.assertIn("dmarc_forensic*", search_index)
def test_index_name_uses_arrival_date_for_monthly_partition(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._FailureReportDoc, "save"),
):
save_failure_report_to_elasticsearch(
_failure_report(), monthly_indexes=True
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_failure-2024-01", index_calls)
def test_failure_search_index_with_suffix_and_prefix(self):
"""When both suffix and prefix are set, the dedup search
pattern joins them onto BOTH dmarc_failure* and
dmarc_forensic* (the rename back-compat)."""
with (
patch("parsedmarc.elastic.Search") as mock_search_cls,
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_failure_report_to_elasticsearch(
_failure_report(),
index_suffix="tenant_a",
index_prefix="cust_",
)
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("cust_dmarc_failure_tenant_a*", search_index)
self.assertIn("cust_dmarc_forensic_tenant_a*", search_index)
def test_failure_index_honours_suffix_and_prefix(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._FailureReportDoc, "save"),
):
save_failure_report_to_elasticsearch(
_failure_report(),
index_suffix="tenant_a",
index_prefix="cust_",
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("cust_dmarc_failure_tenant_a-2024-01-01", index_calls)
def test_from_header_with_empty_display_name(self):
"""When the From display name is empty, the code uses the
address alone (covers the early-return branch in the
display-name handling)."""
report = _failure_report()
report["parsed_sample"]["headers"]["From"] = [["", "sender@example.com"]]
report["parsed_sample"]["headers"]["To"] = [["", "rcpt@example.com"]]
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_elasticsearch(report)
mock_save.assert_called_once()
def test_to_header_with_non_empty_display_joins_with_brackets(self):
"""The other branch: non-empty display joins display+addr
with " <" and appends ">", e.g. 'RT <rcpt@example.com>'."""
report = _failure_report()
report["parsed_sample"]["headers"]["To"] = [["RT", "rcpt@example.com"]]
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_elasticsearch(report)
mock_save.assert_called_once()
def test_sample_address_lists_indexed_for_reply_to_cc_bcc_attachments(self):
"""A failure report sample can carry reply_to / cc / bcc /
attachments. Each populates a nested InnerDoc on the sample —
if the add_* helpers regress, those nested docs would be
silently empty in Elasticsearch."""
report = _failure_report()
report["parsed_sample"]["reply_to"] = [
{"display_name": "RT", "address": "rt@example.com"}
]
report["parsed_sample"]["cc"] = [
{"display_name": "CC", "address": "cc@example.com"}
]
report["parsed_sample"]["bcc"] = [
{"display_name": "", "address": "bcc@example.com"}
]
report["parsed_sample"]["attachments"] = [
{
"filename": "a.pdf",
"mail_content_type": "application/pdf",
"sha256": "deadbeef",
}
]
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_elasticsearch(report)
mock_save.assert_called_once()
# ---------------------------------------------------------------------------
# save_smtp_tls_report_to_elasticsearch
# ---------------------------------------------------------------------------
class TestSaveSmtpTlsReport(unittest.TestCase):
def test_save_emits_one_document(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_elasticsearch(_smtp_tls_report())
mock_save.assert_called_once()
def test_already_saved_raises_on_dedup_hit(self):
with (
patch("parsedmarc.elastic.Search", return_value=_populated_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._SMTPTLSReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_smtp_tls_report_to_elasticsearch(_smtp_tls_report())
mock_save.assert_not_called()
def test_search_exception_wraps_to_elasticsearch_error(self):
bad = MagicMock()
bad.execute.side_effect = RuntimeError("network")
with (
patch("parsedmarc.elastic.Search", return_value=bad),
patch("parsedmarc.elastic.Index"),
):
with self.assertRaises(ElasticsearchError):
save_smtp_tls_report_to_elasticsearch(_smtp_tls_report())
def test_save_exception_wraps_to_elasticsearch_error(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(
elastic_module._SMTPTLSReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(ElasticsearchError):
save_smtp_tls_report_to_elasticsearch(_smtp_tls_report())
def test_index_name_uses_begin_date_for_monthly_partition(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._SMTPTLSReportDoc, "save"),
):
save_smtp_tls_report_to_elasticsearch(
_smtp_tls_report(), monthly_indexes=True
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("smtp_tls-2024-02", index_calls)
def test_index_name_honours_suffix_and_prefix(self):
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index") as mock_index_cls,
patch.object(elastic_module._SMTPTLSReportDoc, "save"),
):
save_smtp_tls_report_to_elasticsearch(
_smtp_tls_report(), index_suffix="t1", index_prefix="cust_"
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("cust_smtp_tls_t1-2024-02-03", index_calls)
def test_policy_without_strings_or_mx_patterns(self):
"""policy_strings / mx_host_patterns are optional in the
report shape — verify the branch where they're absent."""
report = _smtp_tls_report()
for policy in report["policies"]:
policy.pop("policy_strings", None)
policy.pop("mx_host_patterns", None)
policy.pop("failure_details", None)
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_elasticsearch(report)
mock_save.assert_called_once()
def test_failure_details_all_optional_fields_populated(self):
"""Exercise every optional field in failure_details so the
full set of `if "x" in failure_detail` branches runs."""
report = _smtp_tls_report()
report["policies"][0]["failure_details"] = [
{
"result_type": "certificate-expired",
"failed_session_count": 1,
"receiving_mx_hostname": "mx.example.com",
"additional_information_uri": "https://example.com/why",
"failure_reason_code": "ERR_CERT",
"ip_address": "10.0.0.5",
"receiving_ip": "10.0.0.2",
"receiving_mx_helo": "mx.helo.example.com",
"sending_mta_ip": "10.0.0.1",
}
]
with (
patch("parsedmarc.elastic.Search", return_value=_empty_search()),
patch("parsedmarc.elastic.Index"),
patch.object(elastic_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_elasticsearch(report)
mock_save.assert_called_once()
class TestBackwardCompatAlias(unittest.TestCase):
def test_save_forensic_alias_points_to_save_failure(self):
self.assertIs(
elastic_module.save_forensic_report_to_elasticsearch,
elastic_module.save_failure_report_to_elasticsearch,
)
def test_forensic_doc_alias_points_to_failure_doc(self):
self.assertIs(
elastic_module._ForensicReportDoc, elastic_module._FailureReportDoc
)
self.assertIs(
elastic_module._ForensicSampleDoc, elastic_module._FailureSampleDoc
)
# Silence unused-import lint in the test module preamble.
_ = call
if __name__ == "__main__":
unittest.main(verbosity=2)
+319 -8
View File
@@ -1,18 +1,329 @@
"""Tests for parsedmarc.gelf"""
import logging
import unittest
from unittest.mock import MagicMock, patch
from parsedmarc.gelf import ContextFilter, GelfClient, log_context_data
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _sample_aggregate_report():
"""Minimal aggregate report shape acceptable to
parsed_aggregate_reports_to_csv_rows."""
return {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
"org_name": "example.com",
"org_email": "dmarc@example.com",
"org_extra_contact_info": None,
"report_id": "agg-1",
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-02 00:00:00",
"timespan_requires_normalization": False,
"original_timespan_seconds": 86400,
"errors": [],
"generator": None,
},
"policy_published": {
"domain": "example.com",
"adkim": "r",
"aspf": "r",
"p": "none",
"sp": "none",
"pct": None,
"fo": None,
"np": None,
"testing": None,
"discovery_method": None,
},
"records": [
{
"interval_begin": "2024-01-01 00:00:00",
"interval_end": "2024-01-02 00:00:00",
"normalized_timespan": False,
"source": {
"ip_address": "192.0.2.1",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"count": 7,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": None,
},
"auth_results": {
"dkim": [
{
"domain": "example.com",
"selector": "s1",
"result": "pass",
"human_result": None,
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": None,
}
],
},
}
],
}
def testGelfBackwardCompatAlias(self):
"""GelfClient forensic alias points to failure method"""
from parsedmarc.gelf import GelfClient
class _Handler(logging.Handler):
"""Capture the (record, extra) of every log emit, so tests can
assert on what GelfClient actually pushed."""
def __init__(self):
super().__init__()
self.records: list[tuple[str, dict]] = []
def emit(self, record):
# ContextFilter has run by this point so `record.parsedmarc` is
# whatever payload GelfClient set via log_context_data.
self.records.append((record.getMessage(), getattr(record, "parsedmarc", None)))
class TestGelfClientInit(unittest.TestCase):
"""GelfClient.__init__ wires a pygelf handler for the requested
transport. The mode lookup is a real failure surface: a typo in the
config (`udb` instead of `udp`) should KeyError loudly, not silently
pick the wrong transport."""
def test_init_udp_picks_udp_handler(self):
with (
patch("parsedmarc.gelf.GelfUdpHandler") as mock_udp,
patch("parsedmarc.gelf.GelfTcpHandler"),
patch("parsedmarc.gelf.GelfTlsHandler"),
):
GelfClient(host="graylog.example.com", port=12201, mode="udp")
mock_udp.assert_called_once_with(
host="graylog.example.com", port=12201, include_extra_fields=True
)
def test_init_tcp_picks_tcp_handler(self):
with (
patch("parsedmarc.gelf.GelfTcpHandler") as mock_tcp,
patch("parsedmarc.gelf.GelfUdpHandler"),
patch("parsedmarc.gelf.GelfTlsHandler"),
):
GelfClient(host="g", port=12201, mode="tcp")
mock_tcp.assert_called_once_with(
host="g", port=12201, include_extra_fields=True
)
def test_init_tls_picks_tls_handler(self):
with (
patch("parsedmarc.gelf.GelfTlsHandler") as mock_tls,
patch("parsedmarc.gelf.GelfUdpHandler"),
patch("parsedmarc.gelf.GelfTcpHandler"),
):
GelfClient(host="g", port=12201, mode="tls")
mock_tls.assert_called_once_with(
host="g", port=12201, include_extra_fields=True
)
def test_init_unknown_mode_raises_keyerror(self):
"""An unknown mode in config should be a loud failure, not silent."""
with (
patch("parsedmarc.gelf.GelfUdpHandler"),
patch("parsedmarc.gelf.GelfTcpHandler"),
patch("parsedmarc.gelf.GelfTlsHandler"),
):
with self.assertRaises(KeyError):
GelfClient(host="g", port=12201, mode="udb")
def _install_capturing_handler(client):
"""Replace the real pygelf handler with one that records emitted
log records and their `parsedmarc` payload. Returns the handler
so the test can inspect captured records."""
client.logger.removeHandler(client.handler)
h = _Handler()
client.logger.addHandler(h)
client.handler = h
return h
def _gelf_client():
# The parsedmarc_gelf logger is module-level — each new client adds
# another handler. Clear stale handlers from prior tests so the
# logger only carries this client's handler.
logging.getLogger("parsedmarc_gelf").handlers.clear()
with (
patch("parsedmarc.gelf.GelfUdpHandler"),
patch("parsedmarc.gelf.GelfTcpHandler"),
patch("parsedmarc.gelf.GelfTlsHandler"),
):
return GelfClient(host="g", port=12201, mode="udp")
class TestGelfClientSaveAggregate(unittest.TestCase):
"""save_aggregate_report_to_gelf emits one log record per
aggregate CSV row, with the row payload on `record.parsedmarc`.
Verifying the payload — not just "log was called" — catches future
regressions in the row-builder or filter wiring."""
def test_emits_one_record_per_csv_row_with_payload(self):
client = _gelf_client()
handler = _install_capturing_handler(client)
client.save_aggregate_report_to_gelf([_sample_aggregate_report()])
# One row in the sample report → one log record.
self.assertEqual(len(handler.records), 1)
message, payload = handler.records[0]
self.assertEqual(message, "parsedmarc aggregate report")
# The payload is the flattened CSV row; verify the key fields a
# Graylog dashboard would actually filter on.
self.assertEqual(payload["source_ip_address"], "192.0.2.1")
self.assertEqual(payload["header_from"], "example.com")
self.assertEqual(payload["count"], 7)
def test_clears_context_after_emit(self):
"""The thread-local payload is reset to None after the loop so
a later unrelated log call on the same thread doesn't carry
stale DMARC data."""
client = _gelf_client()
_install_capturing_handler(client)
client.save_aggregate_report_to_gelf([_sample_aggregate_report()])
self.assertIsNone(log_context_data.parsedmarc)
class TestGelfClientSaveFailure(unittest.TestCase):
"""save_failure_report_to_gelf operates on already-parsed failure
reports. Build one through the CSV-row helper to verify GelfClient
surfaces the right fields."""
def _sample_failure_report(self):
return {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "x@example.com",
"original_rcpt_to": None,
"arrival_date": "Thu, 1 Jan 2024 00:00:00 +0000",
"arrival_date_utc": "2024-01-01 00:00:00",
"authentication_results": None,
"delivery_result": "other",
"auth_failure": ["dmarc"],
"authentication_mechanisms": [],
"dkim_domain": None,
"reported_domain": "example.com",
"sample_headers_only": True,
"source": {
"ip_address": "192.0.2.5",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"sample": "...",
"parsed_sample": {"subject": "Test"},
}
def test_emits_one_record_per_failure_report(self):
client = _gelf_client()
handler = _install_capturing_handler(client)
client.save_failure_report_to_gelf([self._sample_failure_report()])
self.assertEqual(len(handler.records), 1)
message, payload = handler.records[0]
self.assertEqual(message, "parsedmarc failure report")
self.assertEqual(payload["source_ip_address"], "192.0.2.5")
self.assertEqual(payload["reported_domain"], "example.com")
class TestGelfClientSaveSmtpTls(unittest.TestCase):
def _sample_smtp_tls(self):
return {
"organization_name": "example.com",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"contact_info": "tls@example.com",
"report_id": "tls-1",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"successful_session_count": 100,
"failed_session_count": 0,
}
],
}
def test_emits_one_record_per_policy(self):
client = _gelf_client()
handler = _install_capturing_handler(client)
client.save_smtp_tls_report_to_gelf([self._sample_smtp_tls()])
self.assertEqual(len(handler.records), 1)
message, payload = handler.records[0]
self.assertEqual(message, "parsedmarc smtptls report")
self.assertEqual(payload["policy_domain"], "example.com")
self.assertEqual(payload["successful_session_count"], 100)
class TestContextFilter(unittest.TestCase):
"""ContextFilter copies log_context_data.parsedmarc onto the log
record so pygelf can include it as an extra field. Failure mode:
if the filter raises (or removes itself), GELF output goes dark."""
def test_filter_copies_thread_local_onto_record(self):
log_context_data.parsedmarc = {"hello": "world"}
try:
f = ContextFilter()
record = logging.LogRecord(
name="x",
level=logging.INFO,
pathname=__file__,
lineno=1,
msg="msg",
args=(),
exc_info=None,
)
result = f.filter(record)
self.assertTrue(result)
self.assertEqual(record.parsedmarc, {"hello": "world"}) # type: ignore[attr-defined]
finally:
log_context_data.parsedmarc = None
class TestGelfClientClose(unittest.TestCase):
def test_close_removes_and_closes_handler(self):
client = _gelf_client()
handler = MagicMock()
client.logger.removeHandler(client.handler)
client.logger.addHandler(handler)
client.handler = handler
client.close()
handler.close.assert_called_once()
# Handler should no longer be attached after close().
self.assertNotIn(handler, client.logger.handlers)
class TestGelfClientBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
GelfClient.save_forensic_report_to_gelf, # type: ignore[attr-defined]
GelfClient.save_failure_report_to_gelf,
+173
View File
@@ -2306,5 +2306,178 @@ class TestGetDmarcReportsFromMbox(unittest.TestCase):
os.remove(path)
class TestGetDmarcReportsFromMailboxValidation(unittest.TestCase):
"""Input validation on get_dmarc_reports_from_mailbox.
These guards prevent two real footguns: the test/delete combo
would otherwise delete every message after parsing — silently
destructive — and a None connection would NPE deep in the
iteration loop with a confusing traceback. Fail fast at the
door instead."""
def test_delete_and_test_combination_raises(self):
from unittest.mock import MagicMock
with self.assertRaises(ValueError) as ctx:
parsedmarc.get_dmarc_reports_from_mailbox(
connection=MagicMock(), delete=True, test=True
)
self.assertIn("mutually exclusive", str(ctx.exception))
def test_none_connection_raises(self):
with self.assertRaises(ValueError) as ctx:
parsedmarc.get_dmarc_reports_from_mailbox(connection=None)
self.assertIn("connection", str(ctx.exception).lower())
class TestEmailResultsErrorBranches(unittest.TestCase):
"""email_results requires mail_to to be a list — this is enforced
by an assert. A regression that dropped the assert would mean the
SMTP code further down would silently iterate over the characters
of a string."""
def test_mail_to_must_be_list(self):
with self.assertRaises(AssertionError):
parsedmarc.email_results(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
},
host="smtp.example.com",
mail_from="from@example.com",
mail_to="admin@example.com", # str, not list — triggers assert
)
class TestAppendJson(unittest.TestCase):
"""append_json writes new files cleanly and merges into existing
JSON arrays without breaking valid JSON."""
def test_writes_new_file(self):
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
path = tf.name
os.remove(path) # ensure file is fresh
try:
parsedmarc.append_json(path, [{"a": 1}])
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"a": 1}])
finally:
if os.path.exists(path):
os.remove(path)
def test_appends_to_existing_file(self):
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_json(path, [{"a": 1}])
parsedmarc.append_json(path, [{"b": 2}])
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"a": 1}, {"b": 2}])
finally:
if os.path.exists(path):
os.remove(path)
def test_empty_list_on_existing_file_is_noop(self):
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_json(path, [{"a": 1}])
parsedmarc.append_json(path, [])
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"a": 1}])
finally:
if os.path.exists(path):
os.remove(path)
def test_corrupt_existing_file_is_overwritten_cleanly(self):
"""If the existing JSON file is corrupt (e.g. truncated by a
prior crash, or hit the pre-fix `append_json` bug), the
read-merge-write path falls back to overwriting with the new
content rather than silently failing to record.
Recording at the cost of losing prior corrupt data is the
lesser evil — those bytes are already unparseable, so no
downstream consumer can read them anyway."""
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
tf.write("{ this is not valid json at all")
path = tf.name
try:
parsedmarc.append_json(path, [{"new": "data"}])
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"new": "data"}])
finally:
if os.path.exists(path):
os.remove(path)
def test_existing_file_with_non_list_root_is_overwritten(self):
"""If the existing file parses cleanly but the root isn't a
list (e.g. someone wrote {"foo": 1} by hand), the
isinstance(loaded, list) guard kicks in and we overwrite
rather than concatenating a dict and a list."""
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
tf.write('{"not": "a list"}')
path = tf.name
try:
parsedmarc.append_json(path, [{"new": "data"}])
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"new": "data"}])
finally:
if os.path.exists(path):
os.remove(path)
class TestAppendCsv(unittest.TestCase):
def test_writes_new_file_with_header(self):
with NamedTemporaryFile("w", suffix=".csv", delete=False) as tf:
path = tf.name
os.remove(path)
try:
parsedmarc.append_csv(path, "h1,h2\nv1,v2\n")
with open(path) as f:
content = f.read()
self.assertEqual(content, "h1,h2\nv1,v2\n")
finally:
if os.path.exists(path):
os.remove(path)
def test_appends_strips_header_on_existing_file(self):
"""Second append must not re-emit the header line."""
with NamedTemporaryFile("w", suffix=".csv", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_csv(path, "h1,h2\nv1,v2\n")
parsedmarc.append_csv(path, "h1,h2\nv3,v4\n")
with open(path) as f:
content = f.read()
# Only one header line in the merged output.
self.assertEqual(content.count("h1,h2"), 1)
self.assertIn("v3,v4", content)
finally:
if os.path.exists(path):
os.remove(path)
def test_append_empty_csv_on_existing_file_is_noop(self):
"""append_csv with just a header row (no data) should not
rewrite the file when one already exists."""
with NamedTemporaryFile("w", suffix=".csv", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_csv(path, "h1,h2\nv1,v2\n")
parsedmarc.append_csv(path, "h1,h2\n")
with open(path) as f:
content = f.read()
# File unchanged.
self.assertEqual(content, "h1,h2\nv1,v2\n")
finally:
if os.path.exists(path):
os.remove(path)
if __name__ == "__main__":
unittest.main(verbosity=2)
+256 -34
View File
@@ -1,53 +1,275 @@
"""Tests for parsedmarc.kafkaclient"""
import json
import unittest
from unittest.mock import MagicMock, patch
from kafka.errors import NoBrokersAvailable, UnknownTopicOrPartitionError
from parsedmarc.kafkaclient import KafkaClient, KafkaError
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _aggregate_report():
return {
"report_metadata": {
"org_name": "TestOrg",
"org_email": "test@example.com",
"report_id": "r-123",
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-02 00:00:00",
},
"policy_published": {"domain": "example.com", "p": "none"},
"records": [
{"source": {"ip_address": "192.0.2.1"}, "count": 1},
{"source": {"ip_address": "192.0.2.2"}, "count": 2},
],
}
def testKafkaStripMetadata(self):
"""KafkaClient.strip_metadata extracts metadata to root"""
from parsedmarc.kafkaclient import KafkaClient
report = {
"report_metadata": {
"org_name": "TestOrg",
"org_email": "test@example.com",
"report_id": "r-123",
"begin_date": "2024-01-01",
"end_date": "2024-01-02",
},
"records": [],
}
class TestKafkaClientInit(unittest.TestCase):
"""KafkaProducer config wiring: SSL, SASL, plain — each branch has
user-facing security consequences if it's wrong."""
def test_init_plain_no_ssl(self):
"""No SSL, no auth: just bootstrap_servers and serializer."""
with patch("parsedmarc.kafkaclient.KafkaProducer") as mock_producer:
KafkaClient(kafka_hosts=["broker:9092"])
kwargs = mock_producer.call_args.kwargs
self.assertEqual(kwargs["bootstrap_servers"], ["broker:9092"])
self.assertNotIn("security_protocol", kwargs)
self.assertNotIn("sasl_plain_username", kwargs)
def test_init_ssl_enables_ssl_security_protocol(self):
with (
patch("parsedmarc.kafkaclient.KafkaProducer") as mock_producer,
patch("parsedmarc.kafkaclient.create_default_context") as mock_ctx,
):
KafkaClient(kafka_hosts=["broker:9093"], ssl=True)
kwargs = mock_producer.call_args.kwargs
self.assertEqual(kwargs["security_protocol"], "SSL")
self.assertIs(kwargs["ssl_context"], mock_ctx.return_value)
def test_init_username_implies_ssl(self):
"""Doc says ssl=True is implied when username/password supplied."""
with (
patch("parsedmarc.kafkaclient.KafkaProducer") as mock_producer,
patch("parsedmarc.kafkaclient.create_default_context"),
):
KafkaClient(kafka_hosts=["broker:9093"], username="user", password="pass")
kwargs = mock_producer.call_args.kwargs
self.assertEqual(kwargs["security_protocol"], "SSL")
self.assertEqual(kwargs["sasl_plain_username"], "user")
self.assertEqual(kwargs["sasl_plain_password"], "pass")
def test_init_uses_provided_ssl_context(self):
"""A caller-supplied SSLContext takes precedence over the
default context — this lets ops pin to a private CA."""
custom_ctx = MagicMock()
with (
patch("parsedmarc.kafkaclient.KafkaProducer") as mock_producer,
patch("parsedmarc.kafkaclient.create_default_context") as mock_default,
):
KafkaClient(kafka_hosts=["b:9093"], ssl=True, ssl_context=custom_ctx)
self.assertIs(mock_producer.call_args.kwargs["ssl_context"], custom_ctx)
mock_default.assert_not_called()
def test_init_value_serializer_emits_utf8_json(self):
"""The value_serializer turns Python objects into UTF-8 JSON
bytes. A regression here would corrupt every event sent."""
with patch("parsedmarc.kafkaclient.KafkaProducer") as mock_producer:
KafkaClient(kafka_hosts=["b"])
serializer = mock_producer.call_args.kwargs["value_serializer"]
result = serializer({"hello": "world", "n": 1})
self.assertEqual(json.loads(result.decode("utf-8")), {"hello": "world", "n": 1})
def test_init_no_brokers_available_raises_kafka_error(self):
with patch(
"parsedmarc.kafkaclient.KafkaProducer",
side_effect=NoBrokersAvailable(),
):
with self.assertRaises(KafkaError) as ctx:
KafkaClient(kafka_hosts=["unreachable:9092"])
self.assertIn("No Kafka brokers", str(ctx.exception))
class TestKafkaClientHelpers(unittest.TestCase):
"""Static helpers used by save_aggregate."""
def test_strip_metadata_lifts_keys_to_root_and_drops_metadata(self):
report = _aggregate_report()
result = KafkaClient.strip_metadata(report)
self.assertEqual(result["org_name"], "TestOrg")
self.assertEqual(result["org_email"], "test@example.com")
self.assertEqual(result["report_id"], "r-123")
self.assertNotIn("report_metadata", result)
def testKafkaGenerateDateRange(self):
"""KafkaClient.generate_date_range generates date range list"""
from parsedmarc.kafkaclient import KafkaClient
def test_generate_date_range_iso_format(self):
report = _aggregate_report()
date_range = KafkaClient.generate_date_range(report)
self.assertEqual(date_range, ["2024-01-01T00:00:00", "2024-01-02T00:00:00"])
report = {
"report_metadata": {
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-02 00:00:00",
}
}
result = KafkaClient.generate_date_range(report)
self.assertEqual(len(result), 2)
self.assertIn("2024-01-01", result[0])
self.assertIn("2024-01-02", result[1])
def testKafkaBackwardCompatAlias(self):
"""KafkaClient forensic alias points to failure method"""
from parsedmarc.kafkaclient import KafkaClient
class TestSaveAggregateReportsToKafka(unittest.TestCase):
"""save_aggregate sends one Kafka message per record (slice), with
the metadata + policy duplicated onto each slice for Kibana parity."""
def _client(self):
with patch("parsedmarc.kafkaclient.KafkaProducer"):
return KafkaClient(kafka_hosts=["b:9092"])
def test_sends_one_message_per_record(self):
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "dmarc-aggregate")
# 2 records in the sample report → 2 producer.send calls.
self.assertEqual(client.producer.send.call_count, 2)
# Topic is forwarded verbatim.
for call in client.producer.send.call_args_list:
self.assertEqual(call.args[0], "dmarc-aggregate")
def test_each_slice_carries_metadata(self):
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
sent = [call.args[1] for call in client.producer.send.call_args_list]
for slice_ in sent:
self.assertEqual(slice_["org_name"], "TestOrg")
self.assertEqual(slice_["org_email"], "test@example.com")
self.assertEqual(slice_["report_id"], "r-123")
self.assertEqual(
slice_["date_range"], ["2024-01-01T00:00:00", "2024-01-02T00:00:00"]
)
self.assertEqual(
slice_["policy_published"], {"domain": "example.com", "p": "none"}
)
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_aggregate_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
def test_dict_input_normalized_to_list(self):
"""Single-report dict input is wrapped to a list."""
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
# 2 records still sent (one report with 2 records, not multiple reports).
self.assertEqual(client.producer.send.call_count, 2)
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "missing")
self.assertIn("Unknown topic or partition", str(ctx.exception))
def test_generic_send_exception_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = RuntimeError("transport failure")
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
self.assertIn("transport failure", str(ctx.exception))
def test_flush_exception_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = RuntimeError("flush failure")
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
self.assertIn("flush failure", str(ctx.exception))
class TestSaveFailureReportsToKafka(unittest.TestCase):
def _client(self):
with patch("parsedmarc.kafkaclient.KafkaProducer"):
return KafkaClient(kafka_hosts=["b:9092"])
def test_sends_full_list_in_one_message(self):
"""Failure reports go in a single Kafka message — the comment
in source code documents the 1MB-per-message default."""
client = self._client()
reports = [{"id": "f1"}, {"id": "f2"}]
client.save_failure_reports_to_kafka(reports, "dmarc-failure")
client.producer.send.assert_called_once_with("dmarc-failure", reports)
def test_dict_input_normalized_to_list(self):
client = self._client()
client.save_failure_reports_to_kafka({"id": "single"}, "topic")
# The send payload is wrapped to a single-element list.
args = client.producer.send.call_args.args
self.assertEqual(args[1], [{"id": "single"}])
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_failure_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "missing")
def test_generic_send_error_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = OSError("net")
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
def test_flush_error_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = OSError("flush")
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
class TestSaveSmtpTlsReportsToKafka(unittest.TestCase):
def _client(self):
with patch("parsedmarc.kafkaclient.KafkaProducer"):
return KafkaClient(kafka_hosts=["b:9092"])
def test_sends_full_list_in_one_message(self):
client = self._client()
reports = [{"organization_name": "x"}]
client.save_smtp_tls_reports_to_kafka(reports, "smtp-tls")
client.producer.send.assert_called_once_with("smtp-tls", reports)
def test_dict_input_normalized_to_list(self):
client = self._client()
client.save_smtp_tls_reports_to_kafka({"organization_name": "x"}, "topic")
args = client.producer.send.call_args.args
self.assertEqual(args[1], [{"organization_name": "x"}])
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_smtp_tls_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "missing")
def test_generic_send_error_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = RuntimeError("oops")
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
def test_flush_error_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = RuntimeError("flush")
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
class TestKafkaClientClose(unittest.TestCase):
def test_close_calls_underlying_producer_close(self):
with patch("parsedmarc.kafkaclient.KafkaProducer"):
client = KafkaClient(kafka_hosts=["b"])
client.close()
client.producer.close.assert_called_once()
class TestKafkaBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
KafkaClient.save_forensic_reports_to_kafka, # type: ignore[attr-defined]
KafkaClient.save_failure_reports_to_kafka,
+254 -30
View File
@@ -1,28 +1,39 @@
"""Tests for parsedmarc.loganalytics"""
import unittest
from unittest.mock import MagicMock, patch
from azure.core.exceptions import HttpResponseError
from parsedmarc.loganalytics import (
LogAnalyticsClient,
LogAnalyticsConfig,
LogAnalyticsException,
)
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _valid_kwargs(**overrides):
base = dict(
client_id="cid",
client_secret="csec",
tenant_id="tid",
dce="https://dce.example.com",
dcr_immutable_id="dcr-123",
dcr_aggregate_stream="agg-stream",
dcr_failure_stream="fail-stream",
dcr_smtp_tls_stream="tls-stream",
)
base.update(overrides)
return base
def testLogAnalyticsConfig(self):
"""LogAnalyticsConfig stores all fields"""
from parsedmarc.loganalytics import LogAnalyticsConfig
config = LogAnalyticsConfig(
client_id="cid",
client_secret="csec",
tenant_id="tid",
dce="https://dce.example.com",
dcr_immutable_id="dcr-123",
dcr_aggregate_stream="agg-stream",
dcr_failure_stream="fail-stream",
dcr_smtp_tls_stream="tls-stream",
)
class TestLogAnalyticsConfig(unittest.TestCase):
"""The config dataclass holds every credential and stream needed
to push to Log Analytics. A typo on any attribute would silently
drop data into the wrong stream."""
def test_config_stores_every_field(self):
config = LogAnalyticsConfig(**_valid_kwargs())
self.assertEqual(config.client_id, "cid")
self.assertEqual(config.client_secret, "csec")
self.assertEqual(config.tenant_id, "tid")
@@ -32,21 +43,234 @@ class Test(unittest.TestCase):
self.assertEqual(config.dcr_failure_stream, "fail-stream")
self.assertEqual(config.dcr_smtp_tls_stream, "tls-stream")
def testLogAnalyticsClientValidationError(self):
"""LogAnalyticsClient raises on missing required config"""
from parsedmarc.loganalytics import LogAnalyticsClient, LogAnalyticsException
class TestLogAnalyticsClientInit(unittest.TestCase):
"""The constructor's validation guards against a half-configured
deployment that would otherwise fail late inside Azure SDK calls
with confusing errors."""
def test_init_accepts_complete_config(self):
client = LogAnalyticsClient(**_valid_kwargs())
self.assertEqual(client.conf.client_id, "cid")
self.assertEqual(client.conf.dcr_immutable_id, "dcr-123")
def test_missing_client_id_raises(self):
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(
client_id="",
client_secret="csec",
tenant_id="tid",
dce="https://dce.example.com",
dcr_immutable_id="dcr-123",
dcr_aggregate_stream="agg",
dcr_failure_stream="fail",
dcr_smtp_tls_stream="tls",
LogAnalyticsClient(**_valid_kwargs(client_id=""))
def test_missing_client_secret_raises(self):
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(**_valid_kwargs(client_secret=""))
def test_missing_tenant_id_raises(self):
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(**_valid_kwargs(tenant_id=""))
def test_missing_dce_raises(self):
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(**_valid_kwargs(dce=""))
def test_missing_dcr_immutable_id_raises(self):
with self.assertRaises(LogAnalyticsException):
LogAnalyticsClient(**_valid_kwargs(dcr_immutable_id=""))
class TestPublishJson(unittest.TestCase):
"""publish_json wraps logs_client.upload and translates Azure
HttpResponseError into the module's own exception type so the CLI
error reporter can handle it uniformly."""
def test_publish_json_forwards_to_logs_client(self):
client = LogAnalyticsClient(**_valid_kwargs())
logs_client = MagicMock()
client.publish_json([{"a": 1}], logs_client, "agg-stream")
logs_client.upload.assert_called_once_with("dcr-123", "agg-stream", [{"a": 1}])
def test_publish_json_translates_http_error(self):
client = LogAnalyticsClient(**_valid_kwargs())
logs_client = MagicMock()
logs_client.upload.side_effect = HttpResponseError("forbidden")
with self.assertRaises(LogAnalyticsException) as ctx:
client.publish_json([{"a": 1}], logs_client, "stream")
self.assertIn("forbidden", str(ctx.exception))
class TestPublishResults(unittest.TestCase):
"""publish_results gates each report type behind both a config flag
(save_aggregate / save_failure / save_smtp_tls) and a configured
stream name. Both gates need to work — a missing stream alone is a
config bug that should be silent, but an explicit save_*=False
means the operator opted out."""
def _publish_with(self, results, **flags):
flags.setdefault("save_aggregate", True)
flags.setdefault("save_failure", True)
flags.setdefault("save_smtp_tls", True)
client = LogAnalyticsClient(**_valid_kwargs())
with (
patch("parsedmarc.loganalytics.ClientSecretCredential"),
patch("parsedmarc.loganalytics.LogsIngestionClient") as mock_client_cls,
):
mock_logs_client = mock_client_cls.return_value
client.publish_results(results, **flags)
return mock_logs_client
def test_aggregate_published_to_aggregate_stream(self):
logs_client = self._publish_with(
{
"aggregate_reports": [{"id": "a"}],
"failure_reports": [],
"smtp_tls_reports": [],
}
)
logs_client.upload.assert_called_once_with(
"dcr-123", "agg-stream", [{"id": "a"}]
)
def test_failure_published_to_failure_stream(self):
logs_client = self._publish_with(
{
"aggregate_reports": [],
"failure_reports": [{"id": "f"}],
"smtp_tls_reports": [],
}
)
logs_client.upload.assert_called_once_with(
"dcr-123", "fail-stream", [{"id": "f"}]
)
def test_smtp_tls_published_to_smtp_tls_stream(self):
logs_client = self._publish_with(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [{"id": "t"}],
}
)
logs_client.upload.assert_called_once_with(
"dcr-123", "tls-stream", [{"id": "t"}]
)
def test_all_three_published_together(self):
logs_client = self._publish_with(
{
"aggregate_reports": [{"id": "a"}],
"failure_reports": [{"id": "f"}],
"smtp_tls_reports": [{"id": "t"}],
}
)
self.assertEqual(logs_client.upload.call_count, 3)
streams_uploaded = {call.args[1] for call in logs_client.upload.call_args_list}
self.assertEqual(streams_uploaded, {"agg-stream", "fail-stream", "tls-stream"})
def test_save_aggregate_false_skips_aggregate(self):
logs_client = self._publish_with(
{
"aggregate_reports": [{"id": "a"}],
"failure_reports": [],
"smtp_tls_reports": [],
},
save_aggregate=False,
)
logs_client.upload.assert_not_called()
def test_save_failure_false_skips_failure(self):
logs_client = self._publish_with(
{
"aggregate_reports": [],
"failure_reports": [{"id": "f"}],
"smtp_tls_reports": [],
},
save_failure=False,
)
logs_client.upload.assert_not_called()
def test_save_smtp_tls_false_skips_smtp_tls(self):
logs_client = self._publish_with(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [{"id": "t"}],
},
save_smtp_tls=False,
)
logs_client.upload.assert_not_called()
def test_empty_results_publishes_nothing(self):
logs_client = self._publish_with(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
}
)
logs_client.upload.assert_not_called()
def test_missing_aggregate_stream_skips_aggregate(self):
"""If the operator hasn't configured a stream for one of the
report types, the corresponding publish branch is skipped
silently — matching the existing CLI deployment pattern where
a single client object handles whatever streams are set."""
client = LogAnalyticsClient(**_valid_kwargs(dcr_aggregate_stream=""))
with (
patch("parsedmarc.loganalytics.ClientSecretCredential"),
patch("parsedmarc.loganalytics.LogsIngestionClient") as mock_client_cls,
):
mock_logs_client = mock_client_cls.return_value
client.publish_results(
{
"aggregate_reports": [{"id": "a"}],
"failure_reports": [],
"smtp_tls_reports": [],
},
save_aggregate=True,
save_failure=True,
save_smtp_tls=True,
)
mock_logs_client.upload.assert_not_called()
def test_credential_built_from_config(self):
"""ClientSecretCredential is constructed with the conf's three
identity fields — a rename or order shuffle would auth as the
wrong principal."""
client = LogAnalyticsClient(**_valid_kwargs())
with (
patch("parsedmarc.loganalytics.ClientSecretCredential") as mock_cred,
patch("parsedmarc.loganalytics.LogsIngestionClient"),
):
client.publish_results(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
},
save_aggregate=True,
save_failure=True,
save_smtp_tls=True,
)
mock_cred.assert_called_once_with(
tenant_id="tid", client_id="cid", client_secret="csec"
)
def test_logs_ingestion_client_built_from_dce_and_credential(self):
client = LogAnalyticsClient(**_valid_kwargs())
with (
patch("parsedmarc.loganalytics.ClientSecretCredential") as mock_cred,
patch("parsedmarc.loganalytics.LogsIngestionClient") as mock_client_cls,
):
client.publish_results(
{
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
},
save_aggregate=True,
save_failure=True,
save_smtp_tls=True,
)
mock_client_cls.assert_called_once_with(
"https://dce.example.com", credential=mock_cred.return_value
)
if __name__ == "__main__":
+886
View File
@@ -0,0 +1,886 @@
"""Tests for parsedmarc.opensearch
Mocks at the opensearch-dsl SDK boundary (connections.create_connection,
Index, Search, Document.save) so the tests verify the parsedmarc-side
transformation logic — document construction, index naming, deduplication
queries, error wrapping — without needing a running OpenSearch cluster.
"""
import unittest
from unittest.mock import MagicMock, call, patch
import parsedmarc.opensearch as opensearch_module
from parsedmarc import InvalidFailureReport
from parsedmarc.opensearch import (
AlreadySaved,
OpenSearchError,
create_indexes,
migrate_indexes,
save_aggregate_report_to_opensearch,
save_failure_report_to_opensearch,
save_smtp_tls_report_to_opensearch,
set_hosts,
)
# ---------------------------------------------------------------------------
# Sample report fixtures
# ---------------------------------------------------------------------------
def _aggregate_report(**overrides):
base = {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
"org_name": "TestOrg",
"org_email": "dmarc@example.com",
"org_extra_contact_info": None,
"report_id": "agg-1",
"begin_date": "2024-01-15 00:00:00",
"end_date": "2024-01-16 00:00:00",
"timespan_requires_normalization": False,
"original_timespan_seconds": 86400,
"errors": [],
"generator": "TestGen/1.0",
},
"policy_published": {
"domain": "example.com",
"adkim": "r",
"aspf": "r",
"p": "none",
"sp": "none",
"pct": None,
"fo": None,
"np": "reject",
"testing": "n",
"discovery_method": "treewalk",
},
"records": [
{
"interval_begin": "2024-01-15 00:00:00",
"interval_end": "2024-01-16 00:00:00",
"normalized_timespan": False,
"source": {
"ip_address": "192.0.2.1",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"count": 4,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [
{"type": "local_policy", "comment": "approved"}
],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": "rcpt@example.com",
},
"auth_results": {
"dkim": [
{
"domain": "example.com",
"selector": "s",
"result": "pass",
"human_result": None,
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": None,
}
],
},
}
],
}
base.update(overrides)
return base
def _failure_report(**overrides):
base = {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "x@example.com",
"original_rcpt_to": None,
"arrival_date": "Thu, 1 Jan 2024 00:00:00 +0000",
"arrival_date_utc": "2024-01-01 00:00:00",
"authentication_results": None,
"delivery_result": "other",
"auth_failure": ["dmarc"],
"authentication_mechanisms": [],
"dkim_domain": None,
"reported_domain": "example.com",
"sample_headers_only": True,
"source": {
"ip_address": "192.0.2.5",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"sample": "raw",
"parsed_sample": {
"headers": {
# mailparser emits headers as [[display_name, address]]
# lists; an empty display becomes [["", address]].
"From": [["Sender Name", "sender@example.com"]],
"To": [["", "rcpt@example.com"]],
"Subject": "Test",
},
"subject": "Test",
"filename_safe_subject": "Test",
"body": "body",
"date": "Thu, 1 Jan 2024 00:00:00 +0000",
"to": [{"display_name": None, "address": "rcpt@example.com"}],
"reply_to": [],
"cc": [],
"bcc": [],
"attachments": [],
},
}
base.update(overrides)
return base
def _smtp_tls_report(**overrides):
base = {
"organization_name": "TestOrg",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"contact_info": "tls@example.com",
"report_id": "tls-1",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"successful_session_count": 100,
"failed_session_count": 1,
"policy_strings": ["version: STSv1"],
"mx_host_patterns": ["*.example.com"],
"failure_details": [
{
"result_type": "certificate-expired",
"failed_session_count": 1,
"receiving_mx_hostname": "mx.example.com",
"sending_mta_ip": "10.0.0.1",
}
],
}
],
}
base.update(overrides)
return base
def _empty_search():
"""A Search() mock whose .execute() returns an empty hit list."""
search = MagicMock()
search.execute.return_value = []
return search
def _populated_search():
"""A Search() mock whose .execute() returns a non-empty hit list."""
search = MagicMock()
search.execute.return_value = [MagicMock()]
return search
# ---------------------------------------------------------------------------
# set_hosts: connection-parameter assembly
# ---------------------------------------------------------------------------
class TestSetHosts(unittest.TestCase):
"""Verify the conn_params dict handed to opensearch-dsl
matches each documented option. Each branch corresponds to a
real-world deployment shape (TLS, basic auth, API key, custom CA)."""
def test_single_host_string_normalized_to_list(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("https://es:9200")
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["hosts"], ["https://es:9200"])
def test_host_list_preserved(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts(["es1:9200", "es2:9200"])
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["hosts"], ["es1:9200", "es2:9200"])
def test_timeout_default_60s(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200")
self.assertEqual(mock_conn.call_args.kwargs["timeout"], 60.0)
def test_timeout_custom(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", timeout=30.0)
self.assertEqual(mock_conn.call_args.kwargs["timeout"], 30.0)
def test_use_ssl_enables_verify_by_default(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True)
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["use_ssl"], True)
self.assertEqual(kwargs["verify_certs"], True)
self.assertNotIn("ca_certs", kwargs)
def test_use_ssl_with_custom_ca(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True, ssl_cert_path="/etc/ca.pem")
kwargs = mock_conn.call_args.kwargs
self.assertEqual(kwargs["ca_certs"], "/etc/ca.pem")
def test_skip_certificate_verification_sets_verify_false(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", use_ssl=True, skip_certificate_verification=True)
self.assertEqual(mock_conn.call_args.kwargs["verify_certs"], False)
def test_username_password_sets_http_auth(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", username="u", password="p")
self.assertEqual(mock_conn.call_args.kwargs["http_auth"], ("u", "p"))
def test_username_without_password_not_set(self):
"""Half-configured auth is suspicious enough not to send."""
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", username="u")
self.assertNotIn("http_auth", mock_conn.call_args.kwargs)
def test_api_key_set(self):
with patch("parsedmarc.opensearch.connections.create_connection") as mock_conn:
set_hosts("es:9200", api_key="base64key==")
self.assertEqual(mock_conn.call_args.kwargs["api_key"], "base64key==")
def test_awssigv4_requires_aws_region(self):
"""SigV4 needs an AWS region to sign requests; missing it
must fail loudly, not silently fall back to unsigned auth."""
with self.assertRaises(OpenSearchError) as ctx:
set_hosts("es.amazonaws.com:443", auth_type="awssigv4")
self.assertIn("aws_region", str(ctx.exception))
def test_awssigv4_uses_boto3_credentials_and_signer(self):
"""SigV4 path resolves AWS credentials via boto3 and wires
an AWSV4SignerAuth into the connection params, plus the
RequestsHttpConnection class required by the signer."""
with (
patch("parsedmarc.opensearch.boto3.Session") as mock_session,
patch("parsedmarc.opensearch.AWSV4SignerAuth") as mock_signer,
patch("parsedmarc.opensearch.connections.create_connection") as mock_conn,
):
mock_session.return_value.get_credentials.return_value = MagicMock()
set_hosts(
"es.amazonaws.com:443",
auth_type="awssigv4",
aws_region="us-west-2",
)
kwargs = mock_conn.call_args.kwargs
self.assertIs(kwargs["http_auth"], mock_signer.return_value)
mock_signer.assert_called_once()
# connection_class must be set so opensearch-py uses the
# requests-based transport AWSV4SignerAuth requires.
self.assertIn("connection_class", kwargs)
def test_awssigv4_no_credentials_raises(self):
"""If boto3 can't find credentials, fail with a clear error
rather than letting OpenSearch raise an opaque auth error later."""
with patch("parsedmarc.opensearch.boto3.Session") as mock_session:
mock_session.return_value.get_credentials.return_value = None
with self.assertRaises(OpenSearchError) as ctx:
set_hosts(
"es.amazonaws.com:443",
auth_type="awssigv4",
aws_region="us-west-2",
)
self.assertIn("credentials", str(ctx.exception).lower())
def test_unsupported_auth_type_raises(self):
with self.assertRaises(OpenSearchError) as ctx:
set_hosts("es:9200", auth_type="kerberos")
self.assertIn("Unsupported", str(ctx.exception))
self.assertIn("kerberos", str(ctx.exception))
# ---------------------------------------------------------------------------
# create_indexes
# ---------------------------------------------------------------------------
class TestCreateIndexes(unittest.TestCase):
def test_creates_missing_index_with_default_settings(self):
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = False
create_indexes(["dmarc_aggregate-2024-01-15"])
mock_index.settings.assert_called_once_with(
number_of_shards=1, number_of_replicas=0
)
mock_index.create.assert_called_once()
def test_creates_with_custom_settings(self):
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = False
create_indexes(
["idx"], settings={"number_of_shards": 3, "refresh_interval": "5s"}
)
mock_index.settings.assert_called_once_with(
number_of_shards=3, refresh_interval="5s"
)
def test_skips_existing_index(self):
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
mock_index = mock_index_cls.return_value
mock_index.exists.return_value = True
create_indexes(["idx"])
mock_index.create.assert_not_called()
def test_wraps_sdk_error(self):
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
mock_index_cls.return_value.exists.side_effect = RuntimeError(
"cluster down"
)
with self.assertRaises(OpenSearchError) as ctx:
create_indexes(["idx"])
self.assertIn("cluster down", str(ctx.exception))
# ---------------------------------------------------------------------------
# migrate_indexes
# ---------------------------------------------------------------------------
class TestMigrateIndexes(unittest.TestCase):
"""The legacy `published_policy.fo` field was mapped as `long` in
older indexes. migrate_indexes detects that and rebuilds the index
with the text/keyword shape. The branch is gnarly; a regression
would silently leave old data un-migrated."""
def test_no_indexes_is_noop(self):
migrate_indexes() # Should not raise
def test_skips_non_existent_index(self):
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
mock_index_cls.return_value.exists.return_value = False
migrate_indexes(aggregate_indexes=["missing"])
# exists() returned False — no field_mapping fetch.
mock_index_cls.return_value.get_field_mapping.assert_not_called()
def test_skips_when_doc_mapping_absent(self):
"""An index that has 'fo' but not under the 'doc' type
(e.g., empty index with default mapping) is left alone."""
with patch("parsedmarc.opensearch.Index") as mock_index_cls:
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {"some_key": {"mappings": {}}}
with patch("parsedmarc.opensearch.reindex") as mock_reindex:
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2023-01-01"])
mock_reindex.assert_not_called()
def test_migrates_when_fo_is_long(self):
"""The actual migration path: when fo is mapped as 'long',
a v2 index is created with the corrected mapping, data is
reindexed, and the old index is deleted."""
with (
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch("parsedmarc.opensearch.reindex") as mock_reindex,
patch("parsedmarc.opensearch.connections.get_connection") as mock_get_conn,
):
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {
"dmarc_aggregate-2023-01-01": {
"mappings": {
"doc": {
"published_policy.fo": {"mapping": {"fo": {"type": "long"}}}
}
}
}
}
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2023-01-01"])
# reindex called from old → new (v2) index.
mock_reindex.assert_called_once()
# connections.get_connection consulted to get the ES client.
mock_get_conn.assert_called_once()
def test_skips_when_fo_already_text(self):
with (
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch("parsedmarc.opensearch.reindex") as mock_reindex,
):
idx = mock_index_cls.return_value
idx.exists.return_value = True
idx.get_field_mapping.return_value = {
"dmarc_aggregate-2024-01-01": {
"mappings": {
"doc": {
"published_policy.fo": {"mapping": {"fo": {"type": "text"}}}
}
}
}
}
migrate_indexes(aggregate_indexes=["dmarc_aggregate-2024-01-01"])
mock_reindex.assert_not_called()
# ---------------------------------------------------------------------------
# save_aggregate_report_to_opensearch
# ---------------------------------------------------------------------------
class TestSaveAggregateReport(unittest.TestCase):
"""The aggregate-report save fans out across multiple SDK calls:
Search (for dedup), Index.create (for the daily/monthly index),
Document.save. Each test patches the boundary it needs and
leaves the rest alone."""
def _patches(self, search_factory=_empty_search):
return [
patch("parsedmarc.opensearch.Search", return_value=search_factory()),
patch(
"parsedmarc.opensearch.Index",
return_value=MagicMock(exists=MagicMock(return_value=True)),
),
patch.object(opensearch_module._AggregateReportDoc, "save"),
]
def test_save_emits_one_document_per_record(self):
report = _aggregate_report()
report["records"].append(report["records"][0].copy())
patches = self._patches()
with patches[0], patches[1], patches[2] as mock_save:
save_aggregate_report_to_opensearch(report)
# Two records → two saves.
self.assertEqual(mock_save.call_count, 2)
def test_already_saved_raises_when_search_returns_hit(self):
"""The dedup query is the only thing preventing
double-indexing on re-run. A regression would silently
re-save reports, inflating Kibana counts."""
with (
patch("parsedmarc.opensearch.Search", return_value=_populated_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._AggregateReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_aggregate_report_to_opensearch(_aggregate_report())
mock_save.assert_not_called()
def test_search_exception_wraps_to_opensearch_error(self):
bad_search = MagicMock()
bad_search.execute.side_effect = RuntimeError("network")
with (
patch("parsedmarc.opensearch.Search", return_value=bad_search),
patch("parsedmarc.opensearch.Index"),
):
with self.assertRaises(OpenSearchError) as ctx:
save_aggregate_report_to_opensearch(_aggregate_report())
self.assertIn("network", str(ctx.exception))
def test_save_exception_wraps_to_opensearch_error(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(
opensearch_module._AggregateReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(OpenSearchError) as ctx:
save_aggregate_report_to_opensearch(_aggregate_report())
self.assertIn("disk", str(ctx.exception))
def test_index_name_uses_daily_format_by_default(self):
"""Index naming: dmarc_aggregate-YYYY-MM-DD by default."""
index_calls = []
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_opensearch(_aggregate_report())
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_aggregate-2024-01-15", index_calls)
def test_index_name_uses_monthly_format_when_flag_set(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_opensearch(
_aggregate_report(), monthly_indexes=True
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_aggregate-2024-01", index_calls)
def test_index_name_honours_suffix_and_prefix(self):
"""Prefix/suffix support multi-tenant setups where one ES
cluster serves several DMARC owners."""
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._AggregateReportDoc, "save"),
):
mock_index_cls.return_value.exists.return_value = True
save_aggregate_report_to_opensearch(
_aggregate_report(),
index_suffix="tenant_a",
index_prefix="customer1_",
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("customer1_dmarc_aggregate_tenant_a-2024-01-15", index_calls)
def test_dedup_search_pattern_uses_suffix_wildcard(self):
"""Existing-report search uses '*' so it matches both
daily and monthly index buckets."""
with (
patch("parsedmarc.opensearch.Search") as mock_search_cls,
patch(
"parsedmarc.opensearch.Index",
return_value=MagicMock(exists=MagicMock(return_value=True)),
),
patch.object(opensearch_module._AggregateReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_aggregate_report_to_opensearch(
_aggregate_report(), index_suffix="tenant_a", index_prefix="cust_"
)
# Search index pattern wraps prefix+name+suffix with trailing wildcard.
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("cust_dmarc_aggregate_tenant_a*", search_index)
# ---------------------------------------------------------------------------
# save_failure_report_to_opensearch
# ---------------------------------------------------------------------------
class TestSaveFailureReport(unittest.TestCase):
def test_save_emits_one_document(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_opensearch(_failure_report())
mock_save.assert_called_once()
def test_already_saved_raises_on_dedup_hit(self):
"""Failure-report dedup uses arrival_date + From/To/Subject
from the parsed sample. A hit means we've already indexed
this exact failure sample."""
with (
patch("parsedmarc.opensearch.Search", return_value=_populated_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_failure_report_to_opensearch(_failure_report())
mock_save.assert_not_called()
def test_save_exception_wraps_to_opensearch_error(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(
opensearch_module._FailureReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(OpenSearchError) as ctx:
save_failure_report_to_opensearch(_failure_report())
self.assertIn("disk", str(ctx.exception))
def test_keyerror_wraps_to_invalid_failure_report(self):
"""A malformed failure report (missing a required field) is
surfaced as InvalidFailureReport so the caller can route it
differently from infra errors."""
report = _failure_report()
del report["feedback_type"]
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save"),
):
with self.assertRaises(InvalidFailureReport):
save_failure_report_to_opensearch(report)
def test_index_dedup_pattern_searches_both_old_and_new_names(self):
"""The split-PR rename forensic→failure left existing data
in dmarc_forensic*; the dedup search must check both names
so re-runs don't double-index."""
with (
patch("parsedmarc.opensearch.Search") as mock_search_cls,
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_failure_report_to_opensearch(_failure_report())
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("dmarc_failure*", search_index)
self.assertIn("dmarc_forensic*", search_index)
def test_index_name_uses_arrival_date_for_monthly_partition(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._FailureReportDoc, "save"),
):
save_failure_report_to_opensearch(_failure_report(), monthly_indexes=True)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("dmarc_failure-2024-01", index_calls)
def test_failure_search_index_with_suffix_and_prefix(self):
"""When both suffix and prefix are set, the dedup search
pattern joins them onto BOTH dmarc_failure* and
dmarc_forensic* (the rename back-compat)."""
with (
patch("parsedmarc.opensearch.Search") as mock_search_cls,
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save"),
):
mock_search_cls.return_value.execute.return_value = []
save_failure_report_to_opensearch(
_failure_report(),
index_suffix="tenant_a",
index_prefix="cust_",
)
search_index = mock_search_cls.call_args.kwargs["index"]
self.assertIn("cust_dmarc_failure_tenant_a*", search_index)
self.assertIn("cust_dmarc_forensic_tenant_a*", search_index)
def test_failure_index_honours_suffix_and_prefix(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._FailureReportDoc, "save"),
):
save_failure_report_to_opensearch(
_failure_report(),
index_suffix="tenant_a",
index_prefix="cust_",
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("cust_dmarc_failure_tenant_a-2024-01-01", index_calls)
def test_from_header_with_empty_display_name(self):
"""When the From display name is empty, the code uses the
address alone (covers the early-return branch in the
display-name handling)."""
report = _failure_report()
report["parsed_sample"]["headers"]["From"] = [["", "sender@example.com"]]
report["parsed_sample"]["headers"]["To"] = [["", "rcpt@example.com"]]
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_opensearch(report)
mock_save.assert_called_once()
def test_to_header_with_non_empty_display_joins_with_brackets(self):
"""The other branch: non-empty display joins display+addr
with " <" and appends ">", e.g. 'RT <rcpt@example.com>'."""
report = _failure_report()
report["parsed_sample"]["headers"]["To"] = [["RT", "rcpt@example.com"]]
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_opensearch(report)
mock_save.assert_called_once()
def test_sample_address_lists_indexed_for_reply_to_cc_bcc_attachments(self):
"""A failure report sample can carry reply_to / cc / bcc /
attachments. Each populates a nested InnerDoc on the sample —
if the add_* helpers regress, those nested docs would be
silently empty in OpenSearch."""
report = _failure_report()
report["parsed_sample"]["reply_to"] = [
{"display_name": "RT", "address": "rt@example.com"}
]
report["parsed_sample"]["cc"] = [
{"display_name": "CC", "address": "cc@example.com"}
]
report["parsed_sample"]["bcc"] = [
{"display_name": "", "address": "bcc@example.com"}
]
report["parsed_sample"]["attachments"] = [
{
"filename": "a.pdf",
"mail_content_type": "application/pdf",
"sha256": "deadbeef",
}
]
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._FailureReportDoc, "save") as mock_save,
):
save_failure_report_to_opensearch(report)
mock_save.assert_called_once()
# ---------------------------------------------------------------------------
# save_smtp_tls_report_to_opensearch
# ---------------------------------------------------------------------------
class TestSaveSmtpTlsReport(unittest.TestCase):
def test_save_emits_one_document(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_opensearch(_smtp_tls_report())
mock_save.assert_called_once()
def test_already_saved_raises_on_dedup_hit(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_populated_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._SMTPTLSReportDoc, "save") as mock_save,
):
with self.assertRaises(AlreadySaved):
save_smtp_tls_report_to_opensearch(_smtp_tls_report())
mock_save.assert_not_called()
def test_search_exception_wraps_to_opensearch_error(self):
bad = MagicMock()
bad.execute.side_effect = RuntimeError("network")
with (
patch("parsedmarc.opensearch.Search", return_value=bad),
patch("parsedmarc.opensearch.Index"),
):
with self.assertRaises(OpenSearchError):
save_smtp_tls_report_to_opensearch(_smtp_tls_report())
def test_save_exception_wraps_to_opensearch_error(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(
opensearch_module._SMTPTLSReportDoc,
"save",
side_effect=RuntimeError("disk"),
),
):
with self.assertRaises(OpenSearchError):
save_smtp_tls_report_to_opensearch(_smtp_tls_report())
def test_index_name_uses_begin_date_for_monthly_partition(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._SMTPTLSReportDoc, "save"),
):
save_smtp_tls_report_to_opensearch(_smtp_tls_report(), monthly_indexes=True)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("smtp_tls-2024-02", index_calls)
def test_index_name_honours_suffix_and_prefix(self):
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index") as mock_index_cls,
patch.object(opensearch_module._SMTPTLSReportDoc, "save"),
):
save_smtp_tls_report_to_opensearch(
_smtp_tls_report(), index_suffix="t1", index_prefix="cust_"
)
index_calls = [c.args[0] for c in mock_index_cls.call_args_list]
self.assertIn("cust_smtp_tls_t1-2024-02-03", index_calls)
def test_policy_without_strings_or_mx_patterns(self):
"""policy_strings / mx_host_patterns are optional in the
report shape — verify the branch where they're absent."""
report = _smtp_tls_report()
for policy in report["policies"]:
policy.pop("policy_strings", None)
policy.pop("mx_host_patterns", None)
policy.pop("failure_details", None)
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_opensearch(report)
mock_save.assert_called_once()
def test_failure_details_all_optional_fields_populated(self):
"""Exercise every optional field in failure_details so the
full set of `if "x" in failure_detail` branches runs."""
report = _smtp_tls_report()
report["policies"][0]["failure_details"] = [
{
"result_type": "certificate-expired",
"failed_session_count": 1,
"receiving_mx_hostname": "mx.example.com",
"additional_information_uri": "https://example.com/why",
"failure_reason_code": "ERR_CERT",
"ip_address": "10.0.0.5",
"receiving_ip": "10.0.0.2",
"receiving_mx_helo": "mx.helo.example.com",
"sending_mta_ip": "10.0.0.1",
}
]
with (
patch("parsedmarc.opensearch.Search", return_value=_empty_search()),
patch("parsedmarc.opensearch.Index"),
patch.object(opensearch_module._SMTPTLSReportDoc, "save") as mock_save,
):
save_smtp_tls_report_to_opensearch(report)
mock_save.assert_called_once()
class TestBackwardCompatAlias(unittest.TestCase):
def test_save_forensic_alias_points_to_save_failure(self):
self.assertIs(
opensearch_module.save_forensic_report_to_opensearch,
opensearch_module.save_failure_report_to_opensearch,
)
def test_forensic_doc_alias_points_to_failure_doc(self):
self.assertIs(
opensearch_module._ForensicReportDoc, opensearch_module._FailureReportDoc
)
self.assertIs(
opensearch_module._ForensicSampleDoc, opensearch_module._FailureSampleDoc
)
# Silence unused-import lint in the test module preamble.
_ = call
if __name__ == "__main__":
unittest.main(verbosity=2)
+205 -8
View File
@@ -1,18 +1,215 @@
"""Tests for parsedmarc.s3"""
import json
import unittest
from unittest.mock import MagicMock, patch
from parsedmarc.s3 import S3Client
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _sample_aggregate_report():
"""Minimal aggregate report shape used by S3Client.save_*_to_s3."""
return {
"report_metadata": {
"org_name": "example.com",
"org_email": "dmarc@example.com",
"report_id": "agg-123",
"begin_date": "2024-01-15 00:00:00",
"end_date": "2024-01-16 00:00:00",
# not in S3Client.metadata_keys; should NOT appear on the S3 object
"errors": [],
},
"policy_published": {"domain": "example.com", "p": "none"},
"records": [],
}
def testS3BackwardCompatAlias(self):
"""S3Client forensic alias points to failure method"""
from parsedmarc.s3 import S3Client
def _sample_smtp_tls_report():
"""Minimal SMTP TLS report shape as parse_smtp_tls_report_json
produces it — flat, with ISO-string begin_date / end_date pulled
directly from the report JSON."""
return {
"organization_name": "example.com",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"report_id": "tls-456",
"contact_info": "tls-admin@example.com",
"policies": [],
}
class TestS3ClientInit(unittest.TestCase):
"""S3Client.__init__ delegates to boto3.resource() with the supplied
credentials and endpoint. A regression in argument names or order
would silently send reports to the wrong bucket or auth as the wrong
principal."""
def test_init_forwards_credentials_to_boto3(self):
with patch("parsedmarc.s3.boto3.resource") as mock_resource:
S3Client(
bucket_name="my-bucket",
bucket_path="dmarc",
region_name="us-east-1",
endpoint_url="https://s3.example.com",
access_key_id="AKIA-test",
secret_access_key="secret-test",
)
mock_resource.assert_called_once_with(
"s3",
region_name="us-east-1",
endpoint_url="https://s3.example.com",
aws_access_key_id="AKIA-test",
aws_secret_access_key="secret-test",
)
def test_init_caches_bucket_handle(self):
"""self.bucket is the Bucket(bucket_name) on the boto3 resource,
so subsequent save_* calls go to the right bucket."""
with patch("parsedmarc.s3.boto3.resource") as mock_resource:
mock_resource.return_value.Bucket.return_value = "bucket-handle"
client = S3Client(
bucket_name="my-bucket",
bucket_path="dmarc",
region_name="us-east-1",
endpoint_url="https://s3.example.com",
access_key_id="k",
secret_access_key="s",
)
mock_resource.return_value.Bucket.assert_called_once_with("my-bucket")
self.assertEqual(client.bucket, "bucket-handle")
class TestS3ClientSavePathsAndMetadata(unittest.TestCase):
"""The S3 key is built from the report's begin_date and report_id.
Wrong format = unfindable reports; wrong metadata filtering = secret
leakage onto the S3 object."""
def _client_with_mock_bucket(self):
with patch("parsedmarc.s3.boto3.resource"):
client = S3Client(
bucket_name="b",
bucket_path="dmarc",
region_name="us-east-1",
endpoint_url="https://s3.example.com",
access_key_id="k",
secret_access_key="s",
)
client.bucket = MagicMock()
return client
def test_aggregate_dispatches_with_aggregate_in_key_path(self):
"""save_aggregate_report_to_s3 puts the object under
<bucket_path>/aggregate/year=YYYY/month=MM/day=DD/<report_id>.json."""
client = self._client_with_mock_bucket()
client.save_aggregate_report_to_s3(_sample_aggregate_report())
client.bucket.put_object.assert_called_once()
call = client.bucket.put_object.call_args
self.assertEqual(
call.kwargs["Key"],
"dmarc/aggregate/year=2024/month=01/day=15/agg-123.json",
)
def test_failure_dispatches_with_failure_in_key_path(self):
client = self._client_with_mock_bucket()
report = _sample_aggregate_report()
report["report_metadata"]["report_id"] = "fail-789"
client.save_failure_report_to_s3(report)
key = client.bucket.put_object.call_args.kwargs["Key"]
self.assertEqual(key, "dmarc/failure/year=2024/month=01/day=15/fail-789.json")
def test_smtp_tls_uses_report_begin_date(self):
"""SMTP TLS reports are flat — no report_metadata — and
begin_date is the ISO string produced by parse_smtp_tls_report_json.
The S3 path-builder parses that string into a datetime for the
year=/month=/day= key segments.
Regression: an earlier version assumed ALL reports carried a
report_metadata sub-object, which crashed with KeyError on every
SMTP TLS save. The CLI swallowed the error and only logged it,
so the bug was invisible in production."""
client = self._client_with_mock_bucket()
client.save_smtp_tls_report_to_s3(_sample_smtp_tls_report())
key = client.bucket.put_object.call_args.kwargs["Key"]
self.assertEqual(key, "dmarc/smtp_tls/year=2024/month=02/day=03/tls-456.json")
def test_smtp_tls_metadata_comes_from_flat_report_fields(self):
"""SMTP TLS object metadata is built from the flat report
instead of report_metadata. organization_name is renamed to
org_name (the S3 metadata key) for consistency with DMARC."""
client = self._client_with_mock_bucket()
client.save_smtp_tls_report_to_s3(_sample_smtp_tls_report())
meta = client.bucket.put_object.call_args.kwargs["Metadata"]
self.assertEqual(meta["org_name"], "example.com")
self.assertEqual(meta["report_id"], "tls-456")
self.assertEqual(meta["begin_date"], "2024-02-03T00:00:00Z")
self.assertEqual(meta["end_date"], "2024-02-04T00:00:00Z")
def test_object_body_is_json_serialized_report(self):
client = self._client_with_mock_bucket()
report = _sample_aggregate_report()
client.save_aggregate_report_to_s3(report)
body = client.bucket.put_object.call_args.kwargs["Body"]
# Round-trip the JSON to make sure it actually deserializes and
# carries every top-level key the source report had.
self.assertEqual(json.loads(body), report)
def test_metadata_filtered_to_documented_keys_only(self):
"""report_metadata fields outside `metadata_keys` must not be
attached to the S3 object — they could leak large or sensitive
payloads (errors lists, internal IDs) into object metadata."""
client = self._client_with_mock_bucket()
report = _sample_aggregate_report()
report["report_metadata"]["errors"] = ["a", "b"]
report["report_metadata"]["internal_diag"] = "secret"
client.save_aggregate_report_to_s3(report)
meta = client.bucket.put_object.call_args.kwargs["Metadata"]
self.assertEqual(
set(meta.keys()),
{"org_name", "org_email", "report_id", "begin_date", "end_date"},
)
self.assertNotIn("errors", meta)
self.assertNotIn("internal_diag", meta)
class TestS3ClientClose(unittest.TestCase):
"""close() must release the underlying boto3 client; a slow leak
here matters for long-running watch-mode processes."""
def test_close_calls_underlying_client_close(self):
with patch("parsedmarc.s3.boto3.resource") as mock_resource:
client = S3Client(
bucket_name="b",
bucket_path="p",
region_name="r",
endpoint_url="https://s3.example.com",
access_key_id="k",
secret_access_key="s",
)
client.close()
mock_resource.return_value.meta.client.close.assert_called_once()
def test_close_swallows_exceptions_from_underlying_client(self):
"""close() is called during shutdown/reload; if boto3 raises
from the close path, we don't want it to propagate and prevent
clean exit. The except is defensive but deliberate."""
with patch("parsedmarc.s3.boto3.resource") as mock_resource:
mock_resource.return_value.meta.client.close.side_effect = RuntimeError(
"boom"
)
client = S3Client(
bucket_name="b",
bucket_path="p",
region_name="r",
endpoint_url="https://s3.example.com",
access_key_id="k",
secret_access_key="s",
)
# Should not raise.
client.close()
class TestS3ClientBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
S3Client.save_forensic_report_to_s3, # type: ignore[attr-defined]
S3Client.save_failure_report_to_s3,
+406 -26
View File
@@ -1,44 +1,424 @@
"""Tests for parsedmarc.splunk"""
import json
import unittest
from unittest.mock import MagicMock
from parsedmarc.splunk import HECClient, SplunkError
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _aggregate_report():
return {
"report_metadata": {
"org_name": "TestOrg",
"org_email": "dmarc@example.com",
"report_id": "agg-1",
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-02 00:00:00",
},
"policy_published": {"domain": "example.com", "p": "none"},
"records": [
{
"interval_begin": "2024-01-01 00:00:00",
"interval_end": "2024-01-02 00:00:00",
"normalized_timespan": False,
"source": {
"ip_address": "192.0.2.1",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"count": 4,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": None,
},
"auth_results": {
"dkim": [
{
"domain": "example.com",
"selector": "s",
"result": "pass",
"human_result": None,
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": None,
}
],
},
}
],
}
def testSplunkHECClientInit(self):
"""HECClient initializes with correct URL and headers"""
from parsedmarc.splunk import HECClient
def _failure_report():
return {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "x@example.com",
"original_rcpt_to": None,
"arrival_date": "Thu, 1 Jan 2024 00:00:00 +0000",
"arrival_date_utc": "2024-01-01 00:00:00",
"authentication_results": None,
"delivery_result": "other",
"auth_failure": ["dmarc"],
"authentication_mechanisms": [],
"dkim_domain": None,
"reported_domain": "example.com",
"sample_headers_only": True,
"source": {
"ip_address": "192.0.2.5",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"sample": "...",
"parsed_sample": {"subject": "Test"},
}
def _smtp_tls_report():
return {
"organization_name": "example.com",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"contact_info": "tls@example.com",
"report_id": "tls-1",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"successful_session_count": 100,
"failed_session_count": 0,
}
],
}
def _ok_response():
"""Splunk HEC success response shape: {"code": 0, ...}."""
r = MagicMock()
r.json.return_value = {"code": 0, "text": "Success"}
return r
def _client():
return HECClient(
url="https://splunk.example.com:8088",
access_token="abc-token-uuid",
index="dmarc",
)
class TestHECClientInit(unittest.TestCase):
"""The HEC URL is rebuilt from the user-supplied URL into the
/services/collector/event/1.0 endpoint, and the Authorization
header is set to `Splunk <token>`."""
def test_url_rewritten_to_collector_endpoint(self):
"""A user may supply any URL on the Splunk host; the client
rewrites to the documented HEC path."""
client = HECClient(
url="https://splunk.example.com:8088",
access_token="my-token",
index="main",
url="https://splunk.example.com:8088/some/random/path",
access_token="t",
index="dmarc",
)
self.assertEqual(
client.url, "https://splunk.example.com:8088/services/collector/event/1.0"
)
self.assertIn("/services/collector/event/1.0", client.url)
self.assertEqual(client.access_token, "my-token")
self.assertEqual(client.index, "main")
self.assertEqual(client.source, "parsedmarc")
self.assertIn("Splunk my-token", client.session.headers["Authorization"])
def testSplunkHECClientStripTokenPrefix(self):
"""HECClient strips 'Splunk ' prefix from token"""
from parsedmarc.splunk import HECClient
def test_authorization_header_uses_splunk_prefix(self):
client = HECClient(url="https://h:8088", access_token="my-token", index="dmarc")
self.assertEqual(client.session.headers["Authorization"], "Splunk my-token")
def test_user_agent_header_is_set(self):
client = HECClient(url="https://h:8088", access_token="my-token", index="dmarc")
self.assertIn("parsedmarc", client.session.headers["User-Agent"])
def test_token_with_splunk_prefix_is_normalized(self):
"""If a user pastes `Splunk <token>` from the Splunk UI into
config, the constructor strips the prefix so the resulting
Authorization header isn't `Splunk Splunk <token>`."""
client = HECClient(
url="https://splunk.example.com",
access_token="Splunk my-token",
index="main",
url="https://h:8088",
access_token="Splunk abc-token-uuid",
index="dmarc",
)
self.assertEqual(client.access_token, "my-token")
self.assertEqual(client.access_token, "abc-token-uuid")
def testSplunkBackwardCompatAlias(self):
"""HECClient forensic alias points to failure method"""
from parsedmarc.splunk import HECClient
def test_token_without_prefix_is_unchanged(self):
"""The lstrip("Splunk ") implementation has character-set
semantics, not prefix semantics — it happens to work for the
UUID-shaped tokens HEC issues (none of S/p/l/u/n/k/space
appear in a UUID's hex character set). A token containing
only hex digits and dashes is unchanged."""
client = HECClient(
url="https://h:8088",
access_token="abc-token-uuid",
index="dmarc",
)
self.assertEqual(client.access_token, "abc-token-uuid")
def test_common_data_carries_host_source_and_index(self):
"""Splunk events inherit these three top-level fields. A
regression here would mis-route events to the wrong index."""
client = HECClient(
url="https://h:8088", access_token="t", index="dmarc", source="my-source"
)
self.assertEqual(client._common_data["index"], "dmarc")
self.assertEqual(client._common_data["source"], "my-source")
# host defaults to socket.getfqdn(); non-empty is enough.
self.assertTrue(client._common_data["host"])
class TestSaveAggregateReportsToSplunk(unittest.TestCase):
"""Each record is emitted as a separate Splunk event, with the
record's interval_begin as the event timestamp, the report's
metadata flattened onto the event, and sourcetype dmarc:aggregate."""
def test_sends_one_event_per_record(self):
"""Two-record report → two newline-separated events in the POST body."""
client = _client()
report = _aggregate_report()
report["records"].append(report["records"][0].copy())
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_aggregate_reports_to_splunk(report)
body = client.session.post.call_args.kwargs["data"]
events = [json.loads(line) for line in body.strip().split("\n")]
self.assertEqual(len(events), 2)
for event in events:
self.assertEqual(event["sourcetype"], "dmarc:aggregate")
self.assertEqual(event["index"], "dmarc")
def test_event_payload_carries_source_metadata(self):
"""The flattened event includes source attribution fields a
Splunk dashboard would filter on."""
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_aggregate_reports_to_splunk(_aggregate_report())
body = client.session.post.call_args.kwargs["data"]
event = json.loads(body.strip())["event"]
self.assertEqual(event["source_ip_address"], "192.0.2.1")
self.assertEqual(event["header_from"], "example.com")
self.assertEqual(event["message_count"], 4)
self.assertEqual(event["passed_dmarc"], True)
self.assertEqual(event["org_name"], "TestOrg")
def test_event_includes_published_policy(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_aggregate_reports_to_splunk(_aggregate_report())
event = json.loads(client.session.post.call_args.kwargs["data"].strip())[
"event"
]
self.assertEqual(
event["published_policy"], {"domain": "example.com", "p": "none"}
)
def test_dict_input_normalized_to_list(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_aggregate_reports_to_splunk(_aggregate_report())
client.session.post.assert_called_once()
def test_empty_list_is_a_noop(self):
client = _client()
client.session = MagicMock()
client.save_aggregate_reports_to_splunk([])
client.session.post.assert_not_called()
def test_post_uses_session_verify_and_timeout(self):
client = HECClient(
url="https://h:8088",
access_token="t",
index="dmarc",
verify=False,
timeout=15,
)
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_aggregate_reports_to_splunk(_aggregate_report())
kwargs = client.session.post.call_args.kwargs
self.assertEqual(kwargs["verify"], False)
self.assertEqual(kwargs["timeout"], 15)
def test_non_zero_response_code_raises_splunk_error(self):
"""HEC returns code=0 on success and non-zero codes for
token/index/format errors. The error text from HEC carries
the diagnosis and is propagated."""
client = _client()
client.session = MagicMock()
bad = MagicMock()
bad.json.return_value = {"code": 4, "text": "Invalid token"}
client.session.post.return_value = bad
with self.assertRaises(SplunkError) as ctx:
client.save_aggregate_reports_to_splunk(_aggregate_report())
self.assertIn("Invalid token", str(ctx.exception))
def test_post_exception_translates_to_splunk_error(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = OSError("network")
with self.assertRaises(SplunkError) as ctx:
client.save_aggregate_reports_to_splunk(_aggregate_report())
self.assertIn("network", str(ctx.exception))
class TestSaveFailureReportsToSplunk(unittest.TestCase):
def test_sends_one_event_per_report(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_failure_reports_to_splunk([_failure_report(), _failure_report()])
events = [
json.loads(line)
for line in client.session.post.call_args.kwargs["data"].strip().split("\n")
]
self.assertEqual(len(events), 2)
for event in events:
self.assertEqual(event["sourcetype"], "dmarc:failure")
def test_event_payload_is_the_report_dict(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_failure_reports_to_splunk(_failure_report())
event = json.loads(client.session.post.call_args.kwargs["data"].strip())[
"event"
]
self.assertEqual(event["reported_domain"], "example.com")
def test_empty_list_is_a_noop(self):
client = _client()
client.session = MagicMock()
client.save_failure_reports_to_splunk([])
client.session.post.assert_not_called()
def test_non_zero_response_code_raises_splunk_error(self):
client = _client()
client.session = MagicMock()
bad = MagicMock()
bad.json.return_value = {"code": 6, "text": "Invalid data format"}
client.session.post.return_value = bad
with self.assertRaises(SplunkError):
client.save_failure_reports_to_splunk(_failure_report())
def test_post_exception_translates_to_splunk_error(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = RuntimeError("conn refused")
with self.assertRaises(SplunkError):
client.save_failure_reports_to_splunk(_failure_report())
def test_verify_false_logs_skip_message(self):
"""verify=False should leave a debug breadcrumb so operators
can spot misconfigured TLS in their logs."""
client = HECClient(
url="https://h:8088", access_token="t", index="dmarc", verify=False
)
client.session = MagicMock()
client.session.post.return_value = _ok_response()
with self.assertLogs("parsedmarc.log", level="DEBUG") as cm:
client.save_failure_reports_to_splunk(_failure_report())
self.assertTrue(
any("Skipping certificate verification" in m for m in cm.output)
)
class TestSaveSmtpTlsReportsToSplunk(unittest.TestCase):
def test_sends_one_event_per_report(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_smtp_tls_reports_to_splunk([_smtp_tls_report()])
events = [
json.loads(line)
for line in client.session.post.call_args.kwargs["data"].strip().split("\n")
]
self.assertEqual(len(events), 1)
self.assertEqual(events[0]["sourcetype"], "smtp:tls")
def test_dict_input_normalized_to_list(self):
client = _client()
client.session = MagicMock()
client.session.post.return_value = _ok_response()
client.save_smtp_tls_reports_to_splunk(_smtp_tls_report())
client.session.post.assert_called_once()
def test_empty_list_is_a_noop(self):
client = _client()
client.session = MagicMock()
client.save_smtp_tls_reports_to_splunk([])
client.session.post.assert_not_called()
def test_non_zero_response_code_raises_splunk_error(self):
client = _client()
client.session = MagicMock()
bad = MagicMock()
bad.json.return_value = {"code": 7, "text": "Incorrect index"}
client.session.post.return_value = bad
with self.assertRaises(SplunkError):
client.save_smtp_tls_reports_to_splunk(_smtp_tls_report())
def test_post_exception_translates_to_splunk_error(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = RuntimeError("conn refused")
with self.assertRaises(SplunkError):
client.save_smtp_tls_reports_to_splunk(_smtp_tls_report())
def test_verify_false_logs_skip_message(self):
client = HECClient(
url="https://h:8088", access_token="t", index="dmarc", verify=False
)
client.session = MagicMock()
client.session.post.return_value = _ok_response()
with self.assertLogs("parsedmarc.log", level="DEBUG") as cm:
client.save_smtp_tls_reports_to_splunk(_smtp_tls_report())
self.assertTrue(
any("Skipping certificate verification" in m for m in cm.output)
)
class TestHECClientClose(unittest.TestCase):
def test_close_closes_session(self):
client = _client()
client.session = MagicMock()
client.close()
client.session.close.assert_called_once()
class TestSplunkBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
HECClient.save_forensic_reports_to_splunk, # type: ignore[attr-defined]
HECClient.save_failure_reports_to_splunk,
+346 -20
View File
@@ -1,34 +1,360 @@
"""Tests for parsedmarc.syslog"""
import json
import logging
import socket
import unittest
from unittest.mock import MagicMock, patch
from parsedmarc.syslog import SyslogClient
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _sample_aggregate_report():
return {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
"org_name": "example.com",
"org_email": "dmarc@example.com",
"org_extra_contact_info": None,
"report_id": "agg-1",
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-02 00:00:00",
"timespan_requires_normalization": False,
"original_timespan_seconds": 86400,
"errors": [],
"generator": None,
},
"policy_published": {
"domain": "example.com",
"adkim": "r",
"aspf": "r",
"p": "none",
"sp": "none",
"pct": None,
"fo": None,
"np": None,
"testing": None,
"discovery_method": None,
},
"records": [
{
"interval_begin": "2024-01-01 00:00:00",
"interval_end": "2024-01-02 00:00:00",
"normalized_timespan": False,
"source": {
"ip_address": "192.0.2.1",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"count": 9,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": None,
},
"auth_results": {"dkim": [], "spf": []},
}
],
}
def testSyslogClientUdpInit(self):
"""SyslogClient creates UDP handler"""
from parsedmarc.syslog import SyslogClient
client = SyslogClient("localhost", 514, protocol="udp")
self.assertEqual(client.server_name, "localhost")
self.assertEqual(client.server_port, 514)
self.assertEqual(client.protocol, "udp")
class _CapturingHandler(logging.Handler):
"""Records the messages emitted by SyslogClient.logger."""
def testSyslogClientInvalidProtocol(self):
"""SyslogClient with invalid protocol raises ValueError"""
from parsedmarc.syslog import SyslogClient
def __init__(self):
super().__init__()
self.messages: list[str] = []
with self.assertRaises(ValueError):
SyslogClient("localhost", 514, protocol="invalid")
def emit(self, record):
self.messages.append(record.getMessage())
def testSyslogBackwardCompatAlias(self):
"""SyslogClient forensic alias points to failure method"""
from parsedmarc.syslog import SyslogClient
def _fresh_logger():
"""Reset the module-level parsedmarc_syslog logger before each test."""
logging.getLogger("parsedmarc_syslog").handlers.clear()
class TestSyslogClientInitUdp(unittest.TestCase):
"""UDP is the default protocol — back-compat for every existing
deployment. The handler must be SOCK_DGRAM, not SOCK_STREAM."""
def test_udp_uses_dgram_socket(self):
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as mock_handler:
SyslogClient(server_name="syslog.example.com", server_port=514)
mock_handler.assert_called_once_with(
address=("syslog.example.com", 514),
socktype=socket.SOCK_DGRAM,
)
def test_udp_is_default(self):
"""Explicit protocol='udp' and default produce the same call."""
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as mock_handler:
SyslogClient("s", 514, protocol="udp")
kwargs = mock_handler.call_args.kwargs
self.assertEqual(kwargs["socktype"], socket.SOCK_DGRAM)
class TestSyslogClientInitTcp(unittest.TestCase):
"""TCP path applies the configured timeout to the underlying socket
and uses SOCK_STREAM. Wrong socket type would silently fail to
deliver messages."""
def test_tcp_uses_stream_socket(self):
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as mock_handler:
mock_handler.return_value.socket = MagicMock()
SyslogClient("s", 6514, protocol="tcp")
kwargs = mock_handler.call_args.kwargs
self.assertEqual(kwargs["socktype"], socket.SOCK_STREAM)
def test_tcp_applies_timeout_to_socket(self):
_fresh_logger()
sock = MagicMock()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as mock_handler:
mock_handler.return_value.socket = sock
SyslogClient("s", 6514, protocol="tcp", timeout=12.5)
sock.settimeout.assert_called_once_with(12.5)
class TestSyslogClientInitTls(unittest.TestCase):
"""TLS path: TLS ≥1.2 minimum, optional CA + client cert, retry on
connection failure. Each has user-facing security consequences."""
def _patch_handler_and_ssl(self):
handler_patch = patch("parsedmarc.syslog.logging.handlers.SysLogHandler")
ssl_patch = patch("parsedmarc.syslog.ssl.create_default_context")
return handler_patch, ssl_patch
def test_tls_enforces_tls_1_2_minimum(self):
"""The lowest version security teams accept is TLS 1.2."""
_fresh_logger()
import ssl
handler_p, ssl_p = self._patch_handler_and_ssl()
with handler_p as mock_h, ssl_p as mock_ctx_factory:
mock_h.return_value.socket = MagicMock()
ctx = mock_ctx_factory.return_value
SyslogClient("s", 6514, protocol="tls")
mock_ctx_factory.assert_called_once_with()
self.assertEqual(ctx.minimum_version, ssl.TLSVersion.TLSv1_2)
def test_tls_loads_ca_when_cafile_provided(self):
_fresh_logger()
handler_p, ssl_p = self._patch_handler_and_ssl()
with handler_p as mock_h, ssl_p as mock_ctx_factory:
mock_h.return_value.socket = MagicMock()
SyslogClient("s", 6514, protocol="tls", cafile_path="/etc/ca.pem")
mock_ctx_factory.return_value.load_verify_locations.assert_called_once_with(
cafile="/etc/ca.pem"
)
def test_tls_loads_client_cert_when_both_paths_provided(self):
_fresh_logger()
handler_p, ssl_p = self._patch_handler_and_ssl()
with handler_p as mock_h, ssl_p as mock_ctx_factory:
mock_h.return_value.socket = MagicMock()
SyslogClient(
"s",
6514,
protocol="tls",
certfile_path="/etc/c.pem",
keyfile_path="/etc/k.pem",
)
mock_ctx_factory.return_value.load_cert_chain.assert_called_once_with(
certfile="/etc/c.pem",
keyfile="/etc/k.pem",
)
def test_tls_warns_when_only_certfile_provided(self):
"""Half-configured client auth (cert without key, or vice
versa) is a config bug that disables client auth silently.
The code warns instead."""
_fresh_logger()
handler_p, ssl_p = self._patch_handler_and_ssl()
with handler_p as mock_h, ssl_p:
mock_h.return_value.socket = MagicMock()
with self.assertLogs("parsedmarc_syslog", level="WARNING") as cm:
SyslogClient("s", 6514, protocol="tls", certfile_path="/etc/c.pem")
self.assertTrue(
any("Both certfile_path and keyfile_path" in m for m in cm.output)
)
def test_tls_wraps_socket_with_server_hostname(self):
"""server_name is used as TLS SNI / certificate-verification hostname."""
_fresh_logger()
wrapped_sock = MagicMock()
handler_p, ssl_p = self._patch_handler_and_ssl()
with handler_p as mock_h, ssl_p as mock_ctx_factory:
raw_sock = MagicMock()
mock_h.return_value.socket = raw_sock
mock_ctx_factory.return_value.wrap_socket.return_value = wrapped_sock
SyslogClient("syslog.example.com", 6514, protocol="tls")
mock_ctx_factory.return_value.wrap_socket.assert_called_once_with(
raw_sock, server_hostname="syslog.example.com"
)
def test_tls_retries_then_succeeds(self):
"""Transient connection failures should retry up to
retry_attempts before raising."""
_fresh_logger()
attempts = {"n": 0}
def flaky_handler(*args, **kwargs):
attempts["n"] += 1
if attempts["n"] < 2:
raise OSError("network down")
h = MagicMock()
h.socket = MagicMock()
return h
with (
patch(
"parsedmarc.syslog.logging.handlers.SysLogHandler",
side_effect=flaky_handler,
),
patch("parsedmarc.syslog.ssl.create_default_context"),
patch("parsedmarc.syslog.time.sleep") as mock_sleep,
):
SyslogClient("s", 6514, protocol="tls", retry_attempts=3, retry_delay=1)
self.assertEqual(attempts["n"], 2)
mock_sleep.assert_called_with(1)
def test_tls_raises_after_exhausting_retries(self):
_fresh_logger()
with (
patch(
"parsedmarc.syslog.logging.handlers.SysLogHandler",
side_effect=OSError("network down"),
),
patch("parsedmarc.syslog.ssl.create_default_context"),
patch("parsedmarc.syslog.time.sleep"),
):
with self.assertRaises(OSError):
SyslogClient("s", 6514, protocol="tls", retry_attempts=2, retry_delay=0)
class TestSyslogClientInitInvalidProtocol(unittest.TestCase):
"""Typos in the protocol field should fail loudly."""
def test_invalid_protocol_raises_value_error(self):
_fresh_logger()
with self.assertRaises(ValueError) as ctx:
SyslogClient("s", 514, protocol="udb")
self.assertIn("udb", str(ctx.exception))
self.assertIn("'udp', 'tcp', or 'tls'", str(ctx.exception))
class TestSyslogClientSave(unittest.TestCase):
"""save_* methods emit one syslog message per CSV row, each as a
JSON-encoded payload. Wrong format would break downstream parsers."""
def _client_with_capture(self):
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler"):
client = SyslogClient("s", 514)
client.logger.removeHandler(client.log_handler)
cap = _CapturingHandler()
client.logger.addHandler(cap)
return client, cap
def test_save_aggregate_emits_json_per_row(self):
client, cap = self._client_with_capture()
client.save_aggregate_report_to_syslog([_sample_aggregate_report()])
self.assertEqual(len(cap.messages), 1)
payload = json.loads(cap.messages[0])
self.assertEqual(payload["source_ip_address"], "192.0.2.1")
self.assertEqual(payload["count"], 9)
self.assertEqual(payload["org_name"], "example.com")
def test_save_failure_emits_json_per_report(self):
client, cap = self._client_with_capture()
failure_report = {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "x@example.com",
"original_rcpt_to": None,
"arrival_date": "Thu, 1 Jan 2024 00:00:00 +0000",
"arrival_date_utc": "2024-01-01 00:00:00",
"authentication_results": None,
"delivery_result": "other",
"auth_failure": ["dmarc"],
"authentication_mechanisms": [],
"dkim_domain": None,
"reported_domain": "example.com",
"sample_headers_only": True,
"source": {
"ip_address": "192.0.2.5",
"country": "US",
"reverse_dns": None,
"base_domain": None,
"name": None,
"type": None,
"asn": 64496,
"as_name": "Example AS",
"as_domain": "example.net",
},
"sample": "...",
"parsed_sample": {"subject": "Test"},
}
client.save_failure_report_to_syslog([failure_report])
self.assertEqual(len(cap.messages), 1)
payload = json.loads(cap.messages[0])
self.assertEqual(payload["reported_domain"], "example.com")
self.assertEqual(payload["source_ip_address"], "192.0.2.5")
def test_save_smtp_tls_emits_json_per_policy(self):
client, cap = self._client_with_capture()
report = {
"organization_name": "example.com",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
"contact_info": "tls@example.com",
"report_id": "tls-1",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"successful_session_count": 100,
"failed_session_count": 0,
}
],
}
client.save_smtp_tls_report_to_syslog([report])
self.assertEqual(len(cap.messages), 1)
payload = json.loads(cap.messages[0])
self.assertEqual(payload["policy_domain"], "example.com")
class TestSyslogClientClose(unittest.TestCase):
def test_close_removes_and_closes_handler(self):
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as mock_handler:
client = SyslogClient("s", 514)
client.close()
mock_handler.return_value.close.assert_called_once()
self.assertNotIn(mock_handler.return_value, client.logger.handlers)
class TestSyslogBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
SyslogClient.save_forensic_report_to_syslog, # type: ignore[attr-defined]
SyslogClient.save_failure_report_to_syslog,
+89 -44
View File
@@ -3,74 +3,119 @@
import unittest
from unittest.mock import MagicMock
import parsedmarc
import parsedmarc.webhook
from parsedmarc.webhook import WebhookClient
class Test(unittest.TestCase):
"""Kitchen-sink tests redistributed from the original
tests.py monolith. Future PRs should split these further
into purpose-specific TestCase subclasses as natural
groupings emerge."""
def _client():
return WebhookClient(
aggregate_url="http://agg.example.com",
failure_url="http://fail.example.com",
smtp_tls_url="http://tls.example.com",
)
def testWebhookClientInit(self):
"""WebhookClient initializes with correct attributes"""
from parsedmarc.webhook import WebhookClient
client = WebhookClient(
aggregate_url="http://agg.example.com",
failure_url="http://fail.example.com",
smtp_tls_url="http://tls.example.com",
)
class TestWebhookClientInit(unittest.TestCase):
"""The constructor stores URLs per report type. A mix-up here
would route reports to the wrong endpoint silently."""
def test_urls_and_timeout_stored(self):
client = _client()
self.assertEqual(client.aggregate_url, "http://agg.example.com")
self.assertEqual(client.failure_url, "http://fail.example.com")
self.assertEqual(client.smtp_tls_url, "http://tls.example.com")
self.assertEqual(client.timeout, 60)
def testWebhookClientSaveMethods(self):
"""WebhookClient save methods call _send_to_webhook"""
from parsedmarc.webhook import WebhookClient
def test_custom_timeout_respected(self):
client = WebhookClient(
aggregate_url="a", failure_url="f", smtp_tls_url="t", timeout=120
)
self.assertEqual(client.timeout, 120)
client = WebhookClient("http://a", "http://f", "http://t")
def test_session_headers_set(self):
"""The Content-Type is required by virtually every webhook
receiver to know how to deserialize the body."""
client = _client()
self.assertEqual(client.session.headers["Content-Type"], "application/json")
self.assertIn("parsedmarc", client.session.headers["User-Agent"])
class TestWebhookClientSaveMethods(unittest.TestCase):
"""Each save_* sends the payload to the URL configured for that
report type. A typo on which URL each method uses would
permanently mis-route reports of that type."""
def test_aggregate_posts_to_aggregate_url(self):
client = _client()
client.session = MagicMock()
client.save_aggregate_report_to_webhook('{"test": 1}')
client.session.post.assert_called_with(
"http://a", data='{"test": 1}', timeout=60
client.save_aggregate_report_to_webhook('{"agg": 1}')
client.session.post.assert_called_once_with(
"http://agg.example.com", data='{"agg": 1}', timeout=60
)
def test_failure_posts_to_failure_url(self):
client = _client()
client.session = MagicMock()
client.save_failure_report_to_webhook('{"fail": 1}')
client.session.post.assert_called_with(
"http://f", data='{"fail": 1}', timeout=60
client.session.post.assert_called_once_with(
"http://fail.example.com", data='{"fail": 1}', timeout=60
)
def test_smtp_tls_posts_to_smtp_tls_url(self):
client = _client()
client.session = MagicMock()
client.save_smtp_tls_report_to_webhook('{"tls": 1}')
client.session.post.assert_called_with(
"http://t", data='{"tls": 1}', timeout=60
)
def testWebhookBackwardCompatAlias(self):
"""WebhookClient forensic alias points to failure method"""
from parsedmarc.webhook import WebhookClient
self.assertIs(
WebhookClient.save_forensic_report_to_webhook, # type: ignore[attr-defined]
WebhookClient.save_failure_report_to_webhook,
client.session.post.assert_called_once_with(
"http://tls.example.com", data='{"tls": 1}', timeout=60
)
class TestWebhookClient(unittest.TestCase):
"""Tests for webhook client initialization and close"""
class TestWebhookErrorHandling(unittest.TestCase):
"""HTTP / network failures from the webhook receiver must NOT
abort the surrounding parse-and-output batch — they're logged
and swallowed. Misbehaving webhooks shouldn't take down DMARC
processing."""
def testClose(self):
"""WebhookClient.close() closes session"""
client = parsedmarc.webhook.WebhookClient(
aggregate_url="http://invalid.test/agg",
failure_url="http://invalid.test/fail",
smtp_tls_url="http://invalid.test/tls",
)
def test_network_error_is_logged_and_swallowed(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = OSError("connection refused")
with self.assertLogs("parsedmarc.log", level="ERROR") as cm:
# Should NOT raise.
client.save_aggregate_report_to_webhook('{"a": 1}')
self.assertTrue(any("Webhook Error" in m for m in cm.output))
self.assertTrue(any("connection refused" in m for m in cm.output))
def test_error_in_failure_save_is_swallowed(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = RuntimeError("timeout")
with self.assertLogs("parsedmarc.log", level="ERROR"):
client.save_failure_report_to_webhook('{"f": 1}')
def test_error_in_smtp_tls_save_is_swallowed(self):
client = _client()
client.session = MagicMock()
client.session.post.side_effect = RuntimeError("boom")
with self.assertLogs("parsedmarc.log", level="ERROR"):
client.save_smtp_tls_report_to_webhook('{"t": 1}')
class TestWebhookClientClose(unittest.TestCase):
def test_close_closes_session(self):
client = _client()
mock_close = MagicMock()
client.session.close = mock_close
client.close()
mock_close.assert_called_once()
class TestWebhookBackwardCompatAlias(unittest.TestCase):
def test_forensic_alias_points_to_failure_method(self):
self.assertIs(
WebhookClient.save_forensic_report_to_webhook, # type: ignore[attr-defined]
WebhookClient.save_failure_report_to_webhook,
)
if __name__ == "__main__":
unittest.main(verbosity=2)