snowstorm/net/
connected_exits.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{Arc, Mutex},
4    time::Duration,
5};
6
7use exponential_backoff::Backoff;
8use futures::FutureExt;
9use libp2p::{swarm::SwarmEvent, Multiaddr, PeerId};
10use rand::seq::IteratorRandom;
11use tokio::{
12    sync::{mpsc, watch},
13    task::JoinHandle,
14};
15use tokio_util::sync::CancellationToken;
16
17use crate::{
18    config::RouteDescriptor,
19    control::get_routes_bypass,
20    net::{behaviour::SnowstormBehaviourEvent, is_snowflake},
21};
22
23#[derive(Debug, Clone)]
24pub enum ExitState {
25    Known,
26    Connecting { failed_attempts: u32 },
27    Connected,
28    Disconnected { failed_attempts: u32 },
29}
30
31#[derive(Debug, Clone)]
32pub struct Exit {
33    pub peer_id: PeerId,
34    pub route_id: Option<String>,
35    pub addrs: Vec<Multiaddr>,
36
37    state: ExitState,
38}
39
40impl Exit {
41    pub fn new(peer_id: PeerId, route_id: Option<String>, addrs: Vec<Multiaddr>) -> Self {
42        Self {
43            peer_id,
44            route_id,
45            addrs,
46            state: ExitState::Known,
47        }
48    }
49
50    pub fn is_known(&self) -> bool {
51        matches!(self.state, ExitState::Known)
52    }
53
54    pub fn is_connected(&self) -> bool {
55        matches!(self.state, ExitState::Connected)
56    }
57
58    pub fn is_connecting(&self) -> bool {
59        matches!(self.state, ExitState::Connecting { .. })
60    }
61
62    pub fn is_ok_to_reconnect(&self) -> bool {
63        match self.state {
64            ExitState::Known => false,
65            ExitState::Connecting { .. } => false,
66            ExitState::Connected => false,
67            ExitState::Disconnected { failed_attempts } => failed_attempts < 4,
68        }
69    }
70
71    pub fn connected(&mut self) {
72        self.state = ExitState::Connected;
73    }
74
75    pub fn connecting(&mut self) {
76        match self.state {
77            ExitState::Known => {
78                self.state = ExitState::Connecting { failed_attempts: 0 };
79            }
80            ExitState::Disconnected { failed_attempts } => {
81                self.state = ExitState::Connecting { failed_attempts };
82            }
83            ExitState::Connecting { .. } => {
84                log::warn!("Exit is already connecting: {}", self.peer_id);
85            }
86            ExitState::Connected => {
87                log::warn!("Exit is already connected: {}", self.peer_id);
88            }
89        }
90    }
91
92    pub fn disconnected(&mut self) {
93        match self.state {
94            ExitState::Disconnected { .. } => {}
95            _ => {
96                self.state = ExitState::Disconnected { failed_attempts: 0 };
97            }
98        }
99    }
100
101    pub fn connection_failed(&mut self) {
102        log::info!("Exit connection failed: {:?}", self);
103        match self.state {
104            ExitState::Known | ExitState::Connected => {
105                self.state = ExitState::Disconnected { failed_attempts: 1 };
106            }
107            ExitState::Connecting { failed_attempts } => {
108                self.state = ExitState::Disconnected {
109                    failed_attempts: failed_attempts + 1,
110                };
111            }
112            ExitState::Disconnected { .. } => {}
113        }
114    }
115}
116
117enum NextExitError {
118    Sleeping,
119    Connecting,
120    Connected,
121    NoMoreExits,
122}
123
124pub struct ConnectedExits {
125    connected_tx: watch::Sender<HashSet<PeerId>>,
126    exits: Arc<Mutex<HashMap<PeerId, Exit>>>,
127    dials_tx: mpsc::Sender<(PeerId, Vec<Multiaddr>)>,
128    delay_handle: Mutex<Option<JoinHandle<()>>>,
129    api_token: String,
130    token: CancellationToken,
131    backoff: Mutex<exponential_backoff::IntoIter>,
132}
133
134impl ConnectedExits {
135    fn backoff() -> Backoff {
136        let min = Duration::from_secs(1);
137        let max = Duration::from_secs(120);
138
139        Backoff::new(100, min, max)
140    }
141
142    pub fn new(
143        api_token: String,
144        token: CancellationToken,
145    ) -> (Self, mpsc::Receiver<(PeerId, Vec<Multiaddr>)>) {
146        let (connected_tx, _) = watch::channel(HashSet::new());
147        let (dials_tx, dials_rx) = mpsc::channel(100);
148
149        let this = Self {
150            connected_tx,
151            exits: Arc::new(Mutex::new(HashMap::new())),
152            delay_handle: Mutex::new(None),
153            dials_tx,
154            token,
155            api_token,
156            backoff: Mutex::new(Self::backoff().into_iter()),
157        };
158
159        this.spawn_exit_discovery();
160
161        (this, dials_rx)
162    }
163
164    pub fn expect(&self, peer_id: &PeerId, route_id: Option<&str>, addrs: Vec<Multiaddr>) {
165        let exit = Exit::new(peer_id.clone(), route_id.map(|s| s.to_string()), addrs);
166        self.exits.lock().unwrap().insert(peer_id.clone(), exit);
167    }
168
169    pub fn is_expected(&self) -> bool {
170        !self.exits.lock().unwrap().is_empty()
171    }
172
173    pub fn on_swarm_event(&self, event: &SwarmEvent<SnowstormBehaviourEvent>) {
174        match event {
175            SwarmEvent::ConnectionEstablished {
176                peer_id, endpoint, ..
177            } => {
178                if endpoint.is_dialer() {
179                    let mad = endpoint.get_remote_address();
180
181                    if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
182                        exit.connected();
183                    }
184
185                    if is_snowflake(mad) {
186                        self.add(peer_id.clone());
187                    } else {
188                        self.connected(peer_id);
189                    }
190                }
191            }
192            SwarmEvent::ConnectionClosed { peer_id, .. } => {
193                self.disconnected(peer_id);
194                if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
195                    exit.disconnected();
196                }
197                self.queue_next_dial();
198            }
199            SwarmEvent::OutgoingConnectionError { peer_id, .. } => {
200                let Some(peer_id) = peer_id else {
201                    return;
202                };
203
204                if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
205                    exit.connection_failed();
206                }
207
208                self.queue_next_dial();
209            }
210            _ => {}
211        }
212    }
213
214    pub fn connected(&self, peer_id: &PeerId) {
215        let is_expected = self.exits.lock().unwrap().contains_key(peer_id);
216
217        if is_expected {
218            self.add(peer_id.clone());
219        }
220    }
221
222    pub fn disconnected(&self, peer_id: &PeerId) {
223        self.connected_tx.send_modify(|connected| {
224            connected.remove(peer_id);
225        });
226    }
227
228    pub fn add(&self, peer_id: PeerId) {
229        self.connected_tx.send_modify(|connected| {
230            connected.insert(peer_id);
231        });
232    }
233
234    pub fn subscribe(&self) -> watch::Receiver<HashSet<PeerId>> {
235        self.connected_tx.subscribe()
236    }
237
238    pub async fn pick(&self) -> (Option<PeerId>, Option<String>) {
239        let mut rx = self.connected_tx.subscribe();
240
241        // Pick random from the current available ones, otherwise wait until something is added with a 2s timeout
242        let mut connected = rx.borrow().clone();
243
244        if connected.is_empty() {
245            let f = rx.wait_for(|c| !c.is_empty());
246            let timeout = tokio::time::timeout(std::time::Duration::from_secs(2), f);
247
248            match timeout.await {
249                Ok(Ok(r)) => {
250                    connected = r.clone();
251                }
252                Ok(_) => {
253                    return (None, None);
254                }
255                Err(_) => {
256                    log::error!("Timeout waiting for exit");
257                    return (None, None);
258                }
259            }
260        }
261
262        log::info!("connected exits: {:?}", connected);
263
264        let peer_id = connected.iter().choose(&mut rand::rng()).cloned();
265        let route_id = peer_id.and_then(|peer_id| {
266            let exits = self.exits.lock().unwrap();
267            let exit = exits.get(&peer_id)?;
268
269            exit.route_id.clone()
270        });
271
272        (peer_id, route_id)
273    }
274
275    fn next_exit_to_dial(&self) -> Result<(PeerId, Vec<Multiaddr>), NextExitError> {
276        let is_waiting = self
277            .delay_handle
278            .lock()
279            .unwrap()
280            .as_ref()
281            .map(|h| !h.is_finished())
282            .unwrap_or(false);
283
284        if is_waiting {
285            return Err(NextExitError::Sleeping);
286        }
287
288        let mut exits = self.exits.lock().unwrap();
289
290        let is_connected = !self.connected_tx.borrow().is_empty();
291
292        // 1. Nothing to do if we are connected to at least one exit
293        if is_connected {
294            log::info!("Already connected to at least one exit, no need to dial another one");
295            return Err(NextExitError::Connected);
296        }
297
298        let is_connecting = exits.values().any(|exit| exit.is_connecting());
299
300        // 2. Nothing to do if we are trying to connect right now to an exit
301        if is_connecting {
302            log::info!(
303                "Already trying to connect to an exit, waiting for connection to be established"
304            );
305            return Err(NextExitError::Connecting);
306        }
307
308        // 3. Pick a failed exit with less then X failed attempts
309        if let Some(retry_exit) = exits.values_mut().find(|exit| exit.is_ok_to_reconnect()) {
310            log::info!(
311                "Dialing previously failed exit, still ok to retry: {}",
312                retry_exit.peer_id
313            );
314
315            retry_exit.connecting();
316
317            return Ok((retry_exit.peer_id.clone(), retry_exit.addrs.clone()));
318        }
319
320        // 4. Otherwise a known exit
321        if let Some(known_exit) = exits.values_mut().find(|exit| exit.is_known()) {
322            log::info!("Dialing known exit: {}", known_exit.peer_id);
323
324            known_exit.connecting();
325
326            return Ok((known_exit.peer_id.clone(), known_exit.addrs.clone()));
327        }
328
329        return Err(NextExitError::NoMoreExits);
330    }
331
332    fn queue_next_dial(&self) {
333        match self.next_exit_to_dial() {
334            Ok((peer_id, addrs)) => {
335                let delay = {
336                    let mut backoff = self.backoff.lock().unwrap();
337                    let Some(Some(delay)) = backoff.next() else {
338                        log::warn!("No more backoff delay available, giving up on dialing exits");
339                        return;
340                    };
341
342                    delay
343                };
344
345                let dials_tx = self.dials_tx.clone();
346                let token = self.token.clone();
347                let handle = tokio::spawn(async move {
348                    log::info!("Scheduling next dial for {} in {:?}", peer_id, delay);
349
350                    tokio::select! {
351                        _ = token.cancelled() => {
352                            log::info!("Cancellation token triggered, stopping next dial");
353                            return;
354                        }
355                        _ = tokio::time::sleep(delay) => {
356                            if let Err(e) = dials_tx.try_send((peer_id, addrs)) {
357                                log::error!("Failed to queue next dial: {}", e);
358                            }
359                        }
360                    }
361
362                    ()
363                });
364
365                self.delay_handle.lock().unwrap().replace(handle);
366            }
367            Err(NextExitError::NoMoreExits) => {
368                // TODO: schedule exit discovery?
369            }
370            _ => {}
371        }
372    }
373
374    fn spawn_exit_discovery(&self) {
375        let token = self.token.clone();
376        let api_token = self.api_token.clone();
377        let exits = self.exits.clone();
378        let mut connected_rx = self.subscribe();
379
380        // We want to trigger discovery. After we connect to an exit and then get disconnected
381        // refresh the exits, add new ones, wait for connected_rx empty + at least 60 seconds
382        // Make sure to respect the cancellation token
383        tokio::spawn(async move {
384            log::info!("Starting exit discovery... Waiting for the first connection");
385
386            tokio::select! {
387                _ = token.cancelled() => {
388                    log::info!("Cancellation token triggered, stopping exit discovery");
389                    return;
390                }
391                // Connected for the first time, now we can wait for changes
392                _ = connected_rx.wait_for(|connected| !connected.is_empty()) => {}
393            }
394
395            log::info!("Discovering exits...");
396
397            // No delay on the first run
398            let mut delay = futures::future::ready::<()>(()).boxed();
399
400            loop {
401                let waiting_disconnected = connected_rx
402                    .wait_for(|connected| connected.is_empty())
403                    .map(|_| ())
404                    .boxed();
405                let waiting_with_delay = futures::future::join(delay, waiting_disconnected);
406
407                tokio::select! {
408                    _ = token.cancelled() => {
409                        log::info!("Cancellation token triggered, stopping exit discovery");
410                        return;
411                    }
412                    _ = waiting_with_delay => {
413                        log::info!("Connected exits changed, refreshing exits");
414
415                        let routes = get_routes_bypass(api_token.as_ref()).await;
416                        log::info!("Routes fetched: {:?}", routes);
417
418                        if let Ok(routes) = routes {
419                            log::info!("Discovered new routes: {:?}", routes);
420
421                            let mut exits = exits.lock().unwrap();
422                            for route in routes {
423                                let Ok(route) = RouteDescriptor::try_from(route) else {
424                                    continue;
425                                };
426
427                                let Some(first_hop) = route.path.first() else {
428                                    continue;
429                                };
430
431                                let Some(addr_info) = first_hop.addr_info.first() else {
432                                    continue;
433                                };
434
435                                let exit = Exit::new(addr_info.peer_id.clone(), Some(route.id.clone()), addr_info.multiaddrs.clone());
436                                log::info!("Adding exit: {:?}", exit);
437                                exits.insert(addr_info.peer_id.clone(),exit);
438                            }
439                        };
440
441                        // Next time update not sooner than 60 seconds from this try
442                        delay = tokio::time::sleep(Duration::from_secs(60)).map(|_| ()).boxed();
443                    }
444                }
445            }
446        });
447    }
448}