diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index a23c373d4..4cddc86d4 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -110,12 +110,7 @@ func (b *batchCommandsBuilder) push(entry *batchCommandsEntry) { const highTaskPriority = 10 func (b *batchCommandsBuilder) hasHighPriorityTask() bool { - item := b.entries.top() - if item == nil { - return false - } - - return item.priority() >= highTaskPriority + return b.entries.highestPriority() >= highTaskPriority } // buildWithLimit builds BatchCommandsRequests and calls collect() for each valid entry. @@ -123,36 +118,45 @@ func (b *batchCommandsBuilder) hasHighPriorityTask() bool { // The second is a map that maps forwarded hosts to requests. func (b *batchCommandsBuilder) buildWithLimit(limit int64, collect func(id uint64, e *batchCommandsEntry), ) (*tikvpb.BatchCommandsRequest, map[string]*tikvpb.BatchCommandsRequest) { - pending := b.entries.Len() - for count, i := int64(0), 0; i < pending; i++ { - e := b.entries.Pop().(*batchCommandsEntry) - if e.isCanceled() { - continue - } - if e.priority() < highTaskPriority { - count++ - if count > limit { - b.push(e) - break + count := int64(0) + build := func(reqs []Item) { + for _, e := range reqs { + e := e.(*batchCommandsEntry) + if e.isCanceled() { + continue + } + if e.priority() < highTaskPriority { + count++ } - } - if collect != nil { - collect(b.idAlloc, e) - } - if e.forwardedHost == "" { - b.requestIDs = append(b.requestIDs, b.idAlloc) - b.requests = append(b.requests, e.req) - } else { - batchReq, ok := b.forwardingReqs[e.forwardedHost] - if !ok { - batchReq = &tikvpb.BatchCommandsRequest{} - b.forwardingReqs[e.forwardedHost] = batchReq + if collect != nil { + collect(b.idAlloc, e) + } + if e.forwardedHost == "" { + b.requestIDs = append(b.requestIDs, b.idAlloc) + b.requests = append(b.requests, e.req) + } else { + batchReq, ok := b.forwardingReqs[e.forwardedHost] + if !ok { + batchReq = &tikvpb.BatchCommandsRequest{} + b.forwardingReqs[e.forwardedHost] = batchReq + } + batchReq.RequestIds = append(batchReq.RequestIds, b.idAlloc) + batchReq.Requests = append(batchReq.Requests, e.req) } - batchReq.RequestIds = append(batchReq.RequestIds, b.idAlloc) - batchReq.Requests = append(batchReq.Requests, e.req) + b.idAlloc++ + } + } + for (count < limit && b.entries.Len() > 0) || b.hasHighPriorityTask() { + n := limit + if limit == 0 { + n = 1 + } + reqs := b.entries.Take(int(n)) + if len(reqs) == 0 { + break } - b.idAlloc++ + build(reqs) } var req *tikvpb.BatchCommandsRequest if len(b.requests) > 0 { @@ -175,7 +179,7 @@ func (b *batchCommandsBuilder) cancel(e error) { // reset resets the builder to the initial state. // Should call it before collecting a new batch. func (b *batchCommandsBuilder) reset() { - b.entries.clean() + b.entries.Clean() // NOTE: We can't simply set entries = entries[:0] here. // The data in the cap part of the slice would reference the prewrite keys whose // underlying memory is borrowed from memdb. The reference cause GC can't release @@ -397,16 +401,16 @@ func (a *batchConn) getClientAndSend() { a.index = (a.index + 1) % uint32(len(a.batchCommandsClients)) target = a.batchCommandsClients[a.index].target // The lock protects the batchCommandsClient from been closed while it's in use. - if c := a.batchCommandsClients[a.index]; c.tryLockForSend() { - if hasHighPriorityTask || c.sent.Load() <= c.maxConcurrencyRequestLimit.Load() { + c := a.batchCommandsClients[a.index] + if hasHighPriorityTask || c.sent.Load() <= c.maxConcurrencyRequestLimit.Load() { + if c.tryLockForSend() { cli = c break } else { reason = SendFailedReasonNoAvailableLimit - c.unlockForSend() } } else { - reason = SendFailedReasonTryLockForSendFail + reason = SendFailedReasonNoAvailableLimit } } if cli == nil { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 8f95b6339..d4d22bfe6 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -730,18 +730,18 @@ func TestLimitConcurrency(t *testing.T) { batch.reqBuilder.reset() } + // highest priority task will be sent immediately, not limited by concurrency { batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}, pri: highTaskPriority}) batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}, pri: highTaskPriority - 1}) reqs, _ := batch.reqBuilder.buildWithLimit(0, func(_ uint64, _ *batchCommandsEntry) {}) re.Len(reqs.RequestIds, 1) - re.Equal(1, batch.reqBuilder.len()) batch.reqBuilder.reset() - batch.reqBuilder.entries.Reset() + re.Equal(1, batch.reqBuilder.len()) } + // medium priority tasks are limited by concurrency { - batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}}) batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}}) batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}}) reqs, _ := batch.reqBuilder.buildWithLimit(2, func(_ uint64, _ *batchCommandsEntry) {}) @@ -750,6 +750,14 @@ func TestLimitConcurrency(t *testing.T) { batch.reqBuilder.reset() } + // the expired tasks should be removed from the queue. + { + batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}, canceled: 1}) + batch.reqBuilder.push(&batchCommandsEntry{req: &tikvpb.BatchCommandsRequest_Request{}, canceled: 1}) + batch.reqBuilder.reset() + re.Equal(1, batch.reqBuilder.len()) + } + } func TestPrioritySentLimit(t *testing.T) { diff --git a/internal/client/priority_queue.go b/internal/client/priority_queue.go index 5e8b80db9..d89da66a3 100644 --- a/internal/client/priority_queue.go +++ b/internal/client/priority_queue.go @@ -87,16 +87,43 @@ func (pq *PriorityQueue) Push(item Item) { heap.Push(&pq.ps, entry{entry: item}) } -// Pop removes the highest priority entry from the priority queue. -func (pq *PriorityQueue) Pop() Item { - return heap.Pop(&pq.ps).(entry).entry +// pop removes the highest priority entry from the priority queue. +func (pq *PriorityQueue) pop() Item { + e := heap.Pop(&pq.ps) + if e == nil { + return nil + } + return e.(entry).entry } -func (pq *PriorityQueue) top() Item { - if pq.Len() == 0 { +// Take returns the highest priority entries from the priority queue. +func (pq *PriorityQueue) Take(n int) []Item { + if n <= 0 { return nil } - return pq.ps[0].entry + if n >= pq.Len() { + ret := make([]Item, pq.Len()) + for i := 0; i < pq.Len(); i++ { + ret[i] = pq.ps[i].entry + } + + pq.ps = pq.ps[:0] + return ret + } else { + ret := make([]Item, n) + for i := 0; i < n; i++ { + ret[i] = pq.pop() + } + return ret + } + +} + +func (pq *PriorityQueue) highestPriority() uint64 { + if pq.Len() == 0 { + return 0 + } + return pq.ps[0].entry.priority() } // All returns all entries in the priority queue not ensure the priority. @@ -108,11 +135,13 @@ func (pq *PriorityQueue) All() []Item { return items } -func (pq *PriorityQueue) clean() { - for i := 0; i < pq.Len(); i++ { +func (pq *PriorityQueue) Clean() { + for i := 0; i < pq.Len(); { if pq.ps[i].entry.isCanceled() { heap.Remove(&pq.ps, pq.ps[i].index) + continue } + i++ } } diff --git a/internal/client/priority_queue_test.go b/internal/client/priority_queue_test.go index a7b1a2952..8f0b87f2d 100644 --- a/internal/client/priority_queue_test.go +++ b/internal/client/priority_queue_test.go @@ -21,8 +21,9 @@ import ( ) type FakeItem struct { - pri uint64 - value int + pri uint64 + value int + canceled bool } func (f *FakeItem) priority() uint64 { @@ -30,19 +31,43 @@ func (f *FakeItem) priority() uint64 { } func (f *FakeItem) isCanceled() bool { - return false + return f.canceled } func TestPriority(t *testing.T) { re := require.New(t) - pq := NewPriorityQueue() - for i := 1; i <= 5; i++ { - pq.Push(&FakeItem{value: i, pri: uint64(i)}) - } - re.Equal(5, pq.Len()) - arr := pq.All() - re.Len(arr, 5) - for i := pq.Len(); i > 0; i-- { - re.Equal(i, pq.Pop().(*FakeItem).value) + testFunc := func(aq *PriorityQueue) { + for i := 1; i <= 5; i++ { + aq.Push(&FakeItem{value: i, pri: uint64(i)}) + } + re.Equal(5, aq.Len()) + re.Equal(uint64(5), aq.highestPriority()) + aq.Clean() + re.Equal(5, aq.Len()) + + arr := aq.Take(1) + re.Len(arr, 1) + re.Equal(uint64(5), arr[0].priority()) + re.Equal(uint64(4), aq.highestPriority()) + + arr = aq.Take(2) + re.Len(arr, 2) + re.Equal(uint64(4), arr[0].priority()) + re.Equal(uint64(3), arr[1].priority()) + re.Equal(uint64(2), aq.highestPriority()) + + arr = aq.Take(5) + re.Len(arr, 2) + re.Equal(uint64(2), arr[0].priority()) + re.Equal(uint64(1), arr[1].priority()) + re.Equal(uint64(0), aq.highestPriority()) + re.Equal(0, aq.Len()) + + aq.Push(&FakeItem{value: 1, pri: 1, canceled: true}) + re.Equal(1, aq.Len()) + aq.Clean() + re.Equal(0, aq.Len()) } + hq := NewPriorityQueue() + testFunc(hq) }