mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
feat: expose table trait (#2097)
Similar to
c269524b2f
this PR reworks and exposes an internal trait (this time
`TableInternal`) to be a public trait. These two PRs together should
make it possible for others to integrate LanceDB on top of other
catalogs.
This PR also adds a basic `TableProvider` implementation for tables,
although some work still needs to be done here (pushdown not yet
enabled).
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -4135,7 +4135,10 @@ dependencies = [
|
||||
"candle-transformers",
|
||||
"chrono",
|
||||
"crunchy",
|
||||
"datafusion-catalog",
|
||||
"datafusion-common",
|
||||
"datafusion-execution",
|
||||
"datafusion-expr",
|
||||
"datafusion-physical-plan",
|
||||
"futures",
|
||||
"half",
|
||||
|
||||
@@ -42,7 +42,10 @@ arrow-arith = "53.2"
|
||||
arrow-cast = "53.2"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion-common = "44.0"
|
||||
datafusion-catalog = "44.0"
|
||||
datafusion-common = { version = "44.0", default-features = false }
|
||||
datafusion-execution = "44.0"
|
||||
datafusion-expr = "44.0"
|
||||
datafusion-physical-plan = "44.0"
|
||||
env_logger = "0.10"
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
|
||||
@@ -7,8 +7,7 @@ use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, HasQuery, Query as LanceDbQuery, QueryBase, Select,
|
||||
VectorQuery as LanceDbVectorQuery,
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
@@ -313,7 +312,8 @@ impl VectorQuery {
|
||||
}
|
||||
|
||||
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
|
||||
let fts_query = Query::new(self.inner.mut_query().clone()).nearest_to_text(query)?;
|
||||
let base_query = self.inner.clone().into_plain();
|
||||
let fts_query = Query::new(base_query).nearest_to_text(query)?;
|
||||
Ok(HybridQuery {
|
||||
inner_vec: self.clone(),
|
||||
inner_fts: fts_query,
|
||||
@@ -411,10 +411,14 @@ impl HybridQuery {
|
||||
}
|
||||
|
||||
pub fn get_limit(&mut self) -> Option<u32> {
|
||||
self.inner_fts.inner.limit.map(|i| i as u32)
|
||||
self.inner_fts
|
||||
.inner
|
||||
.current_request()
|
||||
.limit
|
||||
.map(|i| i as u32)
|
||||
}
|
||||
|
||||
pub fn get_with_row_id(&mut self) -> bool {
|
||||
self.inner_fts.inner.with_row_id
|
||||
self.inner_fts.inner.current_request().with_row_id
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,10 @@ arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
chrono = { workspace = true }
|
||||
datafusion-catalog.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
@@ -33,7 +36,7 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
moka = { workspace = true}
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log.workspace = true
|
||||
@@ -82,7 +85,7 @@ aws-sdk-s3 = { version = "1.38.0" }
|
||||
aws-sdk-kms = { version = "1.37" }
|
||||
aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
http-body = "1" # Matching reqwest
|
||||
http-body = "1" # Matching reqwest
|
||||
|
||||
|
||||
[features]
|
||||
@@ -98,7 +101,7 @@ sentence-transformers = [
|
||||
"dep:candle-core",
|
||||
"dep:candle-transformers",
|
||||
"dep:candle-nn",
|
||||
"dep:tokenizers"
|
||||
"dep:tokenizers",
|
||||
]
|
||||
|
||||
# TLS
|
||||
|
||||
@@ -21,7 +21,7 @@ use arrow_array::RecordBatchReader;
|
||||
use lance::dataset::ReadParams;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::table::{TableDefinition, TableInternal, WriteOptions};
|
||||
use crate::table::{BaseTable, TableDefinition, WriteOptions};
|
||||
|
||||
pub mod listing;
|
||||
|
||||
@@ -120,9 +120,9 @@ pub trait Database:
|
||||
/// List the names of tables in the database
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>>;
|
||||
/// Create a table in the database
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn TableInternal>>;
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Open a table in the database
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn TableInternal>>;
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Rename a table in the database
|
||||
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||
/// Drop a table in the database
|
||||
|
||||
@@ -22,8 +22,8 @@ use crate::table::NativeTable;
|
||||
use crate::utils::validate_table_name;
|
||||
|
||||
use super::{
|
||||
CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, TableInternal, TableNamesRequest,
|
||||
BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
};
|
||||
|
||||
/// File extension to indicate a lance table
|
||||
@@ -356,10 +356,7 @@ impl Database for ListingDatabase {
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
async fn create_table(
|
||||
&self,
|
||||
mut request: CreateTableRequest,
|
||||
) -> Result<Arc<dyn TableInternal>> {
|
||||
async fn create_table(&self, mut request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = request
|
||||
@@ -452,7 +449,7 @@ impl Database for ListingDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn TableInternal>> {
|
||||
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
|
||||
// Inherit storage options from the connection
|
||||
|
||||
@@ -8,7 +8,7 @@ use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||
use crate::{table::BaseTable, DistanceType, Error, Result};
|
||||
|
||||
use self::{
|
||||
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
||||
@@ -65,14 +65,14 @@ pub enum Index {
|
||||
///
|
||||
/// The methods on this builder are used to specify options common to all indices.
|
||||
pub struct IndexBuilder {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) index: Index,
|
||||
pub(crate) columns: Vec<String>,
|
||||
pub(crate) replace: bool,
|
||||
}
|
||||
|
||||
impl IndexBuilder {
|
||||
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
|
||||
pub(crate) fn new(parent: Arc<dyn BaseTable>, columns: Vec<String>, index: Index) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
index,
|
||||
|
||||
@@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::arrow::SendableRecordBatchStream;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||
use crate::table::TableInternal;
|
||||
use crate::table::BaseTable;
|
||||
use crate::DistanceType;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
|
||||
mod hybrid;
|
||||
|
||||
@@ -449,7 +449,7 @@ pub trait QueryBase {
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
fn mut_query(&mut self) -> &mut Query;
|
||||
fn mut_query(&mut self) -> &mut QueryRequest;
|
||||
}
|
||||
|
||||
impl<T: HasQuery> QueryBase for T {
|
||||
@@ -577,6 +577,65 @@ pub trait ExecutableQuery {
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
/// A basic query into a table without any kind of search
|
||||
///
|
||||
/// This will result in a (potentially filtered) scan if executed
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryRequest {
|
||||
/// limit the number of rows to return.
|
||||
pub limit: Option<usize>,
|
||||
|
||||
/// Offset of the query.
|
||||
pub offset: Option<usize>,
|
||||
|
||||
/// Apply filter to the returned rows.
|
||||
pub filter: Option<String>,
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
pub full_text_search: Option<FullTextSearchQuery>,
|
||||
|
||||
/// Select column projection.
|
||||
pub select: Select,
|
||||
|
||||
/// If set to true, the query is executed only on the indexed data,
|
||||
/// and yields faster results.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub norm: Option<NormalizeMethod>,
|
||||
}
|
||||
|
||||
impl Default for QueryRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
limit: Some(DEFAULT_TOP_K),
|
||||
offset: None,
|
||||
filter: None,
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for LanceDB queries.
|
||||
///
|
||||
/// See [`crate::Table::query`] for more details on queries
|
||||
@@ -591,59 +650,15 @@ pub trait ExecutableQuery {
|
||||
/// times.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Query {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
|
||||
/// limit the number of rows to return.
|
||||
pub limit: Option<usize>,
|
||||
|
||||
/// Offset of the query.
|
||||
pub(crate) offset: Option<usize>,
|
||||
|
||||
/// Apply filter to the returned rows.
|
||||
pub(crate) filter: Option<String>,
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
pub(crate) full_text_search: Option<FullTextSearchQuery>,
|
||||
|
||||
/// Select column projection.
|
||||
pub(crate) select: Select,
|
||||
|
||||
/// If set to true, the query is executed only on the indexed data,
|
||||
/// and yields faster results.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub(crate) norm: Option<NormalizeMethod>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
request: QueryRequest,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub(crate) fn new(parent: Arc<dyn TableInternal>) -> Self {
|
||||
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
limit: Some(DEFAULT_TOP_K),
|
||||
offset: None,
|
||||
filter: None,
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
request: QueryRequest::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,38 +706,98 @@ impl Query {
|
||||
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
|
||||
let mut vector_query = self.into_vector();
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
vector_query.query_vector.push(query_vector);
|
||||
vector_query.request.query_vector.push(query_vector);
|
||||
Ok(vector_query)
|
||||
}
|
||||
|
||||
pub fn into_request(self) -> QueryRequest {
|
||||
self.request
|
||||
}
|
||||
|
||||
pub fn current_request(&self) -> &QueryRequest {
|
||||
&self.request
|
||||
}
|
||||
}
|
||||
|
||||
impl HasQuery for Query {
|
||||
fn mut_query(&mut self) -> &mut Query {
|
||||
self
|
||||
fn mut_query(&mut self) -> &mut QueryRequest {
|
||||
&mut self.request
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutableQuery for Query {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.parent
|
||||
.clone()
|
||||
.create_plan(&self.clone().into_vector(), options)
|
||||
.await
|
||||
let req = AnyQuery::Query(self.request.clone());
|
||||
self.parent.clone().create_plan(&req, options).await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let query = AnyQuery::Query(self.request.clone());
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
self.parent.clone().plain_query(self, options).await?,
|
||||
self.parent.clone().query(&query, options).await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.parent
|
||||
.explain_plan(&self.clone().into_vector(), verbose)
|
||||
.await
|
||||
let query = AnyQuery::Query(self.request.clone());
|
||||
self.parent.explain_plan(&query, verbose).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A request for a nearest-neighbors search into a table
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQueryRequest {
|
||||
/// The base query
|
||||
pub base: QueryRequest,
|
||||
/// The column to run the search on
|
||||
///
|
||||
/// If None, then the table will need to auto-detect which column to use
|
||||
pub column: Option<String>,
|
||||
/// The vector(s) to search for
|
||||
pub query_vector: Vec<Arc<dyn Array>>,
|
||||
/// The number of partitions to search
|
||||
pub nprobes: usize,
|
||||
/// The lower bound (inclusive) of the distance to search for.
|
||||
pub lower_bound: Option<f32>,
|
||||
/// The upper bound (exclusive) of the distance to search for.
|
||||
pub upper_bound: Option<f32>,
|
||||
/// The number of candidates to return during the refine step for HNSW,
|
||||
/// defaults to 1.5 * limit.
|
||||
pub ef: Option<usize>,
|
||||
/// A multiplier to control how many additional rows are taken during the refine step
|
||||
pub refine_factor: Option<u32>,
|
||||
/// The distance type to use for the search
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub use_index: bool,
|
||||
}
|
||||
|
||||
impl Default for VectorQueryRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: QueryRequest::default(),
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorQueryRequest {
|
||||
pub fn from_plain_query(query: QueryRequest) -> Self {
|
||||
Self {
|
||||
base: query,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,39 +812,30 @@ impl ExecutableQuery for Query {
|
||||
/// the query and retrieve results.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQuery {
|
||||
pub(crate) base: Query,
|
||||
// The column to run the query on. If not specified, we will attempt to guess
|
||||
// the column based on the dataset's schema.
|
||||
pub(crate) column: Option<String>,
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The lower bound (inclusive) of the distance to search for.
|
||||
pub(crate) lower_bound: Option<f32>,
|
||||
// The upper bound (exclusive) of the distance to search for.
|
||||
pub(crate) upper_bound: Option<f32>,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
pub(crate) refine_factor: Option<u32>,
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub(crate) use_index: bool,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
request: VectorQueryRequest,
|
||||
}
|
||||
|
||||
impl VectorQuery {
|
||||
fn new(base: Query) -> Self {
|
||||
Self {
|
||||
base,
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
parent: base.parent,
|
||||
request: VectorQueryRequest::from_plain_query(base.request),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_request(self) -> VectorQueryRequest {
|
||||
self.request
|
||||
}
|
||||
|
||||
pub fn current_request(&self) -> &VectorQueryRequest {
|
||||
&self.request
|
||||
}
|
||||
|
||||
pub fn into_plain(self) -> Query {
|
||||
Query {
|
||||
parent: self.parent,
|
||||
request: self.request.base,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -781,7 +847,7 @@ impl VectorQuery {
|
||||
/// This parameter must be specified if the table has more than one column
|
||||
/// whose data type is a fixed-size-list of floats.
|
||||
pub fn column(mut self, column: &str) -> Self {
|
||||
self.column = Some(column.to_string());
|
||||
self.request.column = Some(column.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
@@ -797,7 +863,7 @@ impl VectorQuery {
|
||||
/// result.
|
||||
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
self.query_vector.push(query_vector);
|
||||
self.request.query_vector.push(query_vector);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -822,15 +888,15 @@ impl VectorQuery {
|
||||
/// your actual data to find the smallest possible value that will still give
|
||||
/// you the desired recall.
|
||||
pub fn nprobes(mut self, nprobes: usize) -> Self {
|
||||
self.nprobes = nprobes;
|
||||
self.request.nprobes = nprobes;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance range for vector search,
|
||||
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||
self.lower_bound = lower_bound;
|
||||
self.upper_bound = upper_bound;
|
||||
self.request.lower_bound = lower_bound;
|
||||
self.request.upper_bound = upper_bound;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -842,7 +908,7 @@ impl VectorQuery {
|
||||
/// Increasing this value will increase the recall of your query but will
|
||||
/// also increase the latency of your query. The default value is 1.5*limit.
|
||||
pub fn ef(mut self, ef: usize) -> Self {
|
||||
self.ef = Some(ef);
|
||||
self.request.ef = Some(ef);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -874,7 +940,7 @@ impl VectorQuery {
|
||||
/// and the quantized result vectors. This can be considerably different than the true
|
||||
/// distance between the query vector and the actual uncompressed vector.
|
||||
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
|
||||
self.refine_factor = Some(refine_factor);
|
||||
self.request.refine_factor = Some(refine_factor);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -891,7 +957,7 @@ impl VectorQuery {
|
||||
///
|
||||
/// By default [`DistanceType::L2`] is used.
|
||||
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
|
||||
self.distance_type = Some(distance_type);
|
||||
self.request.distance_type = Some(distance_type);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -903,16 +969,19 @@ impl VectorQuery {
|
||||
/// the vector index can give you ground truth results which you can use to
|
||||
/// calculate your recall to select an appropriate value for nprobes.
|
||||
pub fn bypass_vector_index(mut self) -> Self {
|
||||
self.use_index = false;
|
||||
self.request.use_index = false;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let fts_query = self.base.clone().with_row_id();
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
fts_query = fts_query.with_row_id();
|
||||
|
||||
let mut vector_query = self.clone().with_row_id();
|
||||
|
||||
vector_query.base.full_text_search = None;
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
@@ -928,7 +997,7 @@ impl VectorQuery {
|
||||
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||
|
||||
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
if matches!(self.request.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||
}
|
||||
@@ -937,14 +1006,20 @@ impl VectorQuery {
|
||||
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||
|
||||
let reranker = self
|
||||
.request
|
||||
.base
|
||||
.reranker
|
||||
.clone()
|
||||
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||
|
||||
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
let fts_query = self
|
||||
.request
|
||||
.base
|
||||
.full_text_search
|
||||
.as_ref()
|
||||
.ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
|
||||
let mut results = reranker
|
||||
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||
@@ -952,12 +1027,12 @@ impl VectorQuery {
|
||||
|
||||
check_reranker_result(&results)?;
|
||||
|
||||
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
let limit = self.request.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
if results.num_rows() > limit {
|
||||
results = results.slice(0, limit);
|
||||
}
|
||||
|
||||
if !self.base.with_row_id {
|
||||
if !self.request.base.with_row_id {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
@@ -969,14 +1044,15 @@ impl VectorQuery {
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.base.parent.clone().create_plan(self, options).await
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
self.parent.clone().create_plan(&query, options).await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
if self.base.full_text_search.is_some() {
|
||||
if self.request.base.full_text_search.is_some() {
|
||||
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||
return Ok(hybrid_result);
|
||||
}
|
||||
@@ -990,13 +1066,14 @@ impl ExecutableQuery for VectorQuery {
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.base.parent.explain_plan(self, verbose).await
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
self.parent.explain_plan(&query, verbose).await
|
||||
}
|
||||
}
|
||||
|
||||
impl HasQuery for VectorQuery {
|
||||
fn mut_query(&mut self) -> &mut Query {
|
||||
&mut self.base
|
||||
fn mut_query(&mut self) -> &mut QueryRequest {
|
||||
&mut self.request.base
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1036,7 +1113,13 @@ mod tests {
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
*query
|
||||
.request
|
||||
.query_vector
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.as_primitive(),
|
||||
vector
|
||||
);
|
||||
|
||||
@@ -1054,15 +1137,21 @@ mod tests {
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
*query
|
||||
.request
|
||||
.query_vector
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.as_primitive(),
|
||||
new_vector
|
||||
);
|
||||
assert_eq!(query.base.limit.unwrap(), 100);
|
||||
assert_eq!(query.base.offset.unwrap(), 1);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert!(query.use_index);
|
||||
assert_eq!(query.distance_type, Some(DistanceType::Cosine));
|
||||
assert_eq!(query.refine_factor, Some(999));
|
||||
assert_eq!(query.request.base.limit.unwrap(), 100);
|
||||
assert_eq!(query.request.base.offset.unwrap(), 1);
|
||||
assert_eq!(query.request.nprobes, 1000);
|
||||
assert!(query.request.use_index);
|
||||
assert_eq!(query.request.distance_type, Some(DistanceType::Cosine));
|
||||
assert_eq!(query.request.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -14,6 +14,7 @@ pub(crate) mod util;
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
#[cfg(test)]
|
||||
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
#[cfg(test)]
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::database::{
|
||||
TableNamesRequest,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::table::TableInternal;
|
||||
use crate::table::BaseTable;
|
||||
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
@@ -126,7 +126,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
Ok(tables)
|
||||
}
|
||||
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn TableInternal>> {
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let data = match request.data {
|
||||
CreateTableData::Data(data) => data,
|
||||
CreateTableData::Empty(table_definition) => {
|
||||
@@ -198,7 +198,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
)))
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn TableInternal>> {
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
if self.table_cache.get(&request.name).is_none() {
|
||||
let req = self
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::io::Cursor;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::query::{QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::table::{AddDataMode, AnyQuery, Filter};
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::{DistanceType, Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
@@ -16,14 +17,14 @@ use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
|
||||
use futures::TryStreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use lance::arrow::json::{JsonDataType, JsonSchema};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
|
||||
use lance_datafusion::exec::OneShotExec;
|
||||
use lance_datafusion::exec::{execute_plan, OneShotExec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@@ -31,16 +32,16 @@ use crate::{
|
||||
connection::NoData,
|
||||
error::Result,
|
||||
index::{IndexBuilder, IndexConfig},
|
||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||
query::QueryExecutionOptions,
|
||||
table::{
|
||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||
TableDefinition, TableInternal, UpdateBuilder,
|
||||
merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats,
|
||||
TableDefinition, UpdateBuilder,
|
||||
},
|
||||
};
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
@@ -147,7 +148,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
}
|
||||
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
|
||||
if let Some(offset) = params.offset {
|
||||
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
|
||||
}
|
||||
@@ -205,7 +206,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
|
||||
fn apply_vector_query_params(
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQuery,
|
||||
query: &VectorQueryRequest,
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
@@ -288,6 +289,45 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
let read_guard = self.version.read().await;
|
||||
*read_guard
|
||||
}
|
||||
|
||||
async fn execute_query(
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
|
||||
match query {
|
||||
AnyQuery::Query(query) => {
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
Ok(vec![stream])
|
||||
}
|
||||
AnyQuery::VectorQuery(query) => {
|
||||
let bodies = Self::apply_vector_query_params(body, query)?;
|
||||
let mut futures = Vec::with_capacity(bodies.len());
|
||||
for body in bodies {
|
||||
let request = request.try_clone().unwrap().json(&body);
|
||||
let future = async move {
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
futures::future::try_join_all(futures).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -325,13 +365,10 @@ mod test_utils {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
None
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
@@ -398,7 +435,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let schema = self.describe().await?.schema;
|
||||
Ok(Arc::new(schema.try_into()?))
|
||||
}
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
@@ -406,6 +443,11 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let version = self.current_version().await;
|
||||
|
||||
if let Some(filter) = filter {
|
||||
let Filter::Sql(filter) = filter else {
|
||||
return Err(Error::NotSupported {
|
||||
message: "querying a remote table with a datafusion filter".to_string(),
|
||||
});
|
||||
};
|
||||
request = request.json(&serde_json::json!({ "predicate": filter, "version": version }));
|
||||
} else {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
@@ -453,25 +495,11 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
let streams = self.execute_query(query, options).await?;
|
||||
|
||||
let version = self.current_version().await;
|
||||
let body = serde_json::json!({ "version": version });
|
||||
let bodies = Self::apply_vector_query_params(body, query)?;
|
||||
|
||||
let mut futures = Vec::with_capacity(bodies.len());
|
||||
for body in bodies {
|
||||
let request = request.try_clone().unwrap().json(&body);
|
||||
let future = async move {
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
@@ -484,29 +512,29 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/query/", self.name))
|
||||
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
|
||||
let streams = self.execute_query(query, _options).await?;
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
if streams.len() == 1 {
|
||||
Ok(DatasetRecordBatchStream::new(
|
||||
streams.into_iter().next().unwrap(),
|
||||
))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
let plan = Table::multi_vector_plan(stream_execs)?;
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
Ok(DatasetRecordBatchStream::new(execute_plan(
|
||||
plan,
|
||||
Default::default(),
|
||||
)?))
|
||||
}
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
self.check_mutable().await?;
|
||||
@@ -891,6 +919,7 @@ mod tests {
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::remote::JSON_CONTENT_TYPE;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
|
||||
@@ -12,6 +12,7 @@ use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::repartition::RepartitionExec;
|
||||
@@ -21,12 +22,13 @@ use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::scanner::Scanner;
|
||||
pub use lance::dataset::ColumnAlteration;
|
||||
pub use lance::dataset::NewColumnTransform;
|
||||
pub use lance::dataset::ReadParams;
|
||||
pub use lance::dataset::Version;
|
||||
use lance::dataset::{
|
||||
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, Version, WhenMatched, WriteMode,
|
||||
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode,
|
||||
WriteParams,
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
@@ -60,7 +62,8 @@ use crate::index::{
|
||||
};
|
||||
use crate::index::{IndexConfig, IndexStatisticsImpl};
|
||||
use crate::query::{
|
||||
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
|
||||
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery,
|
||||
VectorQueryRequest, DEFAULT_TOP_K,
|
||||
};
|
||||
use crate::utils::{
|
||||
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
|
||||
@@ -71,11 +74,13 @@ use crate::utils::{
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
use self::merge::MergeInsertBuilder;
|
||||
|
||||
pub mod datafusion;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
pub use chrono::Duration;
|
||||
pub use lance::dataset::optimize::CompactionOptions;
|
||||
pub use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
pub use lance_index::optimize::OptimizeOptions;
|
||||
|
||||
/// Defines the type of column
|
||||
@@ -273,7 +278,7 @@ pub enum AddDataMode {
|
||||
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
|
||||
/// operation
|
||||
pub struct AddDataBuilder<T: IntoArrow> {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) data: T,
|
||||
pub(crate) mode: AddDataMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
@@ -318,13 +323,13 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
/// A builder for configuring an [`Table::update`] operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateBuilder {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) filter: Option<String>,
|
||||
pub(crate) columns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl UpdateBuilder {
|
||||
fn new(parent: Arc<dyn TableInternal>) -> Self {
|
||||
fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
filter: None,
|
||||
@@ -381,64 +386,102 @@ impl UpdateBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters that can be used to limit the rows returned by a query
|
||||
pub enum Filter {
|
||||
/// A SQL filter string
|
||||
Sql(String),
|
||||
/// A Datafusion logical expression
|
||||
Datafusion(Expr),
|
||||
}
|
||||
|
||||
/// A query that can be used to search a LanceDB table
|
||||
pub enum AnyQuery {
|
||||
Query(QueryRequest),
|
||||
VectorQuery(VectorQueryRequest),
|
||||
}
|
||||
|
||||
/// A trait for anything "table-like". This is used for both native tables (which target
|
||||
/// Lance datasets) and remote tables (which target LanceDB cloud)
|
||||
///
|
||||
/// This trait is still EXPERIMENTAL and subject to change in the future
|
||||
#[async_trait]
|
||||
pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
#[allow(dead_code)]
|
||||
pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// Get a reference to std::any::Any
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
/// Get the name of the table.
|
||||
fn name(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize>;
|
||||
/// Create a physical plan for the query.
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>>;
|
||||
async fn plain_query(
|
||||
/// Execute a query and return the results as a stream of RecordBatches.
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
/// Explain the plan for a query.
|
||||
async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result<String> {
|
||||
let plan = self.create_plan(query, Default::default()).await?;
|
||||
let display = DisplayableExecutionPlan::new(plan.as_ref());
|
||||
|
||||
Ok(format!("{}", display.indent(verbose)))
|
||||
}
|
||||
/// Add new records to the table.
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn arrow_array::RecordBatchReader + Send>,
|
||||
) -> Result<()>;
|
||||
/// Delete rows from the table.
|
||||
async fn delete(&self, predicate: &str) -> Result<()>;
|
||||
/// Update rows in the table.
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64>;
|
||||
/// Create an index on the provided column(s).
|
||||
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
|
||||
/// List the indices on the table.
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>>;
|
||||
/// Drop an index from the table.
|
||||
async fn drop_index(&self, name: &str) -> Result<()>;
|
||||
/// Get statistics about the index.
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>>;
|
||||
/// Merge insert new records into the table.
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
new_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()>;
|
||||
/// Optimize the dataset.
|
||||
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
|
||||
/// Add columns to the table.
|
||||
async fn add_columns(
|
||||
&self,
|
||||
transforms: NewColumnTransform,
|
||||
read_columns: Option<Vec<String>>,
|
||||
) -> Result<()>;
|
||||
/// Alter columns in the table.
|
||||
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>;
|
||||
/// Drop columns from the table.
|
||||
async fn drop_columns(&self, columns: &[&str]) -> Result<()>;
|
||||
/// Get the version of the table.
|
||||
async fn version(&self) -> Result<u64>;
|
||||
/// Checkout a specific version of the table.
|
||||
async fn checkout(&self, version: u64) -> Result<()>;
|
||||
/// Checkout the latest version of the table.
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
/// Restore the table to the currently checked out version.
|
||||
async fn restore(&self) -> Result<()>;
|
||||
/// List the versions of the table.
|
||||
async fn list_versions(&self) -> Result<Vec<Version>>;
|
||||
/// Get the table definition.
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
}
|
||||
|
||||
@@ -447,7 +490,7 @@ pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
#[derive(Clone)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn TableInternal>,
|
||||
inner: Arc<dyn BaseTable>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
@@ -483,15 +526,19 @@ impl std::fmt::Display for Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_table(&self) -> &Arc<dyn BaseTable> {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn TableInternal>,
|
||||
inner: Arc<dyn BaseTable>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -524,7 +571,7 @@ impl Table {
|
||||
///
|
||||
/// * `filter` if present, only count rows matching the filter
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
self.inner.count_rows(filter).await
|
||||
self.inner.count_rows(filter.map(Filter::Sql)).await
|
||||
}
|
||||
|
||||
/// Insert new records into this Table
|
||||
@@ -1063,6 +1110,17 @@ impl From<NativeTable> for Table {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NativeTableExt {
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
}
|
||||
|
||||
impl NativeTableExt for Arc<dyn BaseTable> {
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
self.as_any().downcast_ref::<NativeTable>()
|
||||
}
|
||||
}
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeTable {
|
||||
@@ -1676,7 +1734,7 @@ impl NativeTable {
|
||||
|
||||
async fn generic_query(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = self.create_plan(query, options).await?;
|
||||
@@ -1766,15 +1824,11 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TableInternal for NativeTable {
|
||||
impl BaseTable for NativeTable {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
@@ -1830,8 +1884,15 @@ impl TableInternal for NativeTable {
|
||||
TableDefinition::try_from_rich_schema(schema)
|
||||
}
|
||||
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
match filter {
|
||||
None => Ok(dataset.count_rows(None).await?),
|
||||
Some(Filter::Sql(sql)) => Ok(dataset.count_rows(Some(sql)).await?),
|
||||
Some(Filter::Datafusion(_)) => Err(Error::NotSupported {
|
||||
message: "Datafusion filters are not yet supported".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn add(
|
||||
@@ -1925,9 +1986,14 @@ impl TableInternal for NativeTable {
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let query = match query {
|
||||
AnyQuery::VectorQuery(query) => query.clone(),
|
||||
AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()),
|
||||
};
|
||||
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
let mut column = query.column.clone();
|
||||
let schema = ds_ref.schema();
|
||||
@@ -1975,7 +2041,10 @@ impl TableInternal for NativeTable {
|
||||
let mut sub_query = query.clone();
|
||||
sub_query.query_vector = vec![query_vector];
|
||||
let options_ref = options.clone();
|
||||
async move { self.create_plan(&sub_query, options_ref).await }
|
||||
async move {
|
||||
self.create_plan(&AnyQuery::VectorQuery(sub_query), options_ref)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let plans = futures::future::try_join_all(plan_futures).await?;
|
||||
@@ -2073,13 +2142,12 @@ impl TableInternal for NativeTable {
|
||||
Ok(scanner.create_plan().await?)
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
self.generic_query(&query.clone().into_vector(), options)
|
||||
.await
|
||||
self.generic_query(query, options).await
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
@@ -2348,7 +2416,10 @@ mod tests {
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
|
||||
table
|
||||
.count_rows(Some(Filter::Sql("i >= 5".to_string())))
|
||||
.await
|
||||
.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
187
rust/lancedb/src/table/datafusion.rs
Normal file
187
rust/lancedb/src/table/datafusion.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_schema::Schema as ArrowSchema;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_catalog::{Session, TableProvider};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
|
||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
|
||||
use datafusion_physical_plan::{
|
||||
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
|
||||
};
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
query::{QueryExecutionOptions, QueryRequest, Select},
|
||||
Result,
|
||||
};
|
||||
|
||||
/// Datafusion attempts to maintain batch metadata
|
||||
///
|
||||
/// This is needless and it triggers bugs in DF. This operator erases metadata from the batches.
|
||||
#[derive(Debug)]
|
||||
struct MetadataEraserExec {
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
properties: PlanProperties,
|
||||
}
|
||||
|
||||
impl MetadataEraserExec {
|
||||
fn compute_properties_from_input(
|
||||
input: &Arc<dyn ExecutionPlan>,
|
||||
schema: &Arc<ArrowSchema>,
|
||||
) -> PlanProperties {
|
||||
let input_properties = input.properties();
|
||||
let eq_properties = input_properties
|
||||
.eq_properties
|
||||
.clone()
|
||||
.with_new_schema(schema.clone())
|
||||
.unwrap();
|
||||
input_properties.clone().with_eq_properties(eq_properties)
|
||||
}
|
||||
|
||||
fn new(input: Arc<dyn ExecutionPlan>) -> Self {
|
||||
let schema = Arc::new(
|
||||
input
|
||||
.schema()
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::new()),
|
||||
);
|
||||
Self {
|
||||
properties: Self::compute_properties_from_input(&input, &schema),
|
||||
input,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayAs for MetadataEraserExec {
|
||||
fn fmt_as(&self, _: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetadataEraserExec")
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionPlan for MetadataEraserExec {
|
||||
fn name(&self) -> &str {
|
||||
"MetadataEraserExec"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![&self.input]
|
||||
}
|
||||
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
assert_eq!(children.len(), 1);
|
||||
let new_properties = Self::compute_properties_from_input(&children[0], &self.schema);
|
||||
Ok(Arc::new(Self {
|
||||
input: children[0].clone(),
|
||||
schema: self.schema.clone(),
|
||||
properties: new_properties,
|
||||
}))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
partition: usize,
|
||||
context: Arc<TaskContext>,
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
let stream = self.input.execute(partition, context)?;
|
||||
let schema = self.schema.clone();
|
||||
let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||
Ok(
|
||||
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
|
||||
as SendableRecordBatchStream,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BaseTableAdapter {
|
||||
table: Arc<dyn BaseTable>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
}
|
||||
|
||||
impl BaseTableAdapter {
|
||||
pub async fn try_new(table: Arc<dyn BaseTable>) -> Result<Self> {
|
||||
let schema = table.schema().await?;
|
||||
Ok(Self { table, schema })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TableProvider for BaseTableAdapter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn schema(&self) -> Arc<ArrowSchema> {
|
||||
self.schema.clone()
|
||||
}
|
||||
|
||||
fn table_type(&self) -> TableType {
|
||||
TableType::Base
|
||||
}
|
||||
|
||||
async fn scan(
|
||||
&self,
|
||||
_state: &dyn Session,
|
||||
projection: Option<&Vec<usize>>,
|
||||
filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
let mut query = QueryRequest::default();
|
||||
if let Some(projection) = projection {
|
||||
let field_names = projection
|
||||
.iter()
|
||||
.map(|i| self.schema.field(*i).name().to_string())
|
||||
.collect();
|
||||
query.select = Select::Columns(field_names);
|
||||
}
|
||||
assert!(filters.is_empty());
|
||||
if let Some(limit) = limit {
|
||||
query.limit = Some(limit);
|
||||
} else {
|
||||
// Need to override the default of 10
|
||||
query.limit = None;
|
||||
}
|
||||
let plan = self
|
||||
.table
|
||||
.create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default())
|
||||
.map_err(|err| DataFusionError::External(err.into()))
|
||||
.await?;
|
||||
Ok(Arc::new(MetadataEraserExec::new(plan)))
|
||||
}
|
||||
|
||||
fn supports_filters_pushdown(
|
||||
&self,
|
||||
filters: &[&Expr],
|
||||
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
|
||||
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan
|
||||
Ok(vec![
|
||||
TableProviderFilterPushDown::Unsupported;
|
||||
filters.len()
|
||||
])
|
||||
}
|
||||
|
||||
fn statistics(&self) -> Option<Statistics> {
|
||||
// TODO
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -7,14 +7,14 @@ use arrow_array::RecordBatchReader;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use super::TableInternal;
|
||||
use super::BaseTable;
|
||||
|
||||
/// A builder used to create and run a merge insert operation
|
||||
///
|
||||
/// See [`super::Table::merge_insert`] for more context
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn TableInternal>,
|
||||
table: Arc<dyn BaseTable>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
@@ -24,7 +24,7 @@ pub struct MergeInsertBuilder {
|
||||
}
|
||||
|
||||
impl MergeInsertBuilder {
|
||||
pub(super) fn new(table: Arc<dyn TableInternal>, on: Vec<String>) -> Self {
|
||||
pub(super) fn new(table: Arc<dyn BaseTable>, on: Vec<String>) -> Self {
|
||||
Self {
|
||||
table,
|
||||
on,
|
||||
|
||||
Reference in New Issue
Block a user