diff --git a/docs/Makefile b/docs/Makefile index 424bf6c..5564a63 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = -SPHINXBUILD = python -msphinx +SPHINXBUILD = python3 -msphinx SPHINXPROJ = parsedmarc SOURCEDIR = . BUILDDIR = _build diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 5bc909f..9c369c1 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -27,6 +27,7 @@ import smtplib from ssl import SSLError, CertificateError, create_default_context import time +from expiringdict import ExpiringDict import xmltodict import imapclient import imapclient.exceptions @@ -50,7 +51,7 @@ MAGIC_ZIP = b"\x50\x4B\x03\x04" MAGIC_GZIP = b"\x1F\x8B" MAGIC_XML = b"\x3c\x3f\x78\x6d\x6c\x20" -DNS_CACHE = dict() +IP_ADDRESS_CACHE = ExpiringDict(max_len=10000, max_age_seconds=1800) class ParserError(RuntimeError): @@ -95,12 +96,10 @@ def _parse_report_record(record, nameservers=None, timeout=2.0): nameservers = ["8.8.8.8", "4.4.4.4"] record = record.copy() new_record = OrderedDict() - new_record_source = DNS_CACHE.get(record["row"]["source_ip"], None) - if not new_record_source: - new_record_source = get_ip_address_info(record["row"]["source_ip"], - nameservers=nameservers, - timeout=timeout) - DNS_CACHE[record["row"]["source_ip"]] = new_record_source + new_record_source = get_ip_address_info(record["row"]["source_ip"], + cache=IP_ADDRESS_CACHE, + nameservers=nameservers, + timeout=timeout) new_record["source"] = new_record_source new_record["count"] = int(record["row"]["count"]) policy_evaluated = record["row"]["policy_evaluated"].copy() @@ -542,12 +541,9 @@ def parse_forensic_report(feedback_report, sample, msg_date, parsed_report["arrival_date_utc"] = arrival_utc ip_address = parsed_report["source_ip"] - parsed_report_source = DNS_CACHE.get(ip_address, None) - if not parsed_report_source: - parsed_report_source = get_ip_address_info(ip_address, - nameservers=nameservers, - timeout=timeout) - DNS_CACHE[ip_address] = parsed_report_source + parsed_report_source = get_ip_address_info(ip_address, + nameservers=nameservers, + timeout=timeout) parsed_report["source"] = parsed_report_source del parsed_report["source_ip"] diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index 0f5bbe7..a3870bf 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -104,13 +104,14 @@ def get_base_domain(domain): return psl.get_public_suffix(domain) -def query_dns(domain, record_type, nameservers=None, timeout=2.0): +def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0): """ Queries DNS Args: domain (str): The domain or subdomain to query about record_type (str): The record type to query for + cache (ExpiringDict): Cache storage nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout (float): Sets the DNS timeout in seconds @@ -118,6 +119,14 @@ def query_dns(domain, record_type, nameservers=None, timeout=2.0): Returns: list: A list of answers """ + domain = str(domain).lower() + record_type = record_type.upper() + cache_key = "{0}_{1}".format(domain, record_type) + if cache: + records = cache.get(cache_key, None) + if records: + return records + resolver = dns.resolver.Resolver() timeout = float(timeout) if nameservers is None: @@ -134,19 +143,24 @@ def query_dns(domain, record_type, nameservers=None, timeout=2.0): _resource_record = [ resource_record[0][:0].join(resource_record) for resource_record in resource_records if resource_record] - return [r.decode() for r in _resource_record] + records = [r.decode() for r in _resource_record] else: - return list(map( + records = list(map( lambda r: r.to_text().replace('"', '').rstrip("."), resolver.query(domain, record_type, tcp=True))) + if cache: + cache[cache_key] = records + + return records -def get_reverse_dns(ip_address, nameservers=None, timeout=2.0): +def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0): """ Resolves an IP address to a hostname using a reverse DNS query Args: ip_address (str): The IP address to resolve + cache (ExpiringDict): Cache storage nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout (float): Sets the DNS query timeout in seconds @@ -157,7 +171,7 @@ def get_reverse_dns(ip_address, nameservers=None, timeout=2.0): hostname = None try: address = dns.reversename.from_address(ip_address) - hostname = query_dns(address, "PTR", + hostname = query_dns(address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout)[0] @@ -290,12 +304,13 @@ def get_ip_address_country(ip_address): return country -def get_ip_address_info(ip_address, nameservers=None, timeout=2.0): +def get_ip_address_info(ip_address, cache=None, nameservers=None, timeout=2.0): """ Returns reverse DNS and country information for the given IP address Args: ip_address (str): The IP address to check + cache (ExpiringDict): Cache storage nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout (float): Sets the DNS timeout in seconds @@ -305,6 +320,10 @@ def get_ip_address_info(ip_address, nameservers=None, timeout=2.0): """ ip_address = ip_address.lower() + if cache: + info = cache.get(ip_address, None) + if info: + return info info = OrderedDict() info["ip_address"] = ip_address reverse_dns = get_reverse_dns(ip_address, diff --git a/requirements.txt b/requirements.txt index 6791c47..5e6d1e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ dnspython +expiringdict urllib3<1.24,>=1.21.1 requests publicsuffix @@ -12,7 +13,7 @@ elasticsearch>=6.3.0,<7.0.0 elasticsearch-dsl>=6.2.1,<7.0.0 kafka-python flake8 -sphinx==1.7.9 +sphinx sphinx_rtd_theme wheel -rstcheck +rstcheck>=3.3.1