Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert the token from a string to an object #1585

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
442 changes: 419 additions & 23 deletions client/acquire_token.go

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions client/bearer_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ import (

// BasicAuth structure holds our credentials, this is the authorizer
type bearerAuth struct {
token string
token *tokenGenerator
}

// BearerAuthenticator is an Authenticator for BearerAuth
type bearerAuthenticator struct {
token string
token *tokenGenerator
}

// NewAuthenticator creates a new BearerAuthenticator
Expand All @@ -48,7 +48,11 @@ func (b *bearerAuth) AddAuthenticator(key string, fn gowebdav.AuthFactory) {

// Authorize the current request
func (b *bearerAuthenticator) Authorize(c *http.Client, rq *http.Request, path string) error {
rq.Header.Add("Authorization", "Bearer "+b.token) //set the header with the token
if b.token != nil {
if tokenContents, err := b.token.get(); err == nil && tokenContents != "" {
rq.Header.Add("Authorization", "Bearer "+tokenContents) //set the header with the token
}
}
return nil
}

Expand Down
8 changes: 6 additions & 2 deletions client/bearer_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func TestBearerAuthenticator_Authorize(t *testing.T) {
}))
defer server.Close()

authenticator := &bearerAuthenticator{token: "some_token_1234_abc"}
token := newTokenGenerator(nil, nil, false, false)
token.SetToken("some_token_1234_abc")
authenticator := &bearerAuthenticator{token: token}
client := &http.Client{}

// Create a HTTP request to be authorized
Expand All @@ -53,7 +55,9 @@ func TestBearerAuthenticator_Authorize(t *testing.T) {
}

func TestBearerAuthenticator_Verify(t *testing.T) {
authenticator := &bearerAuthenticator{token: "some_token_1234_abc"}
token := newTokenGenerator(nil, nil, false, false)
token.SetToken("some_token_1234_abc")
authenticator := &bearerAuthenticator{token: token}
client := &http.Client{}

// Create a dummy HTTP response with a 401 status
Expand Down
4 changes: 2 additions & 2 deletions client/fed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ func TestGetAndPutAuth(t *testing.T) {

// Upload the file with PUT
transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name()))
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17))

