From 626b4e9987ef482604342f5a71016dbf8b31dbaa Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Wed, 17 Mar 2021 18:42:40 +0300 Subject: [PATCH] basic support of postgres backend protocol --- .gitignore | 11 +- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 21 +++- src/page_cache.rs | 2 +- src/page_service.rs | 285 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 304 insertions(+), 17 deletions(-) create mode 100644 src/page_service.rs diff --git a/.gitignore b/.gitignore index 088ba6ba7d..ea8c4bf7f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk +/target diff --git a/Cargo.lock b/Cargo.lock index 72c00cfc1e..c3f688abb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -339,6 +339,7 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" name = "pageserver" version = "0.1.0" dependencies = [ + "byteorder", "bytes", "lazy_static", "postgres-protocol", diff --git a/Cargo.toml b/Cargo.toml index cff99beecc..4c66c14caa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" [dependencies] bytes = "1.0.1" +byteorder = "1.4.3" lazy_static = "1.4.0" tokio = { version = "1.3.0", features = ["full"] } tokio-stream = { version = "0.1.4" } diff --git a/src/main.rs b/src/main.rs index cca3d1d026..187bdcee4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,24 +3,33 @@ use std::thread; mod page_cache; mod waldecoder; mod walreceiver; +mod page_service; use std::io::Error; -#[macro_use] -extern crate lazy_static; - fn main() -> Result<(), Error> { - + let mut threads = Vec::new(); // Launch the WAL receiver thread. It will try to connect to the WAL safekeeper, // and stream the WAL. If the connection is lost, it will reconnect on its own. // We just fire and forget it here. - let handler = thread::spawn(|| { + let walreceiver_thread = thread::spawn(|| { // thread code walreceiver::thread_main(); }); + threads.push(walreceiver_thread); - let _unused = handler.join(); // never returns. + let page_server_thread = thread::spawn(|| { + // thread code + page_service::thread_main(); + }); + threads.push(page_server_thread); + + + // never returns. + for t in threads { + t.join().unwrap() + } Ok(()) } diff --git a/src/page_cache.rs b/src/page_cache.rs index fad62156c4..6329505e67 100644 --- a/src/page_cache.rs +++ b/src/page_cache.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; use bytes::Bytes; +use lazy_static::lazy_static; #[derive(PartialEq, Eq, PartialOrd, Ord)] pub struct BufferTag { @@ -14,7 +15,6 @@ pub struct BufferTag { pub struct CacheKey { pub tag: BufferTag, pub lsn: u64 - } pub struct WALRecord { diff --git a/src/page_service.rs b/src/page_service.rs new file mode 100644 index 0000000000..886b9e6cfb --- /dev/null +++ b/src/page_service.rs @@ -0,0 +1,285 @@ +use tokio::net::{TcpListener, TcpStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; +use tokio::runtime; +use tokio::task; +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, Bytes, BytesMut}; +use std::io::{self}; + +type Result = std::result::Result; + +/// +/// Basic support for postgres backend protocol. +/// + +enum FeMessage { + StartupMessage(FeStartupMessage), + Query(FeQueryMessage), + Terminate +} + +enum BeMessage { + AuthenticationOk, + ReadyForQuery, + RowDescription, + DataRow, + CommandComplete +} + +#[derive(Debug)] +struct FeStartupMessage { + version: u32, + body: Bytes +} + +impl FeStartupMessage { + pub fn parse(buf: &mut BytesMut) -> Result> { + const MAX_STARTUP_PACKET_LENGTH: u32 = 10000; + + if buf.len() < 4 { + return Ok(None); + } + let len = BigEndian::read_u32(&buf[0..4]); + + if len < 4 || len as u32 > MAX_STARTUP_PACKET_LENGTH { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length", + )); + } + + let version = BigEndian::read_u32(&buf[4..8]); + buf.advance(len as usize); + + Ok(Some(FeMessage::StartupMessage(FeStartupMessage{ version, body: Bytes::new() }))) + } +} + +#[derive(Debug)] +struct Buffer { + bytes: Bytes, + idx: usize, +} + +#[derive(Debug)] +struct FeQueryMessage { + body: Buffer +} + +impl FeMessage { + pub fn parse(buf: &mut BytesMut) -> Result> { + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_u32(&buf[1..5]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parsing u32", + )); + } + + let total_len = len as usize + 1; + if buf.len() < total_len { + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let buf = Buffer { + bytes: buf.split_to(total_len).freeze(), + idx: 5, + }; + + match tag { + b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage{body: buf}))), + b'X' => Ok(Some(FeMessage::Terminate)), + tag => { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown message tag: {}", tag), + )) + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +pub fn thread_main() { + + let runtime = runtime::Runtime::new().unwrap(); + + let listen_address = "127.0.0.1:5430"; + println!("Starting page server on {}", listen_address); + + runtime.block_on(async { + let _unused = page_service_main(listen_address).await; + }); +} + +async fn page_service_main(listen_address: &str) { + + let listener = TcpListener::bind(listen_address).await.unwrap(); + + loop { + let (socket, _) = listener.accept().await.unwrap(); + + let mut conn_handler = Connection::new(socket); + + task::spawn(async move { + if let Err(err) = conn_handler.run().await { + println!("error: {}", err); + } + }); + } +} + +#[derive(Debug)] +struct Connection { + stream: BufWriter, + buffer: BytesMut, + init_done: bool +} + +impl Connection { + + pub fn new(socket: TcpStream) -> Connection { + Connection { + stream: BufWriter::new(socket), + buffer: BytesMut::with_capacity(10 * 1024), + init_done: false + } + } + + // + // Read full message or return None if connection is closed + // + async fn read_message(&mut self) -> Result> { + loop { + if let Some(message) = self.parse_message()? { + return Ok(Some(message)); + } + + if self.stream.read_buf(&mut self.buffer).await? == 0 { + if self.buffer.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::Other,"connection reset by peer")); + } + } + } + } + + fn parse_message(&mut self) -> Result> { + + if !self.init_done { + FeStartupMessage::parse(&mut self.buffer) + } else { + FeMessage::parse(&mut self.buffer) + } + + } + + async fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { + + match message { + BeMessage::AuthenticationOk => { + self.stream.write_u8(b'R').await?; + self.stream.write_i32(4 + 4).await?; + self.stream.write_i32(0).await?; + } + + BeMessage::ReadyForQuery => { + self.stream.write_u8(b'Z').await?; + self.stream.write_i32(4 + 1).await?; + self.stream.write_u8(b'I').await?; + } + + BeMessage::RowDescription => { + // XXX + let mut b = Bytes::from("data\0"); + + self.stream.write_u8(b'T').await?; + self.stream.write_i32(4 + 2 + b.len() as i32 + 3*(4 + 2)).await?; + + self.stream.write_i16(1).await?; + self.stream.write_buf(&mut b).await?; + self.stream.write_i32(0).await?; /* table oid */ + self.stream.write_i16(0).await?; /* attnum */ + self.stream.write_i32(25).await?; /* TEXTOID */ + self.stream.write_i16(-1).await?; /* typlen */ + self.stream.write_i32(0).await?; /* typmod */ + self.stream.write_i16(0).await?; /* format code */ + } + + // XXX: accept some text data + BeMessage::DataRow => { + // XXX + let mut b = Bytes::from("hello world"); + + self.stream.write_u8(b'D').await?; + self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; + + self.stream.write_i16(1).await?; + self.stream.write_i32(b.len() as i32).await?; + self.stream.write_buf(&mut b).await?; + } + + BeMessage::CommandComplete => { + let mut b = Bytes::from("SELECT 1\0"); + + self.stream.write_u8(b'C').await?; + self.stream.write_i32(4 + b.len() as i32).await?; + self.stream.write_buf(&mut b).await?; + } + } + + Ok(()) + } + + async fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { + self.write_message_noflush(message).await?; + self.stream.flush().await + } + + async fn run(&mut self) -> Result<()> { + + loop { + let message = self.read_message().await?; + + match message { + // TODO: support ssl request + Some(FeMessage::StartupMessage(m)) => { + println!("got message {:?}", m); + self.write_message_noflush(&BeMessage::AuthenticationOk).await?; + self.write_message(&BeMessage::ReadyForQuery).await?; + self.init_done = true; + }, + Some(FeMessage::Query(m)) => { + self.write_message_noflush(&BeMessage::RowDescription).await?; + self.write_message_noflush(&BeMessage::DataRow).await?; + self.write_message_noflush(&BeMessage::CommandComplete).await?; + self.write_message(&BeMessage::ReadyForQuery).await?; + }, + Some(FeMessage::Terminate) => { + break; + } + None => { + println!("connection closed"); + break; + } + } + } + + Ok(()) + } + +} + +