diff --git a/README.md b/README.md index 353bd0f3..c27a634b 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,8 @@ Everything the UI does is also available in the CLI. Copy `config.example.json` "google_ip": "216.239.38.120", "front_domain": "www.google.com", "script_id": "PASTE_YOUR_DEPLOYMENT_ID_HERE", + "cfw_script_id": "OPTIONAL_CFW_APPS_SCRIPT_DEPLOYMENT_ID", + "cfw_hosts": ["x.com", ".twitter.com"], "auth_key": "same-secret-as-in-code-gs", "listen_host": "127.0.0.1", "listen_port": 8085, @@ -190,6 +192,11 @@ Everything the UI does is also available in the CLI. Copy `config.example.json` } ``` +`cfw_script_id` + `cfw_hosts` are optional hybrid-routing knobs: when a request hostname matches `cfw_hosts`, `mhrv-rs` sends only that request through the CFW-backed Apps Script deployment (`assets/apps_script/CodeHybrid.gs`) and keeps everything else on the normal Apps Script deployment. + + +For a clean one-repo setup, everything you need now lives under `assets/apps_script/` (Apps Script + Worker templates). + Then: ```bash diff --git a/assets/apps_script/CodeHybrid.gs b/assets/apps_script/CodeHybrid.gs new file mode 100644 index 00000000..7e9bf137 --- /dev/null +++ b/assets/apps_script/CodeHybrid.gs @@ -0,0 +1,259 @@ +/** + * MHRV Hybrid Relay (Apps Script + optional Cloudflare Worker) + * + * Client protocol (same as mhrv-rs): + * Single: POST { k, m, u, h, b, ct, r } -> { s, h, b } or { e } + * Batch : POST { k, q: [{m,u,h,b,ct,r}, ...] } -> { q: [{s,h,b}|{e}, ...] } + * + * Routing: + * - Default: direct UrlFetchApp to destination URL + * - Optional CFW path: for hostnames listed in CFW_HOSTS, forward via WORKER_URL + * + * Notes: + * - Keep AUTH_KEY secret and match it in mhrv-rs config. + * - If WORKER_URL is empty, CFW route is effectively disabled. + */ + +const AUTH_KEY = "CHANGE_ME_TO_A_STRONG_SECRET"; + +// Optional Cloudflare Worker endpoint (ex: "https://myrelay.workers.dev") +const WORKER_URL = ""; + +// Optional host routing list for worker path. +// Exact host: "x.com" +// Suffix : ".twitter.com" (matches api.twitter.com) +const CFW_HOSTS = [ + // "x.com", + // ".twitter.com", +]; + +const SKIP_HEADERS = { + host: 1, + connection: 1, + "content-length": 1, + "transfer-encoding": 1, + "proxy-connection": 1, + "proxy-authorization": 1, + "priority": 1, + te: 1, +}; + +const DECOY_HTML = + 'Web App' + + '

The script completed but did not return anything.

