diff --git a/README.rst b/README.rst index e827f45..8685ee8 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,8 @@ The full set of configuration options are: - ``debug`` - bool: Print debugging messages - ``silent`` - bool: Only print errors (Default: True) - ``log_file`` - str: Write log messages to a file at this path + - ``n_procs`` - str: Number of process to run in parallel when parsing in CLI mode (Default: 1) + - ``chunksize`` - str: Number of files to give to each process when running in parallel. Setting this to a number larger than one can improve performance when processing thousands of files - ``imap`` - ``host`` - str: The IMAP server hostname or IP address - ``port`` - int: The IMAP server port (Default: 993) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 1514a8e..aab42e2 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -83,7 +83,8 @@ class InvalidForensicReport(InvalidDMARCReport): """Raised when an invalid DMARC forensic report is encountered""" -def _parse_report_record(record, nameservers=None, dns_timeout=2.0): +def _parse_report_record(record, nameservers=None, dns_timeout=2.0, + parallel=False): """ Converts a record from a DMARC aggregate report into a more consistent format @@ -106,7 +107,8 @@ def _parse_report_record(record, nameservers=None, dns_timeout=2.0): new_record_source = get_ip_address_info(record["row"]["source_ip"], cache=IP_ADDRESS_CACHE, nameservers=nameservers, - timeout=dns_timeout) + timeout=dns_timeout, + parallel=parallel) new_record["source"] = new_record_source new_record["count"] = int(record["row"]["count"]) policy_evaluated = record["row"]["policy_evaluated"].copy() @@ -206,7 +208,8 @@ def _parse_report_record(record, nameservers=None, dns_timeout=2.0): return new_record -def parse_aggregate_report_xml(xml, nameservers=None, timeout=2.0): +def parse_aggregate_report_xml(xml, nameservers=None, timeout=2.0, + parallel=False): """Parses a DMARC XML report string and returns a consistent OrderedDict Args: @@ -305,13 +308,15 @@ def parse_aggregate_report_xml(xml, nameservers=None, timeout=2.0): for record in report["record"]: report_record = _parse_report_record(record, nameservers=nameservers, - dns_timeout=timeout) + dns_timeout=timeout, + parallel=parallel) records.append(report_record) else: report_record = _parse_report_record(report["record"], nameservers=nameservers, - dns_timeout=timeout) + dns_timeout=timeout, + parallel=parallel) records.append(report_record) new_report["records"] = records @@ -377,7 +382,8 @@ def extract_xml(input_): return xml -def parse_aggregate_report_file(_input, nameservers=None, dns_timeout=2.0): +def parse_aggregate_report_file(_input, nameservers=None, dns_timeout=2.0, + parallel=False): """Parses a file at the given path, a file-like object. or bytes as a aggregate DMARC report @@ -394,7 +400,8 @@ def parse_aggregate_report_file(_input, nameservers=None, dns_timeout=2.0): return parse_aggregate_report_xml(xml, nameservers=nameservers, - timeout=dns_timeout) + timeout=dns_timeout, + parallel=parallel) def parsed_aggregate_reports_to_csv(reports): @@ -509,7 +516,8 @@ def parsed_aggregate_reports_to_csv(reports): def parse_forensic_report(feedback_report, sample, msg_date, nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False): + strip_attachment_payloads=False, + parallel=False): """ Converts a DMARC forensic report and sample to a ``OrderedDict`` @@ -553,7 +561,8 @@ def parse_forensic_report(feedback_report, sample, msg_date, ip_address = parsed_report["source_ip"] parsed_report_source = get_ip_address_info(ip_address, nameservers=nameservers, - timeout=dns_timeout) + timeout=dns_timeout, + parallel=parallel) parsed_report["source"] = parsed_report_source del parsed_report["source_ip"] @@ -653,7 +662,7 @@ def parsed_forensic_reports_to_csv(reports): def parse_report_email(input_, nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False): + strip_attachment_payloads=False, parallel=False): """ Parses a DMARC report from an email @@ -724,7 +733,8 @@ def parse_report_email(input_, nameservers=None, dns_timeout=2.0, aggregate_report = parse_aggregate_report_file( payload, nameservers=ns, - dns_timeout=dns_timeout) + dns_timeout=dns_timeout, + parallel=parallel) result = OrderedDict([("report_type", "aggregate"), ("report", aggregate_report)]) return result @@ -751,13 +761,15 @@ def parse_report_email(input_, nameservers=None, dns_timeout=2.0, date, nameservers=nameservers, dns_timeout=dns_timeout, - strip_attachment_payloads=strip_attachment_payloads) + strip_attachment_payloads=strip_attachment_payloads, + parallel=parallel) except InvalidForensicReport as e: error = 'Message with subject "{0}" ' \ 'is not a valid ' \ 'forensic DMARC report: {1}'.format(subject, e) raise InvalidForensicReport(error) except Exception as e: + print("DEBUGGGING: {}".format(e)) raise InvalidForensicReport(e.__str__()) result = OrderedDict([("report_type", "forensic"), @@ -771,7 +783,7 @@ def parse_report_email(input_, nameservers=None, dns_timeout=2.0, def parse_report_file(input_, nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False): + strip_attachment_payloads=False, parallel=False): """Parses a DMARC aggregate or forensic file at the given path, a file-like object. or bytes @@ -796,7 +808,8 @@ def parse_report_file(input_, nameservers=None, dns_timeout=2.0, content = file_object.read() try: report = parse_aggregate_report_file(content, nameservers=nameservers, - dns_timeout=dns_timeout) + dns_timeout=dns_timeout, + parallel=parallel) results = OrderedDict([("report_type", "aggregate"), ("report", report)]) except InvalidAggregateReport: @@ -805,7 +818,8 @@ def parse_report_file(input_, nameservers=None, dns_timeout=2.0, results = parse_report_email(content, nameservers=nameservers, dns_timeout=dns_timeout, - strip_attachment_payloads=sa) + strip_attachment_payloads=sa, + parallel=parallel) except InvalidDMARCReport: raise InvalidDMARCReport("Not a valid aggregate or forensic " "report") diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index 14883a3..33331e3 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -11,10 +11,15 @@ import logging from collections import OrderedDict import json from ssl import CERT_NONE, create_default_context +from multiprocessing import Pool, Value +from itertools import repeat +import time +from tqdm import tqdm from parsedmarc import IMAPError, get_dmarc_reports_from_inbox, \ parse_report_file, elastic, kafkaclient, splunk, save_output, \ - watch_inbox, email_results, SMTPError, ParserError, __version__ + watch_inbox, email_results, SMTPError, ParserError, __version__, \ + InvalidDMARCReport logger = logging.getLogger("parsedmarc") @@ -25,6 +30,28 @@ def _str_to_list(s): return list(map(lambda i: i.lstrip(), _list)) +def cli_parse(file_path, sa, nameservers, dns_timeout, parallel=False): + """Separated this function for multiprocessing""" + try: + file_results = parse_report_file(file_path, + nameservers=nameservers, + dns_timeout=dns_timeout, + strip_attachment_payloads=sa, + parallel=parallel) + except ParserError as error: + return (error, file_path) + finally: + global counter + with counter.get_lock(): + counter.value += 1 + return (file_results, file_path) + + +def init(ctr): + global counter + counter = ctr + + def _main(): """Called when the module is executed""" def process_reports(reports_): @@ -134,7 +161,7 @@ def _main(): args = arg_parser.parse_args() opts = Namespace(file_path=args.file_path, - onfig_file=args.config_file, + config_file=args.config_file, strip_attachment_payloads=args.strip_attachment_payloads, output=args.output, nameservers=args.nameservers, @@ -178,7 +205,9 @@ def _main(): smtp_to=[], smtp_subject="parsedmarc report", smtp_message="Please see the attached DMARC results.", - log_file=args.log_file + log_file=args.log_file, + n_procs=1, + chunksize=1 ) args = arg_parser.parse_args() @@ -211,6 +240,10 @@ def _main(): opts.silent = general_config.getboolean("silent") if "log_file" in general_config: opts.log_file = general_config["log_file"] + if "n_procs" in general_config: + opts.n_procs = general_config.getint("n_procs") + if "chunksize" in general_config: + opts.chunksize = general_config.getint("chunksize") if "imap" in config.sections(): imap_config = config["imap"] if "host" in imap_config: @@ -360,21 +393,33 @@ def _main(): file_paths += glob(file_path) file_paths = list(set(file_paths)) - for file_path in file_paths: - try: - sa = opts.strip_attachment_payloads - file_results = parse_report_file(file_path, - nameservers=opts.nameservers, - dns_timeout=opts.dns_timeout, - strip_attachment_payloads=sa) - if file_results["report_type"] == "aggregate": - aggregate_reports.append(file_results["report"]) - elif file_results["report_type"] == "forensic": - forensic_reports.append(file_results["report"]) + counter = Value('i', 0) + pool = Pool(opts.n_procs, initializer=init, initargs=(counter,)) + results = pool.starmap_async(cli_parse, + zip(file_paths, + repeat(opts.strip_attachment_payloads), + repeat(opts.nameservers), + repeat(opts.dns_timeout), + repeat(opts.n_procs >= 1)), + opts.chunksize) + pbar = tqdm(total=len(file_paths)) + while not results.ready(): + pbar.update(counter.value - pbar.n) + time.sleep(0.1) + pbar.close() + results = results.get() + pool.close() + pool.join() - except ParserError as error: - logger.error("Failed to parse {0} - {1}".format(file_path, - error)) + for result in results: + if type(result[0]) is InvalidDMARCReport: + logger.error("Failed to parse {0} - {1}".format(result[1], + result[0])) + else: + if result[0]["report_type"] == "aggregate": + aggregate_reports.append(result[0]["report"]) + elif result[0]["report_type"] == "forensic": + forensic_reports.append(result[0]["report"]) if opts.imap_host: try: diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index 85c126f..e86dfca 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -323,7 +323,8 @@ def get_ip_address_country(ip_address): return country -def get_ip_address_info(ip_address, cache=None, nameservers=None, timeout=2.0): +def get_ip_address_info(ip_address, cache=None, nameservers=None, + timeout=2.0, parallel=False): """ Returns reverse DNS and country information for the given IP address @@ -348,8 +349,11 @@ def get_ip_address_info(ip_address, cache=None, nameservers=None, timeout=2.0): reverse_dns = get_reverse_dns(ip_address, nameservers=nameservers, timeout=timeout) - country = get_ip_address_country(ip_address) - info["country"] = country + if not parallel: + country = get_ip_address_country(ip_address) + info["country"] = country + else: + info["country"] = None info["reverse_dns"] = reverse_dns info["base_domain"] = None if reverse_dns is not None: diff --git a/requirements.txt b/requirements.txt index 8e0ff1a..0e5977d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ flake8 sphinx sphinx_rtd_theme wheel +tqdm rstcheck>=3.3.1 pygments dnspython>=1.16.0 @@ -17,4 +18,4 @@ mail-parser>=3.9.2 dateparser>=0.7.1 elasticsearch>=6.3.1 elasticsearch-dsl>=0.0.12 -kafka-python>=1.4.4 +kafka-python>=1.4.4 \ No newline at end of file