Modernize type hints to PEP 585 / PEP 604 syntax (#803)

This commit is contained in:
Sean Whalen
2026-06-25 15:21:48 -04:00
committed by GitHub
parent d13eb86782
commit c423b8dfff
14 changed files with 337 additions and 337 deletions
+1 -1
View File
@@ -117,7 +117,7 @@ IP address info cached for 4 hours, seen aggregate report IDs cached for 1 hour
- Ruff for formatting and linting (configured in `.vscode/settings.json`). Run `ruff check .` and `ruff format --check .` after every code edit, before committing.
- Pyright for type checking (configured in `pyproject.toml` `[tool.pyright]`, pinned in the `[build]` extra, enforced in CI). Run `pyright` from the repo root before committing; the whole codebase — library and tests — must stay at zero errors and warnings. Prefer real fixes (narrowing, `Optional` annotations, `TYPE_CHECKING` imports) over `# pyright: ignore[...]`; reserve targeted ignores for deliberate wrong-type tests and version-conditional imports, and never use a bare blanket ignore.
- TypedDict for structured data, type hints throughout.
- Python ≥3.10 required.
- Python ≥3.10 required. Use modern type-hint syntax: PEP 585 builtins (`list[str]`, `dict[str, Any]`) and PEP 604 unions (`X | Y`, `X | None`) — not `typing.List` / `Union` / `Optional`. Ruff enforces this (`UP006`/`UP007`/`UP035`/`UP045` in `pyproject.toml`). `typing.NotRequired` / `Required` are 3.11+, so for optional TypedDict keys use `total=False` (see `parsedmarc/types.py`).
- Tests live under `tests/` as `tests/test_<module>.py`, one per top-level `parsedmarc/*` module (e.g. `tests/test_init.py` for `parsedmarc/__init__.py`, `tests/test_cli.py` for `parsedmarc/cli.py`). All test classes use `unittest`. Sample reports live in `samples/`. Run with `pytest tests/`; run one file with `pytest tests/test_init.py`. New tests go in the file whose module they exercise — do not reintroduce a monolithic test file.
- File path config values must be wrapped with `_expand_path()` in `cli.py`.
- Maildir UID checks are intentionally relaxed (warn, don't crash) for Docker compatibility.
+83 -90
View File
@@ -22,15 +22,10 @@ from base64 import b64decode
from csv import DictWriter
from datetime import date, datetime, timedelta, timezone, tzinfo
from io import BytesIO, StringIO
from collections.abc import Callable, Sequence
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
cast,
)
@@ -166,7 +161,7 @@ def _exc_origin(error: BaseException) -> str:
return " (raised at {0}:{1})".format(last.filename, last.lineno)
def _text(value: Any) -> Optional[str]:
def _text(value: Any) -> str | None:
"""Unwrap a possibly-langAttrString value parsed by xmltodict.
RFC 9990 changed several aggregate-report elements (extra_contact_info,
@@ -190,7 +185,7 @@ def _bucket_interval_by_day(
begin: datetime,
end: datetime,
total_count: int,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Split the interval [begin, end) into daily buckets and distribute
`total_count` proportionally across those buckets.
@@ -252,7 +247,7 @@ def _bucket_interval_by_day(
if day_cursor > begin:
day_cursor -= timedelta(days=1)
day_buckets: List[Dict[str, Any]] = []
day_buckets: list[dict[str, Any]] = []
while day_cursor < end:
day_start = day_cursor
@@ -284,12 +279,12 @@ def _bucket_interval_by_day(
# Then apply a "largest remainder" rounding strategy to ensure the sum
# equals exactly total_count.
exact_values: List[float] = [
exact_values: list[float] = [
(b["seconds"] / interval_seconds) * total_count for b in day_buckets
]
floor_values: List[int] = [int(x) for x in exact_values]
fractional_parts: List[float] = [x - int(x) for x in exact_values]
floor_values: list[int] = [int(x) for x in exact_values]
fractional_parts: list[float] = [x - int(x) for x in exact_values]
# How many counts do we still need to distribute after flooring?
remainder = total_count - sum(floor_values)
@@ -309,7 +304,7 @@ def _bucket_interval_by_day(
final_counts[idx] += 1
# --- Step 3: Build the final per-day result list -------------------------
results: List[Dict[str, Any]] = []
results: list[dict[str, Any]] = []
for bucket, count in zip(day_buckets, final_counts):
if count > 0:
results.append(
@@ -370,12 +365,12 @@ def _append_parsed_record(
def _parse_report_record(
record: dict[str, Any],
*,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
is_rfc_9990: bool = False,
@@ -632,7 +627,7 @@ def _parse_smtp_tls_report_policy(policy: dict[str, Any]):
raise InvalidSMTPTLSReport(str(e) + _exc_origin(e)) from e
def parse_smtp_tls_report_json(report: Union[str, bytes]) -> SMTPTLSReport:
def parse_smtp_tls_report_json(report: str | bytes) -> SMTPTLSReport:
"""Parses and validates an SMTP TLS report"""
required_fields = [
"organization-name",
@@ -675,7 +670,7 @@ def parse_smtp_tls_report_json(report: Union[str, bytes]) -> SMTPTLSReport:
def parsed_smtp_tls_reports_to_csv_rows(
reports: Union[SMTPTLSReport, list[SMTPTLSReport]],
reports: SMTPTLSReport | list[SMTPTLSReport],
) -> list[dict[str, Any]]:
"""Converts one oor more parsed SMTP TLS reports into a list of single
layer dict objects suitable for use in a CSV"""
@@ -715,7 +710,7 @@ def parsed_smtp_tls_reports_to_csv_rows(
def parsed_smtp_tls_reports_to_csv(
reports: Union[SMTPTLSReport, list[SMTPTLSReport]],
reports: SMTPTLSReport | list[SMTPTLSReport],
) -> str:
"""
Converts one or more parsed SMTP TLS reports to flat CSV format, including
@@ -764,15 +759,15 @@ def parsed_smtp_tls_reports_to_csv(
def parse_aggregate_report_xml(
xml: str,
*,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
timeout: float = DEFAULT_DNS_TIMEOUT,
retries: int = DEFAULT_DNS_MAX_RETRIES,
keep_alive: Optional[Callable] = None,
keep_alive: Callable | None = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> AggregateReport:
"""Parses a DMARC XML report string and returns a consistent dict
@@ -809,7 +804,7 @@ def parse_aggregate_report_xml(
# The final `is_rfc_9990` decision is made post-parse so that
# RFC 9990-only fields (np, testing, discovery_method, generator,
# human_result) can also vote it in.
xml_namespace: Optional[str] = None
xml_namespace: str | None = None
namespace_match = xml_namespace_regex.search(xml)
if namespace_match:
xml_namespace = namespace_match.group(1)
@@ -1062,7 +1057,7 @@ def parse_aggregate_report_xml(
) from error
def extract_report(content: Union[bytes, str, BinaryIO]) -> str:
def extract_report(content: bytes | str | BinaryIO) -> str:
"""
Extracts text from a zip or gzip file, as a base64-encoded string,
file-like object, or bytes.
@@ -1075,7 +1070,7 @@ def extract_report(content: Union[bytes, str, BinaryIO]) -> str:
str: The extracted text
"""
file_object: Optional[BinaryIO] = None
file_object: BinaryIO | None = None
header: bytes
try:
if isinstance(content, str):
@@ -1152,7 +1147,7 @@ def extract_report(content: Union[bytes, str, BinaryIO]) -> str:
def extract_report_from_file_path(
file_path: Union[str, bytes, os.PathLike[str], os.PathLike[bytes]],
file_path: str | bytes | os.PathLike[str] | os.PathLike[bytes],
) -> str:
"""Extracts report from a file at the given file_path"""
try:
@@ -1163,17 +1158,17 @@ def extract_report_from_file_path(
def parse_aggregate_report_file(
_input: Union[str, bytes, BinaryIO],
_input: str | bytes | BinaryIO,
*,
offline: bool = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
ip_db_path: str | None = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
keep_alive: Optional[Callable] = None,
keep_alive: Callable | None = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> AggregateReport:
"""Parses a file at the given path, a file-like object. or bytes as an
@@ -1219,7 +1214,7 @@ def parse_aggregate_report_file(
def parsed_aggregate_reports_to_csv_rows(
reports: Union[AggregateReport, list[AggregateReport]],
reports: AggregateReport | list[AggregateReport],
) -> list[dict[str, Any]]:
"""
Converts one or more parsed aggregate reports to list of dicts in flat CSV
@@ -1354,7 +1349,7 @@ def parsed_aggregate_reports_to_csv_rows(
def parsed_aggregate_reports_to_csv(
reports: Union[AggregateReport, list[AggregateReport]],
reports: AggregateReport | list[AggregateReport],
) -> str:
"""
Converts one or more parsed aggregate reports to flat CSV format, including
@@ -1433,11 +1428,11 @@ def parse_failure_report(
msg_date: datetime,
*,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
ip_db_path: str | None = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
@@ -1593,7 +1588,7 @@ def parse_failure_report(
def parsed_failure_reports_to_csv_rows(
reports: Union[FailureReport, list[FailureReport]],
reports: FailureReport | list[FailureReport],
) -> list[dict[str, Any]]:
"""
Converts one or more parsed failure reports to a list of dicts in flat CSV
@@ -1634,7 +1629,7 @@ def parsed_failure_reports_to_csv_rows(
def parsed_failure_reports_to_csv(
reports: Union[FailureReport, list[FailureReport]],
reports: FailureReport | list[FailureReport],
) -> str:
"""
Converts one or more parsed failure reports to flat CSV format, including
@@ -1691,18 +1686,18 @@ def parsed_failure_reports_to_csv(
def parse_report_email(
input_: Union[bytes, str],
input_: bytes | str,
*,
offline: bool = False,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
nameservers: Optional[list[str]] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
keep_alive: Optional[Callable] = None,
keep_alive: Callable | None = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> ParsedReport:
"""
@@ -1729,11 +1724,11 @@ def parse_report_email(
* ``report_type``: ``aggregate`` or ``failure``
* ``report``: The parsed report
"""
result: Optional[ParsedReport] = None
result: ParsedReport | None = None
msg_date: datetime = datetime.now(timezone.utc)
try:
input_data: Union[str, bytes, bytearray, memoryview] = input_
input_data: str | bytes | bytearray | memoryview = input_
if isinstance(input_data, (bytes, bytearray, memoryview)):
input_bytes = bytes(input_data)
if is_outlook_msg(input_bytes):
@@ -1919,7 +1914,7 @@ def _looks_like_email(text: str) -> bool:
def _describe_parse_failure(
content: Union[str, bytes],
content: str | bytes,
aggregate_error: InvalidAggregateReport,
smtp_tls_error: InvalidSMTPTLSReport,
email_error: InvalidDMARCReport,
@@ -1952,18 +1947,18 @@ def _describe_parse_failure(
def parse_report_file(
input_: Union[bytes, str, os.PathLike[str], os.PathLike[bytes], BinaryIO],
input_: bytes | str | os.PathLike[str] | os.PathLike[bytes] | BinaryIO,
*,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
keep_alive: Optional[Callable] = None,
keep_alive: Callable | None = None,
normalize_timespan_threshold_hours: float = 24,
) -> ParsedReport:
"""Parses a DMARC aggregate or failure file at the given path, a
@@ -2004,7 +1999,7 @@ def parse_report_file(
if content.startswith(MAGIC_ZIP) or content.startswith(MAGIC_GZIP):
content = extract_report(content)
results: Optional[ParsedReport] = None
results: ParsedReport | None = None
# parse_report_file tries the three report formats in turn. When all three
# reject the input, keep each format's specific error so the final message
@@ -2060,14 +2055,14 @@ def parse_report_file(
def get_dmarc_reports_from_mbox(
input_: str,
*,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
dns_timeout: float = DEFAULT_DNS_TIMEOUT,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
normalize_timespan_threshold_hours: float = 24.0,
) -> ParsingResults:
@@ -2203,18 +2198,18 @@ def get_dmarc_reports_from_mailbox(
archive_folder: str = "Archive",
delete: bool = False,
test: bool = False,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
dns_timeout: float = 6.0,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
results: Optional[ParsingResults] = None,
results: ParsingResults | None = None,
batch_size: int = 10,
since: Optional[Union[datetime, date, str]] = None,
since: datetime | date | str | None = None,
create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24,
) -> ParsingResults:
@@ -2257,7 +2252,7 @@ def get_dmarc_reports_from_mailbox(
raise ValueError("Must supply a connection")
# current_time useful to fetch_messages later in the program
current_time: Optional[Union[datetime, date, str]] = None
current_time: datetime | date | str | None = None
aggregate_reports: list[AggregateReport] = []
failure_reports: list[FailureReport] = []
@@ -2344,7 +2339,7 @@ def get_dmarc_reports_from_mailbox(
i + 1, message_limit, msg_uid
)
)
message_id: Union[int, str]
message_id: int | str
if isinstance(connection, IMAPConnection):
message_id = int(msg_uid)
msg_content = connection.fetch_message(message_id)
@@ -2546,19 +2541,19 @@ def watch_inbox(
delete: bool = False,
test: bool = False,
check_timeout: int = 30,
ip_db_path: Optional[str] = None,
ip_db_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
reverse_dns_map_path: str | None = None,
reverse_dns_map_url: str | None = None,
offline: bool = False,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
dns_timeout: float = 6.0,
dns_retries: int = DEFAULT_DNS_MAX_RETRIES,
strip_attachment_payloads: bool = False,
batch_size: int = 10,
since: Optional[Union[datetime, date, str]] = None,
since: datetime | date | str | None = None,
normalize_timespan_threshold_hours: float = 24,
config_reloading: Optional[Callable] = None,
config_reloading: Callable | None = None,
):
"""
Watches the mailbox for new messages and
@@ -2629,11 +2624,9 @@ def watch_inbox(
def append_json(
filename: str,
reports: Union[
Sequence[AggregateReport],
Sequence[FailureReport],
Sequence[SMTPTLSReport],
],
reports: Sequence[AggregateReport]
| Sequence[FailureReport]
| Sequence[SMTPTLSReport],
) -> None:
"""Append ``reports`` to a JSON array on disk, creating the file
if needed.
@@ -2820,18 +2813,18 @@ def email_results(
results: ParsingResults,
host: str,
mail_from: str,
mail_to: Optional[list[str]],
mail_to: list[str] | None,
*,
mail_cc: Optional[list[str]] = None,
mail_bcc: Optional[list[str]] = None,
mail_cc: list[str] | None = None,
mail_bcc: list[str] | None = None,
port: int = 0,
require_encryption: bool = False,
verify: bool = True,
username: Optional[str] = None,
password: Optional[str] = None,
subject: Optional[str] = None,
attachment_filename: Optional[str] = None,
message: Optional[str] = None,
username: str | None = None,
password: str | None = None,
subject: str | None = None,
attachment_filename: str | None = None,
message: str | None = None,
):
"""
Emails parsing results as a zip file
+31 -31
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Optional, Union
from typing import Any
from elasticsearch.helpers import reindex
from elasticsearch_dsl import (
@@ -125,7 +125,7 @@ class _AggregateReportDoc(Document):
domain: str,
selector: str,
result: _DKIMResult,
human_result: Optional[str] = None,
human_result: str | None = None,
):
self.dkim_results.append(
_DKIMResult(
@@ -141,7 +141,7 @@ class _AggregateReportDoc(Document):
domain: str,
scope: str,
result: _SPFResult,
human_result: Optional[str] = None,
human_result: str | None = None,
):
self.spf_results.append(
_SPFResult(
@@ -254,15 +254,15 @@ class _SMTPTLSPolicyDoc(InnerDoc):
def add_failure_details(
self,
result_type: Optional[str] = None,
ip_address: Optional[str] = None,
receiving_ip: Optional[str] = None,
receiving_mx_helo: Optional[str] = None,
failed_session_count: Optional[int] = None,
sending_mta_ip: Optional[str] = None,
receiving_mx_hostname: Optional[str] = None,
additional_information_uri: Optional[str] = None,
failure_reason_code: Union[str, int, None] = None,
result_type: str | None = None,
ip_address: str | None = None,
receiving_ip: str | None = None,
receiving_mx_helo: str | None = None,
failed_session_count: int | None = None,
sending_mta_ip: str | None = None,
receiving_mx_hostname: str | None = None,
additional_information_uri: str | None = None,
failure_reason_code: str | int | None = None,
):
_details = _SMTPTLSFailureDetailsDoc(
result_type=result_type,
@@ -297,9 +297,9 @@ class _SMTPTLSReportDoc(Document):
successful_session_count: int,
failed_session_count: int,
*,
policy_string: Optional[str] = None,
mx_host_patterns: Optional[list[str]] = None,
failure_details: Optional[str] = None,
policy_string: str | None = None,
mx_host_patterns: list[str] | None = None,
failure_details: str | None = None,
):
self.policies.append(
policy_type=policy_type,
@@ -317,14 +317,14 @@ class AlreadySaved(ValueError):
def set_hosts(
hosts: Union[str, list[str]],
hosts: str | list[str],
*,
use_ssl: bool = False,
ssl_cert_path: Optional[str] = None,
ssl_cert_path: str | None = None,
skip_certificate_verification: bool = False,
username: Optional[str] = None,
password: Optional[str] = None,
api_key: Optional[str] = None,
username: str | None = None,
password: str | None = None,
api_key: str | None = None,
timeout: float = 60.0,
serverless: bool = False,
):
@@ -366,7 +366,7 @@ def set_hosts(
connections.create_connection(**conn_params)
def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
def create_indexes(names: list[str], settings: dict[str, Any] | None = None):
"""
Create Elasticsearch indexes
@@ -400,8 +400,8 @@ def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
def migrate_indexes(
aggregate_indexes: Optional[list[str]] = None,
failure_indexes: Optional[list[str]] = None,
aggregate_indexes: list[str] | None = None,
failure_indexes: list[str] | None = None,
):
"""
Updates index mappings
@@ -450,9 +450,9 @@ def migrate_indexes(
def save_aggregate_report_to_elasticsearch(
aggregate_report: dict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
index_suffix: str | None = None,
index_prefix: str | None = None,
monthly_indexes: bool | None = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
):
@@ -626,9 +626,9 @@ def save_aggregate_report_to_elasticsearch(
def save_failure_report_to_elasticsearch(
failure_report: dict[str, Any],
index_suffix: Optional[Any] = None,
index_prefix: Optional[str] = None,
monthly_indexes: Optional[bool] = False,
index_suffix: Any | None = None,
index_prefix: str | None = None,
monthly_indexes: bool | None = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
):
@@ -808,8 +808,8 @@ def save_failure_report_to_elasticsearch(
def save_smtp_tls_report_to_elasticsearch(
report: dict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
index_suffix: str | None = None,
index_prefix: str | None = None,
monthly_indexes: bool = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
+8 -8
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import json
from ssl import SSLContext, create_default_context
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any
from kafka import KafkaProducer
from kafka.errors import UnknownTopicOrPartitionError
@@ -34,10 +34,10 @@ class KafkaClient(object):
self,
kafka_hosts: list[str],
*,
ssl: Optional[bool] = False,
username: Optional[str] = None,
password: Optional[str] = None,
ssl_context: Optional[SSLContext] = None,
ssl: bool | None = False,
username: str | None = None,
password: str | None = None,
ssl_context: SSLContext | None = None,
):
"""
Initializes the Kafka client
@@ -111,7 +111,7 @@ class KafkaClient(object):
def save_aggregate_reports_to_kafka(
self,
aggregate_reports: Union[dict[str, Any], list[dict[str, Any]]],
aggregate_reports: dict[str, Any] | list[dict[str, Any]],
aggregate_topic: str,
):
"""
@@ -156,7 +156,7 @@ class KafkaClient(object):
def save_failure_reports_to_kafka(
self,
failure_reports: Union[dict[str, Any], list[dict[str, Any]]],
failure_reports: dict[str, Any] | list[dict[str, Any]],
failure_topic: str,
):
"""
@@ -193,7 +193,7 @@ class KafkaClient(object):
def save_smtp_tls_reports_to_kafka(
self,
smtp_tls_reports: Union[list[dict[str, Any]], dict[str, Any]],
smtp_tls_reports: list[dict[str, Any]] | dict[str, Any],
smtp_tls_topic: str,
):
"""
+32 -32
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Optional, Union
from typing import Any
import boto3
from opensearchpy import (
@@ -116,7 +116,7 @@ class _AggregateReportDoc(Document):
domain: str,
selector: str,
result: _DKIMResult,
human_result: Optional[str] = None,
human_result: str | None = None,
):
self.dkim_results.append(
_DKIMResult(
@@ -132,7 +132,7 @@ class _AggregateReportDoc(Document):
domain: str,
scope: str,
result: _SPFResult,
human_result: Optional[str] = None,
human_result: str | None = None,
):
self.spf_results.append(
_SPFResult(
@@ -245,15 +245,15 @@ class _SMTPTLSPolicyDoc(InnerDoc):
def add_failure_details(
self,
result_type: Optional[str] = None,
ip_address: Optional[str] = None,
receiving_ip: Optional[str] = None,
receiving_mx_helo: Optional[str] = None,
failed_session_count: Optional[int] = None,
sending_mta_ip: Optional[str] = None,
receiving_mx_hostname: Optional[str] = None,
additional_information_uri: Optional[str] = None,
failure_reason_code: Union[str, int, None] = None,
result_type: str | None = None,
ip_address: str | None = None,
receiving_ip: str | None = None,
receiving_mx_helo: str | None = None,
failed_session_count: int | None = None,
sending_mta_ip: str | None = None,
receiving_mx_hostname: str | None = None,
additional_information_uri: str | None = None,
failure_reason_code: str | int | None = None,
):
_details = _SMTPTLSFailureDetailsDoc(
result_type=result_type,
@@ -288,9 +288,9 @@ class _SMTPTLSReportDoc(Document):
successful_session_count: int,
failed_session_count: int,
*,
policy_string: Optional[str] = None,
mx_host_patterns: Optional[list[str]] = None,
failure_details: Optional[str] = None,
policy_string: str | None = None,
mx_host_patterns: list[str] | None = None,
failure_details: str | None = None,
):
self.policies.append(
policy_type=policy_type,
@@ -308,17 +308,17 @@ class AlreadySaved(ValueError):
def set_hosts(
hosts: Union[str, list[str]],
hosts: str | list[str],
*,
use_ssl: Optional[bool] = False,
ssl_cert_path: Optional[str] = None,
use_ssl: bool | None = False,
ssl_cert_path: str | None = None,
skip_certificate_verification: bool = False,
username: Optional[str] = None,
password: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Optional[float] = 60.0,
username: str | None = None,
password: str | None = None,
api_key: str | None = None,
timeout: float | None = 60.0,
auth_type: str = "basic",
aws_region: Optional[str] = None,
aws_region: str | None = None,
aws_service: str = "es",
):
"""
@@ -376,7 +376,7 @@ def set_hosts(
connections.create_connection(**conn_params)
def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
def create_indexes(names: list[str], settings: dict[str, Any] | None = None):
"""
Create OpenSearch indexes
@@ -400,8 +400,8 @@ def create_indexes(names: list[str], settings: Optional[dict[str, Any]] = None):
def migrate_indexes(
aggregate_indexes: Optional[list[str]] = None,
failure_indexes: Optional[list[str]] = None,
aggregate_indexes: list[str] | None = None,
failure_indexes: list[str] | None = None,
):
"""
Updates index mappings
@@ -450,8 +450,8 @@ def migrate_indexes(
def save_aggregate_report_to_opensearch(
aggregate_report: dict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
index_suffix: str | None = None,
index_prefix: str | None = None,
monthly_indexes: bool = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
@@ -626,8 +626,8 @@ def save_aggregate_report_to_opensearch(
def save_failure_report_to_opensearch(
failure_report: dict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
index_suffix: str | None = None,
index_prefix: str | None = None,
monthly_indexes: bool = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
@@ -806,8 +806,8 @@ def save_failure_report_to_opensearch(
def save_smtp_tls_report_to_opensearch(
report: dict[str, Any],
index_suffix: Optional[str] = None,
index_prefix: Optional[str] = None,
index_suffix: str | None = None,
index_prefix: str | None = None,
monthly_indexes: bool = False,
number_of_shards: int = 1,
number_of_replicas: int = 0,
+13 -13
View File
@@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# LiteralString requires Python >= 3.11, so only import it for type checking
@@ -39,7 +39,7 @@ _PSYCOPG_INSTALL_HINT = (
# Aggregate *record* interval_begin/end and SMTP-TLS begin/end are already
# **UTC** naive strings, so they only need a ``+00`` suffix via
# ``_ensure_utc_suffix``. Using the wrong helper silently shifts timestamps.
def _ensure_utc_suffix(value: Optional[str]) -> Optional[str]:
def _ensure_utc_suffix(value: str | None) -> str | None:
"""Append ``+00`` to a timestamp string if it lacks timezone info.
Several parsers produce ``YYYY-MM-DD HH:MM:SS`` format strings that
@@ -52,7 +52,7 @@ def _ensure_utc_suffix(value: Optional[str]) -> Optional[str]:
return value
def _naive_local_to_timestamptz(value: Optional[str]) -> Optional[str]:
def _naive_local_to_timestamptz(value: str | None) -> str | None:
"""Convert a naive local-time string to an ISO 8601 string with offset.
``timestamp_to_human()`` produces ``YYYY-MM-DD HH:MM:SS`` in
@@ -72,7 +72,7 @@ def _naive_local_to_timestamptz(value: Optional[str]) -> Optional[str]:
return aware.isoformat()
def _normalize_arrival_date(value: Optional[str]) -> Optional[str]:
def _normalize_arrival_date(value: str | None) -> str | None:
"""Normalize a failure-report ``arrival_date`` for safe TIMESTAMPTZ insert.
The arrival date may be an RFC 2822 string (e.g.
@@ -92,12 +92,12 @@ def _normalize_arrival_date(value: Optional[str]) -> Optional[str]:
def _contact_info_to_text(
value: Union[str, list, None],
) -> Optional[str]:
value: str | list | None,
) -> str | None:
"""Ensure ``contact_info`` is a plain string.
The TLS-RPT ``contact-info`` field is normally a single string, but
the TypedDict allows ``Union[str, List[str]]``. If a list is
the TypedDict allows ``str | list[str]``. If a list is
encountered, join the entries so they fit into a ``TEXT`` column.
"""
if value is None:
@@ -125,12 +125,12 @@ class PostgreSQLClient:
def __init__(
self,
connection_string: Optional[str] = None,
host: Optional[str] = None,
connection_string: str | None = None,
host: str | None = None,
port: int = 5432,
user: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
user: str | None = None,
password: str | None = None,
database: str | None = None,
) -> None:
"""
Initializes the PostgreSQLClient and opens a database connection.
@@ -160,7 +160,7 @@ class PostgreSQLClient:
self._password = password
self._database = database
self._conn: Optional[psycopg.Connection] = None
self._conn: psycopg.Connection | None = None
self._connect()
def _connect(self) -> psycopg.Connection:
+8 -9
View File
@@ -6,7 +6,6 @@ import codecs
import os
import sys
import shutil
from typing import List, Optional, Tuple
"""
Locates and optionally corrects bad UTF-8 bytes in a file.
@@ -98,7 +97,7 @@ def scan_file_for_utf8_errors(path: str, context: int, limit: int):
# -------------------------
def detect_encoding_text(path: str) -> Tuple[str, str]:
def detect_encoding_text(path: str) -> tuple[str, str]:
"""
Use charset-normalizer to detect file encoding.
Return (encoding_name, decoded_text). Falls back to cp1252 if needed.
@@ -128,7 +127,7 @@ def detect_encoding_text(path: str) -> Tuple[str, str]:
def convert_to_utf8(
src_path: str, out_path: str, src_encoding: Optional[str] = None
src_path: str, out_path: str, src_encoding: str | None = None
) -> str:
"""
Convert an entire file to UTF-8 (re-decoding everything).
@@ -155,7 +154,7 @@ def convert_to_utf8(
return used
def verify_utf8_file(path: str) -> Tuple[bool, str]:
def verify_utf8_file(path: str) -> tuple[bool, str]:
try:
with open(path, "rb") as fb:
fb.read().decode("utf-8", errors="strict")
@@ -182,17 +181,17 @@ def iter_lines_with_offsets(b: bytes):
yield b[start:], start
def detect_probable_fallbacks() -> List[str]:
def detect_probable_fallbacks() -> list[str]:
# Good defaults for Western/Portuguese text
return ["cp1252", "iso-8859-1", "iso-8859-15"]
def repair_mixed_utf8_line(line: bytes, base_offset: int, fallback_chain: List[str]):
def repair_mixed_utf8_line(line: bytes, base_offset: int, fallback_chain: list[str]):
"""
Strictly validate UTF-8 and fix *only* the exact offending byte when an error occurs.
This avoids touching adjacent valid UTF-8 (prevents mojibake like 'é').
"""
out_fragments: List[str] = []
out_fragments: list[str] = []
fixes = []
pos = 0
n = len(line)
@@ -253,7 +252,7 @@ def repair_mixed_utf8_line(line: bytes, base_offset: int, fallback_chain: List[s
def targeted_fix_to_utf8(
src_path: str,
out_path: str,
fallback_chain: List[str],
fallback_chain: list[str],
dry_run: bool,
max_fixes: int,
):
@@ -261,7 +260,7 @@ def targeted_fix_to_utf8(
data = fb.read()
total_fixes = 0
repaired_lines: List[str] = []
repaired_lines: list[str] = []
line_no = 0
max_val = max_fixes if max_fixes != 0 else float("inf")
+15 -15
View File
@@ -6,7 +6,7 @@ import os
import csv
import re
from pathlib import Path
from typing import Mapping, Iterable, Optional, Collection, Union, List, Dict
from collections.abc import Mapping, Iterable, Collection
_TYPES_LIST_RE = re.compile(
@@ -15,9 +15,9 @@ _TYPES_LIST_RE = re.compile(
)
def _parse_types_block(block: str, source: str) -> List[str]:
def _parse_types_block(block: str, source: str) -> list[str]:
"""Extract type names from the raw text between the marker comments."""
types: List[str] = []
types: list[str] = []
for line in block.splitlines():
stripped = line.strip()
if not stripped:
@@ -30,7 +30,7 @@ def _parse_types_block(block: str, source: str) -> List[str]:
return types
def normalize_types_in_readme(readme_path: Union[str, Path]) -> List[str]:
def normalize_types_in_readme(readme_path: str | Path) -> list[str]:
"""Validate, normalize, and load the authoritative `type` list from README.md.
Trims leading/trailing whitespace from each item, deduplicates
@@ -52,7 +52,7 @@ def normalize_types_in_readme(readme_path: Union[str, Path]) -> List[str]:
if not raw_types:
raise ValueError(f"{path}: types-list block is empty")
seen: Dict[str, str] = {}
seen: dict[str, str] = {}
for t in raw_types:
key = t.lower()
if key in seen and seen[key] != t:
@@ -71,7 +71,7 @@ def normalize_types_in_readme(readme_path: Union[str, Path]) -> List[str]:
return normalized
def load_types_from_readme(readme_path: Union[str, Path]) -> List[str]:
def load_types_from_readme(readme_path: str | Path) -> list[str]:
"""Read the authoritative `type` list out of README.md without rewriting.
Use `normalize_types_in_readme` to additionally sort, dedupe, and
@@ -98,15 +98,15 @@ class CSVValidationError(Exception):
def sort_csv(
filepath: Union[str, Path],
filepath: str | Path,
field: str,
*,
sort_field_value_must_be_unique: bool = True,
strip_whitespace: bool = True,
fields_to_lowercase: Optional[Iterable[str]] = None,
fields_to_lowercase: Iterable[str] | None = None,
case_insensitive_sort: bool = False,
required_fields: Optional[Iterable[str]] = None,
allowed_values: Optional[Mapping[str, Collection[str]]] = None,
required_fields: Iterable[str] | None = None,
allowed_values: Mapping[str, Collection[str]] | None = None,
) -> None:
"""
Read a CSV, optionally normalize rows (strip whitespace, lowercase certain fields),
@@ -138,7 +138,7 @@ def sort_csv(
)
rows = list(reader)
def normalize_row(row: Dict[str, str]) -> None:
def normalize_row(row: dict[str, str]) -> None:
if strip_whitespace:
for k, v in row.items():
if isinstance(v, str):
@@ -148,7 +148,7 @@ def sort_csv(
row[fld] = row[fld].lower()
def validate_row(
row: Dict[str, str], sort_field: str, line_no: int, errors: list[str]
row: dict[str, str], sort_field: str, line_no: int, errors: list[str]
) -> None:
if sort_field_value_must_be_unique:
if row[sort_field] in seen_sort_field_values:
@@ -178,7 +178,7 @@ def sort_csv(
if errors:
raise CSVValidationError(errors)
def sort_key(r: Dict[str, str]):
def sort_key(r: dict[str, str]):
v = r.get(field, "")
if isinstance(v, str) and case_insensitive_sort:
return v.casefold()
@@ -193,14 +193,14 @@ def sort_csv(
def sort_list_file(
filepath: Union[str, Path],
filepath: str | Path,
*,
lowercase: bool = True,
strip: bool = True,
deduplicate: bool = True,
remove_blank_lines: bool = True,
ending_newline: bool = True,
newline: Optional[str] = "\n",
newline: str | None = "\n",
):
"""Read a list from a file, sort it, optionally strip and deduplicate the values,
then write that list back to the file.
+6 -6
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import json
import socket
from typing import Any, Union
from typing import Any
from urllib.parse import urlparse
import requests
@@ -59,7 +59,7 @@ class HECClient(object):
self.session = requests.Session()
self.timeout = timeout
self.verify = verify
self._common_data: dict[str, Union[str, int, float, dict]] = dict(
self._common_data: dict[str, str | int | float | dict] = dict(
host=self.host, source=self.source, index=self.index
)
@@ -72,7 +72,7 @@ class HECClient(object):
def save_aggregate_reports_to_splunk(
self,
aggregate_reports: Union[list[dict[str, Any]], dict[str, Any]],
aggregate_reports: list[dict[str, Any]] | dict[str, Any],
):
"""
Saves aggregate DMARC reports to Splunk
@@ -93,7 +93,7 @@ class HECClient(object):
json_str = ""
for report in aggregate_reports:
for record in report["records"]:
new_report: dict[str, Union[str, int, float, dict]] = dict()
new_report: dict[str, str | int | float | dict] = dict()
for metadata in report["report_metadata"]:
new_report[metadata] = report["report_metadata"][metadata]
new_report["interval_begin"] = record["interval_begin"]
@@ -143,7 +143,7 @@ class HECClient(object):
def save_failure_reports_to_splunk(
self,
failure_reports: Union[list[dict[str, Any]], dict[str, Any]],
failure_reports: list[dict[str, Any]] | dict[str, Any],
):
"""
Saves failure DMARC reports to Splunk
@@ -181,7 +181,7 @@ class HECClient(object):
raise SplunkError(response["text"])
def save_smtp_tls_reports_to_splunk(
self, reports: Union[list[dict[str, Any]], dict[str, Any]]
self, reports: list[dict[str, Any]] | dict[str, Any]
):
"""
Saves aggregate DMARC reports to Splunk
+6 -7
View File
@@ -9,7 +9,6 @@ import logging.handlers
import socket
import ssl
import time
from typing import Optional
from parsedmarc import (
parsed_aggregate_reports_to_csv_rows,
@@ -27,9 +26,9 @@ class SyslogClient(object):
server_name: str,
server_port: int,
protocol: str = "udp",
cafile_path: Optional[str] = None,
certfile_path: Optional[str] = None,
keyfile_path: Optional[str] = None,
cafile_path: str | None = None,
certfile_path: str | None = None,
keyfile_path: str | None = None,
timeout: float = 5.0,
retry_attempts: int = 3,
retry_delay: int = 5,
@@ -77,9 +76,9 @@ class SyslogClient(object):
server_name: str,
server_port: int,
protocol: str,
cafile_path: Optional[str],
certfile_path: Optional[str],
keyfile_path: Optional[str],
cafile_path: str | None,
certfile_path: str | None,
keyfile_path: str | None,
timeout: float,
retry_attempts: int,
retry_delay: int,
+68 -67
View File
@@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, Literal, TypedDict
# NOTE: This module is intentionally Python 3.10 compatible.
# - No PEP 604 unions (A | B)
# - No typing.NotRequired / Required (3.11+) to avoid an extra dependency.
# For optional keys, use total=False TypedDicts.
# NOTE: This module targets Python 3.10.
# - PEP 604 unions (A | B) and PEP 585 generics (list[str]) are used; both are
# available in 3.10.
# - No typing.NotRequired / Required (3.11+); for optional TypedDict keys, use
# total=False TypedDicts.
ReportType = Literal["aggregate", "failure", "smtp_tls"]
@@ -14,14 +15,14 @@ ReportType = Literal["aggregate", "failure", "smtp_tls"]
class AggregateReportMetadata(TypedDict):
org_name: str
org_email: str
org_extra_contact_info: Optional[str]
org_extra_contact_info: str | None
report_id: str
begin_date: str
end_date: str
timespan_requires_normalization: bool
original_timespan_seconds: int
errors: List[str]
generator: Optional[str]
errors: list[str]
generator: str | None
class AggregatePolicyPublished(TypedDict):
@@ -30,23 +31,23 @@ class AggregatePolicyPublished(TypedDict):
aspf: str
p: str
sp: str
pct: Optional[str]
fo: Optional[str]
np: Optional[str]
testing: Optional[str]
discovery_method: Optional[str]
pct: str | None
fo: str | None
np: str | None
testing: str | None
discovery_method: str | None
class IPSourceInfo(TypedDict):
ip_address: str
country: Optional[str]
reverse_dns: Optional[str]
base_domain: Optional[str]
name: Optional[str]
type: Optional[str]
asn: Optional[int]
as_name: Optional[str]
as_domain: Optional[str]
country: str | None
reverse_dns: str | None
base_domain: str | None
name: str | None
type: str | None
asn: int | None
as_name: str | None
as_domain: str | None
class AggregateAlignment(TypedDict):
@@ -57,39 +58,39 @@ class AggregateAlignment(TypedDict):
class AggregateIdentifiers(TypedDict):
header_from: str
envelope_from: Optional[str]
envelope_to: Optional[str]
envelope_from: str | None
envelope_to: str | None
class AggregatePolicyOverrideReason(TypedDict):
type: Optional[str]
comment: Optional[str]
type: str | None
comment: str | None
class AggregateAuthResultDKIM(TypedDict):
domain: str
result: str
selector: str
human_result: Optional[str]
human_result: str | None
class AggregateAuthResultSPF(TypedDict):
domain: str
result: str
scope: str
human_result: Optional[str]
human_result: str | None
class AggregateAuthResults(TypedDict):
dkim: List[AggregateAuthResultDKIM]
spf: List[AggregateAuthResultSPF]
dkim: list[AggregateAuthResultDKIM]
spf: list[AggregateAuthResultSPF]
class AggregatePolicyEvaluated(TypedDict):
disposition: str
dkim: str
spf: str
policy_override_reasons: List[AggregatePolicyOverrideReason]
policy_override_reasons: list[AggregatePolicyOverrideReason]
class AggregateRecord(TypedDict):
@@ -106,23 +107,23 @@ class AggregateRecord(TypedDict):
class AggregateReport(TypedDict):
xml_schema: str
xml_namespace: Optional[str]
xml_namespace: str | None
report_metadata: AggregateReportMetadata
policy_published: AggregatePolicyPublished
records: List[AggregateRecord]
records: list[AggregateRecord]
class EmailAddress(TypedDict):
display_name: Optional[str]
display_name: str | None
address: str
local: Optional[str]
domain: Optional[str]
local: str | None
domain: str | None
class EmailAttachment(TypedDict, total=False):
filename: Optional[str]
mail_content_type: Optional[str]
sha256: Optional[str]
filename: str | None
mail_content_type: str | None
sha256: str | None
ParsedEmail = TypedDict(
@@ -130,16 +131,16 @@ ParsedEmail = TypedDict(
{
# This is a lightly-specified version of mailsuite/mailparser JSON.
# It focuses on the fields parsedmarc uses in failure report handling.
"headers": Dict[str, Any],
"subject": Optional[str],
"filename_safe_subject": Optional[str],
"date": Optional[str],
"headers": dict[str, Any],
"subject": str | None,
"filename_safe_subject": str | None,
"date": str | None,
"from": EmailAddress,
"to": List[EmailAddress],
"cc": List[EmailAddress],
"bcc": List[EmailAddress],
"attachments": List[EmailAttachment],
"body": Optional[str],
"to": list[EmailAddress],
"cc": list[EmailAddress],
"bcc": list[EmailAddress],
"attachments": list[EmailAttachment],
"body": str | None,
"has_defects": bool,
"defects": Any,
"defects_categories": Any,
@@ -149,19 +150,19 @@ ParsedEmail = TypedDict(
class FailureReport(TypedDict):
feedback_type: Optional[str]
user_agent: Optional[str]
version: Optional[str]
original_envelope_id: Optional[str]
original_mail_from: Optional[str]
original_rcpt_to: Optional[str]
feedback_type: str | None
user_agent: str | None
version: str | None
original_envelope_id: str | None
original_mail_from: str | None
original_rcpt_to: str | None
arrival_date: str
arrival_date_utc: str
authentication_results: Optional[str]
delivery_result: Optional[str]
auth_failure: List[str]
authentication_mechanisms: List[str]
dkim_domain: Optional[str]
authentication_results: str | None
delivery_result: str | None
auth_failure: list[str]
authentication_mechanisms: list[str]
dkim_domain: str | None
reported_domain: str
sample_headers_only: bool
source: IPSourceInfo
@@ -196,18 +197,18 @@ class SMTPTLSPolicySummary(TypedDict):
class SMTPTLSPolicy(SMTPTLSPolicySummary, total=False):
policy_strings: List[str]
mx_host_patterns: List[str]
failure_details: List[SMTPTLSFailureDetailsOptional]
policy_strings: list[str]
mx_host_patterns: list[str]
failure_details: list[SMTPTLSFailureDetailsOptional]
class SMTPTLSReport(TypedDict):
organization_name: str
begin_date: str
end_date: str
contact_info: Union[str, List[str]]
contact_info: str | list[str]
report_id: str
policies: List[SMTPTLSPolicy]
policies: list[SMTPTLSPolicy]
class AggregateParsedReport(TypedDict):
@@ -229,10 +230,10 @@ class SMTPTLSParsedReport(TypedDict):
report: SMTPTLSReport
ParsedReport = Union[AggregateParsedReport, FailureParsedReport, SMTPTLSParsedReport]
ParsedReport = AggregateParsedReport | FailureParsedReport | SMTPTLSParsedReport
class ParsingResults(TypedDict):
aggregate_reports: List[AggregateReport]
failure_reports: List[FailureReport]
smtp_tls_reports: List[SMTPTLSReport]
aggregate_reports: list[AggregateReport]
failure_reports: list[FailureReport]
smtp_tls_reports: list[SMTPTLSReport]
+51 -53
View File
@@ -17,7 +17,7 @@ import shutil
import subprocess
import tempfile
from datetime import datetime, timedelta, timezone
from typing import Optional, TypedDict, Union, cast
from typing import TypedDict, cast
import mailparser
from expiringdict import ExpiringDict
@@ -67,8 +67,8 @@ psl_overrides: list[str] = []
def load_psl_overrides(
*,
always_use_local_file: bool = False,
local_file_path: Optional[str] = None,
url: Optional[str] = None,
local_file_path: str | None = None,
url: str | None = None,
offline: bool = False,
) -> list[str]:
"""
@@ -138,7 +138,7 @@ class DownloadError(RuntimeError):
class ReverseDNSService(TypedDict):
name: str
type: Optional[str]
type: str | None
ReverseDNSMap = dict[str, ReverseDNSService]
@@ -146,14 +146,14 @@ ReverseDNSMap = dict[str, ReverseDNSService]
class IPAddressInfo(TypedDict):
ip_address: str
reverse_dns: Optional[str]
country: Optional[str]
base_domain: Optional[str]
name: Optional[str]
type: Optional[str]
asn: Optional[int]
as_name: Optional[str]
as_domain: Optional[str]
reverse_dns: str | None
country: str | None
base_domain: str | None
name: str | None
type: str | None
asn: int | None
as_name: str | None
as_domain: str | None
def decode_base64(data: str) -> bytes:
@@ -174,7 +174,7 @@ def decode_base64(data: str) -> bytes:
return base64.b64decode(data_bytes)
def get_base_domain(domain: str) -> Optional[str]:
def get_base_domain(domain: str) -> str | None:
"""
Gets the base domain name for the given domain
@@ -202,8 +202,8 @@ def query_dns(
domain: str,
record_type: str,
*,
cache: Optional[ExpiringDict] = None,
nameservers: Optional[list[str]] = None,
cache: ExpiringDict | None = None,
nameservers: list[str] | None = None,
timeout: float = DEFAULT_DNS_TIMEOUT,
retries: int = DEFAULT_DNS_MAX_RETRIES,
_attempt: int = 0,
@@ -294,11 +294,11 @@ def query_dns(
def get_reverse_dns(
ip_address,
*,
cache: Optional[ExpiringDict] = None,
nameservers: Optional[list[str]] = None,
cache: ExpiringDict | None = None,
nameservers: list[str] | None = None,
timeout: float = DEFAULT_DNS_TIMEOUT,
retries: int = DEFAULT_DNS_MAX_RETRIES,
) -> Optional[str]:
) -> str | None:
"""
Resolves an IP address to a hostname using a reverse DNS query
@@ -393,14 +393,14 @@ def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int:
return int(human_timestamp_to_datetime(human_timestamp).timestamp())
_IP_DB_PATH: Optional[str] = None
_IP_DB_PATH: str | None = None
def load_ip_db(
*,
always_use_local_file: bool = False,
local_file_path: Optional[str] = None,
url: Optional[str] = None,
local_file_path: str | None = None,
url: str | None = None,
offline: bool = False,
) -> None:
"""
@@ -461,10 +461,10 @@ def load_ip_db(
class _IPDatabaseRecord(TypedDict):
country: Optional[str]
asn: Optional[int]
as_name: Optional[str]
as_domain: Optional[str]
country: str | None
asn: int | None
as_name: str | None
as_domain: str | None
class InvalidIPinfoAPIKey(Exception):
@@ -481,12 +481,12 @@ class InvalidIPinfoAPIKey(Exception):
# here — adding it would be inventing behavior the service doesn't document.
# Authentication uses the documented ``?token=`` query parameter.
_IPINFO_API_URL = "https://api.ipinfo.io/lite"
_IPINFO_API_TOKEN: Optional[str] = None
_IPINFO_API_TOKEN: str | None = None
_IPINFO_API_TIMEOUT: float = 5.0
def configure_ipinfo_api(
token: Optional[str],
token: str | None,
*,
probe: bool = True,
) -> None:
@@ -520,7 +520,7 @@ def configure_ipinfo_api(
logger.info("IPinfo API configured")
def _ipinfo_api_lookup(ip_address: str) -> Optional[_IPDatabaseRecord]:
def _ipinfo_api_lookup(ip_address: str) -> _IPDatabaseRecord | None:
"""Look up an IP via the IPinfo Lite REST API.
Returns the normalized record on success, or ``None`` on network error or
@@ -569,10 +569,10 @@ def _normalize_ip_record(record: dict) -> _IPDatabaseRecord:
same output: country as ISO code, ASN as plain int, as_name string,
as_domain lowercased.
"""
country: Optional[str] = None
asn: Optional[int] = None
as_name: Optional[str] = None
as_domain: Optional[str] = None
country: str | None = None
asn: int | None = None
as_name: str | None = None
as_domain: str | None = None
code = record.get("country_code")
if code is None:
@@ -609,7 +609,7 @@ def _normalize_ip_record(record: dict) -> _IPDatabaseRecord:
}
def _get_ip_database_path(db_path: Optional[str]) -> str:
def _get_ip_database_path(db_path: str | None) -> str:
db_paths = [
"ipinfo_lite.mmdb",
"GeoLite2-Country.mmdb",
@@ -655,7 +655,7 @@ def _get_ip_database_path(db_path: Optional[str]) -> str:
def get_ip_address_db_record(
ip_address: str, *, db_path: Optional[str] = None
ip_address: str, *, db_path: str | None = None
) -> _IPDatabaseRecord:
"""Look up an IP and return country + ASN fields.
@@ -686,8 +686,8 @@ def get_ip_address_db_record(
def get_ip_address_country(
ip_address: str, *, db_path: Optional[str] = None
) -> Optional[str]:
ip_address: str, *, db_path: str | None = None
) -> str | None:
"""
Returns the ISO code for the country associated
with the given IPv4 or IPv6 address.
@@ -706,11 +706,11 @@ def load_reverse_dns_map(
reverse_dns_map: ReverseDNSMap,
*,
always_use_local_file: bool = False,
local_file_path: Optional[str] = None,
url: Optional[str] = None,
local_file_path: str | None = None,
url: str | None = None,
offline: bool = False,
psl_overrides_path: Optional[str] = None,
psl_overrides_url: Optional[str] = None,
psl_overrides_path: str | None = None,
psl_overrides_url: str | None = None,
) -> None:
"""
Loads the reverse DNS map from a URL or local file.
@@ -794,10 +794,10 @@ def get_service_from_reverse_dns_base_domain(
base_domain,
*,
always_use_local_file: bool = False,
local_file_path: Optional[str] = None,
url: Optional[str] = None,
local_file_path: str | None = None,
url: str | None = None,
offline: bool = False,
reverse_dns_map: Optional[ReverseDNSMap] = None,
reverse_dns_map: ReverseDNSMap | None = None,
) -> ReverseDNSService:
"""
Returns the service name of a given base domain name from reverse DNS.
@@ -843,14 +843,14 @@ def get_service_from_reverse_dns_base_domain(
def get_ip_address_info(
ip_address,
*,
ip_db_path: Optional[str] = None,
reverse_dns_map_path: Optional[str] = None,
ip_db_path: str | None = None,
reverse_dns_map_path: str | None = None,
always_use_local_files: bool = False,
reverse_dns_map_url: Optional[str] = None,
cache: Optional[ExpiringDict] = None,
reverse_dns_map: Optional[ReverseDNSMap] = None,
reverse_dns_map_url: str | None = None,
cache: ExpiringDict | None = None,
reverse_dns_map: ReverseDNSMap | None = None,
offline: bool = False,
nameservers: Optional[list[str]] = None,
nameservers: list[str] | None = None,
timeout: float = DEFAULT_DNS_TIMEOUT,
retries: int = DEFAULT_DNS_MAX_RETRIES,
) -> IPAddressInfo:
@@ -974,7 +974,7 @@ def get_ip_address_info(
return info
def parse_email_address(original_address: str) -> dict[str, Optional[str]]:
def parse_email_address(original_address: str) -> dict[str, str | None]:
if original_address[0] == "":
display_name = None
else:
@@ -1089,9 +1089,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> bytes:
return rfc822
def parse_email(
data: Union[bytes, str], *, strip_attachment_payloads: bool = False
) -> dict:
def parse_email(data: bytes | str, *, strip_attachment_payloads: bool = False) -> dict:
"""
A simplified email parser
+3 -5
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Optional, Union
from typing import Any
import requests
@@ -18,7 +18,7 @@ class WebhookClient(object):
aggregate_url: str,
failure_url: str,
smtp_tls_url: str,
timeout: Optional[int] = 60,
timeout: int | None = 60,
):
"""
Initializes the WebhookClient
@@ -49,9 +49,7 @@ class WebhookClient(object):
def save_aggregate_report_to_webhook(self, report: str):
self._send_to_webhook(self.aggregate_url, report)
def _send_to_webhook(
self, webhook_url: str, payload: Union[bytes, str, dict[str, Any]]
):
def _send_to_webhook(self, webhook_url: str, payload: bytes | str | dict[str, Any]):
# All HTTP / network errors are swallowed and logged: a failing
# webhook should never abort the surrounding parse-and-output
# batch. The outer save_* methods previously wrapped this in a
+12
View File
@@ -113,6 +113,18 @@ exclude = [
"parsedmarc/resources/maps/[!_]*.py",
]
[tool.ruff.lint]
# Enforce modern type-hint syntax on top of ruff's default rules. With
# requires-python >=3.10, PEP 585 builtins (list[int]) and PEP 604 unions
# (X | Y, X | None) are available, so keep the deprecated typing.List /
# Union / Optional spellings out of the codebase.
extend-select = [
"UP006", # non-pep585-annotation: List -> list, Dict -> dict
"UP007", # non-pep604-annotation-union: Union[X, Y] -> X | Y
"UP035", # deprecated-import: typing.List etc. / typing -> collections.abc
"UP045", # non-pep604-annotation-optional: Optional[X] -> X | None
]
[tool.pyright]
# The whole codebase passes pyright with zero errors and warnings; CI
# enforces this (see .github/workflows/python-tests.yml). Run locally with