mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
fix: can't do structured FTS in python (#2300)
missed to support it in `search()` API and there were some pydantic errors <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced full-text search capabilities by incorporating additional parameters, enabling more flexible query definitions. - Extended table search functionality to support full-text queries alongside existing search types. - **Tests** - Introduced new tests that validate both structured and conditional full-text search behaviors. - Expanded test coverage for various query types, including MatchQuery, BoostQuery, MultiMatchQuery, and PhraseQuery. - **Bug Fixes** - Fixed a logic issue in query processing to ensure correct handling of full-text search queries. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -117,6 +117,12 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
||||
|
||||
|
||||
class MatchQuery(FullTextQuery):
|
||||
query: str
|
||||
column: str
|
||||
boost: float = 1.0
|
||||
fuzziness: int = 0
|
||||
max_expansions: int = 50
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
@@ -149,11 +155,13 @@ class MatchQuery(FullTextQuery):
|
||||
The maximum number of terms to consider for fuzzy matching.
|
||||
Defaults to 50.
|
||||
"""
|
||||
self.column = column
|
||||
self.query = query
|
||||
self.boost = boost
|
||||
self.fuzziness = fuzziness
|
||||
self.max_expansions = max_expansions
|
||||
super().__init__(
|
||||
query=query,
|
||||
column=column,
|
||||
boost=boost,
|
||||
fuzziness=fuzziness,
|
||||
max_expansions=max_expansions,
|
||||
)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MATCH
|
||||
@@ -172,6 +180,9 @@ class MatchQuery(FullTextQuery):
|
||||
|
||||
|
||||
class PhraseQuery(FullTextQuery):
|
||||
query: str
|
||||
column: str
|
||||
|
||||
def __init__(self, query: str, column: str):
|
||||
"""
|
||||
Phrase query for full-text search.
|
||||
@@ -183,8 +194,7 @@ class PhraseQuery(FullTextQuery):
|
||||
column : str
|
||||
The name of the column to match against.
|
||||
"""
|
||||
self.column = column
|
||||
self.query = query
|
||||
super().__init__(query=query, column=column)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MATCH_PHRASE
|
||||
@@ -198,11 +208,16 @@ class PhraseQuery(FullTextQuery):
|
||||
|
||||
|
||||
class BoostQuery(FullTextQuery):
|
||||
positive: FullTextQuery
|
||||
negative: FullTextQuery
|
||||
negative_boost: float = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positive: FullTextQuery,
|
||||
negative: FullTextQuery,
|
||||
negative_boost: float,
|
||||
*,
|
||||
negative_boost: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Boost query for full-text search.
|
||||
@@ -216,9 +231,9 @@ class BoostQuery(FullTextQuery):
|
||||
negative_boost : float
|
||||
The boost factor for the negative query.
|
||||
"""
|
||||
self.positive = positive
|
||||
self.negative = negative
|
||||
self.negative_boost = negative_boost
|
||||
super().__init__(
|
||||
positive=positive, negative=negative, negative_boost=negative_boost
|
||||
)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.BOOST
|
||||
@@ -234,6 +249,10 @@ class BoostQuery(FullTextQuery):
|
||||
|
||||
|
||||
class MultiMatchQuery(FullTextQuery):
|
||||
query: str
|
||||
columns: list[str]
|
||||
boosts: list[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
@@ -256,11 +275,9 @@ class MultiMatchQuery(FullTextQuery):
|
||||
The list of boost factors for each column. If not provided,
|
||||
all columns will have the same boost factor.
|
||||
"""
|
||||
self.query = query
|
||||
self.columns = columns
|
||||
if boosts is None:
|
||||
boosts = [1.0] * len(columns)
|
||||
self.boosts = boosts
|
||||
super().__init__(query=query, columns=columns, boosts=boosts)
|
||||
|
||||
def query_type(self) -> FullTextQueryType:
|
||||
return FullTextQueryType.MULTI_MATCH
|
||||
@@ -544,7 +561,7 @@ class LanceQueryBuilder(ABC):
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
if isinstance(query, (str, FullTextQuery)):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(
|
||||
table,
|
||||
@@ -569,8 +586,10 @@ class LanceQueryBuilder(ABC):
|
||||
# If query_type is fts, then query must be a string.
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
if not isinstance(query, (str, FullTextQuery)):
|
||||
raise TypeError(
|
||||
f"'fts' query must be a string or FullTextQuery: {type(query)}"
|
||||
)
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
query = cls._query_to_vector(table, query, vector_column_name)
|
||||
@@ -1486,7 +1505,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: Optional[str] = None,
|
||||
query: Optional[Union[str, FullTextQuery]] = None,
|
||||
vector_column: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
@@ -1516,8 +1535,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
text_query = text or query
|
||||
if text_query is None:
|
||||
raise ValueError("Text query must be provided for hybrid search.")
|
||||
if not isinstance(text_query, str):
|
||||
raise ValueError("Text query must be a string")
|
||||
if not isinstance(text_query, (str, FullTextQuery)):
|
||||
raise ValueError("Text query must be a string or FullTextQuery")
|
||||
|
||||
return vector_query, text_query
|
||||
|
||||
@@ -2308,7 +2327,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text(query.to_dict()))
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||
|
||||
|
||||
class AsyncFTSQuery(AsyncQueryBase):
|
||||
@@ -2627,7 +2646,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text(query.to_dict()))
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
|
||||
@@ -52,6 +52,7 @@ from .query import (
|
||||
AsyncHybridQuery,
|
||||
AsyncQuery,
|
||||
AsyncVectorQuery,
|
||||
FullTextQuery,
|
||||
LanceEmptyQueryBuilder,
|
||||
LanceFtsQueryBuilder,
|
||||
LanceHybridQueryBuilder,
|
||||
@@ -919,7 +920,9 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
query: Optional[
|
||||
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||
] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -2039,7 +2042,9 @@ class LanceTable(Table):
|
||||
@overload
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
query: Optional[
|
||||
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||
] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["hybrid"] = "hybrid",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -2058,7 +2063,9 @@ class LanceTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
query: Optional[
|
||||
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||
] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -3134,7 +3141,9 @@ class AsyncTable:
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
query: Optional[
|
||||
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||
] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["vector"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -3143,7 +3152,9 @@ class AsyncTable:
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
query: Optional[
|
||||
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||
] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -3253,6 +3264,8 @@ class AsyncTable:
|
||||
if is_embedding(query):
|
||||
vector_query = query
|
||||
query_type = "vector"
|
||||
elif isinstance(query, FullTextQuery):
|
||||
query_type = "fts"
|
||||
elif isinstance(query, str):
|
||||
try:
|
||||
(
|
||||
|
||||
@@ -20,6 +20,7 @@ from unittest import mock
|
||||
import lancedb as ldb
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.index import FTS
|
||||
from lancedb.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
@@ -178,11 +179,47 @@ def test_search_fts(table, use_tantivy):
|
||||
results = table.search("puppy").select(["id", "text"]).to_list()
|
||||
assert len(results) == 10
|
||||
|
||||
if not use_tantivy:
|
||||
# Test with a query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
# Test boost query
|
||||
results = (
|
||||
table.search(
|
||||
BoostQuery(
|
||||
MatchQuery("puppy", "text"),
|
||||
MatchQuery("runs", "text"),
|
||||
)
|
||||
)
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
# Test multi match query
|
||||
table.create_fts_index("text2", use_tantivy=use_tantivy)
|
||||
results = (
|
||||
table.search(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_select_async(async_table):
|
||||
tbl = await async_table
|
||||
await tbl.create_index("text", config=FTS())
|
||||
await tbl.create_index("text2", config=FTS())
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
@@ -193,6 +230,54 @@ async def test_fts_select_async(async_table):
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test with FullTextQuery
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text(MatchQuery("puppy", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test with BoostQuery
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text(
|
||||
BoostQuery(
|
||||
MatchQuery("puppy", "text"),
|
||||
MatchQuery("runs", "text"),
|
||||
)
|
||||
)
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test with MultiMatchQuery
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test with search() API
|
||||
results = (
|
||||
await (await tbl.search(MatchQuery("puppy", "text")))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
def test_search_fts_phrase_query(table):
|
||||
table.create_fts_index("text", use_tantivy=False, with_position=False)
|
||||
@@ -207,6 +292,13 @@ def test_search_fts_phrase_query(table):
|
||||
assert len(results) > len(phrase_results)
|
||||
assert len(phrase_results) > 0
|
||||
|
||||
# Test with a query
|
||||
phrase_results = (
|
||||
table.search(PhraseQuery("puppy runs", "text")).limit(100).to_list()
|
||||
)
|
||||
assert len(results) > len(phrase_results)
|
||||
assert len(phrase_results) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_phrase_query_async(async_table):
|
||||
@@ -227,6 +319,16 @@ async def test_search_fts_phrase_query_async(async_table):
|
||||
assert len(results) > len(phrase_results)
|
||||
assert len(phrase_results) > 0
|
||||
|
||||
# Test with a query
|
||||
phrase_results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text(PhraseQuery("puppy runs", "text"))
|
||||
.limit(100)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) > len(phrase_results)
|
||||
assert len(phrase_results) > 0
|
||||
|
||||
|
||||
def test_search_fts_specify_column(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
|
||||
@@ -279,7 +279,7 @@ impl Query {
|
||||
}
|
||||
}
|
||||
query
|
||||
} else if let Ok(query) = query.downcast::<PyDict>() {
|
||||
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
|
||||
let query = parse_fts_query(query)?;
|
||||
FullTextSearchQuery::new_query(query)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user