Refactor and improve parsing and extraction functions

- Updated `extract_report` to handle various input types more robustly, removing unnecessary complexity and improving error handling.
- Simplified the handling of file-like objects and added checks for binary mode.
- Enhanced the `parse_report_email` function to streamline input processing and improve type handling.
- Introduced TypedDicts for better type safety in `utils.py`, specifically for reverse DNS and IP address information.
- Refined the configuration loading in `cli.py` to ensure boolean values are consistently cast to `bool`.
- Improved overall code readability and maintainability by restructuring and clarifying logic in several functions.
This commit is contained in:
Sean Whalen
2025-12-25 15:30:20 -05:00
parent 3608bce344
commit 4b904444e5
4 changed files with 460 additions and 355 deletions

292
.vscode/settings.json vendored
View File

@@ -13,148 +13,154 @@
"MD024": false "MD024": false
}, },
"cSpell.words": [ "cSpell.words": [
"adkim", "adkim",
"akamaiedge", "akamaiedge",
"amsmath", "amsmath",
"andrewmcgilvray", "andrewmcgilvray",
"arcname", "arcname",
"aspf", "aspf",
"autoclass", "autoclass",
"automodule", "automodule",
"backported", "backported",
"bellsouth", "bellsouth",
"boto", "boto",
"brakhane", "brakhane",
"Brightmail", "Brightmail",
"CEST", "CEST",
"CHACHA", "CHACHA",
"checkdmarc", "checkdmarc",
"Codecov", "Codecov",
"confnew", "confnew",
"dateparser", "dateparser",
"dateutil", "dateutil",
"Davmail", "Davmail",
"DBIP", "DBIP",
"dearmor", "dearmor",
"deflist", "deflist",
"devel", "devel",
"DMARC", "DMARC",
"Dmarcian", "Dmarcian",
"dnspython", "dnspython",
"dollarmath", "dollarmath",
"dpkg", "dpkg",
"exampleuser", "exampleuser",
"expiringdict", "expiringdict",
"fieldlist", "fieldlist",
"GELF", "GELF",
"genindex", "genindex",
"geoip", "geoip",
"geoipupdate", "geoipupdate",
"Geolite", "Geolite",
"geolocation", "geolocation",
"githubpages", "githubpages",
"Grafana", "Grafana",
"hostnames", "hostnames",
"htpasswd", "htpasswd",
"httpasswd", "httpasswd",
"httplib", "httplib",
"IMAP", "ifhost",
"imapclient", "IMAP",
"infile", "imapclient",
"Interaktive", "infile",
"IPDB", "Interaktive",
"journalctl", "IPDB",
"keepalive", "journalctl",
"keyout", "kafkaclient",
"keyrings", "keepalive",
"Leeman", "keyout",
"libemail", "keyrings",
"linkify", "Leeman",
"LISTSERV", "libemail",
"lxml", "linkify",
"mailparser", "LISTSERV",
"mailrelay", "loganalytics",
"mailsuite", "lxml",
"maxdepth", "mailparser",
"MAXHEADERS", "mailrelay",
"maxmind", "mailsuite",
"mbox", "maxdepth",
"mfrom", "MAXHEADERS",
"michaeldavie", "maxmind",
"mikesiegel", "mbox",
"Mimecast", "mfrom",
"mitigations", "mhdw",
"MMDB", "michaeldavie",
"modindex", "mikesiegel",
"msgconvert", "Mimecast",
"msgraph", "mitigations",
"MSSP", "MMDB",
"multiprocess", "modindex",
"Munge", "msgconvert",
"ndjson", "msgraph",
"newkey", "MSSP",
"Nhcm", "multiprocess",
"nojekyll", "Munge",
"nondigest", "ndjson",
"nosecureimap", "newkey",
"nosniff", "Nhcm",
"nwettbewerb", "nojekyll",
"opensearch", "nondigest",
"opensearchpy", "nosecureimap",
"parsedmarc", "nosniff",
"passsword", "nwettbewerb",
"Postorius", "opensearch",
"premade", "opensearchpy",
"procs", "parsedmarc",
"publicsuffix", "passsword",
"publicsuffixlist", "pbar",
"publixsuffix", "Postorius",
"pygelf", "premade",
"pypy", "privatesuffix",
"pytest", "procs",
"quickstart", "publicsuffix",
"Reindex", "publicsuffixlist",
"replyto", "publixsuffix",
"reversename", "pygelf",
"Rollup", "pypy",
"Rpdm", "pytest",
"SAMEORIGIN", "quickstart",
"sdist", "Reindex",
"Servernameone", "replyto",
"setuptools", "reversename",
"smartquotes", "Rollup",
"SMTPTLS", "Rpdm",
"sortlists", "SAMEORIGIN",
"sortmaps", "sdist",
"sourcetype", "Servernameone",
"STARTTLS", "setuptools",
"tasklist", "smartquotes",
"timespan", "SMTPTLS",
"tlsa", "sortlists",
"tlsrpt", "sortmaps",
"toctree", "sourcetype",
"TQDDM", "STARTTLS",
"tqdm", "tasklist",
"truststore", "timespan",
"Übersicht", "tlsa",
"uids", "tlsrpt",
"Uncategorized", "toctree",
"unparasable", "TQDDM",
"uper", "tqdm",
"urllib", "truststore",
"Valimail", "Übersicht",
"venv", "uids",
"Vhcw", "Uncategorized",
"viewcode", "unparasable",
"virtualenv", "uper",
"WBITS", "urllib",
"webmail", "Valimail",
"Wettbewerber", "venv",
"Whalen", "Vhcw",
"whitespaces", "viewcode",
"xennn", "virtualenv",
"xmltodict", "WBITS",
"xpack", "webmail",
"zscholl" "Wettbewerber",
"Whalen",
"whitespaces",
"xennn",
"xmltodict",
"xpack",
"zscholl"
], ],
} }

View File

