diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 77d717d..c3abed4 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -2,12 +2,45 @@ package main import ( "fmt" + "github.com/k0kubun/go-ansi" + "github.com/schollz/progressbar/v3" + "os" + "time" hibpsync "github.com/exaring/go-hibp-sync" ) func main() { - if err := hibpsync.Sync(); err != nil { + stateFile, err := os.OpenFile(hibpsync.DefaultStateFile, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + fmt.Printf("opening state file error: %q", err) + } + + bar := progressbar.NewOptions(0xFFFFF, + progressbar.OptionSetWriter(ansi.NewAnsiStdout()), + progressbar.OptionEnableColorCodes(true), + progressbar.OptionSetDescription("[cyan]Syncing HIBP data...[reset]"), + progressbar.OptionShowCount(), + progressbar.OptionShowIts(), + progressbar.OptionSetItsString("prefixes"), + progressbar.OptionThrottle(100*time.Millisecond), + progressbar.OptionSetPredictTime(false), + progressbar.OptionSetElapsedTime(true), + progressbar.OptionSetTheme(progressbar.Theme{ + Saucer: "[green]=[reset]", + SaucerHead: "[green]>[reset]", + SaucerPadding: " ", + BarStart: "[", + BarEnd: "]", + })) + + updateProgressBar := func(lowest, current, _ int64) error { + _ = bar.Set64(current) + + return nil + } + + if err := hibpsync.Sync(hibpsync.WithProgressFn(updateProgressBar), hibpsync.WithStateFile(stateFile)); err != nil { fmt.Printf("sync error: %q", err) } } diff --git a/go.mod b/go.mod index f4f3df3..d6e41e9 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,18 @@ module github.com/exaring/go-hibp-sync go 1.21.6 require ( - github.com/alitto/pond v1.8.3 // indirect + github.com/alitto/pond v1.8.3 + github.com/deckarep/golang-set/v2 v2.6.0 + github.com/hashicorp/go-retryablehttp v0.7.5 + github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213 + github.com/schollz/progressbar/v3 v3.14.1 +) + +require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/hashicorp/go-retryablehttp v0.7.5 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.4.7 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/term v0.17.0 // indirect ) diff --git a/go.sum b/go.sum index 51323ed..060228c 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,37 @@ github.com/alitto/pond v1.8.3 h1:ydIqygCLVPqIX/USe5EaV/aSRXTRXDEI9JwuDdu+/xs= github.com/alitto/pond v1.8.3/go.mod h1:CmvIIGd5jKLasGI3D87qDkQxjzChdKMmnXMg3fG6M6Q= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80NsVHagjM= +github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= +github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213 h1:qGQQKEcAR99REcMpsXCp3lJ03zYT1PkRd3kQGPn9GVg= +github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/schollz/progressbar/v3 v3.14.1 h1:VD+MJPCr4s3wdhTc7OEJ/Z3dAeBzJ7yKH/P4lC5yRTI= +github.com/schollz/progressbar/v3 v3.14.1/go.mod h1:Zc9xXneTzWXF81TGoqL71u0sBPjULtEHYtj/WVgVy8E= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= diff --git a/lib.go b/lib.go index 8c1e807..cbe4cba 100644 --- a/lib.go +++ b/lib.go @@ -1,24 +1,32 @@ package hibpsync import ( + "bytes" + "errors" "fmt" - "github.com/alitto/pond" "github.com/hashicorp/go-retryablehttp" + "io" + "os" + "strconv" + "sync" ) const ( defaultDataDir = "./.hibp-data" defaultEndpoint = "https://api.pwnedpasswords.com/range/" defaultCheckETag = true - defaultWorkers = 100 + defaultWorkers = 50 + DefaultStateFile = "./.hibp-data/state" ) type syncConfig struct { - dataDir string - endpoint string - checkETag bool - worker int + dataDir string + endpoint string + checkETag bool + minWorkers int + progressFn ProgressFunc + stateFile io.ReadWriteSeeker } type SyncOption func(*syncConfig) @@ -41,81 +49,162 @@ func WithCheckETag(checkETag bool) SyncOption { } } -func WithWorkers(workers int) SyncOption { +func WithMinWorkers(workers int) SyncOption { return func(c *syncConfig) { - c.worker = workers + c.minWorkers = workers + } +} + +func WithStateFile(stateFile io.ReadWriteSeeker) SyncOption { + return func(c *syncConfig) { + c.stateFile = stateFile + } +} + +func WithProgressFn(progressFn ProgressFunc) SyncOption { + return func(c *syncConfig) { + c.progressFn = progressFn } } func Sync(options ...SyncOption) error { config := &syncConfig{ - dataDir: defaultDataDir, - endpoint: defaultEndpoint, - checkETag: defaultCheckETag, - worker: defaultWorkers, + dataDir: defaultDataDir, + endpoint: defaultEndpoint, + checkETag: defaultCheckETag, + minWorkers: defaultWorkers, + progressFn: func(_, _, _ int64) error { return nil }, } for _, option := range options { option(config) } - rG, err := newRangeGenerator(0x00000, 0xFFFFF, "") - if err != nil { - return fmt.Errorf("creating range generator: %w", err) + from := int64(0x00000) + + if config.stateFile != nil { + lastState, err := readStateFile(config.stateFile) + if err != nil { + return fmt.Errorf("error reading state file: %w", err) + } + + from = lastState + innerProgressFn := config.progressFn + + config.progressFn = func(lowest, current, to int64) error { + err := func() error { + if lowest < lastState+1000 { + return nil + } + + if _, err := config.stateFile.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("seeking to beginning of state file: %w", err) + } + + if _, err := config.stateFile.Write([]byte(fmt.Sprintf("%d", lowest))); err != nil { + return fmt.Errorf("writing state file: %w", err) + } + + lastState = lowest + + return nil + }() + + if err != nil { + fmt.Printf("updating state file: %v\n", err) + } + + return innerProgressFn(lowest, current, to) + } } - retryClient := retryablehttp.NewClient() //TODO: add dnscache, timeout + rG := newRangeGenerator(from, 0xFFFFF+1, config.progressFn) + + retryClient := retryablehttp.NewClient() retryClient.RetryMax = 10 retryClient.Logger = nil hc := hibpClient{ endpoint: config.endpoint, httpClient: retryClient.StandardClient(), + maxRetries: 2, } storage := fsStorage{ dataDir: config.dataDir, } - pool := pond.New(config.worker, 0, pond.MinWorkers(config.worker)) + pool := pond.New(config.minWorkers, 0, pond.MinWorkers(config.minWorkers)) defer pool.Stop() - for { - rangeIndex, ok, err := rG.Next() - if err != nil { - return fmt.Errorf("getting next range: %w", err) - } + var ( + outerErr error + errLock sync.Mutex + ) - if !ok { - break - } + for !pool.Stopped() { + pool.Submit(func() { + keepGoing, err := rG.Next(func(r int64) error { + rangePrefix := toRangeString(r) - if rangeIndex%100 == 0 || rangeIndex < 10 { - fmt.Printf("processing range %d\n", rangeIndex) - } + etag, _ := storage.LoadETag(rangePrefix) + // TODO: Log error with debug level - pool.Submit(func() { - rangePrefix := toRangeString(rangeIndex) - etag, err := storage.LoadETag(rangePrefix) - if err != nil { - fmt.Printf("error loading etag for range %q: %v\n", rangePrefix, err) - return - } + resp, err := hc.RequestRange(rangePrefix, etag) + if err != nil { + return fmt.Errorf("error requesting range %q: %w", rangePrefix, err) + } - resp, err := hc.RequestRange(rangePrefix, etag) + if resp.NotModified { + return nil + } + + if err := storage.Save(rangePrefix, resp.ETag, resp.Data); err != nil { + return fmt.Errorf("error saving range %q: %w", rangePrefix, err) + } + + return nil + }) if err != nil { - fmt.Printf("error requesting range %q: %v\n", rangePrefix, err) - return - } + errLock.Lock() + defer errLock.Unlock() - if resp.NotModified { - return + outerErr = errors.Join(fmt.Errorf("processing range: %w", err)) } - if err := storage.Save(rangePrefix, resp.ETag, resp.Data); err != nil { - fmt.Printf("error saving range %q: %v\n", rangePrefix, err) + + if !keepGoing { + pool.Stop() } }) } - return nil + return outerErr +} + +func readStateFile(stateFile io.ReadWriteSeeker) (int64, error) { + state, err := io.ReadAll(stateFile) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return 0, nil + } + + return 0, fmt.Errorf("reading state file: %w", err) + } + + state = bytes.TrimSpace(state) + + if len(state) == 0 { + return 0, nil + } + + lastState, err := strconv.ParseInt(string(state), 10, 64) + if err != nil { + return 0, fmt.Errorf("parsing state file: %w", err) + } + + if _, err := stateFile.Seek(0, io.SeekStart); err != nil { + return 0, fmt.Errorf("seeking to beginning of state file: %w", err) + } + + return lastState, nil } diff --git a/ranges.go b/ranges.go index 5bde52a..7da1fd3 100644 --- a/ranges.go +++ b/ranges.go @@ -1,64 +1,75 @@ package hibpsync import ( - "errors" "fmt" - "io/fs" - "os" - "strconv" + mapset "github.com/deckarep/golang-set/v2" + "math" "sync" + "sync/atomic" ) -const writeStateEveryN = 10 +type ProgressFunc func(lowest, current, to int64) error type rangeGenerator struct { - idx, to int - lock sync.Mutex - stateFilePath string + from, to int64 + idx *atomic.Int64 + inFlightSet mapset.Set[int64] + onProgress ProgressFunc + onProgressLock sync.Mutex } -func newRangeGenerator(from, to int, stateFilePath string) (*rangeGenerator, error) { - // Check if the state file exists and read the last state from it. - // This is useful to resume the sync process after a crash. - if stateFilePath != "" { - bytez, err := os.ReadFile(stateFilePath) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("reading state file: %w", err) - } - - from, err = strconv.Atoi(string(bytez)) - if err != nil { - return nil, fmt.Errorf("parsing state file: %w", err) - } - } +func newRangeGenerator(from, to int64, onProgress ProgressFunc) *rangeGenerator { + idx := &atomic.Int64{} + idx.Store(from) return &rangeGenerator{ - idx: from, - to: to, - stateFilePath: stateFilePath, - }, nil + from: from, + to: to, + idx: idx, + inFlightSet: mapset.NewSet[int64](), + onProgress: onProgress, + } } -func (r *rangeGenerator) Next() (int, bool, error) { - r.lock.Lock() - defer r.lock.Unlock() +func (r *rangeGenerator) Next(fn func(r int64) error) (bool, error) { + current := r.idx.Add(1) - 1 + + if current >= r.to { + return false, nil + } + + r.inFlightSet.Add(current) - if r.idx > r.to { - return 0, false, nil + if err := fn(current); err != nil { + return false, err } - current := r.idx - r.idx++ + r.inFlightSet.Remove(current) - if r.stateFilePath != "" && (current%writeStateEveryN == 0 || current == r.to) { - if err := os.WriteFile(r.stateFilePath, []byte(fmt.Sprintf("%d", current)), 0644); err != nil { - return 0, false, fmt.Errorf("writing state file: %w", err) + if current%10 == 0 || current == r.to-1 { + r.onProgressLock.Lock() + defer r.onProgressLock.Unlock() + + // TODO: Compute remaining and provide to progress function + + if err := r.onProgress(r.lowestInFlight(), current, r.to); err != nil { + return false, err } } - return current, true, nil + return true, nil +} + +func (r *rangeGenerator) lowestInFlight() int64 { + lowest := int64(math.MaxInt64) + + for _, a := range r.inFlightSet.ToSlice() { + lowest = min(lowest, a) + } + + return lowest } -func toRangeString(i int) string { +func toRangeString(i int64) string { return fmt.Sprintf("%05X", i) } diff --git a/storage.go b/storage.go index 432085d..9f6355d 100644 --- a/storage.go +++ b/storage.go @@ -115,5 +115,5 @@ func (f *fsStorage) subDir(key string) string { } func (f *fsStorage) filePath(key string) string { - return path.Join(f.subDir(key), key) + return path.Join(f.subDir(key), key[2:]) } diff --git a/upstream.go b/upstream.go index 15e6718..c1506cc 100644 --- a/upstream.go +++ b/upstream.go @@ -1,6 +1,7 @@ package hibpsync import ( + "errors" "fmt" "io" "net/http" @@ -9,6 +10,7 @@ import ( type hibpClient struct { endpoint string httpClient *http.Client + maxRetries int } type hibpResponse struct { @@ -27,9 +29,26 @@ func (h *hibpClient) RequestRange(rangePrefix, etag string) (*hibpResponse, erro req.Header.Set("If-None-Match", etag) } + var mErr error + + for i := 0; i < 1+h.maxRetries; i++ { + resp, err := h.request(req) + if err == nil { + return resp, nil + } + + // TODO: Log error with debug level + + mErr = errors.Join(mErr, err) + } + + return nil, fmt.Errorf("requesting range %d: %w", rangePrefix, mErr) +} + +func (h *hibpClient) request(req *http.Request) (*hibpResponse, error) { resp, err := h.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("executing request for range %q: %w", rangePrefix, err) + return nil, fmt.Errorf("executing request: %w", err) } if resp.StatusCode == http.StatusNotModified { @@ -37,13 +56,13 @@ func (h *hibpClient) RequestRange(rangePrefix, etag string) (*hibpResponse, erro } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code requesting range %q: %d", rangePrefix, resp.StatusCode) + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("reading response body for range %q: %w", rangePrefix, err) + return nil, fmt.Errorf("reading response body: %w", err) } return &hibpResponse{