diff --git a/fn/context_guard.go b/fn/context_guard.go index 3868a3ff02b..94b2d8ee8f1 100644 --- a/fn/context_guard.go +++ b/fn/context_guard.go @@ -51,6 +51,10 @@ func (g *ContextGuard) Quit() { cancel() } + // Clear cancelFns. It is safe to use nil, because no write + // operations to it can happen after g.quit is closed. + g.cancelFns = nil + close(g.quit) }) } @@ -149,7 +153,7 @@ func (g *ContextGuard) Create(ctx context.Context, } if opts.blocking { - g.ctxBlocking(ctx, cancel) + g.ctxBlocking(ctx) return ctx, cancel } @@ -169,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context, return ctx, cancel } -// ctxQuitUnsafe spins off a goroutine that will block until the passed context -// is cancelled or until the quit channel has been signaled after which it will -// call the passed cancel function and decrement the wait group. +// ctxQuitUnsafe increases the wait group counter, waits until the context is +// cancelled and decreases the wait group counter. It stores the passed cancel +// function and returns a wrapped version, which removed the stored one and +// calls it. The Quit method calls all the stored cancel functions. // // NOTE: the caller must hold the ContextGuard's mutex before calling this // function. @@ -181,35 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context, cancel = g.addCancelFnUnsafe(cancel) g.wg.Add(1) - go func() { - defer cancel() - defer g.wg.Done() - - select { - case <-g.quit: - case <-ctx.Done(): - } - }() + // We don't have to wait on g.quit here: g.quit can be closed only in + // the Quit method, which also closes the context we are waiting for. + context.AfterFunc(ctx, func() { + g.wg.Done() + }) return cancel } -// ctxBlocking spins off a goroutine that will block until the passed context -// is cancelled after which it will call the passed cancel function and -// decrement the wait group. -func (g *ContextGuard) ctxBlocking(ctx context.Context, - cancel context.CancelFunc) { - +// ctxBlocking increases the wait group counter, waits until the context is +// cancelled and decreases the wait group counter. +// +// NOTE: the caller must hold the ContextGuard's mutex before calling this +// function. +func (g *ContextGuard) ctxBlocking(ctx context.Context) { g.wg.Add(1) - go func() { - defer cancel() - defer g.wg.Done() - select { - case <-ctx.Done(): - } - }() + context.AfterFunc(ctx, func() { + g.wg.Done() + }) } // addCancelFnUnsafe adds a context cancel function to the manager and returns a diff --git a/fn/context_guard_test.go b/fn/context_guard_test.go index e13cba5dc05..576ca5364f8 100644 --- a/fn/context_guard_test.go +++ b/fn/context_guard_test.go @@ -2,8 +2,11 @@ package fn import ( "context" + "runtime" "testing" "time" + + "github.com/stretchr/testify/require" ) // TestContextGuard tests the behaviour of the ContextGuard. @@ -298,6 +301,12 @@ func TestContextGuard(t *testing.T) { case <-time.After(time.Second): t.Fatalf("timeout") } + + // Cancel the context. + cancel() + + // Make sure wg's counter gets to 0 eventually. + g.WgWait() }) // Test that if we add the CustomTimeoutCGOpt option, then the context @@ -433,3 +442,36 @@ func TestContextGuard(t *testing.T) { } }) } + +// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create +// any goroutines while waiting for contexts. +func TestContextGuardCountGoroutines(t *testing.T) { + // NOTE: t.Parallel() is not called in this test because it relies on an + // accurate count of active goroutines. Running other tests in parallel + // would introduce additional goroutines, leading to unreliable results. + + g := NewContextGuard() + + ctx, cancel := context.WithCancel(context.Background()) + + // Count goroutines before contexts are created. + count1 := runtime.NumGoroutine() + + // Create 1000 contexts of each type. + for i := 0; i < 1000; i++ { + _, _ = g.Create(ctx) + _, _ = g.Create(ctx, WithBlockingCG()) + _, _ = g.Create(ctx, WithTimeoutCG()) + _, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG()) + } + + // Make sure no new goroutine was launched. + count2 := runtime.NumGoroutine() + require.LessOrEqual(t, count2, count1) + + // Cancel root context. + cancel() + + // Make sure wg's counter gets to 0 eventually. + g.WgWait() +}