mirror of
https://github.com/scrapy/scrapy.git
synced 2025-02-06 11:00:46 +00:00
Sort out webclient tests.
This commit is contained in:
parent
d27c6b46b1
commit
16b998f9ca
@ -121,6 +121,7 @@ class ScrapyClientContextFactory(BrowserLikePolicyForHTTPS):
|
||||
# kept for old-style HTTP/1.0 downloader context twisted calls,
|
||||
# e.g. connectSSL()
|
||||
def getContext(self, hostname: Any = None, port: Any = None) -> SSL.Context:
|
||||
# FIXME
|
||||
ctx: SSL.Context = self.getCertificateOptions().getContext()
|
||||
ctx.set_options(0x4) # OP_LEGACY_SERVER_CONNECT
|
||||
return ctx
|
||||
|
@ -1,3 +1,5 @@
|
||||
"""Deprecated HTTP/1.0 helper classes used by HTTP10DownloadHandler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
@ -1,6 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import OpenSSL.SSL
|
||||
import pytest
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
from twisted.protocols.policies import WrappingFactory
|
||||
from twisted.trial import unittest
|
||||
from twisted.web import server, static
|
||||
from twisted.web.client import Agent, BrowserLikePolicyForHTTPS, readBody
|
||||
from twisted.web.client import Response as TxResponse
|
||||
|
||||
from scrapy.core.downloader import Slot
|
||||
from scrapy.core.downloader.contextfactory import (
|
||||
ScrapyClientContextFactory,
|
||||
load_context_factory_from_settings,
|
||||
)
|
||||
from scrapy.core.downloader.handlers.http11 import _RequestBodyProducer
|
||||
from scrapy.settings import Settings
|
||||
from scrapy.utils.defer import deferred_f_from_coro_f, maybe_deferred_to_future
|
||||
from scrapy.utils.misc import build_from_crawler
|
||||
from scrapy.utils.python import to_bytes
|
||||
from scrapy.utils.test import get_crawler
|
||||
from tests.mockserver import PayloadResource, ssl_context_factory
|
||||
|
||||
|
||||
class SlotTest(unittest.TestCase):
|
||||
@ -10,3 +35,112 @@ class SlotTest(unittest.TestCase):
|
||||
repr(slot),
|
||||
"Slot(concurrency=8, delay=0.10, randomize_delay=True)",
|
||||
)
|
||||
|
||||
|
||||
class ContextFactoryBaseTestCase(unittest.TestCase):
|
||||
context_factory = None
|
||||
|
||||
def _listen(self, site):
|
||||
return reactor.listenSSL(
|
||||
0,
|
||||
site,
|
||||
contextFactory=self.context_factory or ssl_context_factory(),
|
||||
interface="127.0.0.1",
|
||||
)
|
||||
|
||||
def getURL(self, path):
|
||||
return f"https://127.0.0.1:{self.portno}/{path}"
|
||||
|
||||
def setUp(self):
|
||||
self.tmpname = Path(mkdtemp())
|
||||
(self.tmpname / "file").write_bytes(b"0123456789")
|
||||
r = static.File(str(self.tmpname))
|
||||
r.putChild(b"payload", PayloadResource())
|
||||
self.site = server.Site(r, timeout=None)
|
||||
self.wrapper = WrappingFactory(self.site)
|
||||
self.port = self._listen(self.wrapper)
|
||||
self.portno = self.port.getHost().port
|
||||
|
||||
@inlineCallbacks
|
||||
def tearDown(self):
|
||||
yield self.port.stopListening()
|
||||
shutil.rmtree(self.tmpname)
|
||||
|
||||
@staticmethod
|
||||
async def get_page(
|
||||
url: str,
|
||||
client_context_factory: BrowserLikePolicyForHTTPS,
|
||||
body: str | None = None,
|
||||
) -> bytes:
|
||||
agent = Agent(reactor, contextFactory=client_context_factory)
|
||||
body_producer = _RequestBodyProducer(body.encode()) if body else None
|
||||
response: TxResponse = await maybe_deferred_to_future(
|
||||
agent.request(b"GET", url.encode(), bodyProducer=body_producer)
|
||||
)
|
||||
return await maybe_deferred_to_future(readBody(response)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ContextFactoryTestCase(ContextFactoryBaseTestCase):
|
||||
@deferred_f_from_coro_f
|
||||
async def testPayload(self):
|
||||
s = "0123456789" * 10
|
||||
crawler = get_crawler()
|
||||
settings = Settings()
|
||||
client_context_factory = load_context_factory_from_settings(settings, crawler)
|
||||
body = await self.get_page(
|
||||
self.getURL("payload"), client_context_factory, body=s
|
||||
)
|
||||
self.assertEqual(body, to_bytes(s))
|
||||
|
||||
|
||||
class ContextFactoryTLSMethodTestCase(ContextFactoryBaseTestCase):
|
||||
async def _assert_factory_works(
|
||||
self, client_context_factory: ScrapyClientContextFactory
|
||||
) -> None:
|
||||
s = "0123456789" * 10
|
||||
body = await self.get_page(
|
||||
self.getURL("payload"), client_context_factory, body=s
|
||||
)
|
||||
self.assertEqual(body, to_bytes(s))
|
||||
|
||||
@deferred_f_from_coro_f
|
||||
async def test_setting_default(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings()
|
||||
client_context_factory = load_context_factory_from_settings(settings, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
|
||||
await self._assert_factory_works(client_context_factory)
|
||||
|
||||
def test_setting_none(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": None})
|
||||
with pytest.raises(KeyError):
|
||||
load_context_factory_from_settings(settings, crawler)
|
||||
|
||||
def test_setting_bad(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
|
||||
with pytest.raises(KeyError):
|
||||
load_context_factory_from_settings(settings, crawler)
|
||||
|
||||
@deferred_f_from_coro_f
|
||||
async def test_setting_explicit(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "TLSv1.2"})
|
||||
client_context_factory = load_context_factory_from_settings(settings, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
|
||||
await self._assert_factory_works(client_context_factory)
|
||||
|
||||
@deferred_f_from_coro_f
|
||||
async def test_direct_from_crawler(self):
|
||||
# the setting is ignored
|
||||
crawler = get_crawler(settings_dict={"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
|
||||
client_context_factory = build_from_crawler(ScrapyClientContextFactory, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
|
||||
await self._assert_factory_works(client_context_factory)
|
||||
|
||||
@deferred_f_from_coro_f
|
||||
async def test_direct_init(self):
|
||||
client_context_factory = ScrapyClientContextFactory(OpenSSL.SSL.TLSv1_2_METHOD)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
|
||||
await self._assert_factory_works(client_context_factory)
|
||||
|
@ -422,6 +422,7 @@ class HttpTestCase(unittest.TestCase):
|
||||
return self.download_request(request, Spider("foo")).addCallback(_test)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::scrapy.exceptions.ScrapyDeprecationWarning")
|
||||
class Http10TestCase(HttpTestCase):
|
||||
"""HTTP 1.0 test case"""
|
||||
|
||||
@ -780,6 +781,7 @@ class HttpProxyTestCase(unittest.TestCase):
|
||||
return self.download_request(request, Spider("foo")).addCallback(_test)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::scrapy.exceptions.ScrapyDeprecationWarning")
|
||||
class Http10ProxyTestCase(HttpProxyTestCase):
|
||||
download_handler_cls: type = HTTP10DownloadHandler
|
||||
|
||||
|
@ -8,12 +8,11 @@ from __future__ import annotations
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any
|
||||
|
||||
import OpenSSL.SSL
|
||||
from pytest import raises
|
||||
import pytest
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.protocols.policies import WrappingFactory
|
||||
from twisted.trial import unittest
|
||||
@ -22,10 +21,8 @@ from twisted.web import resource, server, static, util
|
||||
from scrapy.core.downloader import webclient as client
|
||||
from scrapy.core.downloader.contextfactory import (
|
||||
ScrapyClientContextFactory,
|
||||
load_context_factory_from_settings,
|
||||
)
|
||||
from scrapy.http import Headers, Request
|
||||
from scrapy.settings import Settings
|
||||
from scrapy.utils.misc import build_from_crawler
|
||||
from scrapy.utils.python import to_bytes, to_unicode
|
||||
from scrapy.utils.test import get_crawler
|
||||
@ -38,6 +35,7 @@ from tests.mockserver import (
|
||||
PayloadResource,
|
||||
ssl_context_factory,
|
||||
)
|
||||
from tests.test_core_downloader import ContextFactoryBaseTestCase
|
||||
|
||||
|
||||
def getPage(url, contextFactory=None, response_transform=None, *args, **kwargs):
|
||||
@ -129,6 +127,7 @@ class ParseUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(client._parse(url), test, url)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::scrapy.exceptions.ScrapyDeprecationWarning")
|
||||
class ScrapyHTTPPageGetterTests(unittest.TestCase):
|
||||
def test_earlyHeaders(self):
|
||||
# basic test stolen from twisted HTTPageGetter
|
||||
@ -272,6 +271,7 @@ class EncodingResource(resource.Resource):
|
||||
return body.encode(self.out_encoding)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::scrapy.exceptions.ScrapyDeprecationWarning")
|
||||
class WebClientTestCase(unittest.TestCase):
|
||||
def _listen(self, site):
|
||||
return reactor.listenTCP(0, site, interface="127.0.0.1")
|
||||
@ -427,35 +427,8 @@ class WebClientTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class WebClientSSLTestCase(unittest.TestCase):
|
||||
context_factory = None
|
||||
|
||||
def _listen(self, site):
|
||||
return reactor.listenSSL(
|
||||
0,
|
||||
site,
|
||||
contextFactory=self.context_factory or ssl_context_factory(),
|
||||
interface="127.0.0.1",
|
||||
)
|
||||
|
||||
def getURL(self, path):
|
||||
return f"https://127.0.0.1:{self.portno}/{path}"
|
||||
|
||||
def setUp(self):
|
||||
self.tmpname = Path(mkdtemp())
|
||||
(self.tmpname / "file").write_bytes(b"0123456789")
|
||||
r = static.File(str(self.tmpname))
|
||||
r.putChild(b"payload", PayloadResource())
|
||||
self.site = server.Site(r, timeout=None)
|
||||
self.wrapper = WrappingFactory(self.site)
|
||||
self.port = self._listen(self.wrapper)
|
||||
self.portno = self.port.getHost().port
|
||||
|
||||
@inlineCallbacks
|
||||
def tearDown(self):
|
||||
yield self.port.stopListening()
|
||||
shutil.rmtree(self.tmpname)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::scrapy.exceptions.ScrapyDeprecationWarning")
|
||||
class WebClientSSLTestCase(ContextFactoryBaseTestCase):
|
||||
def testPayload(self):
|
||||
s = "0123456789" * 10
|
||||
return getPage(self.getURL("payload"), body=s).addCallback(
|
||||
@ -490,51 +463,3 @@ class WebClientCustomCiphersSSLTestCase(WebClientSSLTestCase):
|
||||
self.getURL("payload"), body=s, contextFactory=client_context_factory
|
||||
)
|
||||
return self.assertFailure(d, OpenSSL.SSL.Error)
|
||||
|
||||
|
||||
class WebClientTLSMethodTestCase(WebClientSSLTestCase):
|
||||
def _assert_factory_works(
|
||||
self, client_context_factory: ScrapyClientContextFactory
|
||||
) -> Deferred[Any]:
|
||||
s = "0123456789" * 10
|
||||
return getPage(
|
||||
self.getURL("payload"), body=s, contextFactory=client_context_factory
|
||||
).addCallback(self.assertEqual, to_bytes(s))
|
||||
|
||||
def test_setting_default(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings()
|
||||
client_context_factory = load_context_factory_from_settings(settings, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
|
||||
return self._assert_factory_works(client_context_factory)
|
||||
|
||||
def test_setting_none(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": None})
|
||||
with raises(KeyError):
|
||||
load_context_factory_from_settings(settings, crawler)
|
||||
|
||||
def test_setting_bad(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
|
||||
with raises(KeyError):
|
||||
load_context_factory_from_settings(settings, crawler)
|
||||
|
||||
def test_setting_explicit(self):
|
||||
crawler = get_crawler()
|
||||
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "TLSv1.2"})
|
||||
client_context_factory = load_context_factory_from_settings(settings, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
|
||||
return self._assert_factory_works(client_context_factory)
|
||||
|
||||
def test_direct_from_crawler(self):
|
||||
# the setting is ignored
|
||||
crawler = get_crawler(settings_dict={"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
|
||||
client_context_factory = build_from_crawler(ScrapyClientContextFactory, crawler)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
|
||||
return self._assert_factory_works(client_context_factory)
|
||||
|
||||
def test_direct_init(self):
|
||||
client_context_factory = ScrapyClientContextFactory(OpenSSL.SSL.TLSv1_2_METHOD)
|
||||
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
|
||||
return self._assert_factory_works(client_context_factory)
|
||||
|
Loading…
x
Reference in New Issue
Block a user