snowstorm/net/
netif.rs

1use std::future::Future;
2use std::net::IpAddr;
3
4use futures::future::BoxFuture;
5use futures::stream::FusedStream;
6use futures::Stream;
7use if_watch::tokio::IfWatcher;
8use pin_project::pin_project;
9use std::pin::Pin;
10
11use crate::net::get_default_interface;
12
13pub struct Watcher {
14    outbound_addrs: Vec<IpAddr>,
15    inner: IfWatcher,
16}
17
18impl Watcher {
19    pub fn new() -> std::io::Result<Self> {
20        Ok(Watcher {
21            outbound_addrs: get_outbound_ips(),
22            inner: IfWatcher::new()?,
23        })
24    }
25}
26
27impl Stream for Watcher {
28    type Item = Vec<IpAddr>;
29
30    fn poll_next(
31        self: std::pin::Pin<&mut Self>,
32        cx: &mut std::task::Context<'_>,
33    ) -> std::task::Poll<Option<Self::Item>> {
34        let this = self.get_mut();
35
36        // We're looping in order to be able to poll multiple times
37        // Since we potentially need to ignore some events from inner
38        loop {
39            match Pin::new(&mut this.inner).poll_next(cx) {
40                std::task::Poll::Ready(Some(_event)) => {
41                    let new_addrs = get_outbound_ips();
42
43                    if new_addrs != this.outbound_addrs {
44                        this.outbound_addrs = new_addrs.clone();
45
46                        return std::task::Poll::Ready(Some(new_addrs));
47                    } else {
48                        continue;
49                    }
50                }
51                std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
52                std::task::Poll::Pending => return std::task::Poll::Pending,
53            }
54        }
55    }
56}
57
58impl FusedStream for Watcher {
59    fn is_terminated(&self) -> bool {
60        self.inner.is_terminated()
61    }
62}
63
64#[pin_project]
65pub struct DebouncedWatcher {
66    inner: Watcher,
67    debounce_duration: std::time::Duration,
68
69    #[pin]
70    pending_value: Option<BoxFuture<'static, Vec<IpAddr>>>,
71}
72
73impl DebouncedWatcher {
74    pub fn new(
75        debounce_duration: std::time::Duration,
76    ) -> std::io::Result<Self> {
77        Ok(DebouncedWatcher {
78            inner: Watcher::new()?,
79            debounce_duration,
80            pending_value: None,
81        })
82    }
83}
84
85impl Stream for DebouncedWatcher {
86    type Item = Vec<IpAddr>;
87
88    fn poll_next(
89        self: Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> std::task::Poll<Option<Self::Item>> {
92        let mut this = self.project();
93        let debounce_duration = *this.debounce_duration;
94
95        match Pin::new(&mut this.inner).poll_next(cx) {
96            std::task::Poll::Ready(Some(value)) => {
97                // Set a new pending value with the debounce duration
98                this.pending_value.set(Some(Box::pin(async move {
99                    tokio::time::sleep(debounce_duration).await;
100                    value
101                })));
102
103                // Notify the waker to poll us again now that we have a pending value
104                cx.waker().wake_by_ref();
105
106                return std::task::Poll::Pending;
107            }
108            std::task::Poll::Ready(None) => {
109                // If we still have a pending value -> make sure to keep polling for it
110                // Otherwise it's safe to also terminate the stream
111                if this.pending_value.is_none() {
112                    return std::task::Poll::Ready(None);
113                }
114            }
115            std::task::Poll::Pending => {}
116        }
117
118        if let Some(pending) = this.pending_value.as_mut().as_pin_mut() {
119            match pending.poll(cx) {
120                std::task::Poll::Ready(value) => {
121                    this.pending_value.set(None);
122                    return std::task::Poll::Ready(Some(value));
123                }
124                std::task::Poll::Pending => {}
125            }
126        }
127
128        std::task::Poll::Pending
129    }
130}
131
132impl FusedStream for DebouncedWatcher {
133    fn is_terminated(&self) -> bool {
134        self.pending_value.is_none() && self.inner.is_terminated()
135    }
136}
137
138fn get_outbound_ips() -> Vec<IpAddr> {
139    let default_interface = get_default_interface();
140
141    match default_interface {
142        Some(iface) => {
143            let mut addrs: Vec<IpAddr> = vec![];
144
145            if let Some(ipv4) = iface.ipv4.first() {
146                addrs.push(ipv4.addr().clone().into());
147            }
148
149            if let Some(ipv6) = iface
150                .ipv6
151                .into_iter()
152                .find(|net| !net.addr().is_unicast_link_local())
153            {
154                addrs.push(ipv6.addr().clone().into());
155            }
156
157            addrs
158        }
159        None => Vec::new(),
160    }
161}