1
0
mirror of https://github.com/scrapy/scrapy.git synced 2025-02-06 11:00:46 +00:00

Enable SIM Ruff rules.

This commit is contained in:
Andrey Rakhmatullin 2025-01-01 23:05:07 +05:00
parent 273620488c
commit c87354cd46
35 changed files with 128 additions and 146 deletions

View File

@ -254,6 +254,8 @@ extend-select = [
"RUF",
# flake8-bandit
"S",
# flake8-simplify
"SIM",
# flake8-slots
"SLOT",
# flake8-debugger
@ -344,6 +346,12 @@ ignore = [
"S321",
# Argument default set to insecure SSL protocol
"S503",
# Use capitalized environment variable
"SIM112",
# Use a context manager for opening files
"SIM115",
# Yoda condition detected
"SIM300",
]
[tool.ruff.lint.per-file-ignores]

View File

@ -90,12 +90,10 @@ def _get_commands_dict(
def _pop_command_name(argv: list[str]) -> str | None:
i = 0
for arg in argv[1:]:
for i, arg in enumerate(argv[1:]):
if not arg.startswith("-"):
del argv[i]
return arg
i += 1
return None

View File

@ -174,13 +174,12 @@ class Command(BaseRunSpiderCommand):
display.pprint([ItemAdapter(x).asdict() for x in items], colorize=colour)
def print_requests(self, lvl: int | None = None, colour: bool = True) -> None:
if lvl is None:
if self.requests:
requests = self.requests[max(self.requests)]
else:
requests = []
else:
if lvl is not None:
requests = self.requests.get(lvl, [])
elif self.requests:
requests = self.requests[max(self.requests)]
else:
requests = []
print("# Requests ", "-" * 65)
display.pprint(requests, colorize=colour)

View File

@ -95,10 +95,7 @@ class Command(ScrapyCommand):
project_name = args[0]
if len(args) == 2:
project_dir = Path(args[1])
else:
project_dir = Path(args[0])
project_dir = Path(args[-1])
if (project_dir / "scrapy.cfg").exists():
self.exitcode = 1

View File

@ -424,10 +424,7 @@ class ScrapyAgent:
headers = TxHeaders(request.headers)
if isinstance(agent, self._TunnelingAgent):
headers.removeHeader(b"Proxy-Authorization")
if request.body:
bodyproducer = _RequestBodyProducer(request.body)
else:
bodyproducer = None
bodyproducer = _RequestBodyProducer(request.body) if request.body else None
start_time = time()
d: Deferred[TxResponse] = agent.request(
method, to_bytes(url, encoding="ascii"), headers, bodyproducer

View File

@ -291,9 +291,7 @@ class ExecutionEngine:
return False
if self.slot.start_requests is not None: # not all start requests are handled
return False
if self.slot.scheduler.has_pending_requests():
return False
return True
return not self.slot.scheduler.has_pending_requests()
def crawl(self, request: Request) -> None:
"""Inject the request into the spider <-> downloader pipeline"""
@ -388,9 +386,8 @@ class ExecutionEngine:
)
self.slot = Slot(start_requests, close_if_idle, nextcall, scheduler)
self.spider = spider
if hasattr(scheduler, "open"):
if d := scheduler.open(spider):
yield d
if hasattr(scheduler, "open") and (d := scheduler.open(spider)):
yield d
yield self.scraper.open_spider(spider)
assert self.crawler.stats
self.crawler.stats.open_spider(spider)

View File

@ -198,10 +198,7 @@ class SpiderMiddlewareManager(MiddlewareManager):
# chain, they went through it already from the process_spider_exception method
recovered: MutableChain[_T] | MutableAsyncChain[_T]
last_result_is_async = isinstance(result, AsyncIterable)
if last_result_is_async:
recovered = MutableAsyncChain()
else:
recovered = MutableChain()
recovered = MutableAsyncChain() if last_result_is_async else MutableChain()
# There are three cases for the middleware: def foo, async def foo, def foo + async def foo_async.
# 1. def foo. Sync iterables are passed as is, async ones are downgraded.

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import logging
import pprint
import signal
@ -503,7 +504,6 @@ class CrawlerProcess(CrawlerRunner):
def _stop_reactor(self, _: Any = None) -> None:
from twisted.internet import reactor
try:
# raised if already stopped or in shutdown stage
with contextlib.suppress(RuntimeError):
reactor.stop()
except RuntimeError: # raised if already stopped or in shutdown stage
pass

View File

@ -42,7 +42,10 @@ class HttpAuthMiddleware:
self, request: Request, spider: Spider
) -> Request | Response | None:
auth = getattr(self, "auth", None)
if auth and b"Authorization" not in request.headers:
if not self.domain or url_is_from_any_domain(request.url, [self.domain]):
request.headers[b"Authorization"] = auth
if (
auth
and b"Authorization" not in request.headers
and (not self.domain or url_is_from_any_domain(request.url, [self.domain]))
):
request.headers[b"Authorization"] = auth
return None

View File

@ -51,10 +51,7 @@ class HttpProxyMiddleware:
proxy_type, user, password, hostport = _parse_proxy(url)
proxy_url = urlunparse((proxy_type or orig_type, hostport, "", "", "", ""))
if user:
creds = self._basic_auth_header(user, password)
else:
creds = None
creds = self._basic_auth_header(user, password) if user else None
return creds, proxy_url

View File

@ -81,10 +81,7 @@ class BaseItemExporter:
include_empty = self.export_empty_fields
if self.fields_to_export is None:
if include_empty:
field_iter = item.field_names()
else:
field_iter = item.keys()
field_iter = item.field_names() if include_empty else item.keys()
elif isinstance(self.fields_to_export, Mapping):
if include_empty:
field_iter = self.fields_to_export.items()

View File

@ -6,6 +6,7 @@ See documentation in docs/topics/extensions.rst
from __future__ import annotations
import contextlib
import logging
import signal
import sys
@ -69,11 +70,9 @@ class StackTraceDump:
class Debugger:
def __init__(self) -> None:
try:
# win32 platforms don't support SIGUSR signals
with contextlib.suppress(AttributeError):
signal.signal(signal.SIGUSR2, self._enter_debugger) # type: ignore[attr-defined]
except AttributeError:
# win32 platforms don't support SIGUSR signals
pass
def _enter_debugger(self, signum: int, frame: FrameType | None) -> None:
assert frame

View File

@ -6,6 +6,7 @@ See documentation in docs/topics/feed-exports.rst
from __future__ import annotations
import contextlib
import logging
import re
import sys
@ -642,10 +643,8 @@ class FeedExporter:
)
d = {}
for k, v in conf.items():
try:
with contextlib.suppress(NotConfigured):
d[k] = load_object(v)
except NotConfigured:
pass
return d
def _exporter_supported(self, format: str) -> bool:

