mirror of
https://github.com/shiroyashik/sculptor.git
synced 2025-12-06 04:51:13 +03:00
243 lines
No EOL
11 KiB
Rust
243 lines
No EOL
11 KiB
Rust
use anyhow::bail;
|
|
use axum::{body::Bytes, extract::{ws::{Message, WebSocket}, State}};
|
|
use dashmap::DashMap;
|
|
use tokio::sync::{broadcast, mpsc};
|
|
use tracing::instrument;
|
|
|
|
use crate::{auth::Userinfo, AppState};
|
|
|
|
use super::{AuthModeError, C2SMessage, RADError, RecvAndDecode, S2CMessage, SessionMessage, WSSession};
|
|
|
|
pub async fn initial(
|
|
ws: axum::extract::WebSocketUpgrade,
|
|
State(state): State<AppState>
|
|
) -> axum::response::Response {
|
|
ws.on_upgrade(|socket| handle_socket(socket, state))
|
|
}
|
|
|
|
async fn handle_socket(mut ws: WebSocket, state: AppState) {
|
|
// Trying authenticate & get user data or dropping connection
|
|
match authenticate(&mut ws, &state).await {
|
|
Ok(user) => {
|
|
|
|
// Creating session & creating/getting channels
|
|
let mut session = {
|
|
let sub_workers_aborthandles = DashMap::new();
|
|
|
|
// Channel for receiving messages from internal functions.
|
|
let (own_tx, own_rx) = mpsc::channel(32);
|
|
state.session.insert(user.uuid, own_tx.clone());
|
|
|
|
// Channel for sending messages to subscribers
|
|
let subs_tx = match state.subscribes.get(&user.uuid) {
|
|
Some(tx) => tx.clone(),
|
|
None => {
|
|
tracing::debug!("[Subscribes] Can't find own subs channel for {}, creating new...", user.uuid);
|
|
let (subs_tx, _) = broadcast::channel(32);
|
|
state.subscribes.insert(user.uuid, subs_tx.clone());
|
|
subs_tx
|
|
},
|
|
};
|
|
|
|
WSSession { user: user.clone(), own_tx, own_rx, subs_tx, sub_workers_aborthandles }
|
|
};
|
|
|
|
// Starting main worker
|
|
if let Err(kind) = main_worker(&mut session, &mut ws, &state).await {
|
|
tracing::info!(error = %kind, nickname = %session.user.nickname, "Main worker exited");
|
|
}
|
|
|
|
for (_, handle) in session.sub_workers_aborthandles {
|
|
handle.abort();
|
|
}
|
|
|
|
// Removing session data
|
|
state.session.remove(&user.uuid);
|
|
state.user_manager.remove(&user.uuid);
|
|
},
|
|
Err(kind) => {
|
|
tracing::info!(error = %kind, "Can't authenticate");
|
|
}
|
|
}
|
|
|
|
// Closing connection
|
|
if let Err(kind) = ws.send(Message::Close(None)).await { tracing::trace!("[WebSocket] Closing fault: {}", kind) }
|
|
}
|
|
|
|
#[instrument(skip_all, fields(nickname = %session.user.nickname))]
|
|
async fn main_worker(session: &mut WSSession, ws: &mut WebSocket, state: &AppState) -> anyhow::Result<()> {
|
|
state.state_pings.insert(session.user.uuid, vec![]);
|
|
let mut next_state_ping = false;
|
|
|
|
tracing::debug!("WebSocket control for {} is transferred to the main worker", session.user.nickname);
|
|
loop {
|
|
tokio::select! {
|
|
external_msg = ws.recv_and_decode() => {
|
|
|
|
// Getting a value or halt the worker without an error
|
|
let external_msg = match external_msg {
|
|
Ok(m) => m,
|
|
Err(kind) => {
|
|
match kind {
|
|
RADError::Close(_) => return Ok(()),
|
|
RADError::StreamClosed => return Ok(()),
|
|
_ => return Err(kind.into())
|
|
}
|
|
},
|
|
};
|
|
|
|
tracing::trace!(ping = ?external_msg);
|
|
// Processing message
|
|
match external_msg {
|
|
C2SMessage::Token(_) => bail!("authentication passed, but the client sent the Token again"),
|
|
C2SMessage::Ping(func_id, echo, data) => {
|
|
if !(func_id == 252645133) {
|
|
// Normal procedure
|
|
let s2c_ping: Vec<u8> = S2CMessage::Ping(session.user.uuid, func_id, echo, data).into();
|
|
|
|
// State ping storing
|
|
if next_state_ping {
|
|
let mut vec = state.state_pings.get_mut(&session.user.uuid).unwrap();
|
|
vec.push(s2c_ping.clone().into());
|
|
next_state_ping = false;
|
|
}
|
|
// Echo check
|
|
if echo {
|
|
ws.send(Message::Binary(s2c_ping.clone().into())).await?
|
|
}
|
|
// Sending to others
|
|
let _ = session.subs_tx.send(s2c_ping);
|
|
} else {
|
|
// State ping procedure
|
|
match data[1] {
|
|
0 => { state.state_pings.insert(session.user.uuid, vec![]); },
|
|
1 => { next_state_ping = !next_state_ping; },
|
|
_ => {}
|
|
}
|
|
}
|
|
},
|
|
C2SMessage::Sub(uuid) => {
|
|
tracing::debug!("[WebSocket] {} subscribes to {}", session.user.nickname, uuid);
|
|
|
|
// Doesn't allow to subscribe to yourself
|
|
if session.user.uuid != uuid {
|
|
// Creates a channel to send pings to a subscriber if it can't find an existing one
|
|
let rx = match state.subscribes.get(&uuid) {
|
|
Some(tx) => tx.subscribe(),
|
|
None => {
|
|
let (tx, rx) = broadcast::channel(32);
|
|
state.subscribes.insert(uuid, tx);
|
|
rx
|
|
},
|
|
};
|
|
let handle = tokio::spawn(sub_worker(session.own_tx.clone(), rx)).abort_handle();
|
|
session.sub_workers_aborthandles.insert(uuid, handle);
|
|
|
|
// Apply state pings / bmpdpvw / 252645133
|
|
if let Some(vec) = state.state_pings.get(&uuid) {
|
|
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
|
for ping in &*vec {
|
|
let msg = Message::Binary(ping.clone().into());
|
|
tracing::trace!(msg = ?msg);
|
|
ws.send(msg).await?
|
|
}
|
|
};
|
|
}
|
|
},
|
|
C2SMessage::Unsub(uuid) => {
|
|
tracing::debug!("[WebSocket] {} unsubscribes from {}", session.user.nickname, uuid);
|
|
|
|
match session.sub_workers_aborthandles.get(&uuid) {
|
|
Some(handle) => handle.abort(),
|
|
None => tracing::warn!("[WebSocket] {} was not subscribed.", session.user.nickname),
|
|
};
|
|
},
|
|
}
|
|
},
|
|
internal_msg = session.own_rx.recv() => {
|
|
let internal_msg = internal_msg.ok_or(anyhow::anyhow!("Unexpected error! Session channel broken!"))?;
|
|
match internal_msg {
|
|
SessionMessage::Ping(msg) => {
|
|
ws.send(Message::Binary(msg.into())).await?
|
|
},
|
|
SessionMessage::Banned => {
|
|
let _ = ban_action(ws).await
|
|
.inspect_err(
|
|
|kind| tracing::warn!("[WebSocket] Didn't get the ban message due to {}", kind)
|
|
);
|
|
bail!("{} banned!", session.user.nickname)
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn sub_worker(tx_main: mpsc::Sender<SessionMessage>, mut rx: broadcast::Receiver<Vec<u8>>) {
|
|
loop {
|
|
let msg = match rx.recv().await {
|
|
Ok(m) => m,
|
|
Err(kind) => {
|
|
tracing::error!("[Subscribes_Worker] Broadcast error! {}", kind);
|
|
return;
|
|
},
|
|
};
|
|
match tx_main.send(SessionMessage::Ping(msg)).await {
|
|
Ok(_) => (),
|
|
Err(kind) => {
|
|
tracing::error!("[Subscribes_Worker] Session error! {}", kind);
|
|
return;
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn authenticate(socket: &mut WebSocket, state: &AppState) -> Result<Userinfo, AuthModeError> {
|
|
match socket.recv_and_decode().await {
|
|
Ok(msg) => {
|
|
match msg {
|
|
C2SMessage::Token(token) => {
|
|
let token = String::from_utf8(token.to_vec()).map_err(|_| AuthModeError::ConvertError)?;
|
|
match state.user_manager.get(&token) {
|
|
Some(user) => {
|
|
if socket.send(Message::Binary(Bytes::from(Into::<Vec<u8>>::into(S2CMessage::Auth)))).await.is_err() {
|
|
Err(AuthModeError::SendError)
|
|
} else if !user.banned {
|
|
Ok(user.clone())
|
|
} else {
|
|
let _ = ban_action(socket).await
|
|
.inspect_err(
|
|
|kind| tracing::warn!("[WebSocket] Didn't get the ban message due to {}", kind)
|
|
);
|
|
Err(AuthModeError::Banned(user.nickname.clone()))
|
|
}
|
|
},
|
|
None => {
|
|
if socket.send(
|
|
Message::Close(Some(axum::extract::ws::CloseFrame { code: 4000, reason: "Re-auth".into() }))
|
|
).await.is_err() {
|
|
Err(AuthModeError::SendError)
|
|
} else {
|
|
Err(AuthModeError::AuthenticationFailure)
|
|
}
|
|
},
|
|
}
|
|
},
|
|
_ => {
|
|
Err(AuthModeError::UnauthorizedAction)
|
|
}
|
|
}
|
|
},
|
|
Err(err) => {
|
|
Err(AuthModeError::RecvError(err))
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn ban_action(ws: &mut WebSocket) -> anyhow::Result<()> {
|
|
ws.send(Message::Binary(Into::<Vec<u8>>::into(S2CMessage::Toast(2, "You're banned!".to_string(), None)).into())).await?;
|
|
tokio::time::sleep(std::time::Duration::from_secs(6)).await;
|
|
ws.send(Message::Close(Some(axum::extract::ws::CloseFrame { code: 4001, reason: "You're banned!".into() }))).await?;
|
|
|
|
Ok(())
|
|
} |