This commit is contained in:
Anna 2022-09-10 06:57:25 -04:00
parent 97f5f50dfc
commit 27ef958ee6
1 changed files with 23 additions and 20 deletions

View File

@ -1,5 +1,4 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::error::Error;
use std::sync::Arc; use std::sync::Arc;
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
@ -39,12 +38,14 @@ pub fn routes(state: Arc<State>) -> BoxedFilter<(impl Reply, )> {
} }
pub fn get_id(state: Arc<State>) -> BoxedFilter<((i64, i64), )> { pub fn get_id(state: Arc<State>) -> BoxedFilter<((i64, i64), )> {
warp::cookie("access_token") warp::header::optional("x-api-key")
.or(warp::header("x-api-key")) .and_then(move |access_token: Option<String>| {
.unify()
.and_then(move |access_token: String| {
let state = Arc::clone(&state); let state = Arc::clone(&state);
async move { async move {
let access_token = match access_token {
Some(t) => t,
None => return Err(warp::reject::custom(WebError::MissingAuthToken)),
};
let hashed = crate::util::hash(&access_token); let hashed = crate::util::hash(&access_token);
let id = sqlx::query!( let id = sqlx::query!(
// language=sqlite // language=sqlite
@ -65,6 +66,7 @@ pub fn get_id(state: Arc<State>) -> BoxedFilter<((i64, i64), )> {
#[derive(Debug)] #[derive(Debug)]
pub enum WebError { pub enum WebError {
MissingAuthToken,
InvalidAuthToken, InvalidAuthToken,
InvalidPackId, InvalidPackId,
InvalidIndex, InvalidIndex,
@ -81,41 +83,42 @@ pub struct AnyhowRejection(anyhow::Error);
impl Reject for AnyhowRejection {} impl Reject for AnyhowRejection {}
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> { async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
let (status, name, desc) = if err.is_not_found() { let (status, name, desc) = if let Some(e) = err.find::<WebError>() {
(StatusCode::NOT_FOUND, "not_found", "route was unknown to the server") match e {
WebError::MissingAuthToken => (StatusCode::BAD_REQUEST, "missing_auth_token", "an auth token was not provided".into()),
WebError::InvalidAuthToken => (StatusCode::BAD_REQUEST, "invalid_auth_token", "the auth token was not valid".into()),
WebError::InvalidPackId => (StatusCode::NOT_FOUND, "invalid_pack_id", "the server does not have a pack registered with that id".into()),
WebError::InvalidIndex => (StatusCode::NOT_FOUND, "invalid_index", "one of the provided indices was out of range".into()),
WebError::TooManyMessages => (StatusCode::BAD_REQUEST, "too_many_messages", "you have run out of messages - delete one and try again".into()),
WebError::NoSuchMessage => (StatusCode::NOT_FOUND, "no_such_message", "no message with that id was found".into()),
WebError::InvalidExtraCode => (StatusCode::BAD_REQUEST, "invalid_extra_code", "that extra code was not found".into()),
}
} else if err.is_not_found() {
(StatusCode::NOT_FOUND, "not_found", "route was unknown to the server".into())
} else if let Some(e) = err.find::<BodyDeserializeError>() { } else if let Some(e) = err.find::<BodyDeserializeError>() {
(StatusCode::BAD_REQUEST, "invalid_body", format!("invalid body: {}", e)) (StatusCode::BAD_REQUEST, "invalid_body", format!("invalid body: {}", e))
} else if let Some(_) = err.find::<MethodNotAllowed>() { } else if let Some(_) = err.find::<MethodNotAllowed>() {
(StatusCode::METHOD_NOT_ALLOWED, "method_not_allowed", "that http method is not allowed on that route") (StatusCode::METHOD_NOT_ALLOWED, "method_not_allowed", "that http method is not allowed on that route".into())
} else if let Some(AnyhowRejection(e)) = err.find::<AnyhowRejection>() { } else if let Some(AnyhowRejection(e)) = err.find::<AnyhowRejection>() {
eprintln!("{:#?}", e); eprintln!("{:#?}", e);
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"internal_error", "internal_error",
"an internal logic error occured", "an internal logic error occured".into(),
) )
} else if let Some(e) = err.find::<WebError>() {
match e {
WebError::InvalidAuthToken => (StatusCode::BAD_REQUEST, "invalid_auth_token", "the auth token was not valid"),
WebError::InvalidPackId => (StatusCode::NOT_FOUND, "invalid_pack_id", "the server does not have a pack registered with that id"),
WebError::InvalidIndex => (StatusCode::NOT_FOUND, "invalid_index", "one of the provided indices was out of range"),
WebError::TooManyMessages => (StatusCode::BAD_REQUEST, "too_many_messages", "you have run out of messages - delete one and try again"),
WebError::NoSuchMessage => (StatusCode::NOT_FOUND, "no_such_message", "no message with that id was found"),
WebError::InvalidExtraCode => (StatusCode::BAD_REQUEST, "invalid_extra_code", "that extra code was not found"),
}
} else { } else {
eprintln!("{:#?}", err); eprintln!("{:#?}", err);
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"internal_error", "internal_error",
"an unhandled error was encountered", "an unhandled error was encountered".into(),
) )
}; };
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
struct ErrorMessage { struct ErrorMessage {
code: &'static str, code: &'static str,
message: &'static str, message: String,
} }
let message = ErrorMessage { let message = ErrorMessage {