Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep

This commit is contained in:
BubbleCal
2024-09-04 16:45:46 +08:00
88 changed files with 2552 additions and 2828 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.13.0-beta.0"
current_version = "0.13.0-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.13.0-beta.0"
version = "0.13.0-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -18,7 +18,7 @@ description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
license = { file = "LICENSE" }
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
keywords = [
"data-format",
"data-science",

View File

@@ -74,6 +74,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> Query: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
class VectorQuery:

View File

@@ -276,6 +276,10 @@ class DBConnection(EnforceOverrides):
"""
raise NotImplementedError
@property
def uri(self) -> str:
return self._uri
class LanceDBConnection(DBConnection):
"""
@@ -340,10 +344,6 @@ class LanceDBConnection(DBConnection):
val += ")"
return val
@property
def uri(self) -> str:
return self._uri
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)

View File

@@ -127,6 +127,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
normalize_embeddings=self.normalize_embeddings,
device=self.device,
).tolist()
return res

View File

@@ -26,12 +26,23 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
Parameters
----------
name: str, default "all-MiniLM-L6-v2"
The name of the model to use.
device: str, default "cpu"
The device to use for the model
normalize: bool, default True
Whether to normalize the embeddings
trust_remote_code: bool, default True
Whether to trust the remote code
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
trust_remote_code: bool = False
trust_remote_code: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -36,6 +36,10 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
The name of the model to use. This should be a model name that can be loaded
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
default: "colbert-ir/colbertv2.0""
device : str
The device to use for the model. Default is "cpu".
show_progress_bar : bool
Whether to show a progress bar when loading the model. Default is True.
to download package, run :
`pip install transformers`
@@ -44,6 +48,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
"""
name: str = "colbert-ir/colbertv2.0"
device: str = "cpu"
_tokenizer: Any = PrivateAttr()
_model: Any = PrivateAttr()
@@ -53,6 +58,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
transformers = attempt_import_or_raise("transformers")
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name)
self._model.to(self.device)
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
@@ -75,9 +81,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
for text in texts:
encoding = self._tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
).to(self.device)
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
embedding.append(emb.detach().numpy())
embedding.append(emb.tolist())
return embedding

View File

@@ -70,6 +70,18 @@ class LabelList:
self._inner = LanceDbIndex.label_list()
class FTS:
"""Describe a FTS index configuration.
`FTS` is a full-text search index that can be used on `String` columns
For example, it works with `title`, `description`, `content`, etc.
"""
def __init__(self):
self._inner = LanceDbIndex.fts()
class IvfPq:
"""Describes an IVF PQ Index

View File