@@ -18,20 +18,17 @@ import zipfile
import zlib import zlib
from base64 import b64decode from base64 import b64decode
from csv import DictWriter from csv import DictWriter
from datetime import datetime, timedelta, timezone, tzinfo from datetime import date, datetime, timedelta, timezone, tzinfo
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import ( from typing import (
IO,
Any, Any,
BinaryIO, BinaryIO,
Callable, Callable,
Dict, Dict,
List, List,
Optional, Optional,
Protocol,
Union, Union,
cast, cast,
runtime_checkable,
) )
import lxml.etree as etree import lxml.etree as etree
@@ -864,14 +861,7 @@ def parse_aggregate_report_xml(
raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__())) raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__()))
@runtime_checkable def extract_report(content: Union[bytes, str, BinaryIO]) -> str:
class _ReadableSeekable(Protocol):
def read(self, n: int = -1) -> bytes: ...
def seek(self, offset: int, whence: int = 0) -> int: ...
def tell(self) -> int: ...
def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) -> str:
""" """
Extracts text from a zip or gzip file, as a base64-encoded string, Extracts text from a zip or gzip file, as a base64-encoded string,
file-like object, or bytes. file-like object, or bytes.
@@ -884,31 +874,59 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO])
str: The extracted text str: The extracted text
""" """
file_object: Optional[_ReadableSeekable] = None file_object: Optional[BinaryIO] = None
header: bytes
try: try:
if isinstance(content, str): if isinstance(content, str):
try: try:
file_object = BytesIO(b64decode(content)) file_object = BytesIO(b64decode(content))
except binascii.Error: except binascii.Error:
return content return content
elif isinstance(content, (bytes, bytearray, memoryview)): header = file_object.read(6)
file_object.seek(0)
elif isinstance(content, (bytes)):
file_object = BytesIO(bytes(content)) file_object = BytesIO(bytes(content))
header = file_object.read(6)
file_object.seek(0)
else: else:
file_object = cast(_ReadableSeekable, content) stream = cast(BinaryIO, content)
seekable = getattr(stream, "seekable", None)
can_seek = False
if callable(seekable):
try:
can_seek = bool(seekable())
except Exception:
can_seek = False
header = file_object.read(6) if can_seek:
if isinstance(header, str): header_raw = stream.read(6)
raise ParserError("File objects must be opened in binary (rb) mode") if isinstance(header_raw, str):
raise ParserError("File objects must be opened in binary (rb) mode")
header = bytes(header_raw)
stream.seek(0)
file_object = stream
else:
header_raw = stream.read(6)
if isinstance(header_raw, str):
raise ParserError("File objects must be opened in binary (rb) mode")
header = bytes(header_raw)
remainder = stream.read()
file_object = BytesIO(header + bytes(remainder))
file_object.seek(0) if file_object is None:
if header.startswith(MAGIC_ZIP): raise ParserError("Invalid report content")
if header[: len(MAGIC_ZIP)] == MAGIC_ZIP:
_zip = zipfile.ZipFile(file_object) _zip = zipfile.ZipFile(file_object)
report = _zip.open(_zip.namelist()[0]).read().decode(errors="ignore") report = _zip.open(_zip.namelist()[0]).read().decode(errors="ignore")
elif header.startswith(MAGIC_GZIP): elif header[: len(MAGIC_GZIP)] == MAGIC_GZIP:
report = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode( report = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode(
errors="ignore" errors="ignore"
) )
elif header.startswith(MAGIC_XML) or header.startswith(MAGIC_JSON): elif (
header[: len(MAGIC_XML)] == MAGIC_XML
or header[: len(MAGIC_JSON)] == MAGIC_JSON
):
report = file_object.read().decode(errors="ignore") report = file_object.read().decode(errors="ignore")
else: else:
raise ParserError("Not a valid zip, gzip, json, or xml file") raise ParserError("Not a valid zip, gzip, json, or xml file")
@@ -918,7 +936,7 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO])
except Exception as error: except Exception as error:
raise ParserError("Invalid archive file: {0}".format(error.__str__())) raise ParserError("Invalid archive file: {0}".format(error.__str__()))
finally: finally:
if file_object and hasattr(file_object, "close"): if file_object:
try: try:
file_object.close() file_object.close()
except Exception: except Exception:
@@ -937,17 +955,17 @@ def extract_report_from_file_path(file_path: str):
def parse_aggregate_report_file( def parse_aggregate_report_file(
_input: Union[str, bytes, IO[Any]], _input: Union[str, bytes, BinaryIO],
*, *,
offline: Optional[bool] = False, offline: bool = False,
always_use_local_files: Optional[bool] = None, always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None, reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None, reverse_dns_map_url: Optional[str] = None,
ip_db_path: Optional[str] = None, ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0, dns_timeout: float = 2.0,
keep_alive: Optional[Callable] = None, keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24.0, normalize_timespan_threshold_hours: float = 24.0,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Parses a file at the given path, a file-like object. or bytes as an """Parses a file at the given path, a file-like object. or bytes as an
aggregate DMARC report aggregate DMARC report
@@ -1187,14 +1205,14 @@ def parse_forensic_report(
sample: str, sample: str,
msg_date: datetime, msg_date: datetime,
*, *,
always_use_local_files: Optional[bool] = False, always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None, reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None, reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False, offline: bool = False,
ip_db_path: Optional[str] = None, ip_db_path: Optional[str] = None,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0, dns_timeout: float = 2.0,
strip_attachment_payloads: Optional[bool] = False, strip_attachment_payloads: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Converts a DMARC forensic report and sample to a dict Converts a DMARC forensic report and sample to a dict
@@ -1449,20 +1467,32 @@ def parse_report_email(
* ``report_type``: ``aggregate`` or ``forensic`` * ``report_type``: ``aggregate`` or ``forensic``
* ``report``: The parsed report * ``report``: The parsed report
""" """
result = None result: Optional[dict[str, Any]] = None
msg_date: datetime = datetime.now(timezone.utc) msg_date: datetime = datetime.now(timezone.utc)
try: try:
if isinstance(input_, bytes) and is_outlook_msg(input_): input_data: Union[str, bytes, bytearray, memoryview] = input_
input_ = convert_outlook_msg(input_) if isinstance(input_data, (bytes, bytearray, memoryview)):
if isinstance(input_, bytes): input_bytes = bytes(input_data)
input_ = input_.decode(encoding="utf8", errors="replace") if is_outlook_msg(input_bytes):
msg = mailparser.parse_from_string(input_) converted = convert_outlook_msg(input_bytes)
if isinstance(converted, str):
input_str = converted
else:
input_str = bytes(converted).decode(
encoding="utf8", errors="replace"
)
else:
input_str = input_bytes.decode(encoding="utf8", errors="replace")
else:
input_str = input_data
msg = mailparser.parse_from_string(input_str)
msg_headers = json.loads(msg.headers_json) msg_headers = json.loads(msg.headers_json)
if "Date" in msg_headers: if "Date" in msg_headers:
msg_date = human_timestamp_to_datetime(msg_headers["Date"]) msg_date = human_timestamp_to_datetime(msg_headers["Date"])
date = email.utils.format_datetime(msg_date) date = email.utils.format_datetime(msg_date)
msg = email.message_from_string(input_) msg = email.message_from_string(input_str)
except Exception as e: except Exception as e:
raise ParserError(e.__str__()) raise ParserError(e.__str__())
@@ -1477,10 +1507,10 @@ def parse_report_email(
subject = msg_headers["Subject"] subject = msg_headers["Subject"]
for part in msg.walk(): for part in msg.walk():
content_type = part.get_content_type().lower() content_type = part.get_content_type().lower()
payload = part.get_payload() payload_obj = part.get_payload()
if not isinstance(payload, list): if not isinstance(payload_obj, list):
payload = [payload] payload_obj = [payload_obj]
payload = payload[0].__str__() payload = str(payload_obj[0])
if content_type.startswith("multipart/"): if content_type.startswith("multipart/"):
continue continue
if content_type == "text/html": if content_type == "text/html":
@@ -1501,7 +1531,7 @@ def parse_report_email(
sample = payload sample = payload
elif content_type == "application/tlsrpt+json": elif content_type == "application/tlsrpt+json":
if not payload.strip().startswith("{"): if not payload.strip().startswith("{"):
payload = str(b64decode(payload)) payload = b64decode(payload).decode("utf-8", errors="replace")
smtp_tls_report = parse_smtp_tls_report_json(payload) smtp_tls_report = parse_smtp_tls_report_json(payload)
return {"report_type": "smtp_tls", "report": smtp_tls_report} return {"report_type": "smtp_tls", "report": smtp_tls_report}
elif content_type == "application/tlsrpt+gzip": elif content_type == "application/tlsrpt+gzip":
@@ -1531,18 +1561,21 @@ def parse_report_email(
logger.debug(sample) logger.debug(sample)
else: else:
try: try:
payload = b64decode(payload) payload_bytes = b64decode(payload)
if payload.startswith(MAGIC_ZIP) or payload.startswith(MAGIC_GZIP): if payload_bytes.startswith(MAGIC_ZIP) or payload_bytes.startswith(
payload = extract_report(payload) MAGIC_GZIP
if isinstance(payload, bytes): ):
payload = payload.decode("utf-8", errors="replace") payload_text = extract_report(payload_bytes)
if payload.strip().startswith("{"): else:
smtp_tls_report = parse_smtp_tls_report_json(payload) payload_text = payload_bytes.decode("utf-8", errors="replace")
if payload_text.strip().startswith("{"):
smtp_tls_report = parse_smtp_tls_report_json(payload_text)
result = {"report_type": "smtp_tls", "report": smtp_tls_report} result = {"report_type": "smtp_tls", "report": smtp_tls_report}
return result return result
elif payload.strip().startswith("<"): elif payload_text.strip().startswith("<"):
aggregate_report = parse_aggregate_report_xml( aggregate_report = parse_aggregate_report_xml(
payload, payload_text,
ip_db_path=ip_db_path, ip_db_path=ip_db_path,
always_use_local_files=always_use_local_files, always_use_local_files=always_use_local_files,
reverse_dns_map_path=reverse_dns_map_path, reverse_dns_map_path=reverse_dns_map_path,
@@ -1604,26 +1637,28 @@ def parse_report_email(
error = 'Message with subject "{0}" is not a valid report'.format(subject) error = 'Message with subject "{0}" is not a valid report'.format(subject)
raise InvalidDMARCReport(error) raise InvalidDMARCReport(error)
return result
def parse_report_file( def parse_report_file(
input_: Union[bytes, str, IO[Any]], input_: Union[bytes, str, BinaryIO],
*, *,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 2.0, dns_timeout: float = 2.0,
strip_attachment_payloads: Optional[bool] = False, strip_attachment_payloads: bool = False,
ip_db_path: Optional[str] = None, ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False, always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None, reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None, reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False, offline: bool = False,
keep_alive: Optional[Callable] = None, keep_alive: Optional[Callable] = None,
normalize_timespan_threshold_hours: Optional[float] = 24, normalize_timespan_threshold_hours: float = 24,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Parses a DMARC aggregate or forensic file at the given path, a """Parses a DMARC aggregate or forensic file at the given path, a
file-like object. or bytes file-like object. or bytes
Args: Args:
input_ (str | bytes | IO): A path to a file, a file like object, or bytes input_ (str | bytes | BinaryIO): A path to a file, a file like object, or bytes
nameservers (list): A list of one or more nameservers to use nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default) (Cloudflare's public DNS resolvers by default)
dns_timeout (float): Sets the DNS timeout in seconds dns_timeout (float): Sets the DNS timeout in seconds
@@ -1639,11 +1674,12 @@ def parse_report_file(
Returns: Returns:
dict: The parsed DMARC report dict: The parsed DMARC report
""" """
if type(input_) is str: file_object: BinaryIO
if isinstance(input_, str):
logger.debug("Parsing {0}".format(input_)) logger.debug("Parsing {0}".format(input_))
file_object = open(input_, "rb") file_object = open(input_, "rb")
elif type(input_) is bytes: elif isinstance(input_, (bytes, bytearray, memoryview)):
file_object = BytesIO(input_) file_object = BytesIO(bytes(input_))
else: else:
file_object = input_ file_object = input_
@@ -1653,6 +1689,7 @@ def parse_report_file(
content = extract_report(content) content = extract_report(content)
results: Optional[dict[str, Any]] = None results: Optional[dict[str, Any]] = None
try: try:
report = parse_aggregate_report_file( report = parse_aggregate_report_file(
content, content,
@@ -1673,7 +1710,6 @@ def parse_report_file(
results = {"report_type": "smtp_tls", "report": report} results = {"report_type": "smtp_tls", "report": report}
except InvalidSMTPTLSReport: except InvalidSMTPTLSReport:
try: try:
sa = strip_attachment_payloads
results = parse_report_email( results = parse_report_email(
content, content,
ip_db_path=ip_db_path, ip_db_path=ip_db_path,
@@ -1683,7 +1719,7 @@ def parse_report_file(
offline=offline, offline=offline,
nameservers=nameservers, nameservers=nameservers,
dns_timeout=dns_timeout, dns_timeout=dns_timeout,
strip_attachment_payloads=sa, strip_attachment_payloads=strip_attachment_payloads,
keep_alive=keep_alive, keep_alive=keep_alive,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours,
) )
@@ -1799,7 +1835,7 @@ def get_dmarc_reports_from_mailbox(
strip_attachment_payloads: bool = False, strip_attachment_payloads: bool = False,
results: Optional[dict[str, Any]] = None, results: Optional[dict[str, Any]] = None,
batch_size: int = 10, batch_size: int = 10,
since: Optional[Union[datetime, str]] = None, since: Optional[Union[datetime, date, str]] = None,
create_folders: bool = True, create_folders: bool = True,
normalize_timespan_threshold_hours: float = 24, normalize_timespan_threshold_hours: float = 24,
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
@@ -1840,7 +1876,7 @@ def get_dmarc_reports_from_mailbox(
raise ValueError("Must supply a connection") raise ValueError("Must supply a connection")
# current_time useful to fetch_messages later in the program # current_time useful to fetch_messages later in the program
current_time = None current_time: Optional[Union[datetime, date, str]] = None
aggregate_reports = [] aggregate_reports = []
forensic_reports = [] forensic_reports = []
@@ -1865,7 +1901,7 @@ def get_dmarc_reports_from_mailbox(
connection.create_folder(smtp_tls_reports_folder) connection.create_folder(smtp_tls_reports_folder)
connection.create_folder(invalid_reports_folder) connection.create_folder(invalid_reports_folder)
if since: if since and isinstance(since, str):
_since = 1440 # default one day _since = 1440 # default one day
if re.match(r"\d+[mhdw]$", since): if re.match(r"\d+[mhdw]$", since):
s = re.split(r"(\d+)", since) s = re.split(r"(\d+)", since)
@@ -1926,13 +1962,16 @@ def get_dmarc_reports_from_mailbox(
i + 1, message_limit, msg_uid i + 1, message_limit, msg_uid
) )
) )
if isinstance(mailbox, MSGraphConnection): message_id: Union[int, str]
if test: if isinstance(connection, IMAPConnection):
msg_content = connection.fetch_message(msg_uid, mark_read=False) message_id = int(msg_uid)
else: msg_content = connection.fetch_message(message_id)
msg_content = connection.fetch_message(msg_uid, mark_read=True) elif isinstance(connection, MSGraphConnection):
message_id = str(msg_uid)
msg_content = connection.fetch_message(message_id, mark_read=not test)
else: else:
msg_content = connection.fetch_message(msg_uid) message_id = str(msg_uid) if not isinstance(msg_uid, str) else msg_uid
msg_content = connection.fetch_message(message_id)
try: try:
sa = strip_attachment_payloads sa = strip_attachment_payloads
parsed_email = parse_report_email( parsed_email = parse_report_email(
@@ -1959,26 +1998,32 @@ def get_dmarc_reports_from_mailbox(
logger.debug( logger.debug(
f"Skipping duplicate aggregate report with ID: {report_id}" f"Skipping duplicate aggregate report with ID: {report_id}"
) )
aggregate_report_msg_uids.append(msg_uid) aggregate_report_msg_uids.append(message_id)
elif parsed_email["report_type"] == "forensic": elif parsed_email["report_type"] == "forensic":
forensic_reports.append(parsed_email["report"]) forensic_reports.append(parsed_email["report"])
forensic_report_msg_uids.append(msg_uid) forensic_report_msg_uids.append(message_id)
elif parsed_email["report_type"] == "smtp_tls": elif parsed_email["report_type"] == "smtp_tls":
smtp_tls_reports.append(parsed_email["report"]) smtp_tls_reports.append(parsed_email["report"])
smtp_tls_msg_uids.append(msg_uid) smtp_tls_msg_uids.append(message_id)
except ParserError as error: except ParserError as error:
logger.warning(error.__str__()) logger.warning(error.__str__())
if not test: if not test:
if delete: if delete:
logger.debug("Deleting message UID {0}".format(msg_uid)) logger.debug("Deleting message UID {0}".format(msg_uid))
connection.delete_message(msg_uid) if isinstance(connection, IMAPConnection):
connection.delete_message(int(message_id))
else:
connection.delete_message(str(message_id))
else: else:
logger.debug( logger.debug(
"Moving message UID {0} to {1}".format( "Moving message UID {0} to {1}".format(
msg_uid, invalid_reports_folder msg_uid, invalid_reports_folder
) )
) )
connection.move_message(msg_uid, invalid_reports_folder) if isinstance(connection, IMAPConnection):
connection.move_message(int(message_id), invalid_reports_folder)
else:
connection.move_message(str(message_id), invalid_reports_folder)
if not test: if not test:
if delete: if delete:
@@ -2106,21 +2151,21 @@ def watch_inbox(
mailbox_connection: MailboxConnection, mailbox_connection: MailboxConnection,
callback: Callable, callback: Callable,
*, *,
reports_folder: Optional[str] = "INBOX", reports_folder: str = "INBOX",
archive_folder: Optional[str] = "Archive", archive_folder: str = "Archive",
delete: Optional[bool] = False, delete: bool = False,
test: Optional[bool] = False, test: bool = False,
check_timeout: Optional[int] = 30, check_timeout: int = 30,
ip_db_path: Optional[str] = None, ip_db_path: Optional[str] = None,
always_use_local_files: Optional[bool] = False, always_use_local_files: bool = False,
reverse_dns_map_path: Optional[str] = None, reverse_dns_map_path: Optional[str] = None,
reverse_dns_map_url: Optional[str] = None, reverse_dns_map_url: Optional[str] = None,
offline: Optional[bool] = False, offline: bool = False,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
dns_timeout: Optional[float] = 6.0, dns_timeout: float = 6.0,
strip_attachment_payloads: Optional[bool] = False, strip_attachment_payloads: bool = False,
batch_size: Optional[int] = None, batch_size: int = 10,
normalize_timespan_threshold_hours: Optional[float] = 24, normalize_timespan_threshold_hours: float = 24,
): ):
""" """
Watches the mailbox for new messages and Watches the mailbox for new messages and
@@ -2150,7 +2195,6 @@ def watch_inbox(
""" """
def check_callback(connection): def check_callback(connection):
sa = strip_attachment_payloads
res = get_dmarc_reports_from_mailbox( res = get_dmarc_reports_from_mailbox(
connection=connection, connection=connection,
reports_folder=reports_folder, reports_folder=reports_folder,
@@ -2164,7 +2208,7 @@ def watch_inbox(
offline=offline, offline=offline,
nameservers=nameservers, nameservers=nameservers,
dns_timeout=dns_timeout, dns_timeout=dns_timeout,
strip_attachment_payloads=sa, strip_attachment_payloads=strip_attachment_payloads,
batch_size=batch_size, batch_size=batch_size,
create_folders=False, create_folders=False,
normalize_timespan_threshold_hours=normalize_timespan_threshold_hours, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours,
@@ -2212,13 +2256,13 @@ def append_csv(filename, csv):
def save_output( def save_output(
results: dict[str, Any], results: dict[str, Any],
*, *,
output_directory: Optional[str] = "output", output_directory: str = "output",
aggregate_json_filename: Optional[str] = "aggregate.json", aggregate_json_filename: str = "aggregate.json",
forensic_json_filename: Optional[str] = "forensic.json", forensic_json_filename: str = "forensic.json",
smtp_tls_json_filename: Optional[str] = "smtp_tls.json", smtp_tls_json_filename: str = "smtp_tls.json",
aggregate_csv_filename: Optional[str] = "aggregate.csv", aggregate_csv_filename: str = "aggregate.csv",
forensic_csv_filename: Optional[str] = "forensic.csv", forensic_csv_filename: str = "forensic.csv",
smtp_tls_csv_filename: Optional[str] = "smtp_tls.csv", smtp_tls_csv_filename: str = "smtp_tls.csv",
): ):
""" """
Save report data in the given directory Save report data in the given directory
@@ -2322,7 +2366,7 @@ def get_report_zip(results: dict[str, Any]) -> bytes:
storage = BytesIO() storage = BytesIO()
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
try: try:
save_output(results, tmp_dir) save_output(results, output_directory=tmp_dir)
with zipfile.ZipFile(storage, "w", zipfile.ZIP_DEFLATED) as zip_file: with zipfile.ZipFile(storage, "w", zipfile.ZIP_DEFLATED) as zip_file:
for root, dirs, files in os.walk(tmp_dir): for root, dirs, files in os.walk(tmp_dir):
for file in files: for file in files:
@@ -2345,10 +2389,10 @@ def email_results(
results: dict[str, Any], results: dict[str, Any],
host: str, host: str,
mail_from: str, mail_from: str,
mail_to: str, mail_to: Optional[list[str]],
*, *,
mail_cc: Optional[list] = None, mail_cc: Optional[list[str]] = None,
mail_bcc: Optional[list] = None, mail_bcc: Optional[list[str]] = None,
port: int = 0, port: int = 0,
require_encryption: bool = False, require_encryption: bool = False,
verify: bool = True, verify: bool = True,
@@ -2377,7 +2421,7 @@ def email_results(
attachment_filename (str): Override the default attachment filename attachment_filename (str): Override the default attachment filename
message (str): Override the default plain text body message (str): Override the default plain text body
""" """
logger.debug("Emailing report to: {0}".format(",".join(mail_to))) logger.debug("Emailing report")
date_string = datetime.now().strftime("%Y-%m-%d") date_string = datetime.now().strftime("%Y-%m-%d")
if attachment_filename: if attachment_filename:
if not attachment_filename.lower().endswith(".zip"): if not attachment_filename.lower().endswith(".zip"):

View File

@@ -678,7 +678,7 @@ def _main():
if "general" in config.sections(): if "general" in config.sections():
general_config = config["general"] general_config = config["general"]
if "silent" in general_config: if "silent" in general_config:
opts.silent = general_config.getboolean("silent") opts.silent = bool(general_config.getboolean("silent"))
if "normalize_timespan_threshold_hours" in general_config: if "normalize_timespan_threshold_hours" in general_config:
opts.normalize_timespan_threshold_hours = general_config.getfloat( opts.normalize_timespan_threshold_hours = general_config.getfloat(
"normalize_timespan_threshold_hours" "normalize_timespan_threshold_hours"
@@ -687,10 +687,10 @@ def _main():
with open(general_config["index_prefix_domain_map"]) as f: with open(general_config["index_prefix_domain_map"]) as f:
index_prefix_domain_map = yaml.safe_load(f) index_prefix_domain_map = yaml.safe_load(f)
if "offline" in general_config: if "offline" in general_config:
opts.offline = general_config.getboolean("offline") opts.offline = bool(general_config.getboolean("offline"))
if "strip_attachment_payloads" in general_config: if "strip_attachment_payloads" in general_config:
opts.strip_attachment_payloads = general_config.getboolean( opts.strip_attachment_payloads = bool(
"strip_attachment_payloads" general_config.getboolean("strip_attachment_payloads")
) )
if "output" in general_config: if "output" in general_config:
opts.output = general_config["output"] opts.output = general_config["output"]
@@ -732,19 +732,19 @@ def _main():
) )
exit(-1) exit(-1)
if "save_aggregate" in general_config: if "save_aggregate" in general_config:
opts.save_aggregate = general_config.getboolean("save_aggregate") opts.save_aggregate = bool(general_config.getboolean("save_aggregate"))
if "save_forensic" in general_config: if "save_forensic" in general_config:
opts.save_forensic = general_config.getboolean("save_forensic") opts.save_forensic = bool(general_config.getboolean("save_forensic"))
if "save_smtp_tls" in general_config: if "save_smtp_tls" in general_config:
opts.save_smtp_tls = general_config.getboolean("save_smtp_tls") opts.save_smtp_tls = bool(general_config.getboolean("save_smtp_tls"))
if "debug" in general_config: if "debug" in general_config:
opts.debug = general_config.getboolean("debug") opts.debug = bool(general_config.getboolean("debug"))
if "verbose" in general_config: if "verbose" in general_config:
opts.verbose = general_config.getboolean("verbose") opts.verbose = bool(general_config.getboolean("verbose"))
if "silent" in general_config: if "silent" in general_config:
opts.silent = general_config.getboolean("silent") opts.silent = bool(general_config.getboolean("silent"))
if "warnings" in general_config: if "warnings" in general_config:
opts.warnings = general_config.getboolean("warnings") opts.warnings = bool(general_config.getboolean("warnings"))
if "log_file" in general_config: if "log_file" in general_config:
opts.log_file = general_config["log_file"] opts.log_file = general_config["log_file"]
if "n_procs" in general_config: if "n_procs" in general_config:
@@ -754,15 +754,15 @@ def _main():
else: else:
opts.ip_db_path = None opts.ip_db_path = None
if "always_use_local_files" in general_config: if "always_use_local_files" in general_config:
opts.always_use_local_files = general_config.getboolean( opts.always_use_local_files = bool(
"always_use_local_files" general_config.getboolean("always_use_local_files")
) )
if "reverse_dns_map_path" in general_config: if "reverse_dns_map_path" in general_config:
opts.reverse_dns_map_path = general_config["reverse_dns_path"] opts.reverse_dns_map_path = general_config["reverse_dns_path"]
if "reverse_dns_map_url" in general_config: if "reverse_dns_map_url" in general_config:
opts.reverse_dns_map_url = general_config["reverse_dns_url"] opts.reverse_dns_map_url = general_config["reverse_dns_url"]
if "prettify_json" in general_config: if "prettify_json" in general_config:
opts.prettify_json = general_config.getboolean("prettify_json") opts.prettify_json = bool(general_config.getboolean("prettify_json"))
if "mailbox" in config.sections(): if "mailbox" in config.sections():
mailbox_config = config["mailbox"] mailbox_config = config["mailbox"]
@@ -773,11 +773,11 @@ def _main():
if "archive_folder" in mailbox_config: if "archive_folder" in mailbox_config:
opts.mailbox_archive_folder = mailbox_config["archive_folder"] opts.mailbox_archive_folder = mailbox_config["archive_folder"]
if "watch" in mailbox_config: if "watch" in mailbox_config:
opts.mailbox_watch = mailbox_config.getboolean("watch") opts.mailbox_watch = bool(mailbox_config.getboolean("watch"))
if "delete" in mailbox_config: if "delete" in mailbox_config:
opts.mailbox_delete = mailbox_config.getboolean("delete") opts.mailbox_delete = bool(mailbox_config.getboolean("delete"))
if "test" in mailbox_config: if "test" in mailbox_config:
opts.mailbox_test = mailbox_config.getboolean("test") opts.mailbox_test = bool(mailbox_config.getboolean("test"))
if "batch_size" in mailbox_config: if "batch_size" in mailbox_config:
opts.mailbox_batch_size = mailbox_config.getint("batch_size") opts.mailbox_batch_size = mailbox_config.getint("batch_size")
if "check_timeout" in mailbox_config: if "check_timeout" in mailbox_config:
@@ -805,10 +805,10 @@ def _main():
if "max_retries" in imap_config: if "max_retries" in imap_config:
opts.imap_max_retries = imap_config.getint("max_retries") opts.imap_max_retries = imap_config.getint("max_retries")
if "ssl" in imap_config: if "ssl" in imap_config:
opts.imap_ssl = imap_config.getboolean("ssl") opts.imap_ssl = bool(imap_config.getboolean("ssl"))
if "skip_certificate_verification" in imap_config: if "skip_certificate_verification" in imap_config:
opts.imap_skip_certificate_verification = imap_config.getboolean( opts.imap_skip_certificate_verification = bool(
"skip_certificate_verification" imap_config.getboolean("skip_certificate_verification")
) )
if "user" in imap_config: if "user" in imap_config:
opts.imap_user = imap_config["user"] opts.imap_user = imap_config["user"]
@@ -837,7 +837,7 @@ def _main():
"section instead." "section instead."
) )
if "watch" in imap_config: if "watch" in imap_config:
opts.mailbox_watch = imap_config.getboolean("watch") opts.mailbox_watch = bool(imap_config.getboolean("watch"))
logger.warning( logger.warning(
"Use of the watch option in the imap " "Use of the watch option in the imap "
"configuration section has been deprecated. " "configuration section has been deprecated. "
@@ -852,7 +852,7 @@ def _main():
"section instead." "section instead."
) )
if "test" in imap_config: if "test" in imap_config:
opts.mailbox_test = imap_config.getboolean("test") opts.mailbox_test = bool(imap_config.getboolean("test"))
logger.warning( logger.warning(
"Use of the test option in the imap " "Use of the test option in the imap "
"configuration section has been deprecated. " "configuration section has been deprecated. "
@@ -946,8 +946,8 @@ def _main():
opts.graph_url = graph_config["graph_url"] opts.graph_url = graph_config["graph_url"]
if "allow_unencrypted_storage" in graph_config: if "allow_unencrypted_storage" in graph_config:
opts.graph_allow_unencrypted_storage = graph_config.getboolean( opts.graph_allow_unencrypted_storage = bool(
"allow_unencrypted_storage" graph_config.getboolean("allow_unencrypted_storage")
) )
if "elasticsearch" in config: if "elasticsearch" in config:
@@ -975,10 +975,10 @@ def _main():
if "index_prefix" in elasticsearch_config: if "index_prefix" in elasticsearch_config:
opts.elasticsearch_index_prefix = elasticsearch_config["index_prefix"] opts.elasticsearch_index_prefix = elasticsearch_config["index_prefix"]
if "monthly_indexes" in elasticsearch_config: if "monthly_indexes" in elasticsearch_config:
monthly = elasticsearch_config.getboolean("monthly_indexes") monthly = bool(elasticsearch_config.getboolean("monthly_indexes"))
opts.elasticsearch_monthly_indexes = monthly opts.elasticsearch_monthly_indexes = monthly
if "ssl" in elasticsearch_config: if "ssl" in elasticsearch_config:
opts.elasticsearch_ssl = elasticsearch_config.getboolean("ssl") opts.elasticsearch_ssl = bool(elasticsearch_config.getboolean("ssl"))
if "cert_path" in elasticsearch_config: if "cert_path" in elasticsearch_config:
opts.elasticsearch_ssl_cert_path = elasticsearch_config["cert_path"] opts.elasticsearch_ssl_cert_path = elasticsearch_config["cert_path"]
if "user" in elasticsearch_config: if "user" in elasticsearch_config:
@@ -1015,10 +1015,10 @@ def _main():
if "index_prefix" in opensearch_config: if "index_prefix" in opensearch_config:
opts.opensearch_index_prefix = opensearch_config["index_prefix"] opts.opensearch_index_prefix = opensearch_config["index_prefix"]
if "monthly_indexes" in opensearch_config: if "monthly_indexes" in opensearch_config:
monthly = opensearch_config.getboolean("monthly_indexes") monthly = bool(opensearch_config.getboolean("monthly_indexes"))
opts.opensearch_monthly_indexes = monthly opts.opensearch_monthly_indexes = monthly
if "ssl" in opensearch_config: if "ssl" in opensearch_config:
opts.opensearch_ssl = opensearch_config.getboolean("ssl") opts.opensearch_ssl = bool(opensearch_config.getboolean("ssl"))
if "cert_path" in opensearch_config: if "cert_path" in opensearch_config:
opts.opensearch_ssl_cert_path = opensearch_config["cert_path"] opts.opensearch_ssl_cert_path = opensearch_config["cert_path"]
if "user" in opensearch_config: if "user" in opensearch_config:
@@ -1072,9 +1072,11 @@ def _main():
if "password" in kafka_config: if "password" in kafka_config:
opts.kafka_password = kafka_config["password"] opts.kafka_password = kafka_config["password"]
if "ssl" in kafka_config: if "ssl" in kafka_config:
opts.kafka_ssl = kafka_config.getboolean("ssl") opts.kafka_ssl = bool(kafka_config.getboolean("ssl"))
if "skip_certificate_verification" in kafka_config: if "skip_certificate_verification" in kafka_config:
kafka_verify = kafka_config.getboolean("skip_certificate_verification") kafka_verify = bool(
kafka_config.getboolean("skip_certificate_verification")
)
opts.kafka_skip_certificate_verification = kafka_verify opts.kafka_skip_certificate_verification = kafka_verify
if "aggregate_topic" in kafka_config: if "aggregate_topic" in kafka_config:
opts.kafka_aggregate_topic = kafka_config["aggregate_topic"] opts.kafka_aggregate_topic = kafka_config["aggregate_topic"]
@@ -1106,9 +1108,11 @@ def _main():
if "port" in smtp_config: if "port" in smtp_config:
opts.smtp_port = smtp_config.getint("port") opts.smtp_port = smtp_config.getint("port")
if "ssl" in smtp_config: if "ssl" in smtp_config:
opts.smtp_ssl = smtp_config.getboolean("ssl") opts.smtp_ssl = bool(smtp_config.getboolean("ssl"))
if "skip_certificate_verification" in smtp_config: if "skip_certificate_verification" in smtp_config:
smtp_verify = smtp_config.getboolean("skip_certificate_verification") smtp_verify = bool(
smtp_config.getboolean("skip_certificate_verification")
)
opts.smtp_skip_certificate_verification = smtp_verify opts.smtp_skip_certificate_verification = smtp_verify
if "user" in smtp_config: if "user" in smtp_config:
opts.smtp_user = smtp_config["user"] opts.smtp_user = smtp_config["user"]
@@ -1176,11 +1180,11 @@ def _main():
gmail_api_config = config["gmail_api"] gmail_api_config = config["gmail_api"]
opts.gmail_api_credentials_file = gmail_api_config.get("credentials_file") opts.gmail_api_credentials_file = gmail_api_config.get("credentials_file")
opts.gmail_api_token_file = gmail_api_config.get("token_file", ".token") opts.gmail_api_token_file = gmail_api_config.get("token_file", ".token")
opts.gmail_api_include_spam_trash = gmail_api_config.getboolean( opts.gmail_api_include_spam_trash = bool(
"include_spam_trash", False gmail_api_config.getboolean("include_spam_trash", False)
) )
opts.gmail_api_paginate_messages = gmail_api_config.getboolean( opts.gmail_api_paginate_messages = bool(
"paginate_messages", True gmail_api_config.getboolean("paginate_messages", True)
) )
opts.gmail_api_scopes = gmail_api_config.get( opts.gmail_api_scopes = gmail_api_config.get(
"scopes", default_gmail_api_scope "scopes", default_gmail_api_scope
@@ -1194,8 +1198,8 @@ def _main():
if "maildir" in config.sections(): if "maildir" in config.sections():
maildir_api_config = config["maildir"] maildir_api_config = config["maildir"]
opts.maildir_path = maildir_api_config.get("maildir_path") opts.maildir_path = maildir_api_config.get("maildir_path")
opts.maildir_create = maildir_api_config.getboolean( opts.maildir_create = bool(
"maildir_create", fallback=False maildir_api_config.getboolean("maildir_create", fallback=False)
) )
if "log_analytics" in config.sections(): if "log_analytics" in config.sections():
@@ -1274,6 +1278,7 @@ def _main():
exit(1) exit(1)
logger.info("Starting parsedmarc") logger.info("Starting parsedmarc")
if opts.save_aggregate or opts.save_forensic or opts.save_smtp_tls: if opts.save_aggregate or opts.save_forensic or opts.save_smtp_tls:
try: try:
@@ -1513,6 +1518,11 @@ def _main():
smtp_tls_reports.append(result[0]["report"]) smtp_tls_reports.append(result[0]["report"])
for mbox_path in mbox_paths: for mbox_path in mbox_paths:
normalize_timespan_threshold_hours_value = (
float(opts.normalize_timespan_threshold_hours)
if opts.normalize_timespan_threshold_hours is not None
else 24.0
)
strip = opts.strip_attachment_payloads strip = opts.strip_attachment_payloads
reports = get_dmarc_reports_from_mbox( reports = get_dmarc_reports_from_mbox(
mbox_path, mbox_path,
@@ -1524,13 +1534,17 @@ def _main():
reverse_dns_map_path=opts.reverse_dns_map_path, reverse_dns_map_path=opts.reverse_dns_map_path,
reverse_dns_map_url=opts.reverse_dns_map_url, reverse_dns_map_url=opts.reverse_dns_map_url,
offline=opts.offline, offline=opts.offline,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
) )
aggregate_reports += reports["aggregate_reports"] aggregate_reports += reports["aggregate_reports"]
forensic_reports += reports["forensic_reports"] forensic_reports += reports["forensic_reports"]
smtp_tls_reports += reports["smtp_tls_reports"] smtp_tls_reports += reports["smtp_tls_reports"]
mailbox_connection = None mailbox_connection = None
mailbox_batch_size_value = 10
mailbox_check_timeout_value = 30
normalize_timespan_threshold_hours_value = 24.0
if opts.imap_host: if opts.imap_host:
try: try:
if opts.imap_user is None or opts.imap_password is None: if opts.imap_user is None or opts.imap_password is None:
@@ -1552,9 +1566,10 @@ def _main():
imap_max_retries = ( imap_max_retries = (
int(opts.imap_max_retries) if opts.imap_max_retries is not None else 4 int(opts.imap_max_retries) if opts.imap_max_retries is not None else 4
) )
imap_port_value = int(opts.imap_port) if opts.imap_port is not None else 993
mailbox_connection = IMAPConnection( mailbox_connection = IMAPConnection(
host=opts.imap_host, host=opts.imap_host,
port=opts.imap_port, port=imap_port_value,
ssl=ssl, ssl=ssl,
verify=verify, verify=verify,
timeout=imap_timeout, timeout=imap_timeout,
@@ -1624,11 +1639,24 @@ def _main():
exit(1) exit(1)
if mailbox_connection: if mailbox_connection:
mailbox_batch_size_value = (
int(opts.mailbox_batch_size) if opts.mailbox_batch_size is not None else 10
)
mailbox_check_timeout_value = (
int(opts.mailbox_check_timeout)
if opts.mailbox_check_timeout is not None
else 30
)
normalize_timespan_threshold_hours_value = (
float(opts.normalize_timespan_threshold_hours)
if opts.normalize_timespan_threshold_hours is not None
else 24.0
)
try: try:
reports = get_dmarc_reports_from_mailbox( reports = get_dmarc_reports_from_mailbox(
connection=mailbox_connection, connection=mailbox_connection,
delete=opts.mailbox_delete, delete=opts.mailbox_delete,
batch_size=opts.mailbox_batch_size, batch_size=mailbox_batch_size_value,
reports_folder=opts.mailbox_reports_folder, reports_folder=opts.mailbox_reports_folder,
archive_folder=opts.mailbox_archive_folder, archive_folder=opts.mailbox_archive_folder,
ip_db_path=opts.ip_db_path, ip_db_path=opts.ip_db_path,
@@ -1640,7 +1668,7 @@ def _main():
test=opts.mailbox_test, test=opts.mailbox_test,
strip_attachment_payloads=opts.strip_attachment_payloads, strip_attachment_payloads=opts.strip_attachment_payloads,
since=opts.mailbox_since, since=opts.mailbox_since,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
) )
aggregate_reports += reports["aggregate_reports"] aggregate_reports += reports["aggregate_reports"]
@@ -1664,12 +1692,18 @@ def _main():
verify = True verify = True
if opts.smtp_skip_certificate_verification: if opts.smtp_skip_certificate_verification:
verify = False verify = False
smtp_port_value = int(opts.smtp_port) if opts.smtp_port is not None else 25
smtp_to_value = (
list(opts.smtp_to)
if isinstance(opts.smtp_to, list)
else _str_to_list(str(opts.smtp_to))
)
email_results( email_results(
results, results,
opts.smtp_host, opts.smtp_host,
opts.smtp_from, opts.smtp_from,
opts.smtp_to, smtp_to_value,
port=opts.smtp_port, port=smtp_port_value,
verify=verify, verify=verify,
username=opts.smtp_user, username=opts.smtp_user,
password=opts.smtp_password, password=opts.smtp_password,
@@ -1683,6 +1717,7 @@ def _main():
if mailbox_connection and opts.mailbox_watch: if mailbox_connection and opts.mailbox_watch:
logger.info("Watching for email - Quit with ctrl-c") logger.info("Watching for email - Quit with ctrl-c")
try: try:
watch_inbox( watch_inbox(
mailbox_connection=mailbox_connection, mailbox_connection=mailbox_connection,
@@ -1691,17 +1726,17 @@ def _main():
archive_folder=opts.mailbox_archive_folder, archive_folder=opts.mailbox_archive_folder,
delete=opts.mailbox_delete, delete=opts.mailbox_delete,
test=opts.mailbox_test, test=opts.mailbox_test,
check_timeout=opts.mailbox_check_timeout, check_timeout=mailbox_check_timeout_value,
nameservers=opts.nameservers, nameservers=opts.nameservers,
dns_timeout=opts.dns_timeout, dns_timeout=opts.dns_timeout,
strip_attachment_payloads=opts.strip_attachment_payloads, strip_attachment_payloads=opts.strip_attachment_payloads,
batch_size=opts.mailbox_batch_size, batch_size=mailbox_batch_size_value,
ip_db_path=opts.ip_db_path, ip_db_path=opts.ip_db_path,
always_use_local_files=opts.always_use_local_files, always_use_local_files=opts.always_use_local_files,
reverse_dns_map_path=opts.reverse_dns_map_path, reverse_dns_map_path=opts.reverse_dns_map_path,
reverse_dns_map_url=opts.reverse_dns_map_url, reverse_dns_map_url=opts.reverse_dns_map_url,
offline=opts.offline, offline=opts.offline,
normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value,
) )
except FileExistsError as error: except FileExistsError as error:
logger.error("{0}".format(error.__str__())) logger.error("{0}".format(error.__str__()))

View File

@@ -17,7 +17,7 @@ import shutil
import subprocess import subprocess
import tempfile import tempfile
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, Union from typing import Optional, TypedDict, Union, cast
import mailparser import mailparser
from expiringdict import ExpiringDict from expiringdict import ExpiringDict
@@ -64,7 +64,24 @@ class DownloadError(RuntimeError):
"""Raised when an error occurs when downloading a file""" """Raised when an error occurs when downloading a file"""
def decode_base64(data) -> bytes: class ReverseDNSService(TypedDict):
name: str
type: Optional[str]
ReverseDNSMap = dict[str, ReverseDNSService]
class IPAddressInfo(TypedDict):
ip_address: str
reverse_dns: Optional[str]
country: Optional[str]
base_domain: Optional[str]
name: Optional[str]
type: Optional[str]
def decode_base64(data: str) -> bytes:
""" """
Decodes a base64 string, with padding being optional Decodes a base64 string, with padding being optional
@@ -75,11 +92,11 @@ def decode_base64(data) -> bytes:
bytes: The decoded bytes bytes: The decoded bytes
""" """
data = bytes(data, encoding="ascii") data_bytes = bytes(data, encoding="ascii")
missing_padding = len(data) % 4 missing_padding = len(data_bytes) % 4
if missing_padding != 0: if missing_padding != 0:
data += b"=" * (4 - missing_padding) data_bytes += b"=" * (4 - missing_padding)
return base64.b64decode(data) return base64.b64decode(data_bytes)
def get_base_domain(domain: str) -> Optional[str]: def get_base_domain(domain: str) -> Optional[str]:
@@ -132,9 +149,9 @@ def query_dns(
record_type = record_type.upper() record_type = record_type.upper()
cache_key = "{0}_{1}".format(domain, record_type) cache_key = "{0}_{1}".format(domain, record_type)
if cache: if cache:
records = cache.get(cache_key, None) cached_records = cache.get(cache_key, None)
if records: if isinstance(cached_records, list):
return records return cast(list[str], cached_records)
resolver = dns.resolver.Resolver() resolver = dns.resolver.Resolver()
timeout = float(timeout) timeout = float(timeout)
@@ -148,26 +165,12 @@ def query_dns(
resolver.nameservers = nameservers resolver.nameservers = nameservers
resolver.timeout = timeout resolver.timeout = timeout
resolver.lifetime = timeout resolver.lifetime = timeout
if record_type == "TXT": records = list(
resource_records = list( map(
map( lambda r: r.to_text().replace('"', "").rstrip("."),
lambda r: r.strings, resolver.resolve(domain, record_type, lifetime=timeout),
resolver.resolve(domain, record_type, lifetime=timeout),
)
)
_resource_record = [
resource_record[0][:0].join(resource_record)
for resource_record in resource_records
if resource_record
]
records = [r.decode() for r in _resource_record]
else:
records = list(
map(
lambda r: r.to_text().replace('"', "").rstrip("."),
resolver.resolve(domain, record_type, lifetime=timeout),
)
) )
)
if cache: if cache:
cache[cache_key] = records cache[cache_key] = records
@@ -180,7 +183,7 @@ def get_reverse_dns(
cache: Optional[ExpiringDict] = None, cache: Optional[ExpiringDict] = None,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
timeout: float = 2.0, timeout: float = 2.0,
) -> str: ) -> Optional[str]:
""" """
Resolves an IP address to a hostname using a reverse DNS query Resolves an IP address to a hostname using a reverse DNS query
@@ -198,7 +201,7 @@ def get_reverse_dns(
try: try:
address = dns.reversename.from_address(ip_address) address = dns.reversename.from_address(ip_address)
hostname = query_dns( hostname = query_dns(
address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout str(address), "PTR", cache=cache, nameservers=nameservers, timeout=timeout
)[0] )[0]
except dns.exception.DNSException as e: except dns.exception.DNSException as e:
@@ -266,10 +269,12 @@ def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int:
float: The converted timestamp float: The converted timestamp
""" """
human_timestamp = human_timestamp.replace("T", " ") human_timestamp = human_timestamp.replace("T", " ")
return human_timestamp_to_datetime(human_timestamp).timestamp() return int(human_timestamp_to_datetime(human_timestamp).timestamp())
def get_ip_address_country(ip_address: str, *, db_path: Optional[str] = None) -> str: def get_ip_address_country(
ip_address: str, *, db_path: Optional[str] = None
) -> Optional[str]:
""" """
Returns the ISO code for the country associated Returns the ISO code for the country associated
with the given IPv4 or IPv6 address with the given IPv4 or IPv6 address
@@ -335,11 +340,11 @@ def get_service_from_reverse_dns_base_domain(
base_domain, base_domain,
*, *,
always_use_local_file: Optional[bool] = False, always_use_local_file: Optional[bool] = False,
local_file_path: Optional[bool] = None, local_file_path: Optional[str] = None,
url: Optional[bool] = None, url: Optional[str] = None,
offline: Optional[bool] = False, offline: Optional[bool] = False,
reverse_dns_map: Optional[bool] = None, reverse_dns_map: Optional[ReverseDNSMap] = None,
) -> str: ) -> ReverseDNSService:
""" """
Returns the service name of a given base domain name from reverse DNS. Returns the service name of a given base domain name from reverse DNS.
@@ -356,12 +361,6 @@ def get_service_from_reverse_dns_base_domain(
the supplied reverse_dns_base_domain and the type will be None the supplied reverse_dns_base_domain and the type will be None
""" """
def load_csv(_csv_file):
reader = csv.DictReader(_csv_file)
for row in reader:
key = row["base_reverse_dns"].lower().strip()
reverse_dns_map[key] = dict(name=row["name"], type=row["type"])
base_domain = base_domain.lower().strip() base_domain = base_domain.lower().strip()
if url is None: if url is None:
url = ( url = (
@@ -369,11 +368,24 @@ def get_service_from_reverse_dns_base_domain(
"/parsedmarc/master/parsedmarc/" "/parsedmarc/master/parsedmarc/"
"resources/maps/base_reverse_dns_map.csv" "resources/maps/base_reverse_dns_map.csv"
) )
reverse_dns_map_value: ReverseDNSMap
if reverse_dns_map is None: if reverse_dns_map is None:
reverse_dns_map = dict() reverse_dns_map_value = {}
else:
reverse_dns_map_value = reverse_dns_map
def load_csv(_csv_file):
reader = csv.DictReader(_csv_file)
for row in reader:
key = row["base_reverse_dns"].lower().strip()
reverse_dns_map_value[key] = {
"name": row["name"],
"type": row["type"],
}
csv_file = io.StringIO() csv_file = io.StringIO()
if not (offline or always_use_local_file) and len(reverse_dns_map) == 0: if not (offline or always_use_local_file) and len(reverse_dns_map_value) == 0:
try: try:
logger.debug(f"Trying to fetch reverse DNS map from {url}...") logger.debug(f"Trying to fetch reverse DNS map from {url}...")
headers = {"User-Agent": USER_AGENT} headers = {"User-Agent": USER_AGENT}
@@ -390,7 +402,7 @@ def get_service_from_reverse_dns_base_domain(
logging.debug("Response body:") logging.debug("Response body:")
logger.debug(csv_file.read()) logger.debug(csv_file.read())
if len(reverse_dns_map) == 0: if len(reverse_dns_map_value) == 0:
logger.info("Loading included reverse DNS map...") logger.info("Loading included reverse DNS map...")
path = str( path = str(
files(parsedmarc.resources.maps).joinpath("base_reverse_dns_map.csv") files(parsedmarc.resources.maps).joinpath("base_reverse_dns_map.csv")
@@ -399,10 +411,11 @@ def get_service_from_reverse_dns_base_domain(
path = local_file_path path = local_file_path
with open(path) as csv_file: with open(path) as csv_file:
load_csv(csv_file) load_csv(csv_file)
service: ReverseDNSService
try: try:
service = reverse_dns_map[base_domain] service = reverse_dns_map_value[base_domain]
except KeyError: except KeyError:
service = dict(name=base_domain, type=None) service = {"name": base_domain, "type": None}
return service return service
@@ -415,11 +428,11 @@ def get_ip_address_info(
always_use_local_files: Optional[bool] = False, always_use_local_files: Optional[bool] = False,
reverse_dns_map_url: Optional[str] = None, reverse_dns_map_url: Optional[str] = None,
cache: Optional[ExpiringDict] = None, cache: Optional[ExpiringDict] = None,
reverse_dns_map: Optional[dict] = None, reverse_dns_map: Optional[ReverseDNSMap] = None,
offline: Optional[bool] = False, offline: Optional[bool] = False,
nameservers: Optional[list[str]] = None, nameservers: Optional[list[str]] = None,
timeout: Optional[float] = 2.0, timeout: float = 2.0,
) -> dict[str, str]: ) -> IPAddressInfo:
""" """
Returns reverse DNS and country information for the given IP address Returns reverse DNS and country information for the given IP address
@@ -442,12 +455,22 @@ def get_ip_address_info(
""" """
ip_address = ip_address.lower() ip_address = ip_address.lower()
if cache is not None: if cache is not None:
info = cache.get(ip_address, None) cached_info = cache.get(ip_address, None)
if info: if (
cached_info
and isinstance(cached_info, dict)
and "ip_address" in cached_info
):
logger.debug(f"IP address {ip_address} was found in cache") logger.debug(f"IP address {ip_address} was found in cache")
return info return cast(IPAddressInfo, cached_info)
info: dict[str, str] = {} info: IPAddressInfo = {
info["ip_address"] = ip_address "ip_address": ip_address,
"reverse_dns": None,
"country": None,
"base_domain": None,
"name": None,
"type": None,
}
if offline: if offline:
reverse_dns = None reverse_dns = None
else: else:
@@ -457,9 +480,6 @@ def get_ip_address_info(
country = get_ip_address_country(ip_address, db_path=ip_db_path) country = get_ip_address_country(ip_address, db_path=ip_db_path)
info["country"] = country info["country"] = country
info["reverse_dns"] = reverse_dns info["reverse_dns"] = reverse_dns
info["base_domain"] = None
info["name"] = None
info["type"] = None
if reverse_dns is not None: if reverse_dns is not None:
base_domain = get_base_domain(reverse_dns) base_domain = get_base_domain(reverse_dns)
if base_domain is not None: if base_domain is not None:
@@ -484,7 +504,7 @@ def get_ip_address_info(
return info return info
def parse_email_address(original_address: str) -> dict[str, str]: def parse_email_address(original_address: str) -> dict[str, Optional[str]]:
if original_address[0] == "": if original_address[0] == "":
display_name = None display_name = None
else: else:
@@ -563,7 +583,7 @@ def is_outlook_msg(content) -> bool:
) )
def convert_outlook_msg(msg_bytes: bytes) -> str: def convert_outlook_msg(msg_bytes: bytes) -> bytes:
""" """
Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to
standard RFC 822 format standard RFC 822 format
@@ -572,7 +592,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> str:
msg_bytes (bytes): the content of the .msg file msg_bytes (bytes): the content of the .msg file
Returns: Returns:
A RFC 822 string A RFC 822 bytes payload
""" """
if not is_outlook_msg(msg_bytes): if not is_outlook_msg(msg_bytes):
raise ValueError("The supplied bytes are not an Outlook MSG file") raise ValueError("The supplied bytes are not an Outlook MSG file")