Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 97 additions & 29 deletions controlplane/worker_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,28 +163,31 @@ func waitForWorker(socketPath, bearerToken string, timeout time.Duration) (*flig
}

// doHealthCheck performs a HealthCheck action on the worker.
// The server sends exactly one Result message with {"healthy": true, ...}.
func doHealthCheck(ctx context.Context, client *flightsql.Client) error {
// Use the underlying flight client for custom actions.
// flightsql.Client.Client is a flight.Client interface which embeds
// FlightServiceClient, giving us access to DoAction directly.
stream, err := client.Client.DoAction(ctx, &flight.Action{Type: "HealthCheck"})
if err != nil {
return err
return fmt.Errorf("health check action: %w", err)
}

for {
msg, err := stream.Recv()
if err != nil {
break
}
var body struct {
Healthy bool `json:"healthy"`
}
if err := json.Unmarshal(msg.Body, &body); err == nil && body.Healthy {
return nil
}
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("health check recv: %w", err)
}

var body struct {
Healthy bool `json:"healthy"`
}
return fmt.Errorf("worker not healthy")
if err := json.Unmarshal(msg.Body, &body); err != nil {
return fmt.Errorf("health check unmarshal: %w", err)
}
if !body.Healthy {
return fmt.Errorf("worker reported unhealthy")
}
return nil
}

// SetSessionCounter sets the session counter for load balancing.
Expand Down Expand Up @@ -314,23 +317,41 @@ func (p *FlightWorkerPool) RetireWorkerIfNoSessions(id int) {

// retireWorkerProcess handles the actual process shutdown and socket cleanup.
func retireWorkerProcess(w *ManagedWorker) {
slog.Info("Retiring worker.", "id", w.ID)

// Send SIGINT first so the worker can drain in-flight requests
if w.cmd.Process != nil {
_ = w.cmd.Process.Signal(os.Interrupt)
}

// Wait up to 3s for graceful exit. The worker just had its session
// destroyed and should exit almost immediately.
// Check if the process already exited before we try to retire it.
// This happens when a worker crashes and the client disconnect triggers
// RetireWorker before the health check loop detects the crash.
alreadyDead := false
select {
case <-w.done:
case <-time.After(3 * time.Second):
slog.Warn("Worker did not exit in time, killing.", "id", w.ID)
alreadyDead = true
default:
}

if alreadyDead {
exitCode := -1
if w.cmd.ProcessState != nil {
exitCode = w.cmd.ProcessState.ExitCode()
}
slog.Warn("Retiring worker that already exited unexpectedly.", "id", w.ID, "exit_code", exitCode, "error", w.exitErr)
} else {
slog.Info("Retiring worker.", "id", w.ID)

// Send SIGINT first so the worker can drain in-flight requests
if w.cmd.Process != nil {
_ = w.cmd.Process.Kill()
_ = w.cmd.Process.Signal(os.Interrupt)
}

// Wait up to 3s for graceful exit. The worker just had its session
// destroyed and should exit almost immediately.
select {
case <-w.done:
case <-time.After(3 * time.Second):
slog.Warn("Worker did not exit in time, killing.", "id", w.ID)
if w.cmd.Process != nil {
_ = w.cmd.Process.Kill()
}
<-w.done
}
<-w.done
}

// Close gRPC client after the process has exited
Expand Down Expand Up @@ -378,14 +399,23 @@ func (p *FlightWorkerPool) ShutdownAll() {
// WorkerCrashHandler is called when a worker crash is detected, before respawning.
type WorkerCrashHandler func(workerID int)

// maxConsecutiveHealthFailures is the number of consecutive health check failures
// before a worker is force-killed. With a typical 2s health check interval,
// this means ~6s of unresponsiveness triggers retirement.
const maxConsecutiveHealthFailures = 3

// HealthCheckLoop periodically checks worker health and handles crashed workers.
// In the elastic 1:1 model, crashed workers with active sessions trigger crash
// notification (so sessions see errors), and the dead worker is cleaned up.
// Workers without sessions are simply retired.
// Workers that fail maxConsecutiveHealthFailures health checks in a row are
// force-killed and their sessions notified.
func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Duration, onCrash ...WorkerCrashHandler) {
ticker := time.NewTicker(interval)
defer ticker.Stop()

failures := make(map[int]int) // workerID → consecutive failure count

for {
select {
case <-ctx.Done():
Expand All @@ -403,6 +433,7 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du
case <-ctx.Done():
return
case <-w.done:
delete(failures, w.ID)
// Check if already cleaned up by RetireWorker (intentional shutdown).
// If so, skip — this is not a crash.
p.mu.Lock()
Expand All @@ -426,10 +457,47 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du
default:
// Worker is alive, do a health check
hctx, cancel := context.WithTimeout(ctx, 3*time.Second)
if err := doHealthCheck(hctx, w.client); err != nil {
slog.Warn("Worker health check failed.", "id", w.ID, "error", err)
}
err := doHealthCheck(hctx, w.client)
cancel()

if err != nil {
failures[w.ID]++
count := failures[w.ID]
slog.Warn("Worker health check failed.", "id", w.ID, "error", err, "consecutive_failures", count)

if count >= maxConsecutiveHealthFailures {
slog.Error("Worker unresponsive, force-killing.", "id", w.ID, "consecutive_failures", count)
delete(failures, w.ID)

p.mu.Lock()
_, stillInPool := p.workers[w.ID]
if stillInPool {
delete(p.workers, w.ID)
}
p.mu.Unlock()

if stillInPool {
for _, h := range onCrash {
h(w.ID)
}
// Skip SIGINT (unlike retireWorkerProcess) since the worker
// has already proven unresponsive. Go straight to SIGKILL.
go func() {
if w.cmd.Process != nil {
_ = w.cmd.Process.Kill()
}
<-w.done
slog.Warn("Force-killed worker exited.", "id", w.ID, "error", w.exitErr)
if w.client != nil {
_ = w.client.Close()
}
_ = os.Remove(w.socketPath)
}()
}
}
} else {
delete(failures, w.ID)
}
}
}
}
Expand Down
Loading