snowstorm/net/swarm/
auth.rs

1use std::{future::Future, pin::Pin};
2
3use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt};
4use libp2p::{
5    core::{
6        upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade},
7        UpgradeInfo,
8    },
9    PeerId,
10};
11
12use crate::control;
13
14#[derive(thiserror::Error, Debug)]
15pub enum SnowstormAuthError {
16    #[error("No auth provided")]
17    NoAuth,
18
19    #[error("Invalid input")]
20    InvalidInput,
21
22    #[error("Access to route {route_id} denied for {peer_id}")]
23    AccessDenied { route_id: String, peer_id: PeerId },
24}
25
26#[derive(Clone)]
27pub struct SnowstormAuth {
28    // TODO: remove once start using for validation!
29    #[allow(dead_code)]
30    peer_id: PeerId,
31
32    route_id: Option<String>,
33
34    token: String,
35
36    route_auth_enabled: bool,
37}
38
39impl SnowstormAuth {
40    pub fn new(peer_id: PeerId, token: String, route_id: Option<String>) -> Self {
41        Self {
42            peer_id,
43            token,
44            route_id,
45            route_auth_enabled: false,
46        }
47    }
48
49    pub fn with_route_auth_enabled(mut self, force: bool) -> Self {
50        if force {
51            self.route_auth_enabled = force;
52            self
53        } else {
54            self
55        }
56    }
57}
58
59impl UpgradeInfo for SnowstormAuth {
60    type Info = &'static str;
61    type InfoIter = std::iter::Once<Self::Info>;
62
63    fn protocol_info(&self) -> Self::InfoIter {
64        std::iter::once("/snowstorm-auth")
65    }
66}
67
68const AUTH_MARKER: &[u8] = b"ssa";
69
70async fn read_auth_header<T>(io: &mut T) -> Result<String, SnowstormAuthError>
71where
72    T: AsyncRead + Unpin + Send + 'static,
73{
74    let mut marker_slice = [0u8; AUTH_MARKER.len()];
75    io.read_exact(&mut marker_slice)
76        .await
77        .map_err(|_| SnowstormAuthError::NoAuth)?;
78
79    if marker_slice != AUTH_MARKER {
80        return Err(SnowstormAuthError::NoAuth);
81    }
82
83    let mut length_byte = [0u8; 1];
84    io.read_exact(&mut length_byte)
85        .await
86        .map_err(|_| SnowstormAuthError::InvalidInput)?;
87    let length = length_byte[0] as usize;
88    if length > 64 {
89        return Err(SnowstormAuthError::InvalidInput);
90    }
91
92    let mut route_id = vec![0u8; length];
93    io.read_exact(&mut route_id)
94        .await
95        .map_err(|_| SnowstormAuthError::InvalidInput)?;
96
97    String::from_utf8(route_id).map_err(|_| SnowstormAuthError::InvalidInput)
98}
99
100impl<T> InboundConnectionUpgrade<T> for SnowstormAuth
101where
102    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
103{
104    type Output = T;
105    type Error = SnowstormAuthError;
106    type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
107
108    fn upgrade_inbound(self, mut io: T, _: Self::Info) -> Self::Future {
109        async move {
110            let route_id = read_auth_header(&mut io).await?;
111
112            if !self.route_auth_enabled {
113                return Ok(io);
114            }
115
116            // Auth check, via yukigo
117            match control::validate_route(&self.token, &route_id).await {
118                Ok(true) => Ok(io),
119                _ => Err(SnowstormAuthError::AccessDenied {
120                    route_id,
121                    peer_id: self.peer_id.clone(),
122                }),
123            }
124        }
125        .boxed()
126    }
127}
128
129impl<T> OutboundConnectionUpgrade<T> for SnowstormAuth
130where
131    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
132{
133    type Output = T;
134    type Error = std::io::Error;
135    type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
136
137    fn upgrade_outbound(self, mut io: T, _: Self::Info) -> Self::Future {
138        async move {
139            if let Some(route_id) = &self.route_id {
140                io.write(AUTH_MARKER).await?;
141                io.write(&[route_id.len() as u8]).await?;
142                io.write(route_id.as_bytes()).await?;
143            }
144            Ok(io)
145        }
146        .boxed()
147    }
148}