diff --git a/client.go b/client.go index 4d54b39a..7336005e 100644 --- a/client.go +++ b/client.go @@ -1071,10 +1071,15 @@ func (c *Client[TTx]) Start(ctx context.Context) error { return nil }(); err != nil { - defer stopped() + // If context was cancelled due to Stop(), just close the stopped channel + // and return. Stop() will handle cleanup via finalizeStop(). if errors.Is(context.Cause(fetchCtx), startstop.ErrStop) { + stopped() return nil } + + // For real startup failures, reset state so Start() can be called again. + c.baseStartStop.StartFailed(stopped) return err } diff --git a/client_test.go b/client_test.go index 90ccc43e..45669938 100644 --- a/client_test.go +++ b/client_test.go @@ -7115,6 +7115,31 @@ func Test_Client_Start_Error(t *testing.T) { require.ErrorAs(t, err, &pgErr) require.Equal(t, pgerrcode.InvalidCatalogName, pgErr.Code) }) + + t.Run("CanRestartAfterFailure", func(t *testing.T) { + t.Parallel() + + // Use a non-existent database to trigger a startup failure + dbConfig := riversharedtest.DBPool(ctx, t).Config().Copy() + dbConfig.ConnConfig.Database = "does-not-exist-and-dont-create-it" + + dbPool, err := pgxpool.NewWithConfig(ctx, dbConfig) + require.NoError(t, err) + + config := newTestConfig(t, "") + + client := newTestClient(t, dbPool, config) + + // First Start() should fail with a database error + err = client.Start(ctx) + require.Error(t, err, "first Start() should fail with database error") + + // Second Start() should also fail with an error, NOT return nil. + // This verifies that the client's internal state was properly reset + // after the first failure, allowing it to attempt startup again. + err = client.Start(ctx) + require.Error(t, err, "second Start() should return an error, not nil; client state should be reset after failed start") + }) } func Test_NewClient_BaseServiceName(t *testing.T) { diff --git a/rivershared/startstop/start_stop.go b/rivershared/startstop/start_stop.go index b03cbf24..1f13e474 100644 --- a/rivershared/startstop/start_stop.go +++ b/rivershared/startstop/start_stop.go @@ -136,6 +136,28 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, f } } +// StartFailed should be called when a service fails to start after StartInit. +// It closes the stopped channel and resets internal state so Start can be +// called again. +// +// This should not be used when a Stop is already in progress (ErrStop), because +// Stop will handle cleanup via finalizeStop. +func (s *BaseStartStop) StartFailed(stopped func()) { + if s.cancelFunc != nil { + s.cancelFunc(ErrStop) + } + stopped() + + s.mu.Lock() + defer s.mu.Unlock() + if !s.isRunning { + return + } + s.isRunning = false + s.started = nil + s.stopped = nil +} + // Started returns a channel that's closed when a service finishes starting, or // if failed to start and is stopped instead. It can be used in conjunction with // WaitAllStarted to verify startup of a constellation of services.