Make the whole codebase pass pyright cleanly and enforce it in CI (#798)

* Make the whole codebase pass pyright cleanly and enforce it in CI

Fix all 102 pyright (1.1.410, standard mode) errors across the library,
tests, and maps scripts, then pin and enforce the zero-errors bar:

- postgres.py: make the optional psycopg import TYPE_CHECKING-aware so
  the module is properly typed while keeping the runtime install-hint
  fallback; import psycopg.types.json explicitly as psycopg_json (the
  old psycopg_types.json attribute access only worked because psycopg
  imports the submodule eagerly); have _connect()/_ensure_connected()
  return the live connection so save methods use a non-Optional local;
  type the DDL list as list[LiteralString] to match psycopg's execute()
  overloads.
- kafkaclient.py: resolve the kafka-python 2.x/3.x bootstrap-error
  fallback statically via TYPE_CHECKING (kafka-python 3.0 removed
  NoBrokersAvailable), which also fixes _BootstrapError's import
  resolution in tests.
- syslog.py: go through getattr/setattr for SysLogHandler.socket
  (absent from typeshed); type the save_* methods with the report
  TypedDicts (single or list, matching cli.py call sites — gelf.py gets
  the same signatures); raise ValueError when retry_attempts < 1
  instead of falling through and registering a None handler (bug fix,
  with a regression test and a CHANGELOG entry).
- elastic.py / opensearch.py: human_result params are Optional[str].
- maps scripts: sort_csv declared a return type but never returned
  (now -> None); seen_sort_field_values was possibly unbound;
  convert_to_utf8's src_encoding is Optional[str].
- tests: cast sample-report dict helpers to their TypedDicts; mark
  deliberate wrong-type calls with targeted pyright ignores; add
  narrowing asserts for Optional results; access the mocked
  KafkaProducer through a cast helper; match the mailsuite
  fetch_message base signature (**kwargs); patch the renamed
  parsedmarc.postgres.psycopg_json in test_postgres's setUpModule.

Enforcement: [tool.pyright] in pyproject.toml (include parsedmarc,
tests, docs; standard mode), pyright==1.1.410 pinned in the [build]
extra (pinned exactly so a new pyright release can't break CI without a
code change), and a "Check types" step in the lint CI job — which now
also runs ruff format --check and installs the [postgresql] extra so
the optional psycopg import resolves. Documented in AGENTS.md.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Set session headers via update() instead of replacing the dict

requests 2.34 ships inline type annotations, and Session.headers is a
CaseInsensitiveDict[str] — assigning a plain dict fails pyright there
(the CI runner resolved 2.34.2; the local venv's untyped 2.32.4 hid
it). headers.update() is correctly typed against both versions, and is
the documented requests idiom: it overrides User-Agent and the
client-specific headers while keeping the session's defaults
(Accept-Encoding, Connection) instead of wiping them.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Sean Whalen
2026-06-12 21:33:01 -04:00
committed by GitHub
parent 0c456d44ed
commit eaeea4f53d
20 changed files with 236 additions and 130 deletions
+7 -1
View File
@@ -22,10 +22,16 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install .[build]
# postgresql extra included so pyright can resolve the optional
# psycopg import in parsedmarc/postgres.py
pip install .[build,postgresql]
- name: Check code style
run: |
ruff check .
ruff format --check .
- name: Check types
run: |
pyright
- name: Test building documentation
run: |
cd docs
+6
View File
@@ -25,6 +25,11 @@ pytest tests/test_init.py::Test::testAggregateSamples
ruff check .
ruff format .
# Type check (config in pyproject.toml [tool.pyright]; CI enforces zero
# errors/warnings; needs the [postgresql] extra installed so the optional
# psycopg import resolves)
pyright
# Test CLI with sample reports
parsedmarc --debug -c ci.ini samples/aggregate/*
parsedmarc --debug -c ci.ini samples/failure/*
@@ -108,6 +113,7 @@ IP address info cached for 4 hours, seen aggregate report IDs cached for 1 hour
## Code Style
- Ruff for formatting and linting (configured in `.vscode/settings.json`). Run `ruff check .` and `ruff format --check .` after every code edit, before committing.
- Pyright for type checking (configured in `pyproject.toml` `[tool.pyright]`, pinned in the `[build]` extra, enforced in CI). Run `pyright` from the repo root before committing; the whole codebase — library and tests — must stay at zero errors and warnings. Prefer real fixes (narrowing, `Optional` annotations, `TYPE_CHECKING` imports) over `# pyright: ignore[...]`; reserve targeted ignores for deliberate wrong-type tests and version-conditional imports, and never use a bare blanket ignore.
- TypedDict for structured data, type hints throughout.
- Python ≥3.10 required.
- Tests live under `tests/` as `tests/test_<module>.py`, one per top-level `parsedmarc/*` module (e.g. `tests/test_init.py` for `parsedmarc/__init__.py`, `tests/test_cli.py` for `parsedmarc/cli.py`). All test classes use `unittest`. Sample reports live in `samples/`. Run with `pytest tests/`; run one file with `pytest tests/test_init.py`. New tests go in the file whose module they exercise — do not reintroduce a monolithic test file.
+5
View File
@@ -9,6 +9,11 @@
- **Example systemd unit** in `docs/source/usage.md` now sets `KillSignal=SIGTERM` and `TimeoutStopSec=60` so systemd waits long enough for the watcher to drain (keep it above `mailbox_check_timeout`).
- Switch the Kafka client dependency from `kafka-python-ng` back to `kafka-python>=2.3.2` ([#795](https://github.com/domainaware/parsedmarc/issues/795)). `kafka-python-ng` was a fork created while `kafka-python` was unmaintained; upstream `kafka-python` is active again, and the now-archived fork is vulnerable to [CVE-2026-10142](https://nvd.nist.gov/vuln/detail/CVE-2026-10142) and [CVE-2026-10143](https://nvd.nist.gov/vuln/detail/CVE-2026-10143), both fixed in `kafka-python` 2.3.2. Both packages install the same `kafka` module, so if you are upgrading an existing environment in place with `pip`, run `pip uninstall kafka-python-ng` before upgrading parsedmarc so the two distributions don't conflict with each other's files.
- parsedmarc is compatible with both `kafka-python` 2.3.2+ and 3.x: `kafka-python` 3.0 removed the `NoBrokersAvailable` exception (a failed bootstrap now raises `KafkaTimeoutError`), and parsedmarc handles whichever the installed version provides.
- The whole codebase (library and tests) now passes `pyright` with zero errors and warnings, and CI enforces this (plus `ruff format --check`) on every push and pull request. Pyright is configured in `pyproject.toml` (`[tool.pyright]`) and pinned in the `[build]` extra. The fixes are annotation-level only (`Optional` parameters, TypedDict-aware signatures on the syslog/GELF save methods, `TYPE_CHECKING`-aware optional imports for `psycopg` and the `kafka-python` 2.x/3.x bootstrap-error fallback) — runtime behavior is unchanged apart from the `SyslogClient` fix below. Builds on the class-body alias declarations from [#797](https://github.com/domainaware/parsedmarc/pull/797).
### Bug fixes
- `SyslogClient`: constructing a TCP/TLS client with `retry_attempts` < 1 now raises `ValueError` instead of silently skipping the connection loop and registering a broken (`None`) log handler.
## 10.0.4
+2 -2
View File
@@ -125,7 +125,7 @@ class _AggregateReportDoc(Document):
domain: str,
selector: str,
result: _DKIMResult,
human_result: str = None,
human_result: Optional[str] = None,
):
self.dkim_results.append(
_DKIMResult(
@@ -141,7 +141,7 @@ class _AggregateReportDoc(Document):
domain: str,
scope: str,
result: _SPFResult,
human_result: str = None,
human_result: Optional[str] = None,
):
self.spf_results.append(
_SPFResult(
+10 -6
View File
@@ -12,9 +12,7 @@ from parsedmarc import (
parsed_failure_reports_to_csv_rows,
parsed_smtp_tls_reports_to_csv_rows,
)
from typing import Any
from parsedmarc.types import AggregateReport, SMTPTLSReport
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
log_context_data = threading.local()
@@ -51,7 +49,9 @@ class GelfClient(object):
)
self.logger.addHandler(self.handler)
def save_aggregate_report_to_gelf(self, aggregate_reports: list[AggregateReport]):
def save_aggregate_report_to_gelf(
self, aggregate_reports: AggregateReport | list[AggregateReport]
):
rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports)
for row in rows:
log_context_data.parsedmarc = row
@@ -59,13 +59,17 @@ class GelfClient(object):
log_context_data.parsedmarc = None
def save_failure_report_to_gelf(self, failure_reports: list[dict[str, Any]]):
def save_failure_report_to_gelf(
self, failure_reports: FailureReport | list[FailureReport]
):
rows = parsed_failure_reports_to_csv_rows(failure_reports)
for row in rows:
log_context_data.parsedmarc = row
self.logger.info("parsedmarc failure report")
def save_smtp_tls_report_to_gelf(self, smtp_tls_reports: SMTPTLSReport):
def save_smtp_tls_report_to_gelf(
self, smtp_tls_reports: SMTPTLSReport | list[SMTPTLSReport]
):
rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports)
for row in rows:
log_context_data.parsedmarc = row
+13 -10
View File
@@ -4,18 +4,21 @@ from __future__ import annotations
import json
from ssl import SSLContext, create_default_context
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from kafka import KafkaProducer
from kafka.errors import UnknownTopicOrPartitionError
try:
# kafka-python < 3.0 raises this when the producer cannot bootstrap
from kafka.errors import NoBrokersAvailable as _BootstrapError # pyright: ignore[reportAttributeAccessIssue]
except ImportError:
# kafka-python >= 3.0 removed NoBrokersAvailable; a failed bootstrap
# raises KafkaTimeoutError instead
if TYPE_CHECKING:
from kafka.errors import KafkaTimeoutError as _BootstrapError
else:
try:
# kafka-python < 3.0 raises this when the producer cannot bootstrap
from kafka.errors import NoBrokersAvailable as _BootstrapError
except ImportError:
# kafka-python >= 3.0 removed NoBrokersAvailable; a failed bootstrap
# raises KafkaTimeoutError instead
from kafka.errors import KafkaTimeoutError as _BootstrapError
from parsedmarc import __version__
from parsedmarc.log import logger
@@ -185,6 +188,9 @@ class KafkaClient(object):
except Exception as e:
raise KafkaError("Kafka error: {0}".format(e.__str__()))
# Backward-compatible alias
save_forensic_reports_to_kafka = save_failure_reports_to_kafka
def save_smtp_tls_reports_to_kafka(
self,
smtp_tls_reports: Union[list[dict[str, Any]], dict[str, Any]],
@@ -218,6 +224,3 @@ class KafkaClient(object):
self.producer.flush()
except Exception as e:
raise KafkaError("Kafka error: {0}".format(e.__str__()))
# Backward-compatible alias
save_forensic_reports_to_kafka = save_failure_reports_to_kafka
+2 -2
View File
@@ -116,7 +116,7 @@ class _AggregateReportDoc(Document):
domain: str,
selector: str,
result: _DKIMResult,
human_result: str = None,
human_result: Optional[str] = None,
):
self.dkim_results.append(
_DKIMResult(
@@ -132,7 +132,7 @@ class _AggregateReportDoc(Document):
domain: str,
scope: str,
result: _SPFResult,
human_result: str = None,
human_result: Optional[str] = None,
):
self.spf_results.append(
_SPFResult(
+46 -36
View File
@@ -3,14 +3,21 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
# LiteralString requires Python >= 3.11, so only import it for type checking
from typing import LiteralString
try:
import psycopg
from psycopg import types as psycopg_types
except ImportError:
psycopg = None # type: ignore[assignment]
psycopg_types = None # type: ignore[assignment]
from psycopg.types import json as psycopg_json
else:
try:
import psycopg
from psycopg.types import json as psycopg_json
except ImportError:
psycopg = None
psycopg_json = None
from parsedmarc.log import logger
from parsedmarc.utils import human_timestamp_to_datetime
@@ -156,7 +163,7 @@ class PostgreSQLClient:
self._conn: Optional[psycopg.Connection] = None
self._connect()
def _connect(self) -> None:
def _connect(self) -> psycopg.Connection:
"""Open a new database connection using stored parameters.
Raises:
@@ -165,18 +172,20 @@ class PostgreSQLClient:
logger.debug("Connecting to PostgreSQL")
try:
if self._connection_string:
self._conn = psycopg.connect(self._connection_string)
conn = psycopg.connect(self._connection_string)
else:
self._conn = psycopg.connect(
conn = psycopg.connect(
host=self._host,
port=self._port,
user=self._user,
password=self._password,
dbname=self._database,
)
self._conn.autocommit = False
conn.autocommit = False
except psycopg.Error as exc:
raise PostgreSQLError(str(exc)) from exc
self._conn = conn
return conn
def close(self) -> None:
"""Close the database connection if it is open.
@@ -187,7 +196,7 @@ class PostgreSQLClient:
if self._conn is not None and not self._conn.closed:
self._conn.close()
def _ensure_connected(self) -> None:
def _ensure_connected(self) -> psycopg.Connection:
"""Check the connection health and reconnect if necessary.
When *parsedmarc* runs in watch mode the process can stay alive
@@ -197,9 +206,11 @@ class PostgreSQLClient:
and transparently re-establishes it so that subsequent
``save_*`` calls succeed without manual intervention.
"""
if self._conn is None or self._conn.closed:
conn = self._conn
if conn is None or conn.closed:
logger.warning("PostgreSQL connection lost — attempting to reconnect")
self._connect()
conn = self._connect()
return conn
def create_tables(self) -> None:
"""Creates all required tables if they do not already exist.
@@ -209,8 +220,8 @@ class PostgreSQLClient:
Raises:
PostgreSQLError: If table creation fails.
"""
self._ensure_connected()
ddl_statements = [
conn = self._ensure_connected()
ddl_statements: list[LiteralString] = [
# ----------------------------------------------------------------
# Aggregate reports
# ----------------------------------------------------------------
@@ -420,8 +431,8 @@ class PostgreSQLClient:
]
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
with conn.transaction():
with conn.cursor() as cur:
for stmt in ddl_statements:
cur.execute(stmt)
logger.debug("PostgreSQL tables verified / created")
@@ -439,13 +450,13 @@ class PostgreSQLClient:
AlreadySaved: If an identical report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
conn = self._ensure_connected()
meta = report.get("report_metadata", {})
pub = report.get("policy_published", {})
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
with conn.transaction():
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO dmarc_aggregate_report (
@@ -549,7 +560,8 @@ class PostgreSQLClient:
idens.get("envelope_to"),
),
)
record_db_id: int = cur.fetchone()[0]
# INSERT ... RETURNING always yields one row
record_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript]
for dkim in record.get("auth_results", {}).get("dkim", []):
cur.execute(
@@ -615,25 +627,21 @@ class PostgreSQLClient:
AlreadySaved: If a matching failure report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
conn = self._ensure_connected()
sample = report.get("parsed_sample", {}) or {}
src = report.get("source", {}) or {}
arrival_date_utc = _ensure_utc_suffix(report.get("arrival_date_utc"))
sample_subject = sample.get("subject")
# JSONB values are reused by both the dedup check and the INSERT.
sample_headers = (
psycopg_types.json.Jsonb(sample["headers"])
if sample.get("headers")
else None
psycopg_json.Jsonb(sample["headers"]) if sample.get("headers") else None
)
sample_from = (
psycopg_types.json.Jsonb(sample["from"]) if sample.get("from") else None
)
sample_to = psycopg_types.json.Jsonb(sample["to"]) if sample.get("to") else None
sample_from = psycopg_json.Jsonb(sample["from"]) if sample.get("from") else None
sample_to = psycopg_json.Jsonb(sample["to"]) if sample.get("to") else None
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
with conn.transaction():
with conn.cursor() as cur:
# Failure reports have no natural primary key, so mirror the
# Elasticsearch backend's query-then-insert dedup on the same
# dimensions it uses: arrival date + From + To + Subject.
@@ -721,7 +729,8 @@ class PostgreSQLClient:
sample_to,
),
)
report_db_id: int = cur.fetchone()[0]
# INSERT ... RETURNING always yields one row
report_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript]
for addr_type in ("to", "cc", "bcc", "reply_to"):
entries = sample.get(addr_type) or []
@@ -759,10 +768,10 @@ class PostgreSQLClient:
AlreadySaved: If an identical report is already present.
PostgreSQLError: If a database error occurs.
"""
self._ensure_connected()
conn = self._ensure_connected()
try:
with self._conn.transaction():
with self._conn.cursor() as cur:
with conn.transaction():
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO smtp_tls_report (
@@ -813,7 +822,8 @@ class PostgreSQLClient:
policy.get("failed_session_count"),
),
)
policy_db_id: int = cur.fetchone()[0]
# INSERT ... RETURNING always yields one row
policy_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript]
for detail in policy.get("failure_details", []):
cur.execute(
+4 -2
View File
@@ -6,7 +6,7 @@ import codecs
import os
import sys
import shutil
from typing import List, Tuple
from typing import List, Optional, Tuple
"""
Locates and optionally corrects bad UTF-8 bytes in a file.
@@ -127,7 +127,9 @@ def detect_encoding_text(path: str) -> Tuple[str, str]:
return match.encoding, str(match)
def convert_to_utf8(src_path: str, out_path: str, src_encoding: str = None) -> str:
def convert_to_utf8(
src_path: str, out_path: str, src_encoding: Optional[str] = None
) -> str:
"""
Convert an entire file to UTF-8 (re-decoding everything).
If src_encoding is provided, use it; else auto-detect.
+2 -3
View File
@@ -107,7 +107,7 @@ def sort_csv(
case_insensitive_sort: bool = False,
required_fields: Optional[Iterable[str]] = None,
allowed_values: Optional[Mapping[str, Collection[str]]] = None,
) -> List[Dict[str, str]]:
) -> None:
"""
Read a CSV, optionally normalize rows (strip whitespace, lowercase certain fields),
validate field values, and write the sorted CSV back to the same path.
@@ -124,8 +124,7 @@ def sort_csv(
required_fields = set(required_fields or [])
lower_set = set(fields_to_lowercase or [])
allowed_sets = {k: set(v) for k, v in (allowed_values or {}).items()}
if sort_field_value_must_be_unique:
seen_sort_field_values = []
seen_sort_field_values: list[str] = []
with path.open("r", newline="") as infile:
reader = csv.DictReader(infile)
+6 -4
View File
@@ -63,10 +63,12 @@ class HECClient(object):
host=self.host, source=self.source, index=self.index
)
self.session.headers = {
"User-Agent": USER_AGENT,
"Authorization": "Splunk {0}".format(self.access_token),
}
self.session.headers.update(
{
"User-Agent": USER_AGENT,
"Authorization": "Splunk {0}".format(self.access_token),
}
)
def save_aggregate_reports_to_splunk(
self,
+23 -10
View File
@@ -9,13 +9,14 @@ import logging.handlers
import socket
import ssl
import time
from typing import Any, Optional
from typing import Optional
from parsedmarc import (
parsed_aggregate_reports_to_csv_rows,
parsed_failure_reports_to_csv_rows,
parsed_smtp_tls_reports_to_csv_rows,
)
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
class SyslogClient(object):
@@ -103,8 +104,9 @@ class SyslogClient(object):
socktype=socket.SOCK_STREAM,
)
# Set timeout on the socket
if hasattr(handler, "socket") and handler.socket:
handler.socket.settimeout(timeout)
sock = getattr(handler, "socket", None)
if sock is not None:
sock.settimeout(timeout)
return handler
else:
# TLS protocol
@@ -139,12 +141,14 @@ class SyslogClient(object):
)
# Wrap socket with TLS
if hasattr(handler, "socket") and handler.socket:
handler.socket = ssl_context.wrap_socket(
handler.socket,
sock = getattr(handler, "socket", None)
if sock is not None:
tls_sock = ssl_context.wrap_socket(
sock,
server_hostname=server_name,
)
handler.socket.settimeout(timeout)
tls_sock.settimeout(timeout)
setattr(handler, "socket", tls_sock)
return handler
@@ -160,22 +164,31 @@ class SyslogClient(object):
f"Syslog connection failed after {retry_attempts} attempts: {e}"
)
raise
# Only reachable when retry_attempts < 1, which would otherwise
# silently return None and break the caller's addHandler() call.
raise ValueError("retry_attempts must be at least 1")
else:
raise ValueError(
f"Invalid protocol '{protocol}'. Must be 'udp', 'tcp', or 'tls'."
)
def save_aggregate_report_to_syslog(self, aggregate_reports: list[dict[str, Any]]):
def save_aggregate_report_to_syslog(
self, aggregate_reports: AggregateReport | list[AggregateReport]
):
rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports)
for row in rows:
self.logger.info(json.dumps(row))
def save_failure_report_to_syslog(self, failure_reports: list[dict[str, Any]]):
def save_failure_report_to_syslog(
self, failure_reports: FailureReport | list[FailureReport]
):
rows = parsed_failure_reports_to_csv_rows(failure_reports)
for row in rows:
self.logger.info(json.dumps(row))
def save_smtp_tls_report_to_syslog(self, smtp_tls_reports: list[dict[str, Any]]):
def save_smtp_tls_report_to_syslog(
self, smtp_tls_reports: SMTPTLSReport | list[SMTPTLSReport]
):
rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports)
for row in rows:
self.logger.info(json.dumps(row))
+6 -4
View File
@@ -33,10 +33,12 @@ class WebhookClient(object):
self.smtp_tls_url = smtp_tls_url
self.timeout = timeout
self.session = requests.Session()
self.session.headers = {
"User-Agent": USER_AGENT,
"Content-Type": "application/json",
}
self.session.headers.update(
{
"User-Agent": USER_AGENT,
"Content-Type": "application/json",
}
)
def save_failure_report_to_webhook(self, report: str):
self._send_to_webhook(self.failure_url, report)
+13
View File
@@ -70,6 +70,10 @@ build = [
"hatch>=1.14.0",
"myst-parser[linkify]",
"nose",
# Pinned exactly: pyright's checks evolve between releases, so an
# unpinned version could break CI without any code change. Bump
# deliberately (and fix any new findings) rather than implicitly.
"pyright==1.1.410",
"pytest",
"pytest-cov",
"ruff",
@@ -103,6 +107,15 @@ exclude = [
"parsedmarc/resources/maps/[!_]*.py",
]
[tool.pyright]
# The whole codebase passes pyright with zero errors and warnings; CI
# enforces this (see .github/workflows/python-tests.yml). Run locally with
# `pyright` from the repo root. Requires the [postgresql] extra to be
# installed so the optional psycopg import in parsedmarc/postgres.py
# resolves.
include = ["parsedmarc", "tests", "docs"]
typeCheckingMode = "standard"
[tool.pytest.ini_options]
# Default to the per-module test layout under tests/. New tests should go
# into tests/test_<module>.py to match the file they exercise; do not
+2 -1
View File
@@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import parsedmarc
import parsedmarc.cli
import parsedmarc.elastic
import parsedmarc.opensearch as opensearch_module
@@ -34,7 +35,7 @@ class _DummyMailboxConnection(parsedmarc.MailboxConnection):
self.fetch_calls.append({"reports_folder": reports_folder, **kwargs})
return []
def fetch_message(self, message_id) -> str:
def fetch_message(self, message_id, **kwargs) -> str:
return ""
def delete_message(self, message_id):
+12 -7
View File
@@ -2,15 +2,17 @@
import logging
import unittest
from typing import Any, cast
from unittest.mock import MagicMock, patch
from parsedmarc.gelf import ContextFilter, GelfClient, log_context_data
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
def _sample_aggregate_report():
def _sample_aggregate_report() -> AggregateReport:
"""Minimal aggregate report shape acceptable to
parsed_aggregate_reports_to_csv_rows."""
return {
report = {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
@@ -87,6 +89,7 @@ def _sample_aggregate_report():
}
],
}
return cast(AggregateReport, report)
class _Handler(logging.Handler):
@@ -95,7 +98,7 @@ class _Handler(logging.Handler):
def __init__(self):
super().__init__()
self.records: list[tuple[str, dict]] = []
self.records: list[tuple[str, Any]] = []
def emit(self, record):
# ContextFilter has run by this point so `record.parsedmarc` is
@@ -212,8 +215,8 @@ class TestGelfClientSaveFailure(unittest.TestCase):
reports. Build one through the CSV-row helper to verify GelfClient
surfaces the right fields."""
def _sample_failure_report(self):
return {
def _sample_failure_report(self) -> FailureReport:
report = {
"feedback_type": "auth-failure",
"user_agent": "test/1.0",
"version": "1",
@@ -243,6 +246,7 @@ class TestGelfClientSaveFailure(unittest.TestCase):
"sample": "...",
"parsed_sample": {"subject": "Test"},
}
return cast(FailureReport, report)
def test_emits_one_record_per_failure_report(self):
client = _gelf_client()
@@ -256,8 +260,8 @@ class TestGelfClientSaveFailure(unittest.TestCase):
class TestGelfClientSaveSmtpTls(unittest.TestCase):
def _sample_smtp_tls(self):
return {
def _sample_smtp_tls(self) -> SMTPTLSReport:
report = {
"organization_name": "example.com",
"begin_date": "2024-02-03T00:00:00Z",
"end_date": "2024-02-04T00:00:00Z",
@@ -272,6 +276,7 @@ class TestGelfClientSaveSmtpTls(unittest.TestCase):
}
],
}
return cast(SMTPTLSReport, report)
def test_emits_one_record_per_policy(self):
client = _gelf_client()
+15 -9
View File
@@ -568,7 +568,9 @@ class Test(unittest.TestCase):
always_use_local_files=True,
offline=True,
)
csv_text = parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
csv_text = parsedmarc.parsed_aggregate_reports_to_csv(
cast(AggregateReport, result["report"])
)
header = csv_text.splitlines()[0].split(",")
self.assertIn("source_asn", header)
self.assertIn("source_as_name", header)
@@ -2330,7 +2332,10 @@ class TestGetDmarcReportsFromMailboxValidation(unittest.TestCase):
def test_none_connection_raises(self):
with self.assertRaises(ValueError) as ctx:
parsedmarc.get_dmarc_reports_from_mailbox(connection=None)
parsedmarc.get_dmarc_reports_from_mailbox(
# Deliberately invalid: exercises the runtime None check
connection=None # pyright: ignore[reportArgumentType]
)
self.assertIn("connection", str(ctx.exception).lower())
@@ -2535,7 +2540,8 @@ class TestEmailResultsErrorBranches(unittest.TestCase):
},
host="smtp.example.com",
mail_from="from@example.com",
mail_to="admin@example.com", # str, not list — triggers assert
# str, not list — triggers assert
mail_to="admin@example.com", # pyright: ignore[reportArgumentType]
)
@@ -2548,7 +2554,7 @@ class TestAppendJson(unittest.TestCase):
path = tf.name
os.remove(path) # ensure file is fresh
try:
parsedmarc.append_json(path, [{"a": 1}])
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"a": 1}])
@@ -2560,8 +2566,8 @@ class TestAppendJson(unittest.TestCase):
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_json(path, [{"a": 1}])
parsedmarc.append_json(path, [{"b": 2}])
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
parsedmarc.append_json(path, cast(list[AggregateReport], [{"b": 2}]))
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"a": 1}, {"b": 2}])
@@ -2573,7 +2579,7 @@ class TestAppendJson(unittest.TestCase):
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
path = tf.name
try:
parsedmarc.append_json(path, [{"a": 1}])
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
parsedmarc.append_json(path, [])
with open(path) as f:
data = json.loads(f.read())
@@ -2595,7 +2601,7 @@ class TestAppendJson(unittest.TestCase):
tf.write("{ this is not valid json at all")
path = tf.name
try:
parsedmarc.append_json(path, [{"new": "data"}])
parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}]))
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"new": "data"}])
@@ -2612,7 +2618,7 @@ class TestAppendJson(unittest.TestCase):
tf.write('{"not": "a list"}')
path = tf.name
try:
parsedmarc.append_json(path, [{"new": "data"}])
parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}]))
with open(path) as f:
data = json.loads(f.read())
self.assertEqual(data, [{"new": "data"}])
+27 -21
View File
@@ -2,6 +2,7 @@
import json
import unittest
from typing import cast
from unittest.mock import MagicMock, patch
from kafka.errors import UnknownTopicOrPartitionError
@@ -9,6 +10,11 @@ from kafka.errors import UnknownTopicOrPartitionError
from parsedmarc.kafkaclient import KafkaClient, KafkaError, _BootstrapError
def _producer(client: KafkaClient) -> MagicMock:
"""The patched KafkaProducer as a MagicMock, for assertion access."""
return cast(MagicMock, client.producer)
def _aggregate_report():
return {
"report_metadata": {
@@ -121,15 +127,15 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase):
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "dmarc-aggregate")
# 2 records in the sample report → 2 producer.send calls.
self.assertEqual(client.producer.send.call_count, 2)
self.assertEqual(_producer(client).send.call_count, 2)
# Topic is forwarded verbatim.
for call in client.producer.send.call_args_list:
for call in _producer(client).send.call_args_list:
self.assertEqual(call.args[0], "dmarc-aggregate")
def test_each_slice_carries_metadata(self):
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
sent = [call.args[1] for call in client.producer.send.call_args_list]
sent = [call.args[1] for call in _producer(client).send.call_args_list]
for slice_ in sent:
self.assertEqual(slice_["org_name"], "TestOrg")
self.assertEqual(slice_["org_email"], "test@example.com")
@@ -144,32 +150,32 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase):
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_aggregate_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
_producer(client).send.assert_not_called()
def test_dict_input_normalized_to_list(self):
"""Single-report dict input is wrapped to a list."""
client = self._client()
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
# 2 records still sent (one report with 2 records, not multiple reports).
self.assertEqual(client.producer.send.call_count, 2)
self.assertEqual(_producer(client).send.call_count, 2)
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "missing")
self.assertIn("Unknown topic or partition", str(ctx.exception))
def test_generic_send_exception_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = RuntimeError("transport failure")
_producer(client).send.side_effect = RuntimeError("transport failure")
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
self.assertIn("transport failure", str(ctx.exception))
def test_flush_exception_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = RuntimeError("flush failure")
_producer(client).flush.side_effect = RuntimeError("flush failure")
with self.assertRaises(KafkaError) as ctx:
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
self.assertIn("flush failure", str(ctx.exception))
@@ -186,35 +192,35 @@ class TestSaveFailureReportsToKafka(unittest.TestCase):
client = self._client()
reports = [{"id": "f1"}, {"id": "f2"}]
client.save_failure_reports_to_kafka(reports, "dmarc-failure")
client.producer.send.assert_called_once_with("dmarc-failure", reports)
_producer(client).send.assert_called_once_with("dmarc-failure", reports)
def test_dict_input_normalized_to_list(self):
client = self._client()
client.save_failure_reports_to_kafka({"id": "single"}, "topic")
# The send payload is wrapped to a single-element list.
args = client.producer.send.call_args.args
args = _producer(client).send.call_args.args
self.assertEqual(args[1], [{"id": "single"}])
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_failure_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
_producer(client).send.assert_not_called()
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "missing")
def test_generic_send_error_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = OSError("net")
_producer(client).send.side_effect = OSError("net")
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
def test_flush_error_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = OSError("flush")
_producer(client).flush.side_effect = OSError("flush")
with self.assertRaises(KafkaError):
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
@@ -228,34 +234,34 @@ class TestSaveSmtpTlsReportsToKafka(unittest.TestCase):
client = self._client()
reports = [{"organization_name": "x"}]
client.save_smtp_tls_reports_to_kafka(reports, "smtp-tls")
client.producer.send.assert_called_once_with("smtp-tls", reports)
_producer(client).send.assert_called_once_with("smtp-tls", reports)
def test_dict_input_normalized_to_list(self):
client = self._client()
client.save_smtp_tls_reports_to_kafka({"organization_name": "x"}, "topic")
args = client.producer.send.call_args.args
args = _producer(client).send.call_args.args
self.assertEqual(args[1], [{"organization_name": "x"}])
def test_empty_list_is_a_noop(self):
client = self._client()
client.save_smtp_tls_reports_to_kafka([], "topic")
client.producer.send.assert_not_called()
_producer(client).send.assert_not_called()
def test_unknown_topic_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = UnknownTopicOrPartitionError()
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "missing")
def test_generic_send_error_translates_to_kafka_error(self):
client = self._client()
client.producer.send.side_effect = RuntimeError("oops")
_producer(client).send.side_effect = RuntimeError("oops")
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
def test_flush_error_translates_to_kafka_error(self):
client = self._client()
client.producer.flush.side_effect = RuntimeError("flush")
_producer(client).flush.side_effect = RuntimeError("flush")
with self.assertRaises(KafkaError):
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
@@ -265,7 +271,7 @@ class TestKafkaClientClose(unittest.TestCase):
with patch("parsedmarc.kafkaclient.KafkaProducer"):
client = KafkaClient(kafka_hosts=["b"])
client.close()
client.producer.close.assert_called_once()
_producer(client).close.assert_called_once()
class TestKafkaBackwardCompatAlias(unittest.TestCase):
+15 -8
View File
@@ -10,6 +10,7 @@ real-sample round trip, so the tests fail if the dict-key mapping regresses.
import os
import unittest
from glob import glob
from typing import cast
from unittest.mock import MagicMock, patch
import parsedmarc
@@ -27,7 +28,7 @@ OFFLINE_MODE = os.environ.get("GITHUB_ACTIONS", "false").lower() == "true"
# psycopg is an optional dependency and is not installed in CI (which installs
# only the [build] extra). The save methods mock the connection, but the
# failure path also references ``psycopg_types.json.Jsonb`` at module scope, so
# failure path also references ``psycopg_json.Jsonb`` at module scope, so
# mock that SDK boundary for the whole module when psycopg is absent.
_types_patcher = None
@@ -36,8 +37,8 @@ def setUpModule():
global _types_patcher
import parsedmarc.postgres as pg
if pg.psycopg_types is None:
_types_patcher = patch("parsedmarc.postgres.psycopg_types", MagicMock())
if pg.psycopg_json is None:
_types_patcher = patch("parsedmarc.postgres.psycopg_json", MagicMock())
_types_patcher.start()
@@ -99,6 +100,7 @@ class TestPostgreSQLHelpers(unittest.TestCase):
def test_naive_local_to_timestamptz_valid(self):
"""A valid naive string is returned with a timezone offset."""
result = _naive_local_to_timestamptz("2024-01-15 10:30:00")
assert result is not None
self.assertIsInstance(result, str)
self.assertTrue(
"+" in result or "-" in result[10:],
@@ -125,11 +127,13 @@ class TestPostgreSQLHelpers(unittest.TestCase):
def test_normalize_arrival_date_iso_naive_utc(self):
"""A naive ISO string (known UTC) is returned with +00 suffix."""
result = _normalize_arrival_date("2024-01-15 10:30:00")
assert result is not None
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
def test_normalize_arrival_date_rfc2822(self):
"""An RFC 2822 date is converted to UTC with +00 suffix."""
result = _normalize_arrival_date("Fri, 28 Oct 2022 00:34:24 +0800")
assert result is not None
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
# 00:34:24 +0800 is 16:34:24 UTC on 27 Oct 2022.
self.assertIn("2022-10-27", result)
@@ -138,6 +142,7 @@ class TestPostgreSQLHelpers(unittest.TestCase):
def test_normalize_arrival_date_already_utc(self):
"""A string already ending with +00 still works."""
result = _normalize_arrival_date("2024-01-15 10:30:00+00")
assert result is not None
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
def test_normalize_arrival_date_unparseable(self):
@@ -167,7 +172,8 @@ class TestPostgreSQLHelpers(unittest.TestCase):
def test_contact_info_to_text_numeric(self):
"""Non-string scalars are converted via str()."""
self.assertEqual(_contact_info_to_text(123), "123")
# Deliberately outside the annotated parameter types
self.assertEqual(_contact_info_to_text(123), "123") # pyright: ignore[reportArgumentType]
def _make_client():
@@ -180,8 +186,8 @@ def _make_client():
client = PostgreSQLClient(
host="localhost", database="test", user="test", password="test"
)
mock_conn.closed = False
client._conn = mock_conn
client._conn.closed = False
return client, mock_conn
@@ -211,6 +217,7 @@ def _named_params(call):
sql = call.args[0]
m = re.search(r"\(([^)]*?)\)\s*VALUES", sql, re.S)
assert m is not None
cols = [c.strip() for c in m.group(1).split(",") if c.strip()]
return dict(zip(cols, call.args[1]))
@@ -758,7 +765,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
num_records = len(report.get("records", []))
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_records)])
try:
client.save_aggregate_report_to_postgresql(report)
client.save_aggregate_report_to_postgresql(cast(dict, report))
saved += 1
except Exception as exc:
self.fail(f"aggregate save failed for {sample_path}: {exc}")
@@ -783,7 +790,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
# Dedup SELECT returns None (not a dup), then the INSERT id.
_mock_cursor(mock_conn, [None, (1,)])
try:
client.save_failure_report_to_postgresql(report)
client.save_failure_report_to_postgresql(cast(dict, report))
saved += 1
except Exception as exc:
self.fail(f"failure save failed for {sample_path}: {exc}")
@@ -807,7 +814,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
num_policies = len(report.get("policies", []))
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_policies)])
try:
client.save_smtp_tls_report_to_postgresql(report)
client.save_smtp_tls_report_to_postgresql(cast(dict, report))
saved += 1
except Exception as exc:
self.fail(f"smtp_tls save failed for {sample_path}: {exc}")
+20 -4
View File
@@ -6,11 +6,14 @@ import socket
import unittest
from unittest.mock import MagicMock, patch
from typing import cast
from parsedmarc.syslog import SyslogClient
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
def _sample_aggregate_report():
return {
def _sample_aggregate_report() -> AggregateReport:
report = {
"xml_schema": "draft",
"xml_namespace": None,
"report_metadata": {
@@ -70,6 +73,7 @@ def _sample_aggregate_report():
}
],
}
return cast(AggregateReport, report)
class _CapturingHandler(logging.Handler):
@@ -259,6 +263,18 @@ class TestSyslogClientInitInvalidProtocol(unittest.TestCase):
self.assertIn("udb", str(ctx.exception))
self.assertIn("'udp', 'tcp', or 'tls'", str(ctx.exception))
def test_zero_retry_attempts_raises_value_error(self):
"""retry_attempts < 1 means the TCP/TLS connect loop never runs.
Before the fix, _create_syslog_handler fell through and returned
None, which was then passed to logger.addHandler(); now it raises
ValueError instead of silently configuring a broken client."""
_fresh_logger()
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as handler_cls:
with self.assertRaises(ValueError) as ctx:
SyslogClient("s", 514, protocol="tcp", retry_attempts=0)
handler_cls.assert_not_called()
self.assertIn("retry_attempts", str(ctx.exception))
class TestSyslogClientSave(unittest.TestCase):
"""save_* methods emit one syslog message per CSV row, each as a
@@ -314,7 +330,7 @@ class TestSyslogClientSave(unittest.TestCase):
"sample": "...",
"parsed_sample": {"subject": "Test"},
}
client.save_failure_report_to_syslog([failure_report])
client.save_failure_report_to_syslog([cast(FailureReport, failure_report)])
self.assertEqual(len(cap.messages), 1)
payload = json.loads(cap.messages[0])
self.assertEqual(payload["reported_domain"], "example.com")
@@ -337,7 +353,7 @@ class TestSyslogClientSave(unittest.TestCase):
}
],
}
client.save_smtp_tls_report_to_syslog([report])
client.save_smtp_tls_report_to_syslog([cast(SMTPTLSReport, report)])
self.assertEqual(len(cap.messages), 1)
payload = json.loads(cap.messages[0])
self.assertEqual(payload["policy_domain"], "example.com")