192 lines
6.7 KiB
Rust
192 lines
6.7 KiB
Rust
use futures::io::ErrorKind;
|
|
use futures::task::{Context, Poll};
|
|
use futures::{AsyncBufRead, Future};
|
|
use std::io::{Error, Result};
|
|
use std::pin::Pin;
|
|
|
|
pub trait AsyncTokenReader: AsyncBufRead {
|
|
fn read_until_token<'a>(
|
|
&'a mut self,
|
|
token: &'a [u8],
|
|
buf: &'a mut [u8],
|
|
state: &'a mut ReadUntilState,
|
|
) -> ReadUntilToken<'a, Self> {
|
|
ReadUntilToken {
|
|
reader: self,
|
|
token,
|
|
buf,
|
|
state,
|
|
}
|
|
}
|
|
|
|
fn except_token<'a>(&'a mut self, token: &'a [u8]) -> ExceptToken<'a, Self> {
|
|
ExceptToken {
|
|
reader: self,
|
|
token,
|
|
match_size: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncBufRead> AsyncTokenReader for R {}
|
|
|
|
#[derive(Default)]
|
|
pub struct ReadUntilState {
|
|
match_size: usize,
|
|
consume_token: Option<(usize, usize)>,
|
|
}
|
|
|
|
pub struct ReadUntilToken<'a, R: ?Sized> {
|
|
reader: &'a mut R,
|
|
token: &'a [u8],
|
|
buf: &'a mut [u8],
|
|
state: &'a mut ReadUntilState,
|
|
}
|
|
|
|
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntilToken<'a, R> {
|
|
type Output = Result<(usize, bool)>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = &mut *self;
|
|
let mut rsz = 0;
|
|
|
|
loop {
|
|
let nsz = this.buf.len() - rsz;
|
|
|
|
if let Some((pos, size)) = &mut this.state.consume_token {
|
|
let sz = (*size - *pos).min(nsz);
|
|
this.buf[rsz..rsz + sz].copy_from_slice(&this.token[*pos..*pos + sz]);
|
|
*pos += sz;
|
|
rsz += sz;
|
|
if *pos == *size {
|
|
this.state.consume_token = None;
|
|
}
|
|
if rsz == this.buf.len() {
|
|
return Poll::Ready(Ok((rsz, false)));
|
|
}
|
|
} else {
|
|
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
|
|
Poll::Pending => return Poll::Pending,
|
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
|
Poll::Ready(Ok(data)) if data.is_empty() => {
|
|
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
|
|
}
|
|
Poll::Ready(Ok(data)) => {
|
|
let mut consume_size = data.len();
|
|
for (idx, b) in data.iter().enumerate() {
|
|
if *b == this.token[this.state.match_size] {
|
|
this.state.match_size += 1;
|
|
if this.state.match_size == this.token.len() {
|
|
Pin::new(&mut this.reader).consume(idx + 1);
|
|
this.state.match_size = 0;
|
|
return Poll::Ready(Ok((rsz, true)));
|
|
}
|
|
} else if this.state.match_size > 0 {
|
|
this.state.consume_token = Some((0, this.state.match_size));
|
|
this.state.match_size = 0;
|
|
consume_size = idx;
|
|
break;
|
|
} else {
|
|
this.buf[rsz] = *b;
|
|
rsz += 1;
|
|
if rsz == this.buf.len() {
|
|
Pin::new(&mut this.reader).consume(idx + 1);
|
|
return Poll::Ready(Ok((rsz, false)));
|
|
}
|
|
}
|
|
}
|
|
Pin::new(&mut this.reader).consume(consume_size);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ExceptToken<'a, R: ?Sized> {
|
|
reader: &'a mut R,
|
|
token: &'a [u8],
|
|
match_size: usize,
|
|
}
|
|
|
|
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ExceptToken<'a, R> {
|
|
type Output = Result<()>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = &mut *self;
|
|
|
|
loop {
|
|
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
|
|
Poll::Pending => return Poll::Pending,
|
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
|
Poll::Ready(Ok(data)) if data.is_empty() => {
|
|
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
|
|
}
|
|
Poll::Ready(Ok(data)) => {
|
|
for b in data {
|
|
if *b == this.token[this.match_size] {
|
|
this.match_size += 1;
|
|
if this.match_size == this.token.len() {
|
|
Pin::new(&mut this.reader).consume(this.match_size);
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
} else {
|
|
return Poll::Ready(Err(Error::from(ErrorKind::InvalidData)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::http::token_reader::{AsyncTokenReader, ReadUntilState};
|
|
use futures::io::BufReader;
|
|
|
|
#[async_std::test]
|
|
async fn test_read_until_token() {
|
|
let data: &[u8] = b"12AB567890ABC12345ABC6";
|
|
let mut reader = BufReader::new(data);
|
|
let mut buf = [0; 3];
|
|
let mut state = ReadUntilState::default();
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((3, false))));
|
|
assert_eq!(&buf, b"12A");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((3, false))));
|
|
assert_eq!(&buf, b"B56");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((3, false))));
|
|
assert_eq!(&buf, b"789");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((1, true))));
|
|
assert_eq!(&buf[..1], b"0");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((3, false))));
|
|
assert_eq!(&buf, b"123");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Ok((2, true))));
|
|
assert_eq!(&buf[..2], b"45");
|
|
|
|
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
|
assert!(matches!(res, Err(_)));
|
|
}
|
|
|
|
#[async_std::test]
|
|
async fn test_read_expect_token() {
|
|
let data: &[u8] = b"ABCABC";
|
|
let mut reader = BufReader::new(data);
|
|
assert!(reader.except_token(b"ABC").await.is_ok());
|
|
assert!(reader.except_token(b"ABC").await.is_ok());
|
|
assert!(reader.except_token(b"ABC").await.is_err());
|
|
}
|
|
}
|