snowstorm/client/
mod.rs

1use libp2p::metrics::Registry;
2use log::info;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6use std::future::Future;
7use std::result::Result as StdResult;
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::runtime::{Builder as RuntimeBuilder, Handle as RuntimeHandle, Runtime};
13use tokio::sync::{mpsc, oneshot, watch};
14use tokio::task::{JoinError, JoinHandle};
15use tokio::time::sleep;
16use tokio_util::sync::CancellationToken;
17
18use crate::config::serializable::GeoIp;
19use crate::config::Config;
20use crate::control;
21use crate::geoip::TunnelInfo;
22use crate::identity::Role;
23use crate::net::connected_exits::ConnectedExits;
24use crate::net::event::Event;
25use crate::net::{get_outbound_ips, metrics};
26
27pub mod provider;
28
29pub type Result<T> = StdResult<T, ClientError>;
30
31#[derive(Error, Debug)]
32pub enum ClientError {
33    #[error("Failed to start SDK")]
34    StartSdkFailed,
35
36    #[error("TUN device not configured")]
37    TunDeviceNotConfigured,
38
39    #[error("Tunnel provider error")]
40    TunnelProviderError,
41
42    #[error("Could not connect to exit")]
43    ExitConnectionFailed,
44
45    #[error("Timeout connecting to exit")]
46    ExitConnectionTimeout,
47
48    #[error("Timeout connecting to exit")]
49    Cancelled,
50
51    #[error(transparent)]
52    Io(#[from] std::io::Error),
53
54    #[error(transparent)]
55    Fmt(#[from] std::fmt::Error),
56
57    #[error(transparent)]
58    TokioJoin(#[from] JoinError),
59}
60
61pub trait TunnelProvider: Sync + Send + 'static {
62    type Device: Send + Unpin + AsyncRead + AsyncWrite + 'static;
63    type State: Send + Sync + Clone + 'static;
64
65    fn get_device(&self) -> Result<Self::Device>;
66    fn state(&self) -> Self::State;
67
68    fn protect_socket_fd(&self, _fd: i32) {}
69}
70
71#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
72pub enum Status {
73    Idle,
74    Starting,
75    Running,
76}
77
78// Format compatible with what FE expects at the moment.
79impl std::fmt::Display for Status {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            Status::Idle => write!(f, "idle"),
83            Status::Starting => write!(f, "starting"),
84            Status::Running => write!(f, "running"),
85        }
86    }
87}
88
89#[derive(Debug, Serialize, Deserialize, Clone)]
90pub struct State {
91    pub id: String,
92    pub status: Status,
93    pub running_since: Option<String>,
94    pub multiaddrs: Vec<String>,
95    pub tunnel_info: TunnelInfo,
96    pub connected_clients: HashSet<String>,
97    pub connected_exits: HashSet<String>,
98    pub roles: Vec<Role>,
99    pub is_offline: bool,
100    pub last_notify_command: i64,
101}
102
103impl Default for State {
104    fn default() -> Self {
105        Self {
106            id: "".into(),
107            status: Status::Idle,
108            multiaddrs: vec![],
109            running_since: None,
110            tunnel_info: TunnelInfo::default(),
111            connected_clients: HashSet::new(),
112            connected_exits: HashSet::new(),
113            roles: vec![],
114            is_offline: false,
115            last_notify_command: i64::MAX,
116        }
117    }
118}
119
120enum TokioEnvironment {
121    OwnRuntime(Runtime),
122    Handle(RuntimeHandle),
123}
124
125pub struct Client<P: TunnelProvider> {
126    tunnel_provider: Arc<Option<P>>,
127    runtime: TokioEnvironment,
128
129    state_rx: watch::Receiver<State>,
130    state_tx: watch::Sender<State>,
131    config: Option<Config>,
132
133    p2p_control: Arc<Mutex<Option<mpsc::Sender<Event>>>>,
134    // TODO: check if RwLock or maybe Mutex with a .take() is better here
135    connected_exits: Arc<RwLock<Option<Arc<ConnectedExits>>>>,
136    tun_shutdown: Mutex<Option<CancellationToken>>,
137    sdk_shutdown: Mutex<Option<CancellationToken>>,
138    sserver_port_rx: watch::Receiver<u16>,
139    sserver_port_tx: watch::Sender<u16>,
140
141    sdk_handle: Option<JoinHandle<()>>,
142    tun_handle: Option<JoinHandle<()>>,
143
144    metrics_registry: Arc<Mutex<Registry>>,
145}
146
147impl<P: TunnelProvider> Client<P> {
148    pub fn new(provider: Option<P>) -> Self {
149        let (state_tx, state_rx) = watch::channel(State::default());
150        let (sserver_port_tx, sserver_port_rx) = watch::channel(64652);
151
152        let runtime = RuntimeBuilder::new_multi_thread()
153            .enable_all()
154            .thread_name("snowstorm-sdk")
155            .build()
156            .unwrap();
157
158        let this = Self {
159            state_rx,
160            state_tx,
161            config: None,
162
163            tunnel_provider: Arc::new(provider),
164            runtime: TokioEnvironment::OwnRuntime(runtime),
165
166            p2p_control: Arc::new(Mutex::new(None)),
167            connected_exits: Arc::new(RwLock::new(None)),
168            tun_shutdown: Mutex::new(None),
169            sdk_shutdown: Mutex::new(None),
170
171            sserver_port_rx,
172            sserver_port_tx,
173
174            sdk_handle: None,
175            tun_handle: None,
176
177            metrics_registry: Default::default(),
178        };
179
180        // For now just check if we're offline when creating the client.
181        this.update_is_offline();
182
183        this
184    }
185
186    pub fn new_with_runtime(provider: Option<P>, handle: RuntimeHandle) -> Self {
187        let (state_tx, state_rx) = watch::channel(State::default());
188        let (sserver_port_tx, sserver_port_rx) = watch::channel(64652);
189
190        Self {
191            state_rx,
192            state_tx,
193            config: None,
194
195            tunnel_provider: Arc::new(provider),
196            runtime: TokioEnvironment::Handle(handle),
197
198            p2p_control: Arc::new(Mutex::new(None)),
199            connected_exits: Arc::new(RwLock::new(None)),
200            tun_shutdown: Mutex::new(None),
201            sdk_shutdown: Mutex::new(None),
202
203            sserver_port_rx,
204            sserver_port_tx,
205
206            sdk_handle: None,
207            tun_handle: None,
208
209            metrics_registry: Default::default(),
210        }
211    }
212
213    pub fn state(&self) -> State {
214        self.state_rx.borrow().clone()
215    }
216
217    pub fn update_origin_ip(&self) {
218        let tx_clone = self.state_tx.clone();
219        self.spawn(async move {
220            if let Ok(ip) = control::get_origin_ip().await {
221                tx_clone.send_modify(|state: &mut State| {
222                    state.tunnel_info.update_origin_ip(Some(ip));
223                });
224            } else {
225                info!("Failed to update origin IP");
226            }
227        });
228    }
229
230    pub fn update_exit_ip(&self) {
231        let tx_clone = self.state_tx.clone();
232        self.spawn(async move {
233            if let Ok(ip) = control::get_exit_ip().await {
234                tx_clone.send_modify(|state: &mut State| {
235                    state.tunnel_info.update_exit_ip(Some(ip));
236                });
237            } else {
238                info!("Failed to update exit IP");
239            }
240        });
241    }
242
243    pub fn update_state<F>(&self, f: F)
244    where
245        F: FnOnce(&mut State),
246    {
247        self.state_tx.send_modify(f);
248    }
249
250    pub fn set_tunnel_provider(&mut self, provider: P) {
251        self.tunnel_provider = Arc::new(Some(provider));
252    }
253
254    pub fn provider(&self) -> Arc<Option<P>> {
255        self.tunnel_provider.clone()
256    }
257
258    pub fn provider_state(&self) -> Option<P::State> {
259        self.provider().as_ref().as_ref().map(|p| p.state())
260    }
261
262    pub fn watch_state(&self) -> watch::Receiver<State> {
263        self.state_rx.clone()
264    }
265
266    pub fn start_sdk(&mut self, mut config: Config) -> Result<oneshot::Receiver<()>> {
267        // Always get outbound addresses.
268        // Useful if we're restarting the SDK during IF changes.
269        let outbound_addrs = get_outbound_ips();
270        config = config.with_outbound_addrs(outbound_addrs.clone());
271
272        // with_protect_socket_fn is relevant on Android since we need to mark our sockets with the OS
273        // otherwise they will be forwarded back into TUN
274        let provider = self.provider();
275        config = config.with_protect_socket_fn(Arc::new(move |fd| {
276            if let Some(provider) = provider.as_ref() {
277                provider.protect_socket_fd(fd);
278            }
279        }));
280
281        // Main handler, spawns a separate tokio task
282        self.config = Some(config.clone());
283        let (done_rx, handle) = self.run(config);
284        self.sdk_handle = Some(handle);
285
286        Ok(done_rx)
287    }
288
289    pub fn stop_sdk(&mut self) -> Result<()> {
290        if let Some(sender) = self.p2p_control.lock().unwrap().as_ref() {
291            // TODO: handle Result
292            let _ = sender.try_send(Event::Stop);
293        }
294
295        match self.sdk_handle.take() {
296            Some(handle) => {
297                self.block_on(handle)?;
298            }
299            _ => {}
300        }
301
302        {
303            let is_offline = self.is_offline();
304
305            self.update_state(|state| {
306                state.running_since = None;
307                state.status = Status::Idle;
308                state.multiaddrs = vec![];
309                state.is_offline = is_offline;
310                state.tunnel_info = TunnelInfo {
311                    exit_ip: None,
312                    origin_ip: state.tunnel_info.origin_ip.clone(),
313                };
314            });
315        }
316
317        Ok(())
318    }
319
320    pub fn restart_sdk(&mut self) -> Result<oneshot::Receiver<Result<TunnelInfo>>> {
321        match self.config {
322            Some(ref config) => {
323                self.reroute(config.clone())
324            }
325            None => {
326                Err(ClientError::StartSdkFailed)
327            }
328        }
329    }
330
331    pub fn reroute(&mut self, config: Config) -> Result<oneshot::Receiver<Result<TunnelInfo>>> {
332        let connected_exits = self.connected_exits.read().clone();
333        let state_tx = self.state_tx.clone();
334        let (done_tx, done_rx) = oneshot::channel::<Result<TunnelInfo>>();
335
336        // Stopping the SDK performs some blocking which causes tokio runtime to panic
337        // To prevent that, we make sure to run it in a blocking fashion on the current thread
338        if tokio::runtime::Handle::try_current().is_ok() {
339            tokio::task::block_in_place(|| self.stop_sdk())?;
340        } else {
341            self.stop_sdk()?;
342        }
343
344        let sdk_done_rx = self.start_sdk(config)?;
345
346        let token = self
347            .sdk_shutdown
348            .try_lock()
349            .unwrap()
350            .clone()
351            .ok_or(ClientError::TunDeviceNotConfigured)?;
352
353        self.spawn(async move {
354            let Some(_) = token.run_until_cancelled(sdk_done_rx).await else {
355                return Err(ClientError::Cancelled);
356            };
357
358            if let Some(connected_exits) = connected_exits {
359                let result = wait_for_exit(connected_exits, token, Duration::from_secs(10)).await;
360
361                match result {
362                    Ok(tunnel_info) => {
363                        state_tx.send_modify(|state| {
364                            state.tunnel_info = tunnel_info.clone();
365                            state.status = Status::Running;
366                            state.running_since =
367                                Some(format!("{}", chrono::Utc::now().timestamp()));
368                        });
369
370                        let _ = done_tx.send(Ok(tunnel_info));
371                    }
372                    Err(e) => {
373                        let _ = done_tx.send(Err(e));
374                    }
375                }
376            }
377
378            Ok(())
379        });
380
381        return Ok(done_rx);
382    }
383
384    pub fn start_tunnel(&mut self) -> Result<oneshot::Receiver<Result<TunnelInfo>>> {
385        let token = CancellationToken::new();
386        let endpoint_token = token.clone();
387        let event_loop_token = token.clone();
388
389        {
390            let mut state_shutdown = self.tun_shutdown.try_lock().unwrap();
391            if !(*state_shutdown).is_none() {
392                (*state_shutdown).take().unwrap().cancel();
393            }
394            *state_shutdown = Some(token);
395        }
396
397        let provider = self.provider();
398        let sserver_port = self.sserver_port_rx.clone();
399        let connected_exits = self.connected_exits.read().clone();
400        let state_tx = self.state_tx.clone();
401        let (done_tx, done_rx) = oneshot::channel::<Result<TunnelInfo>>();
402
403        let handle = self.spawn(async move {
404            let provider = provider.as_ref().as_ref().unwrap();
405            let device = provider.get_device().unwrap();
406
407            crate::net::endpoint::new_client_with_device(device, endpoint_token, sserver_port)
408                .await
409                .unwrap();
410
411            if let Some(connected_exits) = connected_exits {
412                let result = wait_for_exit(
413                    connected_exits,
414                    event_loop_token.clone(),
415                    Duration::from_secs(45),
416                )
417                .await;
418
419                match result {
420                    Ok(tunnel_info) => {
421                        state_tx.send_modify(|state| {
422                            state.tunnel_info = tunnel_info.clone();
423                            state.status = Status::Running;
424                            state.running_since =
425                                Some(format!("{}", chrono::Utc::now().timestamp()));
426                        });
427
428                        let _ = done_tx.send(Ok(tunnel_info));
429                    }
430                    Err(e) => {
431                        let _ = done_tx.send(Err(e));
432                    }
433                }
434            }
435
436            event_loop_token.cancelled().await;
437        });
438
439        self.tun_handle = Some(handle);
440
441        Ok(done_rx)
442    }
443
444    pub fn stop_tunnel(&mut self) -> Result<()> {
445        {
446            let mut state_shutdown = self.tun_shutdown.lock().unwrap();
447            if !(*state_shutdown).is_none() {
448                (*state_shutdown).take().unwrap().cancel();
449            }
450        }
451
452        match self.tun_handle.take() {
453            Some(handle) => {
454                self.block_on(handle)?;
455            }
456            _ => {}
457        }
458
459        {
460            let mut guard = self.connected_exits.write();
461            *guard = None;
462        }
463
464        Ok(())
465    }
466
467    pub fn is_offline(&self) -> bool {
468        // TODO: add a few TCP probes + take current client-connection state into account
469        get_outbound_ips().is_empty()
470    }
471
472    pub fn update_is_offline(&self) {
473        let is_offline = self.is_offline();
474
475        self.update_state(|state| {
476            state.is_offline = is_offline;
477        });
478    }
479
480    fn run(&self, config: Config) -> (oneshot::Receiver<()>, JoinHandle<()>) {
481        let id = config.keypair.public().to_peer_id().to_base58().to_string();
482        let p2p_control = self.p2p_control.clone();
483        let token = CancellationToken::new();
484        let state_tx = self.state_tx.clone();
485        let sserver_port_tx = self.sserver_port_tx.clone();
486        let metrics_registry = self.metrics_registry.clone();
487        let connected_exits_lock = self.connected_exits.clone();
488
489        {
490            let mut token_guard = self.sdk_shutdown.lock().unwrap();
491            *token_guard = Some(token.clone());
492        }
493
494        self.update_is_offline();
495
496        let (done_tx, done_rx) = oneshot::channel::<()>();
497
498        let join_handle = self.spawn(async move {
499            // TODO: improve error handling, error propagation and just less unwraps
500            let sdk_handle = crate::net::start(&config, metrics_registry.clone())
501                .await
502                .unwrap();
503
504            let mut swarm_rx = sdk_handle.subscribe();
505            let connected_exits = sdk_handle.connected_exits.subscribe();
506            let control_tx = sdk_handle.control_tx.clone();
507
508            {
509                let mut guard = connected_exits_lock.write();
510                *guard = Some(sdk_handle.connected_exits)
511            }
512
513            if let Some(port) = sdk_handle.sserver_port {
514                let _ = sserver_port_tx.send(port);
515            }
516
517            {
518                let mut sender = p2p_control.lock().unwrap();
519                *sender = Some(control_tx);
520            }
521
522            {
523                let token = token.clone();
524                let mut connected_exits = connected_exits.clone();
525                let state_tx = state_tx.clone();
526
527                let is_client = config.roles.contains(&Role::Client);
528                tokio::spawn(async move {
529                    let mut ever_connected = false;
530                    loop {
531                        tokio::select! {
532                            _ = token.cancelled() => {
533                                break;
534                            }
535                            Ok(_) = connected_exits.changed() => {
536                                let exits: HashSet<_> = connected_exits.borrow().iter().map(ToString::to_string).collect();
537                                let connected_now = !exits.is_empty();
538                                let mut exit_ip = state_tx.borrow().tunnel_info.exit_ip.clone();
539
540                                if connected_now {
541                                    if let Ok(ip) = control::get_exit_ip().await {
542                                        exit_ip = Some(ip);
543                                    }
544                                }
545
546                                state_tx.send_modify(|state| {
547                                    if is_client && ever_connected {
548                                        if connected_now {
549                                            state.status = Status::Running;
550                                        } else {
551                                            state.status = Status::Starting;
552                                        }
553                                        state.tunnel_info.exit_ip = exit_ip;
554                                    }
555
556                                    state.connected_exits = exits;
557
558                                    if connected_now { 
559                                        ever_connected = true;
560                                    }
561                                });
562                            }
563                        }
564                    }
565                });
566            }
567
568            let _ = done_tx.send(());
569
570            state_tx.send_modify(move |state| {
571                state.running_since = Some(format!("{}", chrono::Utc::now().timestamp()));
572                state.id = id;
573                state.status = Status::Running;
574                state.roles = config.roles.iter().cloned().collect();
575                state.tunnel_info = TunnelInfo {
576                    exit_ip: None,
577                    origin_ip: state.tunnel_info.origin_ip.clone(),
578                };
579            });
580
581            while let Ok(event) = swarm_rx.recv().await {
582                match event {
583                    crate::net::event::Event::Discovered(peer_id) => {
584                        info!("discovered peer: {:?}", peer_id);
585                    }
586                    crate::net::event::Event::ConnectionEstablished(peer_id) => {
587                        info!("connection established: {:?}", peer_id);
588                        let exits = connected_exits
589                            .borrow()
590                            .iter()
591                            .map(ToString::to_string)
592                            .collect();
593                        state_tx.send_modify(move |state| {
594                            state.connected_exits = exits;
595                        });
596                    }
597                    crate::net::event::Event::ListenerConnectionEstablished(peer_id) => {
598                        info!("listener connection established: {:?}", peer_id);
599
600                        state_tx.send_modify(move |state| {
601                            state.connected_clients.insert(peer_id.to_string());
602                        });
603                    }
604                    crate::net::event::Event::DialerConnectionEstablished(
605                        peer_id,
606                        is_snowflake,
607                    ) => {
608                        info!(
609                            "dialer connection established: {:?}, snowflake: {}",
610                            peer_id, is_snowflake
611                        );
612                    }
613                    crate::net::event::Event::ConnectionClosed(peer_id, cause) => {
614                        info!("connection closed: {:?} {:?}", peer_id, cause);
615
616                        let exits = connected_exits
617                            .borrow()
618                            .iter()
619                            .map(ToString::to_string)
620                            .collect();
621                        state_tx.send_modify(move |state| {
622                            state.connected_clients.remove(&peer_id.to_string());
623                            state.connected_exits = exits;
624                        });
625                    }
626                    crate::net::event::Event::NewListenAddr(multiaddr) => {
627                        info!("new listen addr: {:?}", multiaddr);
628                        state_tx.send_modify(move |state| {
629                            state.multiaddrs.push(multiaddr.to_string());
630                        });
631                    }
632                    crate::net::event::Event::NewObservedAddr(multiaddr) => {
633                        info!("new observed addr: {:?}", multiaddr);
634                        state_tx.send_modify(move |state| {
635                            let mut addrs = vec![multiaddr.to_string()];
636                            addrs.append(&mut state.multiaddrs);
637                            addrs.dedup();
638                            state.multiaddrs = addrs;
639                        });
640                    }
641                    crate::net::event::Event::Stop => {
642                        info!("close received, breaking swarm loop");
643
644                        token.cancel();
645
646                        state_tx.send_modify(move |state| {
647                            state.running_since = None;
648                            state.status = Status::Idle;
649                            state.multiaddrs = vec![];
650                            state.connected_clients = HashSet::new();
651                            state.connected_exits = HashSet::new();
652                            state.roles = vec![];
653                        });
654
655                        break;
656                    }
657                    crate::net::event::Event::NotificationReceived(evt) => {
658                        state_tx.send_modify(move |state| {
659                            state.last_notify_command = evt;
660                        });
661                    }
662                    ev => {
663                        info!("event: {:?}", ev);
664                    }
665                }
666            }
667        });
668
669        (done_rx, join_handle)
670    }
671
672    pub fn sample_metrics(&self) -> Result<metrics::DetailedBandwidthSummary> {
673        metrics::sample_snowstorm_bandwidth(&self.metrics_registry.lock().unwrap())
674            .map_err(Into::into)
675    }
676
677    pub fn tokio_handle(&self) -> RuntimeHandle {
678        match self.runtime {
679            TokioEnvironment::OwnRuntime(ref rt) => rt.handle().clone(),
680            TokioEnvironment::Handle(ref handle) => handle.clone(),
681        }
682    }
683
684    pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
685    where
686        F: Future + Send + 'static,
687        F::Output: Send + 'static,
688    {
689        match self.runtime {
690            TokioEnvironment::OwnRuntime(ref rt) => rt.spawn(future),
691            TokioEnvironment::Handle(ref handle) => handle.spawn(future),
692        }
693    }
694
695    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
696        match self.runtime {
697            TokioEnvironment::OwnRuntime(ref rt) => rt.block_on(future),
698            TokioEnvironment::Handle(ref handle) => handle.block_on(future),
699        }
700    }
701}
702
703impl<T: TunnelProvider> Default for Client<T> {
704    fn default() -> Self {
705        Self::new(None)
706    }
707}
708
709async fn with_retries<F, Fut>(n: usize, timeout: Duration, mut f: F) -> Result<GeoIp>
710where
711    F: FnMut() -> Fut,
712    Fut: std::future::Future<
713        Output = std::result::Result<GeoIp, Box<dyn std::error::Error + Send + Sync>>,
714    >,
715{
716    for _ in 0..n {
717        match f().await {
718            Ok(info) => return Ok(info),
719            Err(_) => sleep(timeout).await,
720        }
721    }
722
723    Err(ClientError::ExitConnectionFailed)
724}
725
726pub async fn wait_for_exit(
727    connected_exits: Arc<ConnectedExits>,
728    token: CancellationToken,
729    timeout: Duration,
730) -> Result<TunnelInfo> {
731    let mut exits = connected_exits.subscribe();
732    
733    let fut = async move {
734        let wait_for_exits = exits.wait_for(|e| !e.is_empty());
735        let is_connected = wait_for_exits.await.is_ok();
736
737        if is_connected {
738            // Longer waiting period is useful on windows, since for some reason dns resolution fails during first few seconds
739            // TODO: investigate if we can wait until tun+dns are ready ina  better way
740            Ok(tokio::join!(
741                with_retries(5, Duration::from_millis(2000), control::get_origin_ip),
742                with_retries(5, Duration::from_millis(2000), control::get_exit_ip)
743            ))
744        } else {
745            Err("No exits connected")
746        }
747    };
748
749    let fut = tokio::time::timeout(timeout, tokio::spawn(fut));
750    let fut = token.run_until_cancelled(fut);
751    let connection_result = fut.await;
752
753    let Some(connection_result) = connection_result else {
754        // Cancelled!
755        return Err(ClientError::Cancelled);
756    };
757
758    let Ok(Ok(connection_result)) = connection_result else {
759        // Timeout
760        return Err(ClientError::ExitConnectionTimeout);
761    };
762
763    log::info!("Connection result: {:?}", connection_result);
764
765    match connection_result {
766        Ok((origin_ip, exit_ip)) => {
767            let tunnel_info = TunnelInfo {
768                exit_ip: exit_ip.ok(),
769                origin_ip: origin_ip.ok(),
770            };
771
772            Ok(tunnel_info)
773        }
774        _ => Err(ClientError::ExitConnectionFailed),
775    }
776}