diff --git a/backup_handlers.go b/backup_handlers.go index 3d263102..fa677173 100644 --- a/backup_handlers.go +++ b/backup_handlers.go @@ -15,6 +15,7 @@ package backuplib import ( + "context" "io" "github.com/aerospike/aerospike-tools-backup-lib/encoding/asb" @@ -167,8 +168,13 @@ func (bwh *BackupHandler) GetStats() BackupStatus { } // Wait waits for the backup job to complete and returns an error if the job failed -func (bwh *BackupHandler) Wait() error { - return <-bwh.errors +func (bwh *BackupHandler) Wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-bwh.errors: + return err + } } func getDataWriter(eb EncoderBuilder, w io.Writer, namespace string, first bool) (*WriteWorker[*models.Token], error) { diff --git a/integration_test.go b/integration_test.go index 4e115fe0..f2036f0e 100644 --- a/integration_test.go +++ b/integration_test.go @@ -16,6 +16,7 @@ package backuplib_test import ( "bytes" + "context" "fmt" "io" "testing" @@ -178,7 +179,8 @@ func (suite *backupRestoreTestSuite) TestBackupRestoreIO() { suite.Nil(err) suite.NotNil(bh) - err = bh.Wait() + ctx := context.Background() + err = bh.Wait(ctx) suite.Nil(err) err = suite.testClient.Truncate(namespace, set) @@ -195,7 +197,7 @@ func (suite *backupRestoreTestSuite) TestBackupRestoreIO() { suite.Nil(err) suite.NotNil(rh) - err = rh.Wait() + err = rh.Wait(ctx) suite.Nil(err) err = suite.testClient.ValidateRecords(expectedRecs, numRec, namespace, set) diff --git a/restore_handlers.go b/restore_handlers.go index e0252b91..e8ddc632 100644 --- a/restore_handlers.go +++ b/restore_handlers.go @@ -15,6 +15,7 @@ package backuplib import ( + "context" "io" "github.com/aerospike/aerospike-tools-backup-lib/models" @@ -161,6 +162,11 @@ func (rrh *RestoreHandler) GetStats() RestoreStatus { } // Wait waits for the restore job to complete and returns an error if the job failed -func (rrh *RestoreHandler) Wait() error { - return <-rrh.errors +func (rrh *RestoreHandler) Wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-rrh.errors: + return err + } }