change n_cpus to n_procs. fix PEP8 issues. remove debugging statements.

This commit is contained in:
zscholl
2019-02-20 11:25:46 -07:00
parent ad6860817f
commit 8fc856d0e3
4 changed files with 25 additions and 19 deletions
+1 -1
View File
@@ -138,7 +138,7 @@ The full set of configuration options are:
- ``debug`` - bool: Print debugging messages
- ``silent`` - bool: Only print errors (Default: True)
- ``log_file`` - str: Write log messages to a file at this path
- ``n_cpus`` - str: Number of process to run in parallel when parsing in CLI mode (Default: 1)
- ``n_procs`` - str: Number of process to run in parallel when parsing in CLI mode (Default: 1)
- ``chunksize`` - str: Number of files to give to each process when running in parallel. Setting this to a number larger than one can improve performance when processing thousands of files
- ``imap``
- ``host`` - str: The IMAP server hostname or IP address
+7 -5
View File
@@ -83,7 +83,8 @@ class InvalidForensicReport(InvalidDMARCReport):
"""Raised when an invalid DMARC forensic report is encountered"""
def _parse_report_record(record, nameservers=None, dns_timeout=2.0, parallel=False):
def _parse_report_record(record, nameservers=None, dns_timeout=2.0,
parallel=False):
"""
Converts a record from a DMARC aggregate report into a more consistent
format
@@ -207,7 +208,8 @@ def _parse_report_record(record, nameservers=None, dns_timeout=2.0, parallel=Fal
return new_record
def parse_aggregate_report_xml(xml, nameservers=None, timeout=2.0, parallel=False):
def parse_aggregate_report_xml(xml, nameservers=None, timeout=2.0,
parallel=False):
"""Parses a DMARC XML report string and returns a consistent OrderedDict
Args:
@@ -380,7 +382,8 @@ def extract_xml(input_):
return xml
def parse_aggregate_report_file(_input, nameservers=None, dns_timeout=2.0, parallel=False):
def parse_aggregate_report_file(_input, nameservers=None, dns_timeout=2.0,
parallel=False):
"""Parses a file at the given path, a file-like object. or bytes as a
aggregate DMARC report
@@ -817,8 +820,7 @@ def parse_report_file(input_, nameservers=None, dns_timeout=2.0,
dns_timeout=dns_timeout,
strip_attachment_payloads=sa,
parallel=parallel)
except InvalidDMARCReport as e:
print("DEBUGGING: {}".format(e))
except InvalidDMARCReport:
raise InvalidDMARCReport("Not a valid aggregate or forensic "
"report")
return results
+15 -11
View File
@@ -13,7 +13,6 @@ import json
from ssl import CERT_NONE, create_default_context
from multiprocessing import Pool, Value
from itertools import repeat
from contextlib import contextmanager
import time
from tqdm import tqdm
@@ -24,11 +23,13 @@ from parsedmarc import IMAPError, get_dmarc_reports_from_inbox, \
logger = logging.getLogger("parsedmarc")
def _str_to_list(s):
"""Converts a comma separated string to a list"""
_list = s.split(",")
return list(map(lambda i: i.lstrip(), _list))
def cli_parse(file_path, sa, nameservers, dns_timeout, parallel=False):
"""Separated this function for multiprocessing"""
try:
@@ -38,17 +39,19 @@ def cli_parse(file_path, sa, nameservers, dns_timeout, parallel=False):
strip_attachment_payloads=sa,
parallel=parallel)
except ParserError as error:
return (error, file_path)
return (error, file_path)
finally:
global counter
with counter.get_lock():
counter.value += 1
return (file_results, file_path)
def init(ctr):
global counter
counter = ctr
def _main():
"""Called when the module is executed"""
def process_reports(reports_):
@@ -235,8 +238,8 @@ def _main():
opts.silent = general_config.getboolean("silent")
if "log_file" in general_config:
opts.log_file = general_config["log_file"]
if "n_cpus" in general_config:
opts.n_cpus = general_config.getint("n_cpus")
if "n_procs" in general_config:
opts.n_procs = general_config.getint("n_procs")
if "chunksize" in general_config:
opts.chunksize = general_config.getint("chunksize")
if "imap" in config.sections():
@@ -389,12 +392,13 @@ def _main():
file_paths = list(set(file_paths))
counter = Value('i', 0)
pool = Pool(opts.n_cpus, initializer=init, initargs=(counter,))
results = pool.starmap_async(cli_parse, zip(file_paths,
repeat(opts.strip_attachment_payloads),
repeat(opts.nameservers),
repeat(opts.dns_timeout),
repeat(opts.n_cpus >= 1)), opts.chunksize)
pool = Pool(opts.n_procs, initializer=init, initargs=(counter,))
results = pool.starmap_async(cli_parse,
zip(file_paths,
repeat(opts.strip_attachment_payloads),
repeat(opts.nameservers),
repeat(opts.dns_timeout),
repeat(opts.n_procs >= 1)), opts.chunksize)
pbar = tqdm(total=len(file_paths))
while not results.ready():
pbar.update(counter.value - pbar.n)
@@ -402,7 +406,7 @@ def _main():
pbar.close()
results = results.get()
pool.close()
pool.join()
pool.join()
for result in results:
if type(result[0]) is InvalidDMARCReport:
+2 -2
View File
@@ -267,7 +267,6 @@ def get_ip_address_country(ip_address):
Args:
location (str): Local location for the database file
"""
import pdb; pdb.set_trace()
url = "https://geolite.maxmind.com/download/geoip/database/" \
"GeoLite2-Country.tar.gz"
# Use a browser-like user agent string to bypass some proxy blocks
@@ -324,7 +323,8 @@ def get_ip_address_country(ip_address):
return country
def get_ip_address_info(ip_address, cache=None, nameservers=None, timeout=2.0, parallel=False):
def get_ip_address_info(ip_address, cache=None, nameservers=None,
timeout=2.0, parallel=False):
"""
Returns reverse DNS and country information for the given IP address