diff --git a/pkg/ring/basic_lifecycler.go b/pkg/ring/basic_lifecycler.go index fb751e4fa1d..956af373668 100644 --- a/pkg/ring/basic_lifecycler.go +++ b/pkg/ring/basic_lifecycler.go @@ -16,6 +16,7 @@ import ( "github.com/cortexproject/cortex/pkg/ring/kv" "github.com/cortexproject/cortex/pkg/util/services" + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) type BasicLifecyclerDelegate interface { @@ -327,7 +328,9 @@ func (l *BasicLifecycler) waitStableTokens(ctx context.Context, period time.Dura // The first observation will occur after the specified period. level.Info(l.logger).Log("msg", "waiting stable tokens", "ring", l.ringName) - observeChan := time.After(period) + observeTimer := time.NewTimer(period) + defer utiltimer.StopAndDrainTimer(observeTimer) + observeChan := observeTimer.C for { select { @@ -335,7 +338,7 @@ func (l *BasicLifecycler) waitStableTokens(ctx context.Context, period time.Dura if !l.verifyTokens(ctx) { // The verification has failed level.Info(l.logger).Log("msg", "tokens verification failed, keep observing", "ring", l.ringName) - observeChan = time.After(period) + utiltimer.ResetTimer(observeTimer, period) break } diff --git a/pkg/ring/kv/dynamodb/client.go b/pkg/ring/kv/dynamodb/client.go index a7c2bd5a98d..a1bd15a24a1 100644 --- a/pkg/ring/kv/dynamodb/client.go +++ b/pkg/ring/kv/dynamodb/client.go @@ -13,6 +13,7 @@ import ( "github.com/cortexproject/cortex/pkg/ring/kv/codec" "github.com/cortexproject/cortex/pkg/util/backoff" + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) const ( @@ -185,7 +186,7 @@ func (c *Client) CAS(ctx context.Context, key string, f func(in any) (out any, r continue } - putRequests := map[dynamodbKey]dynamodbItem{} + putRequests := make(map[dynamodbKey]dynamodbItem, len(buf)) for childKey, bytes := range buf { version := int64(0) if ddbItem, ok := resp[childKey]; ok { @@ -267,7 +268,7 @@ func (c *Client) WatchKey(ctx context.Context, key string, f func(any) bool) { } bo.Reset() - resetTimer(syncTimer, c.pullerSyncTime) + utiltimer.ResetTimer(syncTimer, c.pullerSyncTime) select { case <-ctx.Done(): return @@ -305,7 +306,7 @@ func (c *Client) WatchPrefix(ctx context.Context, prefix string, f func(string, } bo.Reset() - resetTimer(syncTimer, c.pullerSyncTime) + utiltimer.ResetTimer(syncTimer, c.pullerSyncTime) select { case <-ctx.Done(): return @@ -314,16 +315,6 @@ func (c *Client) WatchPrefix(ctx context.Context, prefix string, f func(string, } } -func resetTimer(timer *time.Timer, d time.Duration) { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(d) -} - func (c *Client) decodeMultikey(data map[string]dynamodbItem) (codec.MultiKey, error) { multiKeyData := make(map[string][]byte, len(data)) for key, ddbItem := range data { diff --git a/pkg/ring/kv/dynamodb/client_timer_benchmark_test.go b/pkg/ring/kv/dynamodb/client_timer_benchmark_test.go index 3fb494ef685..0f948037766 100644 --- a/pkg/ring/kv/dynamodb/client_timer_benchmark_test.go +++ b/pkg/ring/kv/dynamodb/client_timer_benchmark_test.go @@ -3,6 +3,8 @@ package dynamodb import ( "testing" "time" + + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) func BenchmarkWatchLoopWaitWithTimeAfter(b *testing.B) { @@ -30,7 +32,7 @@ func BenchmarkWatchLoopWaitWithReusableTimer(b *testing.B) { b.ReportAllocs() for b.Loop() { - resetTimer(timer, interval) + utiltimer.ResetTimer(timer, interval) select { case <-ctx.Done(): diff --git a/pkg/ring/lifecycler.go b/pkg/ring/lifecycler.go index 6038de2277b..1db33a929f0 100644 --- a/pkg/ring/lifecycler.go +++ b/pkg/ring/lifecycler.go @@ -21,6 +21,7 @@ import ( "github.com/cortexproject/cortex/pkg/ring/kv" "github.com/cortexproject/cortex/pkg/util/flagext" "github.com/cortexproject/cortex/pkg/util/services" + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) var ( @@ -526,11 +527,34 @@ func (i *Lifecycler) loop(ctx context.Context) error { } // We do various period tasks + var autoJoinTimer *time.Timer var autoJoinAfter <-chan time.Time + var observeTimer *time.Timer var observeChan <-chan time.Time + setAutoJoinAfter := func(d time.Duration) { + if autoJoinTimer == nil { + autoJoinTimer = time.NewTimer(d) + } else { + utiltimer.ResetTimer(autoJoinTimer, d) + } + autoJoinAfter = autoJoinTimer.C + } + + setObserveAfter := func(d time.Duration) { + if observeTimer == nil { + observeTimer = time.NewTimer(d) + } else { + utiltimer.ResetTimer(observeTimer, d) + } + observeChan = observeTimer.C + } + + defer utiltimer.StopAndDrainTimer(autoJoinTimer) + defer utiltimer.StopAndDrainTimer(observeTimer) + if i.autoJoinOnStartup { - autoJoinAfter = time.After(i.cfg.JoinAfter) + setAutoJoinAfter(i.cfg.JoinAfter) } var heartbeatTickerChan <-chan time.Time @@ -556,7 +580,7 @@ func (i *Lifecycler) loop(ctx context.Context) error { for { select { case <-i.autojoinChan: - autoJoinAfter = time.After(i.cfg.JoinAfter) + setAutoJoinAfter(i.cfg.JoinAfter) case <-autoJoinAfter: if joined { continue @@ -576,7 +600,7 @@ func (i *Lifecycler) loop(ctx context.Context) error { } level.Info(i.logger).Log("msg", "observing tokens before going ACTIVE", "ring", i.RingName) - observeChan = time.After(i.cfg.ObservePeriod) + setObserveAfter(i.cfg.ObservePeriod) } else { if err := i.autoJoin(context.Background(), i.getPreviousState(), addedInRing); err != nil { return errors.Wrapf(err, "failed to pick tokens in the KV store, ring: %s, state: %s", i.RingName, i.getPreviousState()) @@ -593,6 +617,7 @@ func (i *Lifecycler) loop(ctx context.Context) error { // When observing is done, observeChan is set to nil. observeChan = nil + utiltimer.StopAndDrainTimer(observeTimer) if s := i.GetState(); s != JOINING { level.Error(i.logger).Log("msg", "unexpected state while observing tokens", "state", s, "ring", i.RingName) } @@ -611,7 +636,7 @@ func (i *Lifecycler) loop(ctx context.Context) error { } else { level.Info(i.logger).Log("msg", "token verification failed, observing", "ring", i.RingName) // keep observing - observeChan = time.After(i.cfg.ObservePeriod) + setObserveAfter(i.cfg.ObservePeriod) } case <-heartbeatTickerChan: diff --git a/pkg/util/backoff/backoff.go b/pkg/util/backoff/backoff.go index 2146f3b928e..777025aae8b 100644 --- a/pkg/util/backoff/backoff.go +++ b/pkg/util/backoff/backoff.go @@ -6,6 +6,8 @@ import ( "fmt" "math/rand" "time" + + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) // Config configures a Backoff @@ -29,6 +31,7 @@ type Backoff struct { numRetries int nextDelayMin time.Duration nextDelayMax time.Duration + waitTimer *time.Timer } // New creates a Backoff object. Pass a Context that can also terminate the operation. @@ -77,9 +80,16 @@ func (b *Backoff) Wait() { sleepTime := b.NextDelay() if b.Ongoing() { + if b.waitTimer == nil { + b.waitTimer = time.NewTimer(sleepTime) + } else { + utiltimer.ResetTimer(b.waitTimer, sleepTime) + } + select { case <-b.ctx.Done(): - case <-time.After(sleepTime): + utiltimer.StopAndDrainTimer(b.waitTimer) + case <-b.waitTimer.C: } } } diff --git a/pkg/util/backoff/backoff_test.go b/pkg/util/backoff/backoff_test.go index 942cebb6a40..4613289ef25 100644 --- a/pkg/util/backoff/backoff_test.go +++ b/pkg/util/backoff/backoff_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + utiltimer "github.com/cortexproject/cortex/pkg/util/timer" ) func TestBackoff_NextDelay(t *testing.T) { @@ -100,3 +102,61 @@ func TestBackoff_NextDelay(t *testing.T) { }) } } + +func TestBackoff_WaitReusesTimer(t *testing.T) { + t.Parallel() + + b := New(context.Background(), Config{ + MinBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + MaxRetries: 0, + }) + + b.Wait() + if b.waitTimer == nil { + t.Fatal("expected wait timer to be initialized") + } + + firstTimer := b.waitTimer + + b.Wait() + if b.waitTimer != firstTimer { + t.Fatal("expected wait timer to be reused") + } +} + +func TestBackoff_WaitReturnsWhenContextCancelled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + b := New(ctx, Config{ + MinBackoff: time.Second, + MaxBackoff: time.Second, + MaxRetries: 0, + }) + + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + startedAt := time.Now() + b.Wait() + + if time.Since(startedAt) >= 900*time.Millisecond { + t.Fatal("expected Wait to return quickly after context cancellation") + } + + if b.waitTimer == nil { + t.Fatal("expected wait timer to be initialized") + } + + utiltimer.ResetTimer(b.waitTimer, time.Nanosecond) + select { + case <-b.waitTimer.C: + case <-time.After(100 * time.Millisecond): + t.Fatal("expected wait timer to be reusable after cancellation") + } +} diff --git a/pkg/util/timer/timer.go b/pkg/util/timer/timer.go new file mode 100644 index 00000000000..541ac291613 --- /dev/null +++ b/pkg/util/timer/timer.go @@ -0,0 +1,23 @@ +package timer + +import "time" + +// StopAndDrainTimer stops the timer and drains its channel if a tick was already queued. +func StopAndDrainTimer(timer *time.Timer) { + if timer == nil { + return + } + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } +} + +// ResetTimer safely resets timer, handling the required stop+drain sequence first. +func ResetTimer(timer *time.Timer, d time.Duration) { + StopAndDrainTimer(timer) + timer.Reset(d) +} diff --git a/pkg/util/timer/timer_test.go b/pkg/util/timer/timer_test.go new file mode 100644 index 00000000000..cf990dcbc4e --- /dev/null +++ b/pkg/util/timer/timer_test.go @@ -0,0 +1,85 @@ +package timer + +import ( + "testing" + "time" +) + +func TestStopAndDrainTimer_NilTimer(t *testing.T) { + // Should not panic on nil timer. + StopAndDrainTimer(nil) +} + +func TestStopAndDrainTimer_UnfiredTimer(t *testing.T) { + timer := time.NewTimer(time.Hour) + StopAndDrainTimer(timer) + + // Channel should be empty after stop+drain. + select { + case <-timer.C: + t.Fatal("expected timer channel to be drained") + default: + } +} + +func TestStopAndDrainTimer_FiredTimer(t *testing.T) { + timer := time.NewTimer(time.Nanosecond) + // Wait for it to fire. + time.Sleep(time.Millisecond) + + StopAndDrainTimer(timer) + + // Channel should be empty after stop+drain. + select { + case <-timer.C: + t.Fatal("expected timer channel to be drained") + default: + } +} + +func TestResetTimer(t *testing.T) { + timer := time.NewTimer(time.Hour) + + // Reset to a very short duration. + ResetTimer(timer, time.Nanosecond) + + select { + case <-timer.C: + // Expected. + case <-time.After(100 * time.Millisecond): + t.Fatal("expected timer to fire after reset") + } +} + +func TestResetTimer_AfterFired(t *testing.T) { + timer := time.NewTimer(time.Nanosecond) + // Wait for it to fire. + time.Sleep(time.Millisecond) + <-timer.C + + // Reset after consuming the fired event. + ResetTimer(timer, time.Nanosecond) + + select { + case <-timer.C: + // Expected. + case <-time.After(100 * time.Millisecond): + t.Fatal("expected timer to fire after reset") + } +} + +func TestResetTimer_MultipleTimes(t *testing.T) { + timer := time.NewTimer(time.Hour) + defer timer.Stop() + + for i := range 10 { + ResetTimer(timer, time.Nanosecond) + + select { + case <-timer.C: + // Expected. + case <-time.After(100 * time.Millisecond): + t.Fatalf("iteration %d: expected timer to fire after reset", i) + } + } +}