diff --git a/integrations/rocket/Cargo.toml b/integrations/rocket/Cargo.toml index 2eda0c3b..e6431ba9 100644 --- a/integrations/rocket/Cargo.toml +++ b/integrations/rocket/Cargo.toml @@ -16,7 +16,7 @@ categories = ["network-programming", "asynchronous"] [dependencies] async-graphql = { path = "../..", version = "=2.7.1" } -rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "48fd83a", default-features = false } # TODO: Change to Cargo crate when Rocket 0.5.0 is released +rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "2893ce7", default-features = false } # TODO: Change to Cargo crate when Rocket 0.5.0 is released serde = "1.0.125" serde_json = "1.0.64" tokio-util = { version = "0.6.5", default-features = false, features = ["compat"] } diff --git a/integrations/rocket/src/lib.rs b/integrations/rocket/src/lib.rs index c66080da..18ce8fc0 100644 --- a/integrations/rocket/src/lib.rs +++ b/integrations/rocket/src/lib.rs @@ -15,18 +15,14 @@ use std::io::Cursor; use async_graphql::http::MultipartOptions; use async_graphql::{ObjectType, ParseRequestError, Schema, SubscriptionType}; -use query_deserializer::QueryDeserializer; use rocket::{ data::{self, Data, FromData, ToByteUnit}, + form::FromForm, http::{ContentType, Header, Status}, - request::{self, FromQuery}, response::{self, Responder}, }; -use serde::de::Deserialize; use tokio_util::compat::TokioAsyncReadCompatExt; -mod query_deserializer; - /// A batch request which can be extracted from a request's body. /// /// # Examples @@ -56,11 +52,14 @@ impl BatchRequest { } #[rocket::async_trait] -impl FromData for BatchRequest { +impl<'r> FromData<'r> for BatchRequest { type Error = ParseRequestError; - async fn from_data(req: &rocket::Request<'_>, data: Data) -> data::Outcome { - let opts: MultipartOptions = req.managed_state().copied().unwrap_or_default(); + async fn from_data( + req: &'r rocket::Request<'_>, + data: Data, + ) -> data::Outcome { + let opts: MultipartOptions = req.rocket().state().copied().unwrap_or_default(); let request = async_graphql::http::receive_batch_body( req.headers().get_one("Content-Type"), @@ -87,17 +86,12 @@ impl FromData for BatchRequest { } } -/// A GraphQL request which can be extracted from a query string or the request's body. +/// A GraphQL request which can be extracted from the request's body. /// /// # Examples /// /// ```ignore -/// #[rocket::post("/graphql?", rank = 2)] -/// async fn graphql_query(schema: State<'_, ExampleSchema>, query: Request) -> Result { -/// query.execute(&schema).await -/// } -/// -/// #[rocket::post("/graphql", data = "", format = "application/json", rank = 1)] +/// #[rocket::post("/graphql", data = "", format = "application/json", rank = 2)] /// async fn graphql_request(schema: State<'_, ExampleSchema>, request: Request) -> Result { /// request.execute(&schema).await /// } @@ -120,21 +114,66 @@ impl Request { } } -impl<'q> FromQuery<'q> for Request { - type Error = serde::de::value::Error; +impl From for Request { + fn from(query: Query) -> Self { + let mut request = async_graphql::Request::new(query.query); - fn from_query(query: request::Query<'_>) -> Result { - Ok(Self(async_graphql::Request::deserialize( - QueryDeserializer(query), - )?)) + if let Some(operation_name) = query.operation_name { + request = request.operation_name(operation_name); + } + + if let Some(variables) = query.variables { + let value = serde_json::from_str(&variables).unwrap_or_default(); + let variables = async_graphql::Variables::from_json(value); + request = request.variables(variables); + } + + Request(request) + } +} + +/// A GraphQL request which can be extracted from a query string. +/// +/// # Examples +/// +/// ```ignore +/// #[rocket::get("/graphql?")] +/// async fn graphql_query(schema: State<'_, ExampleSchema>, query: Query) -> Result { +/// query.execute(&schema).await +/// } +/// ``` +#[derive(FromForm, Debug)] +pub struct Query { + query: String, + #[field(name = "operationName")] + operation_name: Option, + variables: Option, +} + +impl Query { + /// Shortcut method to execute the request on the schema. + pub async fn execute( + self, + schema: &Schema, + ) -> Response + where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + { + let request: Request = self.into(); + request.execute(schema).await } } #[rocket::async_trait] -impl FromData for Request { +impl<'r> FromData<'r> for Request { type Error = ParseRequestError; - async fn from_data(req: &rocket::Request<'_>, data: Data) -> data::Outcome { + async fn from_data( + req: &'r rocket::Request<'_>, + data: Data, + ) -> data::Outcome { BatchRequest::from_data(req, data) .await .and_then(|request| match request.0.into_single() { diff --git a/integrations/rocket/src/query_deserializer.rs b/integrations/rocket/src/query_deserializer.rs deleted file mode 100644 index b4958522..00000000 --- a/integrations/rocket/src/query_deserializer.rs +++ /dev/null @@ -1,65 +0,0 @@ -use rocket::http::RawStr; -use rocket::request::Query; -use serde::de::{DeserializeSeed, Deserializer, Error as _, IntoDeserializer, MapAccess, Visitor}; -use serde::forward_to_deserialize_any; - -/// A wrapper around `rocket::request::Query` that implements `Deserializer`. -pub(crate) struct QueryDeserializer<'q>(pub(crate) Query<'q>); - -impl<'q, 'de> Deserializer<'de> for QueryDeserializer<'q> { - type Error = serde::de::value::Error; - - fn deserialize_any>(self, visitor: V) -> Result { - visitor.visit_map(QueryMapAccess { - query: self.0, - value: None, - }) - } - - forward_to_deserialize_any! { - bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map struct enum identifier ignored_any - } -} - -struct QueryMapAccess<'q> { - query: Query<'q>, - value: Option<&'q RawStr>, -} - -impl<'q, 'de> MapAccess<'de> for QueryMapAccess<'q> { - type Error = serde::de::value::Error; - - fn next_key_seed>( - &mut self, - seed: K, - ) -> Result, Self::Error> { - self.query - .next() - .map(|item| { - self.value = Some(item.value); - seed.deserialize( - item.key - .url_decode() - .map_err(Self::Error::custom)? - .into_deserializer(), - ) - }) - .transpose() - } - - fn next_value_seed>( - &mut self, - seed: V, - ) -> Result { - seed.deserialize( - self.value - .take() - .unwrap() - .url_decode() - .map_err(Self::Error::custom)? - .into_deserializer(), - ) - } -}