@@ -15,7 +15,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
TYPE_CHECKING,
Dict,
@@ -35,15 +34,15 @@ import pydantic
from . import __version__
from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import fs_from_uri, safe_import_pandas
from .rerankers.rrf import RRFReranker
from .util import safe_import_pandas
if TYPE_CHECKING:
import PIL
import polars as pl
from .common import VEC
from ._lancedb import Query as LanceQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .pydantic import LanceModel
@@ -133,8 +132,8 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
) -> LanceQueryBuilder:
"""
Create a query builder based on the given query and query type.
@@ -152,12 +151,15 @@ class LanceQueryBuilder(ABC):
vector_column_name: str
The name of the vector column to use for vector search.
"""
if query is None:
return LanceEmptyQueryBuilder(table)
# Check hybrid search first as it supports empty query pattern
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if query is None:
return LanceEmptyQueryBuilder(table)
# remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None
@@ -169,12 +171,17 @@ class LanceQueryBuilder(ABC):
)
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(
table, query, ordering_field_name=ordering_field_name
table,
query,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
)
if isinstance(query, list):
@@ -200,8 +207,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
@@ -232,6 +237,8 @@ class LanceQueryBuilder(ABC):
self._where = None
self._prefilter = False
self._with_row_id = False
self._vector = None
self._text = None
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -356,7 +363,10 @@ class LanceQueryBuilder(ABC):
The LanceQueryBuilder object.
"""
if limit is None or limit <= 0:
self._limit = None
if isinstance(self, LanceVectorQueryBuilder):
raise ValueError("Limit is required for ANN/KNN queries")
else:
self._limit = None
else:
self._limit = limit
return self
@@ -456,6 +466,52 @@ class LanceQueryBuilder(ABC):
},
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
"""Set the vector to search for.
Parameters
----------
vector: np.ndarray or list
The vector to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
def text(self, text: str) -> LanceQueryBuilder:
"""Set the text to search for.
Parameters
----------
text: str
The text to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
@abstractmethod
def rerank(self, reranker: Reranker) -> LanceQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
The LanceQueryBuilder object.
"""
raise NotImplementedError
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
@@ -673,14 +729,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self,
table: "Table",
query: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
self._query = query
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
if isinstance(fts_columns, str):
fts_columns = [fts_columns]
self._fts_columns = fts_columns
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
@@ -701,8 +759,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
return self
def to_arrow(self) -> pa.Table:
tantivy_index_path = self._table._get_fts_index_path()
if Path(tantivy_index_path).exists():
path, fs, exist = self._table._get_fts_index_path()
if exist:
return self.tantivy_to_arrow()
query = self._query
@@ -711,23 +769,23 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Phrase query is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
if self._reranker:
raise NotImplementedError(
"Reranking is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
ds = self._table.to_lance()
return ds.to_table(
query = Query(
columns=self._columns,
filter=self._where,
limit=self._limit,
k=self._limit,
prefilter=self._prefilter,
with_row_id=self._with_row_id,
full_text_query={
"query": query,
"columns": self._fts_columns,
},
vector=[],
)
results = self._table._execute_query(query)
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
return results
def tantivy_to_arrow(self) -> pa.Table:
try:
@@ -740,24 +798,24 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
from .fts import search_index
# get the index path
index_path = self._table._get_fts_index_path()
# Check that we are on local filesystem
fs, _path = fs_from_uri(index_path)
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Full-text search is only supported on the local filesystem"
)
path, fs, exist = self._table._get_fts_index_path()
# check if the index exist
if not Path(index_path).exists():
if not exist:
raise FileNotFoundError(
"Fts index does not exist. "
"Please first call table.create_fts_index(['<field_names>']) to "
"create the fts index."
)
# Check that we are on local filesystem
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Tantivy-based full text search "
"is only supported on the local filesystem"
)
# open the index
index = tantivy.Index.open(index_path)
index = tantivy.Index.open(path)
# get the scores and doc ids
query = self._query
if self._phrase_query:
@@ -825,7 +883,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
LanceFtsQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError("Reranking is not yet supported for FTS queries.")
self._reranker = reranker
return self
class LanceEmptyQueryBuilder(LanceQueryBuilder):
@@ -837,54 +896,101 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
limit=self._limit,
)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
LanceEmptyQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError("Reranking is not yet supported.")
class LanceHybridQueryBuilder(LanceQueryBuilder):
"""
A query builder that performs hybrid vector and full text search.
Results are combined and reranked based on the specified reranker.
By default, the results are reranked using the LinearCombinationReranker.
By default, the results are reranked using the RRFReranker, which
uses reciprocal rank fusion score for reranking.
To make the vector and fts results comparable, the scores are normalized.
Instead of normalizing scores, the `normalize` parameter can be set to "rank"
in the `rerank` method to convert the scores to ranks and then normalize them.
"""
def __init__(self, table: "Table", query: str, vector_column: str):
def __init__(
self,
table: "Table",
query: str = None,
vector_column: str = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
self._validate_fts_index()
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._query = query
self._vector_column = vector_column
self._fts_columns = fts_columns
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
self._reranker = RRFReranker()
self._nprobes = None
self._refine_factor = None
def _validate_fts_index(self):
if self._table._get_fts_index_path() is None:
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
raise ValueError(
"Please create a full-text search index " "to perform hybrid search."
"You can either provide a string query in search() method"
"or set `vector()` and `text()` explicitly for hybrid search."
"But not both."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
else:
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
)
vector_query = vector if vector is not None else query
if not isinstance(vector_query, (str, list, np.ndarray)):
raise ValueError("Vector query must be either a string or a vector")
text_query = text or query
if text_query is None:
raise ValueError("Text query must be provided for hybrid search.")
if not isinstance(text_query, str):
raise ValueError("Text query must be a string")
return vector_query, text_query
def to_arrow(self) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
self._fts_query = LanceFtsQueryBuilder(
self._table, fts_query, fts_columns=self._fts_columns
)
vector_query = self._query_to_vector(
self._table, vector_query, self._vector_column
)
self._vector_query = LanceVectorQueryBuilder(
self._table, vector_query, self._vector_column
)
if self._limit:
self._vector_query.limit(self._limit)
self._fts_query.limit(self._limit)
if self._columns:
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, self._prefilter)
self._fts_query.where(self._where, self._prefilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
vector_future = executor.submit(
@@ -961,7 +1067,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def rerank(
self,
normalize="score",
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
reranker: Reranker = RRFReranker(),
) -> LanceHybridQueryBuilder:
"""
Rerank the hybrid search results using the specified reranker. The reranker
@@ -973,7 +1079,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
The method to normalize the scores. Can be "rank" or "score". If "rank",
the scores are converted to ranks and then normalized. If "score", the
scores are normalized directly.
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
reranker: Reranker, default RRFReranker()
The reranker to use. Must be an instance of Reranker class.
Returns
-------
@@ -990,87 +1096,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return self
def limit(self, limit: int) -> LanceHybridQueryBuilder:
"""
Set the maximum number of results to return for both vector and fts search
components.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
self._limit = limit
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:
"""
Set the columns to return for both vector and fts search.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.select(columns)
self._fts_query.select(columns)
return self
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
"""
Set the where clause for both vector and fts search.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.where(where, prefilter=prefilter)
self._fts_query.where(where)
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
"""
Set the distance metric to use for vector search.
Parameters
----------
metric: "L2" or "cosine"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.metric(metric)
return self
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
"""
Set the number of probes to use for vector search.
@@ -1088,7 +1113,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.nprobes(nprobes)
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
@@ -1106,7 +1131,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.refine_factor(refine_factor)
self._refine_factor = refine_factor
return self
def vector(self, vector: Union[np.ndarray, list]) -> LanceHybridQueryBuilder:
self._vector = vector
return self
def text(self, text: str) -> LanceHybridQueryBuilder:
self._text = text
return self
@@ -1354,6 +1387,34 @@ class AsyncQuery(AsyncQueryBase):
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []
) -> AsyncQuery:
"""
Find the documents that are most relevant to the given text query.
This method will perform a full text search on the table and return
the most relevant documents. The relevance is determined by BM25.
The columns to search must be with native FTS index
(Tantivy-based can't work with this method).
By default, all indexed columns are searched,
now only one column can be searched at a time.
Parameters
----------
query: str
The text query to search for.
columns: str or list of str, default None
The columns to search in. If None, all indexed columns are searched.
For now only one column can be searched at a time.
"""
if isinstance(columns, str):
columns = [columns]
self._inner.nearest_to_text({"query": query, "columns": columns})
return self
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):

