1
0
mirror of https://github.com/scrapy/scrapy.git synced 2025-02-06 10:24:24 +00:00
scrapy/tests/test_http2_client_protocol.py
2024-10-16 10:03:16 +02:00

706 lines
24 KiB
Python

from __future__ import annotations
import json
import random
import re
import shutil
import string
from ipaddress import IPv4Address
from pathlib import Path
from tempfile import mkdtemp
from typing import TYPE_CHECKING
from unittest import mock, skipIf
from urllib.parse import urlencode
from twisted.internet import reactor
from twisted.internet.defer import (
CancelledError,
Deferred,
DeferredList,
inlineCallbacks,
)
from twisted.internet.endpoints import SSL4ClientEndpoint, SSL4ServerEndpoint
from twisted.internet.error import TimeoutError
from twisted.internet.ssl import Certificate, PrivateCertificate, optionsForClientTLS
from twisted.trial.unittest import TestCase
from twisted.web.client import URI, ResponseFailed
from twisted.web.http import H2_ENABLED
from twisted.web.http import Request as TxRequest
from twisted.web.server import NOT_DONE_YET, Site
from twisted.web.static import File
from scrapy.http import JsonRequest, Request, Response
from scrapy.settings import Settings
from scrapy.spiders import Spider
from tests.mockserver import LeafResource, Status, ssl_context_factory
if TYPE_CHECKING:
from twisted.python.failure import Failure
def generate_random_string(size):
return "".join(random.choices(string.ascii_uppercase + string.digits, k=size))
def make_html_body(val):
response = f"""<html>
<h1>Hello from HTTP2<h1>
<p>{val}</p>
</html>"""
return bytes(response, "utf-8")
class DummySpider(Spider):
name = "dummy"
start_urls: list = []
def parse(self, response):
print(response)
class Data:
SMALL_SIZE = 1024 # 1 KB
LARGE_SIZE = 1024**2 # 1 MB
STR_SMALL = generate_random_string(SMALL_SIZE)
STR_LARGE = generate_random_string(LARGE_SIZE)
EXTRA_SMALL = generate_random_string(1024 * 15)
EXTRA_LARGE = generate_random_string((1024**2) * 15)
HTML_SMALL = make_html_body(STR_SMALL)
HTML_LARGE = make_html_body(STR_LARGE)
JSON_SMALL = {"data": STR_SMALL}
JSON_LARGE = {"data": STR_LARGE}
DATALOSS = b"Dataloss Content"
NO_CONTENT_LENGTH = b"This response do not have any content-length header"
class GetDataHtmlSmall(LeafResource):
def render_GET(self, request: TxRequest):
request.setHeader("Content-Type", "text/html; charset=UTF-8")
return Data.HTML_SMALL
class GetDataHtmlLarge(LeafResource):
def render_GET(self, request: TxRequest):
request.setHeader("Content-Type", "text/html; charset=UTF-8")
return Data.HTML_LARGE
class PostDataJsonMixin:
@staticmethod
def make_response(request: TxRequest, extra_data: str):
assert request.content is not None
response = {
"request-headers": {},
"request-body": json.loads(request.content.read()),
"extra-data": extra_data,
}
for k, v in request.requestHeaders.getAllRawHeaders():
response["request-headers"][str(k, "utf-8")] = str(v[0], "utf-8")
response_bytes = bytes(json.dumps(response), "utf-8")
request.setHeader("Content-Type", "application/json; charset=UTF-8")
request.setHeader("Content-Encoding", "UTF-8")
return response_bytes
class PostDataJsonSmall(LeafResource, PostDataJsonMixin):
def render_POST(self, request: TxRequest):
return self.make_response(request, Data.EXTRA_SMALL)
class PostDataJsonLarge(LeafResource, PostDataJsonMixin):
def render_POST(self, request: TxRequest):
return self.make_response(request, Data.EXTRA_LARGE)
class Dataloss(LeafResource):
def render_GET(self, request: TxRequest):
request.setHeader(b"Content-Length", b"1024")
self.deferRequest(request, 0, self._delayed_render, request)
return NOT_DONE_YET
@staticmethod
def _delayed_render(request: TxRequest):
request.write(Data.DATALOSS)
request.finish()
class NoContentLengthHeader(LeafResource):
def render_GET(self, request: TxRequest):
request.requestHeaders.removeHeader("Content-Length")
self.deferRequest(request, 0, self._delayed_render, request)
return NOT_DONE_YET
@staticmethod
def _delayed_render(request: TxRequest):
request.write(Data.NO_CONTENT_LENGTH)
request.finish()
class TimeoutResponse(LeafResource):
def render_GET(self, request: TxRequest):
return NOT_DONE_YET
class QueryParams(LeafResource):
def render_GET(self, request: TxRequest):
request.setHeader("Content-Type", "application/json; charset=UTF-8")
request.setHeader("Content-Encoding", "UTF-8")
query_params: dict[str, str] = {}
assert request.args is not None
for k, v in request.args.items():
query_params[str(k, "utf-8")] = str(v[0], "utf-8")
return bytes(json.dumps(query_params), "utf-8")
class RequestHeaders(LeafResource):
"""Sends all the headers received as a response"""
def render_GET(self, request: TxRequest):
request.setHeader("Content-Type", "application/json; charset=UTF-8")
request.setHeader("Content-Encoding", "UTF-8")
headers = {}
for k, v in request.requestHeaders.getAllRawHeaders():
headers[str(k, "utf-8")] = str(v[0], "utf-8")
return bytes(json.dumps(headers), "utf-8")
def get_client_certificate(
key_file: Path, certificate_file: Path
) -> PrivateCertificate:
pem = key_file.read_text(encoding="utf-8") + certificate_file.read_text(
encoding="utf-8"
)
return PrivateCertificate.loadPEM(pem)
@skipIf(not H2_ENABLED, "HTTP/2 support in Twisted is not enabled")
class Https2ClientProtocolTestCase(TestCase):
scheme = "https"
key_file = Path(__file__).parent / "keys" / "localhost.key"
certificate_file = Path(__file__).parent / "keys" / "localhost.crt"
def _init_resource(self):
self.temp_directory = mkdtemp()
r = File(self.temp_directory)
r.putChild(b"get-data-html-small", GetDataHtmlSmall())
r.putChild(b"get-data-html-large", GetDataHtmlLarge())
r.putChild(b"post-data-json-small", PostDataJsonSmall())
r.putChild(b"post-data-json-large", PostDataJsonLarge())
r.putChild(b"dataloss", Dataloss())
r.putChild(b"no-content-length-header", NoContentLengthHeader())
r.putChild(b"status", Status())
r.putChild(b"query-params", QueryParams())
r.putChild(b"timeout", TimeoutResponse())
r.putChild(b"request-headers", RequestHeaders())
return r
@inlineCallbacks
def setUp(self):
# Initialize resource tree
root = self._init_resource()
self.site = Site(root, timeout=None)
# Start server for testing
self.hostname = "localhost"
context_factory = ssl_context_factory(
str(self.key_file), str(self.certificate_file)
)
server_endpoint = SSL4ServerEndpoint(
reactor, 0, context_factory, interface=self.hostname
)
self.server = yield server_endpoint.listen(self.site)
self.port_number = self.server.getHost().port
# Connect H2 client with server
self.client_certificate = get_client_certificate(
self.key_file, self.certificate_file
)
client_options = optionsForClientTLS(
hostname=self.hostname,
trustRoot=self.client_certificate,
acceptableProtocols=[b"h2"],
)
uri = URI.fromBytes(bytes(self.get_url("/"), "utf-8"))
self.conn_closed_deferred = Deferred()
from scrapy.core.http2.protocol import H2ClientFactory
h2_client_factory = H2ClientFactory(uri, Settings(), self.conn_closed_deferred)
client_endpoint = SSL4ClientEndpoint(
reactor, self.hostname, self.port_number, client_options
)
self.client = yield client_endpoint.connect(h2_client_factory)
@inlineCallbacks
def tearDown(self):
if self.client.connected:
yield self.client.transport.loseConnection()
yield self.client.transport.abortConnection()
yield self.server.stopListening()
shutil.rmtree(self.temp_directory)
self.conn_closed_deferred = None
def get_url(self, path):
"""
:param path: Should have / at the starting compulsorily if not empty
:return: Complete url
"""
assert len(path) > 0 and (path[0] == "/" or path[0] == "&")
return f"{self.scheme}://{self.hostname}:{self.port_number}{path}"
def make_request(self, request: Request) -> Deferred:
return self.client.request(request, DummySpider())
@staticmethod
def _check_repeat(get_deferred, count):
d_list = []
for _ in range(count):
d = get_deferred()
d_list.append(d)
return DeferredList(d_list, fireOnOneErrback=True)
def _check_GET(self, request: Request, expected_body, expected_status):
def check_response(response: Response):
self.assertEqual(response.status, expected_status)
self.assertEqual(response.body, expected_body)
self.assertEqual(response.request, request)
content_length_header = response.headers.get("Content-Length")
assert content_length_header is not None
content_length = int(content_length_header)
self.assertEqual(len(response.body), content_length)
d = self.make_request(request)
d.addCallback(check_response)
d.addErrback(self.fail)
return d
def test_GET_small_body(self):
request = Request(self.get_url("/get-data-html-small"))
return self._check_GET(request, Data.HTML_SMALL, 200)
def test_GET_large_body(self):
request = Request(self.get_url("/get-data-html-large"))
return self._check_GET(request, Data.HTML_LARGE, 200)
def _check_GET_x10(self, *args, **kwargs):
def get_deferred():
return self._check_GET(*args, **kwargs)
return self._check_repeat(get_deferred, 10)
def test_GET_small_body_x10(self):
return self._check_GET_x10(
Request(self.get_url("/get-data-html-small")), Data.HTML_SMALL, 200
)
def test_GET_large_body_x10(self):
return self._check_GET_x10(
Request(self.get_url("/get-data-html-large")), Data.HTML_LARGE, 200
)
def _check_POST_json(
self,
request: Request,
expected_request_body,
expected_extra_data,
expected_status: int,
):
d = self.make_request(request)
def assert_response(response: Response):
self.assertEqual(response.status, expected_status)
self.assertEqual(response.request, request)
content_length_header = response.headers.get("Content-Length")
assert content_length_header is not None
content_length = int(content_length_header)
self.assertEqual(len(response.body), content_length)
# Parse the body
content_encoding_header = response.headers[b"Content-Encoding"]
assert content_encoding_header is not None
content_encoding = str(content_encoding_header, "utf-8")
body = json.loads(str(response.body, content_encoding))
self.assertIn("request-body", body)
self.assertIn("extra-data", body)
self.assertIn("request-headers", body)
request_body = body["request-body"]
self.assertEqual(request_body, expected_request_body)
extra_data = body["extra-data"]
self.assertEqual(extra_data, expected_extra_data)
# Check if headers were sent successfully
request_headers = body["request-headers"]
for k, v in request.headers.items():
k_str = str(k, "utf-8")
self.assertIn(k_str, request_headers)
self.assertEqual(request_headers[k_str], str(v[0], "utf-8"))
d.addCallback(assert_response)
d.addErrback(self.fail)
return d
def test_POST_small_json(self):
request = JsonRequest(
url=self.get_url("/post-data-json-small"),
method="POST",
data=Data.JSON_SMALL,
)
return self._check_POST_json(request, Data.JSON_SMALL, Data.EXTRA_SMALL, 200)
def test_POST_large_json(self):
request = JsonRequest(
url=self.get_url("/post-data-json-large"),
method="POST",
data=Data.JSON_LARGE,
)
return self._check_POST_json(request, Data.JSON_LARGE, Data.EXTRA_LARGE, 200)
def _check_POST_json_x10(self, *args, **kwargs):
def get_deferred():
return self._check_POST_json(*args, **kwargs)
return self._check_repeat(get_deferred, 10)
def test_POST_small_json_x10(self):
request = JsonRequest(
url=self.get_url("/post-data-json-small"),
method="POST",
data=Data.JSON_SMALL,
)
return self._check_POST_json_x10(
request, Data.JSON_SMALL, Data.EXTRA_SMALL, 200
)
def test_POST_large_json_x10(self):
request = JsonRequest(
url=self.get_url("/post-data-json-large"),
method="POST",
data=Data.JSON_LARGE,
)
return self._check_POST_json_x10(
request, Data.JSON_LARGE, Data.EXTRA_LARGE, 200
)
@inlineCallbacks
def test_invalid_negotiated_protocol(self):
with mock.patch(
"scrapy.core.http2.protocol.PROTOCOL_NAME", return_value=b"not-h2"
):
request = Request(url=self.get_url("/status?n=200"))
with self.assertRaises(ResponseFailed):
yield self.make_request(request)
def test_cancel_request(self):
request = Request(url=self.get_url("/get-data-html-large"))
def assert_response(response: Response):
self.assertEqual(response.status, 499)
self.assertEqual(response.request, request)
d = self.make_request(request)
d.addCallback(assert_response)
d.addErrback(self.fail)
d.cancel()
return d
def test_download_maxsize_exceeded(self):
request = Request(
url=self.get_url("/get-data-html-large"), meta={"download_maxsize": 1000}
)
def assert_cancelled_error(failure):
self.assertIsInstance(failure.value, CancelledError)
error_pattern = re.compile(
rf"Cancelling download of {request.url}: received response "
rf"size \(\d*\) larger than download max size \(1000\)"
)
self.assertEqual(len(re.findall(error_pattern, str(failure.value))), 1)
d = self.make_request(request)
d.addCallback(self.fail)
d.addErrback(assert_cancelled_error)
return d
def test_received_dataloss_response(self):
"""In case when value of Header Content-Length != len(Received Data)
ProtocolError is raised"""
request = Request(url=self.get_url("/dataloss"))
def assert_failure(failure: Failure):
self.assertTrue(len(failure.value.reasons) > 0)
from h2.exceptions import InvalidBodyLengthError
self.assertTrue(
any(
isinstance(error, InvalidBodyLengthError)
for error in failure.value.reasons
)
)
d = self.make_request(request)
d.addCallback(self.fail)
d.addErrback(assert_failure)
return d
def test_missing_content_length_header(self):
request = Request(url=self.get_url("/no-content-length-header"))
def assert_content_length(response: Response):
self.assertEqual(response.status, 200)
self.assertEqual(response.body, Data.NO_CONTENT_LENGTH)
self.assertEqual(response.request, request)
self.assertNotIn("Content-Length", response.headers)
d = self.make_request(request)
d.addCallback(assert_content_length)
d.addErrback(self.fail)
return d
@inlineCallbacks
def _check_log_warnsize(self, request, warn_pattern, expected_body):
with self.assertLogs("scrapy.core.http2.stream", level="WARNING") as cm:
response = yield self.make_request(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.request, request)
self.assertEqual(response.body, expected_body)
# Check the warning is raised only once for this request
self.assertEqual(
sum(len(re.findall(warn_pattern, log)) for log in cm.output), 1
)
@inlineCallbacks
def test_log_expected_warnsize(self):
request = Request(
url=self.get_url("/get-data-html-large"), meta={"download_warnsize": 1000}
)
warn_pattern = re.compile(
rf"Expected response size \(\d*\) larger than "
rf"download warn size \(1000\) in request {request}"
)
yield self._check_log_warnsize(request, warn_pattern, Data.HTML_LARGE)
@inlineCallbacks
def test_log_received_warnsize(self):
request = Request(
url=self.get_url("/no-content-length-header"),
meta={"download_warnsize": 10},
)
warn_pattern = re.compile(
rf"Received more \(\d*\) bytes than download "
rf"warn size \(10\) in request {request}"
)
yield self._check_log_warnsize(request, warn_pattern, Data.NO_CONTENT_LENGTH)
def test_max_concurrent_streams(self):
"""Send 500 requests at one to check if we can handle
very large number of request.
"""
def get_deferred():
return self._check_GET(
Request(self.get_url("/get-data-html-small")), Data.HTML_SMALL, 200
)
return self._check_repeat(get_deferred, 500)
def test_inactive_stream(self):
"""Here we send 110 requests considering the MAX_CONCURRENT_STREAMS
by default is 100. After sending the first 100 requests we close the
connection."""
d_list = []
def assert_inactive_stream(failure):
self.assertIsNotNone(failure.check(ResponseFailed))
from scrapy.core.http2.stream import InactiveStreamClosed
self.assertTrue(
any(isinstance(e, InactiveStreamClosed) for e in failure.value.reasons)
)
# Send 100 request (we do not check the result)
for _ in range(100):
d = self.make_request(Request(self.get_url("/get-data-html-small")))
d.addBoth(lambda _: None)
d_list.append(d)
# Now send 10 extra request and save the response deferred in a list
for _ in range(10):
d = self.make_request(Request(self.get_url("/get-data-html-small")))
d.addCallback(self.fail)
d.addErrback(assert_inactive_stream)
d_list.append(d)
# Close the connection now to fire all the extra 10 requests errback
# with InactiveStreamClosed
self.client.transport.loseConnection()
return DeferredList(d_list, consumeErrors=True, fireOnOneErrback=True)
def test_invalid_request_type(self):
with self.assertRaises(TypeError):
self.make_request("https://InvalidDataTypePassed.com")
def test_query_parameters(self):
params = {
"a": generate_random_string(20),
"b": generate_random_string(20),
"c": generate_random_string(20),
"d": generate_random_string(20),
}
request = Request(self.get_url(f"/query-params?{urlencode(params)}"))
def assert_query_params(response: Response):
content_encoding_header = response.headers[b"Content-Encoding"]
assert content_encoding_header is not None
content_encoding = str(content_encoding_header, "utf-8")
data = json.loads(str(response.body, content_encoding))
self.assertEqual(data, params)
d = self.make_request(request)
d.addCallback(assert_query_params)
d.addErrback(self.fail)
return d
def test_status_codes(self):
def assert_response_status(response: Response, expected_status: int):
self.assertEqual(response.status, expected_status)
d_list = []
for status in [200, 404]:
request = Request(self.get_url(f"/status?n={status}"))
d = self.make_request(request)
d.addCallback(assert_response_status, status)
d.addErrback(self.fail)
d_list.append(d)
return DeferredList(d_list, fireOnOneErrback=True)
def test_response_has_correct_certificate_ip_address(self):
request = Request(self.get_url("/status?n=200"))
def assert_metadata(response: Response):
self.assertEqual(response.request, request)
self.assertIsInstance(response.certificate, Certificate)
assert response.certificate # typing
self.assertIsNotNone(response.certificate.original)
self.assertEqual(
response.certificate.getIssuer(), self.client_certificate.getIssuer()
)
self.assertTrue(
response.certificate.getPublicKey().matches(
self.client_certificate.getPublicKey()
)
)
self.assertIsInstance(response.ip_address, IPv4Address)
self.assertEqual(str(response.ip_address), "127.0.0.1")
d = self.make_request(request)
d.addCallback(assert_metadata)
d.addErrback(self.fail)
return d
def _check_invalid_netloc(self, url):
request = Request(url)
def assert_invalid_hostname(failure: Failure):
from scrapy.core.http2.stream import InvalidHostname
self.assertIsNotNone(failure.check(InvalidHostname))
error_msg = str(failure.value)
self.assertIn("localhost", error_msg)
self.assertIn("127.0.0.1", error_msg)
self.assertIn(str(request), error_msg)
d = self.make_request(request)
d.addCallback(self.fail)
d.addErrback(assert_invalid_hostname)
return d
def test_invalid_hostname(self):
return self._check_invalid_netloc("https://notlocalhost.notlocalhostdomain")
def test_invalid_host_port(self):
port = self.port_number + 1
return self._check_invalid_netloc(f"https://127.0.0.1:{port}")
def test_connection_stays_with_invalid_requests(self):
d_list = [
self.test_invalid_hostname(),
self.test_invalid_host_port(),
self.test_GET_small_body(),
self.test_POST_small_json(),
]
return DeferredList(d_list, fireOnOneErrback=True)
def test_connection_timeout(self):
request = Request(self.get_url("/timeout"))
d = self.make_request(request)
# Update the timer to 1s to test connection timeout
self.client.setTimeout(1)
def assert_timeout_error(failure: Failure):
for err in failure.value.reasons:
from scrapy.core.http2.protocol import H2ClientProtocol
if isinstance(err, TimeoutError):
self.assertIn(
f"Connection was IDLE for more than {H2ClientProtocol.IDLE_TIMEOUT}s",
str(err),
)
break
else:
self.fail()
d.addCallback(self.fail)
d.addErrback(assert_timeout_error)
return d
def test_request_headers_received(self):
request = Request(
self.get_url("/request-headers"),
headers={"header-1": "header value 1", "header-2": "header value 2"},
)
d = self.make_request(request)
def assert_request_headers(response: Response):
self.assertEqual(response.status, 200)
self.assertEqual(response.request, request)
response_headers = json.loads(str(response.body, "utf-8"))
self.assertIsInstance(response_headers, dict)
for k, v in request.headers.items():
k, v = str(k, "utf-8"), str(v[0], "utf-8")
self.assertIn(k, response_headers)
self.assertEqual(v, response_headers[k])
d.addErrback(self.fail)
d.addCallback(assert_request_headers)
return d