Enhance type hints and argument formatting in utils.py for improved clarity and consistency

This commit is contained in:
Sean Whalen
2025-12-02 16:21:30 -05:00
parent 1127f65fbb
commit 888d717476

View File

@@ -1,11 +1,18 @@
# -*- coding: utf-8 -*-
"""Utility functions that might be useful for other projects"""
from __future__ import annotations
from typing import Optional, Union
import logging
import os
from datetime import datetime
from datetime import timezone
from datetime import timedelta
from collections import OrderedDict
from expiringdict import ExpiringDict
import tempfile
import subprocess
import shutil
@@ -60,12 +67,12 @@ class DownloadError(RuntimeError):
"""Raised when an error occurs when downloading a file"""
def decode_base64(data):
def decode_base64(data: str) -> bytes:
"""
Decodes a base64 string, with padding being optional
Args:
data: A base64 encoded string
data (str): A base64 encoded string
Returns:
bytes: The decoded bytes
@@ -78,7 +85,7 @@ def decode_base64(data):
return base64.b64decode(data)
def get_base_domain(domain):
def get_base_domain(domain: str) -> str:
"""
Gets the base domain name for the given domain
@@ -102,7 +109,12 @@ def get_base_domain(domain):
return publicsuffix
def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0):
def query_dns(domain: str,
record_type: str,
*,
cache: Optional[ExpiringDict] = None,
nameservers: list[str] = None,
timeout:int = 2.0) -> list[str]:
"""
Queries DNS
@@ -163,7 +175,11 @@ def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0):
return records
def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0):
def get_reverse_dns(ip_address,
*,
cache: Optional[ExpiringDict] = None,
nameservers: list[str] = None,
timeout:int = 2.0) -> str:
"""
Resolves an IP address to a hostname using a reverse DNS query
@@ -191,7 +207,7 @@ def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0):
return hostname
def timestamp_to_datetime(timestamp):
def timestamp_to_datetime(timestamp: int) -> datetime:
"""
Converts a UNIX/DMARC timestamp to a Python ``datetime`` object
@@ -204,7 +220,7 @@ def timestamp_to_datetime(timestamp):
return datetime.fromtimestamp(int(timestamp))
def timestamp_to_human(timestamp):
def timestamp_to_human(timestamp: int) -> str:
"""
Converts a UNIX/DMARC timestamp to a human-readable string
@@ -217,7 +233,9 @@ def timestamp_to_human(timestamp):
return timestamp_to_datetime(timestamp).strftime("%Y-%m-%d %H:%M:%S")
def human_timestamp_to_datetime(human_timestamp, to_utc=False):
def human_timestamp_to_datetime(human_timestamp: str,
*,
to_utc: Optional[bool] = False) -> datetime:
"""
Converts a human-readable timestamp into a Python ``datetime`` object
@@ -236,7 +254,7 @@ def human_timestamp_to_datetime(human_timestamp, to_utc=False):
return dt.astimezone(timezone.utc) if to_utc else dt
def human_timestamp_to_unix_timestamp(human_timestamp):
def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int:
"""
Converts a human-readable timestamp into a UNIX timestamp
@@ -250,7 +268,9 @@ def human_timestamp_to_unix_timestamp(human_timestamp):
return human_timestamp_to_datetime(human_timestamp).timestamp()
def get_ip_address_country(ip_address, db_path=None):
def get_ip_address_country(ip_address:str,
*,
db_path: Optional[str ] = None) -> str:
"""
Returns the ISO code for the country associated
with the given IPv4 or IPv6 address
@@ -314,12 +334,13 @@ def get_ip_address_country(ip_address, db_path=None):
def get_service_from_reverse_dns_base_domain(
base_domain,
always_use_local_file=False,
local_file_path=None,
url=None,
offline=False,
reverse_dns_map=None,
):
*,
always_use_local_file: Optional[bool] = False,
local_file_path: Optional[bool] = None,
url: Optional[bool] = None,
offline: Optional[bool ] = False,
reverse_dns_map: Optional[bool] = None,
) -> str:
"""
Returns the service name of a given base domain name from reverse DNS.
@@ -389,16 +410,17 @@ def get_service_from_reverse_dns_base_domain(
def get_ip_address_info(
ip_address,
ip_db_path=None,
reverse_dns_map_path=None,
always_use_local_files=False,
reverse_dns_map_url=None,
cache=None,
reverse_dns_map=None,
offline=False,
nameservers=None,
timeout=2.0,
):
*,
ip_db_path:Optional[str]=None,
reverse_dns_map_path:Optional[str]=None,
always_use_local_files:Optional[bool]=False,
reverse_dns_map_url: Optional[bool ] = None,
cache: Optional[ExpiringDict]=None,
reverse_dns_map: Optional[bool] = None,
offline: Optional[bool]=False,
nameservers:Optional[list[str]]=None,
timeout:Optional[float]=2.0,
) -> OrderedDict[str, str]:
"""
Returns reverse DNS and country information for the given IP address
@@ -416,7 +438,7 @@ def get_ip_address_info(
timeout (float): Sets the DNS timeout in seconds
Returns:
OrderedDict: ``ip_address``, ``reverse_dns``
OrderedDict: ``ip_address``, ``reverse_dns``, ``country``
"""
ip_address = ip_address.lower()
@@ -463,7 +485,7 @@ def get_ip_address_info(
return info
def parse_email_address(original_address):
def parse_email_address(original_address: str) -> OrderedDict[str,str]:
if original_address[0] == "":
display_name = None
else:
@@ -486,7 +508,7 @@ def parse_email_address(original_address):
)
def get_filename_safe_string(string):
def get_filename_safe_string(string: str) -> str:
"""
Converts a string to a string that is safe for a filename
@@ -508,7 +530,7 @@ def get_filename_safe_string(string):
return string
def is_mbox(path):
def is_mbox(path: str) -> bool:
"""
Checks if the given content is an MBOX mailbox file
@@ -529,7 +551,7 @@ def is_mbox(path):
return _is_mbox
def is_outlook_msg(content):
def is_outlook_msg(content) -> bool:
"""
Checks if the given content is an Outlook msg OLE/MSG file
@@ -544,7 +566,7 @@ def is_outlook_msg(content):
)
def convert_outlook_msg(msg_bytes):
def convert_outlook_msg(msg_bytes: bytes) -> str:
"""
Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to
standard RFC 822 format
@@ -580,7 +602,9 @@ def convert_outlook_msg(msg_bytes):
return rfc822
def parse_email(data, strip_attachment_payloads=False):
def parse_email(data: Union[bytes, str],
*,
strip_attachment_payloads: Optional[bool]=False):
"""
A simplified email parser