View File

@ -89,10 +89,7 @@ class RFC2616Policy:
return False
cc = self._parse_cachecontrol(request)
# obey user-agent directive "Cache-Control: no-store"
if b"no-store" in cc:
return False
# Any other is eligible for caching
return True
return b"no-store" not in cc
def should_cache_response(self, response: Response, request: Request) -> bool:
# What is cacheable - https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.9.1

View File

@ -151,10 +151,7 @@ class PeriodicLog:
return False
if exclude and not include:
return True
for p in include:
if p in stat_name:
return True
return False
return any(p in stat_name for p in include)
def spider_closed(self, spider: Spider, reason: str) -> None:
self.log()

View File

@ -64,9 +64,8 @@ class CookieJar:
cookies += self.jar._cookies_for_domain(host, wreq) # type: ignore[attr-defined]
attrs = self.jar._cookie_attrs(cookies) # type: ignore[attr-defined]
if attrs:
if not wreq.has_header("Cookie"):
wreq.add_unredirected_header("Cookie", "; ".join(attrs))
if attrs and not wreq.has_header("Cookie"):
wreq.add_unredirected_header("Cookie", "; ".join(attrs))
self.processed += 1
if self.processed % self.check_expired_frequency == 0:

View File

@ -29,7 +29,7 @@ class JsonRequest(Request):
dumps_kwargs.setdefault("sort_keys", True)
self._dumps_kwargs: dict[str, Any] = dumps_kwargs
body_passed = kwargs.get("body", None) is not None
body_passed = kwargs.get("body") is not None
data: Any = kwargs.pop("data", None)
data_passed: bool = data is not None
@ -61,7 +61,7 @@ class JsonRequest(Request):
def replace(
self, *args: Any, cls: type[Request] | None = None, **kwargs: Any
) -> Request:
body_passed = kwargs.get("body", None) is not None
body_passed = kwargs.get("body") is not None
data: Any = kwargs.pop("data", None)
data_passed: bool = data is not None

