diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index a617098..cda25e2 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -22,10 +22,16 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install .[build] + # postgresql extra included so pyright can resolve the optional + # psycopg import in parsedmarc/postgres.py + pip install .[build,postgresql] - name: Check code style run: | ruff check . + ruff format --check . + - name: Check types + run: | + pyright - name: Test building documentation run: | cd docs diff --git a/AGENTS.md b/AGENTS.md index 24b1791..b49c27d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,11 @@ pytest tests/test_init.py::Test::testAggregateSamples ruff check . ruff format . +# Type check (config in pyproject.toml [tool.pyright]; CI enforces zero +# errors/warnings; needs the [postgresql] extra installed so the optional +# psycopg import resolves) +pyright + # Test CLI with sample reports parsedmarc --debug -c ci.ini samples/aggregate/* parsedmarc --debug -c ci.ini samples/failure/* @@ -108,6 +113,7 @@ IP address info cached for 4 hours, seen aggregate report IDs cached for 1 hour ## Code Style - Ruff for formatting and linting (configured in `.vscode/settings.json`). Run `ruff check .` and `ruff format --check .` after every code edit, before committing. +- Pyright for type checking (configured in `pyproject.toml` `[tool.pyright]`, pinned in the `[build]` extra, enforced in CI). Run `pyright` from the repo root before committing; the whole codebase — library and tests — must stay at zero errors and warnings. Prefer real fixes (narrowing, `Optional` annotations, `TYPE_CHECKING` imports) over `# pyright: ignore[...]`; reserve targeted ignores for deliberate wrong-type tests and version-conditional imports, and never use a bare blanket ignore. - TypedDict for structured data, type hints throughout. - Python ≥3.10 required. - Tests live under `tests/` as `tests/test_.py`, one per top-level `parsedmarc/*` module (e.g. `tests/test_init.py` for `parsedmarc/__init__.py`, `tests/test_cli.py` for `parsedmarc/cli.py`). All test classes use `unittest`. Sample reports live in `samples/`. Run with `pytest tests/`; run one file with `pytest tests/test_init.py`. New tests go in the file whose module they exercise — do not reintroduce a monolithic test file. diff --git a/CHANGELOG.md b/CHANGELOG.md index 1af4f5f..e6eff99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ - **Example systemd unit** in `docs/source/usage.md` now sets `KillSignal=SIGTERM` and `TimeoutStopSec=60` so systemd waits long enough for the watcher to drain (keep it above `mailbox_check_timeout`). - Switch the Kafka client dependency from `kafka-python-ng` back to `kafka-python>=2.3.2` ([#795](https://github.com/domainaware/parsedmarc/issues/795)). `kafka-python-ng` was a fork created while `kafka-python` was unmaintained; upstream `kafka-python` is active again, and the now-archived fork is vulnerable to [CVE-2026-10142](https://nvd.nist.gov/vuln/detail/CVE-2026-10142) and [CVE-2026-10143](https://nvd.nist.gov/vuln/detail/CVE-2026-10143), both fixed in `kafka-python` 2.3.2. Both packages install the same `kafka` module, so if you are upgrading an existing environment in place with `pip`, run `pip uninstall kafka-python-ng` before upgrading parsedmarc so the two distributions don't conflict with each other's files. - parsedmarc is compatible with both `kafka-python` 2.3.2+ and 3.x: `kafka-python` 3.0 removed the `NoBrokersAvailable` exception (a failed bootstrap now raises `KafkaTimeoutError`), and parsedmarc handles whichever the installed version provides. +- The whole codebase (library and tests) now passes `pyright` with zero errors and warnings, and CI enforces this (plus `ruff format --check`) on every push and pull request. Pyright is configured in `pyproject.toml` (`[tool.pyright]`) and pinned in the `[build]` extra. The fixes are annotation-level only (`Optional` parameters, TypedDict-aware signatures on the syslog/GELF save methods, `TYPE_CHECKING`-aware optional imports for `psycopg` and the `kafka-python` 2.x/3.x bootstrap-error fallback) — runtime behavior is unchanged apart from the `SyslogClient` fix below. Builds on the class-body alias declarations from [#797](https://github.com/domainaware/parsedmarc/pull/797). + +### Bug fixes + +- `SyslogClient`: constructing a TCP/TLS client with `retry_attempts` < 1 now raises `ValueError` instead of silently skipping the connection loop and registering a broken (`None`) log handler. ## 10.0.4 diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index 185f7b5..1433062 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -125,7 +125,7 @@ class _AggregateReportDoc(Document): domain: str, selector: str, result: _DKIMResult, - human_result: str = None, + human_result: Optional[str] = None, ): self.dkim_results.append( _DKIMResult( @@ -141,7 +141,7 @@ class _AggregateReportDoc(Document): domain: str, scope: str, result: _SPFResult, - human_result: str = None, + human_result: Optional[str] = None, ): self.spf_results.append( _SPFResult( diff --git a/parsedmarc/gelf.py b/parsedmarc/gelf.py index 374a8bc..2b1f6b1 100644 --- a/parsedmarc/gelf.py +++ b/parsedmarc/gelf.py @@ -12,9 +12,7 @@ from parsedmarc import ( parsed_failure_reports_to_csv_rows, parsed_smtp_tls_reports_to_csv_rows, ) -from typing import Any - -from parsedmarc.types import AggregateReport, SMTPTLSReport +from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport log_context_data = threading.local() @@ -51,7 +49,9 @@ class GelfClient(object): ) self.logger.addHandler(self.handler) - def save_aggregate_report_to_gelf(self, aggregate_reports: list[AggregateReport]): + def save_aggregate_report_to_gelf( + self, aggregate_reports: AggregateReport | list[AggregateReport] + ): rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports) for row in rows: log_context_data.parsedmarc = row @@ -59,13 +59,17 @@ class GelfClient(object): log_context_data.parsedmarc = None - def save_failure_report_to_gelf(self, failure_reports: list[dict[str, Any]]): + def save_failure_report_to_gelf( + self, failure_reports: FailureReport | list[FailureReport] + ): rows = parsed_failure_reports_to_csv_rows(failure_reports) for row in rows: log_context_data.parsedmarc = row self.logger.info("parsedmarc failure report") - def save_smtp_tls_report_to_gelf(self, smtp_tls_reports: SMTPTLSReport): + def save_smtp_tls_report_to_gelf( + self, smtp_tls_reports: SMTPTLSReport | list[SMTPTLSReport] + ): rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports) for row in rows: log_context_data.parsedmarc = row diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 2b8cf04..f22b9e6 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -4,18 +4,21 @@ from __future__ import annotations import json from ssl import SSLContext, create_default_context -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from kafka import KafkaProducer from kafka.errors import UnknownTopicOrPartitionError -try: - # kafka-python < 3.0 raises this when the producer cannot bootstrap - from kafka.errors import NoBrokersAvailable as _BootstrapError # pyright: ignore[reportAttributeAccessIssue] -except ImportError: - # kafka-python >= 3.0 removed NoBrokersAvailable; a failed bootstrap - # raises KafkaTimeoutError instead +if TYPE_CHECKING: from kafka.errors import KafkaTimeoutError as _BootstrapError +else: + try: + # kafka-python < 3.0 raises this when the producer cannot bootstrap + from kafka.errors import NoBrokersAvailable as _BootstrapError + except ImportError: + # kafka-python >= 3.0 removed NoBrokersAvailable; a failed bootstrap + # raises KafkaTimeoutError instead + from kafka.errors import KafkaTimeoutError as _BootstrapError from parsedmarc import __version__ from parsedmarc.log import logger @@ -185,6 +188,9 @@ class KafkaClient(object): except Exception as e: raise KafkaError("Kafka error: {0}".format(e.__str__())) + # Backward-compatible alias + save_forensic_reports_to_kafka = save_failure_reports_to_kafka + def save_smtp_tls_reports_to_kafka( self, smtp_tls_reports: Union[list[dict[str, Any]], dict[str, Any]], @@ -218,6 +224,3 @@ class KafkaClient(object): self.producer.flush() except Exception as e: raise KafkaError("Kafka error: {0}".format(e.__str__())) - - # Backward-compatible alias - save_forensic_reports_to_kafka = save_failure_reports_to_kafka diff --git a/parsedmarc/opensearch.py b/parsedmarc/opensearch.py index 5f4d958..898707c 100644 --- a/parsedmarc/opensearch.py +++ b/parsedmarc/opensearch.py @@ -116,7 +116,7 @@ class _AggregateReportDoc(Document): domain: str, selector: str, result: _DKIMResult, - human_result: str = None, + human_result: Optional[str] = None, ): self.dkim_results.append( _DKIMResult( @@ -132,7 +132,7 @@ class _AggregateReportDoc(Document): domain: str, scope: str, result: _SPFResult, - human_result: str = None, + human_result: Optional[str] = None, ): self.spf_results.append( _SPFResult( diff --git a/parsedmarc/postgres.py b/parsedmarc/postgres.py index 32ce3bd..acf3acb 100644 --- a/parsedmarc/postgres.py +++ b/parsedmarc/postgres.py @@ -3,14 +3,21 @@ from __future__ import annotations from datetime import datetime -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + # LiteralString requires Python >= 3.11, so only import it for type checking + from typing import LiteralString -try: import psycopg - from psycopg import types as psycopg_types -except ImportError: - psycopg = None # type: ignore[assignment] - psycopg_types = None # type: ignore[assignment] + from psycopg.types import json as psycopg_json +else: + try: + import psycopg + from psycopg.types import json as psycopg_json + except ImportError: + psycopg = None + psycopg_json = None from parsedmarc.log import logger from parsedmarc.utils import human_timestamp_to_datetime @@ -156,7 +163,7 @@ class PostgreSQLClient: self._conn: Optional[psycopg.Connection] = None self._connect() - def _connect(self) -> None: + def _connect(self) -> psycopg.Connection: """Open a new database connection using stored parameters. Raises: @@ -165,18 +172,20 @@ class PostgreSQLClient: logger.debug("Connecting to PostgreSQL") try: if self._connection_string: - self._conn = psycopg.connect(self._connection_string) + conn = psycopg.connect(self._connection_string) else: - self._conn = psycopg.connect( + conn = psycopg.connect( host=self._host, port=self._port, user=self._user, password=self._password, dbname=self._database, ) - self._conn.autocommit = False + conn.autocommit = False except psycopg.Error as exc: raise PostgreSQLError(str(exc)) from exc + self._conn = conn + return conn def close(self) -> None: """Close the database connection if it is open. @@ -187,7 +196,7 @@ class PostgreSQLClient: if self._conn is not None and not self._conn.closed: self._conn.close() - def _ensure_connected(self) -> None: + def _ensure_connected(self) -> psycopg.Connection: """Check the connection health and reconnect if necessary. When *parsedmarc* runs in watch mode the process can stay alive @@ -197,9 +206,11 @@ class PostgreSQLClient: and transparently re-establishes it so that subsequent ``save_*`` calls succeed without manual intervention. """ - if self._conn is None or self._conn.closed: + conn = self._conn + if conn is None or conn.closed: logger.warning("PostgreSQL connection lost — attempting to reconnect") - self._connect() + conn = self._connect() + return conn def create_tables(self) -> None: """Creates all required tables if they do not already exist. @@ -209,8 +220,8 @@ class PostgreSQLClient: Raises: PostgreSQLError: If table creation fails. """ - self._ensure_connected() - ddl_statements = [ + conn = self._ensure_connected() + ddl_statements: list[LiteralString] = [ # ---------------------------------------------------------------- # Aggregate reports # ---------------------------------------------------------------- @@ -420,8 +431,8 @@ class PostgreSQLClient: ] try: - with self._conn.transaction(): - with self._conn.cursor() as cur: + with conn.transaction(): + with conn.cursor() as cur: for stmt in ddl_statements: cur.execute(stmt) logger.debug("PostgreSQL tables verified / created") @@ -439,13 +450,13 @@ class PostgreSQLClient: AlreadySaved: If an identical report is already present. PostgreSQLError: If a database error occurs. """ - self._ensure_connected() + conn = self._ensure_connected() meta = report.get("report_metadata", {}) pub = report.get("policy_published", {}) try: - with self._conn.transaction(): - with self._conn.cursor() as cur: + with conn.transaction(): + with conn.cursor() as cur: cur.execute( """ INSERT INTO dmarc_aggregate_report ( @@ -549,7 +560,8 @@ class PostgreSQLClient: idens.get("envelope_to"), ), ) - record_db_id: int = cur.fetchone()[0] + # INSERT ... RETURNING always yields one row + record_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript] for dkim in record.get("auth_results", {}).get("dkim", []): cur.execute( @@ -615,25 +627,21 @@ class PostgreSQLClient: AlreadySaved: If a matching failure report is already present. PostgreSQLError: If a database error occurs. """ - self._ensure_connected() + conn = self._ensure_connected() sample = report.get("parsed_sample", {}) or {} src = report.get("source", {}) or {} arrival_date_utc = _ensure_utc_suffix(report.get("arrival_date_utc")) sample_subject = sample.get("subject") # JSONB values are reused by both the dedup check and the INSERT. sample_headers = ( - psycopg_types.json.Jsonb(sample["headers"]) - if sample.get("headers") - else None + psycopg_json.Jsonb(sample["headers"]) if sample.get("headers") else None ) - sample_from = ( - psycopg_types.json.Jsonb(sample["from"]) if sample.get("from") else None - ) - sample_to = psycopg_types.json.Jsonb(sample["to"]) if sample.get("to") else None + sample_from = psycopg_json.Jsonb(sample["from"]) if sample.get("from") else None + sample_to = psycopg_json.Jsonb(sample["to"]) if sample.get("to") else None try: - with self._conn.transaction(): - with self._conn.cursor() as cur: + with conn.transaction(): + with conn.cursor() as cur: # Failure reports have no natural primary key, so mirror the # Elasticsearch backend's query-then-insert dedup on the same # dimensions it uses: arrival date + From + To + Subject. @@ -721,7 +729,8 @@ class PostgreSQLClient: sample_to, ), ) - report_db_id: int = cur.fetchone()[0] + # INSERT ... RETURNING always yields one row + report_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript] for addr_type in ("to", "cc", "bcc", "reply_to"): entries = sample.get(addr_type) or [] @@ -759,10 +768,10 @@ class PostgreSQLClient: AlreadySaved: If an identical report is already present. PostgreSQLError: If a database error occurs. """ - self._ensure_connected() + conn = self._ensure_connected() try: - with self._conn.transaction(): - with self._conn.cursor() as cur: + with conn.transaction(): + with conn.cursor() as cur: cur.execute( """ INSERT INTO smtp_tls_report ( @@ -813,7 +822,8 @@ class PostgreSQLClient: policy.get("failed_session_count"), ), ) - policy_db_id: int = cur.fetchone()[0] + # INSERT ... RETURNING always yields one row + policy_db_id: int = cur.fetchone()[0] # pyright: ignore[reportOptionalSubscript] for detail in policy.get("failure_details", []): cur.execute( diff --git a/parsedmarc/resources/maps/find_bad_utf8.py b/parsedmarc/resources/maps/find_bad_utf8.py index 90ddb0e..4d447fe 100755 --- a/parsedmarc/resources/maps/find_bad_utf8.py +++ b/parsedmarc/resources/maps/find_bad_utf8.py @@ -6,7 +6,7 @@ import codecs import os import sys import shutil -from typing import List, Tuple +from typing import List, Optional, Tuple """ Locates and optionally corrects bad UTF-8 bytes in a file. @@ -127,7 +127,9 @@ def detect_encoding_text(path: str) -> Tuple[str, str]: return match.encoding, str(match) -def convert_to_utf8(src_path: str, out_path: str, src_encoding: str = None) -> str: +def convert_to_utf8( + src_path: str, out_path: str, src_encoding: Optional[str] = None +) -> str: """ Convert an entire file to UTF-8 (re-decoding everything). If src_encoding is provided, use it; else auto-detect. diff --git a/parsedmarc/resources/maps/sortlists.py b/parsedmarc/resources/maps/sortlists.py index 9595b38..0a30e3b 100755 --- a/parsedmarc/resources/maps/sortlists.py +++ b/parsedmarc/resources/maps/sortlists.py @@ -107,7 +107,7 @@ def sort_csv( case_insensitive_sort: bool = False, required_fields: Optional[Iterable[str]] = None, allowed_values: Optional[Mapping[str, Collection[str]]] = None, -) -> List[Dict[str, str]]: +) -> None: """ Read a CSV, optionally normalize rows (strip whitespace, lowercase certain fields), validate field values, and write the sorted CSV back to the same path. @@ -124,8 +124,7 @@ def sort_csv( required_fields = set(required_fields or []) lower_set = set(fields_to_lowercase or []) allowed_sets = {k: set(v) for k, v in (allowed_values or {}).items()} - if sort_field_value_must_be_unique: - seen_sort_field_values = [] + seen_sort_field_values: list[str] = [] with path.open("r", newline="") as infile: reader = csv.DictReader(infile) diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index a1aa6ac..4fd2525 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -63,10 +63,12 @@ class HECClient(object): host=self.host, source=self.source, index=self.index ) - self.session.headers = { - "User-Agent": USER_AGENT, - "Authorization": "Splunk {0}".format(self.access_token), - } + self.session.headers.update( + { + "User-Agent": USER_AGENT, + "Authorization": "Splunk {0}".format(self.access_token), + } + ) def save_aggregate_reports_to_splunk( self, diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index 7862797..4e5b478 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -9,13 +9,14 @@ import logging.handlers import socket import ssl import time -from typing import Any, Optional +from typing import Optional from parsedmarc import ( parsed_aggregate_reports_to_csv_rows, parsed_failure_reports_to_csv_rows, parsed_smtp_tls_reports_to_csv_rows, ) +from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport class SyslogClient(object): @@ -103,8 +104,9 @@ class SyslogClient(object): socktype=socket.SOCK_STREAM, ) # Set timeout on the socket - if hasattr(handler, "socket") and handler.socket: - handler.socket.settimeout(timeout) + sock = getattr(handler, "socket", None) + if sock is not None: + sock.settimeout(timeout) return handler else: # TLS protocol @@ -139,12 +141,14 @@ class SyslogClient(object): ) # Wrap socket with TLS - if hasattr(handler, "socket") and handler.socket: - handler.socket = ssl_context.wrap_socket( - handler.socket, + sock = getattr(handler, "socket", None) + if sock is not None: + tls_sock = ssl_context.wrap_socket( + sock, server_hostname=server_name, ) - handler.socket.settimeout(timeout) + tls_sock.settimeout(timeout) + setattr(handler, "socket", tls_sock) return handler @@ -160,22 +164,31 @@ class SyslogClient(object): f"Syslog connection failed after {retry_attempts} attempts: {e}" ) raise + # Only reachable when retry_attempts < 1, which would otherwise + # silently return None and break the caller's addHandler() call. + raise ValueError("retry_attempts must be at least 1") else: raise ValueError( f"Invalid protocol '{protocol}'. Must be 'udp', 'tcp', or 'tls'." ) - def save_aggregate_report_to_syslog(self, aggregate_reports: list[dict[str, Any]]): + def save_aggregate_report_to_syslog( + self, aggregate_reports: AggregateReport | list[AggregateReport] + ): rows = parsed_aggregate_reports_to_csv_rows(aggregate_reports) for row in rows: self.logger.info(json.dumps(row)) - def save_failure_report_to_syslog(self, failure_reports: list[dict[str, Any]]): + def save_failure_report_to_syslog( + self, failure_reports: FailureReport | list[FailureReport] + ): rows = parsed_failure_reports_to_csv_rows(failure_reports) for row in rows: self.logger.info(json.dumps(row)) - def save_smtp_tls_report_to_syslog(self, smtp_tls_reports: list[dict[str, Any]]): + def save_smtp_tls_report_to_syslog( + self, smtp_tls_reports: SMTPTLSReport | list[SMTPTLSReport] + ): rows = parsed_smtp_tls_reports_to_csv_rows(smtp_tls_reports) for row in rows: self.logger.info(json.dumps(row)) diff --git a/parsedmarc/webhook.py b/parsedmarc/webhook.py index 7d2dfc8..ce9d576 100644 --- a/parsedmarc/webhook.py +++ b/parsedmarc/webhook.py @@ -33,10 +33,12 @@ class WebhookClient(object): self.smtp_tls_url = smtp_tls_url self.timeout = timeout self.session = requests.Session() - self.session.headers = { - "User-Agent": USER_AGENT, - "Content-Type": "application/json", - } + self.session.headers.update( + { + "User-Agent": USER_AGENT, + "Content-Type": "application/json", + } + ) def save_failure_report_to_webhook(self, report: str): self._send_to_webhook(self.failure_url, report) diff --git a/pyproject.toml b/pyproject.toml index 89c099c..1ba9032 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,10 @@ build = [ "hatch>=1.14.0", "myst-parser[linkify]", "nose", + # Pinned exactly: pyright's checks evolve between releases, so an + # unpinned version could break CI without any code change. Bump + # deliberately (and fix any new findings) rather than implicitly. + "pyright==1.1.410", "pytest", "pytest-cov", "ruff", @@ -103,6 +107,15 @@ exclude = [ "parsedmarc/resources/maps/[!_]*.py", ] +[tool.pyright] +# The whole codebase passes pyright with zero errors and warnings; CI +# enforces this (see .github/workflows/python-tests.yml). Run locally with +# `pyright` from the repo root. Requires the [postgresql] extra to be +# installed so the optional psycopg import in parsedmarc/postgres.py +# resolves. +include = ["parsedmarc", "tests", "docs"] +typeCheckingMode = "standard" + [tool.pytest.ini_options] # Default to the per-module test layout under tests/. New tests should go # into tests/test_.py to match the file they exercise; do not diff --git a/tests/test_cli.py b/tests/test_cli.py index b208a7d..940faa2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import parsedmarc import parsedmarc.cli +import parsedmarc.elastic import parsedmarc.opensearch as opensearch_module @@ -34,7 +35,7 @@ class _DummyMailboxConnection(parsedmarc.MailboxConnection): self.fetch_calls.append({"reports_folder": reports_folder, **kwargs}) return [] - def fetch_message(self, message_id) -> str: + def fetch_message(self, message_id, **kwargs) -> str: return "" def delete_message(self, message_id): diff --git a/tests/test_gelf.py b/tests/test_gelf.py index 6974f16..091c221 100644 --- a/tests/test_gelf.py +++ b/tests/test_gelf.py @@ -2,15 +2,17 @@ import logging import unittest +from typing import Any, cast from unittest.mock import MagicMock, patch from parsedmarc.gelf import ContextFilter, GelfClient, log_context_data +from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport -def _sample_aggregate_report(): +def _sample_aggregate_report() -> AggregateReport: """Minimal aggregate report shape acceptable to parsed_aggregate_reports_to_csv_rows.""" - return { + report = { "xml_schema": "draft", "xml_namespace": None, "report_metadata": { @@ -87,6 +89,7 @@ def _sample_aggregate_report(): } ], } + return cast(AggregateReport, report) class _Handler(logging.Handler): @@ -95,7 +98,7 @@ class _Handler(logging.Handler): def __init__(self): super().__init__() - self.records: list[tuple[str, dict]] = [] + self.records: list[tuple[str, Any]] = [] def emit(self, record): # ContextFilter has run by this point so `record.parsedmarc` is @@ -212,8 +215,8 @@ class TestGelfClientSaveFailure(unittest.TestCase): reports. Build one through the CSV-row helper to verify GelfClient surfaces the right fields.""" - def _sample_failure_report(self): - return { + def _sample_failure_report(self) -> FailureReport: + report = { "feedback_type": "auth-failure", "user_agent": "test/1.0", "version": "1", @@ -243,6 +246,7 @@ class TestGelfClientSaveFailure(unittest.TestCase): "sample": "...", "parsed_sample": {"subject": "Test"}, } + return cast(FailureReport, report) def test_emits_one_record_per_failure_report(self): client = _gelf_client() @@ -256,8 +260,8 @@ class TestGelfClientSaveFailure(unittest.TestCase): class TestGelfClientSaveSmtpTls(unittest.TestCase): - def _sample_smtp_tls(self): - return { + def _sample_smtp_tls(self) -> SMTPTLSReport: + report = { "organization_name": "example.com", "begin_date": "2024-02-03T00:00:00Z", "end_date": "2024-02-04T00:00:00Z", @@ -272,6 +276,7 @@ class TestGelfClientSaveSmtpTls(unittest.TestCase): } ], } + return cast(SMTPTLSReport, report) def test_emits_one_record_per_policy(self): client = _gelf_client() diff --git a/tests/test_init.py b/tests/test_init.py index f4c1690..7546c22 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -568,7 +568,9 @@ class Test(unittest.TestCase): always_use_local_files=True, offline=True, ) - csv_text = parsedmarc.parsed_aggregate_reports_to_csv(result["report"]) + csv_text = parsedmarc.parsed_aggregate_reports_to_csv( + cast(AggregateReport, result["report"]) + ) header = csv_text.splitlines()[0].split(",") self.assertIn("source_asn", header) self.assertIn("source_as_name", header) @@ -2330,7 +2332,10 @@ class TestGetDmarcReportsFromMailboxValidation(unittest.TestCase): def test_none_connection_raises(self): with self.assertRaises(ValueError) as ctx: - parsedmarc.get_dmarc_reports_from_mailbox(connection=None) + parsedmarc.get_dmarc_reports_from_mailbox( + # Deliberately invalid: exercises the runtime None check + connection=None # pyright: ignore[reportArgumentType] + ) self.assertIn("connection", str(ctx.exception).lower()) @@ -2535,7 +2540,8 @@ class TestEmailResultsErrorBranches(unittest.TestCase): }, host="smtp.example.com", mail_from="from@example.com", - mail_to="admin@example.com", # str, not list — triggers assert + # str, not list — triggers assert + mail_to="admin@example.com", # pyright: ignore[reportArgumentType] ) @@ -2548,7 +2554,7 @@ class TestAppendJson(unittest.TestCase): path = tf.name os.remove(path) # ensure file is fresh try: - parsedmarc.append_json(path, [{"a": 1}]) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}])) with open(path) as f: data = json.loads(f.read()) self.assertEqual(data, [{"a": 1}]) @@ -2560,8 +2566,8 @@ class TestAppendJson(unittest.TestCase): with NamedTemporaryFile("w", suffix=".json", delete=False) as tf: path = tf.name try: - parsedmarc.append_json(path, [{"a": 1}]) - parsedmarc.append_json(path, [{"b": 2}]) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}])) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"b": 2}])) with open(path) as f: data = json.loads(f.read()) self.assertEqual(data, [{"a": 1}, {"b": 2}]) @@ -2573,7 +2579,7 @@ class TestAppendJson(unittest.TestCase): with NamedTemporaryFile("w", suffix=".json", delete=False) as tf: path = tf.name try: - parsedmarc.append_json(path, [{"a": 1}]) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}])) parsedmarc.append_json(path, []) with open(path) as f: data = json.loads(f.read()) @@ -2595,7 +2601,7 @@ class TestAppendJson(unittest.TestCase): tf.write("{ this is not valid json at all") path = tf.name try: - parsedmarc.append_json(path, [{"new": "data"}]) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}])) with open(path) as f: data = json.loads(f.read()) self.assertEqual(data, [{"new": "data"}]) @@ -2612,7 +2618,7 @@ class TestAppendJson(unittest.TestCase): tf.write('{"not": "a list"}') path = tf.name try: - parsedmarc.append_json(path, [{"new": "data"}]) + parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}])) with open(path) as f: data = json.loads(f.read()) self.assertEqual(data, [{"new": "data"}]) diff --git a/tests/test_kafkaclient.py b/tests/test_kafkaclient.py index 2c4a673..0d93bfc 100644 --- a/tests/test_kafkaclient.py +++ b/tests/test_kafkaclient.py @@ -2,6 +2,7 @@ import json import unittest +from typing import cast from unittest.mock import MagicMock, patch from kafka.errors import UnknownTopicOrPartitionError @@ -9,6 +10,11 @@ from kafka.errors import UnknownTopicOrPartitionError from parsedmarc.kafkaclient import KafkaClient, KafkaError, _BootstrapError +def _producer(client: KafkaClient) -> MagicMock: + """The patched KafkaProducer as a MagicMock, for assertion access.""" + return cast(MagicMock, client.producer) + + def _aggregate_report(): return { "report_metadata": { @@ -121,15 +127,15 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase): client = self._client() client.save_aggregate_reports_to_kafka(_aggregate_report(), "dmarc-aggregate") # 2 records in the sample report → 2 producer.send calls. - self.assertEqual(client.producer.send.call_count, 2) + self.assertEqual(_producer(client).send.call_count, 2) # Topic is forwarded verbatim. - for call in client.producer.send.call_args_list: + for call in _producer(client).send.call_args_list: self.assertEqual(call.args[0], "dmarc-aggregate") def test_each_slice_carries_metadata(self): client = self._client() client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic") - sent = [call.args[1] for call in client.producer.send.call_args_list] + sent = [call.args[1] for call in _producer(client).send.call_args_list] for slice_ in sent: self.assertEqual(slice_["org_name"], "TestOrg") self.assertEqual(slice_["org_email"], "test@example.com") @@ -144,32 +150,32 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase): def test_empty_list_is_a_noop(self): client = self._client() client.save_aggregate_reports_to_kafka([], "topic") - client.producer.send.assert_not_called() + _producer(client).send.assert_not_called() def test_dict_input_normalized_to_list(self): """Single-report dict input is wrapped to a list.""" client = self._client() client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic") # 2 records still sent (one report with 2 records, not multiple reports). - self.assertEqual(client.producer.send.call_count, 2) + self.assertEqual(_producer(client).send.call_count, 2) def test_unknown_topic_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = UnknownTopicOrPartitionError() + _producer(client).send.side_effect = UnknownTopicOrPartitionError() with self.assertRaises(KafkaError) as ctx: client.save_aggregate_reports_to_kafka(_aggregate_report(), "missing") self.assertIn("Unknown topic or partition", str(ctx.exception)) def test_generic_send_exception_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = RuntimeError("transport failure") + _producer(client).send.side_effect = RuntimeError("transport failure") with self.assertRaises(KafkaError) as ctx: client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic") self.assertIn("transport failure", str(ctx.exception)) def test_flush_exception_translates_to_kafka_error(self): client = self._client() - client.producer.flush.side_effect = RuntimeError("flush failure") + _producer(client).flush.side_effect = RuntimeError("flush failure") with self.assertRaises(KafkaError) as ctx: client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic") self.assertIn("flush failure", str(ctx.exception)) @@ -186,35 +192,35 @@ class TestSaveFailureReportsToKafka(unittest.TestCase): client = self._client() reports = [{"id": "f1"}, {"id": "f2"}] client.save_failure_reports_to_kafka(reports, "dmarc-failure") - client.producer.send.assert_called_once_with("dmarc-failure", reports) + _producer(client).send.assert_called_once_with("dmarc-failure", reports) def test_dict_input_normalized_to_list(self): client = self._client() client.save_failure_reports_to_kafka({"id": "single"}, "topic") # The send payload is wrapped to a single-element list. - args = client.producer.send.call_args.args + args = _producer(client).send.call_args.args self.assertEqual(args[1], [{"id": "single"}]) def test_empty_list_is_a_noop(self): client = self._client() client.save_failure_reports_to_kafka([], "topic") - client.producer.send.assert_not_called() + _producer(client).send.assert_not_called() def test_unknown_topic_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = UnknownTopicOrPartitionError() + _producer(client).send.side_effect = UnknownTopicOrPartitionError() with self.assertRaises(KafkaError): client.save_failure_reports_to_kafka([{"a": 1}], "missing") def test_generic_send_error_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = OSError("net") + _producer(client).send.side_effect = OSError("net") with self.assertRaises(KafkaError): client.save_failure_reports_to_kafka([{"a": 1}], "topic") def test_flush_error_translates_to_kafka_error(self): client = self._client() - client.producer.flush.side_effect = OSError("flush") + _producer(client).flush.side_effect = OSError("flush") with self.assertRaises(KafkaError): client.save_failure_reports_to_kafka([{"a": 1}], "topic") @@ -228,34 +234,34 @@ class TestSaveSmtpTlsReportsToKafka(unittest.TestCase): client = self._client() reports = [{"organization_name": "x"}] client.save_smtp_tls_reports_to_kafka(reports, "smtp-tls") - client.producer.send.assert_called_once_with("smtp-tls", reports) + _producer(client).send.assert_called_once_with("smtp-tls", reports) def test_dict_input_normalized_to_list(self): client = self._client() client.save_smtp_tls_reports_to_kafka({"organization_name": "x"}, "topic") - args = client.producer.send.call_args.args + args = _producer(client).send.call_args.args self.assertEqual(args[1], [{"organization_name": "x"}]) def test_empty_list_is_a_noop(self): client = self._client() client.save_smtp_tls_reports_to_kafka([], "topic") - client.producer.send.assert_not_called() + _producer(client).send.assert_not_called() def test_unknown_topic_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = UnknownTopicOrPartitionError() + _producer(client).send.side_effect = UnknownTopicOrPartitionError() with self.assertRaises(KafkaError): client.save_smtp_tls_reports_to_kafka([{"a": 1}], "missing") def test_generic_send_error_translates_to_kafka_error(self): client = self._client() - client.producer.send.side_effect = RuntimeError("oops") + _producer(client).send.side_effect = RuntimeError("oops") with self.assertRaises(KafkaError): client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic") def test_flush_error_translates_to_kafka_error(self): client = self._client() - client.producer.flush.side_effect = RuntimeError("flush") + _producer(client).flush.side_effect = RuntimeError("flush") with self.assertRaises(KafkaError): client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic") @@ -265,7 +271,7 @@ class TestKafkaClientClose(unittest.TestCase): with patch("parsedmarc.kafkaclient.KafkaProducer"): client = KafkaClient(kafka_hosts=["b"]) client.close() - client.producer.close.assert_called_once() + _producer(client).close.assert_called_once() class TestKafkaBackwardCompatAlias(unittest.TestCase): diff --git a/tests/test_postgres.py b/tests/test_postgres.py index ef62d20..4b05d21 100644 --- a/tests/test_postgres.py +++ b/tests/test_postgres.py @@ -10,6 +10,7 @@ real-sample round trip, so the tests fail if the dict-key mapping regresses. import os import unittest from glob import glob +from typing import cast from unittest.mock import MagicMock, patch import parsedmarc @@ -27,7 +28,7 @@ OFFLINE_MODE = os.environ.get("GITHUB_ACTIONS", "false").lower() == "true" # psycopg is an optional dependency and is not installed in CI (which installs # only the [build] extra). The save methods mock the connection, but the -# failure path also references ``psycopg_types.json.Jsonb`` at module scope, so +# failure path also references ``psycopg_json.Jsonb`` at module scope, so # mock that SDK boundary for the whole module when psycopg is absent. _types_patcher = None @@ -36,8 +37,8 @@ def setUpModule(): global _types_patcher import parsedmarc.postgres as pg - if pg.psycopg_types is None: - _types_patcher = patch("parsedmarc.postgres.psycopg_types", MagicMock()) + if pg.psycopg_json is None: + _types_patcher = patch("parsedmarc.postgres.psycopg_json", MagicMock()) _types_patcher.start() @@ -99,6 +100,7 @@ class TestPostgreSQLHelpers(unittest.TestCase): def test_naive_local_to_timestamptz_valid(self): """A valid naive string is returned with a timezone offset.""" result = _naive_local_to_timestamptz("2024-01-15 10:30:00") + assert result is not None self.assertIsInstance(result, str) self.assertTrue( "+" in result or "-" in result[10:], @@ -125,11 +127,13 @@ class TestPostgreSQLHelpers(unittest.TestCase): def test_normalize_arrival_date_iso_naive_utc(self): """A naive ISO string (known UTC) is returned with +00 suffix.""" result = _normalize_arrival_date("2024-01-15 10:30:00") + assert result is not None self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}") def test_normalize_arrival_date_rfc2822(self): """An RFC 2822 date is converted to UTC with +00 suffix.""" result = _normalize_arrival_date("Fri, 28 Oct 2022 00:34:24 +0800") + assert result is not None self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}") # 00:34:24 +0800 is 16:34:24 UTC on 27 Oct 2022. self.assertIn("2022-10-27", result) @@ -138,6 +142,7 @@ class TestPostgreSQLHelpers(unittest.TestCase): def test_normalize_arrival_date_already_utc(self): """A string already ending with +00 still works.""" result = _normalize_arrival_date("2024-01-15 10:30:00+00") + assert result is not None self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}") def test_normalize_arrival_date_unparseable(self): @@ -167,7 +172,8 @@ class TestPostgreSQLHelpers(unittest.TestCase): def test_contact_info_to_text_numeric(self): """Non-string scalars are converted via str().""" - self.assertEqual(_contact_info_to_text(123), "123") + # Deliberately outside the annotated parameter types + self.assertEqual(_contact_info_to_text(123), "123") # pyright: ignore[reportArgumentType] def _make_client(): @@ -180,8 +186,8 @@ def _make_client(): client = PostgreSQLClient( host="localhost", database="test", user="test", password="test" ) + mock_conn.closed = False client._conn = mock_conn - client._conn.closed = False return client, mock_conn @@ -211,6 +217,7 @@ def _named_params(call): sql = call.args[0] m = re.search(r"\(([^)]*?)\)\s*VALUES", sql, re.S) + assert m is not None cols = [c.strip() for c in m.group(1).split(",") if c.strip()] return dict(zip(cols, call.args[1])) @@ -758,7 +765,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase): num_records = len(report.get("records", [])) _mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_records)]) try: - client.save_aggregate_report_to_postgresql(report) + client.save_aggregate_report_to_postgresql(cast(dict, report)) saved += 1 except Exception as exc: self.fail(f"aggregate save failed for {sample_path}: {exc}") @@ -783,7 +790,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase): # Dedup SELECT returns None (not a dup), then the INSERT id. _mock_cursor(mock_conn, [None, (1,)]) try: - client.save_failure_report_to_postgresql(report) + client.save_failure_report_to_postgresql(cast(dict, report)) saved += 1 except Exception as exc: self.fail(f"failure save failed for {sample_path}: {exc}") @@ -807,7 +814,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase): num_policies = len(report.get("policies", [])) _mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_policies)]) try: - client.save_smtp_tls_report_to_postgresql(report) + client.save_smtp_tls_report_to_postgresql(cast(dict, report)) saved += 1 except Exception as exc: self.fail(f"smtp_tls save failed for {sample_path}: {exc}") diff --git a/tests/test_syslog.py b/tests/test_syslog.py index 33c9dd7..06004dd 100644 --- a/tests/test_syslog.py +++ b/tests/test_syslog.py @@ -6,11 +6,14 @@ import socket import unittest from unittest.mock import MagicMock, patch +from typing import cast + from parsedmarc.syslog import SyslogClient +from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport -def _sample_aggregate_report(): - return { +def _sample_aggregate_report() -> AggregateReport: + report = { "xml_schema": "draft", "xml_namespace": None, "report_metadata": { @@ -70,6 +73,7 @@ def _sample_aggregate_report(): } ], } + return cast(AggregateReport, report) class _CapturingHandler(logging.Handler): @@ -259,6 +263,18 @@ class TestSyslogClientInitInvalidProtocol(unittest.TestCase): self.assertIn("udb", str(ctx.exception)) self.assertIn("'udp', 'tcp', or 'tls'", str(ctx.exception)) + def test_zero_retry_attempts_raises_value_error(self): + """retry_attempts < 1 means the TCP/TLS connect loop never runs. + Before the fix, _create_syslog_handler fell through and returned + None, which was then passed to logger.addHandler(); now it raises + ValueError instead of silently configuring a broken client.""" + _fresh_logger() + with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as handler_cls: + with self.assertRaises(ValueError) as ctx: + SyslogClient("s", 514, protocol="tcp", retry_attempts=0) + handler_cls.assert_not_called() + self.assertIn("retry_attempts", str(ctx.exception)) + class TestSyslogClientSave(unittest.TestCase): """save_* methods emit one syslog message per CSV row, each as a @@ -314,7 +330,7 @@ class TestSyslogClientSave(unittest.TestCase): "sample": "...", "parsed_sample": {"subject": "Test"}, } - client.save_failure_report_to_syslog([failure_report]) + client.save_failure_report_to_syslog([cast(FailureReport, failure_report)]) self.assertEqual(len(cap.messages), 1) payload = json.loads(cap.messages[0]) self.assertEqual(payload["reported_domain"], "example.com") @@ -337,7 +353,7 @@ class TestSyslogClientSave(unittest.TestCase): } ], } - client.save_smtp_tls_report_to_syslog([report]) + client.save_smtp_tls_report_to_syslog([cast(SMTPTLSReport, report)]) self.assertEqual(len(cap.messages), 1) payload = json.loads(cap.messages[0]) self.assertEqual(payload["policy_domain"], "example.com")