From 7b188517b4b0d7a86f9b51b1be5713cc95636d77 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi <29942790+sansyrox@users.noreply.github.com> Date: Thu, 23 May 2024 22:20:45 +0100 Subject: [PATCH] fix: add an optional return type in headers (#805) * fix: add an optional return type in headers * add documentation --- .../documentation/api_reference/getting_started.mdx | 7 +++++++ integration_tests/base_routes.py | 3 +++ src/types/headers.rs | 11 ++++------- src/types/request.rs | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx index 955db3c30..213090689 100644 --- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx +++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx @@ -546,6 +546,9 @@ Either, by using the `headers` field in the `Request` object: headers = request.headers print("These are the request headers: ", headers) + existing_header = headers.get("exisiting_header") + existing_header = headers.get("exisiting_header", "default_value") + exisiting_header = headers["exisiting_header"] # This syntax is also valid headers.set("modified", "modified_value") headers["new_header"] = "new_value" # This syntax is also valid @@ -563,8 +566,12 @@ Either, by using the `headers` field in the `Request` object: headers = request.headers print("These are the request headers: ", headers) + existing_header = headers.get("exisiting_header") + existing_header = headers.get("exisiting_header", "default_value") + exisiting_header = headers["exisiting_header"] # This syntax is also valid headers.set("modified", "modified_value") + headers["new_header"] = "new_value" # This syntax is also valid print("These are the modified request headers: ", headers) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 5f56382da..9d8bb5832 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -145,6 +145,9 @@ def global_after_request(response: Response): @app.get("/sync/global/middlewares") def sync_global_middlewares(request: Request): + print(request.headers) + print(request.headers.get("txt")) + print(request.headers["txt"]) assert "global_before" in request.headers assert request.headers.get("global_before") == "global_before_request" return "sync global middlewares" diff --git a/src/types/headers.rs b/src/types/headers.rs index 129477af1..c2d093cd8 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -62,18 +62,15 @@ impl Headers { } } - pub fn get(&self, key: String) -> PyResult { + pub fn get(&self, key: String) -> Option { // return the last value match self.headers.get(&key.to_lowercase()) { Some(iter) => { let (_, values) = iter.pair(); let last_value = values.last().unwrap(); - Ok(last_value.to_string()) + Some(last_value.to_string()) } - None => Err(pyo3::exceptions::PyKeyError::new_err(format!( - "KeyError: {}", - key - ))), + None => None, } } @@ -151,7 +148,7 @@ impl Headers { self.set(key, value); } - pub fn __getitem__(&self, key: String) -> PyResult { + pub fn __getitem__(&self, key: String) -> Option { self.get(key) } } diff --git a/src/types/request.rs b/src/types/request.rs index 292aa0c9d..47b378423 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -141,7 +141,7 @@ impl Request { let body: Vec = if headers.contains(String::from("content-type")) && headers .get(String::from("content-type")) - .is_ok_and(|val| val.contains("multipart/form-data")) + .is_some_and(|val| val.contains("multipart/form-data")) { let h = headers.get(String::from("content-type")).unwrap(); debug!("Content-Type: {:?}", h);