View File

@ -41,9 +41,12 @@ _collect_string_content = etree.XPath("string()")
def _nons(tag: Any) -> Any:
if isinstance(tag, str):
if tag[0] == "{" and tag[1 : len(XHTML_NAMESPACE) + 1] == XHTML_NAMESPACE:
return tag.split("}")[-1]
if (
isinstance(tag, str)
and tag[0] == "{"
and tag[1 : len(XHTML_NAMESPACE) + 1] == XHTML_NAMESPACE
):
return tag.split("}")[-1]
return tag
@ -230,9 +233,7 @@ class LxmlLinkExtractor:
parsed_url, self.deny_extensions
):
return False
if self.restrict_text and not _matches(link.text, self.restrict_text):
return False
return True
return not self.restrict_text or _matches(link.text, self.restrict_text)
def matches(self, url: str) -> bool:
if self.allow_domains and not url_is_from_any_domain(url, self.allow_domains):

View File

@ -111,11 +111,9 @@ class MailSender:
) -> Deferred[None] | None:
from twisted.internet import reactor
msg: MIMEBase
if attachs:
msg = MIMEMultipart()
else:
msg = MIMENonMultipart(*mimetype.split("/", 1))
msg: MIMEBase = (
MIMEMultipart() if attachs else MIMENonMultipart(*mimetype.split("/", 1))
)
to = list(arg_to_iter(to))
cc = list(arg_to_iter(cc))

View File

@ -553,10 +553,8 @@ class FilesPipeline(MediaPipeline):
ftp_store.USE_ACTIVE_MODE = settings.getbool("FEED_STORAGE_FTP_ACTIVE")
def _get_store(self, uri: str) -> FilesStoreProtocol:
if Path(uri).is_absolute(): # to support win32 paths like: C:\\some\dir
scheme = "file"
else:
scheme = urlparse(uri).scheme
# to support win32 paths like: C:\\some\dir
scheme = "file" if Path(uri).is_absolute() else urlparse(uri).scheme
store_cls = self.STORE_SCHEMES[scheme]
return store_cls(uri)

View File

