mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
fix: specify column to search for FTS (#1572)
Before this we ignored the `fts_columns` parameter, and for now we support to search on only one column, it could lead to an error if we have multiple indexed columns for FTS --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -132,8 +132,8 @@ class LanceQueryBuilder(ABC):
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: str = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
) -> LanceQueryBuilder:
|
||||
"""
|
||||
Create a query builder based on the given query and query type.
|
||||
@@ -156,7 +156,9 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
@@ -168,7 +170,9 @@ class LanceQueryBuilder(ABC):
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
@@ -176,6 +180,7 @@ class LanceQueryBuilder(ABC):
|
||||
table,
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -693,8 +698,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
ordering_field_name: str = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
@@ -887,10 +892,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
in the `rerank` method to convert the scores to ranks and then normalize them.
|
||||
"""
|
||||
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
vector_column: str,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
self._fts_query = LanceFtsQueryBuilder(
|
||||
table, fts_query, fts_columns=fts_columns
|
||||
)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._norm = "score"
|
||||
@@ -1386,7 +1399,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = None
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
) -> AsyncQuery:
|
||||
"""
|
||||
Find the documents that are most relevant to the given text query.
|
||||
@@ -1410,9 +1423,8 @@ class AsyncQuery(AsyncQueryBase):
|
||||
"""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
return AsyncQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
return self
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase):
|
||||
|
||||
@@ -15,7 +15,7 @@ import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, Optional, Union, Literal
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
@@ -268,6 +268,7 @@ class RemoteTable(Table):
|
||||
query: Union[VEC, str],
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type="auto",
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -338,6 +339,7 @@ class RemoteTable(Table):
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
|
||||
@@ -545,7 +545,7 @@ class Table(ABC):
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -1425,7 +1425,7 @@ class LanceTable(Table):
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -1505,6 +1505,7 @@ class LanceTable(Table):
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -29,14 +29,26 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
db = ldb.connect(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
|
||||
text_nouns = ("puppy", "car")
|
||||
text2_nouns = ("rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
nouns[random.randrange(0, 5)],
|
||||
text_nouns[random.randrange(0, len(text_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
text2 = [
|
||||
" ".join(
|
||||
[
|
||||
text2_nouns[random.randrange(0, len(text2_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
@@ -52,7 +64,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"text2": text2,
|
||||
"nested": [{"text": t} for t in text],
|
||||
"count": count,
|
||||
}
|
||||
@@ -66,14 +78,26 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
db = await ldb.connect_async(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
|
||||
text_nouns = ("puppy", "car")
|
||||
text2_nouns = ("rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
nouns[random.randrange(0, 5)],
|
||||
text_nouns[random.randrange(0, len(text_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
text2 = [
|
||||
" ".join(
|
||||
[
|
||||
text2_nouns[random.randrange(0, len(text2_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
@@ -89,7 +113,7 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"text2": text2,
|
||||
"nested": [{"text": t} for t in text],
|
||||
"count": count,
|
||||
}
|
||||
@@ -142,12 +166,81 @@ def test_search_fts(table, use_tantivy):
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
def test_search_fts_specify_column(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text2", use_tantivy=False)
|
||||
|
||||
results = table.search("puppy", fts_columns="text").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
results = table.search("rabbit", fts_columns="text2").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
try:
|
||||
# we can only specify one column for now
|
||||
table.search("puppy", fts_columns=["text", "text2"]).limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# have to specify a column because we have two fts indices
|
||||
table.search("puppy").limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_async(async_table):
|
||||
async_table = await async_table
|
||||
await async_table.create_index("text", config=FTS())
|
||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_specify_column_async(async_table):
|
||||
async_table = await async_table
|
||||
await async_table.create_index("text", config=FTS())
|
||||
await async_table.create_index("text2", config=FTS())
|
||||
|
||||
results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("puppy", columns="text")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("rabbit", columns="text2")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
try:
|
||||
# we can only specify one column for now
|
||||
await (
|
||||
async_table.query()
|
||||
.nearest_to_text("rabbit", columns="text2")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# have to specify a column because we have two fts indices
|
||||
await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def test_search_ordering_field_index_table(tmp_path, table):
|
||||
table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True)
|
||||
rows = (
|
||||
|
||||
Reference in New Issue
Block a user