mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat: support new FTS features in python SDK (#2411)
- AND operator - phrase query slop param - boolean query <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for combining full-text search queries using AND/OR operators, enabling more flexible query composition. - Introduced new query types and parameters, including boolean queries, operator selection, occurrence constraints, and phrase slop for advanced search scenarios. - Enhanced asynchronous search to accept rich full-text query objects directly. - **Bug Fixes** - Improved handling and validation of full-text search queries in both synchronous and asynchronous search operations. - **Tests** - Updated and expanded tests to cover new full-text query types and their usage in search functions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -165,17 +165,14 @@ class HybridQuery:
|
||||
def get_with_row_id(self) -> bool: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class PyFullTextSearchQuery:
|
||||
columns: Optional[List[str]]
|
||||
query: str
|
||||
limit: Optional[int]
|
||||
wand_factor: Optional[float]
|
||||
class FullTextQuery:
|
||||
pass
|
||||
|
||||
class PyQueryRequest:
|
||||
limit: Optional[int]
|
||||
offset: Optional[int]
|
||||
filter: Optional[Union[str, bytes]]
|
||||
full_text_search: Optional[PyFullTextSearchQuery]
|
||||
full_text_search: Optional[FullTextQuery]
|
||||
select: Optional[Union[str, List[str]]]
|
||||
fast_search: Optional[bool]
|
||||
with_row_id: Optional[bool]
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import abc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from datetime import timedelta
|
||||
@@ -88,15 +87,27 @@ def ensure_vector_query(
|
||||
return val
|
||||
|
||||
|
||||
class FullTextQueryType(Enum):
|
||||
class FullTextQueryType(str, Enum):
|
||||
MATCH = "match"
|
||||
MATCH_PHRASE = "match_phrase"
|
||||
BOOST = "boost"
|
||||
MULTI_MATCH = "multi_match"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
|
||||
class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
||||
@abc.abstractmethod
|
||||
class FullTextOperator(str, Enum):
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
|
||||
|
||||
class Occur(str, Enum):
|
||||
MUST = "MUST"
|
||||
SHOULD = "SHOULD"
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class FullTextQuery(ABC):
|
||||
@abstractmethod
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
"""
|
||||
Get the query type of the query.
|
||||
@@ -106,193 +117,174 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
||||
str
|
||||
The type of the query.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_dict(self) -> dict:
|
||||
def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||
"""
|
||||
Convert the query to a dictionary.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The query as a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
class MatchQuery(FullTextQuery):
|
||||
query: str
|
||||
column: str
|
||||
boost: float = 1.0
|
||||
fuzziness: int = 0
|
||||
max_expansions: int = 50
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
column: str,
|
||||
*,
|
||||
boost: float = 1.0,
|
||||
fuzziness: int = 0,
|
||||
max_expansions: int = 50,
|
||||
):
|
||||
"""
|
||||
Match query for full-text search.
|
||||
Combine two queries with a logical AND operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The query string to match against.
|
||||
column : str
|
||||
The name of the column to match against.
|
||||
boost : float, default 1.0
|
||||
The boost factor for the query.
|
||||
The score of each matching document is multiplied by this value.
|
||||
fuzziness : int, optional
|
||||
The maximum edit distance for each term in the match query.
|
||||
Defaults to 0 (exact match).
|
||||
If None, fuzziness is applied automatically by the rules:
|
||||
- 0 for terms with length <= 2
|
||||
- 1 for terms with length <= 5
|
||||
- 2 for terms with length > 5
|
||||
max_expansions : int, optional
|
||||
The maximum number of terms to consider for fuzzy matching.
|
||||
Defaults to 50.
|
||||
other : FullTextQuery
|
||||
The other query to combine with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FullTextQuery
|
||||
A new query that combines both queries with AND.
|
||||
"""
|
||||
super().__init__(
|
||||
query=query,
|
||||
column=column,
|
||||
boost=boost,
|
||||
fuzziness=fuzziness,
|
||||
max_expansions=max_expansions,
|
||||
)
|
||||
return BooleanQuery([(Occur.MUST, self), (Occur.MUST, other)])
|
||||
|
||||
def __or__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||
"""
|
||||
Combine two queries with a logical OR operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : FullTextQuery
|
||||
The other query to combine with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FullTextQuery
|
||||
A new query that combines both queries with OR.
|
||||
"""
|
||||
return BooleanQuery([(Occur.SHOULD, self), (Occur.SHOULD, other)])
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class MatchQuery(FullTextQuery):
|
||||
"""
|
||||
Match query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The query string to match against.
|
||||
column : str
|
||||
The name of the column to match against.
|
||||
boost : float, default 1.0
|
||||
The boost factor for the query.
|
||||
The score of each matching document is multiplied by this value.
|
||||
fuzziness : int, optional
|
||||
The maximum edit distance for each term in the match query.
|
||||
Defaults to 0 (exact match).
|
||||
If None, fuzziness is applied automatically by the rules:
|
||||
- 0 for terms with length <= 2
|
||||
- 1 for terms with length <= 5
|
||||
- 2 for terms with length > 5
|
||||
max_expansions : int, optional
|
||||
The maximum number of terms to consider for fuzzy matching.
|
||||
Defaults to 50.
|
||||
operator : FullTextOperator, default OR
|
||||
The operator to use for combining the query results.
|
||||
Can be either `AND` or `OR`.
|
||||
If `AND`, all terms in the query must match.
|
||||
If `OR`, at least one term in the query must match.
|
||||
"""
|
||||
|
||||
query: str
|
||||
column: str
|
||||
boost: float = pydantic.Field(1.0, kw_only=True)
|
||||
fuzziness: int = pydantic.Field(0, kw_only=True)
|
||||
max_expansions: int = pydantic.Field(50, kw_only=True)
|
||||
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MATCH
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"match": {
|
||||
self.column: {
|
||||
"query": self.query,
|
||||
"boost": self.boost,
|
||||
"fuzziness": self.fuzziness,
|
||||
"max_expansions": self.max_expansions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class PhraseQuery(FullTextQuery):
|
||||
"""
|
||||
Phrase query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The query string to match against.
|
||||
column : str
|
||||
The name of the column to match against.
|
||||
"""
|
||||
|
||||
query: str
|
||||
column: str
|
||||
|
||||
def __init__(self, query: str, column: str):
|
||||
"""
|
||||
Phrase query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The query string to match against.
|
||||
column : str
|
||||
The name of the column to match against.
|
||||
"""
|
||||
super().__init__(query=query, column=column)
|
||||
slop: int = pydantic.Field(0, kw_only=True)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MATCH_PHRASE
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"match_phrase": {
|
||||
self.column: self.query,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class BoostQuery(FullTextQuery):
|
||||
"""
|
||||
Boost query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positive : dict
|
||||
The positive query object.
|
||||
negative : dict
|
||||
The negative query object.
|
||||
negative_boost : float, default 0.5
|
||||
The boost factor for the negative query.
|
||||
"""
|
||||
|
||||
positive: FullTextQuery
|
||||
negative: FullTextQuery
|
||||
negative_boost: float = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positive: FullTextQuery,
|
||||
negative: FullTextQuery,
|
||||
*,
|
||||
negative_boost: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Boost query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positive : dict
|
||||
The positive query object.
|
||||
negative : dict
|
||||
The negative query object.
|
||||
negative_boost : float
|
||||
The boost factor for the negative query.
|
||||
"""
|
||||
super().__init__(
|
||||
positive=positive, negative=negative, negative_boost=negative_boost
|
||||
)
|
||||
negative_boost: float = pydantic.Field(0.5, kw_only=True)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.BOOST
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"boost": {
|
||||
"positive": self.positive.to_dict(),
|
||||
"negative": self.negative.to_dict(),
|
||||
"negative_boost": self.negative_boost,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class MultiMatchQuery(FullTextQuery):
|
||||
"""
|
||||
Multi-match query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str | list[Query]
|
||||
If a string, the query string to match against.
|
||||
columns : list[str]
|
||||
The list of columns to match against.
|
||||
boosts : list[float], optional
|
||||
The list of boost factors for each column. If not provided,
|
||||
all columns will have the same boost factor.
|
||||
operator : FullTextOperator, default OR
|
||||
The operator to use for combining the query results.
|
||||
Can be either `AND` or `OR`.
|
||||
It would be applied to all columns individually.
|
||||
For example, if the operator is `AND`,
|
||||
then the query "hello world" is equal to
|
||||
`match("hello AND world", column1) OR match("hello AND world", column2)`.
|
||||
"""
|
||||
|
||||
query: str
|
||||
columns: list[str]
|
||||
boosts: list[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
columns: list[str],
|
||||
*,
|
||||
boosts: Optional[list[float]] = None,
|
||||
):
|
||||
"""
|
||||
Multi-match query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The query string to match against.
|
||||
|
||||
columns : list[str]
|
||||
The list of columns to match against.
|
||||
|
||||
boosts : list[float], optional
|
||||
The list of boost factors for each column. If not provided,
|
||||
all columns will have the same boost factor.
|
||||
"""
|
||||
if boosts is None:
|
||||
boosts = [1.0] * len(columns)
|
||||
super().__init__(query=query, columns=columns, boosts=boosts)
|
||||
boosts: Optional[list[float]] = pydantic.Field(None, kw_only=True)
|
||||
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MULTI_MATCH
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"multi_match": {
|
||||
"query": self.query,
|
||||
"columns": self.columns,
|
||||
"boost": self.boosts,
|
||||
}
|
||||
}
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class BooleanQuery(FullTextQuery):
|
||||
"""
|
||||
Boolean query for full-text search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
queries : list[tuple(Occur, FullTextQuery)]
|
||||
The list of queries with their occurrence requirements.
|
||||
"""
|
||||
|
||||
queries: list[tuple[Occur, FullTextQuery]]
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.BOOLEAN
|
||||
|
||||
|
||||
class FullTextSearchQuery(pydantic.BaseModel):
|
||||
@@ -493,10 +485,8 @@ class Query(pydantic.BaseModel):
|
||||
query.postfilter = req.postfilter
|
||||
if req.full_text_search is not None:
|
||||
query.full_text_query = FullTextSearchQuery(
|
||||
columns=req.full_text_search.columns,
|
||||
query=req.full_text_search.query,
|
||||
limit=req.full_text_search.limit,
|
||||
wand_factor=req.full_text_search.wand_factor,
|
||||
columns=None,
|
||||
query=req.full_text_search,
|
||||
)
|
||||
return query
|
||||
|
||||
@@ -2513,7 +2503,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
|
||||
|
||||
|
||||
class AsyncFTSQuery(AsyncQueryBase):
|
||||
@@ -2835,7 +2825,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
|
||||
@@ -215,6 +215,19 @@ def test_search_fts(table, use_tantivy):
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test boolean query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
for r in results:
|
||||
assert "puppy" in r["text"]
|
||||
assert "runs" in r["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_select_async(async_table):
|
||||
|
||||
@@ -25,6 +25,8 @@ from lancedb.query import (
|
||||
AsyncQueryBase,
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
MatchQuery,
|
||||
PhraseQuery,
|
||||
Query,
|
||||
FullTextSearchQuery,
|
||||
)
|
||||
@@ -1065,18 +1067,27 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
)
|
||||
|
||||
# FTS queries
|
||||
q = (await table_async.search("foo")).limit(10).to_query_object()
|
||||
match_query = MatchQuery("foo", "text")
|
||||
q = (await table_async.search(match_query)).limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
full_text_query=FullTextSearchQuery(columns=None, query=match_query),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
q = (await table_async.search("foo", query_type="fts")).to_query_object()
|
||||
q = (await table_async.search(match_query)).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
full_text_query=FullTextSearchQuery(columns=None, query=match_query),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
phrase_query = PhraseQuery("foo", "text", slop=1)
|
||||
q = (await table_async.search(phrase_query)).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
full_text_query=FullTextSearchQuery(columns=None, query=phrase_query),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,15 +9,16 @@ use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery};
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
};
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::QueryFilter;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
@@ -27,34 +28,182 @@ use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
intern,
|
||||
};
|
||||
use pyo3::{pyclass, PyErr};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::util::{parse_distance_type, parse_fts_query};
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
|
||||
// Python representation of full text search parameters
|
||||
#[derive(Clone)]
|
||||
#[pyclass(get_all)]
|
||||
pub struct PyFullTextSearchQuery {
|
||||
pub columns: Vec<String>,
|
||||
pub query: String,
|
||||
pub limit: Option<i64>,
|
||||
pub wand_factor: Option<f32>,
|
||||
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
|
||||
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
match class_name(ob)?.as_str() {
|
||||
"MatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
let boost = ob.getattr("boost")?.extract()?;
|
||||
let fuzziness = ob.getattr("fuzziness")?.extract()?;
|
||||
let max_expansions = ob.getattr("max_expansions")?.extract()?;
|
||||
let operator = ob.getattr("operator")?.extract::<String>()?;
|
||||
|
||||
Ok(PyLanceDB(
|
||||
MatchQuery::new(query)
|
||||
.with_column(Some(column))
|
||||
.with_boost(boost)
|
||||
.with_fuzziness(fuzziness)
|
||||
.with_max_expansions(max_expansions)
|
||||
.with_operator(Operator::try_from(operator.as_str()).map_err(|e| {
|
||||
PyValueError::new_err(format!("Invalid operator: {}", e))
|
||||
})?)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
"PhraseQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
let slop = ob.getattr("slop")?.extract()?;
|
||||
|
||||
Ok(PyLanceDB(
|
||||
PhraseQuery::new(query)
|
||||
.with_column(Some(column))
|
||||
.with_slop(slop)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
"BoostQuery" => {
|
||||
let positive: PyLanceDB<FtsQuery> = ob.getattr("positive")?.extract()?;
|
||||
let negative: PyLanceDB<FtsQuery> = ob.getattr("negative")?.extract()?;
|
||||
let negative_boost = ob.getattr("negative_boost")?.extract()?;
|
||||
Ok(PyLanceDB(
|
||||
BoostQuery::new(positive.0, negative.0, negative_boost).into(),
|
||||
))
|
||||
}
|
||||
"MultiMatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let columns = ob.getattr("columns")?.extract()?;
|
||||
let boosts: Option<Vec<f32>> = ob.getattr("boosts")?.extract()?;
|
||||
let operator: String = ob.getattr("operator")?.extract()?;
|
||||
|
||||
let q = MultiMatchQuery::try_new(query, columns)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?;
|
||||
let q = if let Some(boosts) = boosts {
|
||||
q.try_with_boosts(boosts)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))?
|
||||
} else {
|
||||
q
|
||||
};
|
||||
|
||||
let op = Operator::try_from(operator.as_str())
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?;
|
||||
|
||||
Ok(PyLanceDB(q.with_operator(op).into()))
|
||||
}
|
||||
"BooleanQuery" => {
|
||||
let queries: Vec<(String, PyLanceDB<FtsQuery>)> =
|
||||
ob.getattr("queries")?.extract()?;
|
||||
let mut sub_queries = Vec::with_capacity(queries.len());
|
||||
for (occur, q) in queries {
|
||||
let occur = Occur::try_from(occur.as_str())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||
sub_queries.push((occur, q.0));
|
||||
}
|
||||
Ok(PyLanceDB(BooleanQuery::new(sub_queries).into()))
|
||||
}
|
||||
name => Err(PyValueError::new_err(format!(
|
||||
"Unsupported FTS query type: {}",
|
||||
name
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FullTextSearchQuery> for PyFullTextSearchQuery {
|
||||
fn from(query: FullTextSearchQuery) -> Self {
|
||||
Self {
|
||||
columns: query.columns().into_iter().collect(),
|
||||
query: query.query.query().to_owned(),
|
||||
limit: query.limit,
|
||||
wand_factor: query.wand_factor,
|
||||
impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
|
||||
type Target = PyAny;
|
||||
type Output = Bound<'py, Self::Target>;
|
||||
type Error = PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
let namespace = py
|
||||
.import(intern!(py, "lancedb"))
|
||||
.and_then(|m| m.getattr(intern!(py, "query")))
|
||||
.expect("Failed to import namespace");
|
||||
|
||||
match self.0 {
|
||||
FtsQuery::Match(query) => {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("boost", query.boost)?;
|
||||
kwargs.set_item("fuzziness", query.fuzziness)?;
|
||||
kwargs.set_item("max_expansions", query.max_expansions)?;
|
||||
kwargs.set_item("operator", operator_to_str(query.operator))?;
|
||||
namespace
|
||||
.getattr(intern!(py, "MatchQuery"))?
|
||||
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Phrase(query) => {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("slop", query.slop)?;
|
||||
namespace
|
||||
.getattr(intern!(py, "PhraseQuery"))?
|
||||
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Boost(query) => {
|
||||
let positive = PyLanceDB(query.positive.as_ref().clone()).into_pyobject(py)?;
|
||||
let negative = PyLanceDB(query.negative.as_ref().clone()).into_pyobject(py)?;
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("negative_boost", query.negative_boost)?;
|
||||
namespace
|
||||
.getattr(intern!(py, "BoostQuery"))?
|
||||
.call((positive, negative), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::MultiMatch(query) => {
|
||||
let first = &query.match_queries[0];
|
||||
let (columns, boosts): (Vec<_>, Vec<_>) = query
|
||||
.match_queries
|
||||
.iter()
|
||||
.map(|q| (q.column.as_ref().unwrap().clone(), q.boost))
|
||||
.unzip();
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("boosts", boosts)?;
|
||||
kwargs.set_item("operator", operator_to_str(first.operator))?;
|
||||
namespace
|
||||
.getattr(intern!(py, "MultiMatchQuery"))?
|
||||
.call((first.terms.clone(), columns), Some(&kwargs))
|
||||
}
|
||||
FtsQuery::Boolean(query) => {
|
||||
let mut queries = Vec::with_capacity(query.must.len() + query.should.len());
|
||||
for q in query.must {
|
||||
queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?));
|
||||
}
|
||||
for q in query.should {
|
||||
queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?));
|
||||
}
|
||||
namespace
|
||||
.getattr(intern!(py, "BooleanQuery"))?
|
||||
.call1((queries,))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn operator_to_str(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::And => "AND",
|
||||
Operator::Or => "OR",
|
||||
}
|
||||
}
|
||||
|
||||
fn occur_to_str(occur: Occur) -> &'static str {
|
||||
match occur {
|
||||
Occur::Must => "MUST",
|
||||
Occur::Should => "SHOULD",
|
||||
}
|
||||
}
|
||||
|
||||
// Python representation of query vector(s)
|
||||
#[derive(Clone)]
|
||||
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
|
||||
@@ -80,7 +229,7 @@ pub struct PyQueryRequest {
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
pub filter: Option<PyQueryFilter>,
|
||||
pub full_text_search: Option<PyFullTextSearchQuery>,
|
||||
pub full_text_search: Option<PyLanceDB<FtsQuery>>,
|
||||
pub select: PySelect,
|
||||
pub fast_search: Option<bool>,
|
||||
pub with_row_id: Option<bool>,
|
||||
@@ -106,7 +255,7 @@ impl From<AnyQuery> for PyQueryRequest {
|
||||
filter: query_request.filter.map(PyQueryFilter),
|
||||
full_text_search: query_request
|
||||
.full_text_search
|
||||
.map(PyFullTextSearchQuery::from),
|
||||
.map(|fts| PyLanceDB(fts.query)),
|
||||
select: PySelect(query_request.select),
|
||||
fast_search: Some(query_request.fast_search),
|
||||
with_row_id: Some(query_request.with_row_id),
|
||||
@@ -269,8 +418,8 @@ impl Query {
|
||||
}
|
||||
};
|
||||
let mut query = FullTextSearchQuery::new_query(query);
|
||||
if let Some(cols) = columns {
|
||||
if !cols.is_empty() {
|
||||
match columns {
|
||||
Some(cols) if !cols.is_empty() => {
|
||||
query = query.with_columns(&cols).map_err(|e| {
|
||||
PyValueError::new_err(format!(
|
||||
"Failed to set full text search columns: {}",
|
||||
@@ -278,15 +427,12 @@ impl Query {
|
||||
))
|
||||
})?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
query
|
||||
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
|
||||
let query = parse_fts_query(query)?;
|
||||
FullTextSearchQuery::new_query(query)
|
||||
} else {
|
||||
return Err(PyValueError::new_err(
|
||||
"query must be a string or a Query object",
|
||||
));
|
||||
let query = fts_query.extract::<PyLanceDB<FtsQuery>>()?;
|
||||
FullTextSearchQuery::new_query(query.0)
|
||||
};
|
||||
|
||||
Ok(FTSQuery {
|
||||
|
||||
@@ -3,15 +3,11 @@
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::scalar::{BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, PhraseQuery};
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods};
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyfunction, PyResult,
|
||||
};
|
||||
use pyo3::{Bound, PyAny};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
///
|
||||
@@ -64,116 +60,6 @@ pub fn validate_table_name(table_name: &str) -> PyResult<()> {
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn parse_fts_query(query: &Bound<'_, PyDict>) -> PyResult<FtsQuery> {
|
||||
let query_type = query.keys().get_item(0)?.extract::<String>()?;
|
||||
let query_value = query
|
||||
.get_item(&query_type)?
|
||||
.ok_or(PyValueError::new_err(format!(
|
||||
"Query type {} not found",
|
||||
query_type
|
||||
)))?;
|
||||
let query_value = query_value.downcast::<PyDict>()?;
|
||||
|
||||
match query_type.as_str() {
|
||||
"match" => {
|
||||
let column = query_value.keys().get_item(0)?.extract::<String>()?;
|
||||
let params = query_value
|
||||
.get_item(&column)?
|
||||
.ok_or(PyValueError::new_err(format!(
|
||||
"column {} not found",
|
||||
column
|
||||
)))?;
|
||||
let params = params.downcast::<PyDict>()?;
|
||||
|
||||
let query = params
|
||||
.get_item("query")?
|
||||
.ok_or(PyValueError::new_err("query not found"))?
|
||||
.extract::<String>()?;
|
||||
let boost = params
|
||||
.get_item("boost")?
|
||||
.ok_or(PyValueError::new_err("boost not found"))?
|
||||
.extract::<f32>()?;
|
||||
let fuzziness = params
|
||||
.get_item("fuzziness")?
|
||||
.ok_or(PyValueError::new_err("fuzziness not found"))?
|
||||
.extract::<Option<u32>>()?;
|
||||
let max_expansions = params
|
||||
.get_item("max_expansions")?
|
||||
.ok_or(PyValueError::new_err("max_expansions not found"))?
|
||||
.extract::<usize>()?;
|
||||
|
||||
let query = MatchQuery::new(query)
|
||||
.with_column(Some(column))
|
||||
.with_boost(boost)
|
||||
.with_fuzziness(fuzziness)
|
||||
.with_max_expansions(max_expansions);
|
||||
Ok(query.into())
|
||||
}
|
||||
|
||||
"match_phrase" => {
|
||||
let column = query_value.keys().get_item(0)?.extract::<String>()?;
|
||||
let query = query_value
|
||||
.get_item(&column)?
|
||||
.ok_or(PyValueError::new_err(format!(
|
||||
"column {} not found",
|
||||
column
|
||||
)))?
|
||||
.extract::<String>()?;
|
||||
|
||||
let query = PhraseQuery::new(query).with_column(Some(column));
|
||||
Ok(query.into())
|
||||
}
|
||||
|
||||
"boost" => {
|
||||
let positive: Bound<'_, PyAny> = query_value
|
||||
.get_item("positive")?
|
||||
.ok_or(PyValueError::new_err("positive not found"))?;
|
||||
let positive = positive.downcast::<PyDict>()?;
|
||||
|
||||
let negative = query_value
|
||||
.get_item("negative")?
|
||||
.ok_or(PyValueError::new_err("negative not found"))?;
|
||||
let negative = negative.downcast::<PyDict>()?;
|
||||
|
||||
let negative_boost = query_value
|
||||
.get_item("negative_boost")?
|
||||
.ok_or(PyValueError::new_err("negative_boost not found"))?
|
||||
.extract::<f32>()?;
|
||||
|
||||
let positive_query = parse_fts_query(positive)?;
|
||||
let negative_query = parse_fts_query(negative)?;
|
||||
let query = BoostQuery::new(positive_query, negative_query, Some(negative_boost));
|
||||
|
||||
Ok(query.into())
|
||||
}
|
||||
|
||||
"multi_match" => {
|
||||
let query = query_value
|
||||
.get_item("query")?
|
||||
.ok_or(PyValueError::new_err("query not found"))?
|
||||
.extract::<String>()?;
|
||||
|
||||
let columns = query_value
|
||||
.get_item("columns")?
|
||||
.ok_or(PyValueError::new_err("columns not found"))?
|
||||
.extract::<Vec<String>>()?;
|
||||
|
||||
let boost = query_value
|
||||
.get_item("boost")?
|
||||
.ok_or(PyValueError::new_err("boost not found"))?
|
||||
.extract::<Vec<f32>>()?;
|
||||
|
||||
let query = MultiMatchQuery::try_new(query, columns)
|
||||
.and_then(|q| q.try_with_boosts(boost))
|
||||
.map_err(|e| {
|
||||
PyValueError::new_err(format!("Error creating MultiMatchQuery: {}", e))
|
||||
})?;
|
||||
Ok(query.into())
|
||||
}
|
||||
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Unsupported query type: {}",
|
||||
query_type
|
||||
))),
|
||||
}
|
||||
}
|
||||
/// A wrapper around a LanceDB type to allow it to be used in Python
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyLanceDB<T>(pub T);
|
||||
|
||||
Reference in New Issue
Block a user