feat: support to query/index FTS on RemoteTable/AsyncTable (#1537)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-08-16 12:01:05 +08:00
committed by GitHub
parent 20faa4424b
commit 0fa50775d6
11 changed files with 229 additions and 102 deletions

View File

@@ -74,6 +74,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> Query: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
class VectorQuery:

View File

@@ -276,6 +276,10 @@ class DBConnection(EnforceOverrides):
"""
raise NotImplementedError
@property
def uri(self) -> str:
return self._uri
class LanceDBConnection(DBConnection):
"""
@@ -340,10 +344,6 @@ class LanceDBConnection(DBConnection):
val += ")"
return val
@property
def uri(self) -> str:
return self._uri
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)

View File

@@ -70,6 +70,18 @@ class LabelList:
self._inner = LanceDbIndex.label_list()
class FTS:
"""Describe a FTS index configuration.
`FTS` is a full-text search index that can be used on `String` columns
For example, it works with `title`, `description`, `content`, etc.
"""
def __init__(self):
self._inner = LanceDbIndex.fts()
class IvfPq:
"""Describes an IVF PQ Index

View File

@@ -15,7 +15,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
TYPE_CHECKING,
Dict,
@@ -38,7 +37,7 @@ from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import fs_from_uri, safe_import_pandas
from .util import safe_import_pandas
if TYPE_CHECKING:
import PIL
@@ -174,7 +173,9 @@ class LanceQueryBuilder(ABC):
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(
table, query, ordering_field_name=ordering_field_name
table,
query,
ordering_field_name=ordering_field_name,
)
if isinstance(query, list):
@@ -681,6 +682,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
if isinstance(fts_columns, str):
fts_columns = [fts_columns]
self._fts_columns = fts_columns
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
@@ -701,8 +704,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
return self
def to_arrow(self) -> pa.Table:
tantivy_index_path = self._table._get_fts_index_path()
if Path(tantivy_index_path).exists():
path, fs, exist = self._table._get_fts_index_path()
if exist:
return self.tantivy_to_arrow()
query = self._query
@@ -711,23 +714,20 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Phrase query is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
if self._reranker:
raise NotImplementedError(
"Reranking is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
ds = self._table.to_lance()
return ds.to_table(
query = Query(
columns=self._columns,
filter=self._where,
limit=self._limit,
k=self._limit,
prefilter=self._prefilter,
with_row_id=self._with_row_id,
full_text_query={
"query": query,
"columns": self._fts_columns,
},
vector=[],
)
results = self._table._execute_query(query)
return results.read_all()
def tantivy_to_arrow(self) -> pa.Table:
try:
@@ -740,24 +740,24 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
from .fts import search_index
# get the index path
index_path = self._table._get_fts_index_path()
# Check that we are on local filesystem
fs, _path = fs_from_uri(index_path)
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Full-text search is only supported on the local filesystem"
)
path, fs, exist = self._table._get_fts_index_path()
# check if the index exist
if not Path(index_path).exists():
if not exist:
raise FileNotFoundError(
"Fts index does not exist. "
"Please first call table.create_fts_index(['<field_names>']) to "
"create the fts index."
)
# Check that we are on local filesystem
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Tantivy-based full text search "
"is only supported on the local filesystem"
)
# open the index
index = tantivy.Index.open(index_path)
index = tantivy.Index.open(path)
# get the scores and doc ids
query = self._query
if self._phrase_query:
@@ -851,7 +851,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table)
self._validate_fts_index()
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
@@ -859,12 +858,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
def _validate_fts_index(self):
if self._table._get_fts_index_path() is None:
raise ValueError(
"Please create a full-text search index " "to perform hybrid search."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
@@ -1354,6 +1347,35 @@ class AsyncQuery(AsyncQueryBase):
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = None
) -> AsyncQuery:
"""
Find the documents that are most relevant to the given text query.
This method will perform a full text search on the table and return
the most relevant documents. The relevance is determined by BM25.
The columns to search must be with native FTS index
(Tantivy-based can't work with this method).
By default, all indexed columns are searched,
now only one column can be searched at a time.
Parameters
----------
query: str
The text query to search for.
columns: str or list of str, default None
The columns to search in. If None, all indexed columns are searched.
For now only one column can be searched at a time.
"""
if isinstance(columns, str):
columns = [columns]
return AsyncQuery(
self._inner.nearest_to_text({"query": query, "columns": columns})
)
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):

View File

@@ -49,6 +49,7 @@ class RemoteDBConnection(DBConnection):
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self._uri = str(db_url)
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(

View File

@@ -35,10 +35,10 @@ from .db import RemoteDBConnection
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self._name = name
self.name = name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})"
return f"RemoteTable({self._conn.db_name}.{self.name})"
def __len__(self) -> int:
self.count_rows(None)
@@ -49,14 +49,14 @@ class RemoteTable(Table):
of this Table
"""
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
return resp["version"]
@cached_property
@@ -84,13 +84,13 @@ class RemoteTable(Table):
def list_indices(self):
"""List all the indices on the table"""
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
return resp
def index_stats(self, index_uuid: str):
"""List all the stats of a specified index"""
resp = self._conn._client.post(
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
)
return resp
@@ -116,11 +116,27 @@ class RemoteTable(Table):
"replace": True,
}
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_scalar_index/", data=data
f"/v1/table/{self.name}/create_scalar_index/", data=data
)
return resp
def create_fts_index(
self,
column: str,
*,
replace: bool = False,
):
data = {
"column": column,
"index_type": "FTS",
"replace": replace,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
)
return resp
def create_index(
self,
metric="L2",
@@ -194,7 +210,7 @@ class RemoteTable(Table):
"index_cache_size": index_cache_size,
}
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_index/", data=data
f"/v1/table/{self.name}/create_index/", data=data
)
return resp
@@ -241,7 +257,7 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex
self._conn._client.post(
f"/v1/table/{self._name}/insert/",
f"/v1/table/{self.name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
@@ -251,6 +267,7 @@ class RemoteTable(Table):
self,
query: Union[VEC, str],
vector_column_name: Optional[str] = None,
query_type="auto",
) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -310,10 +327,18 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema)
query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name)
return LanceVectorQueryBuilder(self, query, vector_column_name)
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
raise e
return LanceQueryBuilder.create(
self,
query,
query_type,
vector_column_name=vector_column_name,
)
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
@@ -342,12 +367,12 @@ class RemoteTable(Table):
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self._name, q))
results.append(submit(self.name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader()
else:
result = self._conn._client.query(self._name, query)
result = self._conn._client.query(self.name, query)
return result.to_arrow().to_reader()
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
@@ -397,7 +422,7 @@ class RemoteTable(Table):
)
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
f"/v1/table/{self.name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
@@ -451,7 +476,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
def update(
self,
@@ -512,7 +537,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
@@ -529,7 +554,7 @@ class RemoteTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int:
payload = {"predicate": filter}
resp = self._conn._client.post(
f"/v1/table/{self._name}/count_rows/", data=payload
f"/v1/table/{self.name}/count_rows/", data=payload
)
return resp

View File

@@ -51,7 +51,7 @@ if TYPE_CHECKING:
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
pd = safe_import_pandas()
@@ -840,6 +840,18 @@ class Table(ABC):
The names of the columns to drop.
"""
@cached_property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
if get_uri_scheme(self._dataset_uri) != "file":
return ("", None, False)
path = join_uri(self._dataset_uri, "_indices", "fts")
fs, path = fs_from_uri(path)
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
return (path, fs, index_exists)
class _LanceDatasetRef(ABC):
@property
@@ -979,10 +991,6 @@ class LanceTable(Table):
# Cacheable since it's deterministic
return _table_path(self._conn.uri, self.name)
@cached_property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@@ -1247,9 +1255,8 @@ class LanceTable(Table):
raise ValueError("field_names must be a string when use_tantivy=False")
# delete the existing legacy index if it exists
if replace:
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
path, fs, exist = self._get_fts_index_path()
if exist:
fs.delete_dir(path)
self._dataset_mut.create_scalar_index(
field_names, index_type="INVERTED", replace=replace
@@ -1264,9 +1271,8 @@ class LanceTable(Table):
if isinstance(ordering_field_names, str):
ordering_field_names = [ordering_field_names]
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
path, fs, exist = self._get_fts_index_path()
if exist:
if not replace:
raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path)
@@ -1277,7 +1283,7 @@ class LanceTable(Table):
)
index = create_index(
self._get_fts_index_path(),
path,
field_names,
ordering_fields=ordering_field_names,
tokenizer_name=tokenizer_name,
@@ -1290,13 +1296,6 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size,
)
def _get_fts_index_path(self):
if get_uri_scheme(self._dataset_uri) != "file":
raise NotImplementedError(
"Full-text search is not supported on object stores."
)
return join_uri(self._dataset_uri, "_indices", "tantivy")
def add(
self,
data: DATA,
@@ -1492,14 +1491,11 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
if query_type == "fts":
vector_column_name = ""
else:
raise e
raise e
return LanceQueryBuilder.create(
self,
@@ -1690,18 +1686,22 @@ class LanceTable(Table):
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
ds = self.to_lance()
return ds.scanner(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
nearest={
nearest = None
if len(query.vector) > 0:
nearest = {
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
}
return ds.scanner(
columns=query.columns,
limit=query.k,
filter=query.filter,
prefilter=query.prefilter,
nearest=nearest,
full_text_query=query.full_text_query,
with_row_id=query.with_row_id,
batch_size=batch_size,
@@ -2126,7 +2126,7 @@ class AsyncTable:
column: str,
*,
replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList]] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
):
"""Create an index to speed up queries

View File

@@ -15,6 +15,7 @@ import random
from unittest import mock
import lancedb as ldb
from lancedb.index import FTS
import numpy as np
import pandas as pd
import pytest
@@ -60,6 +61,43 @@ def table(tmp_path) -> ldb.table.LanceTable:
return table
@pytest.fixture
async def async_table(tmp_path) -> ldb.table.AsyncTable:
db = await ldb.connect_async(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
count = [random.randint(1, 10000) for _ in range(100)]
table = await db.create_table(
"test",
data=pd.DataFrame(
{
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text,
"nested": [{"text": t} for t in text],
"count": count,
}
),
)
return table
def test_create_index(tmp_path):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert isinstance(index, tantivy.Index)
@@ -91,17 +129,23 @@ def test_search_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
ldb.fts.populate_index(index, table, ["text"])
index.reload()
results = ldb.fts.search_index(index, query="puppy", limit=10)
results = ldb.fts.search_index(index, query="puppy", limit=5)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
assert len(results[0]) == 5 # row_ids
assert len(results[1]) == 5 # _score
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(10).to_list()
assert len(results) == 10
results = table.search("puppy").limit(5).to_list()
assert len(results) == 5
async def test_search_fts_async(async_table):
await async_table.create_index("text", config=FTS())
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5
def test_search_ordering_field_index_table(tmp_path, table):
@@ -125,11 +169,11 @@ def test_search_ordering_field_index(tmp_path, table):
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
index.reload()
results = ldb.fts.search_index(
index, query="puppy", limit=10, ordering_field="count"
index, query="puppy", limit=5, ordering_field="count"
)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
assert len(results[0]) == 5 # row_ids
assert len(results[1]) == 5 # _distance
rows = table.to_lance().take(results[0]).to_pylist()
for r in rows:
@@ -140,8 +184,8 @@ def test_search_ordering_field_index(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_create_index_from_table(tmp_path, table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) <= 10
df = table.search("puppy").limit(5).select(["text"]).to_pandas()
assert len(df) <= 5
assert "text" in df.columns
# Check whether it can be updated
@@ -167,8 +211,8 @@ def test_create_index_from_table(tmp_path, table, use_tantivy):
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"], use_tantivy=True)
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
df = table.search("puppy").limit(5).to_pandas()
assert len(df) == 5
assert "text" in df.columns
assert "text2" in df.columns
@@ -176,14 +220,14 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"], use_tantivy=True)
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_pandas()
df = table.search("puppy").limit(5).to_pandas()
assert len(df) == 0
def test_nested_schema(tmp_path, table):
table.create_fts_index("nested.text", use_tantivy=True)
rs = table.search("puppy").limit(10).to_list()
assert len(rs) == 10
rs = table.search("puppy").limit(5).to_list()
assert len(rs) == 5
@pytest.mark.parametrize("use_tantivy", [True, False])

View File

@@ -251,7 +251,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
# FTS indices should error since they are not supported yet.
with pytest.raises(
NotImplementedError, match="Full-text search is not supported on object stores."
NotImplementedError,
match="Full-text search is only supported on the local filesystem",
):
table.create_fts_index("x")

View File

@@ -28,7 +28,7 @@ from pydantic import BaseModel
class MockDB:
def __init__(self, uri: Path):
self.uri = uri
self.uri = str(uri)
self.read_consistency_interval = None
@functools.cached_property