diff --git a/internal/localruntime/socket.go b/internal/localruntime/socket.go index 7bf99cf..3ec704b 100644 --- a/internal/localruntime/socket.go +++ b/internal/localruntime/socket.go @@ -1,9 +1,11 @@ package localruntime import ( + "errors" "fmt" "os" "path/filepath" + "syscall" ) func DefaultSocketPath() string { @@ -14,5 +16,31 @@ func DefaultSocketPath() string { } func EnsureSocketDir(socketPath string) error { - return os.MkdirAll(filepath.Dir(socketPath), 0o700) + dir := filepath.Dir(socketPath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + info, err := os.Lstat(dir) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("%s is a symlink", dir) + } + if !info.IsDir() { + return fmt.Errorf("%s is not a directory", dir) + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return errors.New("socket directory ownership is unavailable") + } + if int(stat.Uid) != os.Getuid() { + return fmt.Errorf("socket directory %s is owned by uid %d, want %d", dir, stat.Uid, os.Getuid()) + } + if info.Mode().Perm() != 0o700 { + if err := os.Chmod(dir, 0o700); err != nil { + return err + } + } + return nil } diff --git a/internal/localruntime/socket_test.go b/internal/localruntime/socket_test.go new file mode 100644 index 0000000..508878c --- /dev/null +++ b/internal/localruntime/socket_test.go @@ -0,0 +1,56 @@ +package localruntime + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestEnsureSocketDirTightensOwnerWritableDirectory(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket directory permissions are not portable to windows") + } + dir := filepath.Join(t.TempDir(), "socket-dir") + if err := os.Mkdir(dir, 0o777); err != nil { + t.Fatalf("Mkdir() error = %v", err) + } + if err := os.Chmod(dir, 0o777); err != nil { + t.Fatalf("Chmod() error = %v", err) + } + + if err := EnsureSocketDir(filepath.Join(dir, "kontext.sock")); err != nil { + t.Fatalf("EnsureSocketDir() error = %v", err) + } + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + if got := info.Mode().Perm(); got != 0o700 { + t.Fatalf("socket dir mode = %#o, want 0700", got) + } +} + +func TestEnsureSocketDirRejectsNonDirectoryParent(t *testing.T) { + parent := filepath.Join(t.TempDir(), "socket-parent") + if err := os.WriteFile(parent, []byte("not a directory"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if err := EnsureSocketDir(filepath.Join(parent, "kontext.sock")); err == nil { + t.Fatal("EnsureSocketDir() error = nil, want failure for non-directory parent") + } +} + +func TestEnsureSocketDirRejectsSymlinkParent(t *testing.T) { + target := filepath.Join(t.TempDir(), "target") + if err := os.Mkdir(target, 0o700); err != nil { + t.Fatalf("Mkdir() error = %v", err) + } + link := filepath.Join(t.TempDir(), "socket-link") + if err := os.Symlink(target, link); err != nil { + t.Fatalf("Symlink() error = %v", err) + } + if err := EnsureSocketDir(filepath.Join(link, "kontext.sock")); err == nil { + t.Fatal("EnsureSocketDir() error = nil, want failure for symlink parent") + } +}