diff --git a/requests_toolbelt/sessions.py b/requests_toolbelt/sessions.py index 362924f..c747596 100644 --- a/requests_toolbelt/sessions.py +++ b/requests_toolbelt/sessions.py @@ -6,7 +6,7 @@ class BaseUrlSession(requests.Session): """A Session with a URL that all requests will use as a base. - Let's start by looking at an example: + Let's start by looking at a few examples: .. code-block:: python @@ -19,10 +19,22 @@ class BaseUrlSession(requests.Session): Our call to the ``get`` method will make a request to the URL passed in when we created the Session and the partial resource name we provide. + We implement this by overriding the ``request`` method of the Session. - We implement this by overriding the ``request`` method so most uses of a - Session are covered. (This, however, precludes the use of PreparedRequest - objects). + Likewise, we override the ``prepare_request`` method so you can construct + a PreparedRequest in the same way: + + .. code-block:: python + + >>> from requests import Request + >>> from requests_toolbelt import sessions + >>> s = sessions.BaseUrlSession( + ... base_url='https://example.com/resource/') + >>> request = Request(method='GET', url='sub-resource/') + >>> prepared_request = s.prepare_request(request) + >>> r = s.send(prepared_request) + >>> print(r.request.url) + https://example.com/resource/sub-resource .. note:: @@ -65,6 +77,13 @@ def request(self, method, url, *args, **kwargs): method, url, *args, **kwargs ) + def prepare_request(self, request, *args, **kwargs): + """Prepare the request after generating the complete URL.""" + request.url = self.create_url(request.url) + return super(BaseUrlSession, self).prepare_request( + request, *args, **kwargs + ) + def create_url(self, url): """Create the URL based off this partial path.""" return urljoin(self.base_url, url) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 297d9ea..e375578 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -3,26 +3,53 @@ import pytest from requests_toolbelt import sessions +from requests import Request from . import get_betamax class TestBasedSession(unittest.TestCase): - def test_with_base(self): + def test_request_with_base(self): session = sessions.BaseUrlSession('https://httpbin.org/') recorder = get_betamax(session) with recorder.use_cassette('simple_get_request'): response = session.get('/get') response.raise_for_status() - def test_without_base(self): + def test_request_without_base(self): session = sessions.BaseUrlSession() with pytest.raises(ValueError): session.get('/') - def test_override_base(self): + def test_request_override_base(self): session = sessions.BaseUrlSession('https://www.google.com') recorder = get_betamax(session) with recorder.use_cassette('simple_get_request'): response = session.get('https://httpbin.org/get') response.raise_for_status() assert response.json()['headers']['Host'] == 'httpbin.org' + + def test_prepared_request_with_base(self): + session = sessions.BaseUrlSession('https://httpbin.org') + request = Request(method="GET", url="/get") + prepared_request = session.prepare_request(request) + recorder = get_betamax(session) + with recorder.use_cassette('simple_get_request'): + response = session.send(prepared_request) + response.raise_for_status() + + def test_prepared_request_without_base(self): + session = sessions.BaseUrlSession() + request = Request(method="GET", url="/") + with pytest.raises(ValueError): + prepared_request = session.prepare_request(request) + session.send(prepared_request) + + def test_prepared_request_override_base(self): + session = sessions.BaseUrlSession('https://www.google.com') + request = Request(method="GET", url="https://httpbin.org/get") + prepared_request = session.prepare_request(request) + recorder = get_betamax(session) + with recorder.use_cassette('simple_get_request'): + response = session.send(prepared_request) + response.raise_for_status() + assert response.json()['headers']['Host'] == 'httpbin.org'