mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
feat: support to query/index FTS on RemoteTable/AsyncTable (#1537)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -74,6 +74,7 @@ class Query:
|
||||
def select(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> Query: ...
|
||||
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
|
||||
|
||||
class VectorQuery:
|
||||
|
||||
@@ -276,6 +276,10 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
@@ -340,10 +344,6 @@ class LanceDBConnection(DBConnection):
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
|
||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||
return await conn.table_names(start_after=start_after, limit=limit)
|
||||
|
||||
@@ -70,6 +70,18 @@ class LabelList:
|
||||
self._inner = LanceDbIndex.label_list()
|
||||
|
||||
|
||||
class FTS:
|
||||
"""Describe a FTS index configuration.
|
||||
|
||||
`FTS` is a full-text search index that can be used on `String` columns
|
||||
|
||||
For example, it works with `title`, `description`, `content`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.fts()
|
||||
|
||||
|
||||
class IvfPq:
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
@@ -38,7 +37,7 @@ from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -174,7 +173,9 @@ class LanceQueryBuilder(ABC):
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(
|
||||
table, query, ordering_field_name=ordering_field_name
|
||||
table,
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -681,6 +682,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
self._phrase_query = False
|
||||
self.ordering_field_name = ordering_field_name
|
||||
self._reranker = None
|
||||
if isinstance(fts_columns, str):
|
||||
fts_columns = [fts_columns]
|
||||
self._fts_columns = fts_columns
|
||||
|
||||
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
||||
@@ -701,8 +704,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
return self
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
tantivy_index_path = self._table._get_fts_index_path()
|
||||
if Path(tantivy_index_path).exists():
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
return self.tantivy_to_arrow()
|
||||
|
||||
query = self._query
|
||||
@@ -711,23 +714,20 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"Phrase query is not yet supported in Lance FTS. "
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
if self._reranker:
|
||||
raise NotImplementedError(
|
||||
"Reranking is not yet supported in Lance FTS. "
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit,
|
||||
prefilter=self._prefilter,
|
||||
with_row_id=self._with_row_id,
|
||||
full_text_query={
|
||||
"query": query,
|
||||
"columns": self._fts_columns,
|
||||
},
|
||||
vector=[],
|
||||
)
|
||||
results = self._table._execute_query(query)
|
||||
return results.read_all()
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
@@ -740,24 +740,24 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
from .fts import search_index
|
||||
|
||||
# get the index path
|
||||
index_path = self._table._get_fts_index_path()
|
||||
|
||||
# Check that we are on local filesystem
|
||||
fs, _path = fs_from_uri(index_path)
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
if not exist:
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
|
||||
# Check that we are on local filesystem
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Tantivy-based full text search "
|
||||
"is only supported on the local filesystem"
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(index_path)
|
||||
index = tantivy.Index.open(path)
|
||||
# get the scores and doc ids
|
||||
query = self._query
|
||||
if self._phrase_query:
|
||||
@@ -851,7 +851,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
@@ -859,12 +858,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
|
||||
def _validate_fts_index(self):
|
||||
if self._table._get_fts_index_path() is None:
|
||||
raise ValueError(
|
||||
"Please create a full-text search index " "to perform hybrid search."
|
||||
)
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
@@ -1354,6 +1347,35 @@ class AsyncQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = None
|
||||
) -> AsyncQuery:
|
||||
"""
|
||||
Find the documents that are most relevant to the given text query.
|
||||
|
||||
This method will perform a full text search on the table and return
|
||||
the most relevant documents. The relevance is determined by BM25.
|
||||
|
||||
The columns to search must be with native FTS index
|
||||
(Tantivy-based can't work with this method).
|
||||
|
||||
By default, all indexed columns are searched,
|
||||
now only one column can be searched at a time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str
|
||||
The text query to search for.
|
||||
columns: str or list of str, default None
|
||||
The columns to search in. If None, all indexed columns are searched.
|
||||
For now only one column can be searched at a time.
|
||||
"""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
return AsyncQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
|
||||
@@ -49,6 +49,7 @@ class RemoteDBConnection(DBConnection):
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self._uri = str(db_url)
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
|
||||
@@ -35,10 +35,10 @@ from .db import RemoteDBConnection
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self._name = name
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
@@ -49,14 +49,14 @@ class RemoteTable(Table):
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
return resp["version"]
|
||||
|
||||
@cached_property
|
||||
@@ -84,13 +84,13 @@ class RemoteTable(Table):
|
||||
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
|
||||
return resp
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
|
||||
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -116,11 +116,27 @@ class RemoteTable(Table):
|
||||
"replace": True,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_scalar_index/", data=data
|
||||
f"/v1/table/{self.name}/create_scalar_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": "FTS",
|
||||
"replace": replace,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -194,7 +210,7 @@ class RemoteTable(Table):
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_index/", data=data
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
@@ -241,7 +257,7 @@ class RemoteTable(Table):
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
f"/v1/table/{self.name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
@@ -251,6 +267,7 @@ class RemoteTable(Table):
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type="auto",
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -310,10 +327,18 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
@@ -342,12 +367,12 @@ class RemoteTable(Table):
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self._name, q))
|
||||
results.append(submit(self.name, q))
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
).to_reader()
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
result = self._conn._client.query(self.name, query)
|
||||
return result.to_arrow().to_reader()
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
@@ -397,7 +422,7 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/merge_insert/",
|
||||
f"/v1/table/{self.name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
@@ -451,7 +476,7 @@ class RemoteTable(Table):
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -512,7 +537,7 @@ class RemoteTable(Table):
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
@@ -529,7 +554,7 @@ class RemoteTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
payload = {"predicate": filter}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/count_rows/", data=payload
|
||||
f"/v1/table/{self.name}/count_rows/", data=payload
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if TYPE_CHECKING:
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
@@ -840,6 +840,18 @@ class Table(ABC):
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def _dataset_uri(self) -> str:
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
return ("", None, False)
|
||||
path = join_uri(self._dataset_uri, "_indices", "fts")
|
||||
fs, path = fs_from_uri(path)
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
return (path, fs, index_exists)
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@@ -979,10 +991,6 @@ class LanceTable(Table):
|
||||
# Cacheable since it's deterministic
|
||||
return _table_path(self._conn.uri, self.name)
|
||||
|
||||
@cached_property
|
||||
def _dataset_uri(self) -> str:
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return self._ref.dataset
|
||||
@@ -1247,9 +1255,8 @@ class LanceTable(Table):
|
||||
raise ValueError("field_names must be a string when use_tantivy=False")
|
||||
# delete the existing legacy index if it exists
|
||||
if replace:
|
||||
fs, path = fs_from_uri(self._get_fts_index_path())
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
fs.delete_dir(path)
|
||||
self._dataset_mut.create_scalar_index(
|
||||
field_names, index_type="INVERTED", replace=replace
|
||||
@@ -1264,9 +1271,8 @@ class LanceTable(Table):
|
||||
if isinstance(ordering_field_names, str):
|
||||
ordering_field_names = [ordering_field_names]
|
||||
|
||||
fs, path = fs_from_uri(self._get_fts_index_path())
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
if not replace:
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
@@ -1277,7 +1283,7 @@ class LanceTable(Table):
|
||||
)
|
||||
|
||||
index = create_index(
|
||||
self._get_fts_index_path(),
|
||||
path,
|
||||
field_names,
|
||||
ordering_fields=ordering_field_names,
|
||||
tokenizer_name=tokenizer_name,
|
||||
@@ -1290,13 +1296,6 @@ class LanceTable(Table):
|
||||
writer_heap_size=writer_heap_size,
|
||||
)
|
||||
|
||||
def _get_fts_index_path(self):
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
raise NotImplementedError(
|
||||
"Full-text search is not supported on object stores."
|
||||
)
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -1492,14 +1491,11 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
if query_type == "fts":
|
||||
vector_column_name = ""
|
||||
else:
|
||||
raise e
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
@@ -1690,18 +1686,22 @@ class LanceTable(Table):
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
ds = self.to_lance()
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
nearest = None
|
||||
if len(query.vector) > 0:
|
||||
nearest = {
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
}
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
limit=query.k,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest=nearest,
|
||||
full_text_query=query.full_text_query,
|
||||
with_row_id=query.with_row_id,
|
||||
batch_size=batch_size,
|
||||
@@ -2126,7 +2126,7 @@ class AsyncTable:
|
||||
column: str,
|
||||
*,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList]] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
from lancedb.index import FTS
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
@@ -60,6 +61,43 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
return table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
db = await ldb.connect_async(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
nouns[random.randrange(0, 5)],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
count = [random.randint(1, 10000) for _ in range(100)]
|
||||
table = await db.create_table(
|
||||
"test",
|
||||
data=pd.DataFrame(
|
||||
{
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"nested": [{"text": t} for t in text],
|
||||
"count": count,
|
||||
}
|
||||
),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def test_create_index(tmp_path):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
assert isinstance(index, tantivy.Index)
|
||||
@@ -91,17 +129,23 @@ def test_search_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
ldb.fts.populate_index(index, table, ["text"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=10)
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=5)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 10 # row_ids
|
||||
assert len(results[1]) == 10 # _distance
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_fts(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
results = table.search("puppy").limit(10).to_list()
|
||||
assert len(results) == 10
|
||||
results = table.search("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
async def test_search_fts_async(async_table):
|
||||
await async_table.create_index("text", config=FTS())
|
||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
def test_search_ordering_field_index_table(tmp_path, table):
|
||||
@@ -125,11 +169,11 @@ def test_search_ordering_field_index(tmp_path, table):
|
||||
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(
|
||||
index, query="puppy", limit=10, ordering_field="count"
|
||||
index, query="puppy", limit=5, ordering_field="count"
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 10 # row_ids
|
||||
assert len(results[1]) == 10 # _distance
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _distance
|
||||
rows = table.to_lance().take(results[0]).to_pylist()
|
||||
|
||||
for r in rows:
|
||||
@@ -140,8 +184,8 @@ def test_search_ordering_field_index(tmp_path, table):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||
assert len(df) <= 10
|
||||
df = table.search("puppy").limit(5).select(["text"]).to_pandas()
|
||||
assert len(df) <= 5
|
||||
assert "text" in df.columns
|
||||
|
||||
# Check whether it can be updated
|
||||
@@ -167,8 +211,8 @@ def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 10
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 5
|
||||
assert "text" in df.columns
|
||||
assert "text2" in df.columns
|
||||
|
||||
@@ -176,14 +220,14 @@ def test_create_index_multiple_columns(tmp_path, table):
|
||||
def test_empty_rs(tmp_path, table, mocker):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 0
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text", use_tantivy=True)
|
||||
rs = table.search("puppy").limit(10).to_list()
|
||||
assert len(rs) == 10
|
||||
rs = table.search("puppy").limit(5).to_list()
|
||||
assert len(rs) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
@@ -251,7 +251,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
||||
|
||||
# FTS indices should error since they are not supported yet.
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Full-text search is not supported on object stores."
|
||||
NotImplementedError,
|
||||
match="Full-text search is only supported on the local filesystem",
|
||||
):
|
||||
table.create_fts_index("x")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from pydantic import BaseModel
|
||||
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.uri = str(uri)
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
|
||||
@@ -15,17 +15,20 @@
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::pyclass;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::Bound;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::{pyclass, PyErr};
|
||||
use pyo3_asyncio_0_21::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
@@ -68,6 +71,24 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<()> {
|
||||
let query_text = query
|
||||
.get_item("query")?
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"Query text is required for nearest_to_text",
|
||||
))?
|
||||
.extract::<String>()?;
|
||||
let columns = query
|
||||
.get_item("columns")?
|
||||
.map(|columns| columns.extract::<Vec<String>>())
|
||||
.transpose()?;
|
||||
|
||||
let fts_query = FullTextSearchQuery::new(query_text).columns(columns);
|
||||
self.inner = self.inner.clone().full_text_search(fts_query);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
|
||||
Reference in New Issue
Block a user