View File

@@ -49,6 +49,7 @@ class RemoteDBConnection(DBConnection):
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self._uri = str(db_url)
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(

View File

@@ -15,7 +15,7 @@ import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Iterable, Optional, Union
from typing import Dict, Iterable, List, Optional, Union, Literal
import pyarrow as pa
from lance import json_to_schema
@@ -35,10 +35,10 @@ from .db import RemoteDBConnection
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self._name = name
self.name = name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})"
return f"RemoteTable({self._conn.db_name}.{self.name})"
def __len__(self) -> int:
self.count_rows(None)
@@ -49,14 +49,14 @@ class RemoteTable(Table):
of this Table
"""
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
return resp["version"]
@cached_property
@@ -84,19 +84,20 @@ class RemoteTable(Table):
def list_indices(self):
"""List all the indices on the table"""
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
return resp
def index_stats(self, index_uuid: str):
"""List all the stats of a specified index"""
resp = self._conn._client.post(
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
)
return resp
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
):
"""Creates a scalar index
Parameters
@@ -104,8 +105,10 @@ class RemoteTable(Table):
column : str
The column to be indexed. Must be a boolean, integer, float,
or string column.
index_type : str
The index type of the scalar index. Must be "scalar" (BTREE),
"BTREE", "BITMAP", or "LABEL_LIST"
"""
index_type = "scalar"
data = {
"column": column,
@@ -113,11 +116,27 @@ class RemoteTable(Table):
"replace": True,
}
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_scalar_index/", data=data
f"/v1/table/{self.name}/create_scalar_index/", data=data
)
return resp
def create_fts_index(
self,
column: str,
*,
replace: bool = False,
):
data = {
"column": column,
"index_type": "FTS",
"replace": replace,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
)
return resp
def create_index(
self,
metric="L2",
@@ -191,7 +210,7 @@ class RemoteTable(Table):
"index_cache_size": index_cache_size,
}
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_index/", data=data
f"/v1/table/{self.name}/create_index/", data=data
)
return resp
@@ -238,7 +257,7 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex
self._conn._client.post(
f"/v1/table/{self._name}/insert/",
f"/v1/table/{self.name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
@@ -248,6 +267,8 @@ class RemoteTable(Table):
self,
query: Union[VEC, str],
vector_column_name: Optional[str] = None,
query_type="auto",
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -307,10 +328,19 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema)
query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name)
return LanceVectorQueryBuilder(self, query, vector_column_name)
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
raise e
return LanceQueryBuilder.create(
self,
query,
query_type,
vector_column_name=vector_column_name,
fts_columns=fts_columns,
)
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
@@ -339,12 +369,12 @@ class RemoteTable(Table):
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self._name, q))
results.append(submit(self.name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader()
else:
result = self._conn._client.query(self._name, query)
result = self._conn._client.query(self.name, query)
return result.to_arrow().to_reader()
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
@@ -394,7 +424,7 @@ class RemoteTable(Table):
)
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
f"/v1/table/{self.name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
@@ -448,7 +478,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
def update(
self,
@@ -509,7 +539,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
@@ -526,7 +556,7 @@ class RemoteTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int:
payload = {"predicate": filter}
resp = self._conn._client.post(
f"/v1/table/{self._name}/count_rows/", data=payload
f"/v1/table/{self.name}/count_rows/", data=payload
)
return resp

View File

@@ -6,6 +6,7 @@ from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jinaai import JinaReranker
from .rrf import RRFReranker
from .answerdotai import AnswerdotaiRerankers
__all__ = [
"Reranker",
@@ -16,4 +17,5 @@ __all__ = [
"ColbertReranker",
"JinaReranker",
"RRFReranker",
"AnswerdotaiRerankers",
]

View File

@@ -0,0 +1,99 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
from .base import Reranker
from ..util import attempt_import_or_raise
class AnswerdotaiRerankers(Reranker):
"""
Reranks the results using the Answerdotai Rerank API.
All supported reranker model types can be found here:
- https://github.com/AnswerDotAI/rerankers
Parameters
----------
model_type : str, default "colbert"
The type of the model to use.
model_name : str, default "rerank-english-v2.0"
The name of the model to use from the given model type.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_type="colbert",
model_name: str = "answerdotai/answerai-colbert-small-v1",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.column = column
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.reranker = rerankers.Reranker(model_name, model_type)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
doc_ids = list(range(len(docs)))
result = self.reranker.rank(query, docs, doc_ids=doc_ids)
# get the scores of each document in the same order as the input
scores = [result.get_result_by_docid(i).score for i in doc_ids]
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"Answerdotai Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from packaging.version import Version
from typing import Union, List, TYPE_CHECKING

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from packaging.version import Version
from functools import cached_property

