Add optional PostgreSQL storage backend (#667)

Adds a PostgreSQL output backend as a lighter-weight alternative to
Elasticsearch/OpenSearch, configured via a [postgresql] section
(host/port/user/password/database or a libpq connection_string). Tables
are created automatically on first run; a Grafana dashboard is included.

- psycopg is an optional extra (pip install parsedmarc[postgresql]); the
  import is guarded so `import parsedmarc` works without it, and
  PostgreSQLClient raises a clear install hint when constructed without
  the driver. Binary wheels aren't available for every platform.
- Schema captures the RFC 9990 / DMARCbis aggregate fields: np, testing,
  discovery_method, generator, xml_namespace, and per-result human_result
  on the DKIM/SPF auth-result tables.
- forensic -> failure naming throughout (table dmarc_failure_report,
  save_failure_report_to_postgresql, dashboard, docs) to match #659.
- Failure-report de-duplication mirrors the Elasticsearch backend exactly:
  arrival date + From + To + Subject (NULL-safe via IS NOT DISTINCT FROM;
  semantic JSONB equality). Aggregate and SMTP-TLS use ON CONFLICT.
- PostgreSQLClient.close() for clean CLI shutdown; comment documents why
  the two timestamp helpers must stay distinct (report dates are local,
  record/SMTP-TLS dates are UTC).
- CLI: config parse raises ConfigurationError on missing
  host/connection_string; wired into _init_output_clients + save loops.
- Tests in tests/test_postgres.py (helpers, mocked-DB save assertions,
  create_tables, connect/error wrapping, dedup, real-sample round trip)
  and tests/test_cli.py (config parse + end-to-end save wiring incl.
  AlreadySaved/PostgreSQLError handling). postgres.py at 99% line
  coverage; only _main's output-client-init retry path is left.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Fabio Scaccabarozzi
2026-05-21 14:17:49 +01:00
committed by GitHub
parent 0a703172de
commit 327fcff2b9
8 changed files with 3900 additions and 0 deletions
+6
View File
@@ -27,6 +27,12 @@ Several elements that became `langAttrString` in RFC 9990 (`extra_contact_info`,
Backwards compatibility to RFC 7489 is maintained.
#### PostgreSQL storage backend
New optional PostgreSQL output backend as a lighter-weight alternative to Elasticsearch/OpenSearch, configured via a `[postgresql]` section (host/port/user/password/database or a libpq `connection_string`). Tables are created automatically on first run, and the schema captures the RFC 9990 aggregate fields (`np`, `testing`, `discovery_method`, `generator`, `xml_namespace`, and per-result `human_result`). A Grafana dashboard (`grafana/Grafana-DMARC_Reports-PostgreSQL.json`) is included. Aggregate and SMTP-TLS reports are de-duplicated via `ON CONFLICT`; failure reports via an arrival-date / From / To / Subject check mirroring the Elasticsearch backend.
The backend is opt-in: install it with `pip install parsedmarc[postgresql]` (it pulls in `psycopg`). It is not a mandatory dependency because the prebuilt `psycopg` binary wheels are not available for every platform.
#### Docker-secret support via `_FILE` env vars
Any `PARSEDMARC_{SECTION}_{KEY}` environment variable can now also be supplied via a file by appending `_FILE` to its name (e.g. `PARSEDMARC_IMAP_PASSWORD_FILE=/run/secrets/imap_password`). The file's contents (with trailing CR/LF stripped) are used as the value. This is the same convention used by the official Postgres, MariaDB, and Redis container images, so credentials no longer have to appear in plain `environment:` blocks where `docker inspect`, container logs, and `/proc/<pid>/environ` would expose them.
+46
View File
@@ -367,6 +367,52 @@ The full set of configuration options are:
`%` characters must be escaped with another `%` character,
so use `%%` wherever a `%` character is used.
:::
- `postgresql`
- `host` - str: The PostgreSQL server hostname or IP address.
Required unless `connection_string` is provided.
- `port` - int: The PostgreSQL server port (Default: `5432`)
- `user` - str: The database user name (Optional)
- `password` - str: The database user password (Optional)
- `database` - str: The database name (Optional)
- `connection_string` - str: A full libpq connection string or URI
(e.g. `postgresql://user:pass@host/dbname`). When provided,
all individual parameters above are ignored.
The PostgreSQL backend is an optional extra. Install it with
`pip install parsedmarc[postgresql]` (it pulls in `psycopg`); the
prebuilt binary wheels are not available for every platform, which is
why it is not a mandatory dependency.
Tables are created automatically on first run using
`CREATE TABLE IF NOT EXISTS`, so no manual schema migration is needed
for fresh installations.
**Example configuration:**
```ini
[postgresql]
host = localhost
port = 5432
user = parsedmarc
password = secret
database = parsedmarc
```
Or using a DSN/URI:
```ini
[postgresql]
connection_string = postgresql://parsedmarc:secret@localhost/parsedmarc
```
Saving parsed data to PostgreSQL is controlled by the `[general]`
options `save_aggregate`, `save_failure`, and `save_smtp_tls`
(`save_forensic` is still accepted as a deprecated alias for
`save_failure`). These flags must be set to `True` for the
corresponding report types (aggregate DMARC, failure DMARC, and
SMTP TLS reports) or no data will be written to PostgreSQL, even if
this section is configured.
- `s3`
- `bucket` - str: The S3 bucket name
- `path` - str: The path to upload reports to (Default: `/`)
File diff suppressed because it is too large Load Diff
+68
View File
@@ -34,6 +34,7 @@ from parsedmarc import (
loganalytics,
opensearch,
parse_report_file,
postgres,
s3,
save_output,
splunk,
@@ -923,6 +924,26 @@ def _parse_config(config: ConfigParser, opts):
if "secret_access_key" in s3_config:
opts.s3_secret_access_key = s3_config["secret_access_key"]
if "postgresql" in config.sections():
pg_config = config["postgresql"]
if "connection_string" in pg_config:
opts.postgresql_connection_string = pg_config["connection_string"]
elif "host" in pg_config:
opts.postgresql_host = pg_config["host"]
if "port" in pg_config:
opts.postgresql_port = pg_config.getint("port")
if "user" in pg_config:
opts.postgresql_user = pg_config["user"]
if "password" in pg_config:
opts.postgresql_password = pg_config["password"]
if "database" in pg_config:
opts.postgresql_database = pg_config["database"]
else:
raise ConfigurationError(
"host (or connection_string) setting missing from the "
"postgresql config section"
)
if "syslog" in config.sections():
syslog_config = config["syslog"]
if "server" in syslog_config:
@@ -1109,6 +1130,22 @@ def _init_output_clients(opts):
except Exception as e:
raise RuntimeError(f"S3: {e}") from e
try:
if opts.postgresql_host or opts.postgresql_connection_string:
logger.debug("Initializing PostgreSQL client")
pg_client = postgres.PostgreSQLClient(
connection_string=opts.postgresql_connection_string,
host=opts.postgresql_host,
port=int(opts.postgresql_port or 5432),
user=opts.postgresql_user,
password=opts.postgresql_password,
database=opts.postgresql_database,
)
pg_client.create_tables()
clients["postgresql_client"] = pg_client
except Exception as e:
raise RuntimeError(f"PostgreSQL: {e}") from e
try:
if opts.syslog_server:
logger.debug(
@@ -1394,6 +1431,7 @@ def _main():
hec_client = clients.get("hec_client")
gelf_client = clients.get("gelf_client")
webhook_client = clients.get("webhook_client")
pg_client = clients.get("postgresql_client")
kafka_aggregate_topic = opts.kafka_aggregate_topic
kafka_failure_topic = opts.kafka_failure_topic
@@ -1455,6 +1493,14 @@ def _main():
except Exception as error_:
log_output_error("S3", error_.__str__())
try:
if pg_client:
pg_client.save_aggregate_report_to_postgresql(report)
except postgres.AlreadySaved as warning:
logger.warning(warning.__str__())
except postgres.PostgreSQLError as error_:
log_output_error("PostgreSQL", error_.__str__())
try:
if syslog_client:
syslog_client.save_aggregate_report_to_syslog(report)
@@ -1540,6 +1586,14 @@ def _main():
except Exception as error_:
log_output_error("S3", error_.__str__())
try:
if pg_client:
pg_client.save_failure_report_to_postgresql(report)
except postgres.AlreadySaved as warning:
logger.warning(warning.__str__())
except postgres.PostgreSQLError as error_:
log_output_error("PostgreSQL", error_.__str__())
try:
if syslog_client:
syslog_client.save_failure_report_to_syslog(report)
@@ -1625,6 +1679,14 @@ def _main():
except Exception as error_:
log_output_error("S3", error_.__str__())
try:
if pg_client:
pg_client.save_smtp_tls_report_to_postgresql(report)
except postgres.AlreadySaved as warning:
logger.warning(warning.__str__())
except postgres.PostgreSQLError as error_:
log_output_error("PostgreSQL", error_.__str__())
try:
if syslog_client:
syslog_client.save_smtp_tls_report_to_syslog(report)
@@ -1940,6 +2002,12 @@ def _main():
webhook_smtp_tls_url=None,
webhook_timeout=60,
normalize_timespan_threshold_hours=24.0,
postgresql_host=None,
postgresql_port=5432,
postgresql_user=None,
postgresql_password=None,
postgresql_database=None,
postgresql_connection_string=None,
fail_on_output_error=False,
)
+847
View File
@@ -0,0 +1,847 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
from typing import Optional, Union
try:
import psycopg
from psycopg import types as psycopg_types
except ImportError:
psycopg = None # type: ignore[assignment]
psycopg_types = None # type: ignore[assignment]
from parsedmarc.log import logger
from parsedmarc.utils import human_timestamp_to_datetime
# psycopg is an optional dependency (the PostgreSQL backend is opt-in). The
# pure helper functions below work without it; only PostgreSQLClient needs a
# live driver, so the import error surfaces at client construction with a
# pip-install hint rather than breaking ``import parsedmarc`` for everyone.
_PSYCOPG_INSTALL_HINT = (
"The PostgreSQL backend requires the 'psycopg' package. "
"Install it with: pip install parsedmarc[postgresql]"
)
# Two timestamp conventions coexist in parsed reports, so two helpers are
# needed — do not collapse them into one. Aggregate *report* begin/end dates
# come from ``timestamp_to_human()`` → ``datetime.fromtimestamp()``, which is
# **local** naive time, so they go through ``_naive_local_to_timestamptz``.
# Aggregate *record* interval_begin/end and SMTP-TLS begin/end are already
# **UTC** naive strings, so they only need a ``+00`` suffix via
# ``_ensure_utc_suffix``. Using the wrong helper silently shifts timestamps.
def _ensure_utc_suffix(value: Optional[str]) -> Optional[str]:
"""Append ``+00`` to a timestamp string if it lacks timezone info.
Several parsers produce ``YYYY-MM-DD HH:MM:SS`` format strings that
are known to be UTC but lack an explicit offset. PostgreSQL
``TIMESTAMPTZ`` columns need the offset to avoid interpreting the
value in the session timezone.
"""
if value and "+" not in value and "-" not in value[10:] and "Z" not in value:
return value + "+00"
return value
def _naive_local_to_timestamptz(value: Optional[str]) -> Optional[str]:
"""Convert a naive local-time string to an ISO 8601 string with offset.
``timestamp_to_human()`` produces ``YYYY-MM-DD HH:MM:SS`` in
**local** time (via ``datetime.fromtimestamp()``). Inserting such
a string into a ``TIMESTAMPTZ`` column would cause PostgreSQL to
interpret it using the *session* timezone, which may differ from
the machine's local timezone.
This helper re-parses the string, attaches the local timezone
offset, and returns an ISO 8601 representation that PostgreSQL
will interpret unambiguously.
"""
if not value:
return value
naive = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
aware = naive.astimezone() # attaches the local system timezone
return aware.isoformat()
def _normalize_arrival_date(value: Optional[str]) -> Optional[str]:
"""Normalize a failure-report ``arrival_date`` for safe TIMESTAMPTZ insert.
The arrival date may be an RFC 2822 string (e.g.
``Fri, 28 Oct 2022 00:34:24 +0800``) or an ISO 8601 string.
``human_timestamp_to_datetime`` (backed by *dateutil*) can parse
both. We convert to UTC and return an ISO 8601 string with offset
so PostgreSQL interprets it unambiguously.
"""
if not value:
return value
try:
dt = human_timestamp_to_datetime(value, to_utc=True)
return dt.strftime("%Y-%m-%d %H:%M:%S") + "+00"
except Exception:
# If parsing fails, return as-is and let PostgreSQL try.
return value
def _contact_info_to_text(
value: Union[str, list, None],
) -> Optional[str]:
"""Ensure ``contact_info`` is a plain string.
The TLS-RPT ``contact-info`` field is normally a single string, but
the TypedDict allows ``Union[str, List[str]]``. If a list is
encountered, join the entries so they fit into a ``TEXT`` column.
"""
if value is None:
return None
if isinstance(value, list):
return ", ".join(str(v) for v in value)
return str(value)
class PostgreSQLError(RuntimeError):
"""Raised when a PostgreSQL-level error occurs"""
class AlreadySaved(ValueError):
"""Raised when an identical report already exists in the database"""
class PostgreSQLClient:
"""A client for saving DMARC reports to a PostgreSQL database.
Accepts either a full libpq connection string/DSN via
*connection_string* or individual connection parameters. When both
are supplied *connection_string* takes precedence.
"""
def __init__(
self,
connection_string: Optional[str] = None,
host: Optional[str] = None,
port: int = 5432,
user: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
) -> None:
"""
Initializes the PostgreSQLClient and opens a database connection.
Args:
connection_string: A libpq connection string or URI
(e.g. ``postgresql://user:pass@host/dbname``). When
present, individual keyword arguments are ignored.
host: Database server hostname or IP address.
port: Database server port (default: 5432).
user: Database user name.
password: Database user password.
database: Database name to connect to.
Raises:
PostgreSQLError: If psycopg is not installed or the connection
attempt fails.
"""
if psycopg is None:
raise PostgreSQLError(_PSYCOPG_INSTALL_HINT)
# Store parameters so we can reconnect later if needed.
self._connection_string = connection_string
self._host = host
self._port = port
self._user = user
self._password = password
self._database = database
self._conn: Optional[psycopg.Connection] = None
self._connect()
def _connect(self) -> None:
"""Open a new database connection using stored parameters.
Raises:
PostgreSQLError: If the connection attempt fails.
"""
logger.debug("Connecting to PostgreSQL")
try:
if self._connection_string:
self._conn = psycopg.connect(self._connection_string)
else:
self._conn = psycopg.connect(
host=self._host,
port=self._port,
user=self._user,
password=self._password,
dbname=self._database,
)
self._conn.autocommit = False
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
def close(self) -> None:
"""Close the database connection if it is open.
Called by the CLI's output-client cleanup on shutdown / config
reload. Safe to call multiple times.
"""
if self._conn is not None and not self._conn.closed:
self._conn.close()
def _ensure_connected(self) -> None:
"""Check the connection health and reconnect if necessary.
When *parsedmarc* runs in watch mode the process can stay alive
for days or weeks. PostgreSQL may drop idle connections (e.g.
server restart, ``idle_in_transaction_session_timeout``, TCP
keep-alive expiry). This method detects a closed connection
and transparently re-establishes it so that subsequent
``save_*`` calls succeed without manual intervention.
"""
if self._conn is None or self._conn.closed:
logger.warning("PostgreSQL connection lost — attempting to reconnect")
self._connect()
def create_tables(self) -> None:
"""Creates all required tables if they do not already exist.
This method is idempotent and safe to call on every startup.
Raises:
PostgreSQLError: If table creation fails.
"""
self._ensure_connected()
ddl_statements = [
# ----------------------------------------------------------------
# Aggregate reports
# ----------------------------------------------------------------
"""
CREATE TABLE IF NOT EXISTS dmarc_aggregate_report (
id BIGSERIAL PRIMARY KEY,
xml_schema TEXT,
xml_namespace TEXT,
org_name TEXT NOT NULL,
org_email TEXT,
org_extra_contact_info TEXT,
generator TEXT,
report_id TEXT NOT NULL,
begin_date TIMESTAMPTZ NOT NULL,
end_date TIMESTAMPTZ NOT NULL,
errors TEXT[],
domain TEXT NOT NULL,
adkim TEXT,
aspf TEXT,
policy TEXT,
subdomain_policy TEXT,
pct TEXT,
fo TEXT,
np TEXT,
testing TEXT,
discovery_method TEXT,
UNIQUE (org_name, report_id, domain, begin_date, end_date)
)
""",
"""
CREATE TABLE IF NOT EXISTS dmarc_aggregate_record (
id BIGSERIAL PRIMARY KEY,
report_id BIGINT NOT NULL
REFERENCES dmarc_aggregate_report(id)
ON DELETE CASCADE,
interval_begin TIMESTAMPTZ,
interval_end TIMESTAMPTZ,
source_ip_address INET,
source_country TEXT,
source_reverse_dns TEXT,
source_base_domain TEXT,
source_name TEXT,
source_type TEXT,
message_count INTEGER NOT NULL,
spf_aligned BOOLEAN,
dkim_aligned BOOLEAN,
dmarc_passed BOOLEAN,
disposition TEXT,
policy_dkim TEXT,
policy_spf TEXT,
header_from TEXT,
envelope_from TEXT,
envelope_to TEXT
)
""",
"""
CREATE TABLE IF NOT EXISTS dmarc_aggregate_record_dkim (
id BIGSERIAL PRIMARY KEY,
record_id BIGINT NOT NULL
REFERENCES dmarc_aggregate_record(id)
ON DELETE CASCADE,
domain TEXT,
selector TEXT,
result TEXT,
human_result TEXT
)
""",
"""
CREATE TABLE IF NOT EXISTS dmarc_aggregate_record_spf (
id BIGSERIAL PRIMARY KEY,
record_id BIGINT NOT NULL
REFERENCES dmarc_aggregate_record(id)
ON DELETE CASCADE,
domain TEXT,
scope TEXT,
result TEXT,
human_result TEXT
)
""",
"""
CREATE TABLE IF NOT EXISTS dmarc_aggregate_record_policy_override (
id BIGSERIAL PRIMARY KEY,
record_id BIGINT NOT NULL
REFERENCES dmarc_aggregate_record(id)
ON DELETE CASCADE,
override_type TEXT,
comment TEXT
)
""",
# ----------------------------------------------------------------
# Failure reports
# ----------------------------------------------------------------
"""
CREATE TABLE IF NOT EXISTS dmarc_failure_report (
id BIGSERIAL PRIMARY KEY,
feedback_type TEXT,
user_agent TEXT,
version TEXT,
original_envelope_id TEXT,
original_mail_from TEXT,
original_rcpt_to TEXT,
arrival_date TIMESTAMPTZ,
arrival_date_utc TIMESTAMPTZ,
authentication_results TEXT,
delivery_result TEXT,
auth_failure TEXT[],
authentication_mechanisms TEXT[],
dkim_domain TEXT,
reported_domain TEXT,
sample_headers_only BOOLEAN,
source_ip_address INET,
source_country TEXT,
source_reverse_dns TEXT,
source_base_domain TEXT,
source_name TEXT,
source_type TEXT,
sample TEXT,
sample_date TEXT,
sample_subject TEXT,
sample_body TEXT,
sample_has_defects BOOLEAN,
sample_headers JSONB,
sample_from JSONB,
sample_to JSONB
)
""",
"""
CREATE TABLE IF NOT EXISTS dmarc_failure_sample_address (
id BIGSERIAL PRIMARY KEY,
report_id BIGINT NOT NULL
REFERENCES dmarc_failure_report(id)
ON DELETE CASCADE,
address_type TEXT,
display_name TEXT,
address TEXT
)
""",
# ----------------------------------------------------------------
# SMTP TLS reports
# ----------------------------------------------------------------
"""
CREATE TABLE IF NOT EXISTS smtp_tls_report (
id BIGSERIAL PRIMARY KEY,
organization_name TEXT NOT NULL,
begin_date TIMESTAMPTZ NOT NULL,
end_date TIMESTAMPTZ NOT NULL,
contact_info TEXT,
report_id TEXT NOT NULL,
UNIQUE (organization_name, report_id, begin_date, end_date)
)
""",
"""
CREATE TABLE IF NOT EXISTS smtp_tls_policy (
id BIGSERIAL PRIMARY KEY,
report_id BIGINT NOT NULL
REFERENCES smtp_tls_report(id)
ON DELETE CASCADE,
policy_domain TEXT,
policy_type TEXT,
policy_strings TEXT[],
mx_host_patterns TEXT[],
successful_session_count INTEGER,
failed_session_count INTEGER
)
""",
"""
CREATE TABLE IF NOT EXISTS smtp_tls_failure_detail (
id BIGSERIAL PRIMARY KEY,
policy_id BIGINT NOT NULL
REFERENCES smtp_tls_policy(id)
ON DELETE CASCADE,
result_type TEXT,
failed_session_count INTEGER,
sending_mta_ip INET,
receiving_ip INET,
receiving_mx_hostname TEXT,
receiving_mx_helo TEXT,
additional_info_uri TEXT,
failure_reason_code TEXT
)
""",
# ----- indexes for Grafana dashboard query performance -----
"""
CREATE INDEX IF NOT EXISTS idx_agg_report_begin_date
ON dmarc_aggregate_report (begin_date)
""",
"""
CREATE INDEX IF NOT EXISTS idx_agg_record_report_id
ON dmarc_aggregate_record (report_id)
""",
"""
CREATE INDEX IF NOT EXISTS idx_agg_record_header_from
ON dmarc_aggregate_record (header_from)
""",
"""
CREATE INDEX IF NOT EXISTS idx_failure_report_arrival_date
ON dmarc_failure_report (arrival_date_utc)
""",
"""
CREATE INDEX IF NOT EXISTS idx_smtp_tls_report_begin_date
ON smtp_tls_report (begin_date)
""",
"""
CREATE INDEX IF NOT EXISTS idx_smtp_tls_policy_report_id
ON smtp_tls_policy (report_id)
""",
]
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
for stmt in ddl_statements:
cur.execute(stmt)
logger.debug("PostgreSQL tables verified / created")
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
def save_aggregate_report_to_postgresql(self, report: dict) -> None:
"""Saves a parsed aggregate DMARC report to PostgreSQL.
Args:
report: A parsed aggregate report dictionary as returned by
:func:`parsedmarc.parse_report_file`.
Raises:
AlreadySaved: If an identical report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
meta = report.get("report_metadata", {})
pub = report.get("policy_published", {})
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO dmarc_aggregate_report (
xml_schema, xml_namespace, org_name, org_email,
org_extra_contact_info, generator, report_id,
begin_date, end_date, errors,
domain, adkim, aspf, policy,
subdomain_policy, pct, fo,
np, testing, discovery_method
) VALUES (
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s
)
ON CONFLICT (org_name, report_id, domain,
begin_date, end_date)
DO NOTHING
RETURNING id
""",
(
report.get("xml_schema"),
report.get("xml_namespace"),
meta.get("org_name"),
meta.get("org_email"),
meta.get("org_extra_contact_info"),
meta.get("generator"),
meta.get("report_id"),
_naive_local_to_timestamptz(meta.get("begin_date")),
_naive_local_to_timestamptz(meta.get("end_date")),
meta.get("errors") or [],
pub.get("domain"),
pub.get("adkim"),
pub.get("aspf"),
pub.get("p"),
pub.get("sp"),
pub.get("pct"),
pub.get("fo"),
pub.get("np"),
pub.get("testing"),
pub.get("discovery_method"),
),
)
row = cur.fetchone()
if row is None:
raise AlreadySaved(
"Aggregate report {report_id} from {org} "
"has already been saved".format(
report_id=meta.get("report_id"),
org=meta.get("org_name"),
)
)
report_db_id: int = row[0]
for record in report.get("records", []):
src = record.get("source", {})
pol = record.get("policy_evaluated", {})
idens = record.get("identifiers", {})
cur.execute(
"""
INSERT INTO dmarc_aggregate_record (
report_id, interval_begin, interval_end,
source_ip_address, source_country,
source_reverse_dns, source_base_domain,
source_name, source_type,
message_count,
spf_aligned, dkim_aligned, dmarc_passed,
disposition, policy_dkim, policy_spf,
header_from, envelope_from, envelope_to
) VALUES (
%s, %s, %s,
%s, %s, %s, %s, %s, %s,
%s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s
)
RETURNING id
""",
(
report_db_id,
_ensure_utc_suffix(record.get("interval_begin")),
_ensure_utc_suffix(record.get("interval_end")),
src.get("ip_address"),
src.get("country"),
src.get("reverse_dns"),
src.get("base_domain"),
src.get("name"),
src.get("type"),
record.get("count"),
record.get("alignment", {}).get("spf"),
record.get("alignment", {}).get("dkim"),
record.get("alignment", {}).get("dmarc"),
pol.get("disposition"),
pol.get("dkim"),
pol.get("spf"),
idens.get("header_from"),
idens.get("envelope_from"),
idens.get("envelope_to"),
),
)
record_db_id: int = cur.fetchone()[0]
for dkim in record.get("auth_results", {}).get("dkim", []):
cur.execute(
"""
INSERT INTO dmarc_aggregate_record_dkim
(record_id, domain, selector, result,
human_result)
VALUES (%s, %s, %s, %s, %s)
""",
(
record_db_id,
dkim.get("domain"),
dkim.get("selector"),
dkim.get("result"),
dkim.get("human_result"),
),
)
for spf in record.get("auth_results", {}).get("spf", []):
cur.execute(
"""
INSERT INTO dmarc_aggregate_record_spf
(record_id, domain, scope, result,
human_result)
VALUES (%s, %s, %s, %s, %s)
""",
(
record_db_id,
spf.get("domain"),
spf.get("scope"),
spf.get("result"),
spf.get("human_result"),
),
)
for override in pol.get("policy_override_reasons", []):
cur.execute(
"""
INSERT INTO dmarc_aggregate_record_policy_override
(record_id, override_type, comment)
VALUES (%s, %s, %s)
""",
(
record_db_id,
override.get("type"),
override.get("comment"),
),
)
except AlreadySaved:
raise
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
def save_failure_report_to_postgresql(self, report: dict) -> None:
"""Saves a parsed failure (RUF) DMARC report to PostgreSQL.
Args:
report: A parsed failure report dictionary as returned by
:func:`parsedmarc.parse_report_file`.
Raises:
AlreadySaved: If a matching failure report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
sample = report.get("parsed_sample", {}) or {}
src = report.get("source", {}) or {}
arrival_date_utc = _ensure_utc_suffix(report.get("arrival_date_utc"))
sample_subject = sample.get("subject")
# JSONB values are reused by both the dedup check and the INSERT.
sample_headers = (
psycopg_types.json.Jsonb(sample["headers"])
if sample.get("headers")
else None
)
sample_from = (
psycopg_types.json.Jsonb(sample["from"]) if sample.get("from") else None
)
sample_to = psycopg_types.json.Jsonb(sample["to"]) if sample.get("to") else None
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
# Failure reports have no natural primary key, so mirror the
# Elasticsearch backend's query-then-insert dedup on the same
# dimensions it uses: arrival date + From + To + Subject.
# IS NOT DISTINCT FROM is NULL-safe (no PG15 NULLS NOT
# DISTINCT dependency); JSONB equality is semantic, so key
# order within the From/To objects doesn't matter.
cur.execute(
"""
SELECT 1 FROM dmarc_failure_report
WHERE arrival_date_utc IS NOT DISTINCT FROM %s
AND sample_subject IS NOT DISTINCT FROM %s
AND sample_from IS NOT DISTINCT FROM %s
AND sample_to IS NOT DISTINCT FROM %s
LIMIT 1
""",
(arrival_date_utc, sample_subject, sample_from, sample_to),
)
if cur.fetchone() is not None:
raise AlreadySaved(
"A failure report with subject {subj!r} arriving "
"at {date} has already been saved".format(
subj=sample_subject, date=arrival_date_utc
)
)
cur.execute(
"""
INSERT INTO dmarc_failure_report (
feedback_type, user_agent, version,
original_envelope_id, original_mail_from,
original_rcpt_to, arrival_date, arrival_date_utc,
authentication_results, delivery_result,
auth_failure, authentication_mechanisms,
dkim_domain, reported_domain, sample_headers_only,
source_ip_address, source_country,
source_reverse_dns, source_base_domain,
source_name, source_type,
sample, sample_date, sample_subject,
sample_body, sample_has_defects,
sample_headers, sample_from, sample_to
) VALUES (
%s, %s, %s,
%s, %s,
%s, %s, %s,
%s, %s,
%s, %s,
%s, %s, %s,
%s, %s,
%s, %s,
%s, %s,
%s, %s, %s,
%s, %s,
%s, %s, %s
)
RETURNING id
""",
(
report.get("feedback_type"),
report.get("user_agent"),
report.get("version"),
report.get("original_envelope_id"),
report.get("original_mail_from"),
report.get("original_rcpt_to"),
_normalize_arrival_date(report.get("arrival_date")),
arrival_date_utc,
report.get("authentication_results"),
report.get("delivery_result"),
report.get("auth_failure") or [],
report.get("authentication_mechanisms") or [],
report.get("dkim_domain"),
report.get("reported_domain"),
report.get("sample_headers_only"),
src.get("ip_address"),
src.get("country"),
src.get("reverse_dns"),
src.get("base_domain"),
src.get("name"),
src.get("type"),
report.get("sample"),
sample.get("date"),
sample_subject,
sample.get("body"),
sample.get("has_defects"),
sample_headers,
sample_from,
sample_to,
),
)
report_db_id: int = cur.fetchone()[0]
for addr_type in ("to", "cc", "bcc", "reply_to"):
entries = sample.get(addr_type) or []
if isinstance(entries, dict):
entries = [entries]
for entry in entries:
cur.execute(
"""
INSERT INTO dmarc_failure_sample_address
(report_id, address_type,
display_name, address)
VALUES (%s, %s, %s, %s)
""",
(
report_db_id,
addr_type,
entry.get("display_name"),
entry.get("address"),
),
)
except AlreadySaved:
raise
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
def save_smtp_tls_report_to_postgresql(self, report: dict) -> None:
"""Saves a parsed SMTP TLS report to PostgreSQL.
Args:
report: A parsed SMTP TLS report dictionary as returned by
:func:`parsedmarc.parse_report_file`.
Raises:
AlreadySaved: If an identical report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO smtp_tls_report (
organization_name, begin_date, end_date,
contact_info, report_id
) VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (organization_name, report_id,
begin_date, end_date)
DO NOTHING
RETURNING id
""",
(
report.get("organization_name"),
_ensure_utc_suffix(report.get("begin_date")),
_ensure_utc_suffix(report.get("end_date")),
_contact_info_to_text(report.get("contact_info")),
report.get("report_id"),
),
)
row = cur.fetchone()
if row is None:
raise AlreadySaved(
"SMTP TLS report {report_id} from {org} "
"has already been saved".format(
report_id=report.get("report_id"),
org=report.get("organization_name"),
)
)
report_db_id: int = row[0]
for policy in report.get("policies", []):
cur.execute(
"""
INSERT INTO smtp_tls_policy (
report_id, policy_domain, policy_type,
policy_strings, mx_host_patterns,
successful_session_count, failed_session_count
) VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
report_db_id,
policy.get("policy_domain"),
policy.get("policy_type"),
policy.get("policy_strings") or [],
policy.get("mx_host_patterns") or [],
policy.get("successful_session_count"),
policy.get("failed_session_count"),
),
)
policy_db_id: int = cur.fetchone()[0]
for detail in policy.get("failure_details", []):
cur.execute(
"""
INSERT INTO smtp_tls_failure_detail (
policy_id, result_type,
failed_session_count,
sending_mta_ip, receiving_ip,
receiving_mx_hostname, receiving_mx_helo,
additional_info_uri, failure_reason_code
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s
)
""",
(
policy_db_id,
detail.get("result_type"),
detail.get("failed_session_count"),
detail.get("sending_mta_ip"),
detail.get("receiving_ip"),
detail.get("receiving_mx_hostname"),
detail.get("receiving_mx_helo"),
detail.get("additional_info_uri"),
detail.get("failure_reason_code"),
),
)
except AlreadySaved:
raise
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
+6
View File
@@ -54,6 +54,12 @@ dependencies = [
]
[project.optional-dependencies]
postgresql = [
# Optional output backend. psycopg ships prebuilt binary wheels via the
# [binary] extra, but those wheels don't exist for every platform/arch,
# so PostgreSQL support is opt-in rather than a mandatory dependency.
"psycopg[binary]>=3.1.0",
]
build = [
# Used only by maintainer tooling under parsedmarc/resources/maps/ —
# `collect_domain_info.py --use-search-fallback` falls back to a
+240
View File
@@ -2588,6 +2588,246 @@ class TestParseConfigS3(unittest.TestCase):
_parse_config(cp, _opts())
class TestParseConfigPostgreSQL(unittest.TestCase):
def test_postgresql_individual_params(self):
from parsedmarc.cli import _parse_config
cp = _config_with(
"postgresql",
{
"host": "db.example.com",
"port": "6543",
"user": "pmarc",
"password": "secret",
"database": "dmarc",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.postgresql_host, "db.example.com")
self.assertEqual(opts.postgresql_port, 6543)
self.assertEqual(opts.postgresql_user, "pmarc")
self.assertEqual(opts.postgresql_password, "secret")
self.assertEqual(opts.postgresql_database, "dmarc")
def test_postgresql_connection_string_takes_precedence(self):
"""connection_string is read and host parsing is skipped."""
from parsedmarc.cli import _parse_config
cp = _config_with(
"postgresql",
{
"connection_string": "postgresql://u:p@h/db",
"host": "ignored.example.com",
},
)
opts = _opts()
_parse_config(cp, opts)
self.assertEqual(opts.postgresql_connection_string, "postgresql://u:p@h/db")
# The host branch is skipped entirely when a connection_string is set.
self.assertFalse(hasattr(opts, "postgresql_host"))
def test_postgresql_missing_host_and_dsn_raises(self):
from parsedmarc.cli import ConfigurationError, _parse_config
cp = _config_with("postgresql", {"port": "5432"})
with self.assertRaises(ConfigurationError) as ctx:
_parse_config(cp, _opts())
self.assertIn("postgresql", str(ctx.exception))
class TestPostgreSQLCliWiring(unittest.TestCase):
"""End-to-end: a [postgresql] config reaches PostgreSQLClient + create_tables.
Regression guard so the config parse, the Namespace defaults, and the
_init_output_clients wiring can't drift apart.
"""
def test_postgresql_config_constructs_client_and_creates_tables(self):
config = """[general]
save_aggregate = true
silent = true
[imap]
host = imap.example.com
user = test-user
password = test-password
[postgresql]
host = db.example.com
port = 6543
user = pmarc
password = secret
database = dmarc
"""
with tempfile.NamedTemporaryFile(
"w", suffix=".ini", delete=False
) as config_file:
config_file.write(config)
config_path = config_file.name
self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path))
with (
patch("parsedmarc.cli.postgres.PostgreSQLClient") as mock_client_cls,
patch(
"parsedmarc.cli.get_dmarc_reports_from_mailbox",
return_value={
"aggregate_reports": [],
"failure_reports": [],
"smtp_tls_reports": [],
},
),
patch("parsedmarc.cli.IMAPConnection", return_value=object()),
patch.object(sys, "argv", ["parsedmarc", "-c", config_path]),
):
parsedmarc.cli._main()
mock_client_cls.assert_called_once()
kwargs = mock_client_cls.call_args.kwargs
self.assertEqual(kwargs.get("host"), "db.example.com")
self.assertEqual(kwargs.get("port"), 6543)
self.assertEqual(kwargs.get("user"), "pmarc")
self.assertEqual(kwargs.get("database"), "dmarc")
mock_client_cls.return_value.create_tables.assert_called_once()
def test_postgresql_aggregate_report_is_saved(self):
"""An aggregate report reaches the client's save method via the loop."""
config = """[general]
save_aggregate = true
silent = true
[imap]
host = imap.example.com
user = test-user
password = test-password
[postgresql]
host = db.example.com
"""
with tempfile.NamedTemporaryFile(
"w", suffix=".ini", delete=False
) as config_file:
config_file.write(config)
config_path = config_file.name
self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path))
report = {"policy_published": {"domain": "example.com"}, "records": []}
with (
patch("parsedmarc.cli.postgres.PostgreSQLClient") as mock_client_cls,
patch(
"parsedmarc.cli.get_dmarc_reports_from_mailbox",
return_value={
"aggregate_reports": [report],
"failure_reports": [],
"smtp_tls_reports": [],
},
),
patch("parsedmarc.cli.IMAPConnection", return_value=object()),
patch.object(sys, "argv", ["parsedmarc", "-c", config_path]),
):
parsedmarc.cli._main()
pg_client = mock_client_cls.return_value
pg_client.save_aggregate_report_to_postgresql.assert_called_once_with(report)
def _run_main(self, reports, save_side_effect=None):
"""Run _main with all save flags on and PostgreSQLClient mocked.
Returns the mocked client instance for assertions. *save_side_effect*,
if given, is applied to every save_* method so error-handling branches
can be exercised.
"""
config = """[general]
save_aggregate = true
save_failure = true
save_smtp_tls = true
silent = true
[imap]
host = imap.example.com
user = test-user
password = test-password
[postgresql]
host = db.example.com
"""
with tempfile.NamedTemporaryFile(
"w", suffix=".ini", delete=False
) as config_file:
config_file.write(config)
config_path = config_file.name
self.addCleanup(lambda: os.path.exists(config_path) and os.remove(config_path))
with (
patch("parsedmarc.cli.postgres.PostgreSQLClient") as mock_client_cls,
patch(
"parsedmarc.cli.get_dmarc_reports_from_mailbox",
return_value=reports,
),
patch("parsedmarc.cli.IMAPConnection", return_value=object()),
patch.object(sys, "argv", ["parsedmarc", "-c", config_path]),
):
client = mock_client_cls.return_value
if save_side_effect is not None:
for m in (
"save_aggregate_report_to_postgresql",
"save_failure_report_to_postgresql",
"save_smtp_tls_report_to_postgresql",
):
getattr(client, m).side_effect = save_side_effect
parsedmarc.cli._main()
return client
def test_postgresql_all_report_types_saved(self):
"""Failure and SMTP-TLS reports also reach their save methods."""
agg = {"policy_published": {"domain": "example.com"}, "records": []}
fail = {"reported_domain": "example.com", "parsed_sample": {}}
tls = {"organization_name": "Org", "policies": [{"policy_domain": "d"}]}
client = self._run_main(
{
"aggregate_reports": [agg],
"failure_reports": [fail],
"smtp_tls_reports": [tls],
}
)
client.save_aggregate_report_to_postgresql.assert_called_once_with(agg)
client.save_failure_report_to_postgresql.assert_called_once_with(fail)
client.save_smtp_tls_report_to_postgresql.assert_called_once_with(tls)
def test_postgresql_already_saved_is_warned_not_fatal(self):
"""AlreadySaved from any save is swallowed (logged), not propagated."""
from parsedmarc import postgres
agg = {"policy_published": {"domain": "example.com"}, "records": []}
fail = {"reported_domain": "example.com", "parsed_sample": {}}
tls = {"organization_name": "Org", "policies": []}
# Should not raise despite every save raising AlreadySaved.
self._run_main(
{
"aggregate_reports": [agg],
"failure_reports": [fail],
"smtp_tls_reports": [tls],
},
save_side_effect=postgres.AlreadySaved("dup"),
)
def test_postgresql_error_is_logged_not_fatal(self):
"""PostgreSQLError from any save is logged, not propagated."""
from parsedmarc import postgres
agg = {"policy_published": {"domain": "example.com"}, "records": []}
fail = {"reported_domain": "example.com", "parsed_sample": {}}
tls = {"organization_name": "Org", "policies": []}
self._run_main(
{
"aggregate_reports": [agg],
"failure_reports": [fail],
"smtp_tls_reports": [tls],
},
save_side_effect=postgres.PostgreSQLError("boom"),
)
class TestParseConfigSyslog(unittest.TestCase):
def test_syslog_complete(self):
from parsedmarc.cli import _parse_config
+786
View File
@@ -0,0 +1,786 @@
"""Tests for parsedmarc.postgres — the PostgreSQL output backend.
The pure timestamp/contact-info helpers are tested directly. The
``PostgreSQLClient`` save methods are tested with psycopg mocked at the SDK
boundary (``parsedmarc.postgres.psycopg``); the assertions check the SQL and
the bound parameters that a real PostgreSQL server would receive, plus the
real-sample round trip, so the tests fail if the dict-key mapping regresses.
"""
import os
import unittest
from glob import glob
from unittest.mock import MagicMock, patch
import parsedmarc
from parsedmarc.postgres import (
AlreadySaved,
PostgreSQLClient,
PostgreSQLError,
_contact_info_to_text,
_ensure_utc_suffix,
_naive_local_to_timestamptz,
_normalize_arrival_date,
)
OFFLINE_MODE = os.environ.get("GITHUB_ACTIONS", "false").lower() == "true"
# psycopg is an optional dependency and is not installed in CI (which installs
# only the [build] extra). The save methods mock the connection, but the
# failure path also references ``psycopg_types.json.Jsonb`` at module scope, so
# mock that SDK boundary for the whole module when psycopg is absent.
_types_patcher = None
def setUpModule():
global _types_patcher
import parsedmarc.postgres as pg
if pg.psycopg_types is None:
_types_patcher = patch("parsedmarc.postgres.psycopg_types", MagicMock())
_types_patcher.start()
def tearDownModule():
if _types_patcher is not None:
_types_patcher.stop()
class TestPostgreSQLHelpers(unittest.TestCase):
"""Unit tests for the pure helper functions in parsedmarc.postgres."""
# -- _ensure_utc_suffix --------------------------------------------------
def test_ensure_utc_suffix_none(self):
"""None passes through unchanged."""
self.assertIsNone(_ensure_utc_suffix(None))
def test_ensure_utc_suffix_empty_string(self):
"""Empty string passes through unchanged (falsy)."""
self.assertEqual(_ensure_utc_suffix(""), "")
def test_ensure_utc_suffix_naive_utc(self):
"""A naive UTC timestamp gets '+00' appended."""
self.assertEqual(
_ensure_utc_suffix("2024-01-15 10:30:00"),
"2024-01-15 10:30:00+00",
)
def test_ensure_utc_suffix_already_has_plus(self):
"""A timestamp already containing '+' is left unchanged."""
val = "2024-01-15 10:30:00+05:30"
self.assertEqual(_ensure_utc_suffix(val), val)
def test_ensure_utc_suffix_already_has_z(self):
"""A timestamp ending with 'Z' is left unchanged."""
val = "2024-01-15T10:30:00Z"
self.assertEqual(_ensure_utc_suffix(val), val)
def test_ensure_utc_suffix_negative_offset(self):
"""A timestamp with a negative offset after position 10 is unchanged."""
val = "2024-01-15 10:30:00-05:00"
self.assertEqual(_ensure_utc_suffix(val), val)
def test_ensure_utc_suffix_iso_t_naive(self):
"""Naive ISO 8601 with T separator gets '+00'."""
self.assertEqual(
_ensure_utc_suffix("2024-01-15T10:30:00"),
"2024-01-15T10:30:00+00",
)
# -- _naive_local_to_timestamptz -----------------------------------------
def test_naive_local_to_timestamptz_none(self):
self.assertIsNone(_naive_local_to_timestamptz(None))
def test_naive_local_to_timestamptz_empty(self):
self.assertEqual(_naive_local_to_timestamptz(""), "")
def test_naive_local_to_timestamptz_valid(self):
"""A valid naive string is returned with a timezone offset."""
result = _naive_local_to_timestamptz("2024-01-15 10:30:00")
self.assertIsInstance(result, str)
self.assertTrue(
"+" in result or "-" in result[10:],
f"Expected timezone offset in result: {result}",
)
from datetime import datetime as _dt
parsed = _dt.fromisoformat(result)
self.assertIsNotNone(parsed.tzinfo)
def test_naive_local_to_timestamptz_bad_format_raises(self):
"""An unparseable string raises ValueError (from strptime)."""
with self.assertRaises(ValueError):
_naive_local_to_timestamptz("not-a-date")
# -- _normalize_arrival_date ---------------------------------------------
def test_normalize_arrival_date_none(self):
self.assertIsNone(_normalize_arrival_date(None))
def test_normalize_arrival_date_empty(self):
self.assertEqual(_normalize_arrival_date(""), "")
def test_normalize_arrival_date_iso_naive_utc(self):
"""A naive ISO string (known UTC) is returned with +00 suffix."""
result = _normalize_arrival_date("2024-01-15 10:30:00")
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
def test_normalize_arrival_date_rfc2822(self):
"""An RFC 2822 date is converted to UTC with +00 suffix."""
result = _normalize_arrival_date("Fri, 28 Oct 2022 00:34:24 +0800")
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
# 00:34:24 +0800 is 16:34:24 UTC on 27 Oct 2022.
self.assertIn("2022-10-27", result)
self.assertIn("16:34:24", result)
def test_normalize_arrival_date_already_utc(self):
"""A string already ending with +00 still works."""
result = _normalize_arrival_date("2024-01-15 10:30:00+00")
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
def test_normalize_arrival_date_unparseable(self):
"""An unparseable string is returned as-is (fallback)."""
garbage = "not a date at all"
self.assertEqual(_normalize_arrival_date(garbage), garbage)
# -- _contact_info_to_text -----------------------------------------------
def test_contact_info_to_text_none(self):
self.assertIsNone(_contact_info_to_text(None))
def test_contact_info_to_text_string(self):
self.assertEqual(
_contact_info_to_text("admin@example.com"),
"admin@example.com",
)
def test_contact_info_to_text_list(self):
self.assertEqual(
_contact_info_to_text(["admin@example.com", "abuse@example.com"]),
"admin@example.com, abuse@example.com",
)
def test_contact_info_to_text_empty_list(self):
self.assertEqual(_contact_info_to_text([]), "")
def test_contact_info_to_text_numeric(self):
"""Non-string scalars are converted via str()."""
self.assertEqual(_contact_info_to_text(123), "123")
def _make_client():
"""Create a PostgreSQLClient with a fully-mocked psycopg connection."""
with patch("parsedmarc.postgres.psycopg") as mock_psycopg:
mock_conn = MagicMock()
mock_psycopg.connect.return_value = mock_conn
mock_psycopg.Error = Exception
client = PostgreSQLClient(
host="localhost", database="test", user="test", password="test"
)
client._conn = mock_conn
client._conn.closed = False
return client, mock_conn
def _mock_cursor(mock_conn, fetchone_results):
"""Wire up a mock cursor whose fetchone() yields *fetchone_results*."""
mock_cursor = MagicMock()
mock_cursor.fetchone.side_effect = list(fetchone_results)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value = mock_cursor
mock_conn.transaction.return_value.__enter__ = MagicMock()
mock_conn.transaction.return_value.__exit__ = MagicMock(return_value=False)
return mock_cursor
def _executed_sql(mock_cursor):
"""Return the list of SQL strings passed to cursor.execute()."""
return [c.args[0] for c in mock_cursor.execute.call_args_list]
def _named_params(call):
"""Map an INSERT's column names to the bound parameter values.
Lets tests assert by column name instead of fragile positional indices.
"""
import re
sql = call.args[0]
m = re.search(r"\(([^)]*?)\)\s*VALUES", sql, re.S)
cols = [c.strip() for c in m.group(1).split(",") if c.strip()]
return dict(zip(cols, call.args[1]))
class TestPostgreSQLConstruction(unittest.TestCase):
"""Construction-time behaviour, including the optional-dependency guard."""
def test_missing_psycopg_raises_install_hint(self):
"""Without psycopg installed, construction fails with an install hint."""
with patch("parsedmarc.postgres.psycopg", None):
with self.assertRaises(PostgreSQLError) as ctx:
PostgreSQLClient(host="localhost")
self.assertIn("pip install parsedmarc[postgresql]", str(ctx.exception))
def test_close_closes_open_connection(self):
"""close() closes a live connection and is a no-op once closed."""
client, mock_conn = _make_client()
mock_conn.closed = False
client.close()
mock_conn.close.assert_called_once()
mock_conn.close.reset_mock()
mock_conn.closed = True
client.close()
mock_conn.close.assert_not_called()
def test_ensure_connected_reconnects_on_closed(self):
"""_ensure_connected reconnects when the connection is closed."""
client, mock_conn = _make_client()
mock_conn.closed = True
with patch.object(client, "_connect") as mock_reconnect:
client._ensure_connected()
mock_reconnect.assert_called_once()
def test_connect_uses_connection_string_when_provided(self):
"""A DSN/URI is passed straight to psycopg.connect."""
with patch("parsedmarc.postgres.psycopg") as mock_psycopg:
mock_psycopg.Error = Exception
PostgreSQLClient(connection_string="postgresql://u:p@h/db")
mock_psycopg.connect.assert_called_once_with("postgresql://u:p@h/db")
def test_connect_failure_raises_postgresql_error(self):
"""A driver-level connection error is wrapped in PostgreSQLError."""
with patch("parsedmarc.postgres.psycopg") as mock_psycopg:
mock_psycopg.Error = Exception
mock_psycopg.connect.side_effect = mock_psycopg.Error("refused")
with self.assertRaises(PostgreSQLError) as ctx:
PostgreSQLClient(host="localhost")
self.assertIn("refused", str(ctx.exception))
def test_create_tables_executes_all_ddl(self):
"""create_tables issues CREATE TABLE for every table and the indexes."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [])
client.create_tables()
executed = " ".join(_executed_sql(cur))
for table in (
"dmarc_aggregate_report",
"dmarc_aggregate_record",
"dmarc_aggregate_record_dkim",
"dmarc_aggregate_record_spf",
"dmarc_aggregate_record_policy_override",
"dmarc_failure_report",
"dmarc_failure_sample_address",
"smtp_tls_report",
"smtp_tls_policy",
"smtp_tls_failure_detail",
):
self.assertIn(f"CREATE TABLE IF NOT EXISTS {table}", executed)
self.assertIn("CREATE INDEX IF NOT EXISTS", executed)
def test_create_tables_wraps_db_error(self):
"""A driver error during DDL is wrapped in PostgreSQLError."""
class FakeDriverError(Exception):
pass
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [])
cur.execute.side_effect = FakeDriverError("ddl boom")
with patch("parsedmarc.postgres.psycopg") as mp:
mp.Error = FakeDriverError
with self.assertRaises(PostgreSQLError) as ctx:
client.create_tables()
self.assertIn("ddl boom", str(ctx.exception))
class TestPostgreSQLClientSave(unittest.TestCase):
"""Save methods with a mocked DB: assert on SQL and bound parameters."""
# -- aggregate -----------------------------------------------------------
def test_save_aggregate_report_calls_insert(self):
"""Aggregate save executes INSERTs for report, record, dkim and spf."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [(1,), (10,)])
report = {
"xml_schema": "1.0",
"xml_namespace": "urn:ietf:params:xml:ns:dmarc-2.0",
"report_metadata": {
"org_name": "Example Inc.",
"org_email": "dmarc@example.com",
"org_extra_contact_info": None,
"report_id": "rpt-123",
"begin_date": "2024-01-15 00:00:00",
"end_date": "2024-01-15 23:59:59",
"errors": [],
"generator": "ExampleReporter/2.0",
},
"policy_published": {
"domain": "example.com",
"adkim": "r",
"aspf": "r",
"p": "none",
"sp": "none",
"pct": "100",
"fo": "0",
"np": "reject",
"testing": "y",
"discovery_method": "treewalk",
},
"records": [
{
"source": {
"ip_address": "203.0.113.1",
"country": "US",
"reverse_dns": "mail.example.com",
"base_domain": "example.com",
"name": None,
"type": None,
},
"count": 5,
"alignment": {"spf": True, "dkim": True, "dmarc": True},
"policy_evaluated": {
"disposition": "none",
"dkim": "pass",
"spf": "pass",
"policy_override_reasons": [],
},
"identifiers": {
"header_from": "example.com",
"envelope_from": "example.com",
"envelope_to": None,
},
"interval_begin": "2024-01-15 00:00:00",
"interval_end": "2024-01-15 23:59:59",
"auth_results": {
"dkim": [
{
"domain": "example.com",
"selector": "s1",
"result": "pass",
"human_result": "valid signature",
}
],
"spf": [
{
"domain": "example.com",
"scope": "mfrom",
"result": "pass",
"human_result": None,
}
],
},
}
],
}
client.save_aggregate_report_to_postgresql(report)
sqls = _executed_sql(cur)
self.assertIn("dmarc_aggregate_report", sqls[0])
self.assertIn("dmarc_aggregate_record", sqls[1])
self.assertTrue(any("dmarc_aggregate_record_dkim" in s for s in sqls))
self.assertTrue(any("dmarc_aggregate_record_spf" in s for s in sqls))
# The RFC 9990 / DMARCbis fields must reach the report INSERT.
report_params = _named_params(cur.execute.call_args_list[0])
self.assertEqual(
report_params["xml_namespace"], "urn:ietf:params:xml:ns:dmarc-2.0"
)
self.assertEqual(report_params["generator"], "ExampleReporter/2.0")
self.assertEqual(report_params["np"], "reject")
self.assertEqual(report_params["testing"], "y")
self.assertEqual(report_params["discovery_method"], "treewalk")
# DKIM auth-result values, including human_result, reach the INSERT.
dkim_sql_idx = next(
i for i, s in enumerate(sqls) if "dmarc_aggregate_record_dkim" in s
)
dkim_params = _named_params(cur.execute.call_args_list[dkim_sql_idx])
self.assertEqual(dkim_params["domain"], "example.com")
self.assertEqual(dkim_params["selector"], "s1")
self.assertEqual(dkim_params["result"], "pass")
self.assertEqual(dkim_params["human_result"], "valid signature")
def test_save_aggregate_report_already_saved(self):
"""AlreadySaved is raised when ON CONFLICT returns no row."""
client, mock_conn = _make_client()
_mock_cursor(mock_conn, [None])
report = {
"report_metadata": {
"org_name": "Dup Inc.",
"report_id": "dup-001",
"begin_date": "2024-01-01 00:00:00",
"end_date": "2024-01-01 23:59:59",
},
"policy_published": {"domain": "example.com"},
"records": [],
}
with self.assertRaises(AlreadySaved):
client.save_aggregate_report_to_postgresql(report)
def test_aggregate_report_normalizes_timestamps(self):
"""Report dates get a tz offset; record intervals get a +00 suffix."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [(1,), (10,)])
report = {
"report_metadata": {
"org_name": "TZ Test",
"report_id": "tz-001",
"begin_date": "2024-01-15 00:00:00",
"end_date": "2024-01-15 23:59:59",
},
"policy_published": {"domain": "example.com"},
"records": [
{
"source": {},
"count": 1,
"alignment": {},
"policy_evaluated": {},
"identifiers": {"header_from": "example.com"},
"interval_begin": "2024-01-15 00:00:00",
"interval_end": "2024-01-15 23:59:59",
"auth_results": {"dkim": [], "spf": []},
}
],
}
client.save_aggregate_report_to_postgresql(report)
report_params = _named_params(cur.execute.call_args_list[0])
for label in ("begin_date", "end_date"):
val = report_params[label]
self.assertIsNotNone(val, f"{label} should not be None")
self.assertTrue(
"+" in val or "-" in val[10:],
f"Report {label} should carry a tz offset: {val}",
)
record_params = _named_params(cur.execute.call_args_list[1])
for label in ("interval_begin", "interval_end"):
val = record_params[label]
self.assertIsNotNone(val, f"{label} should not be None")
self.assertTrue(
val.endswith("+00"),
f"Record {label} should end with +00: {val}",
)
# -- failure -------------------------------------------------------------
def test_save_failure_report_calls_insert(self):
"""Failure save dedups, then INSERTs the report and sample addresses."""
client, mock_conn = _make_client()
# 1st fetchone = dedup SELECT (None → not a duplicate); 2nd = INSERT id.
cur = _mock_cursor(mock_conn, [None, (1,)])
report = {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
"original_envelope_id": None,
"original_mail_from": "sender@example.com",
"original_rcpt_to": "receiver@example.com",
"arrival_date": "Mon, 15 Jan 2024 10:30:00 +0000",
"arrival_date_utc": "2024-01-15 10:30:00",
"authentication_results": "spf=pass",
"delivery_result": None,
"auth_failure": ["dkim"],
"authentication_mechanisms": [],
"dkim_domain": "example.com",
"reported_domain": "example.com",
"sample_headers_only": False,
"source": {
"ip_address": "203.0.113.1",
"country": "US",
"reverse_dns": "mail.example.com",
"base_domain": "example.com",
"name": None,
"type": None,
},
"sample": "raw email content",
"parsed_sample": {
"date": "2024-01-15",
"subject": "Test",
"body": "Hello",
"has_defects": False,
"headers": {"From": "sender@example.com"},
"from": {"display_name": "Sender", "address": "sender@example.com"},
"to": [{"display_name": "Receiver", "address": "receiver@example.com"}],
"cc": [],
"bcc": [],
"reply_to": [],
},
}
client.save_failure_report_to_postgresql(report)
sqls = _executed_sql(cur)
# First statement is the dedup SELECT, then the report INSERT.
self.assertIn("SELECT", sqls[0])
self.assertIn("dmarc_failure_report", sqls[0])
self.assertTrue(
any("INSERT INTO dmarc_failure_report" in s for s in sqls),
"expected a failure-report INSERT",
)
self.assertTrue(
any("dmarc_failure_sample_address" in s for s in sqls),
"expected a sample-address INSERT for the 'to' recipient",
)
def test_save_failure_report_already_saved(self):
"""A matching existing failure report raises AlreadySaved."""
client, mock_conn = _make_client()
# Dedup SELECT returns a row → duplicate.
_mock_cursor(mock_conn, [(1,)])
report = {
"arrival_date_utc": "2024-01-15 10:30:00",
"reported_domain": "example.com",
"source": {"ip_address": "203.0.113.1"},
"parsed_sample": {"subject": "Test"},
}
with self.assertRaises(AlreadySaved):
client.save_failure_report_to_postgresql(report)
# -- SMTP TLS ------------------------------------------------------------
def test_save_smtp_tls_report_calls_insert(self):
"""SMTP TLS save INSERTs report, policy, and failure detail rows."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [(1,), (10,)])
report = {
"organization_name": "Example Inc.",
"begin_date": "2024-01-15T00:00:00Z",
"end_date": "2024-01-16T00:00:00Z",
"contact_info": "admin@example.com",
"report_id": "tls-001",
"policies": [
{
"policy_domain": "example.com",
"policy_type": "sts",
"policy_strings": ["version: STSv1"],
"mx_host_patterns": ["*.example.com"],
"successful_session_count": 100,
"failed_session_count": 2,
"failure_details": [
{
"result_type": "certificate-expired",
"failed_session_count": 2,
"sending_mta_ip": "203.0.113.1",
"receiving_ip": "198.51.100.1",
"receiving_mx_hostname": "mx.example.com",
"receiving_mx_helo": "mx.example.com",
"additional_info_uri": None,
"failure_reason_code": None,
}
],
}
],
}
client.save_smtp_tls_report_to_postgresql(report)
sqls = _executed_sql(cur)
self.assertIn("smtp_tls_report", sqls[0])
self.assertIn("smtp_tls_policy", sqls[1])
self.assertIn("smtp_tls_failure_detail", sqls[2])
# Policy field mapping must reach the INSERT (regression guard).
policy_params = cur.execute.call_args_list[1].args[1]
self.assertIn("example.com", policy_params)
self.assertIn("sts", policy_params)
self.assertIn(100, policy_params)
self.assertIn(2, policy_params)
def test_save_smtp_tls_report_already_saved(self):
"""AlreadySaved is raised when ON CONFLICT returns no row."""
client, mock_conn = _make_client()
_mock_cursor(mock_conn, [None])
report = {
"organization_name": "Dup Inc.",
"begin_date": "2024-01-01T00:00:00Z",
"end_date": "2024-01-02T00:00:00Z",
"contact_info": "admin@dup.com",
"report_id": "dup-tls-001",
"policies": [],
}
with self.assertRaises(AlreadySaved):
client.save_smtp_tls_report_to_postgresql(report)
def test_save_smtp_tls_report_contact_info_list(self):
"""A contact_info list is joined to a string before insert."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [(1,)])
report = {
"organization_name": "Multi Inc.",
"begin_date": "2024-01-15T00:00:00Z",
"end_date": "2024-01-16T00:00:00Z",
"contact_info": ["admin@multi.com", "abuse@multi.com"],
"report_id": "multi-001",
"policies": [],
}
client.save_smtp_tls_report_to_postgresql(report)
insert_params = cur.execute.call_args_list[0].args[1]
self.assertEqual(insert_params[3], "admin@multi.com, abuse@multi.com")
def test_save_failure_report_single_address_dict(self):
"""A recipient header parsed as a single dict (not a list) is wrapped."""
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [None, (1,)])
report = {
"arrival_date_utc": "2024-01-15 10:30:00",
"reported_domain": "example.com",
"source": {"ip_address": "203.0.113.1"},
"parsed_sample": {
"subject": "Single",
# 'to' as a lone dict rather than a list of dicts.
"to": {"display_name": "Solo", "address": "solo@example.com"},
},
}
client.save_failure_report_to_postgresql(report)
addr_sqls = [
(c.args[0], c.args[1])
for c in cur.execute.call_args_list
if "dmarc_failure_sample_address" in c.args[0]
]
self.assertEqual(len(addr_sqls), 1)
self.assertIn("solo@example.com", addr_sqls[0][1])
class TestPostgreSQLSaveErrors(unittest.TestCase):
"""Driver errors raised mid-save are wrapped in PostgreSQLError."""
class _FakeDriverError(Exception):
pass
def _run(self, method, report):
client, mock_conn = _make_client()
cur = _mock_cursor(mock_conn, [])
cur.execute.side_effect = self._FakeDriverError("db boom")
with patch("parsedmarc.postgres.psycopg") as mp:
mp.Error = self._FakeDriverError
with self.assertRaises(PostgreSQLError) as ctx:
getattr(client, method)(report)
self.assertIn("db boom", str(ctx.exception))
def test_save_aggregate_wraps_db_error(self):
self._run(
"save_aggregate_report_to_postgresql",
{"report_metadata": {}, "policy_published": {}, "records": []},
)
def test_save_failure_wraps_db_error(self):
self._run(
"save_failure_report_to_postgresql",
{"parsed_sample": {}, "source": {}},
)
def test_save_smtp_tls_wraps_db_error(self):
self._run(
"save_smtp_tls_report_to_postgresql",
{"policies": []},
)
class TestPostgreSQLWithSamples(unittest.TestCase):
"""Feed real parsed sample reports through the save methods (DB mocked)."""
def test_aggregate_samples(self):
client, mock_conn = _make_client()
saved = 0
for sample_path in glob("samples/aggregate/*"):
if os.path.isdir(sample_path):
continue
try:
parsed = parsedmarc.parse_report_file(
sample_path,
always_use_local_files=True,
offline=OFFLINE_MODE,
)
except parsedmarc.ParserError:
continue
if parsed.get("report_type") != "aggregate":
continue
report = parsed["report"]
num_records = len(report.get("records", []))
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_records)])
try:
client.save_aggregate_report_to_postgresql(report)
saved += 1
except Exception as exc:
self.fail(f"aggregate save failed for {sample_path}: {exc}")
self.assertGreater(saved, 0, "Expected at least one aggregate sample")
def test_failure_samples(self):
client, mock_conn = _make_client()
saved = 0
for sample_path in glob("samples/failure/*.eml"):
try:
parsed = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE)
except parsedmarc.ParserError:
continue
if parsed.get("report_type") != "failure":
continue
reports = parsed["report"]
if not isinstance(reports, list):
reports = [reports]
for report in reports:
# Dedup SELECT returns None (not a dup), then the INSERT id.
_mock_cursor(mock_conn, [None, (1,)])
try:
client.save_failure_report_to_postgresql(report)
saved += 1
except Exception as exc:
self.fail(f"failure save failed for {sample_path}: {exc}")
self.assertGreater(saved, 0, "Expected at least one failure sample")
def test_smtp_tls_samples(self):
client, mock_conn = _make_client()
saved = 0
for sample_path in glob("samples/smtp_tls/*"):
if os.path.isdir(sample_path):
continue
try:
parsed = parsedmarc.parse_report_file(sample_path, offline=OFFLINE_MODE)
except parsedmarc.ParserError:
continue
if parsed.get("report_type") != "smtp_tls":
continue
report = parsed["report"]
num_policies = len(report.get("policies", []))
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_policies)])
try:
client.save_smtp_tls_report_to_postgresql(report)
saved += 1
except Exception as exc:
self.fail(f"smtp_tls save failed for {sample_path}: {exc}")
self.assertGreater(saved, 0, "Expected at least one SMTP TLS sample")
if __name__ == "__main__":
unittest.main()