snowstorm/control/
mod.rs

1use crate::{
2    config::serializable::GeoIp,
3    config::serializable::Profile,
4    config::serializable::SerializableRouteDescriptor,
5    control::odoh::do_odoh_request,
6    net::{get_default_interface, get_tun_interface, swarm},
7};
8use hyper::Method;
9use libp2p::Multiaddr;
10use once_cell::sync::Lazy;
11use parking_lot::RwLock;
12use reqwest;
13use serde::{Deserialize, Serialize};
14use std::{net::IpAddr, sync::Arc, time::Duration};
15
16#[cfg(target_os = "android")]
17pub mod android;
18pub mod odoh;
19
20static DEFAULT_BACKEND_URL: &str = "https://snowstorm.love/api";
21
22static BACKEND_URL: Lazy<Arc<RwLock<String>>> =
23    Lazy::new(|| Arc::new(RwLock::new(DEFAULT_BACKEND_URL.to_string())));
24
25pub fn get_backend_url() -> String {
26    BACKEND_URL.read().clone()
27}
28
29pub fn set_backend_url(url: String) {
30    let mut guard = BACKEND_URL.write();
31    *guard = url;
32}
33
34static CONTROL_PEER_OVERRIDE: Lazy<Arc<RwLock<Option<Multiaddr>>>> =
35    Lazy::new(|| Arc::new(RwLock::new(None)));
36
37pub fn get_control_peer_override() -> Option<Multiaddr> {
38    CONTROL_PEER_OVERRIDE.read().clone()
39}
40
41pub fn set_control_peer_override(mad: Option<Multiaddr>) {
42    let mut guard = CONTROL_PEER_OVERRIDE.write();
43    *guard = mad;
44}
45
46type SocketProtectFn = Box<dyn Fn(&i32) -> () + Send + Sync>;
47
48static SOCKET_PROTECT_FN: Lazy<Arc<RwLock<Option<SocketProtectFn>>>> =
49    Lazy::new(|| Arc::new(RwLock::new(None)));
50
51pub fn set_socket_protect_fn(f: SocketProtectFn) {
52    let mut guard = SOCKET_PROTECT_FN.write();
53    *guard = Some(f);
54}
55
56pub fn protect_socket(fd: &i32) {
57    match SOCKET_PROTECT_FN.read().as_ref() {
58        Some(protect_fn) => {
59            protect_fn(fd);
60        }
61        None => {}
62    }
63}
64
65#[derive(Debug, Clone, thiserror::Error)]
66pub enum ReqwestClientError {
67    #[error("Failed to get network interface")]
68    IfaceNotFound,
69}
70
71static DEFAULT_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
72    default_builder()
73        .build()
74        .expect("Failed to create default reqwest client")
75});
76
77pub fn default_client() -> reqwest::Client {
78    DEFAULT_CLIENT.clone()
79}
80
81pub fn default_builder() -> reqwest::ClientBuilder {
82    reqwest::ClientBuilder::new()
83}
84
85fn tun_client_builder() -> Result<reqwest::ClientBuilder, ReqwestClientError> {
86    let iface = get_tun_interface().ok_or(ReqwestClientError::IfaceNotFound)?;
87
88    #[cfg(not(any(target_os = "windows", target_os = "android")))]
89    return Ok(reqwest::Client::builder().interface(&iface.name));
90
91    #[cfg(target_os = "windows")]
92    {
93        let addr = iface.ipv4_addrs().first().cloned().map(|ip| ip.into());
94        return Ok(reqwest::Client::builder().local_address(addr));
95    }
96
97    #[cfg(target_os = "android")]
98    return Ok(default_builder());
99}
100
101fn bypass_client_builder() -> Result<reqwest::ClientBuilder, ReqwestClientError> {
102    #[cfg(target_os = "android")]
103    if get_tun_interface().is_none() {
104        // On Android no need to try to bind if tunnel is off.
105        // Caveat: better don't store the client/builder
106        return Ok(default_builder());
107    }
108
109    let iface = get_default_interface().ok_or(ReqwestClientError::IfaceNotFound)?;
110
111    #[cfg(not(target_os = "windows"))]
112    return Ok(reqwest::Client::builder().interface(&iface.name));
113
114    #[cfg(target_os = "windows")]
115    {
116        let addr = iface.ipv4_addrs().first().cloned().map(|ip| ip.into());
117        return Ok(reqwest::Client::builder().local_address(addr));
118    }
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct UpdateRequest {
123    token: String,
124    peer_id: String,
125    device_id: Option<String>,
126    multiaddrs: Vec<String>,
127}
128
129pub async fn update_multiaddrs(
130    token: String,
131    peer_id: String,
132    device_id: Option<String>,
133    multiaddrs: Vec<&Multiaddr>,
134    public_only: bool,
135) -> Result<(), Box<dyn std::error::Error>> {
136    if multiaddrs.is_empty() {
137        return Err("No multiaddrs provided".into());
138    }
139
140    let mas = multiaddrs
141        .iter()
142        .filter(|ma| {
143            !swarm::multiaddr_is_loopback(ma) && (!public_only || swarm::multiaddr_is_public(ma))
144        })
145        .map(|ma| ma.to_string())
146        .collect::<Vec<String>>();
147
148    let params = UpdateRequest {
149        token,
150        peer_id,
151        device_id,
152        multiaddrs: mas,
153    };
154
155    let url = "auth?method=update".to_string();
156    let value = do_rpc(Method::POST, &url, Some(&params)).await?;
157    let resp: String = serde_json::from_value(value)?;
158
159    match resp.as_str() {
160        "ok" => Ok(()),
161        _ => Err("Failed to update multiaddrs".into()),
162    }
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166pub struct RegisterDeviceResponse {
167    token: String,
168    id: String,
169}
170
171pub async fn register_device(
172    token: &str,
173    peer_id: &str,
174    dev_type: &str,
175    dev_id: &str,
176    dev_name: &str,
177) -> Result<RegisterDeviceResponse, Box<dyn std::error::Error>> {
178    let query = serde_urlencoded::to_string(&[
179        ("method", "device"),
180        ("Token", token),
181        ("PeerID", peer_id),
182        ("Type", dev_type),
183        ("ID", dev_id),
184        ("Name", dev_name),
185    ])?;
186    let url = format!("auth?{}", query);
187    let value = do_rpc_get(&url).await?;
188    let dev_response: RegisterDeviceResponse = serde_json::from_value(value)?;
189
190    Ok(dev_response)
191}
192
193pub async fn unregister(token: &str, peer_id: Option<&str>, device_id: Option<&str>) -> Result<(), Box<dyn std::error::Error>> {
194    match (peer_id, device_id) {
195        (Some(_), Some(_)) => {
196            Err("Provide either peer_id or device_id, not both".into())
197        }
198        (Some(pid), None) => {
199            let params = UpdateRequest {
200                token: token.to_string(),
201                peer_id: pid.to_string(),
202                device_id: None,
203                multiaddrs: Vec::new(),
204            };
205
206            let url = "auth?method=unregister";
207            let value = do_rpc(Method::POST, &url, Some(&params)).await?;
208            let resp: String = serde_json::from_value(value)?;
209
210            match resp.as_str() {
211                "ok" => Ok(()),
212                _ => Err("Failed to unregister".into()),
213            }
214        }
215        (None, Some(did)) => {
216            let params = UpdateRequest {
217                token: token.to_string(),
218                peer_id: "".to_string(),
219                device_id: Some(did.to_string()),
220                multiaddrs: Vec::new(),
221            };
222
223            let url = "auth?method=unregister";
224            let value = do_rpc(Method::POST, &url, Some(&params)).await?;
225            let resp: String = serde_json::from_value(value)?;
226
227            match resp.as_str() {
228                "ok" => Ok(()),
229                _ => Err("Failed to unregister".into()),
230            }
231        }
232        (None, None) => {
233            Err("Provide either peer_id or device_id".into())
234        }
235    }
236}
237
238pub async fn auth_with_license(license: String) -> Result<String, Box<dyn std::error::Error>> {
239    let url = format!("auth?method=license&License={}", license);
240    let value = do_rpc_get(&url).await?;
241    let resp: String = serde_json::from_value(value)?;
242
243    Ok(resp)
244}
245
246#[derive(Debug, Serialize, Deserialize)]
247pub struct RefreshResponse {
248    #[serde(rename = "Token")]
249    token: String,
250
251    #[serde(rename = "SubscriptionData", default)]
252    pub subscription_data: serde_json::Value,
253
254    #[serde(rename = "Email", default)]
255    pub email: String,
256
257    #[serde(rename = "Kind", default)]
258    pub kind: String,
259}
260
261pub async fn refresh_token(token: String) -> Result<RefreshResponse, Box<dyn std::error::Error>> {
262    let url = format!("auth?method=refresh&Token={}", token);
263    let value = do_rpc_get(&url).await?;
264    let resp: RefreshResponse = serde_json::from_value(value)?;
265
266    Ok(resp)
267}
268
269pub async fn get_control_peer() -> Result<Multiaddr, Box<dyn std::error::Error>> {
270    #[derive(Debug, Deserialize)]
271    struct YukigoInfo {
272        #[allow(dead_code)]
273        key: String,
274
275        #[allow(dead_code)]
276        peer: String,
277    }
278
279    let resp = do_rpc_get("auth?method=info").await?;
280    let kp: YukigoInfo = serde_json::from_value(resp)?;
281    let ma = Multiaddr::try_from(kp.peer).unwrap();
282
283    Ok(ma)
284}
285
286pub async fn do_rpc_get(path: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
287    do_rpc::<()>(Method::GET, path, None).await
288}
289
290pub async fn do_rpc<T: Serialize>(
291    method: Method,
292    path: &str,
293    body: Option<&T>,
294) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
295    let url = format!("{}/{}", get_backend_url(), path);
296    let client = default_client();
297    let mut req = client.request(method, url);
298    if let Some(body) = body {
299        req = req.json(body);
300    }
301    req = req.timeout(Duration::from_secs(10));
302
303    // Should be safe to clone, since we're not streaming
304    let resp = req.try_clone().unwrap().send().await;
305    match resp {
306        Ok(resp) => {
307            let resp = resp.error_for_status()?.json().await?;
308            return Ok(resp);
309        }
310        Err(err) => {
311            log::error!("direct request failed: {}", err);
312            let req = req.build()?;
313            let raw_resp = do_odoh_request(&req).await?;
314            let resp = serde_json::from_slice(&raw_resp)?;
315            return Ok(resp);
316        }
317    }
318}
319
320pub async fn post_measurement<T: Serialize + ?Sized>(
321    query: &T,
322    payload: &serde_json::Value,
323) -> Result<(), Box<dyn std::error::Error>> {
324    let url = format!("{}/tp", get_backend_url());
325    default_builder()
326        .timeout(Duration::from_secs(5))
327        .build()
328        .map_err(|err: reqwest::Error| {
329            log::error!("Failed to create HTTP client: {}", err);
330            "measurement failed"
331        })?
332        .post(url)
333        .query(query)
334        .header("Content-Type", "application/json")
335        .json(payload)
336        .send()
337        .await
338        .map_err(|err| {
339            // Mainly timeouts and IO errors
340            log::error!("Failed to send analytics event: {}", err);
341            err
342        })?
343        .error_for_status()
344        .map_err(|err| {
345            // HTTP status code errors
346            log::error!("Failed to send analytics event: {}", err);
347            "measurement failed"
348        })?;
349
350    Ok(())
351}
352
353pub async fn get_profile(token: &str) -> Result<Profile, Box<dyn std::error::Error>> {
354    match token.len() {
355        0 => {
356            return Err("Cannot get profile without access token".into());
357        }
358        _ => {}
359    }
360    let url = format!("auth?method=profile&Token={}", token);
361    let value = do_rpc_get(&url).await?;
362    let mut resp: Profile = serde_json::from_value(value)?;
363
364    if resp.peers.is_none() {
365        resp.peers = Some(Vec::new());
366    }
367
368    Ok(resp)
369}
370
371pub async fn validate_route(
372    token: &str,
373    route_id: &str,
374) -> Result<bool, Box<dyn std::error::Error>> {
375    let url = format!("route?method=validate&Token={}&Route={}", token, route_id);
376    do_rpc_get(&url).await?;
377
378    Ok(true)
379}
380
381pub async fn update_profile(
382    token: &str,
383    state_header: &str,
384    local_state: &str,
385) -> Result<(), Box<dyn std::error::Error>> {
386    #[derive(Serialize)]
387    struct UpdateProfileParams<'a> {
388        token: &'a str,
389        state_header: &'a str,
390        local_state: &'a str,
391    }
392
393    let params = UpdateProfileParams {
394        token,
395        state_header,
396        local_state,
397    };
398    let url = "auth?method=update".to_string();
399    do_rpc(Method::POST, &url, Some(&params)).await?;
400
401    Ok(())
402}
403#[derive(Serialize)]
404struct RouteParams<'a> {
405    #[serde(rename = "Token")]
406    token: &'a str,
407
408    #[serde(rename = "Share")]
409    secrets: Vec<&'a str>,
410}
411
412pub async fn get_routes(
413    token: &str,
414    secrets: Vec<&str>,
415) -> Result<Vec<SerializableRouteDescriptor>, Box<dyn std::error::Error>> {
416    let params = RouteParams { token, secrets };
417    let url = "route".to_string();
418    let value = do_rpc(Method::POST, &url, Some(&params)).await?;
419    let resp: Vec<SerializableRouteDescriptor> = serde_json::from_value(value)?;
420
421    Ok(resp)
422}
423
424pub async fn get_routes_bypass(
425    token: &str,
426) -> Result<Vec<SerializableRouteDescriptor>, Box<dyn std::error::Error>> {
427    let params = RouteParams {
428        token,
429        secrets: vec![],
430    };
431    let url = format!("{}/route", get_backend_url());
432
433    #[cfg(target_os = "android")]
434    {
435        let request = android::AndroidBypassTunRequest {
436            url: url.clone(),
437            method: "POST".into(),
438            body: Some(serde_json::to_vec(&params)?),
439        };
440        match android::get_bypass_tun_request(request) {
441            Some(res) => return Ok(res),
442            None => {}
443        }
444    }
445
446    // We're not using do_rpc to make sure no proxy is involved
447    let client = bypass_client_builder()?
448        .timeout(Duration::from_secs(5))
449        .build()?;
450
451    let res = client
452        .post(url)
453        .body(serde_json::to_vec(&params)?)
454        .send()
455        .await?
456        .json()
457        .await?;
458    Ok(res)
459}
460
461pub async fn get_origin_ip() -> Result<GeoIp, Box<dyn std::error::Error + Send + Sync>> {
462    let url = format!("{}/meta", get_backend_url());
463
464    #[cfg(target_os = "android")]
465    {
466        let request = android::AndroidBypassTunRequest {
467            url: url.clone(),
468            method: "GET".into(),
469            body: None,
470        };
471        match android::get_bypass_tun_request(request) {
472            Some(ip) => return Ok(ip),
473            None => {}
474        }
475    }
476
477    // We're not using do_rpc to make sure no proxy is involved
478    let client = bypass_client_builder()?
479        .timeout(Duration::from_secs(7))
480        .build()?;
481    let ip: GeoIp = client
482        .get(url)
483        .send()
484        .await?
485        .json()
486        .await?;
487    Ok(ip)
488}
489
490/**
491 * Try various endpoints to figure our our current public IP.
492 * Begins with Snowstorm's official backend, then fallls back to other options.
493 */
494pub async fn get_public_ip(addr: Option<IpAddr>) -> Result<GeoIp, Box<dyn std::error::Error + Send + Sync>> {
495    match get_public_ip_direct(addr).await {
496        Ok(ip) => Ok(ip),
497        Err(err) => {
498            get_public_ip_cloudflare().await.map_err(|_| err)
499        }
500    }
501}
502
503/**
504 * Underlying RPC to Snowstorm backend (yukigo) for our official meta endpoint.
505 * TODO: Set the timeout duration via a configuration option.
506 */
507async fn get_public_ip_direct(addr: Option<IpAddr>) -> Result<GeoIp, Box<dyn std::error::Error + Send + Sync>> {
508    let client = tun_client_builder()
509        .unwrap_or_else(|_| default_builder())
510        .timeout(Duration::from_secs(2))
511        .local_address(addr)
512        .build()?;
513
514    let url = format!("{}/meta", get_backend_url());
515    let ip: GeoIp = client.get(url).send().await?.json().await?;
516
517    Ok(ip)
518}
519
520pub async fn get_exit_ip() -> Result<GeoIp, Box<dyn std::error::Error + Send + Sync>> {
521    // We're not using do_rpc to make sure no proxy is involved
522    let client = tun_client_builder()?
523        .timeout(Duration::from_secs(3))
524        .build()?;
525    let url = format!("{}/meta", get_backend_url());
526    let ip: GeoIp = client
527        .get(url)
528        .send()
529        .await?
530        .json()
531        .await?;
532
533    Ok(ip)
534}
535
536/**
537 * Falback STUN server checks
538 */
539pub async fn get_public_ip_cloudflare() -> Result<GeoIp, Box<dyn std::error::Error + Send + Sync>> {
540    let client = tun_client_builder()
541        .unwrap_or_else(|_| default_builder())
542        .timeout(Duration::from_secs(3))
543        .build()?;
544    let url = "https://cloudflare.com/cdn-cgi/trace".to_string();
545    let ip = client.get(url).send().await?.text().await?;
546    let mut geoip = GeoIp{
547        ip: "".to_string(),
548        geo_ip: None,
549        active: false,
550    };
551
552    for line in ip.lines() {
553        let mut parts = line.splitn(2, '=');
554        let key = parts.next().unwrap_or("");
555        let value = parts.next().unwrap_or("");
556        match key {
557            "ip" => geoip.ip = value.to_string(),
558            _ => {}
559        }
560    }
561    Ok(geoip)
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    #[tokio::test]
569    async fn test_cloudflare_ip() {
570        let ip = get_public_ip_cloudflare().await.unwrap();
571        println!("Cloudflare IP: {:?}", ip);
572        assert!(!ip.ip.is_empty());
573    }
574}