mirror of
https://github.com/domainaware/parsedmarc.git
synced 2026-02-17 07:03:58 +00:00
Enhance type hints and argument formatting in utils.py for improved clarity and consistency
This commit is contained in:
@@ -1,11 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Utility functions that might be useful for other projects"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from datetime import timedelta
|
||||
from collections import OrderedDict
|
||||
from expiringdict import ExpiringDict
|
||||
import tempfile
|
||||
import subprocess
|
||||
import shutil
|
||||
@@ -60,12 +67,12 @@ class DownloadError(RuntimeError):
|
||||
"""Raised when an error occurs when downloading a file"""
|
||||
|
||||
|
||||
def decode_base64(data):
|
||||
def decode_base64(data: str) -> bytes:
|
||||
"""
|
||||
Decodes a base64 string, with padding being optional
|
||||
|
||||
Args:
|
||||
data: A base64 encoded string
|
||||
data (str): A base64 encoded string
|
||||
|
||||
Returns:
|
||||
bytes: The decoded bytes
|
||||
@@ -78,7 +85,7 @@ def decode_base64(data):
|
||||
return base64.b64decode(data)
|
||||
|
||||
|
||||
def get_base_domain(domain):
|
||||
def get_base_domain(domain: str) -> str:
|
||||
"""
|
||||
Gets the base domain name for the given domain
|
||||
|
||||
@@ -102,7 +109,12 @@ def get_base_domain(domain):
|
||||
return publicsuffix
|
||||
|
||||
|
||||
def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0):
|
||||
def query_dns(domain: str,
|
||||
record_type: str,
|
||||
*,
|
||||
cache: Optional[ExpiringDict] = None,
|
||||
nameservers: list[str] = None,
|
||||
timeout:int = 2.0) -> list[str]:
|
||||
"""
|
||||
Queries DNS
|
||||
|
||||
@@ -163,7 +175,11 @@ def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0):
|
||||
return records
|
||||
|
||||
|
||||
def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0):
|
||||
def get_reverse_dns(ip_address,
|
||||
*,
|
||||
cache: Optional[ExpiringDict] = None,
|
||||
nameservers: list[str] = None,
|
||||
timeout:int = 2.0) -> str:
|
||||
"""
|
||||
Resolves an IP address to a hostname using a reverse DNS query
|
||||
|
||||
@@ -191,7 +207,7 @@ def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0):
|
||||
return hostname
|
||||
|
||||
|
||||
def timestamp_to_datetime(timestamp):
|
||||
def timestamp_to_datetime(timestamp: int) -> datetime:
|
||||
"""
|
||||
Converts a UNIX/DMARC timestamp to a Python ``datetime`` object
|
||||
|
||||
@@ -204,7 +220,7 @@ def timestamp_to_datetime(timestamp):
|
||||
return datetime.fromtimestamp(int(timestamp))
|
||||
|
||||
|
||||
def timestamp_to_human(timestamp):
|
||||
def timestamp_to_human(timestamp: int) -> str:
|
||||
"""
|
||||
Converts a UNIX/DMARC timestamp to a human-readable string
|
||||
|
||||
@@ -217,7 +233,9 @@ def timestamp_to_human(timestamp):
|
||||
return timestamp_to_datetime(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def human_timestamp_to_datetime(human_timestamp, to_utc=False):
|
||||
def human_timestamp_to_datetime(human_timestamp: str,
|
||||
*,
|
||||
to_utc: Optional[bool] = False) -> datetime:
|
||||
"""
|
||||
Converts a human-readable timestamp into a Python ``datetime`` object
|
||||
|
||||
@@ -236,7 +254,7 @@ def human_timestamp_to_datetime(human_timestamp, to_utc=False):
|
||||
return dt.astimezone(timezone.utc) if to_utc else dt
|
||||
|
||||
|
||||
def human_timestamp_to_unix_timestamp(human_timestamp):
|
||||
def human_timestamp_to_unix_timestamp(human_timestamp: str) -> int:
|
||||
"""
|
||||
Converts a human-readable timestamp into a UNIX timestamp
|
||||
|
||||
@@ -250,7 +268,9 @@ def human_timestamp_to_unix_timestamp(human_timestamp):
|
||||
return human_timestamp_to_datetime(human_timestamp).timestamp()
|
||||
|
||||
|
||||
def get_ip_address_country(ip_address, db_path=None):
|
||||
def get_ip_address_country(ip_address:str,
|
||||
*,
|
||||
db_path: Optional[str ] = None) -> str:
|
||||
"""
|
||||
Returns the ISO code for the country associated
|
||||
with the given IPv4 or IPv6 address
|
||||
@@ -314,12 +334,13 @@ def get_ip_address_country(ip_address, db_path=None):
|
||||
|
||||
def get_service_from_reverse_dns_base_domain(
|
||||
base_domain,
|
||||
always_use_local_file=False,
|
||||
local_file_path=None,
|
||||
url=None,
|
||||
offline=False,
|
||||
reverse_dns_map=None,
|
||||
):
|
||||
*,
|
||||
always_use_local_file: Optional[bool] = False,
|
||||
local_file_path: Optional[bool] = None,
|
||||
url: Optional[bool] = None,
|
||||
offline: Optional[bool ] = False,
|
||||
reverse_dns_map: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the service name of a given base domain name from reverse DNS.
|
||||
|
||||
@@ -389,16 +410,17 @@ def get_service_from_reverse_dns_base_domain(
|
||||
|
||||
def get_ip_address_info(
|
||||
ip_address,
|
||||
ip_db_path=None,
|
||||
reverse_dns_map_path=None,
|
||||
always_use_local_files=False,
|
||||
reverse_dns_map_url=None,
|
||||
cache=None,
|
||||
reverse_dns_map=None,
|
||||
offline=False,
|
||||
nameservers=None,
|
||||
timeout=2.0,
|
||||
):
|
||||
*,
|
||||
ip_db_path:Optional[str]=None,
|
||||
reverse_dns_map_path:Optional[str]=None,
|
||||
always_use_local_files:Optional[bool]=False,
|
||||
reverse_dns_map_url: Optional[bool ] = None,
|
||||
cache: Optional[ExpiringDict]=None,
|
||||
reverse_dns_map: Optional[bool] = None,
|
||||
offline: Optional[bool]=False,
|
||||
nameservers:Optional[list[str]]=None,
|
||||
timeout:Optional[float]=2.0,
|
||||
) -> OrderedDict[str, str]:
|
||||
"""
|
||||
Returns reverse DNS and country information for the given IP address
|
||||
|
||||
@@ -416,7 +438,7 @@ def get_ip_address_info(
|
||||
timeout (float): Sets the DNS timeout in seconds
|
||||
|
||||
Returns:
|
||||
OrderedDict: ``ip_address``, ``reverse_dns``
|
||||
OrderedDict: ``ip_address``, ``reverse_dns``, ``country``
|
||||
|
||||
"""
|
||||
ip_address = ip_address.lower()
|
||||
@@ -463,7 +485,7 @@ def get_ip_address_info(
|
||||
return info
|
||||
|
||||
|
||||
def parse_email_address(original_address):
|
||||
def parse_email_address(original_address: str) -> OrderedDict[str,str]:
|
||||
if original_address[0] == "":
|
||||
display_name = None
|
||||
else:
|
||||
@@ -486,7 +508,7 @@ def parse_email_address(original_address):
|
||||
)
|
||||
|
||||
|
||||
def get_filename_safe_string(string):
|
||||
def get_filename_safe_string(string: str) -> str:
|
||||
"""
|
||||
Converts a string to a string that is safe for a filename
|
||||
|
||||
@@ -508,7 +530,7 @@ def get_filename_safe_string(string):
|
||||
return string
|
||||
|
||||
|
||||
def is_mbox(path):
|
||||
def is_mbox(path: str) -> bool:
|
||||
"""
|
||||
Checks if the given content is an MBOX mailbox file
|
||||
|
||||
@@ -529,7 +551,7 @@ def is_mbox(path):
|
||||
return _is_mbox
|
||||
|
||||
|
||||
def is_outlook_msg(content):
|
||||
def is_outlook_msg(content) -> bool:
|
||||
"""
|
||||
Checks if the given content is an Outlook msg OLE/MSG file
|
||||
|
||||
@@ -544,7 +566,7 @@ def is_outlook_msg(content):
|
||||
)
|
||||
|
||||
|
||||
def convert_outlook_msg(msg_bytes):
|
||||
def convert_outlook_msg(msg_bytes: bytes) -> str:
|
||||
"""
|
||||
Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to
|
||||
standard RFC 822 format
|
||||
@@ -580,7 +602,9 @@ def convert_outlook_msg(msg_bytes):
|
||||
return rfc822
|
||||
|
||||
|
||||
def parse_email(data, strip_attachment_payloads=False):
|
||||
def parse_email(data: Union[bytes, str],
|
||||
*,
|
||||
strip_attachment_payloads: Optional[bool]=False):
|
||||
"""
|
||||
A simplified email parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user