Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions internal/guard/app/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"mime"
Expand Down Expand Up @@ -154,13 +155,41 @@ func (s *Server) RuntimeCore() *runtimecore.Core {
return s.core
}

func (s *Server) ListenAndServe(addr string) error {
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return nil
default:
}
httpServer := &http.Server{
Addr: addr,
Handler: s.Handler(),
ReadHeaderTimeout: 5 * time.Second,
}
return httpServer.ListenAndServe()
errs := make(chan error, 1)
go func() {
errs <- httpServer.ListenAndServe()
}()
select {
case err := <-errs:
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return err
}
if err := <-errs; err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
}

func (s *Server) routes() {
Expand Down
4 changes: 2 additions & 2 deletions internal/guard/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error {
if err != nil {
return err
}
if err := runtimeService.Start(context.Background()); err != nil {
if err := runtimeService.Start(ctx); err != nil {
return fmt.Errorf("local runtime start: %w", err)
}
defer runtimeService.Stop()
Expand All @@ -184,7 +184,7 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error {
if !*noOpen {
_ = browser.OpenURL("http://" + *addr)
}
return localServer.ListenAndServe(*addr)
return localServer.ListenAndServe(ctx, *addr)
}

func localJudgeStatusLine(localJudge judge.Judge) string {
Expand Down
71 changes: 71 additions & 0 deletions internal/guard/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/kontext-security/kontext-cli/internal/guard/judge"
"github.com/kontext-security/kontext-cli/internal/guard/judgeruntime"
Expand Down Expand Up @@ -240,6 +244,46 @@ func TestPrintHookStatusReportsGuardAndHostedConflict(t *testing.T) {
}
}

func TestRunDaemonReturnsWhenContextCanceled(t *testing.T) {
probe, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Skipf("local TCP listen unavailable: %v", err)
}
_ = probe.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dir := t.TempDir()
socketPath := filepath.Join("/tmp", fmt.Sprintf("kontext-%d.sock", time.Now().UnixNano()))
defer os.Remove(socketPath)
out := newNotifyWriter("Kontext Guard local daemon listening")
errs := make(chan error, 1)
go func() {
errs <- runDaemon(ctx, []string{
"--skip-hook-install",
"--no-open",
"--addr", "127.0.0.1:0",
"--db", filepath.Join(dir, "guard.db"),
"--socket", socketPath,
}, out)
}()
select {
case <-out.seen:
case err := <-errs:
t.Fatalf("runDaemon() returned before startup: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("runDaemon() did not start")
}
cancel()
select {
case err := <-errs:
if err != nil {
t.Fatalf("runDaemon() error = %v, want nil", err)
}
case <-time.After(5 * time.Second):
t.Fatal("runDaemon() did not return after context cancellation")
}
}

func TestValidateLocalJudgeURLRejectsHostedURL(t *testing.T) {
if err := validateLocalJudgeURL("https://api.example.com/v1"); err == nil {
t.Fatal("validateLocalJudgeURL() error = nil, want hosted URL rejection")
Expand Down Expand Up @@ -446,3 +490,30 @@ type fakeJudge struct{}
func (fakeJudge) Decide(context.Context, judge.Input) (judge.Result, error) {
return judge.Result{}, nil
}

type notifyWriter struct {
needle string
seen chan struct{}
once sync.Once
mu sync.Mutex
buf bytes.Buffer
}

func newNotifyWriter(needle string) *notifyWriter {
return &notifyWriter{
needle: needle,
seen: make(chan struct{}),
}
}

func (w *notifyWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
n, err := w.buf.Write(p)
if strings.Contains(w.buf.String(), w.needle) {
w.once.Do(func() {
close(w.seen)
})
}
return n, err
}
Loading