mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat: support new FTS features in python SDK (#2411)
- AND operator - phrase query slop param - boolean query <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for combining full-text search queries using AND/OR operators, enabling more flexible query composition. - Introduced new query types and parameters, including boolean queries, operator selection, occurrence constraints, and phrase slop for advanced search scenarios. - Enhanced asynchronous search to accept rich full-text query objects directly. - **Bug Fixes** - Improved handling and validation of full-text search queries in both synchronous and asynchronous search operations. - **Tests** - Updated and expanded tests to cover new full-text query types and their usage in search functions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -165,17 +165,14 @@ class HybridQuery:
|
|||||||
def get_with_row_id(self) -> bool: ...
|
def get_with_row_id(self) -> bool: ...
|
||||||
def to_query_request(self) -> PyQueryRequest: ...
|
def to_query_request(self) -> PyQueryRequest: ...
|
||||||
|
|
||||||
class PyFullTextSearchQuery:
|
class FullTextQuery:
|
||||||
columns: Optional[List[str]]
|
pass
|
||||||
query: str
|
|
||||||
limit: Optional[int]
|
|
||||||
wand_factor: Optional[float]
|
|
||||||
|
|
||||||
class PyQueryRequest:
|
class PyQueryRequest:
|
||||||
limit: Optional[int]
|
limit: Optional[int]
|
||||||
offset: Optional[int]
|
offset: Optional[int]
|
||||||
filter: Optional[Union[str, bytes]]
|
filter: Optional[Union[str, bytes]]
|
||||||
full_text_search: Optional[PyFullTextSearchQuery]
|
full_text_search: Optional[FullTextQuery]
|
||||||
select: Optional[Union[str, List[str]]]
|
select: Optional[Union[str, List[str]]]
|
||||||
fast_search: Optional[bool]
|
fast_search: Optional[bool]
|
||||||
with_row_id: Optional[bool]
|
with_row_id: Optional[bool]
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import abc
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@@ -88,15 +87,27 @@ def ensure_vector_query(
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
class FullTextQueryType(Enum):
|
class FullTextQueryType(str, Enum):
|
||||||
MATCH = "match"
|
MATCH = "match"
|
||||||
MATCH_PHRASE = "match_phrase"
|
MATCH_PHRASE = "match_phrase"
|
||||||
BOOST = "boost"
|
BOOST = "boost"
|
||||||
MULTI_MATCH = "multi_match"
|
MULTI_MATCH = "multi_match"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
|
||||||
|
|
||||||
class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
class FullTextOperator(str, Enum):
|
||||||
@abc.abstractmethod
|
AND = "AND"
|
||||||
|
OR = "OR"
|
||||||
|
|
||||||
|
|
||||||
|
class Occur(str, Enum):
|
||||||
|
MUST = "MUST"
|
||||||
|
SHOULD = "SHOULD"
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
|
class FullTextQuery(ABC):
|
||||||
|
@abstractmethod
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
"""
|
"""
|
||||||
Get the query type of the query.
|
Get the query type of the query.
|
||||||
@@ -106,193 +117,174 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
|||||||
str
|
str
|
||||||
The type of the query.
|
The type of the query.
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Convert the query to a dictionary.
|
Combine two queries with a logical AND operation.
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
The query as a dictionary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class MatchQuery(FullTextQuery):
|
|
||||||
query: str
|
|
||||||
column: str
|
|
||||||
boost: float = 1.0
|
|
||||||
fuzziness: int = 0
|
|
||||||
max_expansions: int = 50
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
column: str,
|
|
||||||
*,
|
|
||||||
boost: float = 1.0,
|
|
||||||
fuzziness: int = 0,
|
|
||||||
max_expansions: int = 50,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Match query for full-text search.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query : str
|
other : FullTextQuery
|
||||||
The query string to match against.
|
The other query to combine with.
|
||||||
column : str
|
|
||||||
The name of the column to match against.
|
Returns
|
||||||
boost : float, default 1.0
|
-------
|
||||||
The boost factor for the query.
|
FullTextQuery
|
||||||
The score of each matching document is multiplied by this value.
|
A new query that combines both queries with AND.
|
||||||
fuzziness : int, optional
|
|
||||||
The maximum edit distance for each term in the match query.
|
|
||||||
Defaults to 0 (exact match).
|
|
||||||
If None, fuzziness is applied automatically by the rules:
|
|
||||||
- 0 for terms with length <= 2
|
|
||||||
- 1 for terms with length <= 5
|
|
||||||
- 2 for terms with length > 5
|
|
||||||
max_expansions : int, optional
|
|
||||||
The maximum number of terms to consider for fuzzy matching.
|
|
||||||
Defaults to 50.
|
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
return BooleanQuery([(Occur.MUST, self), (Occur.MUST, other)])
|
||||||
query=query,
|
|
||||||
column=column,
|
def __or__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||||
boost=boost,
|
"""
|
||||||
fuzziness=fuzziness,
|
Combine two queries with a logical OR operation.
|
||||||
max_expansions=max_expansions,
|
|
||||||
)
|
Parameters
|
||||||
|
----------
|
||||||
|
other : FullTextQuery
|
||||||
|
The other query to combine with.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FullTextQuery
|
||||||
|
A new query that combines both queries with OR.
|
||||||
|
"""
|
||||||
|
return BooleanQuery([(Occur.SHOULD, self), (Occur.SHOULD, other)])
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
|
class MatchQuery(FullTextQuery):
|
||||||
|
"""
|
||||||
|
Match query for full-text search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The query string to match against.
|
||||||
|
column : str
|
||||||
|
The name of the column to match against.
|
||||||
|
boost : float, default 1.0
|
||||||
|
The boost factor for the query.
|
||||||
|
The score of each matching document is multiplied by this value.
|
||||||
|
fuzziness : int, optional
|
||||||
|
The maximum edit distance for each term in the match query.
|
||||||
|
Defaults to 0 (exact match).
|
||||||
|
If None, fuzziness is applied automatically by the rules:
|
||||||
|
- 0 for terms with length <= 2
|
||||||
|
- 1 for terms with length <= 5
|
||||||
|
- 2 for terms with length > 5
|
||||||
|
max_expansions : int, optional
|
||||||
|
The maximum number of terms to consider for fuzzy matching.
|
||||||
|
Defaults to 50.
|
||||||
|
operator : FullTextOperator, default OR
|
||||||
|
The operator to use for combining the query results.
|
||||||
|
Can be either `AND` or `OR`.
|
||||||
|
If `AND`, all terms in the query must match.
|
||||||
|
If `OR`, at least one term in the query must match.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
column: str
|
||||||
|
boost: float = pydantic.Field(1.0, kw_only=True)
|
||||||
|
fuzziness: int = pydantic.Field(0, kw_only=True)
|
||||||
|
max_expansions: int = pydantic.Field(50, kw_only=True)
|
||||||
|
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MATCH
|
return FullTextQueryType.MATCH
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"match": {
|
|
||||||
self.column: {
|
|
||||||
"query": self.query,
|
|
||||||
"boost": self.boost,
|
|
||||||
"fuzziness": self.fuzziness,
|
|
||||||
"max_expansions": self.max_expansions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
class PhraseQuery(FullTextQuery):
|
class PhraseQuery(FullTextQuery):
|
||||||
|
"""
|
||||||
|
Phrase query for full-text search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The query string to match against.
|
||||||
|
column : str
|
||||||
|
The name of the column to match against.
|
||||||
|
"""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
column: str
|
column: str
|
||||||
|
slop: int = pydantic.Field(0, kw_only=True)
|
||||||
def __init__(self, query: str, column: str):
|
|
||||||
"""
|
|
||||||
Phrase query for full-text search.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query : str
|
|
||||||
The query string to match against.
|
|
||||||
column : str
|
|
||||||
The name of the column to match against.
|
|
||||||
"""
|
|
||||||
super().__init__(query=query, column=column)
|
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MATCH_PHRASE
|
return FullTextQueryType.MATCH_PHRASE
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"match_phrase": {
|
|
||||||
self.column: self.query,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
class BoostQuery(FullTextQuery):
|
class BoostQuery(FullTextQuery):
|
||||||
|
"""
|
||||||
|
Boost query for full-text search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
positive : dict
|
||||||
|
The positive query object.
|
||||||
|
negative : dict
|
||||||
|
The negative query object.
|
||||||
|
negative_boost : float, default 0.5
|
||||||
|
The boost factor for the negative query.
|
||||||
|
"""
|
||||||
|
|
||||||
positive: FullTextQuery
|
positive: FullTextQuery
|
||||||
negative: FullTextQuery
|
negative: FullTextQuery
|
||||||
negative_boost: float = 0.5
|
negative_boost: float = pydantic.Field(0.5, kw_only=True)
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
positive: FullTextQuery,
|
|
||||||
negative: FullTextQuery,
|
|
||||||
*,
|
|
||||||
negative_boost: float = 0.5,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Boost query for full-text search.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
positive : dict
|
|
||||||
The positive query object.
|
|
||||||
negative : dict
|
|
||||||
The negative query object.
|
|
||||||
negative_boost : float
|
|
||||||
The boost factor for the negative query.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
positive=positive, negative=negative, negative_boost=negative_boost
|
|
||||||
)
|
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.BOOST
|
return FullTextQueryType.BOOST
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"boost": {
|
|
||||||
"positive": self.positive.to_dict(),
|
|
||||||
"negative": self.negative.to_dict(),
|
|
||||||
"negative_boost": self.negative_boost,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
class MultiMatchQuery(FullTextQuery):
|
class MultiMatchQuery(FullTextQuery):
|
||||||
|
"""
|
||||||
|
Multi-match query for full-text search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str | list[Query]
|
||||||
|
If a string, the query string to match against.
|
||||||
|
columns : list[str]
|
||||||
|
The list of columns to match against.
|
||||||
|
boosts : list[float], optional
|
||||||
|
The list of boost factors for each column. If not provided,
|
||||||
|
all columns will have the same boost factor.
|
||||||
|
operator : FullTextOperator, default OR
|
||||||
|
The operator to use for combining the query results.
|
||||||
|
Can be either `AND` or `OR`.
|
||||||
|
It would be applied to all columns individually.
|
||||||
|
For example, if the operator is `AND`,
|
||||||
|
then the query "hello world" is equal to
|
||||||
|
`match("hello AND world", column1) OR match("hello AND world", column2)`.
|
||||||
|
"""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
columns: list[str]
|
columns: list[str]
|
||||||
boosts: list[float]
|
boosts: Optional[list[float]] = pydantic.Field(None, kw_only=True)
|
||||||
|
operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True)
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
columns: list[str],
|
|
||||||
*,
|
|
||||||
boosts: Optional[list[float]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Multi-match query for full-text search.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query : str
|
|
||||||
The query string to match against.
|
|
||||||
|
|
||||||
columns : list[str]
|
|
||||||
The list of columns to match against.
|
|
||||||
|
|
||||||
boosts : list[float], optional
|
|
||||||
The list of boost factors for each column. If not provided,
|
|
||||||
all columns will have the same boost factor.
|
|
||||||
"""
|
|
||||||
if boosts is None:
|
|
||||||
boosts = [1.0] * len(columns)
|
|
||||||
super().__init__(query=query, columns=columns, boosts=boosts)
|
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MULTI_MATCH
|
return FullTextQueryType.MULTI_MATCH
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
@pydantic.dataclasses.dataclass
|
||||||
"multi_match": {
|
class BooleanQuery(FullTextQuery):
|
||||||
"query": self.query,
|
"""
|
||||||
"columns": self.columns,
|
Boolean query for full-text search.
|
||||||
"boost": self.boosts,
|
|
||||||
}
|
Parameters
|
||||||
}
|
----------
|
||||||
|
queries : list[tuple(Occur, FullTextQuery)]
|
||||||
|
The list of queries with their occurrence requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
queries: list[tuple[Occur, FullTextQuery]]
|
||||||
|
|
||||||
|
def query_type(self) -> FullTextQueryType:
|
||||||
|
return FullTextQueryType.BOOLEAN
|
||||||
|
|
||||||
|
|
||||||
class FullTextSearchQuery(pydantic.BaseModel):
|
class FullTextSearchQuery(pydantic.BaseModel):
|
||||||
@@ -493,10 +485,8 @@ class Query(pydantic.BaseModel):
|
|||||||
query.postfilter = req.postfilter
|
query.postfilter = req.postfilter
|
||||||
if req.full_text_search is not None:
|
if req.full_text_search is not None:
|
||||||
query.full_text_query = FullTextSearchQuery(
|
query.full_text_query = FullTextSearchQuery(
|
||||||
columns=req.full_text_search.columns,
|
columns=None,
|
||||||
query=req.full_text_search.query,
|
query=req.full_text_search,
|
||||||
limit=req.full_text_search.limit,
|
|
||||||
wand_factor=req.full_text_search.wand_factor,
|
|
||||||
)
|
)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
@@ -2513,7 +2503,7 @@ class AsyncQuery(AsyncQueryBase):
|
|||||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||||
)
|
)
|
||||||
# FullTextQuery object
|
# FullTextQuery object
|
||||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
|
||||||
|
|
||||||
|
|
||||||
class AsyncFTSQuery(AsyncQueryBase):
|
class AsyncFTSQuery(AsyncQueryBase):
|
||||||
@@ -2835,7 +2825,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
|||||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||||
)
|
)
|
||||||
# FullTextQuery object
|
# FullTextQuery object
|
||||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
|
||||||
|
|
||||||
async def to_batches(
|
async def to_batches(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -215,6 +215,19 @@ def test_search_fts(table, use_tantivy):
|
|||||||
assert len(results) == 5
|
assert len(results) == 5
|
||||||
assert len(results[0]) == 3 # id, text, _score
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
# Test boolean query
|
||||||
|
results = (
|
||||||
|
table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text"))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
for r in results:
|
||||||
|
assert "puppy" in r["text"]
|
||||||
|
assert "runs" in r["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fts_select_async(async_table):
|
async def test_fts_select_async(async_table):
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from lancedb.query import (
|
|||||||
AsyncQueryBase,
|
AsyncQueryBase,
|
||||||
AsyncVectorQuery,
|
AsyncVectorQuery,
|
||||||
LanceVectorQueryBuilder,
|
LanceVectorQueryBuilder,
|
||||||
|
MatchQuery,
|
||||||
|
PhraseQuery,
|
||||||
Query,
|
Query,
|
||||||
FullTextSearchQuery,
|
FullTextSearchQuery,
|
||||||
)
|
)
|
||||||
@@ -1065,18 +1067,27 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# FTS queries
|
# FTS queries
|
||||||
q = (await table_async.search("foo")).limit(10).to_query_object()
|
match_query = MatchQuery("foo", "text")
|
||||||
|
q = (await table_async.search(match_query)).limit(10).to_query_object()
|
||||||
check_set_props(
|
check_set_props(
|
||||||
q,
|
q,
|
||||||
limit=10,
|
limit=10,
|
||||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
full_text_query=FullTextSearchQuery(columns=None, query=match_query),
|
||||||
with_row_id=False,
|
with_row_id=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
q = (await table_async.search("foo", query_type="fts")).to_query_object()
|
q = (await table_async.search(match_query)).to_query_object()
|
||||||
check_set_props(
|
check_set_props(
|
||||||
q,
|
q,
|
||||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
full_text_query=FullTextSearchQuery(columns=None, query=match_query),
|
||||||
|
with_row_id=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
phrase_query = PhraseQuery("foo", "text", slop=1)
|
||||||
|
q = (await table_async.search(phrase_query)).to_query_object()
|
||||||
|
check_set_props(
|
||||||
|
q,
|
||||||
|
full_text_query=FullTextSearchQuery(columns=None, query=phrase_query),
|
||||||
with_row_id=False,
|
with_row_id=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,15 +9,16 @@ use arrow::array::Array;
|
|||||||
use arrow::array::ArrayData;
|
use arrow::array::ArrayData;
|
||||||
use arrow::pyarrow::FromPyArrow;
|
use arrow::pyarrow::FromPyArrow;
|
||||||
use arrow::pyarrow::IntoPyArrow;
|
use arrow::pyarrow::IntoPyArrow;
|
||||||
use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery};
|
use lancedb::index::scalar::{
|
||||||
|
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||||
|
Operator, PhraseQuery,
|
||||||
|
};
|
||||||
use lancedb::query::QueryExecutionOptions;
|
use lancedb::query::QueryExecutionOptions;
|
||||||
use lancedb::query::QueryFilter;
|
use lancedb::query::QueryFilter;
|
||||||
use lancedb::query::{
|
use lancedb::query::{
|
||||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||||
};
|
};
|
||||||
use lancedb::table::AnyQuery;
|
use lancedb::table::AnyQuery;
|
||||||
use pyo3::exceptions::PyRuntimeError;
|
|
||||||
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
|
|
||||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||||
use pyo3::pymethods;
|
use pyo3::pymethods;
|
||||||
use pyo3::types::PyList;
|
use pyo3::types::PyList;
|
||||||
@@ -27,34 +28,182 @@ use pyo3::IntoPyObject;
|
|||||||
use pyo3::PyAny;
|
use pyo3::PyAny;
|
||||||
use pyo3::PyRef;
|
use pyo3::PyRef;
|
||||||
use pyo3::PyResult;
|
use pyo3::PyResult;
|
||||||
|
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::{PyNotImplementedError, PyValueError},
|
||||||
|
intern,
|
||||||
|
};
|
||||||
use pyo3::{pyclass, PyErr};
|
use pyo3::{pyclass, PyErr};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
use crate::arrow::RecordBatchStream;
|
use crate::util::parse_distance_type;
|
||||||
use crate::error::PythonErrorExt;
|
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||||
use crate::util::{parse_distance_type, parse_fts_query};
|
use crate::{error::PythonErrorExt, index::class_name};
|
||||||
|
|
||||||
// Python representation of full text search parameters
|
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
|
||||||
#[derive(Clone)]
|
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||||
#[pyclass(get_all)]
|
match class_name(ob)?.as_str() {
|
||||||
pub struct PyFullTextSearchQuery {
|
"MatchQuery" => {
|
||||||
pub columns: Vec<String>,
|
let query = ob.getattr("query")?.extract()?;
|
||||||
pub query: String,
|
let column = ob.getattr("column")?.extract()?;
|
||||||
pub limit: Option<i64>,
|
let boost = ob.getattr("boost")?.extract()?;
|
||||||
pub wand_factor: Option<f32>,
|
let fuzziness = ob.getattr("fuzziness")?.extract()?;
|
||||||
|
let max_expansions = ob.getattr("max_expansions")?.extract()?;
|
||||||
|
let operator = ob.getattr("operator")?.extract::<String>()?;
|
||||||
|
|
||||||
|
Ok(PyLanceDB(
|
||||||
|
MatchQuery::new(query)
|
||||||
|
.with_column(Some(column))
|
||||||
|
.with_boost(boost)
|
||||||
|
.with_fuzziness(fuzziness)
|
||||||
|
.with_max_expansions(max_expansions)
|
||||||
|
.with_operator(Operator::try_from(operator.as_str()).map_err(|e| {
|
||||||
|
PyValueError::new_err(format!("Invalid operator: {}", e))
|
||||||
|
})?)
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
"PhraseQuery" => {
|
||||||
|
let query = ob.getattr("query")?.extract()?;
|
||||||
|
let column = ob.getattr("column")?.extract()?;
|
||||||
|
let slop = ob.getattr("slop")?.extract()?;
|
||||||
|
|
||||||
|
Ok(PyLanceDB(
|
||||||
|
PhraseQuery::new(query)
|
||||||
|
.with_column(Some(column))
|
||||||
|
.with_slop(slop)
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
"BoostQuery" => {
|
||||||
|
let positive: PyLanceDB<FtsQuery> = ob.getattr("positive")?.extract()?;
|
||||||
|
let negative: PyLanceDB<FtsQuery> = ob.getattr("negative")?.extract()?;
|
||||||
|
let negative_boost = ob.getattr("negative_boost")?.extract()?;
|
||||||
|
Ok(PyLanceDB(
|
||||||
|
BoostQuery::new(positive.0, negative.0, negative_boost).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
"MultiMatchQuery" => {
|
||||||
|
let query = ob.getattr("query")?.extract()?;
|
||||||
|
let columns = ob.getattr("columns")?.extract()?;
|
||||||
|
let boosts: Option<Vec<f32>> = ob.getattr("boosts")?.extract()?;
|
||||||
|
let operator: String = ob.getattr("operator")?.extract()?;
|
||||||
|
|
||||||
|
let q = MultiMatchQuery::try_new(query, columns)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?;
|
||||||
|
let q = if let Some(boosts) = boosts {
|
||||||
|
q.try_with_boosts(boosts)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))?
|
||||||
|
} else {
|
||||||
|
q
|
||||||
|
};
|
||||||
|
|
||||||
|
let op = Operator::try_from(operator.as_str())
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(PyLanceDB(q.with_operator(op).into()))
|
||||||
|
}
|
||||||
|
"BooleanQuery" => {
|
||||||
|
let queries: Vec<(String, PyLanceDB<FtsQuery>)> =
|
||||||
|
ob.getattr("queries")?.extract()?;
|
||||||
|
let mut sub_queries = Vec::with_capacity(queries.len());
|
||||||
|
for (occur, q) in queries {
|
||||||
|
let occur = Occur::try_from(occur.as_str())
|
||||||
|
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||||
|
sub_queries.push((occur, q.0));
|
||||||
|
}
|
||||||
|
Ok(PyLanceDB(BooleanQuery::new(sub_queries).into()))
|
||||||
|
}
|
||||||
|
name => Err(PyValueError::new_err(format!(
|
||||||
|
"Unsupported FTS query type: {}",
|
||||||
|
name
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<FullTextSearchQuery> for PyFullTextSearchQuery {
|
impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
|
||||||
fn from(query: FullTextSearchQuery) -> Self {
|
type Target = PyAny;
|
||||||
Self {
|
type Output = Bound<'py, Self::Target>;
|
||||||
columns: query.columns().into_iter().collect(),
|
type Error = PyErr;
|
||||||
query: query.query.query().to_owned(),
|
|
||||||
limit: query.limit,
|
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||||
wand_factor: query.wand_factor,
|
let namespace = py
|
||||||
|
.import(intern!(py, "lancedb"))
|
||||||
|
.and_then(|m| m.getattr(intern!(py, "query")))
|
||||||
|
.expect("Failed to import namespace");
|
||||||
|
|
||||||
|
match self.0 {
|
||||||
|
FtsQuery::Match(query) => {
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item("boost", query.boost)?;
|
||||||
|
kwargs.set_item("fuzziness", query.fuzziness)?;
|
||||||
|
kwargs.set_item("max_expansions", query.max_expansions)?;
|
||||||
|
kwargs.set_item("operator", operator_to_str(query.operator))?;
|
||||||
|
namespace
|
||||||
|
.getattr(intern!(py, "MatchQuery"))?
|
||||||
|
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||||
|
}
|
||||||
|
FtsQuery::Phrase(query) => {
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item("slop", query.slop)?;
|
||||||
|
namespace
|
||||||
|
.getattr(intern!(py, "PhraseQuery"))?
|
||||||
|
.call((query.terms, query.column.unwrap()), Some(&kwargs))
|
||||||
|
}
|
||||||
|
FtsQuery::Boost(query) => {
|
||||||
|
let positive = PyLanceDB(query.positive.as_ref().clone()).into_pyobject(py)?;
|
||||||
|
let negative = PyLanceDB(query.negative.as_ref().clone()).into_pyobject(py)?;
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item("negative_boost", query.negative_boost)?;
|
||||||
|
namespace
|
||||||
|
.getattr(intern!(py, "BoostQuery"))?
|
||||||
|
.call((positive, negative), Some(&kwargs))
|
||||||
|
}
|
||||||
|
FtsQuery::MultiMatch(query) => {
|
||||||
|
let first = &query.match_queries[0];
|
||||||
|
let (columns, boosts): (Vec<_>, Vec<_>) = query
|
||||||
|
.match_queries
|
||||||
|
.iter()
|
||||||
|
.map(|q| (q.column.as_ref().unwrap().clone(), q.boost))
|
||||||
|
.unzip();
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item("boosts", boosts)?;
|
||||||
|
kwargs.set_item("operator", operator_to_str(first.operator))?;
|
||||||
|
namespace
|
||||||
|
.getattr(intern!(py, "MultiMatchQuery"))?
|
||||||
|
.call((first.terms.clone(), columns), Some(&kwargs))
|
||||||
|
}
|
||||||
|
FtsQuery::Boolean(query) => {
|
||||||
|
let mut queries = Vec::with_capacity(query.must.len() + query.should.len());
|
||||||
|
for q in query.must {
|
||||||
|
queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?));
|
||||||
|
}
|
||||||
|
for q in query.should {
|
||||||
|
queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?));
|
||||||
|
}
|
||||||
|
namespace
|
||||||
|
.getattr(intern!(py, "BooleanQuery"))?
|
||||||
|
.call1((queries,))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn operator_to_str(op: Operator) -> &'static str {
|
||||||
|
match op {
|
||||||
|
Operator::And => "AND",
|
||||||
|
Operator::Or => "OR",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn occur_to_str(occur: Occur) -> &'static str {
|
||||||
|
match occur {
|
||||||
|
Occur::Must => "MUST",
|
||||||
|
Occur::Should => "SHOULD",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Python representation of query vector(s)
|
// Python representation of query vector(s)
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
|
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);
|
||||||
@@ -80,7 +229,7 @@ pub struct PyQueryRequest {
|
|||||||
pub limit: Option<usize>,
|
pub limit: Option<usize>,
|
||||||
pub offset: Option<usize>,
|
pub offset: Option<usize>,
|
||||||
pub filter: Option<PyQueryFilter>,
|
pub filter: Option<PyQueryFilter>,
|
||||||
pub full_text_search: Option<PyFullTextSearchQuery>,
|
pub full_text_search: Option<PyLanceDB<FtsQuery>>,
|
||||||
pub select: PySelect,
|
pub select: PySelect,
|
||||||
pub fast_search: Option<bool>,
|
pub fast_search: Option<bool>,
|
||||||
pub with_row_id: Option<bool>,
|
pub with_row_id: Option<bool>,
|
||||||
@@ -106,7 +255,7 @@ impl From<AnyQuery> for PyQueryRequest {
|
|||||||
filter: query_request.filter.map(PyQueryFilter),
|
filter: query_request.filter.map(PyQueryFilter),
|
||||||
full_text_search: query_request
|
full_text_search: query_request
|
||||||
.full_text_search
|
.full_text_search
|
||||||
.map(PyFullTextSearchQuery::from),
|
.map(|fts| PyLanceDB(fts.query)),
|
||||||
select: PySelect(query_request.select),
|
select: PySelect(query_request.select),
|
||||||
fast_search: Some(query_request.fast_search),
|
fast_search: Some(query_request.fast_search),
|
||||||
with_row_id: Some(query_request.with_row_id),
|
with_row_id: Some(query_request.with_row_id),
|
||||||
@@ -269,8 +418,8 @@ impl Query {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut query = FullTextSearchQuery::new_query(query);
|
let mut query = FullTextSearchQuery::new_query(query);
|
||||||
if let Some(cols) = columns {
|
match columns {
|
||||||
if !cols.is_empty() {
|
Some(cols) if !cols.is_empty() => {
|
||||||
query = query.with_columns(&cols).map_err(|e| {
|
query = query.with_columns(&cols).map_err(|e| {
|
||||||
PyValueError::new_err(format!(
|
PyValueError::new_err(format!(
|
||||||
"Failed to set full text search columns: {}",
|
"Failed to set full text search columns: {}",
|
||||||
@@ -278,15 +427,12 @@ impl Query {
|
|||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
query
|
query
|
||||||
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
|
|
||||||
let query = parse_fts_query(query)?;
|
|
||||||
FullTextSearchQuery::new_query(query)
|
|
||||||
} else {
|
} else {
|
||||||
return Err(PyValueError::new_err(
|
let query = fts_query.extract::<PyLanceDB<FtsQuery>>()?;
|
||||||
"query must be a string or a Query object",
|
FullTextSearchQuery::new_query(query.0)
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(FTSQuery {
|
Ok(FTSQuery {
|
||||||
|
|||||||
@@ -3,15 +3,11 @@
|
|||||||
|
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use lancedb::index::scalar::{BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, PhraseQuery};
|
|
||||||
use lancedb::DistanceType;
|
use lancedb::DistanceType;
|
||||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods};
|
|
||||||
use pyo3::types::PyDict;
|
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
pyfunction, PyResult,
|
pyfunction, PyResult,
|
||||||
};
|
};
|
||||||
use pyo3::{Bound, PyAny};
|
|
||||||
|
|
||||||
/// A wrapper around a rust builder
|
/// A wrapper around a rust builder
|
||||||
///
|
///
|
||||||
@@ -64,116 +60,6 @@ pub fn validate_table_name(table_name: &str) -> PyResult<()> {
|
|||||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_fts_query(query: &Bound<'_, PyDict>) -> PyResult<FtsQuery> {
|
/// A wrapper around a LanceDB type to allow it to be used in Python
|
||||||
let query_type = query.keys().get_item(0)?.extract::<String>()?;
|
#[derive(Debug, Clone)]
|
||||||
let query_value = query
|
pub struct PyLanceDB<T>(pub T);
|
||||||
.get_item(&query_type)?
|
|
||||||
.ok_or(PyValueError::new_err(format!(
|
|
||||||
"Query type {} not found",
|
|
||||||
query_type
|
|
||||||
)))?;
|
|
||||||
let query_value = query_value.downcast::<PyDict>()?;
|
|
||||||
|
|
||||||
match query_type.as_str() {
|
|
||||||
"match" => {
|
|
||||||
let column = query_value.keys().get_item(0)?.extract::<String>()?;
|
|
||||||
let params = query_value
|
|
||||||
.get_item(&column)?
|
|
||||||
.ok_or(PyValueError::new_err(format!(
|
|
||||||
"column {} not found",
|
|
||||||
column
|
|
||||||
)))?;
|
|
||||||
let params = params.downcast::<PyDict>()?;
|
|
||||||
|
|
||||||
let query = params
|
|
||||||
.get_item("query")?
|
|
||||||
.ok_or(PyValueError::new_err("query not found"))?
|
|
||||||
.extract::<String>()?;
|
|
||||||
let boost = params
|
|
||||||
.get_item("boost")?
|
|
||||||
.ok_or(PyValueError::new_err("boost not found"))?
|
|
||||||
.extract::<f32>()?;
|
|
||||||
let fuzziness = params
|
|
||||||
.get_item("fuzziness")?
|
|
||||||
.ok_or(PyValueError::new_err("fuzziness not found"))?
|
|
||||||
.extract::<Option<u32>>()?;
|
|
||||||
let max_expansions = params
|
|
||||||
.get_item("max_expansions")?
|
|
||||||
.ok_or(PyValueError::new_err("max_expansions not found"))?
|
|
||||||
.extract::<usize>()?;
|
|
||||||
|
|
||||||
let query = MatchQuery::new(query)
|
|
||||||
.with_column(Some(column))
|
|
||||||
.with_boost(boost)
|
|
||||||
.with_fuzziness(fuzziness)
|
|
||||||
.with_max_expansions(max_expansions);
|
|
||||||
Ok(query.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
"match_phrase" => {
|
|
||||||
let column = query_value.keys().get_item(0)?.extract::<String>()?;
|
|
||||||
let query = query_value
|
|
||||||
.get_item(&column)?
|
|
||||||
.ok_or(PyValueError::new_err(format!(
|
|
||||||
"column {} not found",
|
|
||||||
column
|
|
||||||
)))?
|
|
||||||
.extract::<String>()?;
|
|
||||||
|
|
||||||
let query = PhraseQuery::new(query).with_column(Some(column));
|
|
||||||
Ok(query.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
"boost" => {
|
|
||||||
let positive: Bound<'_, PyAny> = query_value
|
|
||||||
.get_item("positive")?
|
|
||||||
.ok_or(PyValueError::new_err("positive not found"))?;
|
|
||||||
let positive = positive.downcast::<PyDict>()?;
|
|
||||||
|
|
||||||
let negative = query_value
|
|
||||||
.get_item("negative")?
|
|
||||||
.ok_or(PyValueError::new_err("negative not found"))?;
|
|
||||||
let negative = negative.downcast::<PyDict>()?;
|
|
||||||
|
|
||||||
let negative_boost = query_value
|
|
||||||
.get_item("negative_boost")?
|
|
||||||
.ok_or(PyValueError::new_err("negative_boost not found"))?
|
|
||||||
.extract::<f32>()?;
|
|
||||||
|
|
||||||
let positive_query = parse_fts_query(positive)?;
|
|
||||||
let negative_query = parse_fts_query(negative)?;
|
|
||||||
let query = BoostQuery::new(positive_query, negative_query, Some(negative_boost));
|
|
||||||
|
|
||||||
Ok(query.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
"multi_match" => {
|
|
||||||
let query = query_value
|
|
||||||
.get_item("query")?
|
|
||||||
.ok_or(PyValueError::new_err("query not found"))?
|
|
||||||
.extract::<String>()?;
|
|
||||||
|
|
||||||
let columns = query_value
|
|
||||||
.get_item("columns")?
|
|
||||||
.ok_or(PyValueError::new_err("columns not found"))?
|
|
||||||
.extract::<Vec<String>>()?;
|
|
||||||
|
|
||||||
let boost = query_value
|
|
||||||
.get_item("boost")?
|
|
||||||
.ok_or(PyValueError::new_err("boost not found"))?
|
|
||||||
.extract::<Vec<f32>>()?;
|
|
||||||
|
|
||||||
let query = MultiMatchQuery::try_new(query, columns)
|
|
||||||
.and_then(|q| q.try_with_boosts(boost))
|
|
||||||
.map_err(|e| {
|
|
||||||
PyValueError::new_err(format!("Error creating MultiMatchQuery: {}", e))
|
|
||||||
})?;
|
|
||||||
Ok(query.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => Err(PyValueError::new_err(format!(
|
|
||||||
"Unsupported query type: {}",
|
|
||||||
query_type
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user