use bytes::{Buf, Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; use postgres_protocol2::message::frontend::CopyData; use std::io; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { Raw(Bytes), CopyData(CopyData>), } pub enum BackendMessage { Normal { messages: BackendMessages, request_complete: bool, }, Async(backend::Message), } pub struct BackendMessages(BytesMut); impl BackendMessages { pub fn empty() -> BackendMessages { BackendMessages(BytesMut::new()) } } impl FallibleIterator for BackendMessages { type Item = backend::Message; type Error = io::Error; fn next(&mut self) -> io::Result> { backend::Message::parse(&mut self.0) } } pub struct PostgresCodec; impl Encoder for PostgresCodec { type Error = io::Error; fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { match item { FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), FrontendMessage::CopyData(data) => data.write(dst), } Ok(()) } } impl Decoder for PostgresCodec { type Item = BackendMessage; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; let mut request_complete = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; if src[idx..].len() < len { break; } match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => { if idx == 0 { let message = backend::Message::parse(src)?.unwrap(); return Ok(Some(BackendMessage::Async(message))); } else { break; } } _ => {} } idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { request_complete = true; break; } } if idx == 0 { Ok(None) } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), request_complete, })) } } }