megamappingway/server/src/main.rs

647 lines
18 KiB
Rust

use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use axum::{async_trait, BoxError, Router, Server};
use axum::body::HttpBody;
use axum::error_handling::HandleErrorLayer;
use axum::extract::{FromRequest, Path, Query, State};
use axum::extract::rejection::BytesRejection;
use axum::http::{HeaderValue, Request, StatusCode};
#[cfg(not(debug_assertions))]
use axum::http::Method;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use bytes::Bytes;
use half::f16;
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;
use siphasher::sip::SipHasher;
use sqlx::PgPool;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use tokio::sync::RwLock;
use tower::ServiceBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::CompressionLevel;
use tower_http::cors::CorsLayer;
use tower_http::decompression::RequestDecompressionLayer;
use crate::config::Config;
use crate::generated::game_info::{TERRITORY_INFO, TerritoryInfo};
mod config;
mod generated;
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
type Populations = HashMap<u32, HashMap<u32, i64>>;
struct AppState {
salt: String,
pool: Arc<PgPool>,
hasher: SipHasher,
populations: Arc<RwLock<Populations>>,
}
#[tokio::main]
async fn main() -> Result<()> {
let config: Config = {
let t = tokio::fs::read_to_string("./config.toml").await?;
toml::from_str(&t)?
};
let pool = PgPoolOptions::new()
.max_connections(50)
.connect_with(
PgConnectOptions::new()
.host(&config.database.host)
.port(config.database.port)
.database(&config.database.name)
.username(&config.database.username)
.password(&config.database.password)
)
.await?;
MIGRATOR.run(&pool).await?;
let pool = Arc::new(pool);
// spawn old info deletion task
{
let pool = Arc::clone(&pool);
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60 * 30));
loop {
interval.tick().await;
println!("deleting old records");
let result = sqlx::query!(
// language=postgresql
"delete from players where current_timestamp - timestamp > interval '1 hour'",
)
.execute(&*pool)
.await;
match result {
Ok(res) => println!("{} record(s) deleted", res.rows_affected()),
Err(e) => eprintln!("could not delete old records: {e:#}"),
}
}
});
}
// spawn population task
let populations_cache = Arc::new(RwLock::new(HashMap::default()));
{
let pool = Arc::clone(&pool);
let populations_cache = Arc::clone(&populations_cache);
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
let result = sqlx::query!(
// language=postgresql
r#"select territory, current_world, coalesce(count(*), 0) as "count!" from players where current_timestamp - timestamp < interval '30 seconds' group by territory, current_world"#
)
.fetch_all(&*pool)
.await;
let result = match result {
Ok(r) => r,
Err(e) => {
eprintln!("could not calculate populations: {e:#}");
continue;
}
};
let mut output: Populations = HashMap::with_capacity(result.len());
for record in result {
output.entry(record.territory as u32)
.or_default()
.insert(record.current_world as u32, record.count);
}
*populations_cache.write().await = output;
}
});
}
let state = Arc::new(AppState {
pool,
salt: data_encoding::BASE64_NOPAD.encode(&{
let mut bytes = [0; 16];
rand::thread_rng().fill(&mut bytes);
bytes
}),
hasher: SipHasher::new(),
populations: populations_cache,
});
#[cfg(not(debug_assertions))]
let cors = CorsLayer::new()
.allow_origin([
"https://map.anna.lgbt".parse()?,
])
.allow_methods([Method::GET]);
#[cfg(debug_assertions)]
let cors = CorsLayer::permissive();
let app = Router::new()
.route("/:territory", get(territory))
.route("/upload", post(upload))
.with_state(state)
.layer(cors)
.layer(CompressionLayer::new().quality(CompressionLevel::Best))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async move {
(StatusCode::INTERNAL_SERVER_ERROR, "unhandled server error")
}))
.layer(RequestDecompressionLayer::new())
);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
tokio::task::spawn(async move {
tokio::signal::ctrl_c().await.ok();
println!("caught ctrl-c, shutting down");
shutdown_tx.send(()).ok();
});
Server::bind(&"127.0.0.1:30888".parse()?)
.serve(app.into_make_service())
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await?;
Ok(())
}
#[derive(Deserialize, Default)]
#[serde(default)]
struct TerritoryQuery {
world: Option<u32>,
party: Option<u32>,
}
async fn territory(
State(state): State<Arc<AppState>>,
Path(territory): Path<u32>,
Query(query): Query<TerritoryQuery>,
) -> Result<MsgPack<QueryResponse<Vec<AnonymousPlayerInfo>>>, AppError>
{
let territory_info = match TERRITORY_INFO.get(territory as usize) {
Some(info) => *info,
None => {
eprintln!("warn: missing territory info for territory {territory}");
TerritoryInfo {
parties_visible: false,
map_visible: false,
}
}
};
if !territory_info.map_visible {
return Ok(MsgPack(QueryResponse {
populations: state.populations.read().await.clone(),
parties: Default::default(),
data: Default::default(),
}));
}
let info = sqlx::query_as!(
AnonymousPlayerInfoInternal,
// language=postgresql
r#"
select hash,
world,
x,
y,
z,
w,
customize,
level,
job,
current_hp,
max_hp,
party_id,
coalesce(extract('epoch' from current_timestamp - timestamp), 30)::bigint as "age!"
from players
where territory = $1
and ($2 or current_world = $3)
and current_timestamp - timestamp < interval '30 seconds'
"#,
territory as i64,
query.world.is_none(),
query.world.map(|x| x as i32),
)
.fetch_all(&*state.pool)
.await?;
let parties = if territory_info.parties_visible {
get_parties(&state, territory, &info)
} else {
Default::default()
};
let mut info: Vec<_> = info.into_iter()
.filter(|player| match query.party {
None => true,
// don't allow filtering on party-hidden maps
_ if !territory_info.parties_visible => true,
x => player.party_hash(&state.hasher, &state.salt, territory) == x,
})
.map(|player| AnonymousPlayerInfo::new_from(player, &state.hasher, &state.salt, territory))
.collect();
info.sort_unstable_by_key(|p| p.territory_unique_id);
Ok(MsgPack(QueryResponse {
populations: state.populations.read().await.clone(),
parties,
data: info,
}))
}
async fn upload(
state: State<Arc<AppState>>,
data: MsgPack<Update>,
) -> Result<(), AppError> {
if data.version != 3 {
return Err(anyhow::anyhow!("invalid update request version").into());
}
let mut t = state.pool.begin().await?;
for player in &data.players {
if !player.is_sane() {
continue;
}
sqlx::query!(
// language=postgresql
"
insert into players (hash, world, timestamp, territory, current_world,
x, y, z, w, customize,
level, job, current_hp, max_hp, party_id)
values ($1, $2, current_timestamp, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
on conflict (hash) do update set timestamp = current_timestamp,
world = $2,
territory = $3,
current_world = $4,
x = $5,
y = $6,
z = $7,
w = $8,
customize = $9,
level = $10,
job = $11,
current_hp = $12,
max_hp = $13,
party_id = $14
",
player.hash,
player.world as i64,
data.territory as i64,
data.world as i64,
player.x,
player.y,
player.z,
player.w,
player.customize,
player.level as i64,
player.job as i64,
player.current_hp as i64,
player.max_hp as i64,
player.party_id.map(|id| id as i64),
)
.execute(&mut *t)
.await?;
}
t.commit().await?;
Ok(())
}
fn get_parties(state: &AppState, territory: u32, info: &[AnonymousPlayerInfoInternal]) -> Vec<u32> {
let mut parties: Vec<u32> = info.iter()
.flat_map(|player| player.party_hash(&state.hasher, &state.salt, territory))
.collect();
parties.sort_unstable();
parties.dedup();
parties
}
#[derive(Deserialize)]
struct Update {
version: u8,
territory: u32,
world: u32,
players: Vec<PlayerInfo>,
}
#[derive(Deserialize)]
struct PlayerInfo {
#[serde(with = "serde_bytes")]
hash: Vec<u8>,
world: u32,
x: f64,
y: f64,
z: f64,
w: f64,
#[serde(with = "serde_bytes")]
customize: Vec<u8>,
level: u8,
job: u32,
current_hp: u32,
max_hp: u32,
party_id: Option<u64>,
}
impl PlayerInfo {
fn is_sane(&self) -> bool {
!self.hash.is_empty()
&& self.world != 65535
&& self.level != 0
&& self.customize.len() >= 26
}
}
#[derive(Serialize)]
struct QueryResponse<T> {
populations: Populations,
parties: Vec<u32>,
data: T,
}
#[derive(Serialize)]
struct AnonymousPlayerInfo {
#[serde(with = "serde_bytes")]
floats: Vec<u8>,
gender: u8,
race: u8,
level: u8,
job: u8,
age: u8,
territory_unique_id: u64,
}
impl AnonymousPlayerInfo {
fn new_from(value: AnonymousPlayerInfoInternal, hasher: &SipHasher, salt: &str, territory: u32) -> Self {
let customize = value.customize();
let x = f16::from_f64(value.x).to_le_bytes();
let y = f16::from_f64(value.y).to_le_bytes();
let z = f16::from_f64(value.z).to_le_bytes();
let w = f16::from_f64(value.w).to_le_bytes();
let hp = f16::from_f64(value.current_hp as f64 / value.max_hp as f64).to_le_bytes();
let floats = x.into_iter()
.chain(y)
.chain(z)
.chain(w)
.chain(hp)
.collect();
Self {
territory_unique_id: value.gen_hash(hasher, salt, territory),
floats,
gender: customize.gender,
race: customize.race,
level: value.level as u8,
job: value.job as u8,
age: value.age.max(0) as u8,
}
}
}
#[derive(Default)]
#[allow(dead_code)]
struct Customize {
race: u8,
gender: u8,
model_type: u8,
height: u8,
tribe: u8,
face_type: u8,
hairstyle: u8,
has_highlights: u8,
skin_colour: u8,
eye_colour: u8,
hair_colour: u8,
hair_colour2: u8,
face_features: u8,
face_features_colour: u8,
eyebrows: u8,
eye_colour2: u8,
eye_shape: u8,
nose_shape: u8,
jaw_shape: u8,
lip_style: u8,
lip_colour: u8,
race_feature_size: u8,
race_feature_type: u8,
bust_size: u8,
facepaint: u8,
facepaint_colour: u8,
}
impl Customize {
pub fn new(data: &[u8]) -> Option<Self> {
if data.len() < 26 {
return None;
}
Some(Self {
race: data[0],
gender: data[1],
model_type: data[2],
height: data[3],
tribe: data[4],
face_type: data[5],
hairstyle: data[6],
has_highlights: data[7],
skin_colour: data[8],
eye_colour: data[9],
hair_colour: data[10],
hair_colour2: data[11],
face_features: data[12],
face_features_colour: data[13],
eyebrows: data[14],
eye_colour2: data[15],
eye_shape: data[16],
nose_shape: data[17],
jaw_shape: data[18],
lip_style: data[19],
lip_colour: data[20],
race_feature_size: data[21],
race_feature_type: data[22],
bust_size: data[23],
facepaint: data[24],
facepaint_colour: data[25],
})
}
}
#[allow(dead_code)]
struct AnonymousPlayerInfoInternal {
hash: Vec<u8>,
world: i64,
x: f64,
y: f64,
z: f64,
w: f64,
customize: Vec<u8>,
level: i64,
job: i64,
current_hp: i64,
max_hp: i64,
party_id: Option<i64>,
age: i64,
}
impl AnonymousPlayerInfoInternal {
pub fn gen_hash(&self, hasher: &SipHasher, salt: &str, territory: u32) -> u64 {
hasher.hash(format!(
"{salt}-{territory}-{}",
data_encoding::HEXLOWER.encode(&self.hash),
).as_bytes())
}
pub fn customize(&self) -> Customize {
Customize::new(&self.customize).unwrap_or_default()
}
pub fn party_hash(&self, hasher: &SipHasher, salt: &str, territory: u32) -> Option<u32> {
let party_id = self.party_id?;
Some(hasher.hash(
format!("{}-{territory}-{:x}", salt, party_id).as_bytes()
) as u32)
}
}
// Make our own error that wraps `anyhow::Error`.
struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}
struct MsgPack<T>(pub T);
impl<T> Deref for MsgPack<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for MsgPack<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait]
impl<S, B, T> FromRequest<S, B> for MsgPack<T>
where S: Send + Sync,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
T: DeserializeOwned,
{
type Rejection = MsgPackRejection;
async fn from_request(req: Request<B>, state: &S) -> std::result::Result<Self, Self::Rejection> {
if req.headers().get("content-type").and_then(|val| val.to_str().ok()) != Some("application/msgpack") {
return Err(MsgPackRejection::ContentType);
}
let bytes = Bytes::from_request(req, state).await?;
let value = rmp_serde::from_slice(&bytes)?;
// let des = &mut rmp_serde::Deserializer::from_read_ref(&bytes);
// let value = match serde_path_to_error::deserialize(des) {
// Ok(v) => v,
// Err(err) => {
// let rejection = match err.inner() {
// rmp_serde::decode::Error::DepthLimitExceeded => {},
// };
//
// return Err(rejection);
// }
// };
Ok(MsgPack(value))
}
}
impl<T> IntoResponse for MsgPack<T>
where T: Serialize,
{
fn into_response(self) -> Response {
(
[(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static("application/msgpack"),
)],
rmp_serde::to_vec(&self.0)
.map_err(|e| AppError(e.into()))
.into_response()
).into_response()
}
}
enum MsgPackRejection {
ContentType,
Bytes(BytesRejection),
MsgPack(rmp_serde::decode::Error),
}
impl IntoResponse for MsgPackRejection {
fn into_response(self) -> Response {
(
StatusCode::BAD_REQUEST,
match self {
Self::ContentType => "expected application/msgpack content-type header".into_response(),
Self::Bytes(e) => e.into_response(),
Self::MsgPack(e) => format!("could not deserialize msgpack: {:#}", e).into_response(),
}
).into_response()
}
}
impl From<BytesRejection> for MsgPackRejection {
fn from(value: BytesRejection) -> Self {
Self::Bytes(value)
}
}
impl From<rmp_serde::decode::Error> for MsgPackRejection {
fn from(value: rmp_serde::decode::Error) -> Self {
Self::MsgPack(value)
}
}