fix: specify column to search for FTS (#1572)

Before this we ignored the `fts_columns` parameter, and for now we
support to search on only one column, it could lead to an error if we
have multiple indexed columns for FTS

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-08-29 23:43:46 +08:00
committed by GitHub
parent bfe8fccfab
commit 1521435193
5 changed files with 136 additions and 21 deletions

View File

@@ -132,8 +132,8 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
) -> LanceQueryBuilder:
"""
Create a query builder based on the given query and query type.
@@ -156,7 +156,9 @@ class LanceQueryBuilder(ABC):
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
# remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None
@@ -168,7 +170,9 @@ class LanceQueryBuilder(ABC):
)
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if isinstance(query, str):
# fts
@@ -176,6 +180,7 @@ class LanceQueryBuilder(ABC):
table,
query,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
)
if isinstance(query, list):
@@ -693,8 +698,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self,
table: "Table",
query: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
self._query = query
@@ -887,10 +892,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
in the `rerank` method to convert the scores to ranks and then normalize them.
"""
def __init__(self, table: "Table", query: str, vector_column: str):
def __init__(
self,
table: "Table",
query: str,
vector_column: str,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
self._fts_query = LanceFtsQueryBuilder(
table, fts_query, fts_columns=fts_columns
)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
@@ -1386,7 +1399,7 @@ class AsyncQuery(AsyncQueryBase):
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = None
self, query: str, columns: Union[str, List[str]] = []
) -> AsyncQuery:
"""
Find the documents that are most relevant to the given text query.
@@ -1410,9 +1423,8 @@ class AsyncQuery(AsyncQueryBase):
"""
if isinstance(columns, str):
columns = [columns]
return AsyncQuery(
self._inner.nearest_to_text({"query": query, "columns": columns})
)
self._inner.nearest_to_text({"query": query, "columns": columns})
return self
class AsyncVectorQuery(AsyncQueryBase):

View File

@@ -15,7 +15,7 @@ import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Iterable, Optional, Union, Literal
from typing import Dict, Iterable, List, Optional, Union, Literal
import pyarrow as pa
from lance import json_to_schema
@@ -268,6 +268,7 @@ class RemoteTable(Table):
query: Union[VEC, str],
vector_column_name: Optional[str] = None,
query_type="auto",
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -338,6 +339,7 @@ class RemoteTable(Table):
query,
query_type,
vector_column_name=vector_column_name,
fts_columns=fts_columns,
)
def _execute_query(

View File

@@ -545,7 +545,7 @@ class Table(ABC):
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1425,7 +1425,7 @@ class LanceTable(Table):
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1505,6 +1505,7 @@ class LanceTable(Table):
query_type,
vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
)
@classmethod

View File

@@ -29,14 +29,26 @@ def table(tmp_path) -> ldb.table.LanceTable:
db = ldb.connect(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
text_nouns = ("puppy", "car")
text2_nouns = ("rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
text_nouns[random.randrange(0, len(text_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
text2 = [
" ".join(
[
text2_nouns[random.randrange(0, len(text2_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
@@ -52,7 +64,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text,
"text2": text2,
"nested": [{"text": t} for t in text],
"count": count,
}
@@ -66,14 +78,26 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable:
db = await ldb.connect_async(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
text_nouns = ("puppy", "car")
text2_nouns = ("rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
text_nouns[random.randrange(0, len(text_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
text2 = [
" ".join(
[
text2_nouns[random.randrange(0, len(text2_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
@@ -89,7 +113,7 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable:
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text,
"text2": text2,
"nested": [{"text": t} for t in text],
"count": count,
}
@@ -142,12 +166,81 @@ def test_search_fts(table, use_tantivy):
assert len(results) == 5
def test_search_fts_specify_column(table):
table.create_fts_index("text", use_tantivy=False)
table.create_fts_index("text2", use_tantivy=False)
results = table.search("puppy", fts_columns="text").limit(5).to_list()
assert len(results) == 5
results = table.search("rabbit", fts_columns="text2").limit(5).to_list()
assert len(results) == 5
try:
# we can only specify one column for now
table.search("puppy", fts_columns=["text", "text2"]).limit(5).to_list()
assert False
except Exception:
pass
try:
# have to specify a column because we have two fts indices
table.search("puppy").limit(5).to_list()
assert False
except Exception:
pass
@pytest.mark.asyncio
async def test_search_fts_async(async_table):
async_table = await async_table
await async_table.create_index("text", config=FTS())
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5
@pytest.mark.asyncio
async def test_search_fts_specify_column_async(async_table):
async_table = await async_table
await async_table.create_index("text", config=FTS())
await async_table.create_index("text2", config=FTS())
results = (
await async_table.query()
.nearest_to_text("puppy", columns="text")
.limit(5)
.to_list()
)
assert len(results) == 5
results = (
await async_table.query()
.nearest_to_text("rabbit", columns="text2")
.limit(5)
.to_list()
)
assert len(results) == 5
try:
# we can only specify one column for now
await (
async_table.query()
.nearest_to_text("rabbit", columns="text2")
.limit(5)
.to_list()
)
assert False
except Exception:
pass
try:
# have to specify a column because we have two fts indices
await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert False
except Exception:
pass
def test_search_ordering_field_index_table(tmp_path, table):
table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True)
rows = (