//! Utilities for working with the PostgreSQL binary copy format. use crate::types::{FromSql, IsNull, ToSql, Type, WrongType}; use crate::{slice_iter, CopyInSink, CopyOutStream, Error}; use byteorder::{BigEndian, ByteOrder}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{ready, SinkExt, Stream}; use pin_project_lite::pin_project; use postgres_types::BorrowToSql; use std::convert::TryFrom; use std::io; use std::io::Cursor; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0"; const HEADER_LEN: usize = MAGIC.len() + 4 + 4; pin_project! { /// A type which serializes rows into the PostgreSQL binary copy format. /// /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. pub struct BinaryCopyInWriter { #[pin] sink: CopyInSink, types: Vec, buf: BytesMut, } } impl BinaryCopyInWriter { /// Creates a new writer which will write rows of the provided types to the provided sink. pub fn new(sink: CopyInSink, types: &[Type]) -> BinaryCopyInWriter { let mut buf = BytesMut::new(); buf.put_slice(MAGIC); buf.put_i32(0); // flags buf.put_i32(0); // header extension BinaryCopyInWriter { sink, types: types.to_vec(), buf, } } /// Writes a single row. /// /// # Panics /// /// Panics if the number of values provided does not match the number expected. pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> { self.write_raw(slice_iter(values)).await } /// A maximally-flexible version of `write`. /// /// # Panics /// /// Panics if the number of values provided does not match the number expected. pub async fn write_raw(self: Pin<&mut Self>, values: I) -> Result<(), Error> where P: BorrowToSql, I: IntoIterator, I::IntoIter: ExactSizeIterator, { let mut this = self.project(); let values = values.into_iter(); assert!( values.len() == this.types.len(), "expected {} values but got {}", this.types.len(), values.len(), ); this.buf.put_i16(this.types.len() as i16); for (i, (value, type_)) in values.zip(this.types).enumerate() { let idx = this.buf.len(); this.buf.put_i32(0); let len = match value .borrow_to_sql() .to_sql_checked(type_, this.buf) .map_err(|e| Error::to_sql(e, i))? { IsNull::Yes => -1, IsNull::No => i32::try_from(this.buf.len() - idx - 4) .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?, }; BigEndian::write_i32(&mut this.buf[idx..], len); } if this.buf.len() > 4096 { this.sink.send(this.buf.split().freeze()).await?; } Ok(()) } /// Completes the copy, returning the number of rows added. /// /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted. pub async fn finish(self: Pin<&mut Self>) -> Result { let mut this = self.project(); this.buf.put_i16(-1); this.sink.send(this.buf.split().freeze()).await?; this.sink.finish().await } } struct Header { has_oids: bool, } pin_project! { /// A stream of rows deserialized from the PostgreSQL binary copy format. pub struct BinaryCopyOutStream { #[pin] stream: CopyOutStream, types: Arc>, header: Option
, } } impl BinaryCopyOutStream { /// Creates a stream from a raw copy out stream and the types of the columns being returned. pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream { BinaryCopyOutStream { stream, types: Arc::new(types.to_vec()), header: None, } } } impl Stream for BinaryCopyOutStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); let chunk = match ready!(this.stream.poll_next(cx)) { Some(Ok(chunk)) => chunk, Some(Err(e)) => return Poll::Ready(Some(Err(e))), None => return Poll::Ready(Some(Err(Error::closed()))), }; let mut chunk = Cursor::new(chunk); let has_oids = match &this.header { Some(header) => header.has_oids, None => { check_remaining(&chunk, HEADER_LEN)?; if !chunk.chunk().starts_with(MAGIC) { return Poll::Ready(Some(Err(Error::parse(io::Error::new( io::ErrorKind::InvalidData, "invalid magic value", ))))); } chunk.advance(MAGIC.len()); let flags = chunk.get_i32(); let has_oids = (flags & (1 << 16)) != 0; let header_extension = chunk.get_u32() as usize; check_remaining(&chunk, header_extension)?; chunk.advance(header_extension); *this.header = Some(Header { has_oids }); has_oids } }; check_remaining(&chunk, 2)?; let mut len = chunk.get_i16(); if len == -1 { return Poll::Ready(None); } if has_oids { len += 1; } if len as usize != this.types.len() { return Poll::Ready(Some(Err(Error::parse(io::Error::new( io::ErrorKind::InvalidInput, format!("expected {} values but got {}", this.types.len(), len), ))))); } let mut ranges = vec![]; for _ in 0..len { check_remaining(&chunk, 4)?; let len = chunk.get_i32(); if len == -1 { ranges.push(None); } else { let len = len as usize; check_remaining(&chunk, len)?; let start = chunk.position() as usize; ranges.push(Some(start..start + len)); chunk.advance(len); } } Poll::Ready(Some(Ok(BinaryCopyOutRow { buf: chunk.into_inner(), ranges, types: this.types.clone(), }))) } } fn check_remaining(buf: &Cursor, len: usize) -> Result<(), Error> { if buf.remaining() < len { Err(Error::parse(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected EOF", ))) } else { Ok(()) } } /// A row of data parsed from a binary copy out stream. pub struct BinaryCopyOutRow { buf: Bytes, ranges: Vec>>, types: Arc>, } impl BinaryCopyOutRow { /// Like `get`, but returns a `Result` rather than panicking. pub fn try_get<'a, T>(&'a self, idx: usize) -> Result where T: FromSql<'a>, { let type_ = match self.types.get(idx) { Some(type_) => type_, None => return Err(Error::column(idx.to_string())), }; if !T::accepts(type_) { return Err(Error::from_sql( Box::new(WrongType::new::(type_.clone())), idx, )); } let r = match &self.ranges[idx] { Some(range) => T::from_sql(type_, &self.buf[range.clone()]), None => T::from_sql_null(type_), }; r.map_err(|e| Error::from_sql(e, idx)) } /// Deserializes a value from the row. /// /// # Panics /// /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. pub fn get<'a, T>(&'a self, idx: usize) -> T where T: FromSql<'a>, { match self.try_get(idx) { Ok(value) => value, Err(e) => panic!("error retrieving column {}: {}", idx, e), } } }