diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 280c7b0..94fe60b 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -962,10 +962,12 @@ def extract_report(content: Union[bytes, str, BinaryIO]) -> str: return report -def extract_report_from_file_path(file_path: str): +def extract_report_from_file_path( + file_path: Union[str, bytes, os.PathLike[str], os.PathLike[bytes]], +) -> str: """Extracts report from a file at the given file_path""" try: - with open(file_path, "rb") as report_file: + with open(os.fspath(file_path), "rb") as report_file: return extract_report(report_file.read()) except FileNotFoundError: raise ParserError("File was not found") @@ -1660,7 +1662,7 @@ def parse_report_email( def parse_report_file( - input_: Union[bytes, str, BinaryIO], + input_: Union[bytes, str, os.PathLike[str], os.PathLike[bytes], BinaryIO], *, nameservers: Optional[list[str]] = None, dns_timeout: float = 2.0, @@ -1677,7 +1679,8 @@ def parse_report_file( file-like object. or bytes Args: - input_ (str | bytes | BinaryIO): A path to a file, a file like object, or bytes + input_ (str | os.PathLike | bytes | BinaryIO): A path to a file, + a file-like object, or bytes nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout (float): Sets the DNS timeout in seconds @@ -1694,9 +1697,10 @@ def parse_report_file( dict: The parsed DMARC report """ file_object: BinaryIO - if isinstance(input_, str): - logger.debug("Parsing {0}".format(input_)) - file_object = open(input_, "rb") + if isinstance(input_, (str, os.PathLike)): + file_path = os.fspath(input_) + logger.debug("Parsing {0}".format(file_path)) + file_object = open(file_path, "rb") elif isinstance(input_, (bytes, bytearray, memoryview)): file_object = BytesIO(bytes(input_)) else: diff --git a/tests.py b/tests.py index 7f54b60..4ab6fb1 100755 --- a/tests.py +++ b/tests.py @@ -100,15 +100,23 @@ class Test(unittest.TestCase): def testExtractReportXML(self): """Test extract report function for XML input""" print() - file = "samples/extract_report/nice-input.xml" - print("Testing {0}: ".format(file), end="") - xmlout = parsedmarc.extract_report_from_file_path(file) + report_path = "samples/extract_report/nice-input.xml" + print("Testing {0}: ".format(report_path), end="") + xmlout = parsedmarc.extract_report_from_file_path(report_path) xmlin_file = open("samples/extract_report/nice-input.xml") xmlin = xmlin_file.read() xmlin_file.close() self.assertTrue(compare_xml(xmlout, xmlin)) print("Passed!") + def testExtractReportXMLFromPath(self): + """Test extract report function for pathlib.Path input""" + report_path = Path("samples/extract_report/nice-input.xml") + xmlout = parsedmarc.extract_report_from_file_path(report_path) + with open("samples/extract_report/nice-input.xml") as xmlin_file: + xmlin = xmlin_file.read() + self.assertTrue(compare_xml(xmlout, xmlin)) + def testExtractReportGZip(self): """Test extract report function for gzip input""" print() @@ -137,6 +145,28 @@ class Test(unittest.TestCase): self.assertFalse(compare_xml(xmlout, xmlin)) print("Passed!") + def testParseReportFileAcceptsPathForXML(self): + report_path = Path( + "samples/aggregate/protection.outlook.com!example.com!1711756800!1711843200.xml" + ) + result = parsedmarc.parse_report_file( + report_path, + offline=True, + ) + self.assertEqual(result["report_type"], "aggregate") + self.assertEqual(result["report"]["report_metadata"]["org_name"], "outlook.com") + + def testParseReportFileAcceptsPathForEmail(self): + report_path = Path( + "samples/aggregate/Report domain- borschow.com Submitter- google.com Report-ID- 949348866075514174.eml" + ) + result = parsedmarc.parse_report_file( + report_path, + offline=True, + ) + self.assertEqual(result["report_type"], "aggregate") + self.assertEqual(result["report"]["report_metadata"]["org_name"], "google.com") + def testAggregateSamples(self): """Test sample aggregate/rua DMARC reports""" print()