Refactor and improve parsing and extraction functions

- Updated `extract_report` to handle various input types more robustly, removing unnecessary complexity and improving error handling.
- Simplified the handling of file-like objects and added checks for binary mode.
- Enhanced the `parse_report_email` function to streamline input processing and improve type handling.
- Introduced TypedDicts for better type safety in `utils.py`, specifically for reverse DNS and IP address information.
- Refined the configuration loading in `cli.py` to ensure boolean values are consistently cast to `bool`.
- Improved overall code readability and maintainability by restructuring and clarifying logic in several functions.
This commit is contained in:
Sean Whalen
2025-12-25 15:30:20 -05:00
parent 3608bce344
commit 4b904444e5
4 changed files with 460 additions and 355 deletions

View File

@@ -58,12 +58,14 @@
"htpasswd",
"httpasswd",
"httplib",
"ifhost",
"IMAP",
"imapclient",
"infile",
"Interaktive",
"IPDB",
"journalctl",
"kafkaclient",
"keepalive",
"keyout",
"keyrings",
@@ -71,6 +73,7 @@
"libemail",
"linkify",
"LISTSERV",
"loganalytics",
"lxml",
"mailparser",
"mailrelay",
@@ -80,6 +83,7 @@
"maxmind",
"mbox",
"mfrom",
"mhdw",
"michaeldavie",
"mikesiegel",
"Mimecast",
@@ -103,8 +107,10 @@
"opensearchpy",
"parsedmarc",
"passsword",
"pbar",
"Postorius",
"premade",
"privatesuffix",
"procs",
"publicsuffix",
"publicsuffixlist",

View File