View File

@@ -1,18 +1,26 @@
from functools import cached_property
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
from .answerdotai import AnswerdotaiRerankers
class ColbertReranker(Reranker):
class ColbertReranker(AnswerdotaiRerankers):
"""
Reranks the results using the ColBERT model.
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
model_name : str, default "colbert" (colbert-ir/colbert-v2.0)
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
@@ -26,115 +34,9 @@ class ColbertReranker(Reranker):
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
tokenizer, model = self._model
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
result_set = result_set.drop(self.column)
result_set = result_set.append_column(
self.column, pa.array(docs, type=pa.string())
super().__init__(
model_type="colbert",
model_name=model_name,
column=column,
return_score=return_score,
)
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
@cached_property
def _model(self):
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Union
@@ -22,6 +35,11 @@ class CrossEncoderReranker(Reranker):
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
trust_remote_code : bool, default True
If True, will trust the remote code to be safe. If False, will not trust
the remote code and will not run it
"""
def __init__(
@@ -30,19 +48,26 @@ class CrossEncoderReranker(Reranker):
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
trust_remote_code: bool = True,
):
super().__init__(return_score)
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
self.trust_remote_code = trust_remote_code
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@cached_property
def model(self):
sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
# Allows overriding the automatically selected device
cross_encoder = sbert.CrossEncoder(
self.model_name,
device=self.device,
trust_remote_code=self.trust_remote_code,
)
return cross_encoder

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
from functools import cached_property

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
from .base import Reranker

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from functools import cached_property

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, TYPE_CHECKING
import pyarrow as pa

