snowstorm/net/swarm/
auth.rs1use std::{future::Future, pin::Pin};
2
3use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt};
4use libp2p::{
5 core::{
6 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade},
7 UpgradeInfo,
8 },
9 PeerId,
10};
11
12use crate::control;
13
14#[derive(thiserror::Error, Debug)]
15pub enum SnowstormAuthError {
16 #[error("No auth provided")]
17 NoAuth,
18
19 #[error("Invalid input")]
20 InvalidInput,
21
22 #[error("Access to route {route_id} denied for {peer_id}")]
23 AccessDenied { route_id: String, peer_id: PeerId },
24}
25
26#[derive(Clone)]
27pub struct SnowstormAuth {
28 #[allow(dead_code)]
30 peer_id: PeerId,
31
32 route_id: Option<String>,
33
34 token: String,
35
36 route_auth_enabled: bool,
37}
38
39impl SnowstormAuth {
40 pub fn new(peer_id: PeerId, token: String, route_id: Option<String>) -> Self {
41 Self {
42 peer_id,
43 token,
44 route_id,
45 route_auth_enabled: false,
46 }
47 }
48
49 pub fn with_route_auth_enabled(mut self, force: bool) -> Self {
50 if force {
51 self.route_auth_enabled = force;
52 self
53 } else {
54 self
55 }
56 }
57}
58
59impl UpgradeInfo for SnowstormAuth {
60 type Info = &'static str;
61 type InfoIter = std::iter::Once<Self::Info>;
62
63 fn protocol_info(&self) -> Self::InfoIter {
64 std::iter::once("/snowstorm-auth")
65 }
66}
67
68const AUTH_MARKER: &[u8] = b"ssa";
69
70async fn read_auth_header<T>(io: &mut T) -> Result<String, SnowstormAuthError>
71where
72 T: AsyncRead + Unpin + Send + 'static,
73{
74 let mut marker_slice = [0u8; AUTH_MARKER.len()];
75 io.read_exact(&mut marker_slice)
76 .await
77 .map_err(|_| SnowstormAuthError::NoAuth)?;
78
79 if marker_slice != AUTH_MARKER {
80 return Err(SnowstormAuthError::NoAuth);
81 }
82
83 let mut length_byte = [0u8; 1];
84 io.read_exact(&mut length_byte)
85 .await
86 .map_err(|_| SnowstormAuthError::InvalidInput)?;
87 let length = length_byte[0] as usize;
88 if length > 64 {
89 return Err(SnowstormAuthError::InvalidInput);
90 }
91
92 let mut route_id = vec![0u8; length];
93 io.read_exact(&mut route_id)
94 .await
95 .map_err(|_| SnowstormAuthError::InvalidInput)?;
96
97 String::from_utf8(route_id).map_err(|_| SnowstormAuthError::InvalidInput)
98}
99
100impl<T> InboundConnectionUpgrade<T> for SnowstormAuth
101where
102 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
103{
104 type Output = T;
105 type Error = SnowstormAuthError;
106 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
107
108 fn upgrade_inbound(self, mut io: T, _: Self::Info) -> Self::Future {
109 async move {
110 let route_id = read_auth_header(&mut io).await?;
111
112 if !self.route_auth_enabled {
113 return Ok(io);
114 }
115
116 match control::validate_route(&self.token, &route_id).await {
118 Ok(true) => Ok(io),
119 _ => Err(SnowstormAuthError::AccessDenied {
120 route_id,
121 peer_id: self.peer_id.clone(),
122 }),
123 }
124 }
125 .boxed()
126 }
127}
128
129impl<T> OutboundConnectionUpgrade<T> for SnowstormAuth
130where
131 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
132{
133 type Output = T;
134 type Error = std::io::Error;
135 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
136
137 fn upgrade_outbound(self, mut io: T, _: Self::Info) -> Self::Future {
138 async move {
139 if let Some(route_id) = &self.route_id {
140 io.write(AUTH_MARKER).await?;
141 io.write(&[route_id.len() as u8]).await?;
142 io.write(route_id.as_bytes()).await?;
143 }
144 Ok(io)
145 }
146 .boxed()
147 }
148}