More code cleanup

This commit is contained in:
Sean Whalen
2025-12-24 16:36:59 -05:00
parent bb8f4002bf
commit a76c2f9621
7 changed files with 103 additions and 85 deletions

View File

@@ -291,12 +291,12 @@ def _parse_report_record(
record: dict[str, Any],
*,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
dns_timeout: float = 2.0,
) -> dict[str, Any]:
"""
Converts a record from a DMARC aggregate report into a more consistent
@@ -653,12 +653,12 @@ def parse_aggregate_report_xml(
xml: str,
*,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
nameservers: Optional[list[str]] = None,
timeout: Optional[float] = 2.0,
timeout: float = 2.0,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
@@ -1416,14 +1416,14 @@ def parsed_forensic_reports_to_csv(reports: list[dict[str, Any]]) -> str:
def parse_report_email(
input_: Union[bytes, str],
*,
offline: Optional[bool] = False,
offline: bool = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
@@ -1699,14 +1699,14 @@ def get_dmarc_reports_from_mbox(
input_: str,
*,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
offline: bool = False,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, list[dict[str, Any]]]:
"""Parses a mailbox in mbox format containing e-mails with attached
DMARC reports
@@ -1785,23 +1785,23 @@ def get_dmarc_reports_from_mbox(
def get_dmarc_reports_from_mailbox(
connection: MailboxConnection,
*,
reports_folder: Optional[str] = "INBOX",
archive_folder: Optional[str] = "Archive",
delete: Optional[bool] = False,
test: Optional[bool] = False,
reports_folder: str = "INBOX",
archive_folder: str = "Archive",
delete: bool = False,
test: bool = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 6.0,
strip_attachment_payloads: Optional[bool] = False,
dns_timeout: float = 6.0,
strip_attachment_payloads: bool = False,
results: Optional[dict[str, Any]] = None,
batch_size: Optional[int] = 10,
since: Optional[datetime] = None,
create_folders: Optional[bool] = True,
normalize_timespan_threshold_hours: Optional[float] = 24,
batch_size: int = 10,
since: Optional[Union[datetime, str]] = None,
create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24,
) -> dict[str, list[dict[str, Any]]]:
"""
Fetches and parses DMARC reports from a mailbox
@@ -2343,20 +2343,20 @@ def get_report_zip(results: dict[str, Any]) -> bytes:
def email_results(
results: dict[str, Any],
*,
host: str,
mail_from: str,
mail_to: str,
mail_cc: list = None,
mail_bcc: list = None,
*,
mail_cc: Optional[list] = None,
mail_bcc: Optional[list] = None,
port: int = 0,
require_encryption: bool = False,
verify: bool = True,
username: str = None,
password: str = None,
subject: str = None,
attachment_filename: str = None,
message: str = None,
username: Optional[str] = None,
password: Optional[str] = None,
subject: Optional[str] = None,
attachment_filename: Optional[str] = None,
message: Optional[str] = None,
):
"""
Emails parsing results as a zip file

View File

@@ -6,7 +6,6 @@
import http.client
import json
import logging
import math
import os
import sys
from argparse import ArgumentParser, Namespace
@@ -14,6 +13,7 @@ from configparser import ConfigParser
from glob import glob
from multiprocessing import Pipe, Process
from ssl import CERT_NONE, create_default_context
from typing import Union, cast
import yaml
from tqdm import tqdm
@@ -140,7 +140,7 @@ def _main():
print(output_str)
if opts.output:
save_output(
results,
reports_,
output_directory=opts.output,
aggregate_json_filename=opts.aggregate_json_filename,
forensic_json_filename=opts.forensic_json_filename,
@@ -709,6 +709,8 @@ def _main():
opts.smtp_tls_csv_filename = general_config["smtp_tls_csv_filename"]
if "dns_timeout" in general_config:
opts.dns_timeout = general_config.getfloat("dns_timeout")
if opts.dns_timeout is None:
opts.dns_timeout = 2
if "dns_test_address" in general_config:
opts.dns_test_address = general_config["dns_test_address"]
if "nameservers" in general_config:
@@ -800,7 +802,7 @@ def _main():
if "port" in imap_config:
opts.imap_port = imap_config.getint("port")
if "timeout" in imap_config:
opts.imap_timeout = imap_config.getfloat("timeout")
opts.imap_timeout = imap_config.getint("timeout")
if "max_retries" in imap_config:
opts.imap_max_retries = imap_config.getint("max_retries")
if "ssl" in imap_config:
@@ -1193,7 +1195,9 @@ def _main():
if "maildir" in config.sections():
maildir_api_config = config["maildir"]
opts.maildir_path = maildir_api_config.get("maildir_path")
opts.maildir_create = maildir_api_config.get("maildir_create")
opts.maildir_create = maildir_api_config.getboolean(
"maildir_create", fallback=False
)
if "log_analytics" in config.sections():
log_analytics_config = config["log_analytics"]
@@ -1436,16 +1440,19 @@ def _main():
results = []
pbar = None
if sys.stdout.isatty():
pbar = tqdm(total=len(file_paths))
for batch_index in range(math.ceil(len(file_paths) / opts.n_procs)):
n_procs = int(opts.n_procs or 1)
if n_procs < 1:
n_procs = 1
for batch_index in range((len(file_paths) + n_procs - 1) // n_procs):
processes = []
connections = []
for proc_index in range(
opts.n_procs * batch_index, opts.n_procs * (batch_index + 1)
):
for proc_index in range(n_procs * batch_index, n_procs * (batch_index + 1)):
if proc_index >= len(file_paths):
break
@@ -1478,9 +1485,12 @@ def _main():
for proc in processes:
proc.join()
if sys.stdout.isatty():
if pbar is not None:
counter += 1
pbar.update(counter - pbar.n)
pbar.update(1)
if pbar is not None:
pbar.close()
for result in results:
if isinstance(result[0], ParserError) or result[0] is None:
@@ -1537,13 +1547,19 @@ def _main():
if not opts.imap_ssl:
ssl = False
imap_timeout = (
int(opts.imap_timeout) if opts.imap_timeout is not None else 30
)
imap_max_retries = (
int(opts.imap_max_retries) if opts.imap_max_retries is not None else 4
)
mailbox_connection = IMAPConnection(
host=opts.imap_host,
port=opts.imap_port,
ssl=ssl,
verify=verify,
timeout=opts.imap_timeout,
max_retries=opts.imap_max_retries,
timeout=imap_timeout,
max_retries=imap_max_retries,
user=opts.imap_user,
password=opts.imap_password,
)
@@ -1564,7 +1580,7 @@ def _main():
username=opts.graph_user,
password=opts.graph_password,
token_file=opts.graph_token_file,
allow_unencrypted_storage=opts.graph_allow_unencrypted_storage,
allow_unencrypted_storage=bool(opts.graph_allow_unencrypted_storage),
graph_url=opts.graph_url,
)

View File

@@ -116,14 +116,14 @@ class GmailConnection(MailboxConnection):
else:
return [id for id in self._fetch_all_message_ids(reports_label_id)]
def fetch_message(self, message_id):
def fetch_message(self, message_id) -> str:
msg = (
self.service.users()
.messages()
.get(userId="me", id=message_id, format="raw")
.execute()
)
return urlsafe_b64decode(msg["raw"])
return urlsafe_b64decode(msg["raw"]).decode(errors="replace")
def delete_message(self, message_id: str):
self.service.users().messages().delete(userId="me", id=message_id)

View File

@@ -6,7 +6,7 @@ from enum import Enum
from functools import lru_cache
from pathlib import Path
from time import sleep
from typing import List, Optional
from typing import Any, List, Optional, Union
from azure.identity import (
UsernamePasswordCredential,
@@ -28,7 +28,7 @@ class AuthMethod(Enum):
def _get_cache_args(token_path: Path, allow_unencrypted_storage):
cache_args = {
cache_args: dict[str, Any] = {
"cache_persistence_options": TokenCachePersistenceOptions(
name="parsedmarc", allow_unencrypted_storage=allow_unencrypted_storage
)
@@ -151,9 +151,9 @@ class MSGraphConnection(MailboxConnection):
else:
logger.warning(f"Unknown response {resp.status_code} {resp.json()}")
def fetch_messages(self, folder_name: str, **kwargs) -> List[str]:
def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]:
"""Returns a list of message UIDs in the specified folder"""
folder_id = self._find_folder_id_from_folder_path(folder_name)
folder_id = self._find_folder_id_from_folder_path(reports_folder)
url = f"/users/{self.mailbox_name}/mailFolders/{folder_id}/messages"
since = kwargs.get("since")
if not since:
@@ -166,7 +166,7 @@ class MSGraphConnection(MailboxConnection):
def _get_all_messages(self, url, batch_size, since):
messages: list
params = {"$select": "id"}
params: dict[str, Union[str, int]] = {"$select": "id"}
if since:
params["$filter"] = f"receivedDateTime ge {since}"
if batch_size and batch_size > 0:

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Optional
from typing import cast
from time import sleep
@@ -17,15 +17,14 @@ from parsedmarc.mail.mailbox_connection import MailboxConnection
class IMAPConnection(MailboxConnection):
def __init__(
self,
host: Optional[str] = None,
*,
user: Optional[str] = None,
password: Optional[str] = None,
port: Optional[str] = None,
ssl: Optional[bool] = True,
verify: Optional[bool] = True,
timeout: Optional[int] = 30,
max_retries: Optional[int] = 4,
host: str,
user: str,
password: str,
port: int = 993,
ssl: bool = True,
verify: bool = True,
timeout: int = 30,
max_retries: int = 4,
):
self._username = user
self._password = password
@@ -47,13 +46,13 @@ class IMAPConnection(MailboxConnection):
def fetch_messages(self, reports_folder: str, **kwargs):
self._client.select_folder(reports_folder)
since = kwargs.get("since")
if since:
return self._client.search(["SINCE", since])
if since is not None:
return self._client.search(f"SINCE {since}")
else:
return self._client.search()
def fetch_message(self, message_id: int):
return self._client.fetch_message(message_id, parse=False)
return cast(str, self._client.fetch_message(message_id, parse=False))
def delete_message(self, message_id: int):
self._client.delete_messages([message_id])

View File

@@ -13,16 +13,16 @@ class MailboxConnection(ABC):
def create_folder(self, folder_name: str):
raise NotImplementedError
def fetch_messages(self, reports_folder: str, **kwargs) -> list[str]:
def fetch_messages(self, reports_folder: str, **kwargs):
raise NotImplementedError
def fetch_message(self, message_id) -> str:
raise NotImplementedError
def delete_message(self, message_id: str):
def delete_message(self, message_id):
raise NotImplementedError
def move_message(self, message_id: str, folder_name: str):
def move_message(self, message_id, folder_name: str):
raise NotImplementedError
def keepalive(self):

View File

@@ -2,21 +2,20 @@
from __future__ import annotations
from typing import Optional
import mailbox
import os
from time import sleep
from typing import Dict
from parsedmarc.log import logger
from parsedmarc.mail.mailbox_connection import MailboxConnection
import mailbox
import os
class MaildirConnection(MailboxConnection):
def __init__(
self,
maildir_path: Optional[bool] = None,
maildir_create: Optional[bool] = False,
maildir_path: str,
maildir_create: bool = False,
):
self._maildir_path = maildir_path
self._maildir_create = maildir_create
@@ -33,27 +32,31 @@ class MaildirConnection(MailboxConnection):
)
raise Exception(ex)
self._client = mailbox.Maildir(maildir_path, create=maildir_create)
self._subfolder_client = {}
self._subfolder_client: Dict[str, mailbox.Maildir] = {}
def create_folder(self, folder_name: str):
self._subfolder_client[folder_name] = self._client.add_folder(folder_name)
self._client.add_folder(folder_name)
def fetch_messages(self, reports_folder: str, **kwargs):
return self._client.keys()
def fetch_message(self, message_id: str):
return self._client.get(message_id).as_string()
def fetch_message(self, message_id: str) -> str:
msg = self._client.get(message_id)
if msg is not None:
msg = msg.as_string()
if msg is not None:
return msg
return ""
def delete_message(self, message_id: str):
self._client.remove(message_id)
def move_message(self, message_id: str, folder_name: str):
message_data = self._client.get(message_id)
if folder_name not in self._subfolder_client.keys():
self._subfolder_client = mailbox.Maildir(
os.join(self.maildir_path, folder_name), create=self.maildir_create
)
if message_data is None:
return
if folder_name not in self._subfolder_client:
self._subfolder_client[folder_name] = self._client.add_folder(folder_name)
self._subfolder_client[folder_name].add(message_data)
self._client.remove(message_id)