1
0
mirror of https://github.com/scrapy/scrapy.git synced 2025-02-06 11:00:46 +00:00
scrapy/tests/test_core_downloader.py
2025-01-28 02:08:49 +05:00

173 lines
6.4 KiB
Python

from __future__ import annotations
import shutil
import warnings
from pathlib import Path
from tempfile import mkdtemp
from typing import Any
import OpenSSL.SSL
import pytest
from twisted.internet import reactor
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.protocols.policies import WrappingFactory
from twisted.trial import unittest
from twisted.web import server, static
from twisted.web.client import Agent, BrowserLikePolicyForHTTPS, readBody
from twisted.web.client import Response as TxResponse
from scrapy.core.downloader import Slot
from scrapy.core.downloader.contextfactory import (
ScrapyClientContextFactory,
load_context_factory_from_settings,
)
from scrapy.core.downloader.handlers.http11 import _RequestBodyProducer
from scrapy.settings import Settings
from scrapy.utils.defer import deferred_f_from_coro_f, maybe_deferred_to_future
from scrapy.utils.misc import build_from_crawler
from scrapy.utils.python import to_bytes
from scrapy.utils.test import get_crawler
from tests.mockserver import PayloadResource, ssl_context_factory
class SlotTest(unittest.TestCase):
def test_repr(self):
slot = Slot(concurrency=8, delay=0.1, randomize_delay=True)
self.assertEqual(
repr(slot),
"Slot(concurrency=8, delay=0.10, randomize_delay=True)",
)
class ContextFactoryBaseTestCase(unittest.TestCase):
context_factory = None
def _listen(self, site):
return reactor.listenSSL(
0,
site,
contextFactory=self.context_factory or ssl_context_factory(),
interface="127.0.0.1",
)
def getURL(self, path):
return f"https://127.0.0.1:{self.portno}/{path}"
def setUp(self):
self.tmpname = Path(mkdtemp())
(self.tmpname / "file").write_bytes(b"0123456789")
r = static.File(str(self.tmpname))
r.putChild(b"payload", PayloadResource())
self.site = server.Site(r, timeout=None)
self.wrapper = WrappingFactory(self.site)
self.port = self._listen(self.wrapper)
self.portno = self.port.getHost().port
@inlineCallbacks
def tearDown(self):
yield self.port.stopListening()
shutil.rmtree(self.tmpname)
@staticmethod
async def get_page(
url: str,
client_context_factory: BrowserLikePolicyForHTTPS,
body: str | None = None,
) -> bytes:
agent = Agent(reactor, contextFactory=client_context_factory)
body_producer = _RequestBodyProducer(body.encode()) if body else None
response: TxResponse = await maybe_deferred_to_future(
agent.request(b"GET", url.encode(), bodyProducer=body_producer)
)
with warnings.catch_warnings():
# https://github.com/twisted/twisted/issues/8227
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r".*does not have an abortConnection method",
)
d: Deferred[bytes] = readBody(response) # type: ignore[arg-type]
return await maybe_deferred_to_future(d)
class ContextFactoryTestCase(ContextFactoryBaseTestCase):
@deferred_f_from_coro_f
async def testPayload(self):
s = "0123456789" * 10
crawler = get_crawler()
settings = Settings()
client_context_factory = load_context_factory_from_settings(settings, crawler)
body = await self.get_page(
self.getURL("payload"), client_context_factory, body=s
)
self.assertEqual(body, to_bytes(s))
def test_override_getContext(self):
class MyFactory(ScrapyClientContextFactory):
def getContext(
self, hostname: Any = None, port: Any = None
) -> OpenSSL.SSL.Context:
ctx: OpenSSL.SSL.Context = super().getContext(hostname, port)
return ctx
with warnings.catch_warnings(record=True) as w:
MyFactory()
self.assertEqual(len(w), 1)
self.assertIn(
"Overriding ScrapyClientContextFactory.getContext() is deprecated",
str(w[0].message),
)
class ContextFactoryTLSMethodTestCase(ContextFactoryBaseTestCase):
async def _assert_factory_works(
self, client_context_factory: ScrapyClientContextFactory
) -> None:
s = "0123456789" * 10
body = await self.get_page(
self.getURL("payload"), client_context_factory, body=s
)
self.assertEqual(body, to_bytes(s))
@deferred_f_from_coro_f
async def test_setting_default(self):
crawler = get_crawler()
settings = Settings()
client_context_factory = load_context_factory_from_settings(settings, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
await self._assert_factory_works(client_context_factory)
def test_setting_none(self):
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": None})
with pytest.raises(KeyError):
load_context_factory_from_settings(settings, crawler)
def test_setting_bad(self):
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
with pytest.raises(KeyError):
load_context_factory_from_settings(settings, crawler)
@deferred_f_from_coro_f
async def test_setting_explicit(self):
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "TLSv1.2"})
client_context_factory = load_context_factory_from_settings(settings, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
await self._assert_factory_works(client_context_factory)
@deferred_f_from_coro_f
async def test_direct_from_crawler(self):
# the setting is ignored
crawler = get_crawler(settings_dict={"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
client_context_factory = build_from_crawler(ScrapyClientContextFactory, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
await self._assert_factory_works(client_context_factory)
@deferred_f_from_coro_f
async def test_direct_init(self):
client_context_factory = ScrapyClientContextFactory(OpenSSL.SSL.TLSv1_2_METHOD)
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
await self._assert_factory_works(client_context_factory)