Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mutex to make Client thread-safe #827

Merged
merged 5 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
457 changes: 393 additions & 64 deletions client.go

Large diffs are not rendered by default.

65 changes: 32 additions & 33 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func TestClientAuthScheme(t *testing.T) {

// Ensure setting the scheme works as well
c.SetAuthScheme("Bearer")
assertEqual(t, "Bearer", c.AuthScheme())

resp2, err2 := c.R().Get("/profile")
assertError(t, err2)
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestClientProxy(t *testing.T) {
assertNotNil(t, resp)
assertNotNil(t, err)

// Error
// error
c.SetProxy("//not.a.user@%66%6f%6f.com:8888")

resp, err = c.R().
Expand Down Expand Up @@ -339,9 +340,9 @@ func TestClientSetHeaderVerbatim(t *testing.T) {
SetHeader("header-lowercase", "value_standard")

//lint:ignore SA1008 valid one, so ignore this!
unConventionHdrValue := strings.Join(c.Header["header-lowercase"], "")
unConventionHdrValue := strings.Join(c.Header()["header-lowercase"], "")
assertEqual(t, "value_lowercase", unConventionHdrValue)
assertEqual(t, "value_standard", c.Header.Get("Header-Lowercase"))
assertEqual(t, "value_standard", c.Header().Get("Header-Lowercase"))
}

func TestClientSetTransport(t *testing.T) {
Expand Down Expand Up @@ -387,20 +388,20 @@ func TestClientOptions(t *testing.T) {
assertEqual(t, client.setContentLength, true)

client.SetBaseURL("http://httpbin.org")
assertEqual(t, "http://httpbin.org", client.BaseURL)
assertEqual(t, "http://httpbin.org", client.BaseURL())

client.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8")
client.SetHeaders(map[string]string{
hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.1",
"X-Request-Id": strconv.FormatInt(time.Now().UnixNano(), 10),
})
assertEqual(t, "application/json; charset=utf-8", client.Header.Get(hdrContentTypeKey))
assertEqual(t, "application/json; charset=utf-8", client.Header().Get(hdrContentTypeKey))

client.SetCookie(&http.Cookie{
Name: "default-cookie",
Value: "This is cookie default-cookie value",
})
assertEqual(t, "default-cookie", client.Cookies[0].Name)
assertEqual(t, "default-cookie", client.Cookies()[0].Name)

cookies := []*http.Cookie{
{
Expand All @@ -412,45 +413,45 @@ func TestClientOptions(t *testing.T) {
},
}
client.SetCookies(cookies)
assertEqual(t, "default-cookie-1", client.Cookies[1].Name)
assertEqual(t, "default-cookie-2", client.Cookies[2].Name)
assertEqual(t, "default-cookie-1", client.Cookies()[1].Name)
assertEqual(t, "default-cookie-2", client.Cookies()[2].Name)

client.SetQueryParam("test_param_1", "Param_1")
client.SetQueryParams(map[string]string{"test_param_2": "Param_2", "test_param_3": "Param_3"})
assertEqual(t, "Param_3", client.QueryParam.Get("test_param_3"))
assertEqual(t, "Param_3", client.QueryParam().Get("test_param_3"))

rTime := strconv.FormatInt(time.Now().UnixNano(), 10)
client.SetFormData(map[string]string{"r_time": rTime})
assertEqual(t, rTime, client.FormData.Get("r_time"))
assertEqual(t, rTime, client.FormData().Get("r_time"))

client.SetBasicAuth("myuser", "mypass")
assertEqual(t, "myuser", client.UserInfo.Username)
assertEqual(t, "myuser", client.BasicAuth().Username)

client.SetAuthToken("AC75BD37F019E08FBC594900518B4F7E")
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token)
assertEqual(t, "AC75BD37F019E08FBC594900518B4F7E", client.Token())

client.SetDisableWarn(true)
assertEqual(t, client.DisableWarn, true)
assertEqual(t, client.DisableWarn(), true)

client.SetRetryCount(3)
assertEqual(t, 3, client.RetryCount)
assertEqual(t, 3, client.RetryCount())

rwt := time.Duration(1000) * time.Millisecond
client.SetRetryWaitTime(rwt)
assertEqual(t, rwt, client.RetryWaitTime)
assertEqual(t, rwt, client.RetryWaitTime())

mrwt := time.Duration(2) * time.Second
client.SetRetryMaxWaitTime(mrwt)
assertEqual(t, mrwt, client.RetryMaxWaitTime)
assertEqual(t, mrwt, client.RetryMaxWaitTime())

client.AddRetryAfterErrorCondition()
equal(client.RetryConditions[0], func(response *Response, err error) bool {
equal(client.RetryConditions()[0], func(response *Response, err error) bool {
return response.IsError()
})

err := &AuthError{}
client.SetError(err)
if reflect.TypeOf(err) == client.Error {
if reflect.TypeOf(err) == client.Error() {
t.Error("SetError failed")
}

Expand All @@ -476,14 +477,14 @@ func TestClientOptions(t *testing.T) {
client.SetContentLength(true)

client.SetDebug(true)
assertEqual(t, client.Debug, true)
assertEqual(t, client.Debug(), true)

var sl int64 = 1000000
client.SetDebugBodyLimit(sl)
assertEqual(t, client.debugBodySizeLimit, sl)

client.SetAllowGetMethodPayload(true)
assertEqual(t, client.AllowGetMethodPayload, true)
assertEqual(t, client.AllowGetMethodPayload(), true)

client.SetScheme("http")
assertEqual(t, client.scheme, "http")
Expand Down Expand Up @@ -617,31 +618,31 @@ func TestClientNewRequest(t *testing.T) {
func TestClientSetJSONMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetJSONMarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONMarshal)
p1 := fmt.Sprintf("%p", c.JSONMarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetJSONUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetJSONUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.JSONUnmarshal)
p1 := fmt.Sprintf("%p", c.JSONUnmarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLMarshaler(t *testing.T) {
m := func(v interface{}) ([]byte, error) { return nil, nil }
c := New().SetXMLMarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLMarshal)
p1 := fmt.Sprintf("%p", c.XMLMarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}

func TestClientSetXMLUnmarshaler(t *testing.T) {
m := func([]byte, interface{}) error { return nil }
c := New().SetXMLUnmarshaler(m)
p1 := fmt.Sprintf("%p", c.XMLUnmarshal)
p1 := fmt.Sprintf("%p", c.XMLUnmarshaler())
p2 := fmt.Sprintf("%p", m)
assertEqual(t, p1, p2) // functions can not be compared, we only can compare pointers
}
Expand Down Expand Up @@ -1145,23 +1146,21 @@ func TestClone(t *testing.T) {
parent.SetBaseURL("http://localhost")

// set an interface field
parent.UserInfo = &User{
Username: "parent",
}
parent.SetBasicAuth("parent", "")

clone := parent.Clone()
// update value of non-interface type - change will only happen on clone
clone.SetBaseURL("https://local.host")
// update value of interface type - change will also happen on parent
clone.UserInfo.Username = "clone"
clone.BasicAuth().Username = "clone"

// asert non-interface type
assertEqual(t, "http://localhost", parent.BaseURL)
assertEqual(t, "https://local.host", clone.BaseURL)
assertEqual(t, "http://localhost", parent.BaseURL())
assertEqual(t, "https://local.host", clone.BaseURL())

// assert interface type
assertEqual(t, "clone", parent.UserInfo.Username)
assertEqual(t, "clone", clone.UserInfo.Username)
assertEqual(t, "clone", parent.BasicAuth().Username)
assertEqual(t, "clone", clone.BasicAuth().Username)
}

func TestResponseBodyLimit(t *testing.T) {
Expand All @@ -1172,7 +1171,7 @@ func TestResponseBodyLimit(t *testing.T) {

t.Run("Client body limit", func(t *testing.T) {
c := dc().SetResponseBodyLimit(1024)

assertEqual(t, 1024, c.ResponseBodyLimit())
_, err := c.R().Get(ts.URL + "/")
assertNotNil(t, err)
assertEqual(t, err, ErrResponseBodyTooLarge)
Expand Down
54 changes: 27 additions & 27 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ const debugRequestLogKey = "__restyDebugRequestLog"
//_______________________________________________________________________

func parseRequestURL(c *Client, r *Request) error {
if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
if l := len(c.pathParams) + len(c.rawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
params := make(map[string]string, l)

// GitHub #103 Path Params
for p, v := range r.PathParams {
params[p] = url.PathEscape(v)
}
for p, v := range c.PathParams {
for p, v := range c.pathParams {
if _, ok := params[p]; !ok {
params[p] = url.PathEscape(v)
}
Expand All @@ -46,7 +46,7 @@ func parseRequestURL(c *Client, r *Request) error {
params[p] = v
}
}
for p, v := range c.RawPathParams {
for p, v := range c.rawPathParams {
if _, ok := params[p]; !ok {
params[p] = v
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func parseRequestURL(c *Client, r *Request) error {
r.URL = "/" + r.URL
}

reqURL, err = url.Parse(c.BaseURL + r.URL)
reqURL, err = url.Parse(c.baseURL + r.URL)
if err != nil {
return err
}
Expand All @@ -126,8 +126,8 @@ func parseRequestURL(c *Client, r *Request) error {
}

// Adding Query Param
if len(c.QueryParam)+len(r.QueryParam) > 0 {
for k, v := range c.QueryParam {
if len(c.queryParam)+len(r.QueryParam) > 0 {
for k, v := range c.queryParam {
// skip query parameter if it was set in request
if _, ok := r.QueryParam[k]; ok {
continue
Expand Down Expand Up @@ -155,7 +155,7 @@ func parseRequestURL(c *Client, r *Request) error {
}

func parseRequestHeader(c *Client, r *Request) error {
for k, v := range c.Header {
for k, v := range c.header {
if _, ok := r.Header[k]; ok {
continue
}
Expand All @@ -174,13 +174,13 @@ func parseRequestHeader(c *Client, r *Request) error {
}

func parseRequestBody(c *Client, r *Request) error {
if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if isPayloadSupported(r.Method, c.allowGetMethodPayload) {
switch {
case r.isMultiPart: // Handling Multipart
if err := handleMultipart(c, r); err != nil {
return err
}
case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data
case len(c.formData) > 0 || len(r.FormData) > 0: // Handling Form Data
handleFormData(c, r)
case r.Body != nil: // Handling Request body
handleContentType(c, r)
Expand All @@ -205,7 +205,7 @@ func parseRequestBody(c *Client, r *Request) error {

func createHTTPRequest(c *Client, r *Request) (err error) {
if r.bodyBuf == nil {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.allowGetMethodPayload) {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
} else if c.setContentLength || r.setContentLength {
r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
Expand All @@ -229,7 +229,7 @@ func createHTTPRequest(c *Client, r *Request) (err error) {
r.RawRequest.Header = r.Header

// Add cookies from client instance into http request
for _, cookie := range c.Cookies {
for _, cookie := range c.cookies {
r.RawRequest.AddCookie(cookie)
}

Expand Down Expand Up @@ -271,32 +271,32 @@ func addCredentials(c *Client, r *Request) error {
if r.UserInfo != nil { // takes precedence
r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
isBasicAuth = true
} else if c.UserInfo != nil {
r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
} else if c.userInfo != nil {
r.RawRequest.SetBasicAuth(c.userInfo.Username, c.userInfo.Password)
isBasicAuth = true
}

if !c.DisableWarn {
if !c.disableWarn {
if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
}
}

// Set the Authorization Header Scheme
// Set the Authorization header Scheme
var authScheme string
if !IsStringEmpty(r.AuthScheme) {
authScheme = r.AuthScheme
} else if !IsStringEmpty(c.AuthScheme) {
authScheme = c.AuthScheme
} else if !IsStringEmpty(c.authScheme) {
authScheme = c.authScheme
} else {
authScheme = "Bearer"
}

// Build the Token Auth header
// Build the token Auth header
if !IsStringEmpty(r.Token) { // takes precedence
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.Token) {
r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+c.Token)
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+r.Token)
} else if !IsStringEmpty(c.token) {
r.RawRequest.Header.Set(c.headerAuthorizationKey, authScheme+" "+c.token)
}

return nil
Expand Down Expand Up @@ -401,11 +401,11 @@ func parseResponseBody(c *Client, res *Response) (err error) {
}
}

// HTTP status code > 399, considered as Error
// HTTP status code > 399, considered as error
if res.IsError() {
// global error interface
if res.Request.Error == nil && c.Error != nil {
res.Request.Error = reflect.New(c.Error).Interface()
if res.Request.Error == nil && c.error != nil {
res.Request.Error = reflect.New(c.error).Interface()
}

if res.Request.Error != nil {
Expand All @@ -431,7 +431,7 @@ func handleMultipart(c *Client, r *Request) error {
}
}

for k, v := range c.FormData {
for k, v := range c.formData {
for _, iv := range v {
if err := w.WriteField(k, iv); err != nil {
return err
Expand Down Expand Up @@ -472,7 +472,7 @@ func handleMultipart(c *Client, r *Request) error {
}

func handleFormData(c *Client, r *Request) {
for k, v := range c.FormData {
for k, v := range c.formData {
if _, ok := r.FormData[k]; ok {
continue
}
Expand Down Expand Up @@ -520,7 +520,7 @@ func handleRequestBody(c *Client, r *Request) error {
if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
r.bodyBuf, err = jsonMarshal(c, r, r.Body)
} else if IsXMLType(contentType) && (kind == reflect.Struct) {
bodyBytes, err = c.XMLMarshal(r.Body)
bodyBytes, err = c.xmlMarshal(r.Body)
}
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions middleware_test.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tttturtle-russ I don't think we need to modify the casing on the name field or comment lines. The accessor field requires an update here.

Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Test_parseRequestURL(t *testing.T) {
r.SetPathParams(map[string]string{
"foo": "4/5",
}).SetRawPathParams(map[string]string{
"foo": "4/5", // ignored, because the PathParams takes precedence over the RawPathParams
"foo": "4/5", // ignored, because the pathParams takes precedence over the rawPathParams
"bar": "6/7",
})
r.URL = "https://example.com/{foo}/{bar}"
Expand Down Expand Up @@ -182,7 +182,7 @@ func Test_parseRequestURL(t *testing.T) {
{
name: "using deprecated HostURL with relative path in request URL",
init: func(c *Client, r *Request) {
c.BaseURL = "https://example.com"
c.SetBaseURL("https://example.com")
r.URL = "foo/bar"
},
expectedURL: "https://example.com/foo/bar",
Expand Down
Loading
Loading