diff --git a/pkg/storage/tsdb/config.go b/pkg/storage/tsdb/config.go index fa6f7b1c93..4d52b22398 100644 --- a/pkg/storage/tsdb/config.go +++ b/pkg/storage/tsdb/config.go @@ -64,6 +64,7 @@ var ( ErrInvalidTokenBucketBytesLimiterMode = errors.New("invalid token bucket bytes limiter mode") ErrInvalidLazyExpandedPostingGroupMaxKeySeriesRatio = errors.New("lazy expanded posting group max key series ratio needs to be equal or greater than 0") ErrInvalidBucketStoreType = errors.New("invalid bucket store type") + ErrInvalidMaxConcurrentBytes = errors.New("max concurrent bytes must be non-negative") ) // BlocksStorageConfig holds the config information for the blocks storage. @@ -281,6 +282,7 @@ type BucketStoreConfig struct { SyncInterval time.Duration `yaml:"sync_interval"` MaxConcurrent int `yaml:"max_concurrent"` MaxInflightRequests int `yaml:"max_inflight_requests"` + MaxConcurrentBytes int64 `yaml:"max_concurrent_bytes"` TenantSyncConcurrency int `yaml:"tenant_sync_concurrency"` BlockSyncConcurrency int `yaml:"block_sync_concurrency"` MetaSyncConcurrency int `yaml:"meta_sync_concurrency"` @@ -365,6 +367,7 @@ func (cfg *BucketStoreConfig) RegisterFlags(f *flag.FlagSet) { f.IntVar(&cfg.ChunkPoolMaxBucketSizeBytes, "blocks-storage.bucket-store.chunk-pool-max-bucket-size-bytes", ChunkPoolDefaultMaxBucketSize, "Size - in bytes - of the largest chunks pool bucket.") f.IntVar(&cfg.MaxConcurrent, "blocks-storage.bucket-store.max-concurrent", 100, "Max number of concurrent queries to execute against the long-term storage. The limit is shared across all tenants.") f.IntVar(&cfg.MaxInflightRequests, "blocks-storage.bucket-store.max-inflight-requests", 0, "Max number of inflight queries to execute against the long-term storage. The limit is shared across all tenants. 0 to disable.") + f.Int64Var(&cfg.MaxConcurrentBytes, "blocks-storage.bucket-store.max-concurrent-bytes", 0, "Max number of bytes being processed concurrently across all queries. When the limit is reached, new requests are rejected with HTTP 503. 0 to disable.") f.IntVar(&cfg.TenantSyncConcurrency, "blocks-storage.bucket-store.tenant-sync-concurrency", 10, "Maximum number of concurrent tenants syncing blocks.") f.IntVar(&cfg.BlockSyncConcurrency, "blocks-storage.bucket-store.block-sync-concurrency", 20, "Maximum number of concurrent blocks syncing per tenant.") f.IntVar(&cfg.MetaSyncConcurrency, "blocks-storage.bucket-store.meta-sync-concurrency", 20, "Number of Go routines to use when syncing block meta files from object storage per tenant.") @@ -429,6 +432,9 @@ func (cfg *BucketStoreConfig) Validate() error { if cfg.LazyExpandedPostingGroupMaxKeySeriesRatio < 0 { return ErrInvalidLazyExpandedPostingGroupMaxKeySeriesRatio } + if cfg.MaxConcurrentBytes < 0 { + return ErrInvalidMaxConcurrentBytes + } return nil } diff --git a/pkg/storage/tsdb/config_test.go b/pkg/storage/tsdb/config_test.go index 7a642cc600..41ec872425 100644 --- a/pkg/storage/tsdb/config_test.go +++ b/pkg/storage/tsdb/config_test.go @@ -145,6 +145,24 @@ func TestConfig_Validate(t *testing.T) { }, expectedErr: errUnSupportedWALCompressionType, }, + "should fail on negative max concurrent bytes": { + setup: func(cfg *BlocksStorageConfig) { + cfg.BucketStore.MaxConcurrentBytes = -1 + }, + expectedErr: ErrInvalidMaxConcurrentBytes, + }, + "should pass on zero max concurrent bytes (disabled)": { + setup: func(cfg *BlocksStorageConfig) { + cfg.BucketStore.MaxConcurrentBytes = 0 + }, + expectedErr: nil, + }, + "should pass on positive max concurrent bytes": { + setup: func(cfg *BlocksStorageConfig) { + cfg.BucketStore.MaxConcurrentBytes = 1024 * 1024 * 1024 // 1GB + }, + expectedErr: nil, + }, } for testName, testData := range tests { diff --git a/pkg/storegateway/bucket_stores.go b/pkg/storegateway/bucket_stores.go index f017457a9f..ca5c45ab42 100644 --- a/pkg/storegateway/bucket_stores.go +++ b/pkg/storegateway/bucket_stores.go @@ -95,6 +95,13 @@ type ThanosBucketStores struct { // Keeps number of inflight requests inflightRequests *util.InflightRequestTracker + // Concurrent bytes tracker for limiting bytes being processed across all queries. + concurrentBytesTracker ConcurrentBytesTracker + + // Holder for per-request bytes trackers. The BytesLimiterFactory (created + // once per user store) reads from this to find the current request's tracker. + requestBytesTrackerHolder *requestBytesTrackerHolder + // Metrics. syncTimes prometheus.Histogram syncLastSuccess prometheus.Gauge @@ -133,20 +140,22 @@ func newThanosBucketStores(cfg tsdb.BlocksStorageConfig, shardingStrategy Shardi }).Set(float64(cfg.BucketStore.MaxConcurrent)) u := &ThanosBucketStores{ - logger: logger, - cfg: cfg, - limits: limits, - bucket: cachingBucket, - shardingStrategy: shardingStrategy, - stores: map[string]*store.BucketStore{}, - storesErrors: map[string]error{}, - logLevel: logLevel, - bucketStoreMetrics: NewBucketStoreMetrics(), - metaFetcherMetrics: NewMetadataFetcherMetrics(), - queryGate: queryGate, - partitioner: newGapBasedPartitioner(cfg.BucketStore.PartitionerMaxGapBytes, reg), - userTokenBuckets: make(map[string]*util.TokenBucket), - inflightRequests: util.NewInflightRequestTracker(), + logger: logger, + cfg: cfg, + limits: limits, + bucket: cachingBucket, + shardingStrategy: shardingStrategy, + stores: map[string]*store.BucketStore{}, + storesErrors: map[string]error{}, + logLevel: logLevel, + bucketStoreMetrics: NewBucketStoreMetrics(), + metaFetcherMetrics: NewMetadataFetcherMetrics(), + queryGate: queryGate, + partitioner: newGapBasedPartitioner(cfg.BucketStore.PartitionerMaxGapBytes, reg), + userTokenBuckets: make(map[string]*util.TokenBucket), + inflightRequests: util.NewInflightRequestTracker(), + concurrentBytesTracker: NewConcurrentBytesTracker(uint64(cfg.BucketStore.MaxConcurrentBytes), reg), + requestBytesTrackerHolder: &requestBytesTrackerHolder{}, syncTimes: promauto.With(reg).NewHistogram(prometheus.HistogramOpts{ Name: "cortex_bucket_stores_blocks_sync_seconds", Help: "The total time it takes to perform a sync stores", @@ -381,6 +390,13 @@ func (u *ThanosBucketStores) Series(req *storepb.SeriesRequest, srv storepb.Stor defer u.inflightRequests.Dec() } + reqTracker := newRequestBytesTracker(u.concurrentBytesTracker) + u.requestBytesTrackerHolder.Set(reqTracker) + defer func() { + u.requestBytesTrackerHolder.Clear() + reqTracker.ReleaseAll() + }() + err = store.Series(req, spanSeriesServer{ Store_SeriesServer: srv, ctx: spanCtx, @@ -697,7 +713,7 @@ func (u *ThanosBucketStores) getOrCreateStore(userID string) (*store.BucketStore u.syncDirForUser(userID), newChunksLimiterFactory(u.limits, userID), newSeriesLimiterFactory(u.limits, userID), - newBytesLimiterFactory(u.limits, userID, u.getUserTokenBucket(userID), u.instanceTokenBucket, u.cfg.BucketStore.TokenBucketBytesLimiter, u.getTokensToRetrieve), + newBytesLimiterFactory(u.limits, userID, u.getUserTokenBucket(userID), u.instanceTokenBucket, u.cfg.BucketStore.TokenBucketBytesLimiter, u.getTokensToRetrieve, u.requestBytesTrackerHolder), u.partitioner, u.cfg.BucketStore.BlockSyncConcurrency, false, // No need to enable backward compatibility with Thanos pre 0.8.0 queriers diff --git a/pkg/storegateway/concurrent_bytes_tracker.go b/pkg/storegateway/concurrent_bytes_tracker.go new file mode 100644 index 0000000000..69d7ae3076 --- /dev/null +++ b/pkg/storegateway/concurrent_bytes_tracker.go @@ -0,0 +1,123 @@ +package storegateway + +import ( + "sync/atomic" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/thanos-io/thanos/pkg/pool" +) + +const peakResetInterval = 30 * time.Second + +type ConcurrentBytesTracker interface { + Add(bytes uint64) error + Release(bytes uint64) + Current() uint64 + Stop() +} + +type concurrentBytesTracker struct { + maxConcurrentBytes uint64 + currentBytes atomic.Uint64 + peakBytes atomic.Uint64 + stop chan struct{} + + peakBytesGauge prometheus.Gauge + maxBytesGauge prometheus.Gauge + rejectedRequestsTotal prometheus.Counter +} + +func NewConcurrentBytesTracker(maxConcurrentBytes uint64, reg prometheus.Registerer) ConcurrentBytesTracker { + tracker := &concurrentBytesTracker{ + maxConcurrentBytes: maxConcurrentBytes, + stop: make(chan struct{}), + peakBytesGauge: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "cortex_storegateway_concurrent_bytes_peak", + Help: "Peak concurrent bytes observed in the last 30s window.", + }), + maxBytesGauge: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "cortex_storegateway_concurrent_bytes_max", + Help: "Configured maximum concurrent bytes limit.", + }), + rejectedRequestsTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "cortex_storegateway_bytes_limiter_rejected_requests_total", + Help: "Total requests rejected due to concurrent bytes limit.", + }), + } + + tracker.maxBytesGauge.Set(float64(maxConcurrentBytes)) + if reg != nil { + reg.MustRegister(tracker.peakBytesGauge) + reg.MustRegister(tracker.maxBytesGauge) + reg.MustRegister(tracker.rejectedRequestsTotal) + } + + go tracker.publishPeakLoop() + + return tracker +} + +func (t *concurrentBytesTracker) Add(bytes uint64) error { + if t.maxConcurrentBytes > 0 && t.Current()+bytes > t.maxConcurrentBytes { + t.rejectedRequestsTotal.Inc() + return pool.ErrPoolExhausted + } + + newValue := t.currentBytes.Add(bytes) + for { + peak := t.peakBytes.Load() + if newValue <= peak { + break + } + if t.peakBytes.CompareAndSwap(peak, newValue) { + break + } + // CAS failed, retry + } + + return nil +} + +func (t *concurrentBytesTracker) Release(bytes uint64) { + for { + current := t.currentBytes.Load() + newValue := current - bytes + if t.currentBytes.CompareAndSwap(current, newValue) { + return + } + // CAS failed, retry + } +} + +func (t *concurrentBytesTracker) Current() uint64 { + return t.currentBytes.Load() +} + +func (t *concurrentBytesTracker) publishPeakLoop() { + ticker := time.NewTicker(peakResetInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + current := t.currentBytes.Load() + peak := t.peakBytes.Swap(current) + if current > peak { + peak = current + } + t.peakBytesGauge.Set(float64(peak)) + case <-t.stop: + return + } + } +} + +func (t *concurrentBytesTracker) Stop() { + select { + case <-t.stop: + // Already stopped. + default: + close(t.stop) + } +} diff --git a/pkg/storegateway/concurrent_bytes_tracker_test.go b/pkg/storegateway/concurrent_bytes_tracker_test.go new file mode 100644 index 0000000000..d8f3171769 --- /dev/null +++ b/pkg/storegateway/concurrent_bytes_tracker_test.go @@ -0,0 +1,374 @@ +package storegateway + +import ( + "math/rand" + "sync" + "testing" + "testing/quick" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thanos-io/thanos/pkg/pool" + "github.com/thanos-io/thanos/pkg/store" +) + +func TestConcurrentBytesTracker_Basic(t *testing.T) { + t.Run("add increments counter", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(1000, nil) + assert.Equal(t, uint64(0), tracker.Current()) + + require.NoError(t, tracker.Add(100)) + assert.Equal(t, uint64(100), tracker.Current()) + + tracker.Release(100) + assert.Equal(t, uint64(0), tracker.Current()) + }) + + t.Run("add rejects when would exceed limit", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(100, nil) + require.NoError(t, tracker.Add(100)) + + err := tracker.Add(1) + assert.ErrorIs(t, err, pool.ErrPoolExhausted) + assert.Equal(t, uint64(100), tracker.Current()) + }) + + t.Run("add allows when at exactly limit", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(100, nil) + require.NoError(t, tracker.Add(50)) + require.NoError(t, tracker.Add(50)) + assert.Equal(t, uint64(100), tracker.Current()) + }) +} + +func TestTrackerWithLimitingDisabled(t *testing.T) { + t.Run("add always succeeds", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(0, nil) + require.NoError(t, tracker.Add(1000)) + assert.Equal(t, uint64(1000), tracker.Current()) + tracker.Release(1000) + assert.Equal(t, uint64(0), tracker.Current()) + }) + + t.Run("add succeeds even with very high byte count", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(0, nil) + assert.NoError(t, tracker.Add(uint64(100)*1024*1024*1024)) + }) +} + +func TestConcurrentBytesTracker_Metrics(t *testing.T) { + reg := prometheus.NewRegistry() + tracker := NewConcurrentBytesTracker(1000, reg) + require.NoError(t, tracker.Add(500)) + + metricFamilies, err := reg.Gather() + require.NoError(t, err) + + metricNames := make(map[string]bool) + for _, mf := range metricFamilies { + metricNames[mf.GetName()] = true + } + + assert.True(t, metricNames["cortex_storegateway_concurrent_bytes_peak"]) + assert.True(t, metricNames["cortex_storegateway_concurrent_bytes_max"]) + assert.True(t, metricNames["cortex_storegateway_bytes_limiter_rejected_requests_total"]) + + tracker.Release(500) +} + +func TestProperty_AddIncrementsCounter(t *testing.T) { + f := func(bytes uint64) bool { + bytes = (bytes % (1024 * 1024 * 1024)) + 1 + tracker := NewConcurrentBytesTracker(uint64(10)*1024*1024*1024, nil) + + err := tracker.Add(bytes) + if err != nil { + return false + } + return tracker.Current() == bytes + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_ReleaseRoundTrip(t *testing.T) { + f := func(bytes uint64) bool { + bytes = (bytes % (1024 * 1024 * 1024)) + 1 + tracker := NewConcurrentBytesTracker(uint64(10)*1024*1024*1024, nil) + + _ = tracker.Add(bytes) + tracker.Release(bytes) + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_ThreadSafeCounterUpdates(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + numOps := rng.Intn(100) + 1 + bytesPerOp := uint64(rng.Intn(1024*1024)) + 1 + + tracker := NewConcurrentBytesTracker(uint64(100)*1024*1024*1024, nil) + + var wg sync.WaitGroup + for range numOps { + wg.Add(1) + go func() { + defer wg.Done() + _ = tracker.Add(bytesPerOp) + }() + } + wg.Wait() + + if tracker.Current() != uint64(numOps)*bytesPerOp { + return false + } + + for range numOps { + wg.Add(1) + go func() { + defer wg.Done() + tracker.Release(bytesPerOp) + }() + } + wg.Wait() + + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_PeakBytesMetricTracksPeak(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + numOps := rng.Intn(50) + 1 + + reg := prometheus.NewRegistry() + tracker := NewConcurrentBytesTracker(uint64(100)*1024*1024*1024, reg).(*concurrentBytesTracker) + + var totalBytes uint64 + var maxSeen uint64 + + for range numOps { + bytes := uint64(rng.Intn(1024*1024)) + 1 + _ = tracker.Add(bytes) + totalBytes += bytes + if current := tracker.Current(); current > maxSeen { + maxSeen = current + } + } + + if tracker.peakBytes.Load() < maxSeen { + return false + } + + tracker.Release(totalBytes) + return tracker.Current() == 0 && tracker.peakBytes.Load() >= maxSeen + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_PositiveLimitEnforcement(t *testing.T) { + f := func(limit uint64, bytesToAdd uint64) bool { + limit = (limit % (10 * 1024 * 1024 * 1024)) + 1 + bytesToAdd = limit + (bytesToAdd % (1024 * 1024 * 1024)) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + err := tracker.Add(bytesToAdd) + return err != nil + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_BelowLimitAccepts(t *testing.T) { + f := func(limit uint64, bytesToAdd uint64) bool { + limit = (limit % (10 * 1024 * 1024 * 1024)) + 1024 + bytesToAdd = bytesToAdd % (limit + 1) + + tracker := NewConcurrentBytesTracker(limit, nil) + return tracker.Add(bytesToAdd) == nil + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_RejectionDoesNotCountBytes(t *testing.T) { + f := func(limit uint64, numRejections uint8) bool { + limit = (limit % (1024 * 1024 * 1024)) + 1024 + numRejections = (numRejections % 10) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + if err := tracker.Add(limit); err != nil { + return false + } + + for i := uint8(0); i < numRejections; i++ { + if tracker.Add(1) == nil { + return false + } + } + + return tracker.Current() == limit + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_RejectionReturnsPoolExhausted(t *testing.T) { + f := func(limit uint64, bytesToAdd uint64) bool { + limit = (limit % (1024 * 1024 * 1024)) + 1 + bytesToAdd = limit + (bytesToAdd % (1024 * 1024 * 1024)) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + return tracker.Add(bytesToAdd) == pool.ErrPoolExhausted + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_RecoveryAfterRelease(t *testing.T) { + f := func(limit uint64, bytesToAdd uint64) bool { + limit = (limit % (1024 * 1024 * 1024)) + 1024 + bytesToAdd = (bytesToAdd % limit) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + if err := tracker.Add(limit); err != nil { + return false + } + if tracker.Add(1) == nil { + return false + } + + tracker.Release(limit) + return tracker.Add(bytesToAdd) == nil + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_ConcurrentDecrementsCorrectness(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + numOps := rng.Intn(100) + 10 + + tracker := NewConcurrentBytesTracker(uint64(100)*1024*1024*1024, nil) + + var totalBytes uint64 + bytesPerOp := make([]uint64, numOps) + for i := range numOps { + bytes := uint64(rng.Intn(1024*1024)) + 1 + bytesPerOp[i] = bytes + totalBytes += bytes + _ = tracker.Add(bytes) + } + + if tracker.Current() != totalBytes { + return false + } + + var wg sync.WaitGroup + for i := range numOps { + wg.Add(1) + go func(idx int) { + defer wg.Done() + tracker.Release(bytesPerOp[idx]) + }(i) + } + wg.Wait() + + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_PanicRecoveryCleanup(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + bytesLimit := uint64(rng.Intn(1024*1024*1024)) + 1024 + bytesToTrack := uint64(rng.Intn(1024*1024)) + 1 + + tracker := NewConcurrentBytesTracker(bytesLimit, nil) + + func() { + defer func() { recover() }() + + if err := tracker.Add(bytesToTrack); err != nil { + return + } + defer tracker.Release(bytesToTrack) + + panic("simulated panic") + }() + + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestProperty_PanicRecoveryWithRequestTracker(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + bytesLimit := uint64(rng.Intn(1024*1024*1024)) + 1024 + numLimiters := rng.Intn(10) + 1 + + tracker := NewConcurrentBytesTracker(bytesLimit, nil) + reqTracker := newRequestBytesTracker(tracker) + + func() { + defer func() { recover() }() + defer reqTracker.ReleaseAll() + + for range numLimiters { + inner := newMockBytesLimiter(bytesLimit) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + bytes := uint64(rng.Intn(1024*1024)) + 1 + limiter.ReserveWithType(bytes, store.PostingsFetched) + } + + panic("simulated panic") + }() + + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestPropertyAddReturnsErrPoolExhaustedIffOverLimit(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + limit := uint64(rng.Intn(10*1024*1024)) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + defer tracker.Stop() + + numAdds := rng.Intn(20) + 1 + var cumulativeBytes uint64 + + for i := 0; i < numAdds; i++ { + bytes := uint64(rng.Intn(int(limit))) + 1 + err := tracker.Add(bytes) + + if cumulativeBytes+bytes > limit { + if err != pool.ErrPoolExhausted { + return false + } + } else { + if err != nil { + return false + } + cumulativeBytes += bytes + } + } + + return tracker.Current() == cumulativeBytes + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestPropertyAddReturnsNilErrorWhenLimitingDisabled(t *testing.T) { + f := func(bytes uint64) bool { + bytes = (bytes % (1024 * 1024 * 1024)) + 1 + tracker := NewConcurrentBytesTracker(0, nil) + defer tracker.Stop() + return tracker.Add(bytes) == nil + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} diff --git a/pkg/storegateway/limiter.go b/pkg/storegateway/limiter.go index d907925505..e552c0a6de 100644 --- a/pkg/storegateway/limiter.go +++ b/pkg/storegateway/limiter.go @@ -128,7 +128,7 @@ func newSeriesLimiterFactory(limits *validation.Overrides, userID string) store. } } -func newBytesLimiterFactory(limits *validation.Overrides, userID string, userTokenBucket, instanceTokenBucket *util.TokenBucket, tokenBucketBytesLimiterCfg tsdb.TokenBucketBytesLimiterConfig, getTokensToRetrieve func(tokens uint64, dataType store.StoreDataType) int64) store.BytesLimiterFactory { +func newBytesLimiterFactory(limits *validation.Overrides, userID string, userTokenBucket, instanceTokenBucket *util.TokenBucket, tokenBucketBytesLimiterCfg tsdb.TokenBucketBytesLimiterConfig, getTokensToRetrieve func(tokens uint64, dataType store.StoreDataType) int64, trackerHolder *requestBytesTrackerHolder) store.BytesLimiterFactory { return func(failedCounter prometheus.Counter) store.BytesLimiter { limiters := []store.BytesLimiter{} // Since limit overrides could be live reloaded, we have to get the current user's limit @@ -141,8 +141,16 @@ func newBytesLimiterFactory(limits *validation.Overrides, userID string, userTok limiters = append(limiters, newTokenBucketBytesLimiter(requestTokenBucket, userTokenBucket, instanceTokenBucket, dryRun, failedCounter, getTokensToRetrieve)) } - return &compositeBytesLimiter{ + innerLimiter := &compositeBytesLimiter{ limiters: limiters, } + + if trackerHolder != nil { + reqTracker := trackerHolder.Get() + if reqTracker != nil { + return newTrackingBytesLimiter(innerLimiter, reqTracker) + } + } + return innerLimiter } } diff --git a/pkg/storegateway/tracking_bytes_limiter.go b/pkg/storegateway/tracking_bytes_limiter.go new file mode 100644 index 0000000000..60ddc669f1 --- /dev/null +++ b/pkg/storegateway/tracking_bytes_limiter.go @@ -0,0 +1,96 @@ +package storegateway + +import ( + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/thanos-io/thanos/pkg/store" +) + +type requestBytesTracker struct { + tracker ConcurrentBytesTracker + total atomic.Uint64 + released atomic.Bool +} + +func newRequestBytesTracker(tracker ConcurrentBytesTracker) *requestBytesTracker { + return &requestBytesTracker{ + tracker: tracker, + } +} + +func (r *requestBytesTracker) Add(bytes uint64) error { + if err := r.tracker.Add(bytes); err != nil { + return err + } + r.total.Add(bytes) + return nil +} + +func (r *requestBytesTracker) ReleaseAll() { + if !r.released.CompareAndSwap(false, true) { + return + } + bytes := r.total.Load() + if bytes > 0 { + r.tracker.Release(bytes) + } +} + +func (r *requestBytesTracker) Total() uint64 { + return r.total.Load() +} + +type trackingBytesLimiter struct { + inner store.BytesLimiter + requestTracker *requestBytesTracker +} + +func newTrackingBytesLimiter(inner store.BytesLimiter, requestTracker *requestBytesTracker) *trackingBytesLimiter { + return &trackingBytesLimiter{ + inner: inner, + requestTracker: requestTracker, + } +} + +func (t *trackingBytesLimiter) ReserveWithType(num uint64, dataType store.StoreDataType) error { + if err := t.inner.ReserveWithType(num, dataType); err != nil { + return err + } + return t.requestTracker.Add(num) +} + +type requestBytesTrackerHolder struct { + trackers sync.Map +} + +func (h *requestBytesTrackerHolder) Set(tracker *requestBytesTracker) { + h.trackers.Store(getGoroutineID(), tracker) +} + +func (h *requestBytesTrackerHolder) Get() *requestBytesTracker { + val, ok := h.trackers.Load(getGoroutineID()) + if !ok { + return nil + } + return val.(*requestBytesTracker) +} + +func (h *requestBytesTrackerHolder) Clear() { + h.trackers.Delete(getGoroutineID()) +} + +func getGoroutineID() int64 { + var buf [64]byte + n := runtime.Stack(buf[:], false) + // Stack output starts with "goroutine [" + s := strings.TrimPrefix(string(buf[:n]), "goroutine ") + if idx := strings.IndexByte(s, ' '); idx >= 0 { + s = s[:idx] + } + id, _ := strconv.ParseInt(s, 10, 64) + return id +} diff --git a/pkg/storegateway/tracking_bytes_limiter_test.go b/pkg/storegateway/tracking_bytes_limiter_test.go new file mode 100644 index 0000000000..68bd57e28d --- /dev/null +++ b/pkg/storegateway/tracking_bytes_limiter_test.go @@ -0,0 +1,235 @@ +package storegateway + +import ( + "math/rand" + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thanos-io/thanos/pkg/store" +) + +type mockBytesLimiter struct { + reservedBytes uint64 + limit uint64 +} + +func newMockBytesLimiter(limit uint64) *mockBytesLimiter { + return &mockBytesLimiter{limit: limit} +} + +func (m *mockBytesLimiter) ReserveWithType(num uint64, _ store.StoreDataType) error { + m.reservedBytes += num + return nil +} + +func (m *mockBytesLimiter) Reserved() uint64 { + return m.reservedBytes +} + +func TestRequestBytesTracker_Basic(t *testing.T) { + t.Run("add delegates to underlying tracker", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + + require.NoError(t, reqTracker.Add(100)) + assert.Equal(t, uint64(100), tracker.Current()) + assert.Equal(t, uint64(100), reqTracker.Total()) + }) + + t.Run("multiple adds accumulate", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + + require.NoError(t, reqTracker.Add(100)) + require.NoError(t, reqTracker.Add(200)) + require.NoError(t, reqTracker.Add(300)) + assert.Equal(t, uint64(600), tracker.Current()) + assert.Equal(t, uint64(600), reqTracker.Total()) + }) + + t.Run("release all decrements tracker", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + + require.NoError(t, reqTracker.Add(100)) + require.NoError(t, reqTracker.Add(200)) + reqTracker.ReleaseAll() + assert.Equal(t, uint64(0), tracker.Current()) + }) + + t.Run("release all is idempotent", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + + require.NoError(t, reqTracker.Add(100)) + reqTracker.ReleaseAll() + assert.Equal(t, uint64(0), tracker.Current()) + + // Second call should be a no-op. + reqTracker.ReleaseAll() + assert.Equal(t, uint64(0), tracker.Current()) + }) + + t.Run("propagates add error from tracker", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(50, nil) + reqTracker := newRequestBytesTracker(tracker) + + require.NoError(t, reqTracker.Add(30)) + require.Error(t, reqTracker.Add(30)) // would exceed 50 + assert.Equal(t, uint64(30), tracker.Current()) + assert.Equal(t, uint64(30), reqTracker.Total()) + }) +} + +func TestTrackingBytesLimiter_Basic(t *testing.T) { + t.Run("reserves bytes through inner limiter", func(t *testing.T) { + inner := newMockBytesLimiter(1000) + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + require.NoError(t, limiter.ReserveWithType(100, store.PostingsFetched)) + assert.Equal(t, uint64(100), inner.Reserved()) + }) + + t.Run("tracks bytes in request tracker", func(t *testing.T) { + inner := newMockBytesLimiter(1000) + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + require.NoError(t, limiter.ReserveWithType(100, store.PostingsFetched)) + assert.Equal(t, uint64(100), tracker.Current()) + assert.Equal(t, uint64(100), reqTracker.Total()) + }) + + t.Run("multiple limiters share request tracker", func(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + reqTracker := newRequestBytesTracker(tracker) + + limiter1 := newTrackingBytesLimiter(newMockBytesLimiter(1000), reqTracker) + limiter2 := newTrackingBytesLimiter(newMockBytesLimiter(1000), reqTracker) + limiter3 := newTrackingBytesLimiter(newMockBytesLimiter(1000), reqTracker) + + require.NoError(t, limiter1.ReserveWithType(100, store.PostingsFetched)) + require.NoError(t, limiter2.ReserveWithType(200, store.SeriesFetched)) + require.NoError(t, limiter3.ReserveWithType(300, store.ChunksFetched)) + assert.Equal(t, uint64(600), tracker.Current()) + assert.Equal(t, uint64(600), reqTracker.Total()) + + reqTracker.ReleaseAll() + assert.Equal(t, uint64(0), tracker.Current()) + }) +} + +func TestRequestBytesTracker_PanicRecovery(t *testing.T) { + tracker := NewConcurrentBytesTracker(10000, nil) + + func() { + reqTracker := newRequestBytesTracker(tracker) + defer func() { recover() }() + defer reqTracker.ReleaseAll() + + require.NoError(t, reqTracker.Add(100)) + panic("simulated panic") + }() + + assert.Equal(t, uint64(0), tracker.Current()) +} + +func TestProperty_BytesLimiterIntegration(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + numReserves := rng.Intn(50) + 1 + + inner := newMockBytesLimiter(uint64(100) * 1024 * 1024 * 1024) + tracker := NewConcurrentBytesTracker(uint64(100)*1024*1024*1024, nil) + reqTracker := newRequestBytesTracker(tracker) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + var totalBytes uint64 + for range numReserves { + bytes := uint64(rng.Intn(1024*1024)) + 1 + if err := limiter.ReserveWithType(bytes, store.StoreDataType(rng.Intn(6))); err != nil { + return false + } + totalBytes += bytes + } + + if tracker.Current() != totalBytes || inner.Reserved() != totalBytes || reqTracker.Total() != totalBytes { + return false + } + + reqTracker.ReleaseAll() + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestPropertyReserveWithTypePropagatesAddError(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + limit := uint64(rng.Intn(10*1024*1024)) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + reqTracker := newRequestBytesTracker(tracker) + inner := newMockBytesLimiter(^uint64(0)) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + numReserves := rng.Intn(20) + 1 + var trackedBytes uint64 + + for range numReserves { + bytes := uint64(rng.Intn(2*1024*1024)) + 1 + wouldExceed := trackedBytes+bytes > limit + + err := limiter.ReserveWithType(bytes, store.StoreDataType(rng.Intn(6))) + + if wouldExceed && err == nil { + return false + } + if !wouldExceed && err != nil { + return false + } + if err == nil { + trackedBytes += bytes + } + } + return true + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +} + +func TestPropertyReleaseWorksCorrectlyAfterAddError(t *testing.T) { + f := func(seed int64) bool { + rng := rand.New(rand.NewSource(seed)) + limit := uint64(rng.Intn(1024)) + 1 + + tracker := NewConcurrentBytesTracker(limit, nil) + reqTracker := newRequestBytesTracker(tracker) + inner := newMockBytesLimiter(^uint64(0)) + limiter := newTrackingBytesLimiter(inner, reqTracker) + + numReserves := rng.Intn(10) + 1 + sawError := false + + for range numReserves { + bytes := uint64(rng.Intn(int(limit))) + 1 + if limiter.ReserveWithType(bytes, store.StoreDataType(rng.Intn(6))) != nil { + sawError = true + } + } + + if !sawError { + if limiter.ReserveWithType(limit+1, store.PostingsFetched) == nil { + return false + } + } + + reqTracker.ReleaseAll() + return tracker.Current() == 0 + } + require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 100})) +}