diff --git a/docs/source/usage.md b/docs/source/usage.md index 98f3aaf..cb73223 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -126,8 +126,6 @@ The full set of configuration options are: - `log_file` - str: Write log messages to a file at this path - `n_procs` - int: Number of process to run in parallel when parsing in CLI mode (Default: `1`) - - `chunk_size` - int: Number of files to give to each process - when running in parallel. :::{note} Setting this to a number larger than one can improve diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index a7b1331..ec596da 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -8,13 +8,12 @@ import os from configparser import ConfigParser from glob import glob import logging +import math from collections import OrderedDict import json from ssl import CERT_NONE, create_default_context -from multiprocessing import Pool, Value -from itertools import repeat +from multiprocessing import Pipe, Process import sys -import time from tqdm import tqdm from parsedmarc import get_dmarc_reports_from_mailbox, watch_inbox, \ @@ -42,7 +41,7 @@ def _str_to_list(s): def cli_parse(file_path, sa, nameservers, dns_timeout, - ip_db_path, offline, parallel=False): + ip_db_path, offline, conn, parallel=False): """Separated this function for multiprocessing""" try: file_results = parse_report_file(file_path, @@ -52,18 +51,11 @@ def cli_parse(file_path, sa, nameservers, dns_timeout, dns_timeout=dns_timeout, strip_attachment_payloads=sa, parallel=parallel) + conn.send([file_results, file_path]) except ParserError as error: - return error, file_path + conn.send([error, file_path]) finally: - global counter - with counter.get_lock(): - counter.value += 1 - return file_results, file_path - - -def init(ctr): - global counter - counter = ctr + conn.close() def _main(): @@ -481,7 +473,6 @@ def _main(): gmail_api_oauth2_port=8080, log_file=args.log_file, n_procs=1, - chunk_size=1, ip_db_path=None, la_client_id=None, la_client_secret=None, @@ -551,8 +542,6 @@ def _main(): opts.log_file = general_config["log_file"] if "n_procs" in general_config: opts.n_procs = general_config.getint("n_procs") - if "chunk_size" in general_config: - opts.chunk_size = general_config.getint("chunk_size") if "ip_db_path" in general_config: opts.ip_db_path = general_config["ip_db_path"] else: @@ -1144,29 +1133,49 @@ def _main(): for mbox_path in mbox_paths: file_paths.remove(mbox_path) - counter = Value('i', 0) - pool = Pool(opts.n_procs, initializer=init, initargs=(counter,)) - results = pool.starmap_async(cli_parse, - zip(file_paths, - repeat(opts.strip_attachment_payloads), - repeat(opts.nameservers), - repeat(opts.dns_timeout), - repeat(opts.ip_db_path), - repeat(opts.offline), - repeat(opts.n_procs >= 1)), - opts.chunk_size) + counter = 0 + + results = [] + if sys.stdout.isatty(): pbar = tqdm(total=len(file_paths)) - while not results.ready(): - pbar.update(counter.value - pbar.n) - time.sleep(0.1) - pbar.close() - else: - while not results.ready(): - time.sleep(0.1) - results = results.get() - pool.close() - pool.join() + + for batch_index in range(math.ceil(len(file_paths) / opts.n_procs)): + processes = [] + connections = [] + + for proc_index in range( + opts.n_procs * batch_index, + opts.n_procs * (batch_index + 1)): + if proc_index >= len(file_paths): + break + + parent_conn, child_conn = Pipe() + connections.append(parent_conn) + + process = Process(target=cli_parse, args=( + file_paths[proc_index], + opts.strip_attachment_payloads, + opts.nameservers, + opts.dns_timeout, + opts.ip_db_path, + opts.offline, + child_conn, + opts.n_procs >= 1, + )) + processes.append(process) + + for proc in processes: + proc.start() + + for proc in processes: + proc.join() + if sys.stdout.isatty(): + counter += 1 + pbar.update(counter - pbar.n) + + for conn in connections: + results.append(conn.recv()) for result in results: if type(result[0]) is ParserError: