feat: support new FTS features in python SDK (#2411)

- AND operator
- phrase query slop param
- boolean query

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for combining full-text search queries using AND/OR
operators, enabling more flexible query composition.
- Introduced new query types and parameters, including boolean queries,
operator selection, occurrence constraints, and phrase slop for advanced
search scenarios.
- Enhanced asynchronous search to accept rich full-text query objects
directly.

- **Bug Fixes**
- Improved handling and validation of full-text search queries in both
synchronous and asynchronous search operations.

- **Tests**
- Updated and expanded tests to cover new full-text query types and
their usage in search functions.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-06-06 14:33:46 +08:00
committed by GitHub
parent 65696d9713
commit 84ded9d678
6 changed files with 364 additions and 321 deletions

View File

@@ -165,17 +165,14 @@ class HybridQuery:
def get_with_row_id(self) -> bool: ... def get_with_row_id(self) -> bool: ...
def to_query_request(self) -> PyQueryRequest: ... def to_query_request(self) -> PyQueryRequest: ...
class PyFullTextSearchQuery: class FullTextQuery:
columns: Optional[List[str]] pass
query: str
limit: Optional[int]
wand_factor: Optional[float]
class PyQueryRequest: class PyQueryRequest:
limit: Optional[int] limit: Optional[int]
offset: Optional[int] offset: Optional[int]
filter: Optional[Union[str, bytes]] filter: Optional[Union[str, bytes]]
full_text_search: Optional[PyFullTextSearchQuery] full_text_search: Optional[FullTextQuery]
select: Optional[Union[str, List[str]]] select: Optional[Union[str, List[str]]]
fast_search: Optional[bool] fast_search: Optional[bool]
with_row_id: Optional[bool] with_row_id: Optional[bool]

View File

@@ -4,7 +4,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import abc
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from datetime import timedelta from datetime import timedelta
@@ -88,15 +87,27 @@ def ensure_vector_query(
return val return val
class FullTextQueryType(Enum): class FullTextQueryType(str, Enum):
MATCH = "match" MATCH = "match"
MATCH_PHRASE = "match_phrase" MATCH_PHRASE = "match_phrase"
BOOST = "boost" BOOST = "boost"
MULTI_MATCH = "multi_match" MULTI_MATCH = "multi_match"
BOOLEAN = "boolean"
class FullTextQuery(abc.ABC, pydantic.BaseModel): class FullTextOperator(str, Enum):
@abc.abstractmethod AND = "AND"
OR = "OR"
class Occur(str, Enum):
MUST = "MUST"
SHOULD = "SHOULD"
@pydantic.dataclasses.dataclass
class FullTextQuery(ABC):
@abstractmethod
def query_type(self) -> FullTextQueryType: def query_type(self) -> FullTextQueryType:
""" """
Get the query type of the query. Get the query type of the query.
@@ -106,193 +117,174 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel):
str str
The type of the query. The type of the query.
""" """
pass
@abc.abstractmethod def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
def to_dict(self) -> dict:
""" """
Convert the query to a dictionary. Combine two queries with a logical AND operation.
Returns
-------
dict
The query as a dictionary.
"""
class MatchQuery(FullTextQuery):
query: str
column: str
boost: float = 1.0
fuzziness: int = 0
max_expansions: int = 50
def __init__(
self,
query: str,
column: str,
*,
boost: float = 1.0,
fuzziness: int = 0,
max_expansions: int = 50,
):
"""
Match query for full-text search.
Parameters Parameters
---------- ----------
query : str other : FullTextQuery
The query string to match against. The other query to combine with.
column : str
The name of the column to match against. Returns
boost : float, default 1.0 -------
The boost factor for the query. FullTextQuery
The score of each matching document is multiplied by this value. A new query that combines both queries with AND.
fuzziness : int, optional
The maximum edit distance for each term in the match query.
Defaults to 0 (exact match).
If None, fuzziness is applied automatically by the rules:
- 0 for terms with length <= 2
- 1 for terms with length <= 5
- 2 for terms with length > 5
max_expansions : int, optional
The maximum number of terms to consider for fuzzy matching.
Defaults to 50.
""" """
super().__init__( return BooleanQuery([(Occur.MUST, self), (Occur.MUST, other)])
query=query,
column=column, def __or__(self, other: "FullTextQuery") -> "FullTextQuery":
boost=boost, """
fuzziness=fuzziness, Combine two queries with a logical OR operation.
max_expansions=max_expansions,
) Parameters
----------
other : FullTextQuery
The other query to combine with.
Returns
-------
FullTextQuery
A new query that combines both queries with OR.
"""
return BooleanQuery([(Occur.SHOULD, self), (Occur.SHOULD, other)])
@pydantic.dataclasses.dataclass
class MatchQuery(FullTextQuery):
"""
Match query for full-text search.
Parameters
----------
query : str
The query string to match against.
column : str
The name of the column to match against.
boost : float, default 1.0
The boost factor for the query.
The score of each matching document is multiplied by this value.
fuzziness : int, optional
The maximum edit distance for each term in the match query.
Defaults to 0 (exact match).
If None, fuzziness is applied automatically by the rules:
- 0 for terms with length <= 2
- 1 for terms with length <= 5
- 2 for terms with length > 5
max_expansions : int, optional
The maximum number of terms to consider for fuzzy matching.
Defaults to 50.
operator : FullTextOperator, default OR
The operator to use for combining the query results.
Can be either `AND` or `OR`.
If `AND`, all terms in the query must match.
If `OR`, at least one term in the query must match.
"""
query: str
column: str
boost: float = pydantic.Field(1.0, kw_only=True)
fuzziness: int = pydantic.Field(0, kw_only=True)
max_expansions: int = pydantic.Field(50, kw_only=True)
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
def query_type(self) -> FullTextQueryType: def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MATCH return FullTextQueryType.MATCH
def to_dict(self) -> dict:
return {
"match": {
self.column: {
"query": self.query,
"boost": self.boost,
"fuzziness": self.fuzziness,
"max_expansions": self.max_expansions,
}
}
}
@pydantic.dataclasses.dataclass
class PhraseQuery(FullTextQuery): class PhraseQuery(FullTextQuery):
"""
Phrase query for full-text search.
Parameters
----------
query : str
The query string to match against.
column : str
The name of the column to match against.
"""
query: str query: str
column: str column: str
slop: int = pydantic.Field(0, kw_only=True)
def __init__(self, query: str, column: str):
"""
Phrase query for full-text search.
Parameters
----------
query : str
The query string to match against.
column : str
The name of the column to match against.
"""
super().__init__(query=query, column=column)
def query_type(self) -> FullTextQueryType: def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MATCH_PHRASE return FullTextQueryType.MATCH_PHRASE
def to_dict(self) -> dict:
return {
"match_phrase": {
self.column: self.query,
}
}
@pydantic.dataclasses.dataclass
class BoostQuery(FullTextQuery): class BoostQuery(FullTextQuery):
"""
Boost query for full-text search.
Parameters
----------
positive : dict
The positive query object.
negative : dict
The negative query object.
negative_boost : float, default 0.5
The boost factor for the negative query.
"""
positive: FullTextQuery positive: FullTextQuery
negative: FullTextQuery negative: FullTextQuery
negative_boost: float = 0.5 negative_boost: float = pydantic.Field(0.5, kw_only=True)
def __init__(
self,
positive: FullTextQuery,
negative: FullTextQuery,
*,
negative_boost: float = 0.5,
):
"""
Boost query for full-text search.
Parameters
----------
positive : dict
The positive query object.
negative : dict
The negative query object.
negative_boost : float
The boost factor for the negative query.
"""
super().__init__(
positive=positive, negative=negative, negative_boost=negative_boost
)
def query_type(self) -> FullTextQueryType: def query_type(self) -> FullTextQueryType:
return FullTextQueryType.BOOST return FullTextQueryType.BOOST
def to_dict(self) -> dict:
return {
"boost": {
"positive": self.positive.to_dict(),
"negative": self.negative.to_dict(),
"negative_boost": self.negative_boost,
}
}
@pydantic.dataclasses.dataclass
class MultiMatchQuery(FullTextQuery): class MultiMatchQuery(FullTextQuery):
"""
Multi-match query for full-text search.
Parameters
----------
query : str | list[Query]
If a string, the query string to match against.
columns : list[str]
The list of columns to match against.
boosts : list[float], optional
The list of boost factors for each column. If not provided,
all columns will have the same boost factor.
operator : FullTextOperator, default OR
The operator to use for combining the query results.
Can be either `AND` or `OR`.
It would be applied to all columns individually.
For example, if the operator is `AND`,
then the query "hello world" is equal to
`match("hello AND world", column1) OR match("hello AND world", column2)`.
"""
query: str query: str
columns: list[str] columns: list[str]
boosts: list[float] boosts: Optional[list[float]] = pydantic.Field(None, kw_only=True)
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
def __init__(
self,
query: str,
columns: list[str],
*,
boosts: Optional[list[float]] = None,
):
"""
Multi-match query for full-text search.
Parameters
----------
query : str
The query string to match against.
columns : list[str]
The list of columns to match against.
boosts : list[float], optional
The list of boost factors for each column. If not provided,
all columns will have the same boost factor.
"""
if boosts is None:
boosts = [1.0] * len(columns)
super().__init__(query=query, columns=columns, boosts=boosts)
def query_type(self) -> FullTextQueryType: def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MULTI_MATCH return FullTextQueryType.MULTI_MATCH
def to_dict(self) -> dict:
return { @pydantic.dataclasses.dataclass
"multi_match": { class BooleanQuery(FullTextQuery):
"query": self.query, """
"columns": self.columns, Boolean query for full-text search.
"boost": self.boosts,
} Parameters
} ----------
queries : list[tuple(Occur, FullTextQuery)]
The list of queries with their occurrence requirements.
"""
queries: list[tuple[Occur, FullTextQuery]]
def query_type(self) -> FullTextQueryType:
return FullTextQueryType.BOOLEAN
class FullTextSearchQuery(pydantic.BaseModel): class FullTextSearchQuery(pydantic.BaseModel):
@@ -493,10 +485,8 @@ class Query(pydantic.BaseModel):
query.postfilter = req.postfilter query.postfilter = req.postfilter
if req.full_text_search is not None: if req.full_text_search is not None:
query.full_text_query = FullTextSearchQuery( query.full_text_query = FullTextSearchQuery(
columns=req.full_text_search.columns, columns=None,
query=req.full_text_search.query, query=req.full_text_search,
limit=req.full_text_search.limit,
wand_factor=req.full_text_search.wand_factor,
) )
return query return query
@@ -2513,7 +2503,7 @@ class AsyncQuery(AsyncQueryBase):
self._inner.nearest_to_text({"query": query, "columns": columns}) self._inner.nearest_to_text({"query": query, "columns": columns})
) )
# FullTextQuery object # FullTextQuery object
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()})) return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
class AsyncFTSQuery(AsyncQueryBase): class AsyncFTSQuery(AsyncQueryBase):
@@ -2835,7 +2825,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
self._inner.nearest_to_text({"query": query, "columns": columns}) self._inner.nearest_to_text({"query": query, "columns": columns})
) )
# FullTextQuery object # FullTextQuery object
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()})) return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
async def to_batches( async def to_batches(
self, self,

View File

@@ -215,6 +215,19 @@ def test_search_fts(table, use_tantivy):
assert len(results) == 5 assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score assert len(results[0]) == 3 # id, text, _score
# Test boolean query
results = (
table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text"))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
for r in results:
assert "puppy" in r["text"]
assert "runs" in r["text"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fts_select_async(async_table): async def test_fts_select_async(async_table):

View File

@@ -25,6 +25,8 @@ from lancedb.query import (
AsyncQueryBase, AsyncQueryBase,
AsyncVectorQuery, AsyncVectorQuery,
LanceVectorQueryBuilder, LanceVectorQueryBuilder,
MatchQuery,
PhraseQuery,
Query, Query,
FullTextSearchQuery, FullTextSearchQuery,
) )
@@ -1065,18 +1067,27 @@ async def test_query_serialization_async(table_async: AsyncTable):
) )
# FTS queries # FTS queries
q = (await table_async.search("foo")).limit(10).to_query_object() match_query = MatchQuery("foo", "text")
q = (await table_async.search(match_query)).limit(10).to_query_object()
check_set_props( check_set_props(
q, q,
limit=10, limit=10,
full_text_query=FullTextSearchQuery(columns=[], query="foo"), full_text_query=FullTextSearchQuery(columns=None, query=match_query),
with_row_id=False, with_row_id=False,
) )
q = (await table_async.search("foo", query_type="fts")).to_query_object() q = (await table_async.search(match_query)).to_query_object()
check_set_props( check_set_props(
q, q,
full_text_query=FullTextSearchQuery(columns=[], query="foo"), full_text_query=FullTextSearchQuery(columns=None, query=match_query),
with_row_id=False,
)
phrase_query = PhraseQuery("foo", "text", slop=1)
q = (await table_async.search(phrase_query)).to_query_object()
check_set_props(
q,
full_text_query=FullTextSearchQuery(columns=None, query=phrase_query),
with_row_id=False, with_row_id=False,
) )

View File

@@ -9,15 +9,16 @@ use arrow::array::Array;
use arrow::array::ArrayData; use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow; use arrow::pyarrow::FromPyArrow;
use arrow::pyarrow::IntoPyArrow; use arrow::pyarrow::IntoPyArrow;
use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery}; use lancedb::index::scalar::{
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
Operator, PhraseQuery,
};
use lancedb::query::QueryExecutionOptions; use lancedb::query::QueryExecutionOptions;
use lancedb::query::QueryFilter; use lancedb::query::QueryFilter;
use lancedb::query::{ use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
}; };
use lancedb::table::AnyQuery; use lancedb::table::AnyQuery;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::prelude::{PyAnyMethods, PyDictMethods};
use pyo3::pymethods; use pyo3::pymethods;
use pyo3::types::PyList; use pyo3::types::PyList;
@@ -27,34 +28,182 @@ use pyo3::IntoPyObject;
use pyo3::PyAny; use pyo3::PyAny;
use pyo3::PyRef; use pyo3::PyRef;
use pyo3::PyResult; use pyo3::PyResult;
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
use pyo3::{
exceptions::{PyNotImplementedError, PyValueError},
intern,
};
use pyo3::{pyclass, PyErr}; use pyo3::{pyclass, PyErr};
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
use crate::arrow::RecordBatchStream; use crate::util::parse_distance_type;
use crate::error::PythonErrorExt; use crate::{arrow::RecordBatchStream, util::PyLanceDB};
use crate::util::{parse_distance_type, parse_fts_query}; use crate::{error::PythonErrorExt, index::class_name};
// Python representation of full text search parameters impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
#[derive(Clone)] fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
#[pyclass(get_all)] match class_name(ob)?.as_str() {
pub struct PyFullTextSearchQuery { "MatchQuery" => {
pub columns: Vec<String>, let query = ob.getattr("query")?.extract()?;
pub query: String, let column = ob.getattr("column")?.extract()?;
pub limit: Option<i64>, let boost = ob.getattr("boost")?.extract()?;
pub wand_factor: Option<f32>, let fuzziness = ob.getattr("fuzziness")?.extract()?;
let max_expansions = ob.getattr("max_expansions")?.extract()?;
let operator = ob.getattr("operator")?.extract::<String>()?;
Ok(PyLanceDB(
MatchQuery::new(query)
.with_column(Some(column))
.with_boost(boost)
.with_fuzziness(fuzziness)
.with_max_expansions(max_expansions)
.with_operator(Operator::try_from(operator.as_str()).map_err(|e| {
PyValueError::new_err(format!("Invalid operator: {}", e))
})?)
.into(),
))
}
"PhraseQuery" => {
let query = ob.getattr("query")?.extract()?;
let column = ob.getattr("column")?.extract()?;
let slop = ob.getattr("slop")?.extract()?;
Ok(PyLanceDB(
PhraseQuery::new(query)
.with_column(Some(column))
.with_slop(slop)
.into(),
))
}
"BoostQuery" => {
let positive: PyLanceDB<FtsQuery> = ob.getattr("positive")?.extract()?;
let negative: PyLanceDB<FtsQuery> = ob.getattr("negative")?.extract()?;
let negative_boost = ob.getattr("negative_boost")?.extract()?;
Ok(PyLanceDB(
BoostQuery::new(positive.0, negative.0, negative_boost).into(),
))
}
"MultiMatchQuery" => {
let query = ob.getattr("query")?.extract()?;
let columns = ob.getattr("columns")?.extract()?;
let boosts: Option<Vec<f32>> = ob.getattr("boosts")?.extract()?;
let operator: String = ob.getattr("operator")?.extract()?;
let q = MultiMatchQuery::try_new(query, columns)
.map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?;
let q = if let Some(boosts) = boosts {
q.try_with_boosts(boosts)
.map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))?
} else {
q
};
let op = Operator::try_from(operator.as_str())
.map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?;
Ok(PyLanceDB(q.with_operator(op).into()))
}
"BooleanQuery" => {
let queries: Vec<(String, PyLanceDB<FtsQuery>)> =
ob.getattr("queries")?.extract()?;
let mut sub_queries = Vec::with_capacity(queries.len());
for (occur, q) in queries {
let occur = Occur::try_from(occur.as_str())
.map_err(|e| PyValueError::new_err(e.to_string()))?;
sub_queries.push((occur, q.0));
}
Ok(PyLanceDB(BooleanQuery::new(sub_queries).into()))
}
name => Err(PyValueError::new_err(format!(
"Unsupported FTS query type: {}",
name
))),
}
}
} }
impl From<FullTextSearchQuery> for PyFullTextSearchQuery { impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
fn from(query: FullTextSearchQuery) -> Self { type Target = PyAny;
Self { type Output = Bound<'py, Self::Target>;
columns: query.columns().into_iter().collect(), type Error = PyErr;
query: query.query.query().to_owned(),
limit: query.limit, fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
wand_factor: query.wand_factor, let namespace = py
.import(intern!(py, "lancedb"))
.and_then(|m| m.getattr(intern!(py, "query")))
.expect("Failed to import namespace");
match self.0 {
FtsQuery::Match(query) => {
let kwargs = PyDict::new(py);
kwargs.set_item("boost", query.boost)?;
kwargs.set_item("fuzziness", query.fuzziness)?;
kwargs.set_item("max_expansions", query.max_expansions)?;
kwargs.set_item("operator", operator_to_str(query.operator))?;
namespace
.getattr(intern!(py, "MatchQuery"))?
.call((query.terms, query.column.unwrap()), Some(&kwargs))
}
FtsQuery::Phrase(query) => {
let kwargs = PyDict::new(py);
kwargs.set_item("slop", query.slop)?;
namespace
.getattr(intern!(py, "PhraseQuery"))?
.call((query.terms, query.column.unwrap()), Some(&kwargs))
}
FtsQuery::Boost(query) => {
let positive = PyLanceDB(query.positive.as_ref().clone()).into_pyobject(py)?;
let negative = PyLanceDB(query.negative.as_ref().clone()).into_pyobject(py)?;
let kwargs = PyDict::new(py);
kwargs.set_item("negative_boost", query.negative_boost)?;
namespace
.getattr(intern!(py, "BoostQuery"))?
.call((positive, negative), Some(&kwargs))
}
FtsQuery::MultiMatch(query) => {
let first = &query.match_queries[0];
let (columns, boosts): (Vec<_>, Vec<_>) = query
.match_queries
.iter()
.map(|q| (q.column.as_ref().unwrap().clone(), q.boost))
.unzip();
let kwargs = PyDict::new(py);
kwargs.set_item("boosts", boosts)?;
kwargs.set_item("operator", operator_to_str(first.operator))?;
namespace
.getattr(intern!(py, "MultiMatchQuery"))?
.call((first.terms.clone(), columns), Some(&kwargs))
}
FtsQuery::Boolean(query) => {
let mut queries = Vec::with_capacity(query.must.len() + query.should.len());
for q in query.must {
queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?));
}
for q in query.should {
queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?));
}
namespace
.getattr(intern!(py, "BooleanQuery"))?
.call1((queries,))
}
} }
} }
} }
fn operator_to_str(op: Operator) -> &'static str {
match op {
Operator::And => "AND",
Operator::Or => "OR",
}
}
fn occur_to_str(occur: Occur) -> &'static str {
match occur {
Occur::Must => "MUST",
Occur::Should => "SHOULD",
}
}
// Python representation of query vector(s) // Python representation of query vector(s)
#[derive(Clone)] #[derive(Clone)]
pub struct PyQueryVectors(Vec<Arc<dyn Array>>); pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
@@ -80,7 +229,7 @@ pub struct PyQueryRequest {
pub limit: Option<usize>, pub limit: Option<usize>,
pub offset: Option<usize>, pub offset: Option<usize>,
pub filter: Option<PyQueryFilter>, pub filter: Option<PyQueryFilter>,
pub full_text_search: Option<PyFullTextSearchQuery>, pub full_text_search: Option<PyLanceDB<FtsQuery>>,
pub select: PySelect, pub select: PySelect,
pub fast_search: Option<bool>, pub fast_search: Option<bool>,
pub with_row_id: Option<bool>, pub with_row_id: Option<bool>,
@@ -106,7 +255,7 @@ impl From<AnyQuery> for PyQueryRequest {
filter: query_request.filter.map(PyQueryFilter), filter: query_request.filter.map(PyQueryFilter),
full_text_search: query_request full_text_search: query_request
.full_text_search .full_text_search
.map(PyFullTextSearchQuery::from), .map(|fts| PyLanceDB(fts.query)),
select: PySelect(query_request.select), select: PySelect(query_request.select),
fast_search: Some(query_request.fast_search), fast_search: Some(query_request.fast_search),
with_row_id: Some(query_request.with_row_id), with_row_id: Some(query_request.with_row_id),
@@ -269,8 +418,8 @@ impl Query {
} }
}; };
let mut query = FullTextSearchQuery::new_query(query); let mut query = FullTextSearchQuery::new_query(query);
if let Some(cols) = columns { match columns {
if !cols.is_empty() { Some(cols) if !cols.is_empty() => {
query = query.with_columns(&cols).map_err(|e| { query = query.with_columns(&cols).map_err(|e| {
PyValueError::new_err(format!( PyValueError::new_err(format!(
"Failed to set full text search columns: {}", "Failed to set full text search columns: {}",
@@ -278,15 +427,12 @@ impl Query {
)) ))
})?; })?;
} }
_ => {}
} }
query query
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
let query = parse_fts_query(query)?;
FullTextSearchQuery::new_query(query)
} else { } else {
return Err(PyValueError::new_err( let query = fts_query.extract::<PyLanceDB<FtsQuery>>()?;
"query must be a string or a Query object", FullTextSearchQuery::new_query(query.0)
));
}; };
Ok(FTSQuery { Ok(FTSQuery {

View File

@@ -3,15 +3,11 @@
use std::sync::Mutex; use std::sync::Mutex;
use lancedb::index::scalar::{BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, PhraseQuery};
use lancedb::DistanceType; use lancedb::DistanceType;
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods};
use pyo3::types::PyDict;
use pyo3::{ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyRuntimeError, PyValueError},
pyfunction, PyResult, pyfunction, PyResult,
}; };
use pyo3::{Bound, PyAny};
/// A wrapper around a rust builder /// A wrapper around a rust builder
/// ///
@@ -64,116 +60,6 @@ pub fn validate_table_name(table_name: &str) -> PyResult<()> {
.map_err(|e| PyValueError::new_err(e.to_string())) .map_err(|e| PyValueError::new_err(e.to_string()))
} }
pub fn parse_fts_query(query: &Bound<'_, PyDict>) -> PyResult<FtsQuery> { /// A wrapper around a LanceDB type to allow it to be used in Python
let query_type = query.keys().get_item(0)?.extract::<String>()?; #[derive(Debug, Clone)]
let query_value = query pub struct PyLanceDB<T>(pub T);
.get_item(&query_type)?
.ok_or(PyValueError::new_err(format!(
"Query type {} not found",
query_type
)))?;
let query_value = query_value.downcast::<PyDict>()?;
match query_type.as_str() {
"match" => {
let column = query_value.keys().get_item(0)?.extract::<String>()?;
let params = query_value
.get_item(&column)?
.ok_or(PyValueError::new_err(format!(
"column {} not found",
column
)))?;
let params = params.downcast::<PyDict>()?;
let query = params
.get_item("query")?
.ok_or(PyValueError::new_err("query not found"))?
.extract::<String>()?;
let boost = params
.get_item("boost")?
.ok_or(PyValueError::new_err("boost not found"))?
.extract::<f32>()?;
let fuzziness = params
.get_item("fuzziness")?
.ok_or(PyValueError::new_err("fuzziness not found"))?
.extract::<Option<u32>>()?;
let max_expansions = params
.get_item("max_expansions")?
.ok_or(PyValueError::new_err("max_expansions not found"))?
.extract::<usize>()?;
let query = MatchQuery::new(query)
.with_column(Some(column))
.with_boost(boost)
.with_fuzziness(fuzziness)
.with_max_expansions(max_expansions);
Ok(query.into())
}
"match_phrase" => {
let column = query_value.keys().get_item(0)?.extract::<String>()?;
let query = query_value
.get_item(&column)?
.ok_or(PyValueError::new_err(format!(
"column {} not found",
column
)))?
.extract::<String>()?;
let query = PhraseQuery::new(query).with_column(Some(column));
Ok(query.into())
}
"boost" => {
let positive: Bound<'_, PyAny> = query_value
.get_item("positive")?
.ok_or(PyValueError::new_err("positive not found"))?;
let positive = positive.downcast::<PyDict>()?;
let negative = query_value
.get_item("negative")?
.ok_or(PyValueError::new_err("negative not found"))?;
let negative = negative.downcast::<PyDict>()?;
let negative_boost = query_value
.get_item("negative_boost")?
.ok_or(PyValueError::new_err("negative_boost not found"))?
.extract::<f32>()?;
let positive_query = parse_fts_query(positive)?;
let negative_query = parse_fts_query(negative)?;
let query = BoostQuery::new(positive_query, negative_query, Some(negative_boost));
Ok(query.into())
}
"multi_match" => {
let query = query_value
.get_item("query")?
.ok_or(PyValueError::new_err("query not found"))?
.extract::<String>()?;
let columns = query_value
.get_item("columns")?
.ok_or(PyValueError::new_err("columns not found"))?
.extract::<Vec<String>>()?;
let boost = query_value
.get_item("boost")?
.ok_or(PyValueError::new_err("boost not found"))?
.extract::<Vec<f32>>()?;
let query = MultiMatchQuery::try_new(query, columns)
.and_then(|q| q.try_with_boosts(boost))
.map_err(|e| {
PyValueError::new_err(format!("Error creating MultiMatchQuery: {}", e))
})?;
Ok(query.into())
}
_ => Err(PyValueError::new_err(format!(
"Unsupported query type: {}",
query_type
))),
}
}