// Download that same file with GET
transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name()))
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes)
}
})
Expand Down
59 changes: 36 additions & 23 deletions client/handle_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ type (
callback TransferCallbackFunc
remoteURL *url.URL
localPath string
token string
token *tokenGenerator
upload bool
packOption string
attempts []transferAttemptDetails
Expand All @@ -234,12 +234,10 @@ type (
localPath string
upload bool
recursive bool
skipAcquire bool
prefObjServers []*url.URL // holds any client-requested caches/origins
dirResp server_structs.DirectorResponse
directorUrl string
tokenLocation string
token string
token *tokenGenerator
project string
}

Expand Down Expand Up @@ -639,7 +637,7 @@ func (te *TransferEngine) newPelicanURL(remoteUrl *url.URL) (pelicanURL pelicanU

// If the values do not exist, exit with failure
if pelicanURL.directorUrl == "" {
return pelicanUrl{}, fmt.Errorf("Missing metadata information in config, ensure Federation DirectorUrl, RegistryUrl, and DiscoverUrl are all set")
return pelicanUrl{}, errors.New("missing metadata information in config, ensure Federation DirectorUrl, RegistryUrl, and DiscoverUrl are all set")
}
}
return
Expand Down Expand Up @@ -1105,12 +1103,16 @@ func (tc *TransferClient) NewTransferJob(ctx context.Context, remoteUrl *url.URL
localPath: localPath,
remoteURL: &copyUrl,
callback: tc.callback,
skipAcquire: tc.skipAcquire,
tokenLocation: tc.tokenLocation,
upload: upload,
uuid: id,
token: tc.token,
project: project,
token: newTokenGenerator(&copyUrl, nil, upload, !tc.skipAcquire),
}
if tc.token != "" {
tj.token.SetToken(tc.token)
}
if tc.tokenLocation != "" {
tj.token.SetTokenLocation(tc.tokenLocation)
}

mergeCancel := func(ctx1, ctx2 context.Context) (context.Context, context.CancelFunc) {
Expand All @@ -1133,11 +1135,11 @@ func (tc *TransferClient) NewTransferJob(ctx context.Context, remoteUrl *url.URL
case identTransferOptionCallback{}:
tj.callback = option.Value().(TransferCallbackFunc)
case identTransferOptionTokenLocation{}:
tj.tokenLocation = option.Value().(string)
tj.token.SetTokenLocation(option.Value().(string))
case identTransferOptionAcquireToken{}:
tj.skipAcquire = !option.Value().(bool)
tj.token.EnableAcquire = option.Value().(bool)
case identTransferOptionToken{}:
tj.token = option.Value().(string)
tj.token.SetToken(option.Value().(string))
}
}

Expand All @@ -1149,11 +1151,12 @@ func (tc *TransferClient) NewTransferJob(ctx context.Context, remoteUrl *url.URL
return
}
tj.dirResp = dirResp
tj.token.DirResp = &dirResp

if (upload || dirResp.XPelNsHdr.RequireToken) && tj.token == "" {
tj.token, err = getToken(remoteUrl, dirResp, true, "", tc.tokenLocation, !tj.skipAcquire)
if err != nil {
return nil, fmt.Errorf("failed to get token for transfer: %v", err)
if upload || dirResp.XPelNsHdr.RequireToken {
contents, err := tj.token.get()
if err != nil || contents == "" {
return nil, errors.Wrap(err, "failed to get token for transfer")
}
}

Expand Down Expand Up @@ -1513,7 +1516,7 @@ func runTransferWorker(ctx context.Context, workChan <-chan *clientTransferFile,
//
// Attempts a HEAD against all the endpoints simultaneously. Put any that don't respond within
// a second behind those that do respond.
func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDetails, token string) (size int64, results []transferAttemptDetails) {
func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDetails, token *tokenGenerator) (size int64, results []transferAttemptDetails) {
size = -1
if len(attempts) < 2 {
results = attempts
Expand Down Expand Up @@ -1556,8 +1559,10 @@ func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDe
// header for GETs
headRequest, _ := http.NewRequestWithContext(ctx, http.MethodGet, tUrl.String(), nil)
headRequest.Header.Set("Range", "0-0")
if token != "" {
headRequest.Header.Set("Authorization", "Bearer "+token)
if token != nil {
if tokenContents, err := token.get(); err == nil && tokenContents != "" {
headRequest.Header.Set("Authorization", "Bearer "+tokenContents)
}
}
var headResponse *http.Response
headResponse, err := headClient.Do(headRequest)
Expand Down Expand Up @@ -1698,8 +1703,12 @@ func downloadObject(transfer *transferFile) (transferResults TransferResults, er
transferEndpointUrl.Path = transfer.remoteURL.Path
transferEndpoint.Url = &transferEndpointUrl
transferStartTime = time.Now() // Update start time for this attempt
tokenContents := ""
if transfer.token != nil {
tokenContents, _ = transfer.token.get()
}
attemptDownloaded, timeToFirstByte, cacheAge, serverVersion, err := downloadHTTP(
transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, size, transfer.token, transfer.project,
transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, size, tokenContents, transfer.project,
)
endTime := time.Now()
if cacheAge >= 0 {
Expand Down Expand Up @@ -2250,7 +2259,11 @@ func uploadObject(transfer *transferFile) (transferResult TransferResults, err e
return transferResult, err
}
// Set the authorization header as well as other headers
request.Header.Set("Authorization", "Bearer "+transfer.token)
if transfer.token != nil {
if tokenContents, err := transfer.token.get(); tokenContents != "" && err == nil {
request.Header.Set("Authorization", "Bearer "+tokenContents)
}
}
request.Header.Set("User-Agent", getUserAgent(transfer.project))
if searchJobAd(jobId) != "" {
request.Header.Set("X-Pelican-JobId", searchJobAd(jobId))
Expand Down Expand Up @@ -2388,7 +2401,7 @@ func runPut(request *http.Request, responseChan chan<- *http.Response, errorChan
}

// This helper function creates a web dav client to walkDavDir's. Used for recursive downloads and lists
func createWebDavClient(collectionsUrl *url.URL, token string, project string) (client *gowebdav.Client) {
func createWebDavClient(collectionsUrl *url.URL, token *tokenGenerator, project string) (client *gowebdav.Client) {
auth := &bearerAuth{token: token}
client = gowebdav.NewAuthClient(collectionsUrl.String(), auth)
client.SetHeader("User-Agent", getUserAgent(project))
Expand Down Expand Up @@ -2508,7 +2521,7 @@ func (te *TransferEngine) walkDirUpload(job *clientTransferJob, transfers []tran
}

// This function performs the ls command by walking through the specified collections and printing the contents of the files
func listHttp(remoteObjectUrl *url.URL, dirResp server_structs.DirectorResponse, token string) (fileInfos []FileInfo, err error) {
func listHttp(remoteObjectUrl *url.URL, dirResp server_structs.DirectorResponse, token *tokenGenerator) (fileInfos []FileInfo, err error) {
// Get our collection listing host
collectionsUrl := dirResp.XPelNsHdr.CollectionsUrl
log.Debugln("Collections URL: ", collectionsUrl.String())
Expand Down Expand Up @@ -2569,7 +2582,7 @@ func listHttp(remoteObjectUrl *url.URL, dirResp server_structs.DirectorResponse,
// Otherwise, the first three caches are queried simultaneously.
// For any of the queries, if the attempt with the proxy fails, a second attempt
// is made without.
func statHttp(dest *url.URL, dirResp server_structs.DirectorResponse, token string) (info FileInfo, err error) {
func statHttp(dest *url.URL, dirResp server_structs.DirectorResponse, token *tokenGenerator) (info FileInfo, err error) {
statHosts := make([]url.URL, 0, 3)
collectionsUrl := dirResp.XPelNsHdr.CollectionsUrl

Expand Down
14 changes: 9 additions & 5 deletions client/handle_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,24 +459,26 @@ func TestSortAttempts(t *testing.T) {

defer cancel()

size, results := sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt2, attempt3}, "")
token := newTokenGenerator(nil, nil, false, false)
token.SetToken("aaa")
size, results := sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt2, attempt3}, token)
assert.Equal(t, int64(42), size)
assert.Equal(t, svr2.URL, results[0].Url.String())
assert.Equal(t, svr3.URL, results[1].Url.String())
assert.Equal(t, svr1.URL, results[2].Url.String())

size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3, attempt1}, "")
size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3, attempt1}, token)
assert.Equal(t, int64(42), size)
assert.Equal(t, svr2.URL, results[0].Url.String())
assert.Equal(t, svr3.URL, results[1].Url.String())
assert.Equal(t, svr1.URL, results[2].Url.String())

size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt1}, "")
size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt1, attempt1}, token)
assert.Equal(t, int64(-1), size)
assert.Equal(t, svr1.URL, results[0].Url.String())
assert.Equal(t, svr1.URL, results[1].Url.String())

size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3}, "")
size, results = sortAttempts(ctx, "/path", []transferAttemptDetails{attempt2, attempt3}, token)
assert.Equal(t, int64(42), size)
assert.Equal(t, svr2.URL, results[0].Url.String())
assert.Equal(t, svr3.URL, results[1].Url.String())
Expand Down Expand Up @@ -1023,6 +1025,8 @@ func TestHeadRequestWithDownloadToken(t *testing.T) {
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)

token := newTokenGenerator(nil, nil, false, false)
token.SetToken("test-token")
transfer := &transferFile{
ctx: context.Background(),
job: &TransferJob{},
Expand All @@ -1033,7 +1037,7 @@ func TestHeadRequestWithDownloadToken(t *testing.T) {
Url: svrURL,
},
},
token: "test-token",
token: token,
}
_, _ = downloadObject(transfer)
}
Expand Down
4 changes: 2 additions & 2 deletions client/handle_ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func generateDestination(filePath string, originPrefix string, shadowOriginPrefi
if strings.HasPrefix(hashString, cleanedOriginPrefix) {
return shadowOriginPrefix + hashString[len(cleanedOriginPrefix):], localSize, nil
}
return "", 0, errors.New("File path must have the origin prefix")
return "", 0, errors.New("file path must have the origin prefix")
}

func DoShadowIngest(ctx context.Context, sourceFile string, originPrefix string, shadowOriginPrefix string, options ...TransferOption) (int64, string, error) {
Expand Down Expand Up @@ -118,5 +118,5 @@ func DoShadowIngest(ctx context.Context, sourceFile string, originPrefix string,
return transferResults[0].TransferredBytes, shadowFile, err
}
}
return 0, "", errors.New("After 10 upload attempts, file was still being modified during ingest.")
return 0, "", errors.New("after 10 upload attempts, file was still being modified during ingest")
}
Loading
Loading