Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep

This commit is contained in:
BubbleCal
2024-11-11 17:36:06 +08:00
18 changed files with 867 additions and 7 deletions

View File

@@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction

View File

@@ -0,0 +1,127 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import ClassVar, List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT
@register("voyageai")
class VoyageAIEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the VoyageAI API
https://docs.voyageai.com/docs/embeddings
Parameters
----------
name: str
The name of the model to use. List of acceptable models:
* voyage-3
* voyage-3-lite
* voyage-finance-2
* voyage-multilingual-2
* voyage-law-2
* voyage-code-2
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
"""
name: str
client: ClassVar = None
def ndims(self):
if self.name == "voyage-3-lite":
return 512
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-3",
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
]:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type="query")
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or "document"
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
input_type: Optional[str]
truncation: Optional[bool]
"""
VoyageAIEmbeddingFunction._init_client()
rs = VoyageAIEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
return [emb for emb in rs.embeddings]
@staticmethod
def _init_client():
if VoyageAIEmbeddingFunction.client is None:
voyageai = attempt_import_or_raise("voyageai")
if os.environ.get("VOYAGE_API_KEY") is None:
api_key_not_found_help("voyageai")
VoyageAIEmbeddingFunction.client = voyageai.Client(
os.environ["VOYAGE_API_KEY"]
)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
import asyncio
import logging
from functools import cached_property
@@ -495,6 +496,19 @@ class RemoteTable(Table):
"compact_files() is not supported on the LanceDB cloud"
)
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""optimize() is not supported on the LanceDB cloud.
Indices are optimized automatically."""
raise NotImplementedError(
"optimize() is not supported on the LanceDB cloud. "
"Indices are optimized automatically."
)
def count_rows(self, filter: Optional[str] = None) -> int:
return self._loop.run_until_complete(self._table.count_rows(filter))

View File

@@ -7,6 +7,7 @@ from .openai import OpenaiReranker
from .jinaai import JinaReranker
from .rrf import RRFReranker
from .answerdotai import AnswerdotaiRerankers
from .voyageai import VoyageAIReranker
__all__ = [
"Reranker",
@@ -18,4 +19,5 @@ __all__ = [
"JinaReranker",
"RRFReranker",
"AnswerdotaiRerankers",
"VoyageAIReranker",
]

View File

@@ -0,0 +1,133 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import cached_property
from typing import Union, Optional
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class VoyageAIReranker(Reranker):
"""
Reranks the results using the VoyageAI Rerank API.
https://docs.voyageai.com/docs/reranker
Parameters
----------
model_name : str, default "rerank-english-v2.0"
The name of the cross encoder model to use. Available voyageai models are:
- rerank-2
- rerank-2-lite
column : str, default "text"
The name of the column to use as input to the cross encoder model.
top_n : int, default None
The number of results to return. If None, will return all results.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
api_key : str, default None
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
truncation : Optional[bool], default None
"""
def __init__(
self,
model_name: str,
column: str = "text",
top_n: Optional[int] = None,
return_score="relevance",
api_key: Optional[str] = None,
truncation: Optional[bool] = True,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.top_n = top_n
self.api_key = api_key
self.truncation = truncation
@cached_property
def _client(self):
voyageai = attempt_import_or_raise("voyageai")
if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None:
raise ValueError(
"VOYAGE_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the VoyageAIReranker."
)
return voyageai.Client(
api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key,
)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
response = self._client.rerank(
query=query,
documents=docs,
top_k=self.top_n,
model=self.model_name,
truncation=self.truncation,
)
results = (
response.results
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])
) # tuples
result_set = result_set.take(list(indices))
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for voyageai reranker"
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_score"])
return result_set

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import inspect
import time
from abc import ABC, abstractmethod
@@ -32,7 +33,7 @@ import pyarrow.fs as pa_fs
from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .common import DATA, VEC, VECTOR_COLUMN_NAME, sanitize_uri
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
@@ -57,6 +58,8 @@ from .util import (
)
from .index import lang_mapping
from ._lancedb import connect as lancedb_connect
if TYPE_CHECKING:
import PIL
from lance.dataset import CleanupStats, ReaderLike
@@ -893,6 +896,55 @@ class Table(ABC):
For most cases, the default should be fine.
"""
@abstractmethod
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""
Optimize the on-disk data and indices for better performance.
Modeled after ``VACUUM`` in PostgreSQL.
Optimization covers three operations:
* Compaction: Merges small files into larger ones
* Prune: Removes old versions of the dataset
* Index: Optimizes the indices, adding new data to existing indices
Parameters
----------
cleanup_older_than: timedelta, optional default 7 days
All files belonging to versions older than this will be removed. Set
to 0 days to remove all versions except the latest. The latest version
is never removed.
delete_unverified: bool, default False
Files leftover from a failed transaction may appear to be part of an
in-progress operation (e.g. appending new data) and these files will not
be deleted unless they are at least 7 days old. If delete_unverified is True
then these files will be deleted regardless of their age.
Experimental API
----------------
The optimization process is undergoing active development and may change.
Our goal with these changes is to improve the performance of optimization and
reduce the complexity.
That being said, it is essential today to run optimize if you want the best
performance. It should be stable and safe to use in production, but it our
hope that the API may be simplified (or not even need to be called) in the
future.
The frequency an application shoudl call optimize is based on the frequency of
data modifications. If data is frequently added, deleted, or updated then
optimize should be run frequently. A good rule of thumb is to run optimize if
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
@abstractmethod
def add_columns(self, transforms: Dict[str, str]):
"""
@@ -1971,6 +2023,83 @@ class LanceTable(Table):
"""
return self.to_lance().optimize.compact_files(*args, **kwargs)
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""
Optimize the on-disk data and indices for better performance.
Modeled after ``VACUUM`` in PostgreSQL.
Optimization covers three operations:
* Compaction: Merges small files into larger ones
* Prune: Removes old versions of the dataset
* Index: Optimizes the indices, adding new data to existing indices
Parameters
----------
cleanup_older_than: timedelta, optional default 7 days
All files belonging to versions older than this will be removed. Set
to 0 days to remove all versions except the latest. The latest version
is never removed.
delete_unverified: bool, default False
Files leftover from a failed transaction may appear to be part of an
in-progress operation (e.g. appending new data) and these files will not
be deleted unless they are at least 7 days old. If delete_unverified is True
then these files will be deleted regardless of their age.
Experimental API
----------------
The optimization process is undergoing active development and may change.
Our goal with these changes is to improve the performance of optimization and
reduce the complexity.
That being said, it is essential today to run optimize if you want the best
performance. It should be stable and safe to use in production, but it our
hope that the API may be simplified (or not even need to be called) in the
future.
The frequency an application shoudl call optimize is based on the frequency of
data modifications. If data is frequently added, deleted, or updated then
optimize should be run frequently. A good rule of thumb is to run optimize if
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
try:
asyncio.get_running_loop()
raise AssertionError(
"Synchronous method called in asynchronous context. "
"If you are writing an asynchronous application "
"then please use the asynchronous APIs"
)
except RuntimeError:
asyncio.run(
self._async_optimize(
cleanup_older_than=cleanup_older_than,
delete_unverified=delete_unverified,
)
)
self.checkout_latest()
async def _async_optimize(
self,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
conn = await lancedb_connect(
sanitize_uri(self._conn.uri),
)
table = AsyncTable(await conn.open_table(self.name))
return await table.optimize(
cleanup_older_than=cleanup_older_than, delete_unverified=delete_unverified
)
def add_columns(self, transforms: Dict[str, str]):
self._dataset_mut.add_columns(transforms)

View File

@@ -196,6 +196,7 @@ def test_add_optional_vector(tmp_path):
"ollama",
"cohere",
"instructor",
"voyageai",
],
)
def test_embedding_function_safe_model_dump(embedding_type):

