1use std::{
2 collections::{HashMap, HashSet},
3 sync::{Arc, Mutex},
4 time::Duration,
5};
6
7use exponential_backoff::Backoff;
8use futures::FutureExt;
9use libp2p::{swarm::SwarmEvent, Multiaddr, PeerId};
10use rand::seq::IteratorRandom;
11use tokio::{
12 sync::{mpsc, watch},
13 task::JoinHandle,
14};
15use tokio_util::sync::CancellationToken;
16
17use crate::{
18 config::RouteDescriptor,
19 control::get_routes_bypass,
20 net::{behaviour::SnowstormBehaviourEvent, is_snowflake},
21};
22
23#[derive(Debug, Clone)]
24pub enum ExitState {
25 Known,
26 Connecting { failed_attempts: u32 },
27 Connected,
28 Disconnected { failed_attempts: u32 },
29}
30
31#[derive(Debug, Clone)]
32pub struct Exit {
33 pub peer_id: PeerId,
34 pub route_id: Option<String>,
35 pub addrs: Vec<Multiaddr>,
36
37 state: ExitState,
38}
39
40impl Exit {
41 pub fn new(peer_id: PeerId, route_id: Option<String>, addrs: Vec<Multiaddr>) -> Self {
42 Self {
43 peer_id,
44 route_id,
45 addrs,
46 state: ExitState::Known,
47 }
48 }
49
50 pub fn is_known(&self) -> bool {
51 matches!(self.state, ExitState::Known)
52 }
53
54 pub fn is_connected(&self) -> bool {
55 matches!(self.state, ExitState::Connected)
56 }
57
58 pub fn is_connecting(&self) -> bool {
59 matches!(self.state, ExitState::Connecting { .. })
60 }
61
62 pub fn is_ok_to_reconnect(&self) -> bool {
63 match self.state {
64 ExitState::Known => false,
65 ExitState::Connecting { .. } => false,
66 ExitState::Connected => false,
67 ExitState::Disconnected { failed_attempts } => failed_attempts < 4,
68 }
69 }
70
71 pub fn connected(&mut self) {
72 self.state = ExitState::Connected;
73 }
74
75 pub fn connecting(&mut self) {
76 match self.state {
77 ExitState::Known => {
78 self.state = ExitState::Connecting { failed_attempts: 0 };
79 }
80 ExitState::Disconnected { failed_attempts } => {
81 self.state = ExitState::Connecting { failed_attempts };
82 }
83 ExitState::Connecting { .. } => {
84 log::warn!("Exit is already connecting: {}", self.peer_id);
85 }
86 ExitState::Connected => {
87 log::warn!("Exit is already connected: {}", self.peer_id);
88 }
89 }
90 }
91
92 pub fn disconnected(&mut self) {
93 match self.state {
94 ExitState::Disconnected { .. } => {}
95 _ => {
96 self.state = ExitState::Disconnected { failed_attempts: 0 };
97 }
98 }
99 }
100
101 pub fn connection_failed(&mut self) {
102 log::info!("Exit connection failed: {:?}", self);
103 match self.state {
104 ExitState::Known | ExitState::Connected => {
105 self.state = ExitState::Disconnected { failed_attempts: 1 };
106 }
107 ExitState::Connecting { failed_attempts } => {
108 self.state = ExitState::Disconnected {
109 failed_attempts: failed_attempts + 1,
110 };
111 }
112 ExitState::Disconnected { .. } => {}
113 }
114 }
115}
116
117enum NextExitError {
118 Sleeping,
119 Connecting,
120 Connected,
121 NoMoreExits,
122}
123
124pub struct ConnectedExits {
125 connected_tx: watch::Sender<HashSet<PeerId>>,
126 exits: Arc<Mutex<HashMap<PeerId, Exit>>>,
127 dials_tx: mpsc::Sender<(PeerId, Vec<Multiaddr>)>,
128 delay_handle: Mutex<Option<JoinHandle<()>>>,
129 api_token: String,
130 token: CancellationToken,
131 backoff: Mutex<exponential_backoff::IntoIter>,
132}
133
134impl ConnectedExits {
135 fn backoff() -> Backoff {
136 let min = Duration::from_secs(1);
137 let max = Duration::from_secs(120);
138
139 Backoff::new(100, min, max)
140 }
141
142 pub fn new(
143 api_token: String,
144 token: CancellationToken,
145 ) -> (Self, mpsc::Receiver<(PeerId, Vec<Multiaddr>)>) {
146 let (connected_tx, _) = watch::channel(HashSet::new());
147 let (dials_tx, dials_rx) = mpsc::channel(100);
148
149 let this = Self {
150 connected_tx,
151 exits: Arc::new(Mutex::new(HashMap::new())),
152 delay_handle: Mutex::new(None),
153 dials_tx,
154 token,
155 api_token,
156 backoff: Mutex::new(Self::backoff().into_iter()),
157 };
158
159 this.spawn_exit_discovery();
160
161 (this, dials_rx)
162 }
163
164 pub fn expect(&self, peer_id: &PeerId, route_id: Option<&str>, addrs: Vec<Multiaddr>) {
165 let exit = Exit::new(peer_id.clone(), route_id.map(|s| s.to_string()), addrs);
166 self.exits.lock().unwrap().insert(peer_id.clone(), exit);
167 }
168
169 pub fn is_expected(&self) -> bool {
170 !self.exits.lock().unwrap().is_empty()
171 }
172
173 pub fn on_swarm_event(&self, event: &SwarmEvent<SnowstormBehaviourEvent>) {
174 match event {
175 SwarmEvent::ConnectionEstablished {
176 peer_id, endpoint, ..
177 } => {
178 if endpoint.is_dialer() {
179 let mad = endpoint.get_remote_address();
180
181 if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
182 exit.connected();
183 }
184
185 if is_snowflake(mad) {
186 self.add(peer_id.clone());
187 } else {
188 self.connected(peer_id);
189 }
190 }
191 }
192 SwarmEvent::ConnectionClosed { peer_id, .. } => {
193 self.disconnected(peer_id);
194 if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
195 exit.disconnected();
196 }
197 self.queue_next_dial();
198 }
199 SwarmEvent::OutgoingConnectionError { peer_id, .. } => {
200 let Some(peer_id) = peer_id else {
201 return;
202 };
203
204 if let Some(exit) = self.exits.lock().unwrap().get_mut(peer_id) {
205 exit.connection_failed();
206 }
207
208 self.queue_next_dial();
209 }
210 _ => {}
211 }
212 }
213
214 pub fn connected(&self, peer_id: &PeerId) {
215 let is_expected = self.exits.lock().unwrap().contains_key(peer_id);
216
217 if is_expected {
218 self.add(peer_id.clone());
219 }
220 }
221
222 pub fn disconnected(&self, peer_id: &PeerId) {
223 self.connected_tx.send_modify(|connected| {
224 connected.remove(peer_id);
225 });
226 }
227
228 pub fn add(&self, peer_id: PeerId) {
229 self.connected_tx.send_modify(|connected| {
230 connected.insert(peer_id);
231 });
232 }
233
234 pub fn subscribe(&self) -> watch::Receiver<HashSet<PeerId>> {
235 self.connected_tx.subscribe()
236 }
237
238 pub async fn pick(&self) -> (Option<PeerId>, Option<String>) {
239 let mut rx = self.connected_tx.subscribe();
240
241 let mut connected = rx.borrow().clone();
243
244 if connected.is_empty() {
245 let f = rx.wait_for(|c| !c.is_empty());
246 let timeout = tokio::time::timeout(std::time::Duration::from_secs(2), f);
247
248 match timeout.await {
249 Ok(Ok(r)) => {
250 connected = r.clone();
251 }
252 Ok(_) => {
253 return (None, None);
254 }
255 Err(_) => {
256 log::error!("Timeout waiting for exit");
257 return (None, None);
258 }
259 }
260 }
261
262 log::info!("connected exits: {:?}", connected);
263
264 let peer_id = connected.iter().choose(&mut rand::rng()).cloned();
265 let route_id = peer_id.and_then(|peer_id| {
266 let exits = self.exits.lock().unwrap();
267 let exit = exits.get(&peer_id)?;
268
269 exit.route_id.clone()
270 });
271
272 (peer_id, route_id)
273 }
274
275 fn next_exit_to_dial(&self) -> Result<(PeerId, Vec<Multiaddr>), NextExitError> {
276 let is_waiting = self
277 .delay_handle
278 .lock()
279 .unwrap()
280 .as_ref()
281 .map(|h| !h.is_finished())
282 .unwrap_or(false);
283
284 if is_waiting {
285 return Err(NextExitError::Sleeping);
286 }
287
288 let mut exits = self.exits.lock().unwrap();
289
290 let is_connected = !self.connected_tx.borrow().is_empty();
291
292 if is_connected {
294 log::info!("Already connected to at least one exit, no need to dial another one");
295 return Err(NextExitError::Connected);
296 }
297
298 let is_connecting = exits.values().any(|exit| exit.is_connecting());
299
300 if is_connecting {
302 log::info!(
303 "Already trying to connect to an exit, waiting for connection to be established"
304 );
305 return Err(NextExitError::Connecting);
306 }
307
308 if let Some(retry_exit) = exits.values_mut().find(|exit| exit.is_ok_to_reconnect()) {
310 log::info!(
311 "Dialing previously failed exit, still ok to retry: {}",
312 retry_exit.peer_id
313 );
314
315 retry_exit.connecting();
316
317 return Ok((retry_exit.peer_id.clone(), retry_exit.addrs.clone()));
318 }
319
320 if let Some(known_exit) = exits.values_mut().find(|exit| exit.is_known()) {
322 log::info!("Dialing known exit: {}", known_exit.peer_id);
323
324 known_exit.connecting();
325
326 return Ok((known_exit.peer_id.clone(), known_exit.addrs.clone()));
327 }
328
329 return Err(NextExitError::NoMoreExits);
330 }
331
332 fn queue_next_dial(&self) {
333 match self.next_exit_to_dial() {
334 Ok((peer_id, addrs)) => {
335 let delay = {
336 let mut backoff = self.backoff.lock().unwrap();
337 let Some(Some(delay)) = backoff.next() else {
338 log::warn!("No more backoff delay available, giving up on dialing exits");
339 return;
340 };
341
342 delay
343 };
344
345 let dials_tx = self.dials_tx.clone();
346 let token = self.token.clone();
347 let handle = tokio::spawn(async move {
348 log::info!("Scheduling next dial for {} in {:?}", peer_id, delay);
349
350 tokio::select! {
351 _ = token.cancelled() => {
352 log::info!("Cancellation token triggered, stopping next dial");
353 return;
354 }
355 _ = tokio::time::sleep(delay) => {
356 if let Err(e) = dials_tx.try_send((peer_id, addrs)) {
357 log::error!("Failed to queue next dial: {}", e);
358 }
359 }
360 }
361
362 ()
363 });
364
365 self.delay_handle.lock().unwrap().replace(handle);
366 }
367 Err(NextExitError::NoMoreExits) => {
368 }
370 _ => {}
371 }
372 }
373
374 fn spawn_exit_discovery(&self) {
375 let token = self.token.clone();
376 let api_token = self.api_token.clone();
377 let exits = self.exits.clone();
378 let mut connected_rx = self.subscribe();
379
380 tokio::spawn(async move {
384 log::info!("Starting exit discovery... Waiting for the first connection");
385
386 tokio::select! {
387 _ = token.cancelled() => {
388 log::info!("Cancellation token triggered, stopping exit discovery");
389 return;
390 }
391 _ = connected_rx.wait_for(|connected| !connected.is_empty()) => {}
393 }
394
395 log::info!("Discovering exits...");
396
397 let mut delay = futures::future::ready::<()>(()).boxed();
399
400 loop {
401 let waiting_disconnected = connected_rx
402 .wait_for(|connected| connected.is_empty())
403 .map(|_| ())
404 .boxed();
405 let waiting_with_delay = futures::future::join(delay, waiting_disconnected);
406
407 tokio::select! {
408 _ = token.cancelled() => {
409 log::info!("Cancellation token triggered, stopping exit discovery");
410 return;
411 }
412 _ = waiting_with_delay => {
413 log::info!("Connected exits changed, refreshing exits");
414
415 let routes = get_routes_bypass(api_token.as_ref()).await;
416 log::info!("Routes fetched: {:?}", routes);
417
418 if let Ok(routes) = routes {
419 log::info!("Discovered new routes: {:?}", routes);
420
421 let mut exits = exits.lock().unwrap();
422 for route in routes {
423 let Ok(route) = RouteDescriptor::try_from(route) else {
424 continue;
425 };
426
427 let Some(first_hop) = route.path.first() else {
428 continue;
429 };
430
431 let Some(addr_info) = first_hop.addr_info.first() else {
432 continue;
433 };
434
435 let exit = Exit::new(addr_info.peer_id.clone(), Some(route.id.clone()), addr_info.multiaddrs.clone());
436 log::info!("Adding exit: {:?}", exit);
437 exits.insert(addr_info.peer_id.clone(),exit);
438 }
439 };
440
441 delay = tokio::time::sleep(Duration::from_secs(60)).map(|_| ()).boxed();
443 }
444 }
445 }
446 });
447 }
448}