diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index 24f6dd2..7b8d211 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -1,11 +1,18 @@ +# -*- coding: utf-8 -*- + """Utility functions that might be useful for other projects""" +from __future__ import annotations + +from typing import Optional, Union + import logging import os from datetime import datetime from datetime import timezone from datetime import timedelta from collections import OrderedDict +from expiringdict import ExpiringDict import tempfile import subprocess import shutil @@ -60,12 +67,12 @@ class DownloadError(RuntimeError): """Raised when an error occurs when downloading a file""" -def decode_base64(data): +def decode_base64(data: str) -> bytes: """ Decodes a base64 string, with padding being optional Args: - data: A base64 encoded string + data (str): A base64 encoded string Returns: bytes: The decoded bytes @@ -78,7 +85,7 @@ def decode_base64(data): return base64.b64decode(data) -def get_base_domain(domain): +def get_base_domain(domain: str) -> str: """ Gets the base domain name for the given domain @@ -102,7 +109,12 @@ def get_base_domain(domain): return publicsuffix -def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0): +def query_dns(domain: str, + record_type: str, + *, + cache: Optional[ExpiringDict] = None, + nameservers: list[str] = None, + timeout:int = 2.0) -> list[str]: """ Queries DNS @@ -163,7 +175,11 @@ def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0): return records -def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0): +def get_reverse_dns(ip_address, + *, + cache: Optional[ExpiringDict] = None, + nameservers: list[str] = None, + timeout:int = 2.0) -> str: """ Resolves an IP address to a hostname using a reverse DNS query @@ -191,7 +207,7 @@ def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0): return hostname -def timestamp_to_datetime(timestamp): +def timestamp_to_datetime(timestamp: int) -> datetime: """ Converts a UNIX/DMARC timestamp to a Python ``datetime`` object @@ -204,7 +220,7 @@ def timestamp_to_datetime(timestamp): return datetime.fromtimestamp(int(timestamp)) -def timestamp_to_human(timestamp): +def timestamp_to_human(timestamp: int) -> str: """ Converts a UNIX/DMARC timestamp to a human-readable string @@ -217,7 +233,9 @@ def timestamp_to_human(timestamp): return timestamp_to_datetime(timestamp).strftime("%Y-%m-%d %H:%M:%S") -def human_timestamp_to_datetime(human_timestamp, to_utc=False): +def human_timestamp_to_datetime(human_timestamp: str, + *, + to_utc: Optional[bool] = False) -> datetime: """ Converts a human-readable timestamp into a Python ``datetime`` object @@ -236,7 +254,7 @@ def human_timestamp_to_datetime(human_timestamp, to_utc=False): return dt.astimezone(timezone.utc) if to_utc else dt -def human_timestamp_to_unix_timestamp(human_timestamp): +def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int: """ Converts a human-readable timestamp into a UNIX timestamp @@ -250,7 +268,9 @@ def human_timestamp_to_unix_timestamp(human_timestamp): return human_timestamp_to_datetime(human_timestamp).timestamp() -def get_ip_address_country(ip_address, db_path=None): +def get_ip_address_country(ip_address:str, + *, + db_path: Optional[str ] = None) -> str: """ Returns the ISO code for the country associated with the given IPv4 or IPv6 address @@ -314,12 +334,13 @@ def get_ip_address_country(ip_address, db_path=None): def get_service_from_reverse_dns_base_domain( base_domain, - always_use_local_file=False, - local_file_path=None, - url=None, - offline=False, - reverse_dns_map=None, -): + *, + always_use_local_file: Optional[bool] = False, + local_file_path: Optional[bool] = None, + url: Optional[bool] = None, + offline: Optional[bool ] = False, + reverse_dns_map: Optional[bool] = None, +) -> str: """ Returns the service name of a given base domain name from reverse DNS. @@ -389,16 +410,17 @@ def get_service_from_reverse_dns_base_domain( def get_ip_address_info( ip_address, - ip_db_path=None, - reverse_dns_map_path=None, - always_use_local_files=False, - reverse_dns_map_url=None, - cache=None, - reverse_dns_map=None, - offline=False, - nameservers=None, - timeout=2.0, -): + *, + ip_db_path:Optional[str]=None, + reverse_dns_map_path:Optional[str]=None, + always_use_local_files:Optional[bool]=False, + reverse_dns_map_url: Optional[bool ] = None, + cache: Optional[ExpiringDict]=None, + reverse_dns_map: Optional[bool] = None, + offline: Optional[bool]=False, + nameservers:Optional[list[str]]=None, + timeout:Optional[float]=2.0, +) -> OrderedDict[str, str]: """ Returns reverse DNS and country information for the given IP address @@ -416,7 +438,7 @@ def get_ip_address_info( timeout (float): Sets the DNS timeout in seconds Returns: - OrderedDict: ``ip_address``, ``reverse_dns`` + OrderedDict: ``ip_address``, ``reverse_dns``, ``country`` """ ip_address = ip_address.lower() @@ -463,7 +485,7 @@ def get_ip_address_info( return info -def parse_email_address(original_address): +def parse_email_address(original_address: str) -> OrderedDict[str,str]: if original_address[0] == "": display_name = None else: @@ -486,7 +508,7 @@ def parse_email_address(original_address): ) -def get_filename_safe_string(string): +def get_filename_safe_string(string: str) -> str: """ Converts a string to a string that is safe for a filename @@ -508,7 +530,7 @@ def get_filename_safe_string(string): return string -def is_mbox(path): +def is_mbox(path: str) -> bool: """ Checks if the given content is an MBOX mailbox file @@ -529,7 +551,7 @@ def is_mbox(path): return _is_mbox -def is_outlook_msg(content): +def is_outlook_msg(content) -> bool: """ Checks if the given content is an Outlook msg OLE/MSG file @@ -544,7 +566,7 @@ def is_outlook_msg(content): ) -def convert_outlook_msg(msg_bytes): +def convert_outlook_msg(msg_bytes: bytes) -> str: """ Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to standard RFC 822 format @@ -580,7 +602,9 @@ def convert_outlook_msg(msg_bytes): return rfc822 -def parse_email(data, strip_attachment_payloads=False): +def parse_email(data: Union[bytes, str], + *, + strip_attachment_payloads: Optional[bool]=False): """ A simplified email parser