diff --git a/.vscode/settings.json b/.vscode/settings.json index 379bd1f..366dc9f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,148 +13,154 @@ "MD024": false }, "cSpell.words": [ - "adkim", - "akamaiedge", - "amsmath", - "andrewmcgilvray", - "arcname", - "aspf", - "autoclass", - "automodule", - "backported", - "bellsouth", - "boto", - "brakhane", - "Brightmail", - "CEST", - "CHACHA", - "checkdmarc", - "Codecov", - "confnew", - "dateparser", - "dateutil", - "Davmail", - "DBIP", - "dearmor", - "deflist", - "devel", - "DMARC", - "Dmarcian", - "dnspython", - "dollarmath", - "dpkg", - "exampleuser", - "expiringdict", - "fieldlist", - "GELF", - "genindex", - "geoip", - "geoipupdate", - "Geolite", - "geolocation", - "githubpages", - "Grafana", - "hostnames", - "htpasswd", - "httpasswd", - "httplib", - "IMAP", - "imapclient", - "infile", - "Interaktive", - "IPDB", - "journalctl", - "keepalive", - "keyout", - "keyrings", - "Leeman", - "libemail", - "linkify", - "LISTSERV", - "lxml", - "mailparser", - "mailrelay", - "mailsuite", - "maxdepth", - "MAXHEADERS", - "maxmind", - "mbox", - "mfrom", - "michaeldavie", - "mikesiegel", - "Mimecast", - "mitigations", - "MMDB", - "modindex", - "msgconvert", - "msgraph", - "MSSP", - "multiprocess", - "Munge", - "ndjson", - "newkey", - "Nhcm", - "nojekyll", - "nondigest", - "nosecureimap", - "nosniff", - "nwettbewerb", - "opensearch", - "opensearchpy", - "parsedmarc", - "passsword", - "Postorius", - "premade", - "procs", - "publicsuffix", - "publicsuffixlist", - "publixsuffix", - "pygelf", - "pypy", - "pytest", - "quickstart", - "Reindex", - "replyto", - "reversename", - "Rollup", - "Rpdm", - "SAMEORIGIN", - "sdist", - "Servernameone", - "setuptools", - "smartquotes", - "SMTPTLS", - "sortlists", - "sortmaps", - "sourcetype", - "STARTTLS", - "tasklist", - "timespan", - "tlsa", - "tlsrpt", - "toctree", - "TQDDM", - "tqdm", - "truststore", - "Übersicht", - "uids", - "Uncategorized", - "unparasable", - "uper", - "urllib", - "Valimail", - "venv", - "Vhcw", - "viewcode", - "virtualenv", - "WBITS", - "webmail", - "Wettbewerber", - "Whalen", - "whitespaces", - "xennn", - "xmltodict", - "xpack", - "zscholl" + "adkim", + "akamaiedge", + "amsmath", + "andrewmcgilvray", + "arcname", + "aspf", + "autoclass", + "automodule", + "backported", + "bellsouth", + "boto", + "brakhane", + "Brightmail", + "CEST", + "CHACHA", + "checkdmarc", + "Codecov", + "confnew", + "dateparser", + "dateutil", + "Davmail", + "DBIP", + "dearmor", + "deflist", + "devel", + "DMARC", + "Dmarcian", + "dnspython", + "dollarmath", + "dpkg", + "exampleuser", + "expiringdict", + "fieldlist", + "GELF", + "genindex", + "geoip", + "geoipupdate", + "Geolite", + "geolocation", + "githubpages", + "Grafana", + "hostnames", + "htpasswd", + "httpasswd", + "httplib", + "ifhost", + "IMAP", + "imapclient", + "infile", + "Interaktive", + "IPDB", + "journalctl", + "kafkaclient", + "keepalive", + "keyout", + "keyrings", + "Leeman", + "libemail", + "linkify", + "LISTSERV", + "loganalytics", + "lxml", + "mailparser", + "mailrelay", + "mailsuite", + "maxdepth", + "MAXHEADERS", + "maxmind", + "mbox", + "mfrom", + "mhdw", + "michaeldavie", + "mikesiegel", + "Mimecast", + "mitigations", + "MMDB", + "modindex", + "msgconvert", + "msgraph", + "MSSP", + "multiprocess", + "Munge", + "ndjson", + "newkey", + "Nhcm", + "nojekyll", + "nondigest", + "nosecureimap", + "nosniff", + "nwettbewerb", + "opensearch", + "opensearchpy", + "parsedmarc", + "passsword", + "pbar", + "Postorius", + "premade", + "privatesuffix", + "procs", + "publicsuffix", + "publicsuffixlist", + "publixsuffix", + "pygelf", + "pypy", + "pytest", + "quickstart", + "Reindex", + "replyto", + "reversename", + "Rollup", + "Rpdm", + "SAMEORIGIN", + "sdist", + "Servernameone", + "setuptools", + "smartquotes", + "SMTPTLS", + "sortlists", + "sortmaps", + "sourcetype", + "STARTTLS", + "tasklist", + "timespan", + "tlsa", + "tlsrpt", + "toctree", + "TQDDM", + "tqdm", + "truststore", + "Übersicht", + "uids", + "Uncategorized", + "unparasable", + "uper", + "urllib", + "Valimail", + "venv", + "Vhcw", + "viewcode", + "virtualenv", + "WBITS", + "webmail", + "Wettbewerber", + "Whalen", + "whitespaces", + "xennn", + "xmltodict", + "xpack", + "zscholl" ], } \ No newline at end of file diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 919d33f..8ef96b6 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -18,20 +18,17 @@ import zipfile import zlib from base64 import b64decode from csv import DictWriter -from datetime import datetime, timedelta, timezone, tzinfo +from datetime import date, datetime, timedelta, timezone, tzinfo from io import BytesIO, StringIO from typing import ( - IO, Any, BinaryIO, Callable, Dict, List, Optional, - Protocol, Union, cast, - runtime_checkable, ) import lxml.etree as etree @@ -864,14 +861,7 @@ def parse_aggregate_report_xml( raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__())) -@runtime_checkable -class _ReadableSeekable(Protocol): - def read(self, n: int = -1) -> bytes: ... - def seek(self, offset: int, whence: int = 0) -> int: ... - def tell(self) -> int: ... - - -def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) -> str: +def extract_report(content: Union[bytes, str, BinaryIO]) -> str: """ Extracts text from a zip or gzip file, as a base64-encoded string, file-like object, or bytes. @@ -884,31 +874,59 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) str: The extracted text """ - file_object: Optional[_ReadableSeekable] = None + file_object: Optional[BinaryIO] = None + header: bytes try: if isinstance(content, str): try: file_object = BytesIO(b64decode(content)) except binascii.Error: return content - elif isinstance(content, (bytes, bytearray, memoryview)): + header = file_object.read(6) + file_object.seek(0) + elif isinstance(content, (bytes)): file_object = BytesIO(bytes(content)) + header = file_object.read(6) + file_object.seek(0) else: - file_object = cast(_ReadableSeekable, content) + stream = cast(BinaryIO, content) + seekable = getattr(stream, "seekable", None) + can_seek = False + if callable(seekable): + try: + can_seek = bool(seekable()) + except Exception: + can_seek = False - header = file_object.read(6) - if isinstance(header, str): - raise ParserError("File objects must be opened in binary (rb) mode") + if can_seek: + header_raw = stream.read(6) + if isinstance(header_raw, str): + raise ParserError("File objects must be opened in binary (rb) mode") + header = bytes(header_raw) + stream.seek(0) + file_object = stream + else: + header_raw = stream.read(6) + if isinstance(header_raw, str): + raise ParserError("File objects must be opened in binary (rb) mode") + header = bytes(header_raw) + remainder = stream.read() + file_object = BytesIO(header + bytes(remainder)) - file_object.seek(0) - if header.startswith(MAGIC_ZIP): + if file_object is None: + raise ParserError("Invalid report content") + + if header[: len(MAGIC_ZIP)] == MAGIC_ZIP: _zip = zipfile.ZipFile(file_object) report = _zip.open(_zip.namelist()[0]).read().decode(errors="ignore") - elif header.startswith(MAGIC_GZIP): + elif header[: len(MAGIC_GZIP)] == MAGIC_GZIP: report = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode( errors="ignore" ) - elif header.startswith(MAGIC_XML) or header.startswith(MAGIC_JSON): + elif ( + header[: len(MAGIC_XML)] == MAGIC_XML + or header[: len(MAGIC_JSON)] == MAGIC_JSON + ): report = file_object.read().decode(errors="ignore") else: raise ParserError("Not a valid zip, gzip, json, or xml file") @@ -918,7 +936,7 @@ def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) except Exception as error: raise ParserError("Invalid archive file: {0}".format(error.__str__())) finally: - if file_object and hasattr(file_object, "close"): + if file_object: try: file_object.close() except Exception: @@ -937,17 +955,17 @@ def extract_report_from_file_path(file_path: str): def parse_aggregate_report_file( - _input: Union[str, bytes, IO[Any]], + _input: Union[str, bytes, BinaryIO], *, - offline: Optional[bool] = False, - always_use_local_files: Optional[bool] = None, + offline: bool = False, + always_use_local_files: bool = False, reverse_dns_map_path: Optional[str] = None, reverse_dns_map_url: Optional[str] = None, ip_db_path: Optional[str] = None, nameservers: Optional[list[str]] = None, - dns_timeout: Optional[float] = 2.0, + dns_timeout: float = 2.0, keep_alive: Optional[Callable] = None, - normalize_timespan_threshold_hours: Optional[float] = 24.0, + normalize_timespan_threshold_hours: float = 24.0, ) -> dict[str, Any]: """Parses a file at the given path, a file-like object. or bytes as an aggregate DMARC report @@ -1187,14 +1205,14 @@ def parse_forensic_report( sample: str, msg_date: datetime, *, - always_use_local_files: Optional[bool] = False, + always_use_local_files: bool = False, reverse_dns_map_path: Optional[str] = None, reverse_dns_map_url: Optional[str] = None, - offline: Optional[bool] = False, + offline: bool = False, ip_db_path: Optional[str] = None, nameservers: Optional[list[str]] = None, - dns_timeout: Optional[float] = 2.0, - strip_attachment_payloads: Optional[bool] = False, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, ) -> dict[str, Any]: """ Converts a DMARC forensic report and sample to a dict @@ -1449,20 +1467,32 @@ def parse_report_email( * ``report_type``: ``aggregate`` or ``forensic`` * ``report``: The parsed report """ - result = None + result: Optional[dict[str, Any]] = None msg_date: datetime = datetime.now(timezone.utc) try: - if isinstance(input_, bytes) and is_outlook_msg(input_): - input_ = convert_outlook_msg(input_) - if isinstance(input_, bytes): - input_ = input_.decode(encoding="utf8", errors="replace") - msg = mailparser.parse_from_string(input_) + input_data: Union[str, bytes, bytearray, memoryview] = input_ + if isinstance(input_data, (bytes, bytearray, memoryview)): + input_bytes = bytes(input_data) + if is_outlook_msg(input_bytes): + converted = convert_outlook_msg(input_bytes) + if isinstance(converted, str): + input_str = converted + else: + input_str = bytes(converted).decode( + encoding="utf8", errors="replace" + ) + else: + input_str = input_bytes.decode(encoding="utf8", errors="replace") + else: + input_str = input_data + + msg = mailparser.parse_from_string(input_str) msg_headers = json.loads(msg.headers_json) if "Date" in msg_headers: msg_date = human_timestamp_to_datetime(msg_headers["Date"]) date = email.utils.format_datetime(msg_date) - msg = email.message_from_string(input_) + msg = email.message_from_string(input_str) except Exception as e: raise ParserError(e.__str__()) @@ -1477,10 +1507,10 @@ def parse_report_email( subject = msg_headers["Subject"] for part in msg.walk(): content_type = part.get_content_type().lower() - payload = part.get_payload() - if not isinstance(payload, list): - payload = [payload] - payload = payload[0].__str__() + payload_obj = part.get_payload() + if not isinstance(payload_obj, list): + payload_obj = [payload_obj] + payload = str(payload_obj[0]) if content_type.startswith("multipart/"): continue if content_type == "text/html": @@ -1501,7 +1531,7 @@ def parse_report_email( sample = payload elif content_type == "application/tlsrpt+json": if not payload.strip().startswith("{"): - payload = str(b64decode(payload)) + payload = b64decode(payload).decode("utf-8", errors="replace") smtp_tls_report = parse_smtp_tls_report_json(payload) return {"report_type": "smtp_tls", "report": smtp_tls_report} elif content_type == "application/tlsrpt+gzip": @@ -1531,18 +1561,21 @@ def parse_report_email( logger.debug(sample) else: try: - payload = b64decode(payload) - if payload.startswith(MAGIC_ZIP) or payload.startswith(MAGIC_GZIP): - payload = extract_report(payload) - if isinstance(payload, bytes): - payload = payload.decode("utf-8", errors="replace") - if payload.strip().startswith("{"): - smtp_tls_report = parse_smtp_tls_report_json(payload) + payload_bytes = b64decode(payload) + if payload_bytes.startswith(MAGIC_ZIP) or payload_bytes.startswith( + MAGIC_GZIP + ): + payload_text = extract_report(payload_bytes) + else: + payload_text = payload_bytes.decode("utf-8", errors="replace") + + if payload_text.strip().startswith("{"): + smtp_tls_report = parse_smtp_tls_report_json(payload_text) result = {"report_type": "smtp_tls", "report": smtp_tls_report} return result - elif payload.strip().startswith("<"): + elif payload_text.strip().startswith("<"): aggregate_report = parse_aggregate_report_xml( - payload, + payload_text, ip_db_path=ip_db_path, always_use_local_files=always_use_local_files, reverse_dns_map_path=reverse_dns_map_path, @@ -1604,26 +1637,28 @@ def parse_report_email( error = 'Message with subject "{0}" is not a valid report'.format(subject) raise InvalidDMARCReport(error) + return result + def parse_report_file( - input_: Union[bytes, str, IO[Any]], + input_: Union[bytes, str, BinaryIO], *, nameservers: Optional[list[str]] = None, - dns_timeout: Optional[float] = 2.0, - strip_attachment_payloads: Optional[bool] = False, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, ip_db_path: Optional[str] = None, - always_use_local_files: Optional[bool] = False, + always_use_local_files: bool = False, reverse_dns_map_path: Optional[str] = None, reverse_dns_map_url: Optional[str] = None, - offline: Optional[bool] = False, + offline: bool = False, keep_alive: Optional[Callable] = None, - normalize_timespan_threshold_hours: Optional[float] = 24, + normalize_timespan_threshold_hours: float = 24, ) -> dict[str, Any]: """Parses a DMARC aggregate or forensic file at the given path, a file-like object. or bytes Args: - input_ (str | bytes | IO): A path to a file, a file like object, or bytes + input_ (str | bytes | BinaryIO): A path to a file, a file like object, or bytes nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout (float): Sets the DNS timeout in seconds @@ -1639,11 +1674,12 @@ def parse_report_file( Returns: dict: The parsed DMARC report """ - if type(input_) is str: + file_object: BinaryIO + if isinstance(input_, str): logger.debug("Parsing {0}".format(input_)) file_object = open(input_, "rb") - elif type(input_) is bytes: - file_object = BytesIO(input_) + elif isinstance(input_, (bytes, bytearray, memoryview)): + file_object = BytesIO(bytes(input_)) else: file_object = input_ @@ -1653,6 +1689,7 @@ def parse_report_file( content = extract_report(content) results: Optional[dict[str, Any]] = None + try: report = parse_aggregate_report_file( content, @@ -1673,7 +1710,6 @@ def parse_report_file( results = {"report_type": "smtp_tls", "report": report} except InvalidSMTPTLSReport: try: - sa = strip_attachment_payloads results = parse_report_email( content, ip_db_path=ip_db_path, @@ -1683,7 +1719,7 @@ def parse_report_file( offline=offline, nameservers=nameservers, dns_timeout=dns_timeout, - strip_attachment_payloads=sa, + strip_attachment_payloads=strip_attachment_payloads, keep_alive=keep_alive, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours, ) @@ -1799,7 +1835,7 @@ def get_dmarc_reports_from_mailbox( strip_attachment_payloads: bool = False, results: Optional[dict[str, Any]] = None, batch_size: int = 10, - since: Optional[Union[datetime, str]] = None, + since: Optional[Union[datetime, date, str]] = None, create_folders: bool = True, normalize_timespan_threshold_hours: float = 24, ) -> dict[str, list[dict[str, Any]]]: @@ -1840,7 +1876,7 @@ def get_dmarc_reports_from_mailbox( raise ValueError("Must supply a connection") # current_time useful to fetch_messages later in the program - current_time = None + current_time: Optional[Union[datetime, date, str]] = None aggregate_reports = [] forensic_reports = [] @@ -1865,7 +1901,7 @@ def get_dmarc_reports_from_mailbox( connection.create_folder(smtp_tls_reports_folder) connection.create_folder(invalid_reports_folder) - if since: + if since and isinstance(since, str): _since = 1440 # default one day if re.match(r"\d+[mhdw]$", since): s = re.split(r"(\d+)", since) @@ -1926,13 +1962,16 @@ def get_dmarc_reports_from_mailbox( i + 1, message_limit, msg_uid ) ) - if isinstance(mailbox, MSGraphConnection): - if test: - msg_content = connection.fetch_message(msg_uid, mark_read=False) - else: - msg_content = connection.fetch_message(msg_uid, mark_read=True) + message_id: Union[int, str] + if isinstance(connection, IMAPConnection): + message_id = int(msg_uid) + msg_content = connection.fetch_message(message_id) + elif isinstance(connection, MSGraphConnection): + message_id = str(msg_uid) + msg_content = connection.fetch_message(message_id, mark_read=not test) else: - msg_content = connection.fetch_message(msg_uid) + message_id = str(msg_uid) if not isinstance(msg_uid, str) else msg_uid + msg_content = connection.fetch_message(message_id) try: sa = strip_attachment_payloads parsed_email = parse_report_email( @@ -1959,26 +1998,32 @@ def get_dmarc_reports_from_mailbox( logger.debug( f"Skipping duplicate aggregate report with ID: {report_id}" ) - aggregate_report_msg_uids.append(msg_uid) + aggregate_report_msg_uids.append(message_id) elif parsed_email["report_type"] == "forensic": forensic_reports.append(parsed_email["report"]) - forensic_report_msg_uids.append(msg_uid) + forensic_report_msg_uids.append(message_id) elif parsed_email["report_type"] == "smtp_tls": smtp_tls_reports.append(parsed_email["report"]) - smtp_tls_msg_uids.append(msg_uid) + smtp_tls_msg_uids.append(message_id) except ParserError as error: logger.warning(error.__str__()) if not test: if delete: logger.debug("Deleting message UID {0}".format(msg_uid)) - connection.delete_message(msg_uid) + if isinstance(connection, IMAPConnection): + connection.delete_message(int(message_id)) + else: + connection.delete_message(str(message_id)) else: logger.debug( "Moving message UID {0} to {1}".format( msg_uid, invalid_reports_folder ) ) - connection.move_message(msg_uid, invalid_reports_folder) + if isinstance(connection, IMAPConnection): + connection.move_message(int(message_id), invalid_reports_folder) + else: + connection.move_message(str(message_id), invalid_reports_folder) if not test: if delete: @@ -2106,21 +2151,21 @@ def watch_inbox( mailbox_connection: MailboxConnection, callback: Callable, *, - reports_folder: Optional[str] = "INBOX", - archive_folder: Optional[str] = "Archive", - delete: Optional[bool] = False, - test: Optional[bool] = False, - check_timeout: Optional[int] = 30, + reports_folder: str = "INBOX", + archive_folder: str = "Archive", + delete: bool = False, + test: bool = False, + check_timeout: int = 30, ip_db_path: Optional[str] = None, - always_use_local_files: Optional[bool] = False, + always_use_local_files: bool = False, reverse_dns_map_path: Optional[str] = None, reverse_dns_map_url: Optional[str] = None, - offline: Optional[bool] = False, + offline: bool = False, nameservers: Optional[list[str]] = None, - dns_timeout: Optional[float] = 6.0, - strip_attachment_payloads: Optional[bool] = False, - batch_size: Optional[int] = None, - normalize_timespan_threshold_hours: Optional[float] = 24, + dns_timeout: float = 6.0, + strip_attachment_payloads: bool = False, + batch_size: int = 10, + normalize_timespan_threshold_hours: float = 24, ): """ Watches the mailbox for new messages and @@ -2150,7 +2195,6 @@ def watch_inbox( """ def check_callback(connection): - sa = strip_attachment_payloads res = get_dmarc_reports_from_mailbox( connection=connection, reports_folder=reports_folder, @@ -2164,7 +2208,7 @@ def watch_inbox( offline=offline, nameservers=nameservers, dns_timeout=dns_timeout, - strip_attachment_payloads=sa, + strip_attachment_payloads=strip_attachment_payloads, batch_size=batch_size, create_folders=False, normalize_timespan_threshold_hours=normalize_timespan_threshold_hours, @@ -2212,13 +2256,13 @@ def append_csv(filename, csv): def save_output( results: dict[str, Any], *, - output_directory: Optional[str] = "output", - aggregate_json_filename: Optional[str] = "aggregate.json", - forensic_json_filename: Optional[str] = "forensic.json", - smtp_tls_json_filename: Optional[str] = "smtp_tls.json", - aggregate_csv_filename: Optional[str] = "aggregate.csv", - forensic_csv_filename: Optional[str] = "forensic.csv", - smtp_tls_csv_filename: Optional[str] = "smtp_tls.csv", + output_directory: str = "output", + aggregate_json_filename: str = "aggregate.json", + forensic_json_filename: str = "forensic.json", + smtp_tls_json_filename: str = "smtp_tls.json", + aggregate_csv_filename: str = "aggregate.csv", + forensic_csv_filename: str = "forensic.csv", + smtp_tls_csv_filename: str = "smtp_tls.csv", ): """ Save report data in the given directory @@ -2322,7 +2366,7 @@ def get_report_zip(results: dict[str, Any]) -> bytes: storage = BytesIO() tmp_dir = tempfile.mkdtemp() try: - save_output(results, tmp_dir) + save_output(results, output_directory=tmp_dir) with zipfile.ZipFile(storage, "w", zipfile.ZIP_DEFLATED) as zip_file: for root, dirs, files in os.walk(tmp_dir): for file in files: @@ -2345,10 +2389,10 @@ def email_results( results: dict[str, Any], host: str, mail_from: str, - mail_to: str, + mail_to: Optional[list[str]], *, - mail_cc: Optional[list] = None, - mail_bcc: Optional[list] = None, + mail_cc: Optional[list[str]] = None, + mail_bcc: Optional[list[str]] = None, port: int = 0, require_encryption: bool = False, verify: bool = True, @@ -2377,7 +2421,7 @@ def email_results( attachment_filename (str): Override the default attachment filename message (str): Override the default plain text body """ - logger.debug("Emailing report to: {0}".format(",".join(mail_to))) + logger.debug("Emailing report") date_string = datetime.now().strftime("%Y-%m-%d") if attachment_filename: if not attachment_filename.lower().endswith(".zip"): diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index c7eeead..4361499 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -678,7 +678,7 @@ def _main(): if "general" in config.sections(): general_config = config["general"] if "silent" in general_config: - opts.silent = general_config.getboolean("silent") + opts.silent = bool(general_config.getboolean("silent")) if "normalize_timespan_threshold_hours" in general_config: opts.normalize_timespan_threshold_hours = general_config.getfloat( "normalize_timespan_threshold_hours" @@ -687,10 +687,10 @@ def _main(): with open(general_config["index_prefix_domain_map"]) as f: index_prefix_domain_map = yaml.safe_load(f) if "offline" in general_config: - opts.offline = general_config.getboolean("offline") + opts.offline = bool(general_config.getboolean("offline")) if "strip_attachment_payloads" in general_config: - opts.strip_attachment_payloads = general_config.getboolean( - "strip_attachment_payloads" + opts.strip_attachment_payloads = bool( + general_config.getboolean("strip_attachment_payloads") ) if "output" in general_config: opts.output = general_config["output"] @@ -732,19 +732,19 @@ def _main(): ) exit(-1) if "save_aggregate" in general_config: - opts.save_aggregate = general_config.getboolean("save_aggregate") + opts.save_aggregate = bool(general_config.getboolean("save_aggregate")) if "save_forensic" in general_config: - opts.save_forensic = general_config.getboolean("save_forensic") + opts.save_forensic = bool(general_config.getboolean("save_forensic")) if "save_smtp_tls" in general_config: - opts.save_smtp_tls = general_config.getboolean("save_smtp_tls") + opts.save_smtp_tls = bool(general_config.getboolean("save_smtp_tls")) if "debug" in general_config: - opts.debug = general_config.getboolean("debug") + opts.debug = bool(general_config.getboolean("debug")) if "verbose" in general_config: - opts.verbose = general_config.getboolean("verbose") + opts.verbose = bool(general_config.getboolean("verbose")) if "silent" in general_config: - opts.silent = general_config.getboolean("silent") + opts.silent = bool(general_config.getboolean("silent")) if "warnings" in general_config: - opts.warnings = general_config.getboolean("warnings") + opts.warnings = bool(general_config.getboolean("warnings")) if "log_file" in general_config: opts.log_file = general_config["log_file"] if "n_procs" in general_config: @@ -754,15 +754,15 @@ def _main(): else: opts.ip_db_path = None if "always_use_local_files" in general_config: - opts.always_use_local_files = general_config.getboolean( - "always_use_local_files" + opts.always_use_local_files = bool( + general_config.getboolean("always_use_local_files") ) if "reverse_dns_map_path" in general_config: opts.reverse_dns_map_path = general_config["reverse_dns_path"] if "reverse_dns_map_url" in general_config: opts.reverse_dns_map_url = general_config["reverse_dns_url"] if "prettify_json" in general_config: - opts.prettify_json = general_config.getboolean("prettify_json") + opts.prettify_json = bool(general_config.getboolean("prettify_json")) if "mailbox" in config.sections(): mailbox_config = config["mailbox"] @@ -773,11 +773,11 @@ def _main(): if "archive_folder" in mailbox_config: opts.mailbox_archive_folder = mailbox_config["archive_folder"] if "watch" in mailbox_config: - opts.mailbox_watch = mailbox_config.getboolean("watch") + opts.mailbox_watch = bool(mailbox_config.getboolean("watch")) if "delete" in mailbox_config: - opts.mailbox_delete = mailbox_config.getboolean("delete") + opts.mailbox_delete = bool(mailbox_config.getboolean("delete")) if "test" in mailbox_config: - opts.mailbox_test = mailbox_config.getboolean("test") + opts.mailbox_test = bool(mailbox_config.getboolean("test")) if "batch_size" in mailbox_config: opts.mailbox_batch_size = mailbox_config.getint("batch_size") if "check_timeout" in mailbox_config: @@ -805,10 +805,10 @@ def _main(): if "max_retries" in imap_config: opts.imap_max_retries = imap_config.getint("max_retries") if "ssl" in imap_config: - opts.imap_ssl = imap_config.getboolean("ssl") + opts.imap_ssl = bool(imap_config.getboolean("ssl")) if "skip_certificate_verification" in imap_config: - opts.imap_skip_certificate_verification = imap_config.getboolean( - "skip_certificate_verification" + opts.imap_skip_certificate_verification = bool( + imap_config.getboolean("skip_certificate_verification") ) if "user" in imap_config: opts.imap_user = imap_config["user"] @@ -837,7 +837,7 @@ def _main(): "section instead." ) if "watch" in imap_config: - opts.mailbox_watch = imap_config.getboolean("watch") + opts.mailbox_watch = bool(imap_config.getboolean("watch")) logger.warning( "Use of the watch option in the imap " "configuration section has been deprecated. " @@ -852,7 +852,7 @@ def _main(): "section instead." ) if "test" in imap_config: - opts.mailbox_test = imap_config.getboolean("test") + opts.mailbox_test = bool(imap_config.getboolean("test")) logger.warning( "Use of the test option in the imap " "configuration section has been deprecated. " @@ -946,8 +946,8 @@ def _main(): opts.graph_url = graph_config["graph_url"] if "allow_unencrypted_storage" in graph_config: - opts.graph_allow_unencrypted_storage = graph_config.getboolean( - "allow_unencrypted_storage" + opts.graph_allow_unencrypted_storage = bool( + graph_config.getboolean("allow_unencrypted_storage") ) if "elasticsearch" in config: @@ -975,10 +975,10 @@ def _main(): if "index_prefix" in elasticsearch_config: opts.elasticsearch_index_prefix = elasticsearch_config["index_prefix"] if "monthly_indexes" in elasticsearch_config: - monthly = elasticsearch_config.getboolean("monthly_indexes") + monthly = bool(elasticsearch_config.getboolean("monthly_indexes")) opts.elasticsearch_monthly_indexes = monthly if "ssl" in elasticsearch_config: - opts.elasticsearch_ssl = elasticsearch_config.getboolean("ssl") + opts.elasticsearch_ssl = bool(elasticsearch_config.getboolean("ssl")) if "cert_path" in elasticsearch_config: opts.elasticsearch_ssl_cert_path = elasticsearch_config["cert_path"] if "user" in elasticsearch_config: @@ -1015,10 +1015,10 @@ def _main(): if "index_prefix" in opensearch_config: opts.opensearch_index_prefix = opensearch_config["index_prefix"] if "monthly_indexes" in opensearch_config: - monthly = opensearch_config.getboolean("monthly_indexes") + monthly = bool(opensearch_config.getboolean("monthly_indexes")) opts.opensearch_monthly_indexes = monthly if "ssl" in opensearch_config: - opts.opensearch_ssl = opensearch_config.getboolean("ssl") + opts.opensearch_ssl = bool(opensearch_config.getboolean("ssl")) if "cert_path" in opensearch_config: opts.opensearch_ssl_cert_path = opensearch_config["cert_path"] if "user" in opensearch_config: @@ -1072,9 +1072,11 @@ def _main(): if "password" in kafka_config: opts.kafka_password = kafka_config["password"] if "ssl" in kafka_config: - opts.kafka_ssl = kafka_config.getboolean("ssl") + opts.kafka_ssl = bool(kafka_config.getboolean("ssl")) if "skip_certificate_verification" in kafka_config: - kafka_verify = kafka_config.getboolean("skip_certificate_verification") + kafka_verify = bool( + kafka_config.getboolean("skip_certificate_verification") + ) opts.kafka_skip_certificate_verification = kafka_verify if "aggregate_topic" in kafka_config: opts.kafka_aggregate_topic = kafka_config["aggregate_topic"] @@ -1106,9 +1108,11 @@ def _main(): if "port" in smtp_config: opts.smtp_port = smtp_config.getint("port") if "ssl" in smtp_config: - opts.smtp_ssl = smtp_config.getboolean("ssl") + opts.smtp_ssl = bool(smtp_config.getboolean("ssl")) if "skip_certificate_verification" in smtp_config: - smtp_verify = smtp_config.getboolean("skip_certificate_verification") + smtp_verify = bool( + smtp_config.getboolean("skip_certificate_verification") + ) opts.smtp_skip_certificate_verification = smtp_verify if "user" in smtp_config: opts.smtp_user = smtp_config["user"] @@ -1176,11 +1180,11 @@ def _main(): gmail_api_config = config["gmail_api"] opts.gmail_api_credentials_file = gmail_api_config.get("credentials_file") opts.gmail_api_token_file = gmail_api_config.get("token_file", ".token") - opts.gmail_api_include_spam_trash = gmail_api_config.getboolean( - "include_spam_trash", False + opts.gmail_api_include_spam_trash = bool( + gmail_api_config.getboolean("include_spam_trash", False) ) - opts.gmail_api_paginate_messages = gmail_api_config.getboolean( - "paginate_messages", True + opts.gmail_api_paginate_messages = bool( + gmail_api_config.getboolean("paginate_messages", True) ) opts.gmail_api_scopes = gmail_api_config.get( "scopes", default_gmail_api_scope @@ -1194,8 +1198,8 @@ def _main(): if "maildir" in config.sections(): maildir_api_config = config["maildir"] opts.maildir_path = maildir_api_config.get("maildir_path") - opts.maildir_create = maildir_api_config.getboolean( - "maildir_create", fallback=False + opts.maildir_create = bool( + maildir_api_config.getboolean("maildir_create", fallback=False) ) if "log_analytics" in config.sections(): @@ -1274,6 +1278,7 @@ def _main(): exit(1) logger.info("Starting parsedmarc") + if opts.save_aggregate or opts.save_forensic or opts.save_smtp_tls: try: @@ -1513,6 +1518,11 @@ def _main(): smtp_tls_reports.append(result[0]["report"]) for mbox_path in mbox_paths: + normalize_timespan_threshold_hours_value = ( + float(opts.normalize_timespan_threshold_hours) + if opts.normalize_timespan_threshold_hours is not None + else 24.0 + ) strip = opts.strip_attachment_payloads reports = get_dmarc_reports_from_mbox( mbox_path, @@ -1524,13 +1534,17 @@ def _main(): reverse_dns_map_path=opts.reverse_dns_map_path, reverse_dns_map_url=opts.reverse_dns_map_url, offline=opts.offline, - normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, + normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value, ) aggregate_reports += reports["aggregate_reports"] forensic_reports += reports["forensic_reports"] smtp_tls_reports += reports["smtp_tls_reports"] mailbox_connection = None + mailbox_batch_size_value = 10 + mailbox_check_timeout_value = 30 + normalize_timespan_threshold_hours_value = 24.0 + if opts.imap_host: try: if opts.imap_user is None or opts.imap_password is None: @@ -1552,9 +1566,10 @@ def _main(): imap_max_retries = ( int(opts.imap_max_retries) if opts.imap_max_retries is not None else 4 ) + imap_port_value = int(opts.imap_port) if opts.imap_port is not None else 993 mailbox_connection = IMAPConnection( host=opts.imap_host, - port=opts.imap_port, + port=imap_port_value, ssl=ssl, verify=verify, timeout=imap_timeout, @@ -1624,11 +1639,24 @@ def _main(): exit(1) if mailbox_connection: + mailbox_batch_size_value = ( + int(opts.mailbox_batch_size) if opts.mailbox_batch_size is not None else 10 + ) + mailbox_check_timeout_value = ( + int(opts.mailbox_check_timeout) + if opts.mailbox_check_timeout is not None + else 30 + ) + normalize_timespan_threshold_hours_value = ( + float(opts.normalize_timespan_threshold_hours) + if opts.normalize_timespan_threshold_hours is not None + else 24.0 + ) try: reports = get_dmarc_reports_from_mailbox( connection=mailbox_connection, delete=opts.mailbox_delete, - batch_size=opts.mailbox_batch_size, + batch_size=mailbox_batch_size_value, reports_folder=opts.mailbox_reports_folder, archive_folder=opts.mailbox_archive_folder, ip_db_path=opts.ip_db_path, @@ -1640,7 +1668,7 @@ def _main(): test=opts.mailbox_test, strip_attachment_payloads=opts.strip_attachment_payloads, since=opts.mailbox_since, - normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, + normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value, ) aggregate_reports += reports["aggregate_reports"] @@ -1664,12 +1692,18 @@ def _main(): verify = True if opts.smtp_skip_certificate_verification: verify = False + smtp_port_value = int(opts.smtp_port) if opts.smtp_port is not None else 25 + smtp_to_value = ( + list(opts.smtp_to) + if isinstance(opts.smtp_to, list) + else _str_to_list(str(opts.smtp_to)) + ) email_results( results, opts.smtp_host, opts.smtp_from, - opts.smtp_to, - port=opts.smtp_port, + smtp_to_value, + port=smtp_port_value, verify=verify, username=opts.smtp_user, password=opts.smtp_password, @@ -1683,6 +1717,7 @@ def _main(): if mailbox_connection and opts.mailbox_watch: logger.info("Watching for email - Quit with ctrl-c") + try: watch_inbox( mailbox_connection=mailbox_connection, @@ -1691,17 +1726,17 @@ def _main(): archive_folder=opts.mailbox_archive_folder, delete=opts.mailbox_delete, test=opts.mailbox_test, - check_timeout=opts.mailbox_check_timeout, + check_timeout=mailbox_check_timeout_value, nameservers=opts.nameservers, dns_timeout=opts.dns_timeout, strip_attachment_payloads=opts.strip_attachment_payloads, - batch_size=opts.mailbox_batch_size, + batch_size=mailbox_batch_size_value, ip_db_path=opts.ip_db_path, always_use_local_files=opts.always_use_local_files, reverse_dns_map_path=opts.reverse_dns_map_path, reverse_dns_map_url=opts.reverse_dns_map_url, offline=opts.offline, - normalize_timespan_threshold_hours=opts.normalize_timespan_threshold_hours, + normalize_timespan_threshold_hours=normalize_timespan_threshold_hours_value, ) except FileExistsError as error: logger.error("{0}".format(error.__str__())) diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index a39b154..35929ed 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -17,7 +17,7 @@ import shutil import subprocess import tempfile from datetime import datetime, timedelta, timezone -from typing import Optional, Union +from typing import Optional, TypedDict, Union, cast import mailparser from expiringdict import ExpiringDict @@ -64,7 +64,24 @@ class DownloadError(RuntimeError): """Raised when an error occurs when downloading a file""" -def decode_base64(data) -> bytes: +class ReverseDNSService(TypedDict): + name: str + type: Optional[str] + + +ReverseDNSMap = dict[str, ReverseDNSService] + + +class IPAddressInfo(TypedDict): + ip_address: str + reverse_dns: Optional[str] + country: Optional[str] + base_domain: Optional[str] + name: Optional[str] + type: Optional[str] + + +def decode_base64(data: str) -> bytes: """ Decodes a base64 string, with padding being optional @@ -75,11 +92,11 @@ def decode_base64(data) -> bytes: bytes: The decoded bytes """ - data = bytes(data, encoding="ascii") - missing_padding = len(data) % 4 + data_bytes = bytes(data, encoding="ascii") + missing_padding = len(data_bytes) % 4 if missing_padding != 0: - data += b"=" * (4 - missing_padding) - return base64.b64decode(data) + data_bytes += b"=" * (4 - missing_padding) + return base64.b64decode(data_bytes) def get_base_domain(domain: str) -> Optional[str]: @@ -132,9 +149,9 @@ def query_dns( record_type = record_type.upper() cache_key = "{0}_{1}".format(domain, record_type) if cache: - records = cache.get(cache_key, None) - if records: - return records + cached_records = cache.get(cache_key, None) + if isinstance(cached_records, list): + return cast(list[str], cached_records) resolver = dns.resolver.Resolver() timeout = float(timeout) @@ -148,26 +165,12 @@ def query_dns( resolver.nameservers = nameservers resolver.timeout = timeout resolver.lifetime = timeout - if record_type == "TXT": - resource_records = list( - map( - lambda r: r.strings, - resolver.resolve(domain, record_type, lifetime=timeout), - ) - ) - _resource_record = [ - resource_record[0][:0].join(resource_record) - for resource_record in resource_records - if resource_record - ] - records = [r.decode() for r in _resource_record] - else: - records = list( - map( - lambda r: r.to_text().replace('"', "").rstrip("."), - resolver.resolve(domain, record_type, lifetime=timeout), - ) + records = list( + map( + lambda r: r.to_text().replace('"', "").rstrip("."), + resolver.resolve(domain, record_type, lifetime=timeout), ) + ) if cache: cache[cache_key] = records @@ -180,7 +183,7 @@ def get_reverse_dns( cache: Optional[ExpiringDict] = None, nameservers: Optional[list[str]] = None, timeout: float = 2.0, -) -> str: +) -> Optional[str]: """ Resolves an IP address to a hostname using a reverse DNS query @@ -198,7 +201,7 @@ def get_reverse_dns( try: address = dns.reversename.from_address(ip_address) hostname = query_dns( - address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout + str(address), "PTR", cache=cache, nameservers=nameservers, timeout=timeout )[0] except dns.exception.DNSException as e: @@ -266,10 +269,12 @@ def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int: float: The converted timestamp """ human_timestamp = human_timestamp.replace("T", " ") - return human_timestamp_to_datetime(human_timestamp).timestamp() + return int(human_timestamp_to_datetime(human_timestamp).timestamp()) -def get_ip_address_country(ip_address: str, *, db_path: Optional[str] = None) -> str: +def get_ip_address_country( + ip_address: str, *, db_path: Optional[str] = None +) -> Optional[str]: """ Returns the ISO code for the country associated with the given IPv4 or IPv6 address @@ -335,11 +340,11 @@ def get_service_from_reverse_dns_base_domain( base_domain, *, always_use_local_file: Optional[bool] = False, - local_file_path: Optional[bool] = None, - url: Optional[bool] = None, + local_file_path: Optional[str] = None, + url: Optional[str] = None, offline: Optional[bool] = False, - reverse_dns_map: Optional[bool] = None, -) -> str: + reverse_dns_map: Optional[ReverseDNSMap] = None, +) -> ReverseDNSService: """ Returns the service name of a given base domain name from reverse DNS. @@ -356,12 +361,6 @@ def get_service_from_reverse_dns_base_domain( the supplied reverse_dns_base_domain and the type will be None """ - def load_csv(_csv_file): - reader = csv.DictReader(_csv_file) - for row in reader: - key = row["base_reverse_dns"].lower().strip() - reverse_dns_map[key] = dict(name=row["name"], type=row["type"]) - base_domain = base_domain.lower().strip() if url is None: url = ( @@ -369,11 +368,24 @@ def get_service_from_reverse_dns_base_domain( "/parsedmarc/master/parsedmarc/" "resources/maps/base_reverse_dns_map.csv" ) + reverse_dns_map_value: ReverseDNSMap if reverse_dns_map is None: - reverse_dns_map = dict() + reverse_dns_map_value = {} + else: + reverse_dns_map_value = reverse_dns_map + + def load_csv(_csv_file): + reader = csv.DictReader(_csv_file) + for row in reader: + key = row["base_reverse_dns"].lower().strip() + reverse_dns_map_value[key] = { + "name": row["name"], + "type": row["type"], + } + csv_file = io.StringIO() - if not (offline or always_use_local_file) and len(reverse_dns_map) == 0: + if not (offline or always_use_local_file) and len(reverse_dns_map_value) == 0: try: logger.debug(f"Trying to fetch reverse DNS map from {url}...") headers = {"User-Agent": USER_AGENT} @@ -390,7 +402,7 @@ def get_service_from_reverse_dns_base_domain( logging.debug("Response body:") logger.debug(csv_file.read()) - if len(reverse_dns_map) == 0: + if len(reverse_dns_map_value) == 0: logger.info("Loading included reverse DNS map...") path = str( files(parsedmarc.resources.maps).joinpath("base_reverse_dns_map.csv") @@ -399,10 +411,11 @@ def get_service_from_reverse_dns_base_domain( path = local_file_path with open(path) as csv_file: load_csv(csv_file) + service: ReverseDNSService try: - service = reverse_dns_map[base_domain] + service = reverse_dns_map_value[base_domain] except KeyError: - service = dict(name=base_domain, type=None) + service = {"name": base_domain, "type": None} return service @@ -415,11 +428,11 @@ def get_ip_address_info( always_use_local_files: Optional[bool] = False, reverse_dns_map_url: Optional[str] = None, cache: Optional[ExpiringDict] = None, - reverse_dns_map: Optional[dict] = None, + reverse_dns_map: Optional[ReverseDNSMap] = None, offline: Optional[bool] = False, nameservers: Optional[list[str]] = None, - timeout: Optional[float] = 2.0, -) -> dict[str, str]: + timeout: float = 2.0, +) -> IPAddressInfo: """ Returns reverse DNS and country information for the given IP address @@ -442,12 +455,22 @@ def get_ip_address_info( """ ip_address = ip_address.lower() if cache is not None: - info = cache.get(ip_address, None) - if info: + cached_info = cache.get(ip_address, None) + if ( + cached_info + and isinstance(cached_info, dict) + and "ip_address" in cached_info + ): logger.debug(f"IP address {ip_address} was found in cache") - return info - info: dict[str, str] = {} - info["ip_address"] = ip_address + return cast(IPAddressInfo, cached_info) + info: IPAddressInfo = { + "ip_address": ip_address, + "reverse_dns": None, + "country": None, + "base_domain": None, + "name": None, + "type": None, + } if offline: reverse_dns = None else: @@ -457,9 +480,6 @@ def get_ip_address_info( country = get_ip_address_country(ip_address, db_path=ip_db_path) info["country"] = country info["reverse_dns"] = reverse_dns - info["base_domain"] = None - info["name"] = None - info["type"] = None if reverse_dns is not None: base_domain = get_base_domain(reverse_dns) if base_domain is not None: @@ -484,7 +504,7 @@ def get_ip_address_info( return info -def parse_email_address(original_address: str) -> dict[str, str]: +def parse_email_address(original_address: str) -> dict[str, Optional[str]]: if original_address[0] == "": display_name = None else: @@ -563,7 +583,7 @@ def is_outlook_msg(content) -> bool: ) -def convert_outlook_msg(msg_bytes: bytes) -> str: +def convert_outlook_msg(msg_bytes: bytes) -> bytes: """ Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to standard RFC 822 format @@ -572,7 +592,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> str: msg_bytes (bytes): the content of the .msg file Returns: - A RFC 822 string + A RFC 822 bytes payload """ if not is_outlook_msg(msg_bytes): raise ValueError("The supplied bytes are not an Outlook MSG file")