@ -6,6 +6,7 @@ See documentation in docs/topics/shell.rst
from __future__ import annotations
import contextlib
import os
import signal
from typing import TYPE_CHECKING, Any
@ -143,12 +144,10 @@ class Shell:
else:
request.meta["handle_httpstatus_all"] = True
response = None
try:
with contextlib.suppress(IgnoreRequest):
response, spider = threads.blockingCallFromThread(
reactor, self._schedule, request, spider
)
except IgnoreRequest:
pass
self.populate_vars(response, request, spider)
def populate_vars(

View File

@ -360,11 +360,10 @@ class RefererMiddleware:
- otherwise, the policy from settings is used.
"""
policy_name = request.meta.get("referrer_policy")
if policy_name is None:
if isinstance(resp_or_url, Response):
policy_header = resp_or_url.headers.get("Referrer-Policy")
if policy_header is not None:
policy_name = to_unicode(policy_header.decode("latin1"))
if policy_name is None and isinstance(resp_or_url, Response):
policy_header = resp_or_url.headers.get("Referrer-Policy")
if policy_header is not None:
policy_name = to_unicode(policy_header.decode("latin1"))
if policy_name is None:
return self.default_policy()

View File

@ -1,3 +1,4 @@
import contextlib
import zlib
from io import BytesIO
from warnings import warn
@ -37,10 +38,8 @@ else:
return decompressor.process(data)
try:
with contextlib.suppress(ImportError):
import zstandard
except ImportError:
pass
_CHUNK_SIZE = 65536 # 64 KiB

View File

@ -8,6 +8,7 @@ This module must not depend on any module outside the Standard Library.
from __future__ import annotations
import collections
import contextlib
import warnings
import weakref
from collections import OrderedDict
@ -173,10 +174,9 @@ class LocalWeakReferencedCache(weakref.WeakKeyDictionary):
self.data: LocalCache = LocalCache(limit=limit)
def __setitem__(self, key: _KT, value: _VT) -> None:
try:
# if raised, key is not weak-referenceable, skip caching
with contextlib.suppress(TypeError):
super().__setitem__(key, value)
except TypeError:
pass # key is not weak-referenceable, skip caching
def __getitem__(self, key: _KT) -> _VT | None: # type: ignore[override]
try:

View File

@ -36,7 +36,7 @@ def send_catch_log(
dont_log = named.pop("dont_log", ())
dont_log = tuple(dont_log) if isinstance(dont_log, Sequence) else (dont_log,)
dont_log += (StopDownload,)
spider = named.get("spider", None)
spider = named.get("spider")
responses: list[tuple[TypingAny, TypingAny]] = []
for receiver in liveReceivers(getAllReceivers(sender, signal)):
result: TypingAny
@ -88,7 +88,7 @@ def send_catch_log_deferred(
return failure
dont_log = named.pop("dont_log", None)
spider = named.get("spider", None)
spider = named.get("spider")
dfds: list[Deferred[tuple[TypingAny, TypingAny]]] = []
for receiver in liveReceivers(getAllReceivers(sender, signal)):
d: Deferred[TypingAny] = maybeDeferred_coro(

View File

@ -173,13 +173,19 @@ def strip_url(
parsed_url.username or parsed_url.password
):
netloc = netloc.split("@")[-1]
if strip_default_port and parsed_url.port:
if (parsed_url.scheme, parsed_url.port) in (
if (
strip_default_port
and parsed_url.port
and (parsed_url.scheme, parsed_url.port)
in (
("http", 80),
("https", 443),
("ftp", 21),
):
netloc = netloc.replace(f":{parsed_url.port}", "")
)
):
netloc = netloc.replace(f":{parsed_url.port}", "")
return urlunparse(
(
parsed_url.scheme,

View File

@ -166,19 +166,21 @@ class AddonManagerTest(unittest.TestCase):
def update_settings(self, settings):
pass
with patch("scrapy.addons.logger") as logger_mock:
with patch("scrapy.addons.build_from_crawler") as build_from_crawler_mock:
settings_dict = {
"ADDONS": {LoggedAddon: 1},
}
addon = LoggedAddon()
build_from_crawler_mock.return_value = addon
crawler = get_crawler(settings_dict=settings_dict)
logger_mock.info.assert_called_once_with(
"Enabled addons:\n%(addons)s",
{"addons": [addon]},
extra={"crawler": crawler},
)
with (
patch("scrapy.addons.logger") as logger_mock,
patch("scrapy.addons.build_from_crawler") as build_from_crawler_mock,
):
settings_dict = {
"ADDONS": {LoggedAddon: 1},
}
addon = LoggedAddon()
build_from_crawler_mock.return_value = addon
crawler = get_crawler(settings_dict=settings_dict)
logger_mock.info.assert_called_once_with(
"Enabled addons:\n%(addons)s",
{"addons": [addon]},
extra={"crawler": crawler},
)
@inlineCallbacks
def test_enable_addon_in_spider(self):

View File

@ -530,9 +530,10 @@ class Http11TestCase(HttpTestCase):
d = self.download_request(request, Spider("foo"))
def checkDataLoss(failure):
if failure.check(ResponseFailed):
if any(r.check(_DataLoss) for r in failure.value.reasons):
return None
if failure.check(ResponseFailed) and any(
r.check(_DataLoss) for r in failure.value.reasons
):
return None
return failure
d.addCallback(lambda _: self.fail("No DataLoss exception"))

View File

@ -756,7 +756,7 @@ class FeedExportTest(FeedExportTestBase):
)
finally:
for file_path in FEEDS.keys():
for file_path in FEEDS:
if not Path(file_path).exists():
continue
@ -1229,15 +1229,13 @@ class FeedExportTest(FeedExportTestBase):
class CustomFilter2(scrapy.extensions.feedexport.ItemFilter):
def accepts(self, item):
if "foo" not in item.fields:
return False
return True
return "foo" in item.fields
class CustomFilter3(scrapy.extensions.feedexport.ItemFilter):
def accepts(self, item):
if isinstance(item, tuple(self.item_classes)) and item["foo"] == "bar1":
return True
return False
return (
isinstance(item, tuple(self.item_classes)) and item["foo"] == "bar1"
)
formats = {
"json": b'[\n{"foo": "bar1", "egg": "spam1"}\n]',

View File

@ -1488,10 +1488,7 @@ def _buildresponse(body, **kwargs):
def _qs(req, encoding="utf-8", to_unicode=False):
if req.method == "POST":
qs = req.body
else:
qs = req.url.partition("?")[2]
qs = req.body if req.method == "POST" else req.url.partition("?")[2]
uqs = unquote_to_bytes(qs)
if to_unicode:
uqs = uqs.decode(encoding)

View File

@ -634,19 +634,21 @@ class TestGCSFilesStore(unittest.TestCase):
import google.cloud.storage # noqa: F401
except ModuleNotFoundError:
raise unittest.SkipTest("google-cloud-storage is not installed")
with mock.patch("google.cloud.storage") as _:
with mock.patch("scrapy.pipelines.files.time") as _:
uri = "gs://my_bucket/my_prefix/"
store = GCSFilesStore(uri)
store.bucket = mock.Mock()
path = "full/my_data.txt"
yield store.persist_file(
path, mock.Mock(), info=None, meta=None, headers=None
)
yield store.stat_file(path, info=None)
expected_blob_path = store.prefix + path
store.bucket.blob.assert_called_with(expected_blob_path)
store.bucket.get_blob.assert_called_with(expected_blob_path)
with (
mock.patch("google.cloud.storage"),
mock.patch("scrapy.pipelines.files.time"),
):
uri = "gs://my_bucket/my_prefix/"
store = GCSFilesStore(uri)
store.bucket = mock.Mock()
path = "full/my_data.txt"
yield store.persist_file(
path, mock.Mock(), info=None, meta=None, headers=None
)
yield store.stat_file(path, info=None)
expected_blob_path = store.prefix + path
store.bucket.blob.assert_called_with(expected_blob_path)
store.bucket.get_blob.assert_called_with(expected_blob_path)
class TestFTPFileStore(unittest.TestCase):

View File

@ -170,7 +170,7 @@ class BaseSettingsTest(unittest.TestCase):
self.assertCountEqual(self.settings.attributes.keys(), ctrl_attributes.keys())
for key in ctrl_attributes.keys():
for key in ctrl_attributes:
attr = self.settings.attributes[key]
ctrl_attr = ctrl_attributes[key]
self.assertEqual(attr.value, ctrl_attr.value)

View File

@ -1,3 +1,4 @@
import contextlib
import shutil
import sys
import tempfile
@ -22,10 +23,8 @@ module_dir = Path(__file__).resolve().parent
def _copytree(source: Path, target: Path):
try:
with contextlib.suppress(shutil.Error):
shutil.copytree(source, target)
except shutil.Error:
pass
class SpiderLoaderTest(unittest.TestCase):

View File

@ -259,12 +259,14 @@ class WarnWhenSubclassedTest(unittest.TestCase):
self.assertIn("foo.Bar", str(w[1].message))
def test_inspect_stack(self):
with mock.patch("inspect.stack", side_effect=IndexError):
with warnings.catch_warnings(record=True) as w:
DeprecatedName = create_deprecated_class("DeprecatedName", NewName)
with (
mock.patch("inspect.stack", side_effect=IndexError),
warnings.catch_warnings(record=True) as w,
):
DeprecatedName = create_deprecated_class("DeprecatedName", NewName)
class SubClass(DeprecatedName):
pass
class SubClass(DeprecatedName):
pass
self.assertIn("Error detecting parent module", str(w[0].message))

View File

@ -366,7 +366,7 @@ class UtilsCsvTestCase(unittest.TestCase):
# explicit type check cuz' we no like stinkin' autocasting! yarrr
for result_row in result:
self.assertTrue(all(isinstance(k, str) for k in result_row.keys()))
self.assertTrue(all(isinstance(k, str) for k in result_row))
self.assertTrue(all(isinstance(v, str) for v in result_row.values()))
def test_csviter_delimiter(self):