mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-06-25 11:34:18 +00:00
Make the whole codebase pass pyright cleanly and enforce it in CI (#798)
* Make the whole codebase pass pyright cleanly and enforce it in CI Fix all 102 pyright (1.1.410, standard mode) errors across the library, tests, and maps scripts, then pin and enforce the zero-errors bar: - postgres.py: make the optional psycopg import TYPE_CHECKING-aware so the module is properly typed while keeping the runtime install-hint fallback; import psycopg.types.json explicitly as psycopg_json (the old psycopg_types.json attribute access only worked because psycopg imports the submodule eagerly); have _connect()/_ensure_connected() return the live connection so save methods use a non-Optional local; type the DDL list as list[LiteralString] to match psycopg's execute() overloads. - kafkaclient.py: resolve the kafka-python 2.x/3.x bootstrap-error fallback statically via TYPE_CHECKING (kafka-python 3.0 removed NoBrokersAvailable), which also fixes _BootstrapError's import resolution in tests. - syslog.py: go through getattr/setattr for SysLogHandler.socket (absent from typeshed); type the save_* methods with the report TypedDicts (single or list, matching cli.py call sites — gelf.py gets the same signatures); raise ValueError when retry_attempts < 1 instead of falling through and registering a None handler (bug fix, with a regression test and a CHANGELOG entry). - elastic.py / opensearch.py: human_result params are Optional[str]. - maps scripts: sort_csv declared a return type but never returned (now -> None); seen_sort_field_values was possibly unbound; convert_to_utf8's src_encoding is Optional[str]. - tests: cast sample-report dict helpers to their TypedDicts; mark deliberate wrong-type calls with targeted pyright ignores; add narrowing asserts for Optional results; access the mocked KafkaProducer through a cast helper; match the mailsuite fetch_message base signature (**kwargs); patch the renamed parsedmarc.postgres.psycopg_json in test_postgres's setUpModule. Enforcement: [tool.pyright] in pyproject.toml (include parsedmarc, tests, docs; standard mode), pyright==1.1.410 pinned in the [build] extra (pinned exactly so a new pyright release can't break CI without a code change), and a "Check types" step in the lint CI job — which now also runs ruff format --check and installs the [postgresql] extra so the optional psycopg import resolves. Documented in AGENTS.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Set session headers via update() instead of replacing the dict requests 2.34 ships inline type annotations, and Session.headers is a CaseInsensitiveDict[str] — assigning a plain dict fails pyright there (the CI runner resolved 2.34.2; the local venv's untyped 2.32.4 hid it). headers.update() is correctly typed against both versions, and is the documented requests idiom: it overrides User-Agent and the client-specific headers while keeping the session's defaults (Accept-Encoding, Connection) instead of wiping them. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> --------- Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
+2
-1
@@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import parsedmarc
|
||||
import parsedmarc.cli
|
||||
import parsedmarc.elastic
|
||||
import parsedmarc.opensearch as opensearch_module
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ class _DummyMailboxConnection(parsedmarc.MailboxConnection):
|
||||
self.fetch_calls.append({"reports_folder": reports_folder, **kwargs})
|
||||
return []
|
||||
|
||||
def fetch_message(self, message_id) -> str:
|
||||
def fetch_message(self, message_id, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
def delete_message(self, message_id):
|
||||
|
||||
+12
-7
@@ -2,15 +2,17 @@
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from parsedmarc.gelf import ContextFilter, GelfClient, log_context_data
|
||||
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
|
||||
|
||||
|
||||
def _sample_aggregate_report():
|
||||
def _sample_aggregate_report() -> AggregateReport:
|
||||
"""Minimal aggregate report shape acceptable to
|
||||
parsed_aggregate_reports_to_csv_rows."""
|
||||
return {
|
||||
report = {
|
||||
"xml_schema": "draft",
|
||||
"xml_namespace": None,
|
||||
"report_metadata": {
|
||||
@@ -87,6 +89,7 @@ def _sample_aggregate_report():
|
||||
}
|
||||
],
|
||||
}
|
||||
return cast(AggregateReport, report)
|
||||
|
||||
|
||||
class _Handler(logging.Handler):
|
||||
@@ -95,7 +98,7 @@ class _Handler(logging.Handler):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records: list[tuple[str, dict]] = []
|
||||
self.records: list[tuple[str, Any]] = []
|
||||
|
||||
def emit(self, record):
|
||||
# ContextFilter has run by this point so `record.parsedmarc` is
|
||||
@@ -212,8 +215,8 @@ class TestGelfClientSaveFailure(unittest.TestCase):
|
||||
reports. Build one through the CSV-row helper to verify GelfClient
|
||||
surfaces the right fields."""
|
||||
|
||||
def _sample_failure_report(self):
|
||||
return {
|
||||
def _sample_failure_report(self) -> FailureReport:
|
||||
report = {
|
||||
"feedback_type": "auth-failure",
|
||||
"user_agent": "test/1.0",
|
||||
"version": "1",
|
||||
@@ -243,6 +246,7 @@ class TestGelfClientSaveFailure(unittest.TestCase):
|
||||
"sample": "...",
|
||||
"parsed_sample": {"subject": "Test"},
|
||||
}
|
||||
return cast(FailureReport, report)
|
||||
|
||||
def test_emits_one_record_per_failure_report(self):
|
||||
client = _gelf_client()
|
||||
@@ -256,8 +260,8 @@ class TestGelfClientSaveFailure(unittest.TestCase):
|
||||
|
||||
|
||||
class TestGelfClientSaveSmtpTls(unittest.TestCase):
|
||||
def _sample_smtp_tls(self):
|
||||
return {
|
||||
def _sample_smtp_tls(self) -> SMTPTLSReport:
|
||||
report = {
|
||||
"organization_name": "example.com",
|
||||
"begin_date": "2024-02-03T00:00:00Z",
|
||||
"end_date": "2024-02-04T00:00:00Z",
|
||||
@@ -272,6 +276,7 @@ class TestGelfClientSaveSmtpTls(unittest.TestCase):
|
||||
}
|
||||
],
|
||||
}
|
||||
return cast(SMTPTLSReport, report)
|
||||
|
||||
def test_emits_one_record_per_policy(self):
|
||||
client = _gelf_client()
|
||||
|
||||
+15
-9
@@ -568,7 +568,9 @@ class Test(unittest.TestCase):
|
||||
always_use_local_files=True,
|
||||
offline=True,
|
||||
)
|
||||
csv_text = parsedmarc.parsed_aggregate_reports_to_csv(result["report"])
|
||||
csv_text = parsedmarc.parsed_aggregate_reports_to_csv(
|
||||
cast(AggregateReport, result["report"])
|
||||
)
|
||||
header = csv_text.splitlines()[0].split(",")
|
||||
self.assertIn("source_asn", header)
|
||||
self.assertIn("source_as_name", header)
|
||||
@@ -2330,7 +2332,10 @@ class TestGetDmarcReportsFromMailboxValidation(unittest.TestCase):
|
||||
|
||||
def test_none_connection_raises(self):
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
parsedmarc.get_dmarc_reports_from_mailbox(connection=None)
|
||||
parsedmarc.get_dmarc_reports_from_mailbox(
|
||||
# Deliberately invalid: exercises the runtime None check
|
||||
connection=None # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
self.assertIn("connection", str(ctx.exception).lower())
|
||||
|
||||
|
||||
@@ -2535,7 +2540,8 @@ class TestEmailResultsErrorBranches(unittest.TestCase):
|
||||
},
|
||||
host="smtp.example.com",
|
||||
mail_from="from@example.com",
|
||||
mail_to="admin@example.com", # str, not list — triggers assert
|
||||
# str, not list — triggers assert
|
||||
mail_to="admin@example.com", # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
|
||||
@@ -2548,7 +2554,7 @@ class TestAppendJson(unittest.TestCase):
|
||||
path = tf.name
|
||||
os.remove(path) # ensure file is fresh
|
||||
try:
|
||||
parsedmarc.append_json(path, [{"a": 1}])
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
|
||||
with open(path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.assertEqual(data, [{"a": 1}])
|
||||
@@ -2560,8 +2566,8 @@ class TestAppendJson(unittest.TestCase):
|
||||
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
|
||||
path = tf.name
|
||||
try:
|
||||
parsedmarc.append_json(path, [{"a": 1}])
|
||||
parsedmarc.append_json(path, [{"b": 2}])
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"b": 2}]))
|
||||
with open(path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.assertEqual(data, [{"a": 1}, {"b": 2}])
|
||||
@@ -2573,7 +2579,7 @@ class TestAppendJson(unittest.TestCase):
|
||||
with NamedTemporaryFile("w", suffix=".json", delete=False) as tf:
|
||||
path = tf.name
|
||||
try:
|
||||
parsedmarc.append_json(path, [{"a": 1}])
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"a": 1}]))
|
||||
parsedmarc.append_json(path, [])
|
||||
with open(path) as f:
|
||||
data = json.loads(f.read())
|
||||
@@ -2595,7 +2601,7 @@ class TestAppendJson(unittest.TestCase):
|
||||
tf.write("{ this is not valid json at all")
|
||||
path = tf.name
|
||||
try:
|
||||
parsedmarc.append_json(path, [{"new": "data"}])
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}]))
|
||||
with open(path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.assertEqual(data, [{"new": "data"}])
|
||||
@@ -2612,7 +2618,7 @@ class TestAppendJson(unittest.TestCase):
|
||||
tf.write('{"not": "a list"}')
|
||||
path = tf.name
|
||||
try:
|
||||
parsedmarc.append_json(path, [{"new": "data"}])
|
||||
parsedmarc.append_json(path, cast(list[AggregateReport], [{"new": "data"}]))
|
||||
with open(path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.assertEqual(data, [{"new": "data"}])
|
||||
|
||||
+27
-21
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from kafka.errors import UnknownTopicOrPartitionError
|
||||
@@ -9,6 +10,11 @@ from kafka.errors import UnknownTopicOrPartitionError
|
||||
from parsedmarc.kafkaclient import KafkaClient, KafkaError, _BootstrapError
|
||||
|
||||
|
||||
def _producer(client: KafkaClient) -> MagicMock:
|
||||
"""The patched KafkaProducer as a MagicMock, for assertion access."""
|
||||
return cast(MagicMock, client.producer)
|
||||
|
||||
|
||||
def _aggregate_report():
|
||||
return {
|
||||
"report_metadata": {
|
||||
@@ -121,15 +127,15 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase):
|
||||
client = self._client()
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "dmarc-aggregate")
|
||||
# 2 records in the sample report → 2 producer.send calls.
|
||||
self.assertEqual(client.producer.send.call_count, 2)
|
||||
self.assertEqual(_producer(client).send.call_count, 2)
|
||||
# Topic is forwarded verbatim.
|
||||
for call in client.producer.send.call_args_list:
|
||||
for call in _producer(client).send.call_args_list:
|
||||
self.assertEqual(call.args[0], "dmarc-aggregate")
|
||||
|
||||
def test_each_slice_carries_metadata(self):
|
||||
client = self._client()
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
|
||||
sent = [call.args[1] for call in client.producer.send.call_args_list]
|
||||
sent = [call.args[1] for call in _producer(client).send.call_args_list]
|
||||
for slice_ in sent:
|
||||
self.assertEqual(slice_["org_name"], "TestOrg")
|
||||
self.assertEqual(slice_["org_email"], "test@example.com")
|
||||
@@ -144,32 +150,32 @@ class TestSaveAggregateReportsToKafka(unittest.TestCase):
|
||||
def test_empty_list_is_a_noop(self):
|
||||
client = self._client()
|
||||
client.save_aggregate_reports_to_kafka([], "topic")
|
||||
client.producer.send.assert_not_called()
|
||||
_producer(client).send.assert_not_called()
|
||||
|
||||
def test_dict_input_normalized_to_list(self):
|
||||
"""Single-report dict input is wrapped to a list."""
|
||||
client = self._client()
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
|
||||
# 2 records still sent (one report with 2 records, not multiple reports).
|
||||
self.assertEqual(client.producer.send.call_count, 2)
|
||||
self.assertEqual(_producer(client).send.call_count, 2)
|
||||
|
||||
def test_unknown_topic_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = UnknownTopicOrPartitionError()
|
||||
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
|
||||
with self.assertRaises(KafkaError) as ctx:
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "missing")
|
||||
self.assertIn("Unknown topic or partition", str(ctx.exception))
|
||||
|
||||
def test_generic_send_exception_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = RuntimeError("transport failure")
|
||||
_producer(client).send.side_effect = RuntimeError("transport failure")
|
||||
with self.assertRaises(KafkaError) as ctx:
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
|
||||
self.assertIn("transport failure", str(ctx.exception))
|
||||
|
||||
def test_flush_exception_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.flush.side_effect = RuntimeError("flush failure")
|
||||
_producer(client).flush.side_effect = RuntimeError("flush failure")
|
||||
with self.assertRaises(KafkaError) as ctx:
|
||||
client.save_aggregate_reports_to_kafka(_aggregate_report(), "topic")
|
||||
self.assertIn("flush failure", str(ctx.exception))
|
||||
@@ -186,35 +192,35 @@ class TestSaveFailureReportsToKafka(unittest.TestCase):
|
||||
client = self._client()
|
||||
reports = [{"id": "f1"}, {"id": "f2"}]
|
||||
client.save_failure_reports_to_kafka(reports, "dmarc-failure")
|
||||
client.producer.send.assert_called_once_with("dmarc-failure", reports)
|
||||
_producer(client).send.assert_called_once_with("dmarc-failure", reports)
|
||||
|
||||
def test_dict_input_normalized_to_list(self):
|
||||
client = self._client()
|
||||
client.save_failure_reports_to_kafka({"id": "single"}, "topic")
|
||||
# The send payload is wrapped to a single-element list.
|
||||
args = client.producer.send.call_args.args
|
||||
args = _producer(client).send.call_args.args
|
||||
self.assertEqual(args[1], [{"id": "single"}])
|
||||
|
||||
def test_empty_list_is_a_noop(self):
|
||||
client = self._client()
|
||||
client.save_failure_reports_to_kafka([], "topic")
|
||||
client.producer.send.assert_not_called()
|
||||
_producer(client).send.assert_not_called()
|
||||
|
||||
def test_unknown_topic_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = UnknownTopicOrPartitionError()
|
||||
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_failure_reports_to_kafka([{"a": 1}], "missing")
|
||||
|
||||
def test_generic_send_error_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = OSError("net")
|
||||
_producer(client).send.side_effect = OSError("net")
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
|
||||
|
||||
def test_flush_error_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.flush.side_effect = OSError("flush")
|
||||
_producer(client).flush.side_effect = OSError("flush")
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_failure_reports_to_kafka([{"a": 1}], "topic")
|
||||
|
||||
@@ -228,34 +234,34 @@ class TestSaveSmtpTlsReportsToKafka(unittest.TestCase):
|
||||
client = self._client()
|
||||
reports = [{"organization_name": "x"}]
|
||||
client.save_smtp_tls_reports_to_kafka(reports, "smtp-tls")
|
||||
client.producer.send.assert_called_once_with("smtp-tls", reports)
|
||||
_producer(client).send.assert_called_once_with("smtp-tls", reports)
|
||||
|
||||
def test_dict_input_normalized_to_list(self):
|
||||
client = self._client()
|
||||
client.save_smtp_tls_reports_to_kafka({"organization_name": "x"}, "topic")
|
||||
args = client.producer.send.call_args.args
|
||||
args = _producer(client).send.call_args.args
|
||||
self.assertEqual(args[1], [{"organization_name": "x"}])
|
||||
|
||||
def test_empty_list_is_a_noop(self):
|
||||
client = self._client()
|
||||
client.save_smtp_tls_reports_to_kafka([], "topic")
|
||||
client.producer.send.assert_not_called()
|
||||
_producer(client).send.assert_not_called()
|
||||
|
||||
def test_unknown_topic_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = UnknownTopicOrPartitionError()
|
||||
_producer(client).send.side_effect = UnknownTopicOrPartitionError()
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "missing")
|
||||
|
||||
def test_generic_send_error_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.send.side_effect = RuntimeError("oops")
|
||||
_producer(client).send.side_effect = RuntimeError("oops")
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
|
||||
|
||||
def test_flush_error_translates_to_kafka_error(self):
|
||||
client = self._client()
|
||||
client.producer.flush.side_effect = RuntimeError("flush")
|
||||
_producer(client).flush.side_effect = RuntimeError("flush")
|
||||
with self.assertRaises(KafkaError):
|
||||
client.save_smtp_tls_reports_to_kafka([{"a": 1}], "topic")
|
||||
|
||||
@@ -265,7 +271,7 @@ class TestKafkaClientClose(unittest.TestCase):
|
||||
with patch("parsedmarc.kafkaclient.KafkaProducer"):
|
||||
client = KafkaClient(kafka_hosts=["b"])
|
||||
client.close()
|
||||
client.producer.close.assert_called_once()
|
||||
_producer(client).close.assert_called_once()
|
||||
|
||||
|
||||
class TestKafkaBackwardCompatAlias(unittest.TestCase):
|
||||
|
||||
+15
-8
@@ -10,6 +10,7 @@ real-sample round trip, so the tests fail if the dict-key mapping regresses.
|
||||
import os
|
||||
import unittest
|
||||
from glob import glob
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import parsedmarc
|
||||
@@ -27,7 +28,7 @@ OFFLINE_MODE = os.environ.get("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
|
||||
# psycopg is an optional dependency and is not installed in CI (which installs
|
||||
# only the [build] extra). The save methods mock the connection, but the
|
||||
# failure path also references ``psycopg_types.json.Jsonb`` at module scope, so
|
||||
# failure path also references ``psycopg_json.Jsonb`` at module scope, so
|
||||
# mock that SDK boundary for the whole module when psycopg is absent.
|
||||
_types_patcher = None
|
||||
|
||||
@@ -36,8 +37,8 @@ def setUpModule():
|
||||
global _types_patcher
|
||||
import parsedmarc.postgres as pg
|
||||
|
||||
if pg.psycopg_types is None:
|
||||
_types_patcher = patch("parsedmarc.postgres.psycopg_types", MagicMock())
|
||||
if pg.psycopg_json is None:
|
||||
_types_patcher = patch("parsedmarc.postgres.psycopg_json", MagicMock())
|
||||
_types_patcher.start()
|
||||
|
||||
|
||||
@@ -99,6 +100,7 @@ class TestPostgreSQLHelpers(unittest.TestCase):
|
||||
def test_naive_local_to_timestamptz_valid(self):
|
||||
"""A valid naive string is returned with a timezone offset."""
|
||||
result = _naive_local_to_timestamptz("2024-01-15 10:30:00")
|
||||
assert result is not None
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertTrue(
|
||||
"+" in result or "-" in result[10:],
|
||||
@@ -125,11 +127,13 @@ class TestPostgreSQLHelpers(unittest.TestCase):
|
||||
def test_normalize_arrival_date_iso_naive_utc(self):
|
||||
"""A naive ISO string (known UTC) is returned with +00 suffix."""
|
||||
result = _normalize_arrival_date("2024-01-15 10:30:00")
|
||||
assert result is not None
|
||||
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
|
||||
|
||||
def test_normalize_arrival_date_rfc2822(self):
|
||||
"""An RFC 2822 date is converted to UTC with +00 suffix."""
|
||||
result = _normalize_arrival_date("Fri, 28 Oct 2022 00:34:24 +0800")
|
||||
assert result is not None
|
||||
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
|
||||
# 00:34:24 +0800 is 16:34:24 UTC on 27 Oct 2022.
|
||||
self.assertIn("2022-10-27", result)
|
||||
@@ -138,6 +142,7 @@ class TestPostgreSQLHelpers(unittest.TestCase):
|
||||
def test_normalize_arrival_date_already_utc(self):
|
||||
"""A string already ending with +00 still works."""
|
||||
result = _normalize_arrival_date("2024-01-15 10:30:00+00")
|
||||
assert result is not None
|
||||
self.assertTrue(result.endswith("+00"), f"Expected +00 suffix: {result}")
|
||||
|
||||
def test_normalize_arrival_date_unparseable(self):
|
||||
@@ -167,7 +172,8 @@ class TestPostgreSQLHelpers(unittest.TestCase):
|
||||
|
||||
def test_contact_info_to_text_numeric(self):
|
||||
"""Non-string scalars are converted via str()."""
|
||||
self.assertEqual(_contact_info_to_text(123), "123")
|
||||
# Deliberately outside the annotated parameter types
|
||||
self.assertEqual(_contact_info_to_text(123), "123") # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def _make_client():
|
||||
@@ -180,8 +186,8 @@ def _make_client():
|
||||
client = PostgreSQLClient(
|
||||
host="localhost", database="test", user="test", password="test"
|
||||
)
|
||||
mock_conn.closed = False
|
||||
client._conn = mock_conn
|
||||
client._conn.closed = False
|
||||
return client, mock_conn
|
||||
|
||||
|
||||
@@ -211,6 +217,7 @@ def _named_params(call):
|
||||
|
||||
sql = call.args[0]
|
||||
m = re.search(r"\(([^)]*?)\)\s*VALUES", sql, re.S)
|
||||
assert m is not None
|
||||
cols = [c.strip() for c in m.group(1).split(",") if c.strip()]
|
||||
return dict(zip(cols, call.args[1]))
|
||||
|
||||
@@ -758,7 +765,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
|
||||
num_records = len(report.get("records", []))
|
||||
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_records)])
|
||||
try:
|
||||
client.save_aggregate_report_to_postgresql(report)
|
||||
client.save_aggregate_report_to_postgresql(cast(dict, report))
|
||||
saved += 1
|
||||
except Exception as exc:
|
||||
self.fail(f"aggregate save failed for {sample_path}: {exc}")
|
||||
@@ -783,7 +790,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
|
||||
# Dedup SELECT returns None (not a dup), then the INSERT id.
|
||||
_mock_cursor(mock_conn, [None, (1,)])
|
||||
try:
|
||||
client.save_failure_report_to_postgresql(report)
|
||||
client.save_failure_report_to_postgresql(cast(dict, report))
|
||||
saved += 1
|
||||
except Exception as exc:
|
||||
self.fail(f"failure save failed for {sample_path}: {exc}")
|
||||
@@ -807,7 +814,7 @@ class TestPostgreSQLWithSamples(unittest.TestCase):
|
||||
num_policies = len(report.get("policies", []))
|
||||
_mock_cursor(mock_conn, [(rid,) for rid in range(1, 2 + num_policies)])
|
||||
try:
|
||||
client.save_smtp_tls_report_to_postgresql(report)
|
||||
client.save_smtp_tls_report_to_postgresql(cast(dict, report))
|
||||
saved += 1
|
||||
except Exception as exc:
|
||||
self.fail(f"smtp_tls save failed for {sample_path}: {exc}")
|
||||
|
||||
+20
-4
@@ -6,11 +6,14 @@ import socket
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from typing import cast
|
||||
|
||||
from parsedmarc.syslog import SyslogClient
|
||||
from parsedmarc.types import AggregateReport, FailureReport, SMTPTLSReport
|
||||
|
||||
|
||||
def _sample_aggregate_report():
|
||||
return {
|
||||
def _sample_aggregate_report() -> AggregateReport:
|
||||
report = {
|
||||
"xml_schema": "draft",
|
||||
"xml_namespace": None,
|
||||
"report_metadata": {
|
||||
@@ -70,6 +73,7 @@ def _sample_aggregate_report():
|
||||
}
|
||||
],
|
||||
}
|
||||
return cast(AggregateReport, report)
|
||||
|
||||
|
||||
class _CapturingHandler(logging.Handler):
|
||||
@@ -259,6 +263,18 @@ class TestSyslogClientInitInvalidProtocol(unittest.TestCase):
|
||||
self.assertIn("udb", str(ctx.exception))
|
||||
self.assertIn("'udp', 'tcp', or 'tls'", str(ctx.exception))
|
||||
|
||||
def test_zero_retry_attempts_raises_value_error(self):
|
||||
"""retry_attempts < 1 means the TCP/TLS connect loop never runs.
|
||||
Before the fix, _create_syslog_handler fell through and returned
|
||||
None, which was then passed to logger.addHandler(); now it raises
|
||||
ValueError instead of silently configuring a broken client."""
|
||||
_fresh_logger()
|
||||
with patch("parsedmarc.syslog.logging.handlers.SysLogHandler") as handler_cls:
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
SyslogClient("s", 514, protocol="tcp", retry_attempts=0)
|
||||
handler_cls.assert_not_called()
|
||||
self.assertIn("retry_attempts", str(ctx.exception))
|
||||
|
||||
|
||||
class TestSyslogClientSave(unittest.TestCase):
|
||||
"""save_* methods emit one syslog message per CSV row, each as a
|
||||
@@ -314,7 +330,7 @@ class TestSyslogClientSave(unittest.TestCase):
|
||||
"sample": "...",
|
||||
"parsed_sample": {"subject": "Test"},
|
||||
}
|
||||
client.save_failure_report_to_syslog([failure_report])
|
||||
client.save_failure_report_to_syslog([cast(FailureReport, failure_report)])
|
||||
self.assertEqual(len(cap.messages), 1)
|
||||
payload = json.loads(cap.messages[0])
|
||||
self.assertEqual(payload["reported_domain"], "example.com")
|
||||
@@ -337,7 +353,7 @@ class TestSyslogClientSave(unittest.TestCase):
|
||||
}
|
||||
],
|
||||
}
|
||||
client.save_smtp_tls_report_to_syslog([report])
|
||||
client.save_smtp_tls_report_to_syslog([cast(SMTPTLSReport, report)])
|
||||
self.assertEqual(len(cap.messages), 1)
|
||||
payload = json.loads(cap.messages[0])
|
||||
self.assertEqual(payload["policy_domain"], "example.com")
|
||||
|
||||
Reference in New Issue
Block a user