diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index a6b5cb74..3b6c5e1e 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -67,12 +67,30 @@ pub fn graphql_subscription_with_data( schema: Schema, initializer: F, ) -> impl Filter + Clone +where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + F: FnOnce(serde_json::Value) -> R + Clone + Send + 'static, + R: Future> + Send + 'static, { graphql_subscription_with_data_and_callbacks(schema, initializer, ||{}, ||{}) } + +/// GraphQL subscription filter +/// +/// Specifies that a function converts the init payload to data. +pub fn graphql_subscription_with_data_and_callbacks( + schema: Schema, + initializer: F, + open_callback: FO, + close_callback: FC, +) -> impl Filter + Clone where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, F: FnOnce(serde_json::Value) -> R + Clone + Send + 'static, R: Future> + Send + 'static, + FO: FnOnce() + Clone + Send + 'static, + FC: FnOnce() + Clone + Send + 'static, { use async_graphql::http::WebSocketProtocols; use std::str::FromStr; @@ -82,6 +100,8 @@ where .map(move |ws: ws::Ws, protocols: Option| { let schema = schema.clone(); let initializer = initializer.clone(); + let open_callback = open_callback.clone(); + let close_callback = close_callback.clone(); let protocol = protocols .and_then(|protocols| { @@ -94,6 +114,8 @@ where let reply = ws.on_upgrade(move |websocket| { let (ws_sender, ws_receiver) = websocket.split(); + open_callback(); + async move { let _ = async_graphql::http::WebSocket::with_data( schema, @@ -112,6 +134,8 @@ where .map(Ok) .forward(ws_sender) .await; + + close_callback(); } });