@@ -18,20 +18,17 @@ import zipfile
import zlib
from base64 import b64decode
from csv import DictWriter
from datetime import datetime, timedelta, timezone, tzinfo
from datetime import date, datetime, timedelta, timezone, tzinfo
from io import BytesIO, StringIO
from typing import (
IO,
Any,
BinaryIO,
Callable,
Dict,
List,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
import lxml.etree as etree
@@ -864,14 +861,7 @@ def parse_aggregate_report_xml(
raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__()))
@runtime_checkable
class _ReadableSeekable(Protocol):
def read(self, n: int = -1) -> bytes: ...
def seek(self, offset: int, whence: int = 0) -> int: ...
def tell(self) -> int: ...
def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) -> str:
def extract_report(content: Union[bytes, str, BinaryIO]) -> str:
"""
Extracts text from a zip or gzip file, as a base64-encoded string,
file-like object, or bytes.
@@ -884,31 +874,59 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO])
str: The extracted text
"""
file_object: Optional[_ReadableSeekable] = None
file_object: Optional[BinaryIO] = None
header: bytes
try:
if isinstance(content, str):
try:
file_object = BytesIO(b64decode(content))
except binascii.Error:
return content
elif isinstance(content, (bytes, bytearray, memoryview)):
file_object = BytesIO(bytes(content))
else:
file_object = cast(_ReadableSeekable, content)
header = file_object.read(6)
if isinstance(header, str):
raise ParserError("File objects must be opened in binary (rb) mode")
file_object.seek(0)
if header.startswith(MAGIC_ZIP):
elif isinstance(content, (bytes)):
file_object = BytesIO(bytes(content))
header = file_object.read(6)
file_object.seek(0)
else:
stream = cast(BinaryIO, content)
seekable = getattr(stream, "seekable", None)
can_seek = False
if callable(seekable):
try:
can_seek = bool(seekable())
except Exception:
can_seek = False
if can_seek:
header_raw = stream.read(6)
if isinstance(header_raw, str):
raise ParserError("File objects must be opened in binary (rb) mode")
header = bytes(header_raw)
stream.seek(0)
file_object = stream
else:
header_raw = stream.read(6)
if isinstance(header_raw, str):
raise ParserError("File objects must be opened in binary (rb) mode")
header = bytes(header_raw)
remainder = stream.read()
file_object = BytesIO(header + bytes(remainder))
if file_object is None:
raise ParserError("Invalid report content")
if header[: len(MAGIC_ZIP)] == MAGIC_ZIP:
_zip = zipfile.ZipFile(file_object)
report = _zip.open(_zip.namelist()[0]).read().decode(errors="ignore")
elif header.startswith(MAGIC_GZIP):
elif header[: len(MAGIC_GZIP)] == MAGIC_GZIP:
report = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode(
errors="ignore"
)
elif header.startswith(MAGIC_XML) or header.startswith(MAGIC_JSON):
elif (
header[: len(MAGIC_XML)] == MAGIC_XML
or header[: len(MAGIC_JSON)] == MAGIC_JSON
):
report = file_object.read().decode(errors="ignore")
else:
raise ParserError("Not a valid zip, gzip, json, or xml file")
@@ -918,7 +936,7 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO])
except Exception as error:
raise ParserError("Invalid archive file: {0}".format(error.__str__()))
finally:
if file_object and hasattr(file_object, "close"):
if file_object:
try:
file_object.close()
except Exception:
@@ -937,17 +955,17 @@ def extract_report_from_file_path(file_path: str):
def parse_aggregate_report_file(
_input: Union[str, bytes, IO[Any]],
_input: Union[str, bytes, BinaryIO],
*,
offline: Optional[bool] = False,
always_use_local_files: Optional[bool] = None,
offline: bool = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
dns_timeout: float = 2.0,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24.0,
normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]:
"""Parses a file at the given path, a file-like object. or bytes as an
aggregate DMARC report
@@ -1187,14 +1205,14 @@ def parse_forensic_report(
sample: str,
msg_date: datetime,
*,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
) -> dict[str, Any]:
"""
Converts a DMARC forensic report and sample to a dict
@@ -1449,20 +1467,32 @@ def parse_report_email(
* ``report_type``: ``aggregate`` or ``forensic``
* ``report``: The parsed report
"""
result = None
result: Optional[dict[str, Any]] = None
msg_date: datetime = datetime.now(timezone.utc)
try:
if isinstance(input_, bytes) and is_outlook_msg(input_):
input_ = convert_outlook_msg(input_)
if isinstance(input_, bytes):
input_ = input_.decode(encoding="utf8", errors="replace")
msg = mailparser.parse_from_string(input_)
input_data: Union[str, bytes, bytearray, memoryview] = input_
if isinstance(input_data, (bytes, bytearray, memoryview)):
input_bytes = bytes(input_data)
if is_outlook_msg(input_bytes):
converted = convert_outlook_msg(input_bytes)
if isinstance(converted, str):
input_str = converted
else:
input_str = bytes(converted).decode(
encoding="utf8", errors="replace"
)
else:
input_str = input_bytes.decode(encoding="utf8", errors="replace")
else:
input_str = input_data
msg = mailparser.parse_from_string(input_str)
msg_headers = json.loads(msg.headers_json)
if "Date" in msg_headers:
msg_date = human_timestamp_to_datetime(msg_headers["Date"])
date = email.utils.format_datetime(msg_date)
msg = email.message_from_string(input_)
msg = email.message_from_string(input_str)
except Exception as e:
raise ParserError(e.__str__())
@@ -1477,10 +1507,10 @@ def parse_report_email(
subject = msg_headers["Subject"]
for part in msg.walk():
content_type = part.get_content_type().lower()
payload = part.get_payload()
if not isinstance(payload, list):
payload = [payload]
payload = payload[0].__str__()
payload_obj = part.get_payload()
if not isinstance(payload_obj, list):
payload_obj = [payload_obj]
payload = str(payload_obj[0])
if content_type.startswith("multipart/"):
continue
if content_type == "text/html":
@@ -1501,7 +1531,7 @@ def parse_report_email(
sample = payload
elif content_type == "application/tlsrpt+json":
if not payload.strip().startswith("{"):
payload = str(b64decode(payload))
payload = b64decode(payload).decode("utf-8", errors="replace")
smtp_tls_report = parse_smtp_tls_report_json(payload)
return {"report_type": "smtp_tls", "report": smtp_tls_report}
elif content_type == "application/tlsrpt+gzip":
@@ -1531,18 +1561,21 @@ def parse_report_email(
logger.debug(sample)
else:
try:
payload = b64decode(payload)
if payload.startswith(MAGIC_ZIP) or payload.startswith(MAGIC_GZIP):
payload = extract_report(payload)
if isinstance(payload, bytes):
payload = payload.decode("utf-8", errors="replace")
if payload.strip().startswith("{"):
smtp_tls_report = parse_smtp_tls_report_json(payload)
payload_bytes = b64decode(payload)
if payload_bytes.startswith(MAGIC_ZIP) or payload_bytes.startswith(
MAGIC_GZIP
):
payload_text = extract_report(payload_bytes)
else:
payload_text = payload_bytes.decode("utf-8", errors="replace")
if payload_text.strip().startswith("{"):
smtp_tls_report = parse_smtp_tls_report_json(payload_text)
result = {"report_type": "smtp_tls", "report": smtp_tls_report}
return result
elif payload.strip().startswith("<"):
elif payload_text.strip().startswith("<"):
aggregate_report = parse_aggregate_report_xml(
payload,
payload_text,
ip_db_path=ip_db_path,
always_use_local_files=always_use_local_files,
reverse_dns_map_path=reverse_dns_map_path,
@@ -1604,26 +1637,28 @@ def parse_report_email(
error = 'Message with subject "{0}" is not a valid report'.format(subject)
raise InvalidDMARCReport(error)
return result
def parse_report_file(
input_: Union[bytes, str, IO[Any]],
input_: Union[bytes, str, BinaryIO],
*,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0,
strip_attachment_payloads: Optional[bool] = False,
dns_timeout: float = 2.0,
strip_attachment_payloads: bool = False,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24,
normalize_timespan_threshold_hours: float = 24,
) -> dict[str, Any]:
"""Parses a DMARC aggregate or forensic file at the given path, a
file-like object. or bytes
Args:
input_ (str | bytes | IO): A path to a file, a file like object, or bytes
input_ (str | bytes | BinaryIO): A path to a file, a file like object, or bytes
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
dns_timeout (float): Sets the DNS timeout in seconds
@@ -1639,11 +1674,12 @@ def parse_report_file(
Returns:
dict: The parsed DMARC report
"""
if type(input_) is str:
file_object: BinaryIO
if isinstance(input_, str):
logger.debug("Parsing {0}".format(input_))
file_object = open(input_, "rb")
elif type(input_) is bytes:
file_object = BytesIO(input_)
elif isinstance(input_, (bytes, bytearray, memoryview)):
file_object = BytesIO(bytes(input_))
else:
file_object = input_
@@ -1653,6 +1689,7 @@ def parse_report_file(
content = extract_report(content)
results: Optional[dict[str, Any]] = None
try:
report = parse_aggregate_report_file(
content,
@@ -1673,7 +1710,6 @@ def parse_report_file(
results = {"report_type": "smtp_tls", "report": report}
except InvalidSMTPTLSReport:
try:
sa = strip_attachment_payloads
results = parse_report_email(
content,
ip_db_path=ip_db_path,
@@ -1683,7 +1719,7 @@ def parse_report_file(
offline=offline,
nameservers=nameservers,
dns_timeout=dns_timeout,
strip_attachment_payloads=sa,
strip_attachment_payloads=strip_attachment_payloads,
keep_alive=keep_alive,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours,
)
@@ -1799,7 +1835,7 @@ def get_dmarc_reports_from_mailbox(
strip_attachment_payloads: bool = False,
results: Optional[dict[str, Any]] = None,
batch_size: int = 10,
since: Optional[Union[datetime, str]] = None,
since: Optional[Union[datetime, date, str]] = None,
create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24,
) -> dict[str, list[dict[str, Any]]]:
@@ -1840,7 +1876,7 @@ def get_dmarc_reports_from_mailbox(
raise ValueError("Must supply a connection")
# current_time useful to fetch_messages later in the program
current_time = None
current_time: Optional[Union[datetime, date, str]] = None
aggregate_reports = []
forensic_reports = []
@@ -1865,7 +1901,7 @@ def get_dmarc_reports_from_mailbox(
connection.create_folder(smtp_tls_reports_folder)
connection.create_folder(invalid_reports_folder)
if since:
if since and isinstance(since, str):
_since = 1440 # default one day
if re.match(r"\d+[mhdw]$", since):
s = re.split(r"(\d+)", since)
@@ -1926,13 +1962,16 @@ def get_dmarc_reports_from_mailbox(
i + 1, message_limit, msg_uid
)
)
if isinstance(mailbox, MSGraphConnection):
if test:
msg_content = connection.fetch_message(msg_uid, mark_read=False)
message_id: Union[int, str]
if isinstance(connection, IMAPConnection):
message_id = int(msg_uid)
msg_content = connection.fetch_message(message_id)
elif isinstance(connection, MSGraphConnection):
message_id = str(msg_uid)
msg_content = connection.fetch_message(message_id, mark_read=not test)
else:
msg_content = connection.fetch_message(msg_uid, mark_read=True)
else:
msg_content = connection.fetch_message(msg_uid)
message_id = str(msg_uid) if not isinstance(msg_uid, str) else msg_uid
msg_content = connection.fetch_message(message_id)
try:
sa = strip_attachment_payloads
parsed_email = parse_report_email(
@@ -1959,26 +1998,32 @@ def get_dmarc_reports_from_mailbox(
logger.debug(
f"Skipping duplicate aggregate report with ID: {report_id}"
)
aggregate_report_msg_uids.append(msg_uid)
aggregate_report_msg_uids.append(message_id)
elif parsed_email["report_type"] == "forensic":
forensic_reports.append(parsed_email["report"])
forensic_report_msg_uids.append(msg_uid)
forensic_report_msg_uids.append(message_id)
elif parsed_email["report_type"] == "smtp_tls":
smtp_tls_reports.append(parsed_email["report"])
smtp_tls_msg_uids.append(msg_uid)
smtp_tls_msg_uids.append(message_id)
except ParserError as error:
logger.warning(error.__str__())
if not test:
if delete:
logger.debug("Deleting message UID {0}".format(msg_uid))
connection.delete_message(msg_uid)
if isinstance(connection, IMAPConnection):
connection.delete_message(int(message_id))
else:
connection.delete_message(str(message_id))
else:
logger.debug(
"Moving message UID {0} to {1}".format(
msg_uid, invalid_reports_folder
)
)
connection.move_message(msg_uid, invalid_reports_folder)
if isinstance(connection, IMAPConnection):
connection.move_message(int(message_id), invalid_reports_folder)
else:
connection.move_message(str(message_id), invalid_reports_folder)
if not test:
if delete:
@@ -2106,21 +2151,21 @@ def watch_inbox(
mailbox_connection: MailboxConnection,
callback: Callable,
*,
reports_folder: Optional[str] = "INBOX",
archive_folder: Optional[str] = "Archive",
delete: Optional[bool] = False,
test: Optional[bool] = False,
check_timeout: Optional[int] = 30,
reports_folder: str = "INBOX",
archive_folder: str = "Archive",
delete: bool = False,
test: bool = False,
check_timeout: int = 30,
ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False,
always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 6.0,
strip_attachment_payloads: Optional[bool] = False,
batch_size: Optional[int] = None,
normalize_timespan_threshold_hours: Optional[float] = 24,
dns_timeout: float = 6.0,
strip_attachment_payloads: bool = False,
batch_size: int = 10,
normalize_timespan_threshold_hours: float = 24,
):
"""
Watches the mailbox for new messages and
@@ -2150,7 +2195,6 @@ def watch_inbox(
"""
def check_callback(connection):
sa = strip_attachment_payloads
res = get_dmarc_reports_from_mailbox(
connection=connection,
reports_folder=reports_folder,
@@ -2164,7 +2208,7 @@ def watch_inbox(
offline=offline,
nameservers=nameservers,
dns_timeout=dns_timeout,
strip_attachment_payloads=sa,
strip_attachment_payloads=strip_attachment_payloads,
batch_size=batch_size,
create_folders=False,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours,
@@ -2212,13 +2256,13 @@ def append_csv(filename, csv):
def save_output(
results: dict[str, Any],
*,
output_directory: Optional[str] = "output",
aggregate_json_filename: Optional[str] = "aggregate.json",
forensic_json_filename: Optional[str] = "forensic.json",
smtp_tls_json_filename: Optional[str] = "smtp_tls.json",
aggregate_csv_filename: Optional[str] = "aggregate.csv",
forensic_csv_filename: Optional[str] = "forensic.csv",
smtp_tls_csv_filename: Optional[str] = "smtp_tls.csv",
output_directory: str = "output",
aggregate_json_filename: str = "aggregate.json",
forensic_json_filename: str = "forensic.json",
smtp_tls_json_filename: str = "smtp_tls.json",
aggregate_csv_filename: str = "aggregate.csv",
forensic_csv_filename: str = "forensic.csv",
smtp_tls_csv_filename: str = "smtp_tls.csv",
):
"""
Save report data in the given directory
@@ -2322,7 +2366,7 @@ def get_report_zip(results: dict[str, Any]) -> bytes:
storage = BytesIO()
tmp_dir = tempfile.mkdtemp()
try:
save_output(results, tmp_dir)
save_output(results, output_directory=tmp_dir)
with zipfile.ZipFile(storage, "w", zipfile.ZIP_DEFLATED) as zip_file:
for root, dirs, files in os.walk(tmp_dir):
for file in files:
@@ -2345,10 +2389,10 @@ def email_results(
results: dict[str, Any],
host: str,
mail_from: str,
mail_to: str,
mail_to: Optional[list[str]],
*,
mail_cc: Optional[list] = None,
mail_bcc: Optional[list] = None,
mail_cc: Optional[list[str]] = None,
mail_bcc: Optional[list[str]] = None,
port: int = 0,
require_encryption: bool = False,
verify: bool = True,
@@ -2377,7 +2421,7 @@ def email_results(
attachment_filename (str): Override the default attachment filename
message (str): Override the default plain text body
"""
logger.debug("Emailing report to: {0}".format(",".join(mail_to)))
logger.debug("Emailing report")
date_string = datetime.now().strftime("%Y-%m-%d")
if attachment_filename:
if not attachment_filename.lower().endswith(".zip"):

View File

@@ -678,7 +678,7 @@ def _main():
if "general" in config.sections():
general_config = config["general"]
if "silent" in general_config:
opts.silent = general_config.getboolean("silent")
opts.silent = bool(general_config.getboolean("silent"))
if "normalize_timespan_threshold_hours" in general_config:
opts.normalize_timespan_threshold_hours = general_config.getfloat(
"normalize_timespan_threshold_hours"
@@ -687,10 +687,10 @@ def _main():
with open(general_config["index_prefix_domain_map"]) as f:
index_prefix_domain_map = yaml.safe_load(f)
if "offline" in general_config:
opts.offline = general_config.getboolean("offline")
opts.offline = bool(general_config.getboolean("offline"))
if "strip_attachment_payloads" in general_config:
opts.strip_attachment_payloads = general_config.getboolean(
"strip_attachment_payloads"
opts.strip_attachment_payloads = bool(
general_config.getboolean("strip_attachment_payloads")
)
if "output" in general_config:
opts.output = general_config["output"]
@@ -732,19 +732,19 @@ def _main():
)
exit(-1)
if "save_aggregate" in general_config:
opts.save_aggregate = general_config.getboolean("save_aggregate")
opts.save_aggregate = bool(general_config.getboolean("save_aggregate"))
if "save_forensic" in general_config:
opts.save_forensic = general_config.getboolean("save_forensic")
opts.save_forensic = bool(general_config.getboolean("save_forensic"))
if "save_smtp_tls" in general_config:
opts.save_smtp_tls = general_config.getboolean("save_smtp_tls")
opts.save_smtp_tls = bool(general_config.getboolean("save_smtp_tls"))
if "debug" in general_config:
opts.debug = general_config.getboolean("debug")
opts.debug = bool(general_config.getboolean("debug"))
if "verbose" in general_config:
opts.verbose = general_config.getboolean("verbose")
opts.verbose = bool(general_config.getboolean("verbose"))
if "silent" in general_config:
opts.silent = general_config.getboolean("silent")
opts.silent = bool(general_config.getboolean("silent"))
if "warnings" in general_config:
opts.warnings = general_config.getboolean("warnings")
opts.warnings = bool(general_config.getboolean("warnings"))
if "log_file" in general_config:
opts.log_file = general_config["log_file"]
if "n_procs" in general_config:
@@ -754,15 +754,15 @@ def _main():
else:
opts.ip_db_path = None
if "always_use_local_files" in general_config:
opts.always_use_local_files = general_config.getboolean(
"always_use_local_files"
opts.always_use_local_files = bool(
general_config.getboolean("always_use_local_files")
)
if "reverse_dns_map_path" in general_config:
opts.reverse_dns_map_path = general_config["reverse_dns_path"]
if "reverse_dns_map_url" in general_config:
opts.reverse_dns_map_url = general_config["reverse_dns_url"]
if "prettify_json" in general_config:
opts.prettify_json = general_config.getboolean("prettify_json")
opts.prettify_json = bool(general_config.getboolean("prettify_json"))
if "mailbox" in config.sections():
mailbox_config = config["mailbox"]
@@ -773,11 +773,11 @@ def _main():
if "archive_folder" in mailbox_config:
opts.mailbox_archive_folder = mailbox_config["archive_folder"]
if "watch" in mailbox_config:
opts.mailbox_watch = mailbox_config.getboolean("watch")
opts.mailbox_watch = bool(mailbox_config.getboolean("watch"))
if "delete" in mailbox_config:
opts.mailbox_delete = mailbox_config.getboolean("delete")
opts.mailbox_delete = bool(mailbox_config.getboolean("delete"))
if "test" in mailbox_config:
opts.mailbox_test = mailbox_config.getboolean("test")
opts.mailbox_test = bool(mailbox_config.getboolean("test"))
if "batch_size" in mailbox_config:
opts.mailbox_batch_size = mailbox_config.getint("batch_size")
if "check_timeout" in mailbox_config:
@@ -805,10 +805,10 @@ def _main():
if "max_retries" in imap_config:
opts.imap_max_retries = imap_config.getint("max_retries")
if "ssl" in imap_config:
opts.imap_ssl = imap_config.getboolean("ssl")
opts.imap_ssl = bool(imap_config.getboolean("ssl"))
if "skip_certificate_verification" in imap_config:
opts.imap_skip_certificate_verification = imap_config.getboolean(
"skip_certificate_verification"
opts.imap_skip_certificate_verification = bool(
imap_config.getboolean("skip_certificate_verification")
)
if "user" in imap_config:
opts.imap_user = imap_config["user"]
@@ -837,7 +837,7 @@ def _main():
"section instead."
)
if "watch" in imap_config:
opts.mailbox_watch = imap_config.getboolean("watch")
opts.mailbox_watch = bool(imap_config.getboolean("watch"))
logger.warning(
"Use of the watch option in the imap "
"configuration section has been deprecated. "
@@ -852,7 +852,7 @@ def _main():
"section instead."
)
if "test" in imap_config:
opts.mailbox_test = imap_config.getboolean("test")
opts.mailbox_test = bool(imap_config.getboolean("test"))
logger.warning(
"Use of the test option in the imap "
"configuration section has been deprecated. "
@@ -946,8 +946,8 @@ def _main():
opts.graph_url = graph_config["graph_url"]
if "allow_unencrypted_storage" in graph_config:
opts.graph_allow_unencrypted_storage = graph_config.getboolean(
"allow_unencrypted_storage"
opts.graph_allow_unencrypted_storage = bool(
graph_config.getboolean("allow_unencrypted_storage")
)
if "elasticsearch" in config:
@@ -975,10 +975,10 @@ def _main():
if "index_prefix" in elasticsearch_config:
opts.elasticsearch_index_prefix = elasticsearch_config["index_prefix"]
if "monthly_indexes" in elasticsearch_config:
monthly = elasticsearch_config.getboolean("monthly_indexes")
monthly = bool(elasticsearch_config.getboolean("monthly_indexes"))
opts.elasticsearch_monthly_indexes = monthly
if "ssl" in elasticsearch_config:
opts.elasticsearch_ssl = elasticsearch_config.getboolean("ssl")
opts.elasticsearch_ssl = bool(elasticsearch_config.getboolean("ssl"))
if "cert_path" in elasticsearch_config:
opts.elasticsearch_ssl_cert_path = elasticsearch_config["cert_path"]
if "user" in elasticsearch_config:
@@ -1015,10 +1015,10 @@ def _main():
if "index_prefix" in opensearch_config:
opts.opensearch_index_prefix = opensearch_config["index_prefix"]
if "monthly_indexes" in opensearch_config:
monthly = opensearch_config.getboolean("monthly_indexes")
monthly = bool(opensearch_config.getboolean("monthly_indexes"))
opts.opensearch_monthly_indexes = monthly
if "ssl" in opensearch_config:
opts.opensearch_ssl = opensearch_config.getboolean("ssl")
opts.opensearch_ssl = bool(opensearch_config.getboolean("ssl"))
if "cert_path" in opensearch_config:
opts.opensearch_ssl_cert_path = opensearch_config["cert_path"]
if "user" in opensearch_config:
@@ -1072,9 +1072,11 @@ def _main():
if "password" in kafka_config:
opts.kafka_password = kafka_config["password"]
if "ssl" in kafka_config:
opts.kafka_ssl = kafka_config.getboolean("ssl")
opts.kafka_ssl = bool(kafka_config.getboolean("ssl"))
if "skip_certificate_verification" in kafka_config:
kafka_verify = kafka_config.getboolean("skip_certificate_verification")
kafka_verify = bool(
kafka_config.getboolean("skip_certificate_verification")
)
opts.kafka_skip_certificate_verification = kafka_verify
if "aggregate_topic" in kafka_config:
opts.kafka_aggregate_topic = kafka_config["aggregate_topic"]
@@ -1106,9 +1108,11 @@ def _main():
if "port" in smtp_config:
opts.smtp_port = smtp_config.getint("port")
if "ssl" in smtp_config:
opts.smtp_ssl = smtp_config.getboolean("ssl")
opts.smtp_ssl = bool(smtp_config.getboolean("ssl"))
if "skip_certificate_verification" in smtp_config:
smtp_verify = smtp_config.getboolean("skip_certificate_verification")
smtp_verify = bool(
smtp_config.getboolean("skip_certificate_verification")
)
opts.smtp_skip_certificate_verification = smtp_verify
if "user" in smtp_config:
opts.smtp_user = smtp_config["user"]
@@ -1176,11 +1180,11 @@ def _main():
gmail_api_config = config["gmail_api"]
opts.gmail_api_credentials_file = gmail_api_config.get("credentials_file")
opts.gmail_api_token_file = gmail_api_config.get("token_file", ".token")
opts.gmail_api_include_spam_trash = gmail_api_config.getboolean(
"include_spam_trash", False
opts.gmail_api_include_spam_trash = bool(
gmail_api_config.getboolean("include_spam_trash", False)
)
opts.gmail_api_paginate_messages = gmail_api_config.getboolean(
"paginate_messages", True
opts.gmail_api_paginate_messages = bool(
gmail_api_config.getboolean("paginate_messages", True)
)
opts.gmail_api_scopes = gmail_api_config.get(
"scopes", default_gmail_api_scope
@@ -1194,8 +1198,8 @@ def _main():
if "maildir" in config.sections():
maildir_api_config = config["maildir"]
opts.maildir_path = maildir_api_config.get("maildir_path")
opts.maildir_create = maildir_api_config.getboolean(
"maildir_create", fallback=False
opts.maildir_create = bool(
maildir_api_config.getboolean("maildir_create", fallback=False)
)
if "log_analytics" in config.sections():
@@ -1275,6 +1279,7 @@ def _main():
logger.info("Starting parsedmarc")
if opts.save_aggregate or opts.save_forensic or opts.save_smtp_tls:
try:
if opts.elasticsearch_hosts:
@@ -1513,6 +1518,11 @@ def _main():
smtp_tls_reports.append(result[0]["report"])
for mbox_path in mbox_paths:
normalize_timespan_threshold_hours_value = (
float(opts.normalize_timespan_threshold_hours)
if opts.normalize_timespan_threshold_hours is not None
else 24.0
)
strip = opts.strip_attachment_payloads
reports = get_dmarc_reports_from_mbox(
mbox_path,
@@ -1524,13 +1534,17 @@ def _main():
reverse_dns_map_path=opts.reverse_dns_map_path,
reverse_dns_map_url=opts.reverse_dns_map_url,
offline=opts.offline,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
)
aggregate_reports += reports["aggregate_reports"]
forensic_reports += reports["forensic_reports"]
smtp_tls_reports += reports["smtp_tls_reports"]
mailbox_connection = None
mailbox_batch_size_value = 10
mailbox_check_timeout_value = 30
normalize_timespan_threshold_hours_value = 24.0
if opts.imap_host:
try:
if opts.imap_user is None or opts.imap_password is None:
@@ -1552,9 +1566,10 @@ def _main():
imap_max_retries = (
int(opts.imap_max_retries) if opts.imap_max_retries is not None else 4
)
imap_port_value = int(opts.imap_port) if opts.imap_port is not None else 993
mailbox_connection = IMAPConnection(
host=opts.imap_host,
port=opts.imap_port,
port=imap_port_value,
ssl=ssl,
verify=verify,
timeout=imap_timeout,
@@ -1624,11 +1639,24 @@ def _main():
exit(1)
if mailbox_connection:
mailbox_batch_size_value = (
int(opts.mailbox_batch_size) if opts.mailbox_batch_size is not None else 10
)
mailbox_check_timeout_value = (
int(opts.mailbox_check_timeout)
if opts.mailbox_check_timeout is not None
else 30
)
normalize_timespan_threshold_hours_value = (
float(opts.normalize_timespan_threshold_hours)
if opts.normalize_timespan_threshold_hours is not None
else 24.0
)
try:
reports = get_dmarc_reports_from_mailbox(
connection=mailbox_connection,
delete=opts.mailbox_delete,
batch_size=opts.mailbox_batch_size,
batch_size=mailbox_batch_size_value,
reports_folder=opts.mailbox_reports_folder,
archive_folder=opts.mailbox_archive_folder,
ip_db_path=opts.ip_db_path,
@@ -1640,7 +1668,7 @@ def _main():
test=opts.mailbox_test,
strip_attachment_payloads=opts.strip_attachment_payloads,
since=opts.mailbox_since,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
)
aggregate_reports += reports["aggregate_reports"]
@@ -1664,12 +1692,18 @@ def _main():
verify = True
if opts.smtp_skip_certificate_verification:
verify = False
smtp_port_value = int(opts.smtp_port) if opts.smtp_port is not None else 25
smtp_to_value = (
list(opts.smtp_to)
if isinstance(opts.smtp_to, list)
else _str_to_list(str(opts.smtp_to))
)
email_results(
results,
opts.smtp_host,
opts.smtp_from,
opts.smtp_to,
port=opts.smtp_port,
smtp_to_value,
port=smtp_port_value,
verify=verify,
username=opts.smtp_user,
password=opts.smtp_password,
@@ -1683,6 +1717,7 @@ def _main():
if mailbox_connection and opts.mailbox_watch:
logger.info("Watching for email - Quit with ctrl-c")
try:
watch_inbox(
mailbox_connection=mailbox_connection,
@@ -1691,17 +1726,17 @@ def _main():
archive_folder=opts.mailbox_archive_folder,
delete=opts.mailbox_delete,
test=opts.mailbox_test,
check_timeout=opts.mailbox_check_timeout,
check_timeout=mailbox_check_timeout_value,
nameservers=opts.nameservers,
dns_timeout=opts.dns_timeout,
strip_attachment_payloads=opts.strip_attachment_payloads,
batch_size=opts.mailbox_batch_size,
batch_size=mailbox_batch_size_value,
ip_db_path=opts.ip_db_path,
always_use_local_files=opts.always_use_local_files,
reverse_dns_map_path=opts.reverse_dns_map_path,
reverse_dns_map_url=opts.reverse_dns_map_url,
offline=opts.offline,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
)
except FileExistsError as error:
logger.error("{0}".format(error.__str__()))

View File

@@ -17,7 +17,7 @@ import shutil
import subprocess
import tempfile
from datetime import datetime, timedelta, timezone
from typing import Optional, Union
from typing import Optional, TypedDict, Union, cast
import mailparser
from expiringdict import ExpiringDict
@@ -64,7 +64,24 @@ class DownloadError(RuntimeError):
"""Raised when an error occurs when downloading a file"""
def decode_base64(data) -> bytes:
class ReverseDNSService(TypedDict):
name: str
type: Optional[str]
ReverseDNSMap = dict[str, ReverseDNSService]
class IPAddressInfo(TypedDict):
ip_address: str
reverse_dns: Optional[str]
country: Optional[str]
base_domain: Optional[str]
name: Optional[str]
type: Optional[str]
def decode_base64(data: str) -> bytes:
"""
Decodes a base64 string, with padding being optional
@@ -75,11 +92,11 @@ def decode_base64(data) -> bytes:
bytes: The decoded bytes
"""
data = bytes(data, encoding="ascii")
missing_padding = len(data) % 4
data_bytes = bytes(data, encoding="ascii")
missing_padding = len(data_bytes) % 4
if missing_padding != 0:
data += b"=" * (4 - missing_padding)
return base64.b64decode(data)
data_bytes += b"=" * (4 - missing_padding)
return base64.b64decode(data_bytes)
def get_base_domain(domain: str) -> Optional[str]:
@@ -132,9 +149,9 @@ def query_dns(
record_type = record_type.upper()
cache_key = "{0}_{1}".format(domain, record_type)
if cache:
records = cache.get(cache_key, None)
if records:
return records
cached_records = cache.get(cache_key, None)
if isinstance(cached_records, list):
return cast(list[str], cached_records)
resolver = dns.resolver.Resolver()
timeout = float(timeout)
@@ -148,20 +165,6 @@ def query_dns(
resolver.nameservers = nameservers
resolver.timeout = timeout
resolver.lifetime = timeout
if record_type == "TXT":
resource_records = list(
map(
lambda r: r.strings,
resolver.resolve(domain, record_type, lifetime=timeout),
)
)
_resource_record = [
resource_record[0][:0].join(resource_record)
for resource_record in resource_records
if resource_record
]
records = [r.decode() for r in _resource_record]
else:
records = list(
map(
lambda r: r.to_text().replace('"', "").rstrip("."),
@@ -180,7 +183,7 @@ def get_reverse_dns(
cache: Optional[ExpiringDict] = None,
nameservers: Optional[list[str]] = None,
timeout: float = 2.0,
) -> str:
) -> Optional[str]:
"""
Resolves an IP address to a hostname using a reverse DNS query
@@ -198,7 +201,7 @@ def get_reverse_dns(
try:
address = dns.reversename.from_address(ip_address)
hostname = query_dns(
address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout
str(address), "PTR", cache=cache, nameservers=nameservers, timeout=timeout
)[0]
except dns.exception.DNSException as e:
@@ -266,10 +269,12 @@ def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int:
float: The converted timestamp
"""
human_timestamp = human_timestamp.replace("T", " ")
return human_timestamp_to_datetime(human_timestamp).timestamp()
return int(human_timestamp_to_datetime(human_timestamp).timestamp())
def get_ip_address_country(ip_address: str, *, db_path: Optional[str] = None) -> str:
def get_ip_address_country(
ip_address: str, *, db_path: Optional[str] = None
) -> Optional[str]:
"""
Returns the ISO code for the country associated
with the given IPv4 or IPv6 address
@@ -335,11 +340,11 @@ def get_service_from_reverse_dns_base_domain(
base_domain,
*,
always_use_local_file: Optional[bool] = False,
local_file_path: Optional[bool] = None,
url: Optional[bool] = None,
local_file_path: Optional[str] = None,
url: Optional[str] = None,
offline: Optional[bool] = False,
reverse_dns_map: Optional[bool] = None,
) -> str:
reverse_dns_map: Optional[ReverseDNSMap] = None,
) -> ReverseDNSService:
"""
Returns the service name of a given base domain name from reverse DNS.
@@ -356,12 +361,6 @@ def get_service_from_reverse_dns_base_domain(
the supplied reverse_dns_base_domain and the type will be None
"""
def load_csv(_csv_file):
reader = csv.DictReader(_csv_file)
for row in reader:
key = row["base_reverse_dns"].lower().strip()
reverse_dns_map[key] = dict(name=row["name"], type=row["type"])
base_domain = base_domain.lower().strip()
if url is None:
url = (
@@ -369,11 +368,24 @@ def get_service_from_reverse_dns_base_domain(
"/parsedmarc/master/parsedmarc/"
"resources/maps/base_reverse_dns_map.csv"
)
reverse_dns_map_value: ReverseDNSMap
if reverse_dns_map is None:
reverse_dns_map = dict()
reverse_dns_map_value = {}
else:
reverse_dns_map_value = reverse_dns_map
def load_csv(_csv_file):
reader = csv.DictReader(_csv_file)
for row in reader:
key = row["base_reverse_dns"].lower().strip()
reverse_dns_map_value[key] = {
"name": row["name"],
"type": row["type"],
}
csv_file = io.StringIO()
if not (offline or always_use_local_file) and len(reverse_dns_map) == 0:
if not (offline or always_use_local_file) and len(reverse_dns_map_value) == 0:
try:
logger.debug(f"Trying to fetch reverse DNS map from {url}...")
headers = {"User-Agent": USER_AGENT}
@@ -390,7 +402,7 @@ def get_service_from_reverse_dns_base_domain(
logging.debug("Response body:")
logger.debug(csv_file.read())
if len(reverse_dns_map) == 0:
if len(reverse_dns_map_value) == 0:
logger.info("Loading included reverse DNS map...")
path = str(
files(parsedmarc.resources.maps).joinpath("base_reverse_dns_map.csv")
@@ -399,10 +411,11 @@ def get_service_from_reverse_dns_base_domain(
path = local_file_path
with open(path) as csv_file:
load_csv(csv_file)
service: ReverseDNSService
try:
service = reverse_dns_map[base_domain]
service = reverse_dns_map_value[base_domain]
except KeyError:
service = dict(name=base_domain, type=None)
service = {"name": base_domain, "type": None}
return service
@@ -415,11 +428,11 @@ def get_ip_address_info(
always_use_local_files: Optional[bool] = False,
reverse_dns_map_url: Optional[str] = None,
cache: Optional[ExpiringDict] = None,
reverse_dns_map: Optional[dict] = None,
reverse_dns_map: Optional[ReverseDNSMap] = None,
offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None,
timeout: Optional[float] = 2.0,
) -> dict[str, str]:
timeout: float = 2.0,
) -> IPAddressInfo:
"""
Returns reverse DNS and country information for the given IP address
@@ -442,12 +455,22 @@ def get_ip_address_info(
"""
ip_address = ip_address.lower()
if cache is not None:
info = cache.get(ip_address, None)
if info:
cached_info = cache.get(ip_address, None)
if (
cached_info
and isinstance(cached_info, dict)
and "ip_address" in cached_info
):
logger.debug(f"IP address {ip_address} was found in cache")
return info
info: dict[str, str] = {}
info["ip_address"] = ip_address
return cast(IPAddressInfo, cached_info)
info: IPAddressInfo = {
"ip_address": ip_address,
"reverse_dns": None,
"country": None,
"base_domain": None,
"name": None,
"type": None,
}
if offline:
reverse_dns = None
else:
@@ -457,9 +480,6 @@ def get_ip_address_info(
country = get_ip_address_country(ip_address, db_path=ip_db_path)
info["country"] = country
info["reverse_dns"] = reverse_dns
info["base_domain"] = None
info["name"] = None
info["type"] = None
if reverse_dns is not None:
base_domain = get_base_domain(reverse_dns)
if base_domain is not None:
@@ -484,7 +504,7 @@ def get_ip_address_info(
return info
def parse_email_address(original_address: str) -> dict[str, str]:
def parse_email_address(original_address: str) -> dict[str, Optional[str]]:
if original_address[0] == "":
display_name = None
else:
@@ -563,7 +583,7 @@ def is_outlook_msg(content) -> bool:
)
def convert_outlook_msg(msg_bytes: bytes) -> str:
def convert_outlook_msg(msg_bytes: bytes) -> bytes:
"""
Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to
standard RFC 822 format
@@ -572,7 +592,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> str:
msg_bytes (bytes): the content of the .msg file
Returns:
A RFC 822 string
A RFC 822 bytes payload
"""
if not is_outlook_msg(msg_bytes):
raise ValueError("The supplied bytes are not an Outlook MSG file")