mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat: support new FTS features in python SDK (#2411)
- AND operator - phrase query slop param - boolean query <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for combining full-text search queries using AND/OR operators, enabling more flexible query composition. - Introduced new query types and parameters, including boolean queries, operator selection, occurrence constraints, and phrase slop for advanced search scenarios. - Enhanced asynchronous search to accept rich full-text query objects directly. - **Bug Fixes** - Improved handling and validation of full-text search queries in both synchronous and asynchronous search operations. - **Tests** - Updated and expanded tests to cover new full-text query types and their usage in search functions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -9,15 +9,16 @@ use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery};
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
};
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::QueryFilter;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
@@ -27,34 +28,182 @@ use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
intern,
|
||||
};
|
||||
use pyo3::{pyclass, PyErr};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::util::{parse_distance_type, parse_fts_query};
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
|
||||
// Python representation of full text search parameters
|
||||
#[derive(Clone)]
|
||||
#[pyclass(get_all)]
|
||||
pub struct PyFullTextSearchQuery {
|
||||
pub columns: Vec<String>,
|
||||
pub query: String,
|
||||
pub limit: Option<i64>,
|
||||
pub wand_factor: Option<f32>,
|
||||
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
|
||||
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
match class_name(ob)?.as_str() {
|
||||
"MatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
let boost = ob.getattr("boost")?.extract()?;
|
||||
let fuzziness = ob.getattr("fuzziness")?.extract()?;
|
||||
let max_expansions = ob.getattr("max_expansions")?.extract()?;
|
||||
let operator = ob.getattr("operator")?.extract::<String>()?;
|
||||
|
||||
Ok(PyLanceDB(
|
||||
MatchQuery::new(query)
|
||||
.with_column(Some(column))
|
||||
.with_boost(boost)
|
||||
.with_fuzziness(fuzziness)
|
||||
.with_max_expansions(max_expansions)
|
||||
.with_operator(Operator::try_from(operator.as_str()).map_err(|e| {
|
||||
PyValueError::new_err(format!("Invalid operator: {}", e))
|
||||
})?)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
"PhraseQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
let slop = ob.getattr("slop")?.extract()?;
|
||||
|
||||
Ok(PyLanceDB(
|
||||
PhraseQuery::new(query)
|
||||
.with_column(Some(column))
|
||||
.with_slop(slop)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
"BoostQuery" => {
|
||||
let positive: PyLanceDB<FtsQuery> = ob.getattr("positive")?.extract()?;
|
||||
let negative: PyLanceDB<FtsQuery> = ob.getattr("negative")?.extract()?;
|
||||
let negative_boost = ob.getattr("negative_boost")?.extract()?;
|
||||
Ok(PyLanceDB(
|
||||
BoostQuery::new(positive.0, negative.0, negative_boost).into(),
|
||||
))
|
||||
}
|
||||
"MultiMatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let columns = ob.getattr("columns")?.extract()?;
|
||||
let boosts: Option<Vec<f32>> = ob.getattr("boosts")?.extract()?;
|
||||
let operator: String = ob.getattr("operator")?.extract()?;
|
||||
|
||||
let q = MultiMatchQuery::try_new(query, columns)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?;
|
||||
let q = if let Some(boosts) = boosts {
|
||||
q.try_with_boosts(boosts)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))?
|
||||
} else {
|
||||
q
|
||||
};
|
||||
|
||||
let op = Operator::try_from(operator.as_str())
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?;
|
||||
|
||||
Ok(PyLanceDB(q.with_operator(op).into()))
|
||||
}
|
||||
"BooleanQuery" => {
|
||||
let queries: Vec<(String, PyLanceDB<FtsQuery>)> =
|
||||
ob.getattr("queries")?.extract()?;
|
||||
let mut sub_queries = Vec::with_capacity(queries.len());
|
||||
for (occur, q) in queries {
|
||||
let occur = Occur::try_from(occur.as_str())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||
sub_queries.push((occur, q.0));
|
||||
}
|
||||
Ok(PyLanceDB(BooleanQuery::new(sub_queries).into()))
|
||||
}
|
||||
name => Err(PyValueError::new_err(format!(
|
||||
"Unsupported FTS query type: {}",
|
||||
name
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FullTextSearchQuery> for PyFullTextSearchQuery {
|
||||
fn from(query: FullTextSearchQuery) -> Self {
|
||||
Self {
|
||||
columns: query.columns().into_iter().collect(),
|
||||
query: query.query.query().to_owned(),
|
||||
limit: query.limit,
|
||||
wand_factor: query.wand_factor,
|
||||
impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
|
||||
type Target = PyAny;
|
||||
type Output = Bound<'py, Self::Target>;
|
||||
type Error = PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
let namespace = py
|
||||
.import(intern!(py, "lancedb"))
|
||||
.and_then(|m| m.getattr(intern!(py, "query")))
|
||||
.expect("Failed to import namespace");
|
||||
|
||||
match self.0 {
|
||||
FtsQuery::Match(query) => {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("boost", query.boost)?;
|
||||
kwargs.set_item("fuzziness", query.fuzziness)?;
|
||||
kwargs.set_item("max_expansions", query.max_expansions)?;
|
||||
kwargs.set_item("operator", operator_to_str(query.operator))?;
|
||||
namespace
|
||||
.getattr(intern!(py, "MatchQuery"))?
|
||||
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Phrase(query) => {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("slop", query.slop)?;
|
||||
namespace
|
||||
.getattr(intern!(py, "PhraseQuery"))?
|
||||
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Boost(query) => {
|
||||
let positive = PyLanceDB(query.positive.as_ref().clone()).into_pyobject(py)?;
|
||||
let negative = PyLanceDB(query.negative.as_ref().clone()).into_pyobject(py)?;
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("negative_boost", query.negative_boost)?;
|
||||
namespace
|
||||
.getattr(intern!(py, "BoostQuery"))?
|
||||
.call((positive, negative), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::MultiMatch(query) => {
|
||||
let first = &query.match_queries[0];
|
||||
let (columns, boosts): (Vec<_>, Vec<_>) = query
|
||||
.match_queries
|
||||
.iter()
|
||||
.map(|q| (q.column.as_ref().unwrap().clone(), q.boost))
|
||||
.unzip();
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("boosts", boosts)?;
|
||||
kwargs.set_item("operator", operator_to_str(first.operator))?;
|
||||
namespace
|
||||
.getattr(intern!(py, "MultiMatchQuery"))?
|
||||
.call((first.terms.clone(), columns), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Boolean(query) => {
|
||||
let mut queries = Vec::with_capacity(query.must.len() + query.should.len());
|
||||
for q in query.must {
|
||||
queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?));
|
||||
}
|
||||
for q in query.should {
|
||||
queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?));
|
||||
}
|
||||
namespace
|
||||
.getattr(intern!(py, "BooleanQuery"))?
|
||||
.call1((queries,))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn operator_to_str(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::And => "AND",
|
||||
Operator::Or => "OR",
|
||||
}
|
||||
}
|
||||
|
||||
fn occur_to_str(occur: Occur) -> &'static str {
|
||||
match occur {
|
||||
Occur::Must => "MUST",
|
||||
Occur::Should => "SHOULD",
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of query vector(s)
|
||||
#[derive(Clone)]
|
||||
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
|
||||
@@ -80,7 +229,7 @@ pub struct PyQueryRequest {
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
pub filter: Option<PyQueryFilter>,
|
||||
pub full_text_search: Option<PyFullTextSearchQuery>,
|
||||
pub full_text_search: Option<PyLanceDB<FtsQuery>>,
|
||||
pub select: PySelect,
|
||||
pub fast_search: Option<bool>,
|
||||
pub with_row_id: Option<bool>,
|
||||
@@ -106,7 +255,7 @@ impl From<AnyQuery> for PyQueryRequest {
|
||||
filter: query_request.filter.map(PyQueryFilter),
|
||||
full_text_search: query_request
|
||||
.full_text_search
|
||||
.map(PyFullTextSearchQuery::from),
|
||||
.map(|fts| PyLanceDB(fts.query)),
|
||||
select: PySelect(query_request.select),
|
||||
fast_search: Some(query_request.fast_search),
|
||||
with_row_id: Some(query_request.with_row_id),
|
||||
@@ -269,8 +418,8 @@ impl Query {
|
||||
}
|
||||
};
|
||||
let mut query = FullTextSearchQuery::new_query(query);
|
||||
if let Some(cols) = columns {
|
||||
if !cols.is_empty() {
|
||||
match columns {
|
||||
Some(cols) if !cols.is_empty() => {
|
||||
query = query.with_columns(&cols).map_err(|e| {
|
||||
PyValueError::new_err(format!(
|
||||
"Failed to set full text search columns: {}",
|
||||
@@ -278,15 +427,12 @@ impl Query {
|
||||
))
|
||||
})?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
query
|
||||
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
|
||||
let query = parse_fts_query(query)?;
|
||||
FullTextSearchQuery::new_query(query)
|
||||
} else {
|
||||
return Err(PyValueError::new_err(
|
||||
"query must be a string or a Query object",
|
||||
));
|
||||
let query = fts_query.extract::<PyLanceDB<FtsQuery>>()?;
|
||||
FullTextSearchQuery::new_query(query.0)
|
||||
};
|
||||
|
||||
Ok(FTSQuery {
|
||||
|
||||
Reference in New Issue
Block a user