diff --git a/src/http/websocket.rs b/src/http/websocket.rs index dc8d8d87..27779f44 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -93,55 +93,58 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.project(); - while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) { - let message = match message { - Some(message) => message, - None => return Poll::Ready(None), - }; + if this.init_fut.is_none() { + while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) { + let message = match message { + Some(message) => message, + None => return Poll::Ready(None), + }; - let message: ClientMessage = match serde_json::from_slice(message.as_ref()) { - Ok(message) => message, - Err(e) => { - return Poll::Ready(Some( - serde_json::to_string(&ServerMessage::ConnectionError { - payload: Error::new(e.to_string()), - }) - .unwrap(), - )) - } - }; - - match message { - ClientMessage::ConnectionInit { payload } => { - if let Some(data_initializer) = this.data_initializer.take() { - *this.init_fut = Some(Box::pin(async move { - data_initializer(payload.unwrap_or_default()).await - })); - } - } - ClientMessage::Start { - id, - payload: request, - } => { - this.streams.insert( - id, - Box::pin( - this.schema - .execute_stream_with_ctx_data(request, Arc::clone(this.data)), - ), - ); - } - ClientMessage::Stop { id } => { - if this.streams.remove(id).is_some() { + let message: ClientMessage = match serde_json::from_slice(message.as_ref()) { + Ok(message) => message, + Err(e) => { return Poll::Ready(Some( - serde_json::to_string(&ServerMessage::Complete { id }).unwrap(), - )); + serde_json::to_string(&ServerMessage::ConnectionError { + payload: Error::new(e.to_string()), + }) + .unwrap(), + )) } + }; + + match message { + ClientMessage::ConnectionInit { payload } => { + if let Some(data_initializer) = this.data_initializer.take() { + *this.init_fut = Some(Box::pin(async move { + data_initializer(payload.unwrap_or_default()).await + })); + break; + } + } + ClientMessage::Start { + id, + payload: request, + } => { + this.streams.insert( + id, + Box::pin( + this.schema + .execute_stream_with_ctx_data(request, Arc::clone(this.data)), + ), + ); + } + ClientMessage::Stop { id } => { + if this.streams.remove(id).is_some() { + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::Complete { id }).unwrap(), + )); + } + } + // Note: in the revised `graphql-ws` spec, there is no equivalent to the + // `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is + // handled by disconnecting the websocket + ClientMessage::ConnectionTerminate => return Poll::Ready(None), } - // Note: in the revised `graphql-ws` spec, there is no equivalent to the - // `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is - // handled by disconnecting the websocket - ClientMessage::ConnectionTerminate => return Poll::Ready(None), } }