diff --git a/scrapy/utils/iterators.py b/scrapy/utils/iterators.py index 6b89334e0..9c53ab524 100644 --- a/scrapy/utils/iterators.py +++ b/scrapy/utils/iterators.py @@ -12,11 +12,14 @@ from typing import ( List, Literal, Optional, + Tuple, Union, cast, overload, ) +from lxml import etree + from scrapy.http import Response, TextResponse from scrapy.selector import Selector from scrapy.utils.python import re_rsearch, to_unicode @@ -77,15 +80,31 @@ def xmliter( yield Selector(text=nodetext, type="xml") +def _resolve_xml_namespace(element_name: str, data: bytes) -> Tuple[str, str]: + if ":" not in element_name: + return element_name, None, None + reader: "SupportsReadClose[bytes]" = _StreamReader(data) + node_prefix, element_name = element_name.split(":", maxsplit=1) + ns_iterator = etree.iterparse( + reader, encoding=reader.encoding, events=("start-ns",) + ) + for event, (_prefix, _namespace) in ns_iterator: + if _prefix != node_prefix: + continue + return element_name, _prefix, _namespace + return f"{node_prefix}:{element_name}", None, None + + def xmliter_lxml( obj: Union[Response, str, bytes], nodename: str, namespace: Optional[str] = None, prefix: str = "x", ) -> Generator[Selector, Any, None]: - from lxml import etree + if not namespace: + nodename, prefix, namespace = _resolve_xml_namespace(nodename, obj) - reader = _StreamReader(obj) + reader: "SupportsReadClose[bytes]" = _StreamReader(obj) tag = f"{{{namespace}}}{nodename}" if namespace else nodename iterable = etree.iterparse( cast("SupportsReadClose[bytes]", reader), tag=tag, encoding=reader.encoding diff --git a/tests/test_utils_iterators.py b/tests/test_utils_iterators.py index 3598fa0bb..24f03155b 100644 --- a/tests/test_utils_iterators.py +++ b/tests/test_utils_iterators.py @@ -1,4 +1,3 @@ -from pytest import mark from twisted.trial import unittest from scrapy.http import Response, TextResponse, XmlResponse @@ -247,10 +246,6 @@ class XmliterTestCase(unittest.TestCase): class LxmlXmliterTestCase(XmliterTestCase): xmliter = staticmethod(xmliter_lxml) - @mark.xfail(reason="known bug of the current implementation") - def test_xmliter_namespaced_nodename(self): - super().test_xmliter_namespaced_nodename() - def test_xmliter_iterate_namespace(self): body = b"""