Fix etree import to type checkers don't complain

This commit is contained in:
Sean Whalen
2025-12-24 14:37:38 -05:00
parent b99bd67225
commit b5773c6b4a

View File

@@ -4,7 +4,19 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Any, Union, Optional, IO, Callable from typing import (
Dict,
List,
Any,
Union,
Optional,
IO,
Callable,
BinaryIO,
Protocol,
runtime_checkable,
cast,
)
import binascii import binascii
import email import email
@@ -27,7 +39,7 @@ from io import BytesIO, StringIO
import mailparser import mailparser
import xmltodict import xmltodict
from expiringdict import ExpiringDict from expiringdict import ExpiringDict
from lxml import etree import lxml.etree as etree
from mailsuite.smtp import send_email from mailsuite.smtp import send_email
from parsedmarc.log import logger from parsedmarc.log import logger
@@ -847,7 +859,14 @@ def parse_aggregate_report_xml(
raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__())) raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__()))
def extract_report(content: Union[bytes, str, IO[Any]]) -> str: @runtime_checkable
class _ReadableSeekable(Protocol):
def read(self, n: int = -1) -> bytes: ...
def seek(self, offset: int, whence: int = 0) -> int: ...
def tell(self) -> int: ...
def extract_report(content: Union[bytes, bytearray, memoryview, str, BinaryIO]) -> str:
""" """
Extracts text from a zip or gzip file, as a base64-encoded string, Extracts text from a zip or gzip file, as a base64-encoded string,
file-like object, or bytes. file-like object, or bytes.
@@ -860,19 +879,22 @@ def extract_report(content: Union[bytes, str, IO[Any]]) -> str:
str: The extracted text str: The extracted text
""" """
file_object = None file_object: Optional[_ReadableSeekable] = None
try: try:
if isinstance(content, str): if isinstance(content, str):
try: try:
file_object = BytesIO(b64decode(content)) file_object = BytesIO(b64decode(content))
except binascii.Error: except binascii.Error:
return content return content
elif type(content) is bytes: elif isinstance(content, (bytes, bytearray, memoryview)):
file_object = BytesIO(content) file_object = BytesIO(bytes(content))
else: else:
file_object = content file_object = cast(_ReadableSeekable, content)
header = file_object.read(6) header = file_object.read(6)
if isinstance(header, str):
raise ParserError("File objects must be opened in binary (rb) mode")
file_object.seek(0) file_object.seek(0)
if header.startswith(MAGIC_ZIP): if header.startswith(MAGIC_ZIP):
_zip = zipfile.ZipFile(file_object) _zip = zipfile.ZipFile(file_object)
@@ -884,19 +906,18 @@ def extract_report(content: Union[bytes, str, IO[Any]]) -> str:
elif header.startswith(MAGIC_XML) or header.startswith(MAGIC_JSON): elif header.startswith(MAGIC_XML) or header.startswith(MAGIC_JSON):
report = file_object.read().decode(errors="ignore") report = file_object.read().decode(errors="ignore")
else: else:
file_object.close()
raise ParserError("Not a valid zip, gzip, json, or xml file") raise ParserError("Not a valid zip, gzip, json, or xml file")
file_object.close()
except UnicodeDecodeError: except UnicodeDecodeError:
if file_object:
file_object.close()
raise ParserError("File objects must be opened in binary (rb) mode") raise ParserError("File objects must be opened in binary (rb) mode")
except Exception as error: except Exception as error:
if file_object:
file_object.close()
raise ParserError("Invalid archive file: {0}".format(error.__str__())) raise ParserError("Invalid archive file: {0}".format(error.__str__()))
finally:
if file_object and hasattr(file_object, "close"):
try:
file_object.close()
except Exception:
pass
return report return report