'; + +function doPost(e) { + try { + var req = JSON.parse(e.postData.contents); + if (req.k !== AUTH_KEY) return _decoy(); + + if (Array.isArray(req.q)) return _doBatch(req.q); + return _doSingle(req); + } catch (err) { + return _decoy(); + } +} + +function doGet(e) { + return ContentService.createTextOutput(DECOY_HTML).setMimeType(ContentService.MimeType.HTML); +} + +function _doSingle(req) { + if (!_isValidUrl(req.u)) return _json({ e: "bad url" }); + + try { + var resp = _fetchRelay(req); + return _json(_packResponse(resp)); + } catch (err) { + return _json({ e: String(err) }); + } +} + +function _doBatch(items) { + var fetchArgs = []; + var indexMap = []; + var results = []; + + for (var i = 0; i < items.length; i++) { + var item = items[i]; + if (!_isValidUrl(item.u)) { + results[i] = { e: "bad url" }; + continue; + } + + var built = _buildFetch(item); + fetchArgs.push(built.opts); + indexMap.push({ idx: i, worker: built.worker }); + } + + if (fetchArgs.length > 0) { + var responses = UrlFetchApp.fetchAll(fetchArgs); + for (var j = 0; j < responses.length; j++) { + var meta = indexMap[j]; + try { + if (meta.worker) { + results[meta.idx] = JSON.parse(responses[j].getContentText()); + } else { + results[meta.idx] = _packResponse(responses[j]); + } + } catch (err) { + results[meta.idx] = { e: "invalid worker response" }; + } + } + } + + for (var k = 0; k < items.length; k++) { + if (!results[k]) results[k] = { e: "unknown" }; + } + + return _json({ q: results }); +} + +function _fetchRelay(req) { + var built = _buildFetch(req); + var resp = UrlFetchApp.fetch(built.url, built.opts); + if (!built.worker) return resp; + + var txt = resp.getContentText(); + return { + _worker: true, + _parsed: JSON.parse(txt), + }; +} + +function _packResponse(resp) { + if (resp && resp._worker) return resp._parsed; + return { + s: resp.getResponseCode(), + h: _respHeaders(resp), + b: Utilities.base64Encode(resp.getContent()), + }; +} + +function _buildFetch(req) { + var useWorker = _shouldUseWorker(req.u); + if (!useWorker) { + return { + url: req.u, + worker: false, + opts: _buildDirectOpts(req), + }; + } + + if (!WORKER_URL) { + throw new Error("WORKER_URL is empty but request matched CFW_HOSTS"); + } + + return { + url: WORKER_URL, + worker: true, + opts: { + method: "post", + contentType: "application/json", + payload: JSON.stringify(_buildWorkerPayload(req)), + muteHttpExceptions: true, + followRedirects: true, + validateHttpsCertificates: true, + escaping: false, + }, + }; +} + +function _buildDirectOpts(req) { + var opts = { + method: (req.m || "GET").toLowerCase(), + muteHttpExceptions: true, + followRedirects: req.r !== false, + validateHttpsCertificates: true, + escaping: false, + }; + var headers = _filteredHeaders(req.h); + if (Object.keys(headers).length > 0) opts.headers = headers; + if (req.b) { + opts.payload = Utilities.base64Decode(req.b); + if (req.ct) opts.contentType = req.ct; + } + return opts; +} + +function _buildWorkerPayload(req) { + return { + u: req.u, + m: (req.m || "GET").toUpperCase(), + h: _filteredHeaders(req.h), + b: req.b || null, + ct: req.ct || null, + r: req.r !== false, + }; +} + +function _filteredHeaders(inHeaders) { + var headers = {}; + if (!inHeaders || typeof inHeaders !== "object") return headers; + + for (var k in inHeaders) { + if (!inHeaders.hasOwnProperty(k)) continue; + if (SKIP_HEADERS[k.toLowerCase()]) continue; + headers[k] = inHeaders[k]; + } + return headers; +} + +function _shouldUseWorker(url) { + if (!CFW_HOSTS || CFW_HOSTS.length === 0) return false; + + var host; + try { + host = _hostFromUrl(url); + } catch (_) { + return false; + } + + for (var i = 0; i < CFW_HOSTS.length; i++) { + var entry = String(CFW_HOSTS[i] || "").trim().toLowerCase().replace(/\.+$/, ""); + if (!entry) continue; + if (entry.charAt(0) === ".") { + var suffix = entry.slice(1); + if (!suffix) continue; + if (host === suffix || host.endsWith("." + suffix)) return true; + } else { + if (host === entry) return true; + } + } + return false; +} + +function _hostFromUrl(url) { + var m = String(url || "").match(/^https?:\/\/([^\/]+)/i); + if (!m) throw new Error("invalid url"); + var authority = m[1].toLowerCase(); + var noAuth = authority.indexOf("@") >= 0 ? authority.split("@").pop() : authority; + if (noAuth.charAt(0) === "[") { + var r = noAuth.indexOf("]"); + return (r > 0 ? noAuth.slice(1, r) : noAuth).replace(/\.+$/, ""); + } + return noAuth.split(":")[0].replace(/\.+$/, ""); +} + +function _isValidUrl(u) { + return typeof u === "string" && /^https?:\/\//i.test(u); +} + +function _respHeaders(resp) { + try { + if (typeof resp.getAllHeaders === "function") { + return resp.getAllHeaders(); + } + } catch (_) {} + return resp.getHeaders(); +} + +function _decoy() { + return ContentService.createTextOutput(DECOY_HTML).setMimeType(ContentService.MimeType.HTML); +} + +function _json(obj) { + return ContentService + .createTextOutput(JSON.stringify(obj)) + .setMimeType(ContentService.MimeType.JSON); +} diff --git a/assets/apps_script/README.md b/assets/apps_script/README.md index 1cf339a2..15e1eb2f 100644 --- a/assets/apps_script/README.md +++ b/assets/apps_script/README.md @@ -1,13 +1,26 @@ -# Apps Script source (mirrored) +# Apps Script / Worker templates for `mhrv-rs` -The file `Code.gs` next to this README is a verbatim snapshot of the upstream script you deploy in your own Google Apps Script project: +This folder contains deploy-ready scripts used by the Rust client. -- Upstream: -- Raw link: +## Files -This copy lives in our repo for two reasons: +- `Code.gs` — upstream-compatible direct Apps Script relay. +- `CodeFull.gs` — full-mode tunnel relay script (for `mode = "full"`). +- `CodeHybrid.gs` — new hybrid relay script: + - default route: direct `UrlFetchApp` (normal Apps Script behavior) + - optional route: forwards selected hostnames to your Cloudflare Worker +- `worker.js` — minimal Cloudflare Worker endpoint that accepts the same relay payload and returns `{s,h,b}`. -1. **Survives upstream outages**: if the user is on a network where raw.githubusercontent.com is temporarily unreachable but they can clone or ZIP this repo, they still have the deploy-ready file. -2. **Pins what we tested against**: the relay protocol between `mhrv-rs` and the script is informal; upstream changes can silently break us. Keeping a snapshot here lets us diff and see if a spec drift is responsible for any reported breakage. +## When to use which -All credit for `Code.gs` goes to [@masterking32](https://github.com/masterking32) — we do not modify it. If you're using mhrv-rs, follow the upstream deploy instructions in the script's header comment. The only edit **you** must make is the `AUTH_KEY` constant — set it to a strong secret and reuse that exact string in your `mhrv-rs` config. +- Want classic setup only: deploy **`Code.gs`**. +- Want full tunnel mode: deploy **`CodeFull.gs`**. +- Want mixed routing (normal via Apps Script + specific hosts via CFW): deploy **`CodeHybrid.gs`** and configure: + - `WORKER_URL`, `CFW_HOSTS` in script + - `cfw_script_id` / `cfw_hosts` in `mhrv-rs` config + +## Security notes + +- Always change `AUTH_KEY` before deployment. +- Keep Worker URL private if possible. +- Do not share deployment IDs and auth key publicly. diff --git a/assets/apps_script/worker.js b/assets/apps_script/worker.js new file mode 100644 index 00000000..fd92c99c --- /dev/null +++ b/assets/apps_script/worker.js @@ -0,0 +1,88 @@ +const WORKER_URL = "myworker.workers.dev"; + +export default { + async fetch(request) { + try { + if (request.headers.get("x-relay-hop") === "1") { + return json({ e: "loop detected" }, 508); + } + + const req = await request.json(); + + if (!req.u) { + return json({ e: "missing url" }, 400); + } + + const targetUrl = new URL(req.u); + + const BLOCKED_HOSTS = [ + WORKER_URL, + ]; + + if (BLOCKED_HOSTS.some(h => targetUrl.hostname.endsWith(h))) { + return json({ e: "self-fetch blocked" }, 400); + } + + const headers = new Headers(); + if (req.h && typeof req.h === "object") { + for (const [k, v] of Object.entries(req.h)) { + headers.set(k, v); + } + } + + headers.set("x-relay-hop", "1"); + + const fetchOptions = { + method: (req.m || "GET").toUpperCase(), + headers, + redirect: req.r === false ? "manual" : "follow" + }; + + if (req.b) { + const binary = Uint8Array.from(atob(req.b), c => c.charCodeAt(0)); + fetchOptions.body = binary; + } + + const resp = await fetch(targetUrl.toString(), fetchOptions); + + // Read response safely (no stack overflow) + const buffer = await resp.arrayBuffer(); + const uint8 = new Uint8Array(buffer); + + let binary = ""; + const chunkSize = 0x8000; // prevent call stack overflow + + for (let i = 0; i < uint8.length; i += chunkSize) { + binary += String.fromCharCode.apply( + null, + uint8.subarray(i, i + chunkSize) + ); + } + + const base64 = btoa(binary); + + const responseHeaders = {}; + resp.headers.forEach((v, k) => { + responseHeaders[k] = v; + }); + + return json({ + s: resp.status, + h: responseHeaders, + b: base64 + }); + + } catch (err) { + return json({ e: String(err) }, 500); + } + } +}; + +function json(obj, status = 200) { + return new Response(JSON.stringify(obj), { + status, + headers: { + "content-type": "application/json" + } + }); +} diff --git a/config.example.json b/config.example.json index fbd6acbb..438e3092 100644 --- a/config.example.json +++ b/config.example.json @@ -3,6 +3,8 @@ "google_ip": "216.239.38.120", "front_domain": "www.google.com", "script_id": "YOUR_APPS_SCRIPT_DEPLOYMENT_ID", + "cfw_script_id": "", + "cfw_hosts": [], "auth_key": "CHANGE_ME_TO_A_STRONG_SECRET", "listen_host": "127.0.0.1", "listen_port": 8085, diff --git a/config.full.example.json b/config.full.example.json index 106112eb..801fc5c1 100644 --- a/config.full.example.json +++ b/config.full.example.json @@ -3,6 +3,8 @@ "google_ip": "216.239.38.120", "front_domain": "www.google.com", "script_id": "YOUR_APPS_SCRIPT_DEPLOYMENT_ID", + "cfw_script_id": "", + "cfw_hosts": [], "auth_key": "CHANGE_ME_TO_A_STRONG_SECRET", "listen_host": "127.0.0.1", "listen_port": 8085, diff --git a/src/config.rs b/src/config.rs index a5808266..7ed87a18 100644 --- a/src/config.rs +++ b/src/config.rs @@ -62,6 +62,22 @@ pub struct Config { pub script_id: Option, #[serde(default)] pub script_ids: Option, + /// Optional Apps Script deployment ID(s) for Cloudflare-Worker-backed + /// relay scripts (`assets/apps_script/CodeHybrid.gs`). Requests to hosts matched by + /// `cfw_hosts` are sent to this pool instead of the default `script_id(s)`. + /// + /// Accepts either a single string or an array. + #[serde(default)] + pub cfw_script_id: Option, + #[serde(default)] + pub cfw_script_ids: Option, + /// Host routing table for "send this site through CFW-backed script IDs". + /// + /// Matching is case-insensitive. Entries support exact hostnames + /// ("x.com") and leading-dot suffixes (".x.com" matches all subdomains). + /// Requires `cfw_script_id` / `cfw_script_ids` to be set. + #[serde(default)] + pub cfw_hosts: Vec, #[serde(default)] pub auth_key: String, #[serde(default = "default_listen_host")] @@ -111,7 +127,7 @@ pub struct Config { pub max_ips_to_scan: usize, #[serde(default = "default_scan_batch_size")] - pub scan_batch_size:usize, + pub scan_batch_size: usize, #[serde(default = "default_google_ip_validation")] pub google_ip_validation: bool, @@ -207,10 +223,18 @@ pub struct Config { pub disable_padding: bool, } -fn default_fetch_ips_from_api() -> bool { false } -fn default_max_ips_to_scan() -> usize { 100 } -fn default_scan_batch_size() -> usize {500} -fn default_google_ip_validation() -> bool {true} +fn default_fetch_ips_from_api() -> bool { + false +} +fn default_max_ips_to_scan() -> usize { + 100 +} +fn default_scan_batch_size() -> usize { + 500 +} +fn default_google_ip_validation() -> bool { + true +} fn default_google_ip() -> String { "216.239.38.120".into() @@ -267,6 +291,21 @@ impl Config { "scan_batch_size must be greater than 0".into(), )); } + if !self.cfw_hosts.is_empty() { + let ids = self.cfw_script_ids_resolved(); + if ids.is_empty() { + return Err(ConfigError::Invalid( + "cfw_hosts is set but cfw_script_id / cfw_script_ids is missing".into(), + )); + } + for id in &ids { + if id.is_empty() || id == "YOUR_APPS_SCRIPT_DEPLOYMENT_ID" { + return Err(ConfigError::Invalid( + "cfw_script_id is not set — deploy assets/apps_script/CodeHybrid.gs and paste its Deployment ID".into(), + )); + } + } + } if self.socks5_port == Some(self.listen_port) { return Err(ConfigError::Invalid( "listen_port and socks5_port must be different".into(), @@ -296,6 +335,16 @@ impl Config { } Vec::new() } + + pub fn cfw_script_ids_resolved(&self) -> Vec { + if let Some(s) = &self.cfw_script_ids { + return s.clone().into_vec(); + } + if let Some(s) = &self.cfw_script_id { + return s.clone().into_vec(); + } + Vec::new() + } } #[cfg(test)] @@ -355,7 +404,8 @@ mod tests { "mode": "google_only" }"#; let cfg: Config = serde_json::from_str(s).unwrap(); - cfg.validate().expect("google_only must validate without script_id / auth_key"); + cfg.validate() + .expect("google_only must validate without script_id / auth_key"); assert_eq!(cfg.mode_kind().unwrap(), Mode::GoogleOnly); } @@ -429,6 +479,32 @@ mod tests { let cfg: Config = serde_json::from_str(s).unwrap(); assert!(cfg.validate().is_err()); } + + #[test] + fn cfw_hosts_require_cfw_script_id() { + let s = r#"{ + "mode": "apps_script", + "auth_key": "SECRET", + "script_id": "MAIN", + "cfw_hosts": ["x.com", ".twitter.com"] + }"#; + let cfg: Config = serde_json::from_str(s).unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn cfw_hosts_with_cfw_script_id_validate() { + let s = r#"{ + "mode": "apps_script", + "auth_key": "SECRET", + "script_id": "MAIN", + "cfw_script_ids": ["CFW1", "CFW2"], + "cfw_hosts": ["x.com", ".twitter.com"] + }"#; + let cfg: Config = serde_json::from_str(s).unwrap(); + assert_eq!(cfg.cfw_script_ids_resolved(), vec!["CFW1", "CFW2"]); + cfg.validate().unwrap(); + } } #[cfg(test)] diff --git a/src/domain_fronter.rs b/src/domain_fronter.rs index 21b14173..b383457f 100644 --- a/src/domain_fronter.rs +++ b/src/domain_fronter.rs @@ -84,6 +84,8 @@ pub struct DomainFronter { http_host: &'static str, auth_key: String, script_ids: Vec, + cfw_script_ids: Vec, + cfw_hosts: Vec, script_idx: AtomicUsize, /// Fan-out factor: fire this many Apps Script instances in parallel /// per request and return first success. `<= 1` = off. @@ -278,6 +280,8 @@ impl DomainFronter { normalize_x_graphql: config.normalize_x_graphql, cert_hint_shown: std::sync::atomic::AtomicBool::new(false), script_ids, + cfw_script_ids: config.cfw_script_ids_resolved(), + cfw_hosts: config.cfw_hosts.clone(), script_idx: AtomicUsize::new(0), tls_connector, pool: Arc::new(Mutex::new(Vec::new())), @@ -395,14 +399,21 @@ impl DomainFronter { } pub fn next_script_id(&self) -> String { - let n = self.script_ids.len(); + self.next_script_id_from(&self.script_ids) + } + + fn next_script_id_from(&self, pool: &[String]) -> String { + let n = pool.len(); + if n == 0 { + return String::new(); + } let mut bl = self.blacklist.lock().unwrap(); let now = Instant::now(); bl.retain(|_, until| *until > now); for _ in 0..n { let idx = self.script_idx.fetch_add(1, Ordering::Relaxed); - let sid = &self.script_ids[idx % n]; + let sid = &pool[idx % n]; if !bl.contains_key(sid) { return sid.clone(); } @@ -413,7 +424,7 @@ impl DomainFronter { bl.remove(&sid); return sid; } - self.script_ids[0].clone() + pool[0].clone() } /// Pick `want` distinct non-blacklisted script IDs for a parallel fan-out @@ -421,7 +432,11 @@ impl DomainFronter { /// IDs available. Advances the round-robin index by `want` to spread load /// across subsequent calls. fn next_script_ids(&self, want: usize) -> Vec { - let n = self.script_ids.len(); + self.next_script_ids_from(want, &self.script_ids) + } + + fn next_script_ids_from(&self, want: usize, pool: &[String]) -> Vec { + let n = pool.len(); if n == 0 { return vec![]; } @@ -435,13 +450,13 @@ impl DomainFronter { break; } let idx = self.script_idx.fetch_add(1, Ordering::Relaxed); - let sid = &self.script_ids[idx % n]; + let sid = &pool[idx % n]; if !bl.contains_key(sid) && !picked.iter().any(|p| p == sid) { picked.push(sid.clone()); } } if picked.is_empty() { - picked.push(self.script_ids[0].clone()); + picked.push(pool[0].clone()); } picked } @@ -475,9 +490,7 @@ impl DomainFronter { pub(crate) fn record_timeout_strike(&self, script_id: &str) { let now = Instant::now(); let mut counts = self.script_timeouts.lock().unwrap(); - let entry = counts - .entry(script_id.to_string()) - .or_insert((now, 0)); + let entry = counts.entry(script_id.to_string()).or_insert((now, 0)); if now.duration_since(entry.0) > TIMEOUT_STRIKE_WINDOW { *entry = (now, 1); } else { @@ -655,13 +668,22 @@ impl DomainFronter { // Range header is present, skip cache and coalesce entirely. let has_range = headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("range")); let coalescible = is_cacheable_method(method) && body.is_empty() && !has_range; - let key = if coalescible { Some(cache_key(method, url)) } else { None }; + let key = if coalescible { + Some(cache_key(method, url)) + } else { + None + }; let t_start = Instant::now(); if let Some(ref k) = key { if let Some(hit) = self.cache.get(k) { tracing::debug!("cache hit: {}", url); - self.record_site(url, true, hit.len() as u64, t_start.elapsed().as_nanos() as u64); + self.record_site( + url, + true, + hit.len() as u64, + t_start.elapsed().as_nanos() as u64, + ); return hit; } } @@ -694,7 +716,10 @@ impl DomainFronter { } } - let bytes = self.relay_uncoalesced(method, url, headers, body, key.as_deref()).await; + let use_cfw = self.should_use_cfw_for_url(url); + let bytes = self + .relay_uncoalesced(method, url, headers, body, key.as_deref(), use_cfw) + .await; if let Some(ref k) = key { let mut inflight = self.inflight.lock().await; @@ -703,7 +728,12 @@ impl DomainFronter { } } - self.record_site(url, false, bytes.len() as u64, t_start.elapsed().as_nanos() as u64); + self.record_site( + url, + false, + bytes.len() as u64, + t_start.elapsed().as_nanos() as u64, + ); bytes } @@ -772,8 +802,7 @@ impl DomainFronter { return first; } - let probe_range = match validate_probe_range(status, &resp_headers, resp_body, chunk - 1) - { + let probe_range = match validate_probe_range(status, &resp_headers, resp_body, chunk - 1) { Some(r) => r, None => { tracing::warn!( @@ -812,7 +841,9 @@ impl DomainFronter { tracing::info!( "range-parallel: {} bytes total, {} chunks remaining after probe, up to {} in flight", - total, ranges.len(), MAX_PARALLEL, + total, + ranges.len(), + MAX_PARALLEL, ); // Concurrent fetch with `buffered` — preserves input order @@ -878,7 +909,9 @@ impl DomainFronter { // when the parallel stitch can't be trusted. tracing::warn!( "range-parallel: stitched {}/{} bytes for {}; falling back to single GET", - full.len(), total, url, + full.len(), + total, + url, ); return self.relay(method, url, headers, body).await; } @@ -897,11 +930,12 @@ impl DomainFronter { headers: &[(String, String)], body: &[u8], cache_key_opt: Option<&str>, + use_cfw: bool, ) -> Vec { self.relay_calls.fetch_add(1, Ordering::Relaxed); let bytes = match timeout( Duration::from_secs(REQUEST_TIMEOUT_SECS), - self.do_relay_with_retry(method, url, headers, body), + self.do_relay_with_retry(method, url, headers, body, use_cfw), ) .await { @@ -932,7 +966,8 @@ impl DomainFronter { ); } }; - self.bytes_relayed.fetch_add(bytes.len() as u64, Ordering::Relaxed); + self.bytes_relayed + .fetch_add(bytes.len() as u64, Ordering::Relaxed); // Daily-budget counters (reset at 00:00 UTC). Only counts // successful relays — the two error branches above don't reach // here, matching what Google actually billed to quota. @@ -953,21 +988,33 @@ impl DomainFronter { url: &str, headers: &[(String, String)], body: &[u8], + use_cfw: bool, ) -> Result, FronterError> { + let ids_pool = if use_cfw && !self.cfw_script_ids.is_empty() { + &self.cfw_script_ids + } else { + &self.script_ids + }; // Fan-out path: fire N instances in parallel, return first Ok, cancel // the rest. Clamps to number of available script IDs so the single-ID // case is a no-op even if parallel_relay>1 was configured. - let fan = self.parallel_relay.min(self.script_ids.len()).max(1); + let fan = self.parallel_relay.min(ids_pool.len()).max(1); if fan >= 2 { - return self.do_relay_parallel(method, url, headers, body, fan).await; + return self + .do_relay_parallel(method, url, headers, body, fan, ids_pool) + .await; } // Sequential path: one retry on connection failure. - match self.do_relay_once(method, url, headers, body).await { + match self + .do_relay_once(method, url, headers, body, ids_pool) + .await + { Ok(v) => Ok(v), Err(e) => { tracing::debug!("relay attempt 1 failed: {}; retrying", e); - self.do_relay_once(method, url, headers, body).await + self.do_relay_once(method, url, headers, body, ids_pool) + .await } } } @@ -979,9 +1026,10 @@ impl DomainFronter { headers: &[(String, String)], body: &[u8], fan: usize, + ids_pool: &[String], ) -> Result, FronterError> { use futures_util::future::FutureExt; - let ids = self.next_script_ids(fan); + let ids = self.next_script_ids_from(fan, ids_pool); if ids.is_empty() { return Err(FronterError::Relay("no script_ids available".into())); } @@ -990,7 +1038,9 @@ impl DomainFronter { // `select_ok` over them. let mut futs = Vec::with_capacity(ids.len()); for sid in ids { - let fut = self.do_relay_once_with(sid.clone(), method, url, headers, body).boxed(); + let fut = self + .do_relay_once_with(sid.clone(), method, url, headers, body) + .boxed(); futs.push(fut); } @@ -1009,9 +1059,24 @@ impl DomainFronter { url: &str, headers: &[(String, String)], body: &[u8], + ids_pool: &[String], ) -> Result, FronterError> { - let script_id = self.next_script_id(); - self.do_relay_once_with(script_id, method, url, headers, body).await + let script_id = self.next_script_id_from(ids_pool); + self.do_relay_once_with(script_id, method, url, headers, body) + .await + } + + fn should_use_cfw_for_url(&self, url: &str) -> bool { + if self.cfw_script_ids.is_empty() || self.cfw_hosts.is_empty() { + return false; + } + let Ok(parsed) = url::Url::parse(url) else { + return false; + }; + let Some(host) = parsed.host_str() else { + return false; + }; + host_matches_list(host, &self.cfw_hosts) } async fn do_relay_once_with( @@ -1260,7 +1325,10 @@ impl DomainFronter { text } else { let start = text.find('{').ok_or_else(|| { - FronterError::BadResponse(format!("no json in tunnel response: {}", &text[..text.len().min(200)])) + FronterError::BadResponse(format!( + "no json in tunnel response: {}", + &text[..text.len().min(200)] + )) })?; let end = text.rfind('}').ok_or_else(|| { FronterError::BadResponse("no json end in tunnel response".into()) @@ -1357,8 +1425,12 @@ impl DomainFronter { // Follow redirect chain for _ in 0..5 { - if !matches!(status, 301 | 302 | 303 | 307 | 308) { break; } - let Some(loc) = header_get(&resp_headers, "location") else { break; }; + if !matches!(status, 301 | 302 | 303 | 307 | 308) { + break; + } + let Some(loc) = header_get(&resp_headers, "location") else { + break; + }; let (rpath, rhost) = parse_redirect(&loc); let rhost = rhost.unwrap_or_else(|| self.http_host.to_string()); let req = format!( @@ -1367,15 +1439,23 @@ impl DomainFronter { entry.stream.write_all(req.as_bytes()).await?; entry.stream.flush().await?; let (s, h, b) = read_http_response(&mut entry.stream).await?; - status = s; resp_headers = h; resp_body = b; + status = s; + resp_headers = h; + resp_body = b; } if status != 200 { - let body_txt = String::from_utf8_lossy(&resp_body).chars().take(200).collect::(); + let body_txt = String::from_utf8_lossy(&resp_body) + .chars() + .take(200) + .collect::(); if should_blacklist(status, &body_txt) { self.blacklist_script(&script_id, &format!("HTTP {}", status)); } - return Err(FronterError::Relay(format!("batch tunnel HTTP {}: {}", status, body_txt))); + return Err(FronterError::Relay(format!( + "batch tunnel HTTP {}: {}", + status, body_txt + ))); } let text = std::str::from_utf8(&resp_body) @@ -1386,20 +1466,30 @@ impl DomainFronter { text } else { let start = text.find('{').ok_or_else(|| { - FronterError::BadResponse(format!("no json in batch response: {}", &text[..text.len().min(200)])) - })?; - let end = text.rfind('}').ok_or_else(|| { - FronterError::BadResponse("no json end in batch response".into()) + FronterError::BadResponse(format!( + "no json in batch response: {}", + &text[..text.len().min(200)] + )) })?; + let end = text + .rfind('}') + .ok_or_else(|| FronterError::BadResponse("no json end in batch response".into()))?; &text[start..=end] }; - tracing::debug!("batch response body: {}", &json_str[..json_str.len().min(500)]); + tracing::debug!( + "batch response body: {}", + &json_str[..json_str.len().min(500)] + ); let resp: BatchTunnelResponse = match serde_json::from_str(json_str) { Ok(v) => v, Err(e) => { - tracing::error!("batch JSON parse error: {} — body: {}", e, &json_str[..json_str.len().min(300)]); + tracing::error!( + "batch JSON parse error: {} — body: {}", + e, + &json_str[..json_str.len().min(300)] + ); return Err(FronterError::Json(e)); } }; @@ -1442,7 +1532,9 @@ fn split_response(raw: &[u8]) -> Option<(u16, Vec<(String, String)>, &[u8])> { let mut lines = head.split(|&b| b == b'\n'); let status_line = lines.next()?; // Status line: "HTTP/1.1 206 Partial Content" - let status_line = std::str::from_utf8(status_line).ok()?.trim_end_matches('\r'); + let status_line = std::str::from_utf8(status_line) + .ok()? + .trim_end_matches('\r'); let mut parts = status_line.splitn(3, ' '); let _version = parts.next()?; let code = parts.next()?.parse::().ok()?; @@ -1607,7 +1699,10 @@ fn normalize_x_graphql_url(url: &str) -> String { // Split host from the rest. We accept both "x.com" and common legacy // forms; the Python patch only checks x.com so we do the same to be // safe about the endpoint actually accepting truncated requests. - let Some(rest) = url.strip_prefix("https://").or_else(|| url.strip_prefix("http://")) else { + let Some(rest) = url + .strip_prefix("https://") + .or_else(|| url.strip_prefix("http://")) + else { return url.to_string(); }; let Some(slash) = rest.find('/') else { @@ -1636,7 +1731,11 @@ fn normalize_x_graphql_url(url: &str) -> String { Some(amp) => &query[..amp], None => query, }; - let scheme = if url.starts_with("https://") { "https://" } else { "http://" }; + let scheme = if url.starts_with("https://") { + "https://" + } else { + "http://" + }; format!("{}{}{}?{}", scheme, host, path, new_query) } @@ -1801,7 +1900,10 @@ fn extract_host(url: &str) -> Option { let after_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url); let authority = after_scheme.split('/').next().unwrap_or(""); // Strip userinfo if present. - let authority = authority.rsplit_once('@').map(|(_, a)| a).unwrap_or(authority); + let authority = authority + .rsplit_once('@') + .map(|(_, a)| a) + .unwrap_or(authority); // Strip port. Handle IPv6 literals in brackets. let host = if let Some(stripped) = authority.strip_prefix('[') { // [::1]:443 -> ::1 @@ -1816,6 +1918,22 @@ fn extract_host(url: &str) -> Option { } } +fn host_matches_list(host: &str, list: &[String]) -> bool { + let h = host.to_ascii_lowercase(); + let h = h.trim_end_matches('.'); + list.iter().any(|entry| { + let e = entry.trim().trim_end_matches('.').to_ascii_lowercase(); + if e.is_empty() { + return false; + } + if let Some(suffix) = e.strip_prefix('.') { + h == suffix || h.ends_with(&format!(".{}", suffix)) + } else { + h == e + } + }) +} + /// The default pool of SNI names that share the Google Front End with /// `www.google.com`. Used both when auto-expanding from `front_domain` and /// when the UI wants to show the starting candidates for the SNI editor. @@ -1953,7 +2071,12 @@ fn strip_brotli_from_accept_encoding(value: &str) -> String { let kept: Vec<&str> = parts .into_iter() .filter(|p| { - let tok = p.split(';').next().unwrap_or("").trim().to_ascii_lowercase(); + let tok = p + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase(); tok != "br" && tok != "zstd" }) .collect(); @@ -1976,10 +2099,17 @@ fn header_get(headers: &[(String, String)], name: &str) -> Option { fn parse_redirect(location: &str) -> (String, Option) { // Absolute URL: http(s)://host/path?query - if let Some(rest) = location.strip_prefix("https://").or_else(|| location.strip_prefix("http://")) { + if let Some(rest) = location + .strip_prefix("https://") + .or_else(|| location.strip_prefix("http://")) + { let slash = rest.find('/').unwrap_or(rest.len()); let host = rest[..slash].to_string(); - let path = if slash < rest.len() { rest[slash..].to_string() } else { "/".into() }; + let path = if slash < rest.len() { + rest[slash..].to_string() + } else { + "/".into() + }; return (path, Some(host)); } // Relative path. @@ -1988,17 +2118,22 @@ fn parse_redirect(location: &str) -> (String, Option) { /// Read a single HTTP/1.1 response from the stream. Keep-alive safe: respects /// Content-Length or chunked transfer-encoding. -async fn read_http_response(stream: &mut S) -> Result<(u16, Vec<(String, String)>, Vec), FronterError> +async fn read_http_response( + stream: &mut S, +) -> Result<(u16, Vec<(String, String)>, Vec), FronterError> where S: tokio::io::AsyncRead + Unpin, { let mut buf = Vec::with_capacity(8192); let mut tmp = [0u8; 8192]; let header_end = loop { - let n = timeout(Duration::from_secs(10), stream.read(&mut tmp)).await + let n = timeout(Duration::from_secs(10), stream.read(&mut tmp)) + .await .map_err(|_| FronterError::Timeout)??; if n == 0 { - return Err(FronterError::BadResponse("connection closed before headers".into())); + return Err(FronterError::BadResponse( + "connection closed before headers".into(), + )); } buf.extend_from_slice(&tmp[..n]); if let Some(pos) = find_double_crlf(&buf) { @@ -2024,8 +2159,8 @@ where } let mut body = buf[header_end + 4..].to_vec(); - let content_length: Option = header_get(&headers_out, "content-length") - .and_then(|v| v.parse().ok()); + let content_length: Option = + header_get(&headers_out, "content-length").and_then(|v| v.parse().ok()); let te = header_get(&headers_out, "transfer-encoding").unwrap_or_default(); let is_chunked = te.to_ascii_lowercase().contains("chunked"); @@ -2035,7 +2170,8 @@ where while body.len() < cl { let need = cl - body.len(); let want = need.min(tmp.len()); - let n = timeout(Duration::from_secs(20), stream.read(&mut tmp[..want])).await + let n = timeout(Duration::from_secs(20), stream.read(&mut tmp[..want])) + .await .map_err(|_| FronterError::Timeout)??; if n == 0 { return Err(FronterError::BadResponse( @@ -2075,18 +2211,18 @@ where let mut out: Vec = Vec::new(); let mut tmp = [0u8; 16384]; loop { - let size_line_owned = std::str::from_utf8(&read_crlf_line(stream, &mut buf, &mut tmp).await?) - .map_err(|_| FronterError::BadResponse("bad chunk size".into()))? - .trim() - .to_string(); + let size_line_owned = + std::str::from_utf8(&read_crlf_line(stream, &mut buf, &mut tmp).await?) + .map_err(|_| FronterError::BadResponse("bad chunk size".into()))? + .trim() + .to_string(); if size_line_owned.is_empty() { continue; } - let size = usize::from_str_radix( - size_line_owned.split(';').next().unwrap_or(""), - 16, - ) - .map_err(|_| FronterError::BadResponse(format!("bad chunk size '{}'", size_line_owned)))?; + let size = usize::from_str_radix(size_line_owned.split(';').next().unwrap_or(""), 16) + .map_err(|_| { + FronterError::BadResponse(format!("bad chunk size '{}'", size_line_owned)) + })?; if size == 0 { loop { if read_crlf_line(stream, &mut buf, &mut tmp).await?.is_empty() { @@ -2095,7 +2231,8 @@ where } } while buf.len() < size + 2 { - let n = timeout(Duration::from_secs(20), stream.read(&mut tmp)).await + let n = timeout(Duration::from_secs(20), stream.read(&mut tmp)) + .await .map_err(|_| FronterError::Timeout)??; if n == 0 { return Err(FronterError::BadResponse( @@ -2128,7 +2265,8 @@ where buf.drain(..idx + 2); return Ok(line); } - let n = timeout(Duration::from_secs(20), stream.read(tmp)).await + let n = timeout(Duration::from_secs(20), stream.read(tmp)) + .await .map_err(|_| FronterError::Timeout)??; if n == 0 { return Err(FronterError::BadResponse( @@ -2154,10 +2292,11 @@ fn parse_status_line(line: &str) -> Result { // "HTTP/1.1 200 OK" let mut parts = line.split_whitespace(); let _version = parts.next(); - let code = parts.next().ok_or_else(|| { - FronterError::BadResponse(format!("bad status line: {}", line)) - })?; - code.parse::().map_err(|_| FronterError::BadResponse(format!("bad status code: {}", code))) + let code = parts + .next() + .ok_or_else(|| FronterError::BadResponse(format!("bad status line: {}", line)))?; + code.parse::() + .map_err(|_| FronterError::BadResponse(format!("bad status code: {}", code))) } /// Parse the JSON envelope from Apps Script and build a raw HTTP response. @@ -2177,7 +2316,10 @@ fn parse_relay_json(body: &[u8]) -> Result, FronterError> { FronterError::BadResponse(format!("no json in: {}", &text[..text.len().min(200)])) })?; let end = text.rfind('}').ok_or_else(|| { - FronterError::BadResponse(format!("no json end in: {}", &text[..text.len().min(200)])) + FronterError::BadResponse(format!( + "no json end in: {}", + &text[..text.len().min(200)] + )) })?; serde_json::from_str(&text[start..=end])? } @@ -2397,7 +2539,9 @@ pub fn error_response(status: u16, message: &str) -> Vec { } fn html_escape(s: &str) -> String { - s.replace('&', "&").replace('<', "<").replace('>', ">") + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") } // Dangerous "accept anything" TLS verifier, used only when config.verify_ssl=false. @@ -2458,14 +2602,14 @@ mod tests { fn unix_to_ymd_utc_handles_known_epochs() { // Anchors chosen to catch the common off-by-one errors (pre/post // leap day, pre/post epoch, year-end rollover). - assert_eq!(unix_to_ymd_utc(0), (1970, 1, 1)); // epoch - assert_eq!(unix_to_ymd_utc(86_399), (1970, 1, 1)); // one sec before day 2 - assert_eq!(unix_to_ymd_utc(86_400), (1970, 1, 2)); // day 2 starts at midnight - assert_eq!(unix_to_ymd_utc(951_782_400), (2000, 2, 29)); // leap day (Feb 29, 2000) - assert_eq!(unix_to_ymd_utc(951_868_800), (2000, 3, 1)); // day after leap Feb - assert_eq!(unix_to_ymd_utc(1_583_020_800), (2020, 3, 1)); // day after a leap Feb - assert_eq!(unix_to_ymd_utc(1_735_689_599), (2024, 12, 31)); // last sec of 2024 - assert_eq!(unix_to_ymd_utc(1_735_689_600), (2025, 1, 1)); // first sec of 2025 + assert_eq!(unix_to_ymd_utc(0), (1970, 1, 1)); // epoch + assert_eq!(unix_to_ymd_utc(86_399), (1970, 1, 1)); // one sec before day 2 + assert_eq!(unix_to_ymd_utc(86_400), (1970, 1, 2)); // day 2 starts at midnight + assert_eq!(unix_to_ymd_utc(951_782_400), (2000, 2, 29)); // leap day (Feb 29, 2000) + assert_eq!(unix_to_ymd_utc(951_868_800), (2000, 3, 1)); // day after leap Feb + assert_eq!(unix_to_ymd_utc(1_583_020_800), (2020, 3, 1)); // day after a leap Feb + assert_eq!(unix_to_ymd_utc(1_735_689_599), (2024, 12, 31)); // last sec of 2024 + assert_eq!(unix_to_ymd_utc(1_735_689_600), (2025, 1, 1)); // first sec of 2025 } #[test] @@ -2499,7 +2643,7 @@ mod tests { assert!(!pacific_is_dst(2026, 12, 25)); assert!(!pacific_is_dst(2026, 2, 28)); assert!(!pacific_is_dst(2026, 11, 5)); // first Sun of Nov 2026 = Nov 1; Nov 5 is past - // Inside: PDT. + // Inside: PDT. assert!(pacific_is_dst(2026, 6, 1)); assert!(pacific_is_dst(2026, 9, 30)); // Boundary: March 8, 2026 (DST start day) and after = PDT. @@ -2593,7 +2737,7 @@ mod tests { let cases = [ "https://x.com/home", "https://x.com/i/api/2/notifications/view/generic.json", - "https://x.com/i/api/graphql/x/y", // no query + "https://x.com/i/api/graphql/x/y", // no query "https://x.com/i/api/graphql/x/y?features=1&variables=2", // variables not first ]; for u in cases { @@ -2612,14 +2756,35 @@ mod tests { #[test] fn extract_host_strips_scheme_port_path() { - assert_eq!(extract_host("https://example.com/foo"), Some("example.com".into())); - assert_eq!(extract_host("http://foo.bar:8080/x"), Some("foo.bar".into())); - assert_eq!(extract_host("https://user:pw@host.test/x"), Some("host.test".into())); - assert_eq!(extract_host("https://[2001:db8::1]:443/"), Some("2001:db8::1".into())); + assert_eq!( + extract_host("https://example.com/foo"), + Some("example.com".into()) + ); + assert_eq!( + extract_host("http://foo.bar:8080/x"), + Some("foo.bar".into()) + ); + assert_eq!( + extract_host("https://user:pw@host.test/x"), + Some("host.test".into()) + ); + assert_eq!( + extract_host("https://[2001:db8::1]:443/"), + Some("2001:db8::1".into()) + ); assert_eq!(extract_host("API.X.com/foo"), Some("api.x.com".into())); assert_eq!(extract_host(""), None); } + #[test] + fn host_matches_list_exact_and_suffix() { + let list = vec!["x.com".to_string(), ".twitter.com".to_string()]; + assert!(host_matches_list("x.com", &list)); + assert!(host_matches_list("api.twitter.com", &list)); + assert!(host_matches_list("TWITTER.com", &list)); + assert!(!host_matches_list("example.com", &list)); + } + #[test] fn build_sni_pool_extends_for_google() { let p = build_sni_pool("www.google.com"); @@ -2770,7 +2935,10 @@ Content-Length: 45812\r\n\r\n" checked_stitched_range_capacity(MAX_STITCHED_RANGE_BYTES), Some(MAX_STITCHED_RANGE_BYTES as usize), ); - assert_eq!(checked_stitched_range_capacity(MAX_STITCHED_RANGE_BYTES + 1), None); + assert_eq!( + checked_stitched_range_capacity(MAX_STITCHED_RANGE_BYTES + 1), + None + ); assert_eq!(checked_stitched_range_capacity(u64::MAX), None); } @@ -2812,10 +2980,15 @@ hello"; fn blacklist_heuristics() { assert!(should_blacklist(429, "")); assert!(should_blacklist(403, "quota")); - assert!(should_blacklist(500, "Service invoked too many times per day: urlfetch")); + assert!(should_blacklist( + 500, + "Service invoked too many times per day: urlfetch" + )); assert!(!should_blacklist(200, "")); assert!(!should_blacklist(502, "bad gateway")); - assert!(looks_like_quota_error("Exception: Service invoked too many times per day")); + assert!(looks_like_quota_error( + "Exception: Service invoked too many times per day" + )); assert!(looks_like_quota_error( "Exception: Bandbreitenkontingent überschritten: https://example.com. Verringern Sie die Datenübertragungsrate." )); @@ -2873,7 +3046,11 @@ hello"; let err = read_http_response(&mut server).await.unwrap_err(); match err { FronterError::BadResponse(msg) => { - assert!(msg.contains("full response body"), "unexpected error: {}", msg); + assert!( + msg.contains("full response body"), + "unexpected error: {}", + msg + ); } other => panic!("unexpected error: {}", other), }