Caching impovements

This commit is contained in:
Sean Whalen
2018-11-15 13:17:09 -05:00
parent 0ddc904c9d
commit 13e2b50671
4 changed files with 38 additions and 22 deletions
+1 -1
View File
@@ -3,7 +3,7 @@
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = python -msphinx
SPHINXBUILD = python3 -msphinx
SPHINXPROJ = parsedmarc
SOURCEDIR = .
BUILDDIR = _build
+9 -13
View File
@@ -27,6 +27,7 @@ import smtplib
from ssl import SSLError, CertificateError, create_default_context
import time
from expiringdict import ExpiringDict
import xmltodict
import imapclient
import imapclient.exceptions
@@ -50,7 +51,7 @@ MAGIC_ZIP = b"\x50\x4B\x03\x04"
MAGIC_GZIP = b"\x1F\x8B"
MAGIC_XML = b"\x3c\x3f\x78\x6d\x6c\x20"
DNS_CACHE = dict()
IP_ADDRESS_CACHE = ExpiringDict(max_len=10000, max_age_seconds=1800)
class ParserError(RuntimeError):
@@ -95,12 +96,10 @@ def _parse_report_record(record, nameservers=None, timeout=2.0):
nameservers = ["8.8.8.8", "4.4.4.4"]
record = record.copy()
new_record = OrderedDict()
new_record_source = DNS_CACHE.get(record["row"]["source_ip"], None)
if not new_record_source:
new_record_source = get_ip_address_info(record["row"]["source_ip"],
nameservers=nameservers,
timeout=timeout)
DNS_CACHE[record["row"]["source_ip"]] = new_record_source
new_record_source = get_ip_address_info(record["row"]["source_ip"],
cache=IP_ADDRESS_CACHE,
nameservers=nameservers,
timeout=timeout)
new_record["source"] = new_record_source
new_record["count"] = int(record["row"]["count"])
policy_evaluated = record["row"]["policy_evaluated"].copy()
@@ -542,12 +541,9 @@ def parse_forensic_report(feedback_report, sample, msg_date,
parsed_report["arrival_date_utc"] = arrival_utc
ip_address = parsed_report["source_ip"]
parsed_report_source = DNS_CACHE.get(ip_address, None)
if not parsed_report_source:
parsed_report_source = get_ip_address_info(ip_address,
nameservers=nameservers,
timeout=timeout)
DNS_CACHE[ip_address] = parsed_report_source
parsed_report_source = get_ip_address_info(ip_address,
nameservers=nameservers,
timeout=timeout)
parsed_report["source"] = parsed_report_source
del parsed_report["source_ip"]
+25 -6
View File
@@ -104,13 +104,14 @@ def get_base_domain(domain):
return psl.get_public_suffix(domain)
def query_dns(domain, record_type, nameservers=None, timeout=2.0):
def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0):
"""
Queries DNS
Args:
domain (str): The domain or subdomain to query about
record_type (str): The record type to query for
cache (ExpiringDict): Cache storage
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
timeout (float): Sets the DNS timeout in seconds
@@ -118,6 +119,14 @@ def query_dns(domain, record_type, nameservers=None, timeout=2.0):
Returns:
list: A list of answers
"""
domain = str(domain).lower()
record_type = record_type.upper()
cache_key = "{0}_{1}".format(domain, record_type)
if cache:
records = cache.get(cache_key, None)
if records:
return records
resolver = dns.resolver.Resolver()
timeout = float(timeout)
if nameservers is None:
@@ -134,19 +143,24 @@ def query_dns(domain, record_type, nameservers=None, timeout=2.0):
_resource_record = [
resource_record[0][:0].join(resource_record)
for resource_record in resource_records if resource_record]
return [r.decode() for r in _resource_record]
records = [r.decode() for r in _resource_record]
else:
return list(map(
records = list(map(
lambda r: r.to_text().replace('"', '').rstrip("."),
resolver.query(domain, record_type, tcp=True)))
if cache:
cache[cache_key] = records
return records
def get_reverse_dns(ip_address, nameservers=None, timeout=2.0):
def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0):
"""
Resolves an IP address to a hostname using a reverse DNS query
Args:
ip_address (str): The IP address to resolve
cache (ExpiringDict): Cache storage
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
timeout (float): Sets the DNS query timeout in seconds
@@ -157,7 +171,7 @@ def get_reverse_dns(ip_address, nameservers=None, timeout=2.0):
hostname = None
try:
address = dns.reversename.from_address(ip_address)
hostname = query_dns(address, "PTR",
hostname = query_dns(address, "PTR", cache=cache,
nameservers=nameservers,
timeout=timeout)[0]
@@ -290,12 +304,13 @@ def get_ip_address_country(ip_address):
return country
def get_ip_address_info(ip_address, nameservers=None, timeout=2.0):
def get_ip_address_info(ip_address, cache=None, nameservers=None, timeout=2.0):
"""
Returns reverse DNS and country information for the given IP address
Args:
ip_address (str): The IP address to check
cache (ExpiringDict): Cache storage
nameservers (list): A list of one or more nameservers to use
(Cloudflare's public DNS resolvers by default)
timeout (float): Sets the DNS timeout in seconds
@@ -305,6 +320,10 @@ def get_ip_address_info(ip_address, nameservers=None, timeout=2.0):
"""
ip_address = ip_address.lower()
if cache:
info = cache.get(ip_address, None)
if info:
return info
info = OrderedDict()
info["ip_address"] = ip_address
reverse_dns = get_reverse_dns(ip_address,
+3 -2
View File
@@ -1,4 +1,5 @@
dnspython
expiringdict
urllib3<1.24,>=1.21.1
requests
publicsuffix
@@ -12,7 +13,7 @@ elasticsearch>=6.3.0,<7.0.0
elasticsearch-dsl>=6.2.1,<7.0.0
kafka-python
flake8
sphinx==1.7.9
sphinx
sphinx_rtd_theme
wheel
rstcheck
rstcheck>=3.3.1