View File

@@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path):
json.dumps(dumped_model)
except TypeError:
pytest.fail("Failed to JSON serialize the dumped model")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()

View File

@@ -16,6 +16,7 @@ from lancedb.rerankers import (
OpenaiReranker,
JinaReranker,
AnswerdotaiRerankers,
VoyageAIReranker,
)
from lancedb.table import LanceTable
@@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy):
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_voyageai_reranker(tmp_path, use_tantivy):
pytest.importorskip("voyageai")
reranker = VoyageAIReranker(model_name="rerank-2")
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)

View File

@@ -1223,6 +1223,54 @@ async def test_time_travel(db_async: AsyncConnection):
await table.restore()
def test_sync_optimize(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
table.create_scalar_index("price", index_type="BTREE")
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 2
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
assert table.count_rows() == 3
table.optimize()
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 3
@pytest.mark.asyncio
async def test_sync_optimize_in_async(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
table.create_scalar_index("price", index_type="BTREE")
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 2
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
assert table.count_rows() == 3
try:
table.optimize()
except Exception as e:
assert (
"Synchronous method called in asynchronous context. "
"If you are writing an asynchronous application "
"then please use the asynchronous APIs" in str(e)
)
@pytest.mark.asyncio
async def test_optimize(db_async: AsyncConnection):
table = await db_async.create_table(