diff --git a/internal/guard/app/server/server.go b/internal/guard/app/server/server.go index cc162f2..eb16a8f 100644 --- a/internal/guard/app/server/server.go +++ b/internal/guard/app/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "fmt" "io/fs" "mime" @@ -154,13 +155,41 @@ func (s *Server) RuntimeCore() *runtimecore.Core { return s.core } -func (s *Server) ListenAndServe(addr string) error { +func (s *Server) ListenAndServe(ctx context.Context, addr string) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return nil + default: + } httpServer := &http.Server{ Addr: addr, Handler: s.Handler(), ReadHeaderTimeout: 5 * time.Second, } - return httpServer.ListenAndServe() + errs := make(chan error, 1) + go func() { + errs <- httpServer.ListenAndServe() + }() + select { + case err := <-errs: + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { + return err + } + if err := <-errs; err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil + } } func (s *Server) routes() { diff --git a/internal/guard/cli/cli.go b/internal/guard/cli/cli.go index 341c37e..a09a81d 100644 --- a/internal/guard/cli/cli.go +++ b/internal/guard/cli/cli.go @@ -172,7 +172,7 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error { if err != nil { return err } - if err := runtimeService.Start(context.Background()); err != nil { + if err := runtimeService.Start(ctx); err != nil { return fmt.Errorf("local runtime start: %w", err) } defer runtimeService.Stop() @@ -184,7 +184,7 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error { if !*noOpen { _ = browser.OpenURL("http://" + *addr) } - return localServer.ListenAndServe(*addr) + return localServer.ListenAndServe(ctx, *addr) } func localJudgeStatusLine(localJudge judge.Judge) string { diff --git a/internal/guard/cli/cli_test.go b/internal/guard/cli/cli_test.go index 6ecc43d..38dd5fd 100644 --- a/internal/guard/cli/cli_test.go +++ b/internal/guard/cli/cli_test.go @@ -4,12 +4,16 @@ import ( "bytes" "context" "encoding/json" + "fmt" + "net" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" + "time" "github.com/kontext-security/kontext-cli/internal/guard/judge" "github.com/kontext-security/kontext-cli/internal/guard/judgeruntime" @@ -240,6 +244,46 @@ func TestPrintHookStatusReportsGuardAndHostedConflict(t *testing.T) { } } +func TestRunDaemonReturnsWhenContextCanceled(t *testing.T) { + probe, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Skipf("local TCP listen unavailable: %v", err) + } + _ = probe.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dir := t.TempDir() + socketPath := filepath.Join("/tmp", fmt.Sprintf("kontext-%d.sock", time.Now().UnixNano())) + defer os.Remove(socketPath) + out := newNotifyWriter("Kontext Guard local daemon listening") + errs := make(chan error, 1) + go func() { + errs <- runDaemon(ctx, []string{ + "--skip-hook-install", + "--no-open", + "--addr", "127.0.0.1:0", + "--db", filepath.Join(dir, "guard.db"), + "--socket", socketPath, + }, out) + }() + select { + case <-out.seen: + case err := <-errs: + t.Fatalf("runDaemon() returned before startup: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("runDaemon() did not start") + } + cancel() + select { + case err := <-errs: + if err != nil { + t.Fatalf("runDaemon() error = %v, want nil", err) + } + case <-time.After(5 * time.Second): + t.Fatal("runDaemon() did not return after context cancellation") + } +} + func TestValidateLocalJudgeURLRejectsHostedURL(t *testing.T) { if err := validateLocalJudgeURL("https://api.example.com/v1"); err == nil { t.Fatal("validateLocalJudgeURL() error = nil, want hosted URL rejection") @@ -446,3 +490,30 @@ type fakeJudge struct{} func (fakeJudge) Decide(context.Context, judge.Input) (judge.Result, error) { return judge.Result{}, nil } + +type notifyWriter struct { + needle string + seen chan struct{} + once sync.Once + mu sync.Mutex + buf bytes.Buffer +} + +func newNotifyWriter(needle string) *notifyWriter { + return ¬ifyWriter{ + needle: needle, + seen: make(chan struct{}), + } +} + +func (w *notifyWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + n, err := w.buf.Write(p) + if strings.Contains(w.buf.String(), w.needle) { + w.once.Do(func() { + close(w.seen) + }) + } + return n, err +}