mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-14 15:52:57 +00:00
Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.13.0-beta.0"
|
||||
current_version = "0.13.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.13.0-beta.0"
|
||||
version = "0.13.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -18,7 +18,7 @@ description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
license = { file = "LICENSE" }
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.9"
|
||||
keywords = [
|
||||
"data-format",
|
||||
"data-science",
|
||||
|
||||
@@ -74,6 +74,7 @@ class Query:
|
||||
def select(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> Query: ...
|
||||
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
|
||||
|
||||
class VectorQuery:
|
||||
|
||||
@@ -276,6 +276,10 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
@@ -340,10 +344,6 @@ class LanceDBConnection(DBConnection):
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
|
||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||
return await conn.table_names(start_after=start_after, limit=limit)
|
||||
|
||||
@@ -127,6 +127,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
device=self.device,
|
||||
).tolist()
|
||||
return res
|
||||
|
||||
|
||||
@@ -26,12 +26,23 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "all-MiniLM-L6-v2"
|
||||
The name of the model to use.
|
||||
device: str, default "cpu"
|
||||
The device to use for the model
|
||||
normalize: bool, default True
|
||||
Whether to normalize the embeddings
|
||||
trust_remote_code: bool, default True
|
||||
Whether to trust the remote code
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
trust_remote_code: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -36,6 +36,10 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
The name of the model to use. This should be a model name that can be loaded
|
||||
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
|
||||
default: "colbert-ir/colbertv2.0""
|
||||
device : str
|
||||
The device to use for the model. Default is "cpu".
|
||||
show_progress_bar : bool
|
||||
Whether to show a progress bar when loading the model. Default is True.
|
||||
|
||||
to download package, run :
|
||||
`pip install transformers`
|
||||
@@ -44,6 +48,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "colbert-ir/colbertv2.0"
|
||||
device: str = "cpu"
|
||||
_tokenizer: Any = PrivateAttr()
|
||||
_model: Any = PrivateAttr()
|
||||
|
||||
@@ -53,6 +58,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
self._model.to(self.device)
|
||||
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
@@ -75,9 +81,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
for text in texts:
|
||||
encoding = self._tokenizer(
|
||||
text, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
).to(self.device)
|
||||
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
|
||||
embedding.append(emb.detach().numpy())
|
||||
embedding.append(emb.tolist())
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
@@ -70,6 +70,18 @@ class LabelList:
|
||||
self._inner = LanceDbIndex.label_list()
|
||||
|
||||
|
||||
class FTS:
|
||||
"""Describe a FTS index configuration.
|
||||
|
||||
`FTS` is a full-text search index that can be used on `String` columns
|
||||
|
||||
For example, it works with `title`, `description`, `content`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.fts()
|
||||
|
||||
|
||||
class IvfPq:
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
@@ -35,15 +34,15 @@ import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from .common import VEC
|
||||
from ._lancedb import Query as LanceQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from .pydantic import LanceModel
|
||||
@@ -133,8 +132,8 @@ class LanceQueryBuilder(ABC):
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: str = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
) -> LanceQueryBuilder:
|
||||
"""
|
||||
Create a query builder based on the given query and query type.
|
||||
@@ -152,12 +151,15 @@ class LanceQueryBuilder(ABC):
|
||||
vector_column_name: str
|
||||
The name of the vector column to use for vector search.
|
||||
"""
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# Check hybrid search first as it supports empty query pattern
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
@@ -169,12 +171,17 @@ class LanceQueryBuilder(ABC):
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(
|
||||
table, query, ordering_field_name=ordering_field_name
|
||||
table,
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -200,8 +207,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -232,6 +237,8 @@ class LanceQueryBuilder(ABC):
|
||||
self._where = None
|
||||
self._prefilter = False
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -356,7 +363,10 @@ class LanceQueryBuilder(ABC):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if limit is None or limit <= 0:
|
||||
self._limit = None
|
||||
if isinstance(self, LanceVectorQueryBuilder):
|
||||
raise ValueError("Limit is required for ANN/KNN queries")
|
||||
else:
|
||||
self._limit = None
|
||||
else:
|
||||
self._limit = limit
|
||||
return self
|
||||
@@ -456,6 +466,52 @@ class LanceQueryBuilder(ABC):
|
||||
},
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
"""Set the vector to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector: np.ndarray or list
|
||||
The vector to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def text(self, text: str) -> LanceQueryBuilder:
|
||||
"""Set the text to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
The text to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self, reranker: Reranker) -> LanceQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -673,14 +729,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
ordering_field_name: str = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
self.ordering_field_name = ordering_field_name
|
||||
self._reranker = None
|
||||
if isinstance(fts_columns, str):
|
||||
fts_columns = [fts_columns]
|
||||
self._fts_columns = fts_columns
|
||||
|
||||
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
||||
@@ -701,8 +759,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
return self
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
tantivy_index_path = self._table._get_fts_index_path()
|
||||
if Path(tantivy_index_path).exists():
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
return self.tantivy_to_arrow()
|
||||
|
||||
query = self._query
|
||||
@@ -711,23 +769,23 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"Phrase query is not yet supported in Lance FTS. "
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
if self._reranker:
|
||||
raise NotImplementedError(
|
||||
"Reranking is not yet supported in Lance FTS. "
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit,
|
||||
prefilter=self._prefilter,
|
||||
with_row_id=self._with_row_id,
|
||||
full_text_query={
|
||||
"query": query,
|
||||
"columns": self._fts_columns,
|
||||
},
|
||||
vector=[],
|
||||
)
|
||||
results = self._table._execute_query(query)
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
return results
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
@@ -740,24 +798,24 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
from .fts import search_index
|
||||
|
||||
# get the index path
|
||||
index_path = self._table._get_fts_index_path()
|
||||
|
||||
# Check that we are on local filesystem
|
||||
fs, _path = fs_from_uri(index_path)
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
if not exist:
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
|
||||
# Check that we are on local filesystem
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Tantivy-based full text search "
|
||||
"is only supported on the local filesystem"
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(index_path)
|
||||
index = tantivy.Index.open(path)
|
||||
# get the scores and doc ids
|
||||
query = self._query
|
||||
if self._phrase_query:
|
||||
@@ -825,7 +883,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
LanceFtsQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not yet supported for FTS queries.")
|
||||
self._reranker = reranker
|
||||
return self
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
@@ -837,54 +896,101 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceEmptyQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not yet supported.")
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
A query builder that performs hybrid vector and full text search.
|
||||
Results are combined and reranked based on the specified reranker.
|
||||
By default, the results are reranked using the LinearCombinationReranker.
|
||||
By default, the results are reranked using the RRFReranker, which
|
||||
uses reciprocal rank fusion score for reranking.
|
||||
|
||||
To make the vector and fts results comparable, the scores are normalized.
|
||||
Instead of normalizing scores, the `normalize` parameter can be set to "rank"
|
||||
in the `rerank` method to convert the scores to ranks and then normalize them.
|
||||
"""
|
||||
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str = None,
|
||||
vector_column: str = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._query = query
|
||||
self._vector_column = vector_column
|
||||
self._fts_columns = fts_columns
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
self._reranker = RRFReranker()
|
||||
self._nprobes = None
|
||||
self._refine_factor = None
|
||||
|
||||
def _validate_fts_index(self):
|
||||
if self._table._get_fts_index_path() is None:
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
raise ValueError(
|
||||
"Please create a full-text search index " "to perform hybrid search."
|
||||
"You can either provide a string query in search() method"
|
||||
"or set `vector()` and `text()` explicitly for hybrid search."
|
||||
"But not both."
|
||||
)
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
return query, query
|
||||
elif isinstance(query, tuple):
|
||||
if len(query) != 2:
|
||||
raise ValueError(
|
||||
"The query must be a tuple of (vector_query, fts_query)."
|
||||
)
|
||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
if not isinstance(query[1], str):
|
||||
raise ValueError("The fts query must be a string.")
|
||||
return query[0], query[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"The query must be either a string or a tuple of (vector, string)."
|
||||
)
|
||||
vector_query = vector if vector is not None else query
|
||||
if not isinstance(vector_query, (str, list, np.ndarray)):
|
||||
raise ValueError("Vector query must be either a string or a vector")
|
||||
|
||||
text_query = text or query
|
||||
if text_query is None:
|
||||
raise ValueError("Text query must be provided for hybrid search.")
|
||||
if not isinstance(text_query, str):
|
||||
raise ValueError("Text query must be a string")
|
||||
|
||||
return vector_query, text_query
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
vector_query, fts_query = self._validate_query(
|
||||
self._query, self._vector, self._text
|
||||
)
|
||||
self._fts_query = LanceFtsQueryBuilder(
|
||||
self._table, fts_query, fts_columns=self._fts_columns
|
||||
)
|
||||
vector_query = self._query_to_vector(
|
||||
self._table, vector_query, self._vector_column
|
||||
)
|
||||
self._vector_query = LanceVectorQueryBuilder(
|
||||
self._table, vector_query, self._vector_column
|
||||
)
|
||||
|
||||
if self._limit:
|
||||
self._vector_query.limit(self._limit)
|
||||
self._fts_query.limit(self._limit)
|
||||
if self._columns:
|
||||
self._vector_query.select(self._columns)
|
||||
self._fts_query.select(self._columns)
|
||||
if self._where:
|
||||
self._vector_query.where(self._where, self._prefilter)
|
||||
self._fts_query.where(self._where, self._prefilter)
|
||||
if self._with_row_id:
|
||||
self._vector_query.with_row_id(True)
|
||||
self._fts_query.with_row_id(True)
|
||||
if self._nprobes:
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
vector_future = executor.submit(
|
||||
@@ -961,7 +1067,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
|
||||
reranker: Reranker = RRFReranker(),
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
@@ -973,7 +1079,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
@@ -990,87 +1096,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the maximum number of results to return for both vector and fts search
|
||||
components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the columns to return for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.select(columns)
|
||||
self._fts_query.select(columns)
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the where clause for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
|
||||
self._vector_query.where(where, prefilter=prefilter)
|
||||
self._fts_query.where(where)
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance metric to use for vector search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.metric(metric)
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of probes to use for vector search.
|
||||
@@ -1088,7 +1113,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.nprobes(nprobes)
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
@@ -1106,7 +1131,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.refine_factor(refine_factor)
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceHybridQueryBuilder:
|
||||
self._vector = vector
|
||||
return self
|
||||
|
||||
def text(self, text: str) -> LanceHybridQueryBuilder:
|
||||
self._text = text
|
||||
return self
|
||||
|
||||
|
||||
@@ -1354,6 +1387,34 @@ class AsyncQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
) -> AsyncQuery:
|
||||
"""
|
||||
Find the documents that are most relevant to the given text query.
|
||||
|
||||
This method will perform a full text search on the table and return
|
||||
the most relevant documents. The relevance is determined by BM25.
|
||||
|
||||
The columns to search must be with native FTS index
|
||||
(Tantivy-based can't work with this method).
|
||||
|
||||
By default, all indexed columns are searched,
|
||||
now only one column can be searched at a time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str
|
||||
The text query to search for.
|
||||
columns: str or list of str, default None
|
||||
The columns to search in. If None, all indexed columns are searched.
|
||||
For now only one column can be searched at a time.
|
||||
"""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
return self
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
|
||||
@@ -49,6 +49,7 @@ class RemoteDBConnection(DBConnection):
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self._uri = str(db_url)
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
|
||||
@@ -15,7 +15,7 @@ import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
@@ -35,10 +35,10 @@ from .db import RemoteDBConnection
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self._name = name
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
@@ -49,14 +49,14 @@ class RemoteTable(Table):
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
return resp["version"]
|
||||
|
||||
@cached_property
|
||||
@@ -84,19 +84,20 @@ class RemoteTable(Table):
|
||||
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
|
||||
return resp
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
|
||||
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -104,8 +105,10 @@ class RemoteTable(Table):
|
||||
column : str
|
||||
The column to be indexed. Must be a boolean, integer, float,
|
||||
or string column.
|
||||
index_type : str
|
||||
The index type of the scalar index. Must be "scalar" (BTREE),
|
||||
"BTREE", "BITMAP", or "LABEL_LIST"
|
||||
"""
|
||||
index_type = "scalar"
|
||||
|
||||
data = {
|
||||
"column": column,
|
||||
@@ -113,11 +116,27 @@ class RemoteTable(Table):
|
||||
"replace": True,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_scalar_index/", data=data
|
||||
f"/v1/table/{self.name}/create_scalar_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": "FTS",
|
||||
"replace": replace,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -191,7 +210,7 @@ class RemoteTable(Table):
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_index/", data=data
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
@@ -238,7 +257,7 @@ class RemoteTable(Table):
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
f"/v1/table/{self.name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
@@ -248,6 +267,8 @@ class RemoteTable(Table):
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type="auto",
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -307,10 +328,19 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
@@ -339,12 +369,12 @@ class RemoteTable(Table):
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self._name, q))
|
||||
results.append(submit(self.name, q))
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
).to_reader()
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
result = self._conn._client.query(self.name, query)
|
||||
return result.to_arrow().to_reader()
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
@@ -394,7 +424,7 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/merge_insert/",
|
||||
f"/v1/table/{self.name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
@@ -448,7 +478,7 @@ class RemoteTable(Table):
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -509,7 +539,7 @@ class RemoteTable(Table):
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
@@ -526,7 +556,7 @@ class RemoteTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
payload = {"predicate": filter}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/count_rows/", data=payload
|
||||
f"/v1/table/{self.name}/count_rows/", data=payload
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"ColbertReranker",
|
||||
"JinaReranker",
|
||||
"RRFReranker",
|
||||
"AnswerdotaiRerankers",
|
||||
]
|
||||
|
||||
99
python/python/lancedb/rerankers/answerdotai.py
Normal file
99
python/python/lancedb/rerankers/answerdotai.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pyarrow as pa
|
||||
from .base import Reranker
|
||||
from ..util import attempt_import_or_raise
|
||||
|
||||
|
||||
class AnswerdotaiRerankers(Reranker):
|
||||
"""
|
||||
Reranks the results using the Answerdotai Rerank API.
|
||||
All supported reranker model types can be found here:
|
||||
- https://github.com/AnswerDotAI/rerankers
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_type : str, default "colbert"
|
||||
The type of the model to use.
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the model to use from the given model type.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="colbert",
|
||||
model_name: str = "answerdotai/answerai-colbert-small-v1",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.column = column
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.reranker = rerankers.Reranker(model_name, model_type)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
doc_ids = list(range(len(docs)))
|
||||
result = self.reranker.rank(query, docs, doc_ids=doc_ids)
|
||||
|
||||
# get the scores of each document in the same order as the input
|
||||
scores = [result.get_result_by_docid(i).score for i in doc_ids]
|
||||
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"Answerdotai Reranker does not support score='all' yet"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return fts_results
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from packaging.version import Version
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from packaging.version import Version
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
from functools import cached_property
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
class ColbertReranker(AnswerdotaiRerankers):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
model_name : str, default "colbert" (colbert-ir/colbert-v2.0)
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
@@ -26,115 +34,9 @@ class ColbertReranker(Reranker):
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
result_set = result_set.drop(self.column)
|
||||
result_set = result_set.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
super().__init__(
|
||||
model_type="colbert",
|
||||
model_name=model_name,
|
||||
column=column,
|
||||
return_score=return_score,
|
||||
)
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -22,6 +35,11 @@ class CrossEncoderReranker(Reranker):
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
trust_remote_code : bool, default True
|
||||
If True, will trust the remote code to be safe. If False, will not trust
|
||||
the remote code and will not run it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -30,19 +48,26 @@ class CrossEncoderReranker(Reranker):
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
trust_remote_code: bool = True,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
self.trust_remote_code = trust_remote_code
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
# Allows overriding the automatically selected device
|
||||
cross_encoder = sbert.CrossEncoder(
|
||||
self.model_name,
|
||||
device=self.device,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if TYPE_CHECKING:
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
@@ -339,9 +339,9 @@ class Table(ABC):
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
@@ -391,6 +391,8 @@ class Table(ABC):
|
||||
or string column.
|
||||
replace : bool, default True
|
||||
Replace the existing index if it exists.
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
|
||||
The type of index to create.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -403,6 +405,47 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
ordering_field_names: Union[str, List[str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||
tokenizer_name: str = "default",
|
||||
use_tantivy: bool = True,
|
||||
):
|
||||
"""Create a full-text search index on the table.
|
||||
|
||||
Warning - this API is highly experimental and is highly likely to change
|
||||
in the future.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_names: str or list of str
|
||||
The name(s) of the field to index.
|
||||
can be only str if use_tantivy=True for now.
|
||||
replace: bool, default False
|
||||
If True, replace the existing index if it exists. Note that this is
|
||||
not yet an atomic operation; the index will be temporarily
|
||||
unavailable while the new index is being created.
|
||||
writer_heap_size: int, default 1GB
|
||||
Only available with use_tantivy=True
|
||||
ordering_field_names:
|
||||
A list of unsigned type fields to index to optionally order
|
||||
results on at search time.
|
||||
only available with use_tantivy=True
|
||||
tokenizer_name: str, default "default"
|
||||
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
|
||||
language code followed by "_stem". So for english it would be "en_stem".
|
||||
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
|
||||
only available with use_tantivy=True for now
|
||||
use_tantivy: bool, default True
|
||||
If True, use the legacy full-text search implementation based on tantivy.
|
||||
If False, use the new full-text search implementation based on lance-index.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
@@ -502,7 +545,7 @@ class Table(ABC):
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -799,6 +842,18 @@ class Table(ABC):
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def _dataset_uri(self) -> str:
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
return ("", None, False)
|
||||
path = join_uri(self._dataset_uri, "_indices", "fts")
|
||||
fs, path = fs_from_uri(path)
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
return (path, fs, index_exists)
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@@ -938,10 +993,6 @@ class LanceTable(Table):
|
||||
# Cacheable since it's deterministic
|
||||
return _table_path(self._conn.uri, self.name)
|
||||
|
||||
@cached_property
|
||||
def _dataset_uri(self) -> str:
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return self._ref.dataset
|
||||
@@ -1183,9 +1234,9 @@ class LanceTable(Table):
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
):
|
||||
self._dataset_mut.create_scalar_index(
|
||||
column, index_type=index_type, replace=replace
|
||||
@@ -1201,42 +1252,13 @@ class LanceTable(Table):
|
||||
tokenizer_name: str = "default",
|
||||
use_tantivy: bool = True,
|
||||
):
|
||||
"""Create a full-text search index on the table.
|
||||
|
||||
Warning - this API is highly experimental and is highly likely to change
|
||||
in the future.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_names: str or list of str
|
||||
The name(s) of the field to index.
|
||||
can be only str if use_tantivy=True for now.
|
||||
replace: bool, default False
|
||||
If True, replace the existing index if it exists. Note that this is
|
||||
not yet an atomic operation; the index will be temporarily
|
||||
unavailable while the new index is being created.
|
||||
writer_heap_size: int, default 1GB
|
||||
ordering_field_names:
|
||||
A list of unsigned type fields to index to optionally order
|
||||
results on at search time.
|
||||
only available with use_tantivy=True
|
||||
tokenizer_name: str, default "default"
|
||||
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
|
||||
language code followed by "_stem". So for english it would be "en_stem".
|
||||
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
|
||||
only available with use_tantivy=True for now
|
||||
use_tantivy: bool, default False
|
||||
If True, use the legacy full-text search implementation based on tantivy.
|
||||
If False, use the new full-text search implementation based on lance-index.
|
||||
"""
|
||||
if not use_tantivy:
|
||||
if not isinstance(field_names, str):
|
||||
raise ValueError("field_names must be a string when use_tantivy=False")
|
||||
# delete the existing legacy index if it exists
|
||||
if replace:
|
||||
fs, path = fs_from_uri(self._get_fts_index_path())
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
fs.delete_dir(path)
|
||||
self._dataset_mut.create_scalar_index(
|
||||
field_names, index_type="INVERTED", replace=replace
|
||||
@@ -1251,9 +1273,8 @@ class LanceTable(Table):
|
||||
if isinstance(ordering_field_names, str):
|
||||
ordering_field_names = [ordering_field_names]
|
||||
|
||||
fs, path = fs_from_uri(self._get_fts_index_path())
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
if not replace:
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
@@ -1264,7 +1285,7 @@ class LanceTable(Table):
|
||||
)
|
||||
|
||||
index = create_index(
|
||||
self._get_fts_index_path(),
|
||||
path,
|
||||
field_names,
|
||||
ordering_fields=ordering_field_names,
|
||||
tokenizer_name=tokenizer_name,
|
||||
@@ -1277,13 +1298,6 @@ class LanceTable(Table):
|
||||
writer_heap_size=writer_heap_size,
|
||||
)
|
||||
|
||||
def _get_fts_index_path(self):
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
raise NotImplementedError(
|
||||
"Full-text search is not supported on object stores."
|
||||
)
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -1411,7 +1425,7 @@ class LanceTable(Table):
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -1479,14 +1493,11 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
if query_type == "fts":
|
||||
vector_column_name = ""
|
||||
else:
|
||||
raise e
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
@@ -1494,6 +1505,7 @@ class LanceTable(Table):
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1677,18 +1689,22 @@ class LanceTable(Table):
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
ds = self.to_lance()
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
nearest = None
|
||||
if len(query.vector) > 0:
|
||||
nearest = {
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
}
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
limit=query.k,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest=nearest,
|
||||
full_text_query=query.full_text_query,
|
||||
with_row_id=query.with_row_id,
|
||||
batch_size=batch_size,
|
||||
@@ -2113,7 +2129,7 @@ class AsyncTable:
|
||||
column: str,
|
||||
*,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList]] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
@@ -2438,7 +2454,10 @@ class AsyncTable:
|
||||
await self._inner.restore()
|
||||
|
||||
async def optimize(
|
||||
self, *, cleanup_older_than: Optional[timedelta] = None
|
||||
self,
|
||||
*,
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
) -> OptimizeStats:
|
||||
"""
|
||||
Optimize the on-disk data and indices for better performance.
|
||||
@@ -2457,6 +2476,11 @@ class AsyncTable:
|
||||
All files belonging to versions older than this will be removed. Set
|
||||
to 0 days to remove all versions except the latest. The latest version
|
||||
is never removed.
|
||||
delete_unverified: bool, default False
|
||||
Files leftover from a failed transaction may appear to be part of an
|
||||
in-progress operation (e.g. appending new data) and these files will not
|
||||
be deleted unless they are at least 7 days old. If delete_unverified is True
|
||||
then these files will be deleted regardless of their age.
|
||||
|
||||
Experimental API
|
||||
----------------
|
||||
@@ -2478,7 +2502,7 @@ class AsyncTable:
|
||||
"""
|
||||
if cleanup_older_than is not None:
|
||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(cleanup_older_than)
|
||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||
|
||||
async def list_indices(self) -> IndexConfig:
|
||||
"""
|
||||
|
||||
@@ -15,6 +15,7 @@ import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
from lancedb.index import FTS
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
@@ -28,14 +29,26 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
db = ldb.connect(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
|
||||
text_nouns = ("puppy", "car")
|
||||
text2_nouns = ("rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
nouns[random.randrange(0, 5)],
|
||||
text_nouns[random.randrange(0, len(text_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
text2 = [
|
||||
" ".join(
|
||||
[
|
||||
text2_nouns[random.randrange(0, len(text2_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
@@ -51,7 +64,56 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"text2": text2,
|
||||
"nested": [{"text": t} for t in text],
|
||||
"count": count,
|
||||
}
|
||||
),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
db = await ldb.connect_async(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
text_nouns = ("puppy", "car")
|
||||
text2_nouns = ("rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
text_nouns[random.randrange(0, len(text_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
text2 = [
|
||||
" ".join(
|
||||
[
|
||||
text2_nouns[random.randrange(0, len(text2_nouns))],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
count = [random.randint(1, 10000) for _ in range(100)]
|
||||
table = await db.create_table(
|
||||
"test",
|
||||
data=pd.DataFrame(
|
||||
{
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text2,
|
||||
"nested": [{"text": t} for t in text],
|
||||
"count": count,
|
||||
}
|
||||
@@ -91,17 +153,92 @@ def test_search_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
ldb.fts.populate_index(index, table, ["text"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=10)
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=5)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 10 # row_ids
|
||||
assert len(results[1]) == 10 # _distance
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_fts(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
results = table.search("puppy").limit(10).to_list()
|
||||
assert len(results) == 10
|
||||
results = table.search("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
def test_search_fts_specify_column(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text2", use_tantivy=False)
|
||||
|
||||
results = table.search("puppy", fts_columns="text").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
results = table.search("rabbit", fts_columns="text2").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
try:
|
||||
# we can only specify one column for now
|
||||
table.search("puppy", fts_columns=["text", "text2"]).limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# have to specify a column because we have two fts indices
|
||||
table.search("puppy").limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_async(async_table):
|
||||
async_table = await async_table
|
||||
await async_table.create_index("text", config=FTS())
|
||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_specify_column_async(async_table):
|
||||
async_table = await async_table
|
||||
await async_table.create_index("text", config=FTS())
|
||||
await async_table.create_index("text2", config=FTS())
|
||||
|
||||
results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("puppy", columns="text")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("rabbit", columns="text2")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
try:
|
||||
# we can only specify one column for now
|
||||
await (
|
||||
async_table.query()
|
||||
.nearest_to_text("rabbit", columns="text2")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# have to specify a column because we have two fts indices
|
||||
await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def test_search_ordering_field_index_table(tmp_path, table):
|
||||
@@ -125,11 +262,11 @@ def test_search_ordering_field_index(tmp_path, table):
|
||||
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(
|
||||
index, query="puppy", limit=10, ordering_field="count"
|
||||
index, query="puppy", limit=5, ordering_field="count"
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 10 # row_ids
|
||||
assert len(results[1]) == 10 # _distance
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _distance
|
||||
rows = table.to_lance().take(results[0]).to_pylist()
|
||||
|
||||
for r in rows:
|
||||
@@ -140,8 +277,8 @@ def test_search_ordering_field_index(tmp_path, table):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||
assert len(df) <= 10
|
||||
df = table.search("puppy").limit(5).select(["text"]).to_pandas()
|
||||
assert len(df) <= 5
|
||||
assert "text" in df.columns
|
||||
|
||||
# Check whether it can be updated
|
||||
@@ -167,8 +304,8 @@ def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 10
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 5
|
||||
assert "text" in df.columns
|
||||
assert "text2" in df.columns
|
||||
|
||||
@@ -176,14 +313,14 @@ def test_create_index_multiple_columns(tmp_path, table):
|
||||
def test_empty_rs(tmp_path, table, mocker):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 0
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text", use_tantivy=True)
|
||||
rs = table.search("puppy").limit(10).to_list()
|
||||
assert len(rs) == 10
|
||||
rs = table.search("puppy").limit(5).to_list()
|
||||
assert len(rs) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
@@ -117,6 +117,18 @@ def test_query_builder(table):
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
["id", "vector"]
|
||||
).to_list()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(None).select(
|
||||
["id", "vector"]
|
||||
).to_list()
|
||||
|
||||
|
||||
def test_query_builder_batches(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
|
||||
@@ -15,6 +15,7 @@ from lancedb.rerankers import (
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
AnswerdotaiRerankers,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -110,7 +111,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
@@ -206,14 +209,26 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
# Fail if both query and (vector or text) are provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").vector(
|
||||
query_vector
|
||||
).to_arrow()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
@@ -236,33 +251,45 @@ def test_rrf_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cohere_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_colbert_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = AnswerdotaiRerankers()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_openai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = OpenaiReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -270,8 +297,9 @@ def test_openai_reranker(tmp_path):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
|
||||
)
|
||||
def test_jina_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_jina_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("jina")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = JinaReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -251,7 +251,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
||||
|
||||
# FTS indices should error since they are not supported yet.
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Full-text search is not supported on object stores."
|
||||
NotImplementedError,
|
||||
match="Full-text search is only supported on the local filesystem",
|
||||
):
|
||||
table.create_fts_index("x")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
import os
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
@@ -27,7 +28,7 @@ from pydantic import BaseModel
|
||||
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.uri = str(uri)
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
@@ -1052,3 +1053,25 @@ async def test_optimize(db_async: AsyncConnection):
|
||||
assert stats.prune.old_versions_removed == 3
|
||||
|
||||
assert await table.query().to_arrow() == pa.table({"x": [[1], [2]]})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optimize_delete_unverified(db_async: AsyncConnection, tmp_path):
|
||||
table = await db_async.create_table(
|
||||
"test",
|
||||
data=[{"x": [1]}],
|
||||
)
|
||||
await table.add(
|
||||
data=[
|
||||
{"x": [2]},
|
||||
],
|
||||
)
|
||||
version = await table.version()
|
||||
path = tmp_path / "test.lance" / "_versions" / f"{version - 1}.manifest"
|
||||
os.remove(path)
|
||||
stats = await table.optimize(delete_unverified=False)
|
||||
assert stats.prune.old_versions_removed == 0
|
||||
stats = await table.optimize(
|
||||
cleanup_older_than=timedelta(seconds=0), delete_unverified=True
|
||||
)
|
||||
assert stats.prune.old_versions_removed == 2
|
||||
|
||||
@@ -98,6 +98,13 @@ impl Index {
|
||||
inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))),
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn fts() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(Default::default()))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
|
||||
@@ -15,17 +15,20 @@
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::pyclass;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::Bound;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::{pyclass, PyErr};
|
||||
use pyo3_asyncio_0_21::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
@@ -68,6 +71,24 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<()> {
|
||||
let query_text = query
|
||||
.get_item("query")?
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"Query text is required for nearest_to_text",
|
||||
))?
|
||||
.extract::<String>()?;
|
||||
let columns = query
|
||||
.get_item("columns")?
|
||||
.map(|columns| columns.extract::<Vec<String>>())
|
||||
.transpose()?;
|
||||
|
||||
let fts_query = FullTextSearchQuery::new(query_text).columns(columns);
|
||||
self.inner = self.inner.clone().full_text_search(fts_query);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
|
||||
@@ -248,6 +248,7 @@ impl Table {
|
||||
pub fn optimize(
|
||||
self_: PyRef<'_, Self>,
|
||||
cleanup_since_ms: Option<u64>,
|
||||
delete_unverified: Option<bool>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let older_than = if let Some(ms) = cleanup_since_ms {
|
||||
@@ -275,7 +276,7 @@ impl Table {
|
||||
let prune_stats = inner
|
||||
.optimize(OptimizeAction::Prune {
|
||||
older_than,
|
||||
delete_unverified: None,
|
||||
delete_unverified,
|
||||
error_if_tagged_old_versions: None,
|
||||
})
|
||||
.await
|
||||
|
||||
Reference in New Issue
Block a user