mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
feat: add to_query_object method (#2239)
This PR adds a `to_query_object` method to the various query builders (except not hybrid queries yet). This makes it possible to inspect the query that is built. In addition this PR does some normalization between the sync and async query paths. A few custom defaults were removed in favor of None (with the default getting set once, in rust). Also, the synchronous to_batches method will now actually stream results Also, the remote API now defaults to prefiltering
This commit is contained in:
@@ -1,19 +1,28 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::QueryFilter;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::exceptions::PyNotImplementedError;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::Bound;
|
||||
use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
@@ -24,6 +33,156 @@ use crate::arrow::RecordBatchStream;
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
// Python representation of full text search parameters
|
||||
#[derive(Clone)]
|
||||
#[pyclass(get_all)]
|
||||
pub struct PyFullTextSearchQuery {
|
||||
pub columns: Vec<String>,
|
||||
pub query: String,
|
||||
pub limit: Option<i64>,
|
||||
pub wand_factor: Option<f32>,
|
||||
}
|
||||
|
||||
impl From<FullTextSearchQuery> for PyFullTextSearchQuery {
|
||||
fn from(query: FullTextSearchQuery) -> Self {
|
||||
PyFullTextSearchQuery {
|
||||
columns: query.columns,
|
||||
query: query.query,
|
||||
limit: query.limit,
|
||||
wand_factor: query.wand_factor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of query vector(s)
|
||||
#[derive(Clone)]
|
||||
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
|
||||
|
||||
impl<'py> IntoPyObject<'py> for PyQueryVectors {
|
||||
type Target = PyList;
|
||||
type Output = Bound<'py, Self::Target>;
|
||||
type Error = PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
let py_objs = self
|
||||
.0
|
||||
.into_iter()
|
||||
.map(|v| v.to_data().into_pyarrow(py))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
PyList::new(py, py_objs)
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of a query
|
||||
#[pyclass(get_all)]
|
||||
pub struct PyQueryRequest {
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
pub filter: Option<PyQueryFilter>,
|
||||
pub full_text_search: Option<PyFullTextSearchQuery>,
|
||||
pub select: PySelect,
|
||||
pub fast_search: Option<bool>,
|
||||
pub with_row_id: Option<bool>,
|
||||
pub column: Option<String>,
|
||||
pub query_vector: Option<PyQueryVectors>,
|
||||
pub nprobes: Option<usize>,
|
||||
pub lower_bound: Option<f32>,
|
||||
pub upper_bound: Option<f32>,
|
||||
pub ef: Option<usize>,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub distance_type: Option<String>,
|
||||
pub bypass_vector_index: Option<bool>,
|
||||
pub postfilter: Option<bool>,
|
||||
pub norm: Option<String>,
|
||||
}
|
||||
|
||||
impl From<AnyQuery> for PyQueryRequest {
|
||||
fn from(query: AnyQuery) -> Self {
|
||||
match query {
|
||||
AnyQuery::Query(query_request) => PyQueryRequest {
|
||||
limit: query_request.limit,
|
||||
offset: query_request.offset,
|
||||
filter: query_request.filter.map(PyQueryFilter),
|
||||
full_text_search: query_request
|
||||
.full_text_search
|
||||
.map(PyFullTextSearchQuery::from),
|
||||
select: PySelect(query_request.select),
|
||||
fast_search: Some(query_request.fast_search),
|
||||
with_row_id: Some(query_request.with_row_id),
|
||||
column: None,
|
||||
query_vector: None,
|
||||
nprobes: None,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
bypass_vector_index: None,
|
||||
postfilter: None,
|
||||
norm: None,
|
||||
},
|
||||
AnyQuery::VectorQuery(vector_query) => PyQueryRequest {
|
||||
limit: vector_query.base.limit,
|
||||
offset: vector_query.base.offset,
|
||||
filter: vector_query.base.filter.map(PyQueryFilter),
|
||||
full_text_search: None,
|
||||
select: PySelect(vector_query.base.select),
|
||||
fast_search: Some(vector_query.base.fast_search),
|
||||
with_row_id: Some(vector_query.base.with_row_id),
|
||||
column: vector_query.column,
|
||||
query_vector: Some(PyQueryVectors(vector_query.query_vector)),
|
||||
nprobes: Some(vector_query.nprobes),
|
||||
lower_bound: vector_query.lower_bound,
|
||||
upper_bound: vector_query.upper_bound,
|
||||
ef: vector_query.ef,
|
||||
refine_factor: vector_query.refine_factor,
|
||||
distance_type: vector_query.distance_type.map(|d| d.to_string()),
|
||||
bypass_vector_index: Some(!vector_query.use_index),
|
||||
postfilter: Some(!vector_query.base.prefilter),
|
||||
norm: vector_query.base.norm.map(|n| n.to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of query selection
|
||||
#[derive(Clone)]
|
||||
pub struct PySelect(Select);
|
||||
|
||||
impl<'py> IntoPyObject<'py> for PySelect {
|
||||
type Target = PyAny;
|
||||
type Output = Bound<'py, Self::Target>;
|
||||
type Error = PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
match self.0 {
|
||||
Select::All => Ok(py.None().into_bound(py).into_any()),
|
||||
Select::Columns(columns) => Ok(columns.into_pyobject(py)?.into_any()),
|
||||
Select::Dynamic(columns) => Ok(columns.into_pyobject(py)?.into_any()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of query filter
|
||||
#[derive(Clone)]
|
||||
pub struct PyQueryFilter(QueryFilter);
|
||||
|
||||
impl<'py> IntoPyObject<'py> for PyQueryFilter {
|
||||
type Target = PyAny;
|
||||
type Output = Bound<'py, Self::Target>;
|
||||
type Error = PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
match self.0 {
|
||||
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err(
|
||||
"Datafusion filter has no conversion to Python",
|
||||
)),
|
||||
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
|
||||
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct Query {
|
||||
inner: LanceDbQuery,
|
||||
@@ -121,6 +280,10 @@ impl Query {
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_query_request(&self) -> PyQueryRequest {
|
||||
PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -205,6 +368,12 @@ impl FTSQuery {
|
||||
pub fn get_query(&self) -> String {
|
||||
self.fts_query.query.clone()
|
||||
}
|
||||
|
||||
pub fn to_query_request(&self) -> PyQueryRequest {
|
||||
let mut req = self.inner.clone().into_request();
|
||||
req.full_text_search = Some(self.fts_query.clone());
|
||||
PyQueryRequest::from(AnyQuery::Query(req))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -319,6 +488,10 @@ impl VectorQuery {
|
||||
inner_fts: fts_query,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_query_request(&self) -> PyQueryRequest {
|
||||
PyQueryRequest::from(AnyQuery::VectorQuery(self.inner.clone().into_request()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -421,4 +594,17 @@ impl HybridQuery {
|
||||
pub fn get_with_row_id(&mut self) -> bool {
|
||||
self.inner_fts.inner.current_request().with_row_id
|
||||
}
|
||||
|
||||
pub fn to_query_request(&self) -> PyQueryRequest {
|
||||
let mut req = self.inner_fts.to_query_request();
|
||||
let vec_req = self.inner_vec.to_query_request();
|
||||
req.query_vector = vec_req.query_vector;
|
||||
req.column = vec_req.column;
|
||||
req.distance_type = vec_req.distance_type;
|
||||
req.ef = vec_req.ef;
|
||||
req.refine_factor = vec_req.refine_factor;
|
||||
req.lower_bound = vec_req.lower_bound;
|
||||
req.upper_bound = vec_req.upper_bound;
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user