diff --git a/go.mod b/go.mod index 05ccb784..228186c0 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/hashicorp/go-retryablehttp v0.7.8 + github.com/huin/goupnp v1.3.0 github.com/knadh/koanf/parsers/json v1.0.0 github.com/knadh/koanf/providers/rawbytes v1.0.0 github.com/knadh/koanf/v2 v2.3.0 @@ -132,7 +133,6 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/huandu/xstrings v1.3.2 // indirect - github.com/huin/goupnp v1.3.0 // indirect github.com/illarion/gonotify/v2 v2.0.3 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/klauspost/pgzip v1.2.6 // indirect diff --git a/peer/api.go b/peer/api.go new file mode 100644 index 00000000..05752402 --- /dev/null +++ b/peer/api.go @@ -0,0 +1,124 @@ +// Package peer implements the client side of "Share My Connection". api.go +// is the thin HTTP client for lantern-cloud's /v1/peer/* endpoints. +package peer + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/getlantern/radiance/common/settings" +) + +type RegisterRequest struct { + ExternalIP string `json:"external_ip"` + ExternalPort uint16 `json:"external_port"` + InternalPort uint16 `json:"internal_port"` +} + +type RegisterResponse struct { + RouteID string `json:"route_id"` + ServerConfig string `json:"server_config"` + HeartbeatIntervalSeconds int64 `json:"heartbeat_interval_seconds"` +} + +type LifecycleRequest struct { + RouteID string `json:"route_id"` +} + +// APIError carries the server's HTTP status and body. Callers map specific +// statuses to user-facing errors (404 → not registered, 422 → not reachable +// from the public internet, 503 → feature off). +type APIError struct { + Status int + Body string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("peer api: status=%d body=%s", e.Status, e.Body) +} + +type API struct { + httpClient *http.Client + baseURL string + deviceID string +} + +// NewAPI constructs the client. baseURL must not have a trailing slash and +// must not include "/v1" — that's appended per-endpoint. +func NewAPI(httpClient *http.Client, baseURL, deviceID string) *API { + return &API{httpClient: httpClient, baseURL: baseURL, deviceID: deviceID} +} + +func (a *API) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) { + var resp RegisterResponse + if err := a.do(ctx, http.MethodPost, "/v1/peer/register", req, &resp); err != nil { + return nil, fmt.Errorf("register: %w", err) + } + return &resp, nil +} + +// Heartbeat extends the peer route's TTL. The server owner-gates via +// X-Lantern-Device-Id, so a leaked route_id can't be used by another device +// to keep the registration alive. +func (a *API) Heartbeat(ctx context.Context, routeID string) error { + if err := a.do(ctx, http.MethodPost, "/v1/peer/heartbeat", LifecycleRequest{RouteID: routeID}, nil); err != nil { + return fmt.Errorf("heartbeat: %w", err) + } + return nil +} + +func (a *API) Deregister(ctx context.Context, routeID string) error { + if err := a.do(ctx, http.MethodPost, "/v1/peer/deregister", LifecycleRequest{RouteID: routeID}, nil); err != nil { + return fmt.Errorf("deregister: %w", err) + } + return nil +} + +func (a *API) do(ctx context.Context, method, path string, body, out any) error { + var reqBody io.Reader + if body != nil { + buf, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal request: %w", err) + } + reqBody = bytes.NewReader(buf) + } + r, err := http.NewRequestWithContext(ctx, method, a.baseURL+path, reqBody) + if err != nil { + return fmt.Errorf("build request: %w", err) + } + if body != nil { + r.Header.Set("Content-Type", "application/json") + } + r.Header.Set("X-Lantern-Device-Id", a.deviceID) + // Forward the same feature-override header that config/fetcher.go uses + // for /config-new requests, so QA can flip on `peer_proxy` ahead of the + // public-flag rollout via FeatureOverridesKey (RADIANCE_FEATURE_OVERRIDES). + // Without this the server-side gate rejects register/heartbeat/deregister + // regardless of the local toggle. + if val := settings.GetString(settings.FeatureOverridesKey); val != "" { + r.Header.Set("X-Lantern-Feature-Override", val) + } + + resp, err := a.httpClient.Do(r) + if err != nil { + return fmt.Errorf("send: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + const maxBody = 4096 + buf, _ := io.ReadAll(io.LimitReader(resp.Body, maxBody)) + return &APIError{Status: resp.StatusCode, Body: string(bytes.TrimSpace(buf))} + } + if out != nil { + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("decode response: %w", err) + } + } + return nil +} diff --git a/peer/peer.go b/peer/peer.go new file mode 100644 index 00000000..a70ad7ca --- /dev/null +++ b/peer/peer.go @@ -0,0 +1,433 @@ +package peer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "math/rand/v2" + "sync" + "time" + + "github.com/sagernet/sing-box/experimental/libbox" + + "github.com/getlantern/radiance/portforward" +) + +// Port range chosen to minimize collision risk on the typical home network, +// not to guarantee one. 30000–50000 sits above the well-known/system range +// (0–1023) and above the ports most services use by default (web/dev/dbs +// usually <30000). It overlaps both the IANA registered range (1024–49151) +// and the OS ephemeral range on some platforms (Linux's default +// net.ipv4.ip_local_port_range starts at 32768, Windows uses 49152+), so +// a collision is still possible. AddPortMapping surfaces the conflict and +// the peer.Client caller can retry with a fresh pick. +const ( + internalPortMin = 30000 + internalPortMax = 50000 +) + +type portForwarder interface { + MapPort(ctx context.Context, internalPort uint16, description string) (*portforward.Mapping, error) + UnmapPort(ctx context.Context) error + StartRenewal(ctx context.Context) + ExternalIP(ctx context.Context) (string, error) +} + +type boxService interface { + Start() error + Close() error +} + +type boxFactory func(ctx context.Context, options string) (boxService, error) + +type Status struct { + Active bool `json:"active"` + SharingSince time.Time `json:"sharing_since,omitempty"` + ExternalIP string `json:"external_ip,omitempty"` + ExternalPort uint16 `json:"external_port,omitempty"` + RouteID string `json:"route_id,omitempty"` +} + +// Config plumbs in dependencies. Zero-valued fields fall back to production +// defaults; HeartbeatInterval and HeartbeatTimeout exist so tests can drive +// the loop without sleeping a full minute. +type Config struct { + API *API + NewForwarder func(ctx context.Context) (portForwarder, error) + BuildBoxService boxFactory + HeartbeatInterval time.Duration + HeartbeatTimeout time.Duration +} + +// Client orchestrates one peer-proxy session: open UPnP port → register with +// lantern-cloud → run a sing-box samizdat inbound on the forwarded port → +// heartbeat → on shutdown: deregister + close inbound + unmap. +// +// Re-Starting a stopped Client is allowed. +type Client struct { + cfg Config + + mu sync.Mutex + // startingDone is created when Start sets starting=true and closed when + // the same Start clears it (success or fail). Stop callers that arrive + // mid-Start block on this channel rather than racing the in-flight + // setup. Nil whenever no Start is in flight. + startingDone chan struct{} + // starting and active together serialize Start: starting is true while a + // Start call is in flight, active is true once it succeeds. Without + // starting, two concurrent Start calls could both pass the !active check + // and run setup in parallel — the second's state would overwrite the + // first's, orphaning a registered route + open box that this Client can + // no longer Stop. + starting bool + active bool + status Status + cancelRun context.CancelFunc + runDone chan struct{} + forwarder portForwarder + box boxService + routeID string +} + +// peerCleanupTimeout caps how long Start's rollback path waits for +// Deregister / UnmapPort. Cleanup uses a fresh Background context (not the +// caller's ctx) so an already-canceled or expired Start ctx doesn't skip +// teardown and leak the registered route or router rule. +const peerCleanupTimeout = 30 * time.Second + +func NewClient(cfg Config) (*Client, error) { + if cfg.API == nil { + return nil, errors.New("peer: Config.API is required") + } + if cfg.NewForwarder == nil { + cfg.NewForwarder = func(ctx context.Context) (portForwarder, error) { + return portforward.NewForwarder(ctx) + } + } + if cfg.BuildBoxService == nil { + cfg.BuildBoxService = defaultBuildBoxService + } + if cfg.HeartbeatTimeout == 0 { + cfg.HeartbeatTimeout = 30 * time.Second + } + return &Client{cfg: cfg}, nil +} + +// Start opens the peer-proxy session. On success a background heartbeat +// goroutine is running; on error any partial setup is torn down before +// returning. +func (c *Client) Start(ctx context.Context) error { + c.mu.Lock() + if c.active || c.starting { + c.mu.Unlock() + return errors.New("peer client already active") + } + c.starting = true + c.startingDone = make(chan struct{}) + c.mu.Unlock() + + var ( + success bool + fwd portForwarder + regResp *RegisterResponse + box boxService + runCtx context.Context + cancelRun context.CancelFunc + ) + defer func() { + c.mu.Lock() + c.starting = false + done := c.startingDone + c.startingDone = nil + c.mu.Unlock() + close(done) // unblocks any Stop call that arrived mid-Start + if success { + return + } + // A fresh ctx — the caller's may already be canceled by the time we + // roll back, which would skip Deregister and UnmapPort and leak the + // registered route + router rule. + cleanupCtx, cancel := context.WithTimeout(context.Background(), peerCleanupTimeout) + defer cancel() + if box != nil { + _ = box.Close() + } + if cancelRun != nil { + cancelRun() + } + if regResp != nil { + _ = c.cfg.API.Deregister(cleanupCtx, regResp.RouteID) + } + if fwd != nil { + _ = fwd.UnmapPort(cleanupCtx) + } + }() + + fwd, err := c.cfg.NewForwarder(ctx) + if err != nil { + return fmt.Errorf("discover gateway: %w", err) + } + internalPort := pickInternalPort() + mapping, err := fwd.MapPort(ctx, internalPort, "Lantern Share My Connection") + if err != nil { + return fmt.Errorf("map port %d: %w", internalPort, err) + } + + externalIP, err := fwd.ExternalIP(ctx) + if err != nil { + return fmt.Errorf("get external ip: %w", err) + } + regResp, err = c.cfg.API.Register(ctx, RegisterRequest{ + ExternalIP: externalIP, + ExternalPort: mapping.ExternalPort, + InternalPort: mapping.InternalPort, + }) + if err != nil { + return fmt.Errorf("register with lantern-cloud: %w", err) + } + + // The peer's outbound traffic must bypass any TUN device the user's own + // VPN may have installed — otherwise censored clients' traffic would + // egress through the local user's Lantern proxy instead of their + // residential connection, defeating the whole point of peer-sharing. + // auto_detect_interface tells sing-box to bind outbound dials to the + // underlying physical interface rather than whatever the OS routing + // table picks (which would be the VPN TUN if the VPN is up). + options, err := ensurePeerOutboundsBypassVPN(regResp.ServerConfig) + if err != nil { + return fmt.Errorf("patch sing-box options: %w", err) + } + + // runCtx must outlive Start, so it derives from Background() rather than + // the caller's ctx — otherwise libbox's stored ctx would die when Start + // returns and take the box's internal goroutines with it. + runCtx, cancelRun = context.WithCancel(context.Background()) + box, err = c.cfg.BuildBoxService(runCtx, options) + if err != nil { + cancelRun() + return fmt.Errorf("build sing-box: %w", err) + } + if err := box.Start(); err != nil { + cancelRun() + return fmt.Errorf("start sing-box: %w", err) + } + + // HeartbeatIntervalSeconds is server-driven so lantern-cloud can dial up + // the cadence on registrations it wants to expire faster. Honor any + // positive value verbatim — clamping short intervals up would defeat + // that and risk the server reaping the route between our heartbeats. + // A non-positive value means the field was unset (e.g., older server, + // JSON omitted); fall back to a sane default. + heartbeat := c.cfg.HeartbeatInterval + if heartbeat == 0 { + heartbeat = time.Duration(regResp.HeartbeatIntervalSeconds) * time.Second + if heartbeat <= 0 { + heartbeat = 5 * time.Minute + } + } + runDone := make(chan struct{}) + + c.mu.Lock() + c.active = true + c.forwarder = fwd + c.box = box + c.routeID = regResp.RouteID + c.cancelRun = cancelRun + c.runDone = runDone + c.status = Status{ + Active: true, + SharingSince: time.Now(), + ExternalIP: externalIP, + ExternalPort: mapping.ExternalPort, + RouteID: regResp.RouteID, + } + c.mu.Unlock() + + fwd.StartRenewal(runCtx) + go c.heartbeatLoop(runCtx, heartbeat, runDone) + + slog.Info("peer client started", + "external_ip", externalIP, + "external_port", mapping.ExternalPort, + "internal_port", mapping.InternalPort, + "method", mapping.Method, + "route_id", regResp.RouteID, + "heartbeat", heartbeat, + ) + success = true + return nil +} + +// Stop tears down an active session. Idempotent. Blocks until the heartbeat +// goroutine has exited and all teardown calls have completed (or timed out). +// +// If a Start is in flight when Stop is called, Stop waits for that Start to +// finish (success or fail) before proceeding. Without this, a Stop arriving +// while starting=true would return nil and let the racing Start leave the +// client active afterward — exactly the orphaned-session shape Start's own +// rollback path is designed to prevent. The wait honors ctx so a cancellable +// caller still has an exit door if Start hangs. +func (c *Client) Stop(ctx context.Context) error { + c.mu.Lock() + for c.starting { + done := c.startingDone + c.mu.Unlock() + select { + case <-done: + case <-ctx.Done(): + return ctx.Err() + } + c.mu.Lock() + } + if !c.active { + c.mu.Unlock() + return nil + } + cancel := c.cancelRun + done := c.runDone + fwd := c.forwarder + box := c.box + routeID := c.routeID + c.active = false + c.cancelRun = nil + c.runDone = nil + c.forwarder = nil + c.box = nil + c.routeID = "" + c.status = Status{} + c.mu.Unlock() + + cancel() + <-done + + var firstErr error + if err := c.cfg.API.Deregister(ctx, routeID); err != nil { + firstErr = fmt.Errorf("deregister: %w", err) + slog.Warn("peer client deregister failed (continuing teardown)", "err", err) + } + if err := box.Close(); err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("close sing-box: %w", err) + } + slog.Warn("peer client sing-box close failed", "err", err) + } + if err := fwd.UnmapPort(ctx); err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("unmap port: %w", err) + } + slog.Warn("peer client unmap port failed", "err", err) + } + slog.Info("peer client stopped", "route_id", routeID) + return firstErr +} + +func (c *Client) IsActive() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.active +} + +func (c *Client) CurrentStatus() Status { + c.mu.Lock() + defer c.mu.Unlock() + return c.status +} + +// heartbeatLoop closes done on exit so Stop can wait for the loop before +// tearing down resources. The channel is passed in rather than read off the +// Client because Stop nils c.runDone before waiting on its local copy. +func (c *Client) heartbeatLoop(ctx context.Context, interval time.Duration, done chan struct{}) { + defer close(done) + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + c.mu.Lock() + routeID := c.routeID + c.mu.Unlock() + if routeID == "" { + return + } + hbCtx, cancel := context.WithTimeout(ctx, c.cfg.HeartbeatTimeout) + err := c.cfg.API.Heartbeat(hbCtx, routeID) + cancel() + if err != nil { + // A single transient blip shouldn't kill the registration — + // the server-side reaper will deprecate the row if heartbeats + // stay missing past expiration, and we'll observe that on a + // later heartbeat as a 404. + slog.Warn("peer heartbeat failed", "err", err, "route_id", routeID) + if isNotRegistered(err) { + slog.Info("peer route no longer registered server-side, stopping client") + // Stop runs in a separate goroutine to avoid the cyclic + // Stop → cancelRun → loop-exit deadlock. + go func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _ = c.Stop(stopCtx) + }() + return + } + } + } + } +} + +// isNotRegistered reports whether an error from the heartbeat is a 404 from +// the server (deprecated / reaped / wrong owner). On 404 the registration is +// gone and we stop ourselves; on any other error we keep trying. +func isNotRegistered(err error) bool { + var apiErr *APIError + return errors.As(err, &apiErr) && apiErr.Status == 404 +} + +// ensurePeerOutboundsBypassVPN guarantees the peer sing-box's outbound dials +// bind to the physical interface rather than whatever the OS routing table +// picks. Without this, when the user's own Lantern VPN is up its TUN holds +// the default route and the peer's outbound traffic — i.e. the censored +// client's destination requests — would egress through Lantern's proxy +// network instead of the user's residential connection. That defeats the +// whole point of using the user's home IP as a circumvention exit. +// +// We splice the flag into whatever sing-box options the server supplied +// rather than relying on the server-side track config to set it, since the +// VPN-bypass requirement is a property of the *client's* environment, not +// the proxy track config. +func ensurePeerOutboundsBypassVPN(options string) (string, error) { + var raw map[string]any + if err := json.Unmarshal([]byte(options), &raw); err != nil { + return "", fmt.Errorf("decode options: %w", err) + } + route, _ := raw["route"].(map[string]any) + if route == nil { + route = map[string]any{} + raw["route"] = route + } + route["auto_detect_interface"] = true + out, err := json.Marshal(raw) + if err != nil { + return "", fmt.Errorf("encode options: %w", err) + } + return string(out), nil +} + +func pickInternalPort() uint16 { + return uint16(internalPortMin + rand.IntN(internalPortMax-internalPortMin)) +} + +// We pass a nil PlatformInterface — peer-proxy inbounds don't need TUN / +// platform-VPN integration the way the main VPN tunnel does. The samizdat +// inbound is just an HTTPS server bound to a TCP port; sing-box's default +// network stack handles it. +func defaultBuildBoxService(ctx context.Context, options string) (boxService, error) { + bs, err := libbox.NewServiceWithContext(ctx, options, nil) + if err != nil { + return nil, fmt.Errorf("libbox.NewServiceWithContext: %w", err) + } + return bs, nil +} diff --git a/peer/peer_test.go b/peer/peer_test.go new file mode 100644 index 00000000..5872419b --- /dev/null +++ b/peer/peer_test.go @@ -0,0 +1,579 @@ +package peer + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getlantern/radiance/portforward" +) + +type fakeForwarder struct { + mu sync.Mutex + mapErr error + extIPErr error + unmapErr error + mapped bool + unmapped bool + renewals int + externalIP string + mapping *portforward.Mapping + cancelRenew context.CancelFunc +} + +func (f *fakeForwarder) MapPort(_ context.Context, internalPort uint16, _ string) (*portforward.Mapping, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.mapErr != nil { + return nil, f.mapErr + } + f.mapped = true + f.mapping = &portforward.Mapping{ + ExternalPort: internalPort, + InternalPort: internalPort, + InternalIP: "192.168.1.10", + Protocol: "TCP", + LeaseDuration: time.Hour, + Method: "fake", + } + return f.mapping, nil +} + +func (f *fakeForwarder) UnmapPort(_ context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + f.unmapped = true + return f.unmapErr +} + +func (f *fakeForwarder) StartRenewal(ctx context.Context) { + f.mu.Lock() + f.renewals++ + rctx, cancel := context.WithCancel(ctx) + f.cancelRenew = cancel + f.mu.Unlock() + go func() { <-rctx.Done() }() +} + +func (f *fakeForwarder) ExternalIP(_ context.Context) (string, error) { + if f.extIPErr != nil { + return "", f.extIPErr + } + if f.externalIP == "" { + return "203.0.113.99", nil + } + return f.externalIP, nil +} + +func (f *fakeForwarder) wasUnmapped() bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.unmapped +} + +func (f *fakeForwarder) wasMapped() bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.mapped +} + +// slowMapForwarder blocks MapPort on a gate channel and signals via entered +// when the call is in flight. Used to race two concurrent Starts so the +// test can observe the serialization invariant. +type slowMapForwarder struct { + gate chan struct{} + entered chan struct{} +} + +func (f *slowMapForwarder) MapPort(_ context.Context, internalPort uint16, _ string) (*portforward.Mapping, error) { + select { + case f.entered <- struct{}{}: + default: + } + <-f.gate + return &portforward.Mapping{ + ExternalPort: internalPort, InternalPort: internalPort, + InternalIP: "192.168.1.10", Protocol: "TCP", + LeaseDuration: time.Hour, Method: "fake", + }, nil +} +func (f *slowMapForwarder) UnmapPort(context.Context) error { return nil } +func (f *slowMapForwarder) StartRenewal(context.Context) {} +func (f *slowMapForwarder) ExternalIP(context.Context) (string, error) { + return "203.0.113.99", nil +} + +type fakeBoxService struct { + startErr error + closeErr error + started atomic.Bool + closed atomic.Bool + gotConfig string +} + +func (b *fakeBoxService) Start() error { + if b.startErr != nil { + return b.startErr + } + b.started.Store(true) + return nil +} + +func (b *fakeBoxService) Close() error { + b.closed.Store(true) + return b.closeErr +} + +type stubServer struct { + t *testing.T + server *httptest.Server + registerStatus int + registerResp RegisterResponse + heartbeatStatus int + deregisterStatus int + registerCount atomic.Int64 + heartbeatCount atomic.Int64 + deregisterCount atomic.Int64 + registerDeviceID atomic.Value // string + heartbeatDeviceID atomic.Value // string + deregisterDeviceID atomic.Value // string + lastRegisterReq atomic.Value // RegisterRequest +} + +func newStubServer(t *testing.T) *stubServer { + t.Helper() + s := &stubServer{ + t: t, + registerStatus: http.StatusOK, + heartbeatStatus: http.StatusOK, + deregisterStatus: http.StatusOK, + registerResp: RegisterResponse{ + RouteID: "00000000-0000-0000-0000-000000000123", + ServerConfig: `{"inbounds": [{"type":"samizdat","tag":"samizdat-in"}]}`, + HeartbeatIntervalSeconds: 60, + }, + } + mux := http.NewServeMux() + mux.HandleFunc("/v1/peer/register", func(w http.ResponseWriter, r *http.Request) { + s.registerCount.Add(1) + s.registerDeviceID.Store(r.Header.Get("X-Lantern-Device-Id")) + var req RegisterRequest + _ = json.NewDecoder(r.Body).Decode(&req) + s.lastRegisterReq.Store(req) + if s.registerStatus != http.StatusOK { + http.Error(w, "register failed", s.registerStatus) + return + } + _ = json.NewEncoder(w).Encode(s.registerResp) + }) + mux.HandleFunc("/v1/peer/heartbeat", func(w http.ResponseWriter, r *http.Request) { + s.heartbeatCount.Add(1) + s.heartbeatDeviceID.Store(r.Header.Get("X-Lantern-Device-Id")) + if s.heartbeatStatus != http.StatusOK { + http.Error(w, "heartbeat failed", s.heartbeatStatus) + return + } + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/v1/peer/deregister", func(w http.ResponseWriter, r *http.Request) { + s.deregisterCount.Add(1) + s.deregisterDeviceID.Store(r.Header.Get("X-Lantern-Device-Id")) + if s.deregisterStatus != http.StatusOK { + http.Error(w, "deregister failed", s.deregisterStatus) + return + } + w.WriteHeader(http.StatusOK) + }) + s.server = httptest.NewServer(mux) + t.Cleanup(s.server.Close) + return s +} + +// newTestClient builds a Client wired to the supplied test doubles. The +// HeartbeatInterval default of 0 leaves the production floor in place +// (caller can override per test). +func newTestClient(t *testing.T, fwd portForwarder, box *fakeBoxService, srv *stubServer, opts ...func(*Config)) *Client { + t.Helper() + cfg := Config{ + API: NewAPI(srv.server.Client(), srv.server.URL, "test-device"), + NewForwarder: func(_ context.Context) (portForwarder, error) { + return fwd, nil + }, + BuildBoxService: func(_ context.Context, options string) (boxService, error) { + box.gotConfig = options + return box, nil + }, + } + for _, opt := range opts { + opt(&cfg) + } + c, err := NewClient(cfg) + require.NoError(t, err) + return c +} + +func TestClient_Start_HappyPath(t *testing.T) { + fwd := &fakeForwarder{externalIP: "203.0.113.42"} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + ctx := context.Background() + require.NoError(t, c.Start(ctx)) + t.Cleanup(func() { _ = c.Stop(ctx) }) + + assert.True(t, c.IsActive()) + assert.True(t, fwd.wasMapped()) + assert.True(t, box.started.Load()) + assert.Equal(t, int64(1), srv.registerCount.Load()) + assert.Equal(t, "test-device", srv.registerDeviceID.Load()) + + req := srv.lastRegisterReq.Load().(RegisterRequest) + assert.Equal(t, "203.0.113.42", req.ExternalIP) + assert.NotZero(t, req.ExternalPort) + assert.NotZero(t, req.InternalPort) + + status := c.CurrentStatus() + assert.True(t, status.Active) + assert.Equal(t, "203.0.113.42", status.ExternalIP) + assert.Equal(t, "00000000-0000-0000-0000-000000000123", status.RouteID) +} + +func TestClient_Start_DoubleStartIsError(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + require.NoError(t, c.Start(context.Background())) + t.Cleanup(func() { _ = c.Stop(context.Background()) }) + + err := c.Start(context.Background()) + assert.ErrorContains(t, err, "already active") +} + +// Two goroutines hitting Start at the same time must not both run setup — +// the second one would overwrite the first's state, leaving the first +// session orphaned with no way to Stop it through this Client. +func TestClient_Start_ConcurrentStartsAreSerialized(t *testing.T) { + fwd := &slowMapForwarder{ + gate: make(chan struct{}), + entered: make(chan struct{}, 1), + } + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + t.Cleanup(func() { _ = c.Stop(context.Background()) }) + + results := make(chan error, 2) + for range 2 { + go func() { results <- c.Start(context.Background()) }() + } + // Wait for one Start to be inside MapPort holding starting=true; release + // it once the second Start has had a chance to observe the contended + // state and reject. + <-fwd.entered + close(fwd.gate) + + var nilCount, errCount int + for range 2 { + if err := <-results; err == nil { + nilCount++ + } else { + errCount++ + assert.ErrorContains(t, err, "already active") + } + } + assert.Equal(t, 1, nilCount, "exactly one Start must succeed") + assert.Equal(t, 1, errCount, "the racing Start must be rejected") + assert.Equal(t, int64(1), srv.registerCount.Load()) +} + +// A Stop that arrives while Start is still in flight must wait for that +// Start to finish — otherwise it returns nil and the racing Start happily +// leaves the client active afterward, which produces the exact orphaned- +// session shape Start's own rollback path is designed to prevent. +func TestClient_Stop_WaitsForInflightStart(t *testing.T) { + fwd := &slowMapForwarder{ + gate: make(chan struct{}), + entered: make(chan struct{}, 1), + } + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + startErr := make(chan error, 1) + go func() { startErr <- c.Start(context.Background()) }() + + // Wait until Start is blocked inside MapPort (starting=true, active=false). + <-fwd.entered + + stopErr := make(chan error, 1) + go func() { stopErr <- c.Stop(context.Background()) }() + + // Stop must not return while Start is still in flight. + select { + case <-stopErr: + t.Fatal("Stop returned before Start finished — would orphan the session") + case <-time.After(50 * time.Millisecond): + } + + // Let Start complete. Stop should unblock and tear down what Start set up. + close(fwd.gate) + + require.NoError(t, <-startErr) + require.NoError(t, <-stopErr) + + // Client must be in clean post-Stop state — not active and ready to be + // Started again. + assert.False(t, c.IsActive()) + assert.Equal(t, int64(1), srv.registerCount.Load(), "Start completed once") + assert.Equal(t, int64(1), srv.deregisterCount.Load(), "Stop tore down what Start set up") +} + +// A Stop with an already-canceled ctx that races a slow Start should give +// up promptly rather than wait forever. +func TestClient_Stop_RespectsCtxWhileWaitingForStart(t *testing.T) { + fwd := &slowMapForwarder{ + gate: make(chan struct{}), + entered: make(chan struct{}, 1), + } + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + t.Cleanup(func() { + close(fwd.gate) + // Drain the in-flight Start so the test goroutines don't leak. + _ = c.Stop(context.Background()) + }) + + go func() { _ = c.Start(context.Background()) }() + <-fwd.entered + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := c.Stop(ctx) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestClient_Start_PortForwardFailureUnwinds(t *testing.T) { + fwd := &fakeForwarder{mapErr: portforward.ErrNoPortForwarding} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + err := c.Start(context.Background()) + require.Error(t, err) + assert.False(t, c.IsActive()) + assert.Equal(t, int64(0), srv.registerCount.Load()) + assert.False(t, box.started.Load()) +} + +func TestClient_Start_ExternalIPFailureUnwinds(t *testing.T) { + fwd := &fakeForwarder{extIPErr: errors.New("gateway returned empty")} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + err := c.Start(context.Background()) + require.Error(t, err) + assert.False(t, c.IsActive()) + assert.True(t, fwd.wasUnmapped(), "port must be unmapped after external-ip failure") + assert.Equal(t, int64(0), srv.registerCount.Load()) + assert.False(t, box.started.Load()) +} + +func TestClient_Start_RegisterFailureUnwinds(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + srv.registerStatus = http.StatusUnprocessableEntity + c := newTestClient(t, fwd, box, srv) + + err := c.Start(context.Background()) + require.Error(t, err) + assert.False(t, c.IsActive()) + assert.True(t, fwd.wasUnmapped()) + assert.False(t, box.started.Load()) +} + +func TestClient_Start_BoxStartFailureUnwinds(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{startErr: errors.New("boom")} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + err := c.Start(context.Background()) + require.Error(t, err) + assert.False(t, c.IsActive()) + assert.True(t, fwd.wasUnmapped()) + assert.True(t, box.closed.Load()) + assert.Equal(t, int64(1), srv.deregisterCount.Load()) +} + +func TestClient_Stop_HappyPath(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + ctx := context.Background() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.Stop(ctx)) + + assert.False(t, c.IsActive()) + assert.True(t, fwd.wasUnmapped()) + assert.True(t, box.closed.Load()) + assert.Equal(t, int64(1), srv.deregisterCount.Load()) + assert.Equal(t, "test-device", srv.deregisterDeviceID.Load()) +} + +func TestClient_Stop_IsIdempotent(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + c := newTestClient(t, fwd, box, srv) + + ctx := context.Background() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.Stop(ctx)) + require.NoError(t, c.Stop(ctx)) + assert.Equal(t, int64(1), srv.deregisterCount.Load()) +} + +// Stop continues teardown even if individual steps fail. The first error is +// returned; the others are logged. All resources still get released. +func TestClient_Stop_ContinuesPastIndividualErrors(t *testing.T) { + fwd := &fakeForwarder{unmapErr: errors.New("router said no")} + box := &fakeBoxService{closeErr: errors.New("box close failed")} + srv := newStubServer(t) + srv.deregisterStatus = http.StatusInternalServerError + c := newTestClient(t, fwd, box, srv) + + ctx := context.Background() + require.NoError(t, c.Start(ctx)) + err := c.Stop(ctx) + require.Error(t, err) + assert.ErrorContains(t, err, "deregister") + + assert.False(t, c.IsActive()) + assert.True(t, fwd.wasUnmapped()) + assert.True(t, box.closed.Load()) + assert.Equal(t, int64(1), srv.deregisterCount.Load()) +} + +// Drives the loop with a 50ms interval (overridden via Config.HeartbeatInterval) +// against a server that always 404s, then waits for the auto-stop goroutine to +// flip IsActive() false and run teardown. +func TestClient_Heartbeat_404TriggersAutoStop(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + srv.heartbeatStatus = http.StatusNotFound + c := newTestClient(t, fwd, box, srv, func(cfg *Config) { + cfg.HeartbeatInterval = 50 * time.Millisecond + cfg.HeartbeatTimeout = 1 * time.Second + }) + + require.NoError(t, c.Start(context.Background())) + + deadline := time.After(3 * time.Second) + for c.IsActive() { + select { + case <-deadline: + t.Fatal("client did not auto-stop within 3s") + case <-time.After(20 * time.Millisecond): + } + } + + assert.GreaterOrEqual(t, srv.heartbeatCount.Load(), int64(1)) + assert.Equal(t, "test-device", srv.heartbeatDeviceID.Load()) + assert.Equal(t, int64(1), srv.deregisterCount.Load()) + assert.True(t, fwd.wasUnmapped()) + assert.True(t, box.closed.Load()) +} + +// Non-404 heartbeat errors must not tear the client down — they're logged +// and the loop keeps trying. +func TestClient_Heartbeat_TransientErrorDoesNotStop(t *testing.T) { + fwd := &fakeForwarder{} + box := &fakeBoxService{} + srv := newStubServer(t) + srv.heartbeatStatus = http.StatusInternalServerError + c := newTestClient(t, fwd, box, srv, func(cfg *Config) { + cfg.HeartbeatInterval = 50 * time.Millisecond + cfg.HeartbeatTimeout = 1 * time.Second + }) + + require.NoError(t, c.Start(context.Background())) + t.Cleanup(func() { _ = c.Stop(context.Background()) }) + + // Wait long enough for several heartbeats to fire. + deadline := time.After(500 * time.Millisecond) + for srv.heartbeatCount.Load() < 3 { + select { + case <-deadline: + t.Fatalf("only %d heartbeats fired in 500ms", srv.heartbeatCount.Load()) + case <-time.After(20 * time.Millisecond): + } + } + assert.True(t, c.IsActive()) + assert.Equal(t, int64(0), srv.deregisterCount.Load()) +} + +// The peer's sing-box must bypass the user's own VPN TUN — verify both the +// "no route block at all" and "existing route block" cases get the flag set, +// and that other route-level keys are preserved. +func TestEnsurePeerOutboundsBypassVPN(t *testing.T) { + t.Run("adds route block when missing", func(t *testing.T) { + in := `{"inbounds":[{"type":"samizdat","tag":"samizdat-in"}]}` + out, err := ensurePeerOutboundsBypassVPN(in) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(out), &parsed)) + route := parsed["route"].(map[string]any) + assert.Equal(t, true, route["auto_detect_interface"]) + assert.Contains(t, parsed, "inbounds", "must preserve other top-level fields") + }) + t.Run("preserves existing route fields", func(t *testing.T) { + in := `{"route":{"rules":[{"action":"sniff"}],"final":"direct"}}` + out, err := ensurePeerOutboundsBypassVPN(in) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(out), &parsed)) + route := parsed["route"].(map[string]any) + assert.Equal(t, true, route["auto_detect_interface"]) + assert.Equal(t, "direct", route["final"]) + assert.NotEmpty(t, route["rules"]) + }) + t.Run("rejects malformed json", func(t *testing.T) { + _, err := ensurePeerOutboundsBypassVPN(`{not json`) + assert.Error(t, err) + }) +} + +func TestPickInternalPort_InRange(t *testing.T) { + for i := 0; i < 100; i++ { + p := pickInternalPort() + assert.GreaterOrEqual(t, int(p), internalPortMin) + assert.Less(t, int(p), internalPortMax) + } +} + +func TestAPIError_StringFormat(t *testing.T) { + e := &APIError{Status: 422, Body: "could not connect to peer port"} + assert.Contains(t, e.Error(), "422") + assert.Contains(t, e.Error(), "could not connect") +} + +var _ portForwarder = (*fakeForwarder)(nil) +var _ boxService = (*fakeBoxService)(nil) diff --git a/portforward/portforward.go b/portforward/portforward.go new file mode 100644 index 00000000..6a8fe9c1 --- /dev/null +++ b/portforward/portforward.go @@ -0,0 +1,384 @@ +// Package portforward opens TCP ports on the local network gateway via UPnP +// IGD so a peer-proxy inbound is reachable from the public internet without +// manual router configuration. IGDv2 is tried first and IGDv1 is the +// fallback. +package portforward + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "sync" + "time" + + "github.com/huin/goupnp/dcps/internetgateway1" + "github.com/huin/goupnp/dcps/internetgateway2" +) + +// ErrNoPortForwarding is returned when no UPnP gateway is reachable, the +// gateway refuses to map a port, or the discovery scan times out. Callers +// should treat this as "this network can't host a peer proxy" and surface it +// to the user rather than retry indefinitely. +var ErrNoPortForwarding = errors.New("no port forwarding available") + +type Mapping struct { + ExternalPort uint16 + InternalPort uint16 + InternalIP string + Protocol string + LeaseDuration time.Duration + Method string +} + +// igdClient is the subset of the IGDv2/v1 clients we use. goupnp's generated +// clients already satisfy this shape. +type igdClient interface { + AddPortMapping(remoteHost string, externalPort uint16, protocol string, internalPort uint16, internalClient string, enabled bool, description string, leaseDuration uint32) error + DeletePortMapping(remoteHost string, externalPort uint16, protocol string) error + GetExternalIPAddress() (string, error) +} + +// Forwarder manages a single port mapping on the local gateway. Construct +// one per peer-proxy session. +type Forwarder struct { + mu sync.Mutex + client igdClient + method string + mapping *Mapping + cancel context.CancelFunc +} + +// NewForwarder discovers the local gateway and returns a Forwarder bound to +// it. Callers should pick a 5-10s timeout on ctx — UPnP discovery is M-SEARCH +// multicast and waits for replies. +// +// Returns ErrNoPortForwarding only when discovery completes without finding +// a usable gateway. If ctx was canceled or its deadline expired during +// discovery, the ctx error is returned verbatim so callers can distinguish +// "this network can't host a peer" from "we ran out of time, retry later". +func NewForwarder(ctx context.Context) (*Forwarder, error) { + if c, err := discoverIGDv2(ctx); err == nil && c != nil { + return &Forwarder{client: c, method: "upnp-igd2"}, nil + } + if c, err := discoverIGDv1(ctx); err == nil && c != nil { + return &Forwarder{client: c, method: "upnp-igd1"}, nil + } + if err := ctx.Err(); err != nil { + return nil, err + } + return nil, ErrNoPortForwarding +} + +// MapPort asks the gateway to forward externalPort → (LocalIP():internalPort) +// for TCP. Lease duration is requested as 1 hour but some routers ignore the +// request and assign their own (or none — "permanent"). description is shown +// in the router's UI so users can identify and remove the mapping manually +// if needed. +func (f *Forwarder) MapPort(ctx context.Context, internalPort uint16, description string) (*Mapping, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.mapping != nil { + return nil, errors.New("forwarder already has an active mapping") + } + + internalIP, err := localIP() + if err != nil { + return nil, fmt.Errorf("determine local ip: %w", err) + } + + const requestedLease uint32 = 3600 + // externalPort defaults to internalPort. If the router already has that + // port mapped to someone else, AddPortMapping fails and the caller can + // retry with a different internalPort. + externalPort := internalPort + client := f.client + err = runWithCtx(ctx, func() error { + return client.AddPortMapping("", externalPort, "TCP", internalPort, internalIP, true, description, requestedLease) + }) + if err != nil { + // Propagate ctx cancellation/deadline verbatim so callers can retry + // rather than treating it as a permanent "this network won't work". + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, fmt.Errorf("add port mapping: %w", ctxErr) + } + // Per the ErrNoPortForwarding docstring, a gateway refusing to map a + // port is the "this network can't host a peer" case. Join the + // sentinel so callers can detect it via errors.Is while still + // surfacing the underlying router-specific reason for diagnostics. + return nil, fmt.Errorf("add port mapping: %w", errors.Join(ErrNoPortForwarding, err)) + } + + f.mapping = &Mapping{ + ExternalPort: externalPort, + InternalPort: internalPort, + InternalIP: internalIP, + Protocol: "TCP", + LeaseDuration: time.Duration(requestedLease) * time.Second, + Method: f.method, + } + return f.mapping, nil +} + +// UnmapPort removes the active mapping. No-op if no mapping is active. +// Always called as part of teardown — even if the gateway has already let +// the lease expire, DeletePortMapping is the polite signal to the router. +// +// f.mapping is cleared only on a successful delete. A failed delete leaves +// the mapping in place so the caller can retry; otherwise we'd "forget" +// about a router rule that's actually still live and the user would have +// to wait for the UPnP lease to expire. +func (f *Forwarder) UnmapPort(ctx context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.cancel != nil { + f.cancel() + f.cancel = nil + } + if f.mapping == nil { + return nil + } + m := f.mapping + client := f.client + err := runWithCtx(ctx, func() error { + return client.DeletePortMapping("", m.ExternalPort, m.Protocol) + }) + if err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + f.mapping = nil + return nil +} + +// StartRenewal launches a goroutine that re-issues AddPortMapping at half +// the requested lease duration (minimum 1 minute) until ctx is cancelled +// or UnmapPort is called. The cadence is keyed off what we *requested* +// (mapping.LeaseDuration) — UPnP IGD has no API to query what the router +// actually assigned, so a router that ignored the request and silently +// applied a shorter TTL can still drop the mapping between renewal ticks. +// The peer's heartbeat path will surface that failure and auto-Stop the +// session; routine 30-minute refresh of an hour-long requested lease +// handles the common case where the router honors the requested duration. +func (f *Forwarder) StartRenewal(ctx context.Context) { + f.mu.Lock() + defer f.mu.Unlock() + if f.cancel != nil { + return + } + if f.mapping == nil { + return + } + renewCtx, cancel := context.WithCancel(ctx) + f.cancel = cancel + interval := f.mapping.LeaseDuration / 2 + if interval < 1*time.Minute { + interval = 1 * time.Minute + } + go f.renewLoop(renewCtx, interval) +} + +func (f *Forwarder) renewLoop(ctx context.Context, interval time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + f.mu.Lock() + m := f.mapping + client := f.client + f.mu.Unlock() + if m == nil { + return + } + // Most routers treat a re-issued AddPortMapping as "extend the + // existing lease"; some replace it with a fresh one. Either is + // fine here. + err := runWithCtx(ctx, func() error { + return client.AddPortMapping("", m.ExternalPort, "TCP", m.InternalPort, m.InternalIP, true, "Lantern peer share (renew)", uint32(m.LeaseDuration/time.Second)) + }) + if err != nil { + slog.Warn("portforward: lease renewal failed", "err", err, "external_port", m.ExternalPort) + } + } + } +} + +// ExternalIP queries the gateway for its WAN-side IP address. Cheaper than +// dialing a public-IP service when we already have a UPnP client open. +func (f *Forwarder) ExternalIP(ctx context.Context) (string, error) { + f.mu.Lock() + c := f.client + f.mu.Unlock() + var ip string + err := runWithCtx(ctx, func() error { + got, gerr := c.GetExternalIPAddress() + if gerr != nil { + return gerr + } + ip = got + return nil + }) + if err != nil { + return "", fmt.Errorf("get external ip: %w", err) + } + if ip == "" { + return "", fmt.Errorf("gateway returned empty external ip") + } + return ip, nil +} + +// localIP returns the LAN address the OS would use to reach the gateway. +// +// First tries the UDP-noop trick (let the kernel pick a route to a known +// public address) — fastest and most accurate when the host has a working +// default route. Falls back to scanning interfaces for a private IPv4 if +// that fails, which covers networks that block 8.8.8.8 outbound or use +// non-default IPv4 routing tables. UPnP IGD itself is IPv4 in IGDv1 and +// almost always IPv4 in IGDv2, so we only consider IPv4 addresses. +func localIP() (string, error) { + if ip, err := localIPByDial(); err == nil { + return ip, nil + } + return localIPByInterfaceScan() +} + +func localIPByDial() (string, error) { + conn, err := net.Dial("udp", "8.8.8.8:53") + if err != nil { + return "", fmt.Errorf("dial udp for local ip: %w", err) + } + defer func() { _ = conn.Close() }() + addr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return "", fmt.Errorf("unexpected local addr type %T", conn.LocalAddr()) + } + return addr.IP.String(), nil +} + +func localIPByInterfaceScan() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("list interfaces: %w", err) + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP.IsLoopback() || ipnet.IP.IsLinkLocalUnicast() { + continue + } + ip4 := ipnet.IP.To4() + if ip4 == nil || !ip4.IsPrivate() { + continue + } + return ip4.String(), nil + } + } + return "", fmt.Errorf("no usable private ipv4 found on any interface") +} + +func LocalIP() (string, error) { return localIP() } + +// runWithCtx wraps a blocking call so the caller's context can abort the +// wait. Returns ctx.Err() immediately if ctx is already cancelled, so the +// gateway-side side effect (port mapping, etc.) doesn't fire after the +// caller has already given up. If ctx cancels mid-call, the wrapped +// goroutine still runs to completion — UPnP/HTTP calls have their own +// underlying timeouts — but we no longer hand the entire wait time to an +// unresponsive gateway. +func runWithCtx(ctx context.Context, fn func() error) error { + if err := ctx.Err(); err != nil { + return err + } + done := make(chan error, 1) + go func() { done <- fn() }() + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err + } +} + +func discoverIGDv2(ctx context.Context) (igdClient, error) { + clients, _, err := internetgateway2.NewWANIPConnection2ClientsCtx(ctx) + if err != nil { + return nil, err + } + if len(clients) == 0 { + return nil, nil + } + return wanIPv2Wrapper{c: clients[0]}, nil +} + +// discoverIGDv1 looks for both WANIPConnection and WANPPPConnection gateways. +// Cable/fiber CPE routers typically expose UPnP via WANIPConnection; DSL and +// other PPPoE-terminated CPEs typically expose it via WANPPPConnection. +// Probing only one would miss large swaths of consumer hardware. +func discoverIGDv1(ctx context.Context) (igdClient, error) { + if clients, _, err := internetgateway1.NewWANIPConnection1ClientsCtx(ctx); err == nil && len(clients) > 0 { + return wanIPv1Wrapper{c: clients[0]}, nil + } + clients, _, err := internetgateway1.NewWANPPPConnection1ClientsCtx(ctx) + if err != nil { + return nil, err + } + if len(clients) == 0 { + return nil, nil + } + return wanPPPv1Wrapper{c: clients[0]}, nil +} + +// IGDv1 and IGDv2's generated clients have slightly different method +// signatures, so wrappers normalize them to a single igdClient interface. + +type wanIPv2Wrapper struct{ c *internetgateway2.WANIPConnection2 } + +func (w wanIPv2Wrapper) AddPortMapping(remoteHost string, externalPort uint16, protocol string, internalPort uint16, internalClient string, enabled bool, description string, leaseDuration uint32) error { + return w.c.AddPortMapping(remoteHost, externalPort, protocol, internalPort, internalClient, enabled, description, leaseDuration) +} +func (w wanIPv2Wrapper) DeletePortMapping(remoteHost string, externalPort uint16, protocol string) error { + return w.c.DeletePortMapping(remoteHost, externalPort, protocol) +} +func (w wanIPv2Wrapper) GetExternalIPAddress() (string, error) { + return w.c.GetExternalIPAddress() +} + +type wanIPv1Wrapper struct{ c *internetgateway1.WANIPConnection1 } + +func (w wanIPv1Wrapper) AddPortMapping(remoteHost string, externalPort uint16, protocol string, internalPort uint16, internalClient string, enabled bool, description string, leaseDuration uint32) error { + return w.c.AddPortMapping(remoteHost, externalPort, protocol, internalPort, internalClient, enabled, description, leaseDuration) +} +func (w wanIPv1Wrapper) DeletePortMapping(remoteHost string, externalPort uint16, protocol string) error { + return w.c.DeletePortMapping(remoteHost, externalPort, protocol) +} +func (w wanIPv1Wrapper) GetExternalIPAddress() (string, error) { + return w.c.GetExternalIPAddress() +} + +type wanPPPv1Wrapper struct{ c *internetgateway1.WANPPPConnection1 } + +func (w wanPPPv1Wrapper) AddPortMapping(remoteHost string, externalPort uint16, protocol string, internalPort uint16, internalClient string, enabled bool, description string, leaseDuration uint32) error { + return w.c.AddPortMapping(remoteHost, externalPort, protocol, internalPort, internalClient, enabled, description, leaseDuration) +} +func (w wanPPPv1Wrapper) DeletePortMapping(remoteHost string, externalPort uint16, protocol string) error { + return w.c.DeletePortMapping(remoteHost, externalPort, protocol) +} +func (w wanPPPv1Wrapper) GetExternalIPAddress() (string, error) { + return w.c.GetExternalIPAddress() +} + +var ( + _ igdClient = wanIPv2Wrapper{} + _ igdClient = wanIPv1Wrapper{} + _ igdClient = wanPPPv1Wrapper{} +) diff --git a/portforward/portforward_test.go b/portforward/portforward_test.go new file mode 100644 index 00000000..7d6e0ee2 --- /dev/null +++ b/portforward/portforward_test.go @@ -0,0 +1,259 @@ +package portforward + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeIGD struct { + mu sync.Mutex + addCalls atomic.Int64 + deleteCalls atomic.Int64 + addErr error + deleteErr error + extIPErr error + extIP string + addBlock chan struct{} // if non-nil, AddPortMapping blocks on receive + lastAdd mappingArgs + lastDelete deleteArgs +} + +type mappingArgs struct { + externalPort, internalPort uint16 + internalClient, description string + leaseDuration uint32 +} + +type deleteArgs struct { + externalPort uint16 + protocol string +} + +func (f *fakeIGD) AddPortMapping(_ string, externalPort uint16, _ string, internalPort uint16, internalClient string, _ bool, description string, leaseDuration uint32) error { + f.addCalls.Add(1) + if f.addBlock != nil { + <-f.addBlock + } + f.mu.Lock() + f.lastAdd = mappingArgs{ + externalPort: externalPort, + internalPort: internalPort, + internalClient: internalClient, + description: description, + leaseDuration: leaseDuration, + } + f.mu.Unlock() + return f.addErr +} + +func (f *fakeIGD) DeletePortMapping(_ string, externalPort uint16, protocol string) error { + f.deleteCalls.Add(1) + f.mu.Lock() + f.lastDelete = deleteArgs{externalPort: externalPort, protocol: protocol} + f.mu.Unlock() + return f.deleteErr +} + +func (f *fakeIGD) GetExternalIPAddress() (string, error) { + if f.extIPErr != nil { + return "", f.extIPErr + } + if f.extIP == "" { + return "203.0.113.1", nil + } + return f.extIP, nil +} + +func newTestForwarder(t *testing.T, c *fakeIGD) *Forwarder { + t.Helper() + return &Forwarder{client: c, method: "fake"} +} + +func TestForwarder_MapPort_HappyPath(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + m, err := f.MapPort(context.Background(), 30001, "test") + require.NoError(t, err) + assert.Equal(t, uint16(30001), m.ExternalPort) + assert.Equal(t, uint16(30001), m.InternalPort) + assert.Equal(t, "TCP", m.Protocol) + assert.Equal(t, "fake", m.Method) + assert.Equal(t, int64(1), c.addCalls.Load()) +} + +func TestForwarder_MapPort_DoubleMapRejected(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + _, err := f.MapPort(context.Background(), 30001, "test") + require.NoError(t, err) + _, err = f.MapPort(context.Background(), 30002, "test") + assert.ErrorContains(t, err, "already has an active mapping") +} + +func TestForwarder_MapPort_PropagatesGatewayError(t *testing.T) { + c := &fakeIGD{addErr: errors.New("conflict")} + f := newTestForwarder(t, c) + + _, err := f.MapPort(context.Background(), 30001, "test") + assert.ErrorContains(t, err, "add port mapping") +} + +// MapPort must respect the caller's context — a hung router shouldn't tie up +// Start past its deadline. +func TestForwarder_MapPort_RespectsContextCancellation(t *testing.T) { + block := make(chan struct{}) + c := &fakeIGD{addBlock: block} + f := newTestForwarder(t, c) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := f.MapPort(ctx, 30001, "test") + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + close(block) // release the leaked goroutine +} + +func TestForwarder_UnmapPort_NoMappingIsNoop(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + require.NoError(t, f.UnmapPort(context.Background())) + assert.Equal(t, int64(0), c.deleteCalls.Load()) +} + +func TestForwarder_UnmapPort_RemovesMapping(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + _, err := f.MapPort(context.Background(), 30001, "test") + require.NoError(t, err) + + require.NoError(t, f.UnmapPort(context.Background())) + assert.Equal(t, int64(1), c.deleteCalls.Load()) + assert.Equal(t, uint16(30001), c.lastDelete.externalPort) + assert.Equal(t, "TCP", c.lastDelete.protocol) + + // Calling MapPort after UnmapPort must succeed (mapping cleared). + _, err = f.MapPort(context.Background(), 30002, "test") + require.NoError(t, err) +} + +func TestForwarder_StartRenewal_ReissuesAddPortMapping(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + // Use a short lease so the renewal interval clamps to the 1m floor; we + // invoke the loop directly with a fast interval to avoid waiting. + _, err := f.MapPort(context.Background(), 30001, "test") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + go f.renewLoop(ctx, 20*time.Millisecond) + + deadline := time.After(2 * time.Second) + for c.addCalls.Load() < 3 { + select { + case <-deadline: + t.Fatalf("renewal fired only %d times", c.addCalls.Load()) + case <-time.After(10 * time.Millisecond): + } + } + cancel() +} + +// Cancelling the renewal ctx must stop the loop even with a long interval. +func TestForwarder_StartRenewal_CancelsCleanly(t *testing.T) { + c := &fakeIGD{} + f := newTestForwarder(t, c) + + _, err := f.MapPort(context.Background(), 30001, "test") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + f.renewLoop(ctx, time.Hour) + close(done) + }() + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("renewLoop did not exit after ctx cancel") + } +} + +func TestForwarder_ExternalIP(t *testing.T) { + c := &fakeIGD{extIP: "203.0.113.50"} + f := newTestForwarder(t, c) + ip, err := f.ExternalIP(context.Background()) + require.NoError(t, err) + assert.Equal(t, "203.0.113.50", ip) +} + +func TestForwarder_ExternalIP_EmptyIsError(t *testing.T) { + f := &Forwarder{client: emptyExtIPClient{}, method: "fake"} + _, err := f.ExternalIP(context.Background()) + assert.ErrorContains(t, err, "empty external ip") +} + +type emptyExtIPClient struct{} + +func (emptyExtIPClient) AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error { + return nil +} +func (emptyExtIPClient) DeletePortMapping(string, uint16, string) error { return nil } +func (emptyExtIPClient) GetExternalIPAddress() (string, error) { return "", nil } + +func TestForwarder_ExternalIP_PropagatesError(t *testing.T) { + c := &fakeIGD{extIPErr: errors.New("upstream timeout")} + f := newTestForwarder(t, c) + _, err := f.ExternalIP(context.Background()) + assert.ErrorContains(t, err, "upstream timeout") +} + +func TestLocalIP(t *testing.T) { + // Best-effort: localIP needs working UDP. CI machines have it; offline + // dev machines may not. Skip rather than fail if it errors. + ip, err := LocalIP() + if err != nil { + t.Skipf("localIP unavailable in this environment: %v", err) + } + assert.NotEmpty(t, ip) +} + +// The interface-scan fallback covers networks where the UDP-noop trick +// fails (IPv6-only host, kernel rejects 8.8.8.8, etc.). Skip if the dev +// machine genuinely lacks a private IPv4 — running this on a CI worker +// without a LAN address shouldn't fail the build. +func TestLocalIPByInterfaceScan(t *testing.T) { + ip, err := localIPByInterfaceScan() + if err != nil { + t.Skipf("no private ipv4 interface available: %v", err) + } + assert.NotEmpty(t, ip) +} + +// MapPort's gateway-refused path must surface ErrNoPortForwarding via +// errors.Is so callers can distinguish "this network won't work" from +// "something else broke", per the package-level docstring. +func TestForwarder_MapPort_GatewayErrorWrapsErrNoPortForwarding(t *testing.T) { + c := &fakeIGD{addErr: errors.New("ConflictInMappingEntry")} + f := newTestForwarder(t, c) + + _, err := f.MapPort(context.Background(), 30001, "test") + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoPortForwarding, "callers must be able to detect via errors.Is") + assert.ErrorContains(t, err, "ConflictInMappingEntry", "underlying gateway error must survive for diagnostics") +}