View File

@@ -51,7 +51,7 @@ if TYPE_CHECKING:
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
pd = safe_import_pandas()
@@ -339,9 +339,9 @@ class Table(ABC):
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
*,
replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
):
"""Create a scalar index on a column.
@@ -391,6 +391,8 @@ class Table(ABC):
or string column.
replace : bool, default True
Replace the existing index if it exists.
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
The type of index to create.
Examples
--------
@@ -403,6 +405,47 @@ class Table(ABC):
"""
raise NotImplementedError
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
tokenizer_name: str = "default",
use_tantivy: bool = True,
):
"""Create a full-text search index on the table.
Warning - this API is highly experimental and is highly likely to change
in the future.
Parameters
----------
field_names: str or list of str
The name(s) of the field to index.
can be only str if use_tantivy=True for now.
replace: bool, default False
If True, replace the existing index if it exists. Note that this is
not yet an atomic operation; the index will be temporarily
unavailable while the new index is being created.
writer_heap_size: int, default 1GB
Only available with use_tantivy=True
ordering_field_names:
A list of unsigned type fields to index to optionally order
results on at search time.
only available with use_tantivy=True
tokenizer_name: str, default "default"
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
language code followed by "_stem". So for english it would be "en_stem".
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
only available with use_tantivy=True for now
use_tantivy: bool, default True
If True, use the legacy full-text search implementation based on tantivy.
If False, use the new full-text search implementation based on lance-index.
"""
raise NotImplementedError
@abstractmethod
def add(
self,
@@ -502,7 +545,7 @@ class Table(ABC):
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -799,6 +842,18 @@ class Table(ABC):
The names of the columns to drop.
"""
@cached_property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
if get_uri_scheme(self._dataset_uri) != "file":
return ("", None, False)
path = join_uri(self._dataset_uri, "_indices", "fts")
fs, path = fs_from_uri(path)
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
return (path, fs, index_exists)
class _LanceDatasetRef(ABC):
@property
@@ -938,10 +993,6 @@ class LanceTable(Table):
# Cacheable since it's deterministic
return _table_path(self._conn.uri, self.name)
@cached_property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@@ -1183,9 +1234,9 @@ class LanceTable(Table):
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
*,
replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
):
self._dataset_mut.create_scalar_index(
column, index_type=index_type, replace=replace
@@ -1201,42 +1252,13 @@ class LanceTable(Table):
tokenizer_name: str = "default",
use_tantivy: bool = True,
):
"""Create a full-text search index on the table.
Warning - this API is highly experimental and is highly likely to change
in the future.
Parameters
----------
field_names: str or list of str
The name(s) of the field to index.
can be only str if use_tantivy=True for now.
replace: bool, default False
If True, replace the existing index if it exists. Note that this is
not yet an atomic operation; the index will be temporarily
unavailable while the new index is being created.
writer_heap_size: int, default 1GB
ordering_field_names:
A list of unsigned type fields to index to optionally order
results on at search time.
only available with use_tantivy=True
tokenizer_name: str, default "default"
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
language code followed by "_stem". So for english it would be "en_stem".
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
only available with use_tantivy=True for now
use_tantivy: bool, default False
If True, use the legacy full-text search implementation based on tantivy.
If False, use the new full-text search implementation based on lance-index.
"""
if not use_tantivy:
if not isinstance(field_names, str):
raise ValueError("field_names must be a string when use_tantivy=False")
# delete the existing legacy index if it exists
if replace:
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
path, fs, exist = self._get_fts_index_path()
if exist:
fs.delete_dir(path)
self._dataset_mut.create_scalar_index(
field_names, index_type="INVERTED", replace=replace
@@ -1251,9 +1273,8 @@ class LanceTable(Table):
if isinstance(ordering_field_names, str):
ordering_field_names = [ordering_field_names]
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
path, fs, exist = self._get_fts_index_path()
if exist:
if not replace:
raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path)
@@ -1264,7 +1285,7 @@ class LanceTable(Table):
)
index = create_index(
self._get_fts_index_path(),
path,
field_names,
ordering_fields=ordering_field_names,
tokenizer_name=tokenizer_name,
@@ -1277,13 +1298,6 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size,
)
def _get_fts_index_path(self):
if get_uri_scheme(self._dataset_uri) != "file":
raise NotImplementedError(
"Full-text search is not supported on object stores."
)
return join_uri(self._dataset_uri, "_indices", "tantivy")
def add(
self,
data: DATA,
@@ -1411,7 +1425,7 @@ class LanceTable(Table):
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1479,14 +1493,11 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
if query_type == "fts":
vector_column_name = ""
else:
raise e
raise e
return LanceQueryBuilder.create(
self,
@@ -1494,6 +1505,7 @@ class LanceTable(Table):
query_type,
vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
)
@classmethod
@@ -1677,18 +1689,22 @@ class LanceTable(Table):
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
ds = self.to_lance()
return ds.scanner(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
nearest={
nearest = None
if len(query.vector) > 0:
nearest = {
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
}
return ds.scanner(
columns=query.columns,
limit=query.k,
filter=query.filter,
prefilter=query.prefilter,
nearest=nearest,
full_text_query=query.full_text_query,
with_row_id=query.with_row_id,
batch_size=batch_size,
@@ -2113,7 +2129,7 @@ class AsyncTable:
column: str,
*,
replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList]] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
):
"""Create an index to speed up queries
@@ -2438,7 +2454,10 @@ class AsyncTable:
await self._inner.restore()
async def optimize(
self, *, cleanup_older_than: Optional[timedelta] = None
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
) -> OptimizeStats:
"""
Optimize the on-disk data and indices for better performance.
@@ -2457,6 +2476,11 @@ class AsyncTable:
All files belonging to versions older than this will be removed. Set
to 0 days to remove all versions except the latest. The latest version
is never removed.
delete_unverified: bool, default False
Files leftover from a failed transaction may appear to be part of an
in-progress operation (e.g. appending new data) and these files will not
be deleted unless they are at least 7 days old. If delete_unverified is True
then these files will be deleted regardless of their age.
Experimental API
----------------
@@ -2478,7 +2502,7 @@ class AsyncTable:
"""
if cleanup_older_than is not None:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than)
return await self._inner.optimize(cleanup_older_than, delete_unverified)
async def list_indices(self) -> IndexConfig:
"""

View File

@@ -15,6 +15,7 @@ import random
from unittest import mock
import lancedb as ldb
from lancedb.index import FTS
import numpy as np
import pandas as pd
import pytest
@@ -28,14 +29,26 @@ def table(tmp_path) -> ldb.table.LanceTable:
db = ldb.connect(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
text_nouns = ("puppy", "car")
text2_nouns = ("rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
text_nouns[random.randrange(0, len(text_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
text2 = [
" ".join(
[
text2_nouns[random.randrange(0, len(text2_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
@@ -51,7 +64,56 @@ def table(tmp_path) -> ldb.table.LanceTable:
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text,
"text2": text2,
"nested": [{"text": t} for t in text],
"count": count,
}
),
)
return table
@pytest.fixture
async def async_table(tmp_path) -> ldb.table.AsyncTable:
db = await ldb.connect_async(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
text_nouns = ("puppy", "car")
text2_nouns = ("rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
text_nouns[random.randrange(0, len(text_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
text2 = [
" ".join(
[
text2_nouns[random.randrange(0, len(text2_nouns))],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
count = [random.randint(1, 10000) for _ in range(100)]
table = await db.create_table(
"test",
data=pd.DataFrame(
{
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text2,
"nested": [{"text": t} for t in text],
"count": count,
}
@@ -91,17 +153,92 @@ def test_search_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
ldb.fts.populate_index(index, table, ["text"])
index.reload()
results = ldb.fts.search_index(index, query="puppy", limit=10)
results = ldb.fts.search_index(index, query="puppy", limit=5)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
assert len(results[0]) == 5 # row_ids
assert len(results[1]) == 5 # _score
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(10).to_list()
assert len(results) == 10
results = table.search("puppy").limit(5).to_list()
assert len(results) == 5
def test_search_fts_specify_column(table):
table.create_fts_index("text", use_tantivy=False)
table.create_fts_index("text2", use_tantivy=False)
results = table.search("puppy", fts_columns="text").limit(5).to_list()
assert len(results) == 5
results = table.search("rabbit", fts_columns="text2").limit(5).to_list()
assert len(results) == 5
try:
# we can only specify one column for now
table.search("puppy", fts_columns=["text", "text2"]).limit(5).to_list()
assert False
except Exception:
pass
try:
# have to specify a column because we have two fts indices
table.search("puppy").limit(5).to_list()
assert False
except Exception:
pass
@pytest.mark.asyncio
async def test_search_fts_async(async_table):
async_table = await async_table
await async_table.create_index("text", config=FTS())
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5
@pytest.mark.asyncio
async def test_search_fts_specify_column_async(async_table):
async_table = await async_table
await async_table.create_index("text", config=FTS())
await async_table.create_index("text2", config=FTS())
results = (
await async_table.query()
.nearest_to_text("puppy", columns="text")
.limit(5)
.to_list()
)
assert len(results) == 5
results = (
await async_table.query()
.nearest_to_text("rabbit", columns="text2")
.limit(5)
.to_list()
)
assert len(results) == 5
try:
# we can only specify one column for now
await (
async_table.query()
.nearest_to_text("rabbit", columns="text2")
.limit(5)
.to_list()
)
assert False
except Exception:
pass
try:
# have to specify a column because we have two fts indices
await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert False
except Exception:
pass
def test_search_ordering_field_index_table(tmp_path, table):
@@ -125,11 +262,11 @@ def test_search_ordering_field_index(tmp_path, table):
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
index.reload()
results = ldb.fts.search_index(
index, query="puppy", limit=10, ordering_field="count"
index, query="puppy", limit=5, ordering_field="count"
)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
assert len(results[0]) == 5 # row_ids
assert len(results[1]) == 5 # _distance
rows = table.to_lance().take(results[0]).to_pylist()
for r in rows:
@@ -140,8 +277,8 @@ def test_search_ordering_field_index(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_create_index_from_table(tmp_path, table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) <= 10
df = table.search("puppy").limit(5).select(["text"]).to_pandas()
assert len(df) <= 5
assert "text" in df.columns
# Check whether it can be updated
@@ -167,8 +304,8 @@ def test_create_index_from_table(tmp_path, table, use_tantivy):
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"], use_tantivy=True)
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
df = table.search("puppy").limit(5).to_pandas()
assert len(df) == 5
assert "text" in df.columns
assert "text2" in df.columns
@@ -176,14 +313,14 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"], use_tantivy=True)
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_pandas()
df = table.search("puppy").limit(5).to_pandas()
assert len(df) == 0
def test_nested_schema(tmp_path, table):
table.create_fts_index("nested.text", use_tantivy=True)
rs = table.search("puppy").limit(10).to_list()
assert len(rs) == 10
rs = table.search("puppy").limit(5).to_list()
assert len(rs) == 5
@pytest.mark.parametrize("use_tantivy", [True, False])

View File

@@ -117,6 +117,18 @@ def test_query_builder(table):
assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
["id", "vector"]
).to_list()
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(None).select(
["id", "vector"]
).to_list()
def test_query_builder_batches(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")

View File

@@ -15,6 +15,7 @@ from lancedb.rerankers import (
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
AnswerdotaiRerankers,
)
from lancedb.table import LanceTable
@@ -110,7 +111,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), vector_column_name="vector")
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
@@ -206,14 +209,26 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), vector_column_name="vector")
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
assert len(result) == 30
# Fail if both query and (vector or text) are provided
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").vector(
query_vector
).to_arrow()
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").text(
query
).to_arrow()
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
@@ -236,33 +251,45 @@ def test_rrf_reranker(tmp_path, use_tantivy):
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cohere_reranker(tmp_path, use_tantivy):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_cross_encoder_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cross_encoder_reranker(tmp_path, use_tantivy):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_colbert_reranker(tmp_path, use_tantivy):
pytest.importorskip("rerankers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_answerdotai_reranker(tmp_path, use_tantivy):
pytest.importorskip("rerankers")
reranker = AnswerdotaiRerankers()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_openai_reranker(tmp_path, use_tantivy):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = OpenaiReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@@ -270,8 +297,9 @@ def test_openai_reranker(tmp_path):
@pytest.mark.skipif(
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
)
def test_jina_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_jina_reranker(tmp_path, use_tantivy):
pytest.importorskip("jina")
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)

View File

@@ -251,7 +251,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
# FTS indices should error since they are not supported yet.
with pytest.raises(
NotImplementedError, match="Full-text search is not supported on object stores."
NotImplementedError,
match="Full-text search is only supported on the local filesystem",
):
table.create_fts_index("x")

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from time import sleep
from typing import List
from unittest.mock import PropertyMock, patch
import os
import lance
import lancedb
@@ -27,7 +28,7 @@ from pydantic import BaseModel
class MockDB:
def __init__(self, uri: Path):
self.uri = uri
self.uri = str(uri)
self.read_consistency_interval = None
@functools.cached_property
@@ -1052,3 +1053,25 @@ async def test_optimize(db_async: AsyncConnection):
assert stats.prune.old_versions_removed == 3
assert await table.query().to_arrow() == pa.table({"x": [[1], [2]]})
@pytest.mark.asyncio
async def test_optimize_delete_unverified(db_async: AsyncConnection, tmp_path):
table = await db_async.create_table(
"test",
data=[{"x": [1]}],
)
await table.add(
data=[
{"x": [2]},
],
)
version = await table.version()
path = tmp_path / "test.lance" / "_versions" / f"{version - 1}.manifest"
os.remove(path)
stats = await table.optimize(delete_unverified=False)
assert stats.prune.old_versions_removed == 0
stats = await table.optimize(
cleanup_older_than=timedelta(seconds=0), delete_unverified=True
)
assert stats.prune.old_versions_removed == 2

View File

@@ -98,6 +98,13 @@ impl Index {
inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))),
})
}
#[staticmethod]
pub fn fts() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::FTS(Default::default()))),
})
}
}
#[pyclass(get_all)]

View File

@@ -15,17 +15,20 @@
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
use pyo3::exceptions::PyRuntimeError;
use pyo3::pyclass;
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
use pyo3::pymethods;
use pyo3::types::PyDict;
use pyo3::Bound;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3::{pyclass, PyErr};
use pyo3_asyncio_0_21::tokio::future_into_py;
use crate::arrow::RecordBatchStream;
@@ -68,6 +71,24 @@ impl Query {
Ok(VectorQuery { inner })
}
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<()> {
let query_text = query
.get_item("query")?
.ok_or(PyErr::new::<PyRuntimeError, _>(
"Query text is required for nearest_to_text",
))?
.extract::<String>()?;
let columns = query
.get_item("columns")?
.map(|columns| columns.extract::<Vec<String>>())
.transpose()?;
let fts_query = FullTextSearchQuery::new(query_text).columns(columns);
self.inner = self.inner.clone().full_text_search(fts_query);
Ok(())
}
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,

View File

@@ -248,6 +248,7 @@ impl Table {
pub fn optimize(
self_: PyRef<'_, Self>,
cleanup_since_ms: Option<u64>,
delete_unverified: Option<bool>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
let older_than = if let Some(ms) = cleanup_since_ms {
@@ -275,7 +276,7 @@ impl Table {
let prune_stats = inner
.optimize(OptimizeAction::Prune {
older_than,
delete_unverified: None,
delete_unverified,
error_if_tagged_old_versions: None,
})
.await