1
0
mirror of https://github.com/scrapy/scrapy.git synced 2025-02-14 14:05:01 +00:00

Allow updating flags in follow and follow_all (#4279)

This commit is contained in:
Abhishek Pratap Singh 2020-02-10 18:48:31 +00:00 committed by GitHub
parent b0eaf114e5
commit 4626e90df8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 4 deletions

View File

@ -107,7 +107,7 @@ class Response(object_ref):
def follow(self, url, callback=None, method='GET', headers=None, body=None,
cookies=None, meta=None, encoding='utf-8', priority=0,
dont_filter=False, errback=None, cb_kwargs=None):
dont_filter=False, errback=None, cb_kwargs=None, flags=None):
# type: (...) -> Request
"""
Return a :class:`~.Request` instance to follow a link ``url``.
@ -124,6 +124,7 @@ class Response(object_ref):
elif url is None:
raise ValueError("url can't be None")
url = self.urljoin(url)
return Request(
url=url,
callback=callback,
@ -137,11 +138,12 @@ class Response(object_ref):
dont_filter=dont_filter,
errback=errback,
cb_kwargs=cb_kwargs,
flags=flags,
)
def follow_all(self, urls, callback=None, method='GET', headers=None, body=None,
cookies=None, meta=None, encoding='utf-8', priority=0,
dont_filter=False, errback=None, cb_kwargs=None):
dont_filter=False, errback=None, cb_kwargs=None, flags=None):
# type: (...) -> Generator[Request, None, None]
"""
Return an iterable of :class:`~.Request` instances to follow all links
@ -169,6 +171,7 @@ class Response(object_ref):
dont_filter=dont_filter,
errback=errback,
cb_kwargs=cb_kwargs,
flags=flags,
)
for url in urls
)

View File

@ -121,7 +121,7 @@ class TextResponse(Response):
def follow(self, url, callback=None, method='GET', headers=None, body=None,
cookies=None, meta=None, encoding=None, priority=0,
dont_filter=False, errback=None, cb_kwargs=None):
dont_filter=False, errback=None, cb_kwargs=None, flags=None):
# type: (...) -> Request
"""
Return a :class:`~.Request` instance to follow a link ``url``.
@ -157,11 +157,12 @@ class TextResponse(Response):
dont_filter=dont_filter,
errback=errback,
cb_kwargs=cb_kwargs,
flags=flags,
)
def follow_all(self, urls=None, callback=None, method='GET', headers=None, body=None,
cookies=None, meta=None, encoding=None, priority=0,
dont_filter=False, errback=None, cb_kwargs=None,
dont_filter=False, errback=None, cb_kwargs=None, flags=None,
css=None, xpath=None):
# type: (...) -> Generator[Request, None, None]
"""
@ -214,6 +215,7 @@ class TextResponse(Response):
dont_filter=dont_filter,
errback=errback,
cb_kwargs=cb_kwargs,
flags=flags,
)

View File

@ -166,6 +166,10 @@ class BaseResponseTest(unittest.TestCase):
def test_follow_whitespace_link(self):
self._assert_followed_url(Link('http://example.com/foo '),
'http://example.com/foo%20')
def test_follow_flags(self):
res = self.response_class('http://example.com/')
fol = res.follow('http://example.com/', flags=['cached', 'allowed'])
self.assertEqual(fol.flags, ['cached', 'allowed'])
# Response.follow_all
@ -232,6 +236,17 @@ class BaseResponseTest(unittest.TestCase):
expected = [u.replace(' ', '%20') for u in absolute]
self._assert_followed_all_urls(links, expected)
def test_follow_all_flags(self):
re = self.response_class('http://www.example.com/')
urls = [
'http://www.example.com/',
'http://www.example.com/2',
'http://www.example.com/foo',
]
fol = re.follow_all(urls, flags=['cached', 'allowed'])
for req in fol:
self.assertEqual(req.flags, ['cached', 'allowed'])
def _assert_followed_url(self, follow_obj, target_url, response=None):
if response is None:
response = self._links_response()
@ -562,6 +577,22 @@ class TextResponseTest(BaseResponseTest):
)
self.assertEqual(req.encoding, 'cp1251')
def test_follow_flags(self):
res = self.response_class('http://example.com/')
fol = res.follow('http://example.com/', flags=['cached', 'allowed'])
self.assertEqual(fol.flags, ['cached', 'allowed'])
def test_follow_all_flags(self):
re = self.response_class('http://www.example.com/')
urls = [
'http://www.example.com/',
'http://www.example.com/2',
'http://www.example.com/foo',
]
fol = re.follow_all(urls, flags=['cached', 'allowed'])
for req in fol:
self.assertEqual(req.flags, ['cached', 'allowed'])
def test_follow_all_css(self):
expected = [
'http://example.com/sample3.html',