Files
lancedb/rust/ffi/node/src/arrow.rs
gsilvestrin 6036cf48a7 fix(node) Replace panic errors with friendlier ones (#366)
- Implement Result/Error in the node FFI
- Implement a trait (ResultExt) to make error handling less verbose
- Refactor some parts of the code that touch arrow into arrow.rs
2023-07-26 13:44:58 -07:00

86 lines
2.9 KiB
Rust

// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Cursor;
use std::ops::Deref;
use std::sync::Arc;
use arrow_array::cast::as_list_array;
use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch};
use arrow_ipc::reader::FileReader;
use arrow_ipc::writer::FileWriter;
use arrow_schema::{DataType, Field, Schema};
use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
use vectordb::table::VECTOR_COLUMN_NAME;
use crate::error::{MissingColumnSnafu, Result};
use snafu::prelude::*;
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> Result<RecordBatch> {
let column = get_column(VECTOR_COLUMN_NAME, &record_batch)?;
// TODO: we should just consume the underlying js buffer in the future instead of this arrow around a bunch of times
let arr = as_list_array(column.as_ref());
let list_size = arr.values().len() / record_batch.num_rows();
let r = FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32)?;
let schema = Arc::new(Schema::new(vec![Field::new(
VECTOR_COLUMN_NAME,
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
list_size as i32,
),
true,
)]));
let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)])?;
if record_batch.num_columns() > 1 {
let rb = record_batch.drop_column(VECTOR_COLUMN_NAME)?;
new_batch = new_batch.merge(&rb)?;
}
Ok(new_batch)
}
fn get_column(column_name: &str, record_batch: &RecordBatch) -> Result<ArrayRef> {
record_batch
.column_by_name(column_name)
.cloned()
.context(MissingColumnSnafu { name: column_name })
}
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<Vec<RecordBatch>> {
let mut batches: Vec<RecordBatch> = Vec::new();
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
for b in file_reader {
let record_batch = convert_record_batch(b?)?;
batches.push(record_batch);
}
Ok(batches)
}
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
if batches.is_empty() {
return Ok(Vec::new());
}
let schema = batches.get(0).unwrap().schema();
let mut fr = FileWriter::try_new(Vec::new(), schema.deref())?;
for batch in batches.iter() {
fr.write(batch)?
}
fr.finish()?;
Ok(fr.into_inner()?)
}