Compare commits

..

11 Commits

Author SHA1 Message Date
Lance Release
1090c311e8 [python] Bump version: 0.6.10 → 0.6.11 2024-04-27 03:54:58 +00:00
Weston Pace
e767cbb374 chore: update to Lance version 0.10.16 and Arrow version 51 (#1247) 2024-04-26 16:26:57 -07:00
Weston Pace
3d7c48feca feat: allow the index_cache_size to be configured when opening a table (#1245)
This was already configurable in the rust API but it wasn't actually
being passed down to the underlying dataset. I added this option to both
the async python API and the new nodejs API.

I also added this option to the synchronous python API.

I did not add the option to vectordb.
2024-04-26 13:42:02 -07:00
Bert
08d62550bb fix: passing data to createTable as option (#1242)
Fixes issue where we would throw `Either data or schema needs to
defined` when passing `data` to `createTable` as a property of the first
argument (an object).

```ts
await db.createTable({
  name: 'table1',
  data,
  schema
})
```
2024-04-26 15:26:08 -04:00
Lei Xu
b272408b05 chore: fix main branch test failure (#1240) 2024-04-24 13:49:37 -07:00
Weston Pace
46ffa87cd4 chore: disable the remote feature by default (#1239)
The rust implementation of the remote client is not yet ready. This is
understandably confusing for users since it is enabled by default. This
PR disables it by default. We can re-enable it when we are ready (even
then it is not clear this is something that should be a default
feature).

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-04-24 09:28:24 -07:00
QianZhu
cd9fc37b95 add rename_table fn and more data for index_stats to return (#1234)
1. added rename_table fn to enable dashboard to rename a table
2. added index_type and distance_type (for vector index) to index_stats
so that more detailed data can be shown on the table page.
2024-04-23 16:42:26 -07:00
Lance Release
431f94e564 [python] Bump version: 0.6.9 → 0.6.10 2024-04-22 17:42:24 +00:00
Alex Kohler
c1a7d65473 chore: fix get_registry call in baai embeddings example (#1230) 2024-04-20 07:25:16 +05:30
Rob Meng
1e5ccb1614 chore: upgrade lance to 0.10.15 (#1229) 2024-04-19 10:31:39 -04:00
Bert
2e7ab373dc fix: update lance to 0.10.13 (#1226) 2024-04-17 09:29:10 -04:00
29 changed files with 201 additions and 1120 deletions

View File

@@ -14,19 +14,19 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.10.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.12" }
lance-linalg = { "version" = "=0.10.12" }
lance-testing = { "version" = "=0.10.12" }
lance = { "version" = "=0.10.16", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.16" }
lance-linalg = { "version" = "=0.10.16" }
lance-testing = { "version" = "=0.10.16" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"
arrow-data = "50.0"
arrow-ipc = "50.0"
arrow-ord = "50.0"
arrow-schema = "50.0"
arrow-arith = "50.0"
arrow-cast = "50.0"
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"
arrow-data = "51.0"
arrow-ipc = "51.0"
arrow-ord = "51.0"
arrow-schema = "51.0"
arrow-arith = "51.0"
arrow-cast = "51.0"
async-trait = "0"
chrono = "0.4.35"
half = { "version" = "=2.3.1", default-features = false, features = [

View File

@@ -1,280 +0,0 @@
import argparse
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.core.llama_dataset import LabelledRagDataset
from llama_index.core.node_parser import SentenceSplitter
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
from lancedb.embeddings.fine_tuner.llm import Openai
from lancedb.embeddings import get_registry
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.llama_pack import download_llama_pack
import time
import wandb
import pandas as pd
def get_paths_from_dataset(dataset: str, split=True):
"""
Returns paths of:
- downloaded dataset, lance train dataset, lance test dataset, finetuned model
"""
if split:
return (
f"./data/{dataset}",
f"./data/{dataset}_lance_train",
f"./data/{dataset}_lance_test",
f"./data/tuned_{dataset}",
)
return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}"
def get_llama_dataset(dataset: str):
"""
returns:
- nodes, documents, rag_dataset
"""
if not os.path.exists(f"./data/{dataset}"):
os.system(
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" # noqa
)
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
return nodes, docs, rag_dataset
def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
llm = Openai()
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
# train test split 75-35
if not split:
if os.path.exists(f"./data/{name}_lance"):
ds = QADataset.load(f"./data/{name}_lance")
return ds
ds = QADataset.from_llm(chunks, llm)
ds.save(f"./data/{name}_lance")
return ds
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(
f"./data/{name}_lance_test"
):
train_ds = QADataset.load(f"./data/{name}_lance_train")
test_ds = QADataset.load(f"./data/{name}_lance_test")
return train_ds, test_ds
# split chunks random
train_size = int(len(chunks) * 0.65)
train_chunks = chunks[:train_size]
test_chunks = chunks[train_size:]
train_ds = QADataset.from_llm(train_chunks, llm)
test_ds = QADataset.from_llm(test_chunks, llm)
train_ds.save(f"./data/{name}_lance_train")
test_ds.save(f"./data/{name}_lance_test")
return train_ds, test_ds
def finetune(
trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5
):
print(f"Finetuning {model} for {epochs} epochs")
print(f"trainset query instances: {len(trainset.queries)}")
print(f"valset query instances: {len(valset.queries)}")
valset = valset if valset is not None else trainset
model = get_registry().get("sentence-transformers").create(name=model)
base_result = model.evaluate(valset, path="./data/eval/", top_k=top_k)
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
tuned = get_registry().get("sentence-transformers").create(name=path)
tuned_result = tuned.evaluate(
valset, path=f"./data/eval/{str(time.time())}", top_k=top_k
)
tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean()
return base_hit_rate, tuned_hit_rate
def do_eval_rag(dataset: str, model: str):
# Requires - pip install llama-index-vector-stores-lancedb
# Requires - pip install llama-index-embeddings-huggingface
nodes, docs, rag_dataset = get_llama_dataset(dataset)
embed_model = HuggingFaceEmbedding(model)
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
index = VectorStoreIndex(
nodes,
service_context=service_context,
show_progress=True,
storage_context=storage_context,
)
# build basic RAG system
index = VectorStoreIndex.from_documents(documents=docs)
query_engine = index.as_query_engine()
# evaluate using the RagEvaluatorPack
RagEvaluatorPack = download_llama_pack("RagEvaluatorPack", "./rag_evaluator_pack")
rag_evaluator_pack = RagEvaluatorPack(
rag_dataset=rag_dataset, query_engine=query_engine
)
metrics = rag_evaluator_pack.run(
batch_size=20, # batches the number of openai api calls to make
sleep_time_in_seconds=1, # seconds to sleep before making an api call
)
return metrics
def main(
dataset,
model,
epochs,
top_k=5,
eval_rag=False,
split=True,
project: str = "lancedb_finetune",
):
nodes, _, _ = get_llama_dataset(dataset)
trainset = None
valset = None
if split:
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset, split)
data_path, lance_train_path, lance_test_path, tuned_path = (
get_paths_from_dataset(dataset, split=split)
)
else:
trainset = lance_dataset_from_llama_nodes(nodes, dataset, split)
valset = trainset
data_path, lance_path, tuned_path = get_paths_from_dataset(dataset, split=split)
base_hit_rate, tuned_hit_rate = finetune(
trainset, model, epochs, tuned_path, valset, top_k=top_k
)
# Base model model metrics
metrics = do_eval_rag(dataset, model) if eval_rag else {}
# Tuned model metrics
metrics_tuned = do_eval_rag(dataset, tuned_path) if eval_rag else {}
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
wandb.log(
{
"hit_rate": tuned_hit_rate,
}
)
wandb.log(metrics_tuned)
wandb.finish()
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
wandb.log(
{
"hit_rate": base_hit_rate,
}
)
wandb.log(metrics)
wandb.finish()
def banchmark_all():
datasets = [
"Uber10KDataset2021",
"MiniTruthfulQADataset",
"MiniSquadV2Dataset",
"MiniEsgBenchDataset",
"MiniCovidQaDataset",
"Llama2PaperDataset",
"HistoryOfAlexnetDataset",
"PatronusAIFinanceBenchDataset",
]
models = ["BAAI/bge-small-en-v1.5"]
top_ks = [5]
for top_k in top_ks:
for model in models:
for dataset in datasets:
main(dataset, model, 5, top_k=top_k)
if __name__ == "__main__":
"""
Benchmark the fine-tuning process for a given dataset and model.
Usage:
- For a single dataset
python lancedb/docs/benchmarks/llama-index-datasets.py --dataset Uber10KDataset2021 --model BAAI/bge-small-en-v1.5 --epochs 4 --top_k 5 --split 1 --project lancedb_finetune
- For all datasets and models across all top_ks
python lancedb/docs/benchmarks/llama-index-datasets.py --benchmark-all
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="BraintrustCodaHelpDeskDataset",
help="The dataset to use for fine-tuning",
)
parser.add_argument(
"--model",
type=str,
default="BAAI/bge-small-en-v1.5",
help="The model to use for fine-tuning",
)
parser.add_argument(
"--epochs",
type=int,
default=4,
help="The number of epochs to fine-tune the model",
)
parser.add_argument(
"--project",
type=str,
default="lancedb_finetune",
help="The wandb project to log the results",
)
parser.add_argument(
"--top_k", type=int, default=5, help="The number of top results to evaluate"
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Whether to split the dataset into train and test(65-35 split), default is 1",
)
parser.add_argument(
"--eval-rag",
action="store_true",
default=False,
help="Whether to evaluate the model using RAG",
)
parser.add_argument(
"--benchmark-all",
action="store_true",
default=False,
help="Benchmark all datasets across all models and top_ks",
)
args = parser.parse_args()
if args.benchmark_all:
banchmark_all()
else:
main(
args.dataset,
args.model,
args.epochs,
args.top_k,
args.eval_rag,
args.split,
args.project,
)

View File

@@ -159,7 +159,7 @@ Allows you to set parameters when registering a `sentence-transformers` object.
from lancedb.embeddings import get_registry
db = lancedb.connect("/tmp/db")
model = get_registry.get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
class Words(LanceModel):
text: str = model.SourceField()

View File

@@ -140,6 +140,9 @@ export class RemoteConnection implements Connection {
schema = nameOrOpts.schema
embeddings = nameOrOpts.embeddingFunction
tableName = nameOrOpts.name
if (data === undefined) {
data = nameOrOpts.data
}
}
let buffer: Buffer

View File

@@ -77,6 +77,18 @@ export interface OpenTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
/**
* Set the size of the index cache, specified as a number of entries
*
* The exact meaning of an "entry" will depend on the type of index:
* - IVF: there is one entry for each IVF partition
* - BTREE: there is one entry for the entire index
*
* This cache applies to the entire opened table, across all indices.
* Setting this value higher will increase performance on larger datasets
* at the expense of more RAM
*/
indexCacheSize?: number;
}
export interface TableNamesOptions {
@@ -160,6 +172,7 @@ export class Connection {
const innerTable = await this.inner.openTable(
name,
cleanseStorageOptions(options?.storageOptions),
options?.indexCacheSize,
);
return new Table(innerTable);
}

View File

@@ -176,6 +176,7 @@ impl Connection {
&self,
name: String,
storage_options: Option<HashMap<String, String>>,
index_cache_size: Option<u32>,
) -> napi::Result<Table> {
let mut builder = self.get_inner()?.open_table(&name);
if let Some(storage_options) = storage_options {
@@ -183,6 +184,9 @@ impl Connection {
builder = builder.storage_option(key, value);
}
}
if let Some(index_cache_size) = index_cache_size {
builder = builder.index_cache_size(index_cache_size);
}
let tbl = builder
.execute()
.await

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.9
current_version = 0.6.11
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -14,7 +14,7 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "50.0.0", features = ["pyarrow"] }
arrow = { version = "51.0.0", features = ["pyarrow"] }
lancedb = { path = "../rust/lancedb" }
env_logger = "0.10"
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }

View File

@@ -1,6 +1,6 @@
[project]
name = "lancedb"
version = "0.6.9"
version = "0.6.11"
dependencies = [
"deprecation",
"pylance==0.10.12",

View File

@@ -224,13 +224,23 @@ class DBConnection(EnforceOverrides):
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def open_table(self, name: str) -> Table:
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
index_cache_size: int, default 256
Set the size of the index cache, specified as a number of entries
The exact meaning of an "entry" will depend on the type of index:
* IVF - there is one entry for each IVF partition
* BTREE - there is one entry for the entire index
This cache applies to the entire opened table, across all indices.
Setting this value higher will increase performance on larger datasets
at the expense of more RAM
Returns
-------
@@ -248,6 +258,18 @@ class DBConnection(EnforceOverrides):
"""
raise NotImplementedError
def rename_table(self, cur_name: str, new_name: str):
"""Rename a table in the database.
Parameters
----------
cur_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
raise NotImplementedError
def drop_database(self):
"""
Drop database
@@ -407,7 +429,9 @@ class LanceDBConnection(DBConnection):
return tbl
@override
def open_table(self, name: str) -> LanceTable:
def open_table(
self, name: str, *, index_cache_size: Optional[int] = None
) -> LanceTable:
"""Open a table in the database.
Parameters
@@ -419,7 +443,7 @@ class LanceDBConnection(DBConnection):
-------
A LanceTable object representing the table.
"""
return LanceTable.open(self, name)
return LanceTable.open(self, name, index_cache_size=index_cache_size)
@override
def drop_table(self, name: str, ignore_missing: bool = False):
@@ -751,7 +775,10 @@ class AsyncConnection(object):
return AsyncTable(new_table)
async def open_table(
self, name: str, storage_options: Optional[Dict[str, str]] = None
self,
name: str,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> Table:
"""Open a Lance Table in the database.
@@ -764,12 +791,22 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
index_cache_size: int, default 256
Set the size of the index cache, specified as a number of entries
The exact meaning of an "entry" will depend on the type of index:
* IVF - there is one entry for each IVF partition
* BTREE - there is one entry for the entire index
This cache applies to the entire opened table, across all indices.
Setting this value higher will increase performance on larger datasets
at the expense of more RAM
Returns
-------
A LanceTable object representing the table.
"""
table = await self._inner.open_table(name, storage_options)
table = await self._inner.open_table(name, storage_options, index_cache_size)
return AsyncTable(table)
async def drop_table(self, name: str):

View File

@@ -10,18 +10,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import List, Union
import numpy as np
import pyarrow as pa
from pydantic import BaseModel, Field, PrivateAttr
from tqdm import tqdm
import lancedb
from .fine_tuner import QADataset
from .utils import TEXT, retry_with_exponential_backoff
@@ -131,22 +126,6 @@ class EmbeddingFunction(BaseModel, ABC):
def __hash__(self) -> int:
return hash(frozenset(vars(self).items()))
def finetune(self, dataset: QADataset, *args, **kwargs):
"""
Finetune the embedding function on a dataset
"""
raise NotImplementedError(
"Finetuning is not supported for this embedding function"
)
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
"""
Evaluate the embedding function on a dataset
"""
raise NotImplementedError(
"Evaluation is not supported for this embedding function"
)
class EmbeddingFunctionConfig(BaseModel):
"""
@@ -180,52 +159,3 @@ class TextEmbeddingFunction(EmbeddingFunction):
Generate the embeddings for the given texts
"""
pass
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
"""
Evaluate the embedding function on a dataset. This calculates the hit-rate for
the top-k retrieved documents for each query in the dataset. Assumes that the
first relevant document is the expected document.
Pro - Should work for any embedding model
Con - Returns every simple metric.
Parameters
----------
dataset: QADataset
The dataset to evaluate on
Returns
-------
dict
The evaluation results
"""
corpus = dataset.corpus
queries = dataset.queries
relevant_docs = dataset.relevant_docs
path = path or os.path.join(os.getcwd(), "eval")
db = lancedb.connect(path)
class Schema(lancedb.pydantic.LanceModel):
id: str
text: str = self.SourceField()
vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField()
retriever = db.create_table("eval", schema=Schema, mode="overwrite")
pylist = [{"id": str(k), "text": v} for k, v in corpus.items()]
retriever.add(pylist)
eval_results = []
for query_id, query in tqdm(queries.items()):
retrieved_nodes = retriever.search(query).limit(top_k).to_list()
retrieved_ids = [node["id"] for node in retrieved_nodes]
expected_id = relevant_docs[query_id][0]
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
eval_result = {
"is_hit": is_hit,
"retrieved": retrieved_ids,
"expected": expected_id,
"query": query_id,
}
eval_results.append(eval_result)
return eval_results

View File

@@ -1,147 +0,0 @@
### Fine-tuning workflow
The fine-tuning workflow is as follows:
1. Create a `QADataset` object.
2. Initialize any embedding function using LanceDB embedding API
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
# End-to-End Examples
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
## Example 1: Fine-tuning from a synthetic dataset
```python
import os
import pandas as pd
from lancedb.embeddings.fine_tuner.llm import Openai
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
from lancedb.pydantic import LanceModel, Vector
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode
from lancedb.embeddings import get_registry
dataset = "Uber10KDataset2021"
lance_dataset_dir = dataset + "_lance"
valset_dir = dataset + "_lance_val"
finetuned_model_path = "./model_finetuned"
# 1. Create a QADataset object. See all datasets on llama-index here: https://github.com/run-llama/llama_index/tree/main/llama-datasets
if not os.path.exists(f"./data/{dataset}"):
os.system(
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
)
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
# convert Llama-index TextNode to TextChunk
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
llm = Openai()
if os.path.exists(lance_dataset_dir):
trainset = QADataset.load(lance_dataset_dir)
else:
trainset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
trainset.save(lance_dataset_dir)
# Ideally, we should have a standard dataset for validation, but here we're just generating a synthetic dataset.
if os.path.exists(valset_dir):
valset = QADataset.load(valset_dir)
else:
valset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=4)
valset.save(valset_dir)
# 2. Initialize the embedding model
model = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
# 3. Fine-tune the model
model.finetune(trainset=trainset, path=finetuned_model_path, epochs=4)
# 4. Evaluate the fine-tuned model
base = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
base_results = base.evaluate(valset, top_k=5)
tuned = get_registry().get("sentence-transformers").create(name=finetuned_model_path)
tuned_results = tuned.evaluate(valset, top_k=5)
openai = get_registry().get("openai").create(name="text-embedding-3-small")
openai_results = openai.evaluate(valset, top_k=5)
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(openai_results)["is_hit"].mean())
print("fine-tuned hit-rate - ", pd.DataFrame(tuned_results)["is_hit"].mean())
print("Base model hite-rate - ", pd.DataFrame(base_results)["is_hit"].mean())
```
Fine-tuning workflow for embeddings consists for the following parts:
### QADataset
This class is used for managing the data for fine-tuning. It contains the following builder methods:
```
- from_llm(
nodes: 'List[TextChunk]' ,
llm: BaseLLM,
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> "QADataset"
```
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
```python
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
```
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
It also contains the following data attributes:
```
queries (Dict[str, str]): Dict id -> query.
corpus (Dict[str, str]): Dict id -> string.
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
```
### TextChunk
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
```
text: str
id: str
metadata: Dict[str, Any] = {}
```
Builder Methods:
```python
from_llama_index_node(node) -> "TextChunk"
```
Create a text chunk from a llama index node.
```python
from_langchain_node(node) -> "TextChunk"
```
Create a text chunk from a langchain index node.
```python
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
```
Create a text chunk from a string.
### FineTuner
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
```python
class BaseEmbeddingTuner(ABC):
"""Base Embedding finetuning engine."""
@abstractmethod
def finetune(self) -> None:
"""Goes off and does stuff."""
def helper(self) -> None:
"""A helper method."""
pass
```
### Embedding API finetuning implementation
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.

View File

@@ -1,4 +0,0 @@
from .dataset import QADataset, TextChunk
from .llm import Gemini, Openai
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]

View File

@@ -1,19 +0,0 @@
from abc import ABC, abstractmethod
class BaseEmbeddingTuner(ABC):
"""Base Embedding finetuning engine."""
@abstractmethod
def finetune(self) -> None:
"""
Finetune the embedding model.
"""
pass
def helper(self) -> None:
"""
A helper method called after finetuning. This is meant to provide
usage instructions or other helpful information.
"""
pass

View File

@@ -1,283 +0,0 @@
import re
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple
import lance
import pyarrow as pa
from pydantic import BaseModel
from tqdm import tqdm
from .llm import BaseLLM
DEFAULT_PROMPT_TMPL = """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and no prior knowledge.
generate only questions based on the below query.
You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided."
"""
class QADataset(BaseModel):
"""Embedding QA Finetuning Dataset.
Args:
queries (Dict[str, str]): Dict id -> query.
corpus (Dict[str, str]): Dict id -> string.
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
"""
queries: Dict[str, str] # id -> query
corpus: Dict[str, str] # id -> text
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
mode: str = "text"
@property
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
"""Get query, relevant doc ids."""
return [
(query, self.relevant_docs[query_id])
for query_id, query in self.queries.items()
]
def save(self, path: str, mode: str = "overwrite") -> None:
"""
Save the current dataset to a directory as .lance files.
Parameters
----------
path : str
The path to save the dataset.
mode : str, optional
The mode to save the dataset, by default "overwrite". Accepts
lance modes.
"""
save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True)
# convert to pydict {"id": []}
queries = {
"id": list(self.queries.keys()),
"query": list(self.queries.values()),
}
corpus = {
"id": list(self.corpus.keys()),
"text": [
val or " " for val in self.corpus.values()
], # lance saves empty strings as null
}
relevant_docs = {
"query_id": list(self.relevant_docs.keys()),
"doc_id": list(self.relevant_docs.values()),
}
# write to lance
lance.write_dataset(
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
)
lance.write_dataset(
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
)
lance.write_dataset(
pa.Table.from_pydict(relevant_docs),
save_dir / "relevant_docs.lance",
mode=mode,
)
@classmethod
def load(cls, path: str) -> "QADataset":
"""
Load QADataset from a directory.
Parameters
----------
path : str
The path to load the dataset from.
Returns
-------
QADataset
The loaded QADataset.
"""
load_dir = Path(path)
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
relevant_docs = (
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
)
return cls(
queries=dict(zip(queries["id"], queries["query"])),
corpus=dict(zip(corpus["id"], corpus["text"])),
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
)
# generate queries as a convenience function
@classmethod
def from_llm(
cls,
nodes: "List[TextChunk]",
llm: BaseLLM,
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> "QADataset":
"""
Generate a QADataset from a list of TextChunks.
Parameters
----------
nodes : List[TextChunk]
The list of text chunks.
llm : BaseLLM
The language model to generate questions.
qa_generate_prompt_tmpl : str, optional
The template for generating questions, by default DEFAULT_PROMPT_TMPL.
num_questions_per_chunk : int, optional
The number of questions to generate per chunk, by default 2.
Returns
-------
QADataset
The generated QADataset.
"""
node_dict = {node.id: node.text for node in nodes}
queries = {}
relevant_docs = {}
for node_id, text in tqdm(node_dict.items()):
query = qa_generate_prompt_tmpl.format(
context_str=text, num_questions_per_chunk=num_questions_per_chunk
)
response = llm.chat_completion(query)
result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]
for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [node_id]
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
@classmethod
def from_responses(
cls,
nodes: List["TextChunk"],
queries: Dict[str, str],
relevant_docs: Dict[str, List[str]],
) -> "QADataset":
"""
Create a QADataset from a list of TextChunks and a list of
questions, queries, and relevant docs.
Parameters
----------
nodes : List[TextChunk]
The list of text chunks.
queries : Dict[str, str]
The queries. query id -> query.
relevant_docs : Dict[str, List[str]]
The relevant docs. Dict query id -> list of doc ids.
Returns
-------
QADataset
The QADataset.
"""
node_dict = {node.id: node.text for node in nodes}
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
class TextChunk(BaseModel):
"""
Simple text chunk for storing text nodes. Acts as a wrapper around text.
Allow interoperability between different text processing libraries.
Args:
text (str): The text of the chunk.
id (str): The id of the chunk.
metadata (Dict[str, Any], optional): The metadata of the chunk. Defaults to {}.
"""
text: str
id: str
metadata: Dict[str, Any] = {}
@classmethod
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
"""
Create a SimpleTextChunk from a chunk.
Parameters
----------
chunk : str
The text chunk.
metadata : dict, optional
The metadata, by default {}.
Returns
-------
TextChunk
The text chunk.
"""
# generate a unique id
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
@classmethod
def from_llama_index_node(cls, node):
"""
Generate a TextChunk from a llama index node.
Parameters
----------
node : llama_index.core.TextNode
The llama index node.
"""
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
@classmethod
def from_langchain_node(cls, node):
"""
Generate a TextChunk from a langchain node.
Parameters
----------
node : langchain.core.TextNode
The langchain node.
"""
raise NotImplementedError("Not implemented yet.")
def to_dict(self) -> Dict[str, Any]:
"""
Convert to a dictionary.
Returns
-------
Dict[str, Any]
The dictionary.
"""
return self.dict()
def __str__(self) -> str:
return self.text
def __repr__(self) -> str:
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
metadata={self.metadata})"

View File

@@ -1,88 +0,0 @@
import os
import re
from functools import cached_property
from typing import Optional
from pydantic import BaseModel
from ...util import attempt_import_or_raise
from ..utils import api_key_not_found_help
class BaseLLM(BaseModel):
"""
TODO:
Base class for Language Model based Embedding Functions. This class is
loosely desined rn, and will be updated as the usage gets clearer.
"""
class Config:
protected_namespaces = () # Disable protected namespace check
model_name: str
model_kwargs: dict = {}
@cached_property
def _client():
"""
Get the client for the language model
"""
raise NotImplementedError
def chat_completion(self, prompt: str, **kwargs):
"""
Get the chat completion for the given prompt
"""
raise NotImplementedError
class Openai(BaseLLM):
model_name: str = "gpt-3.5-turbo"
kwargs: dict = {}
api_key: Optional[str] = None
@cached_property
def _client(self):
"""
Get the client for the language model
"""
openai = attempt_import_or_raise("openai")
if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai")
return openai.OpenAI()
def chat_completion(self, prompt: str) -> str:
"""
Get the chat completion for the given prompt
"""
# TODO: this is legacy openai api replace with completions
completion = self._client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
**self.kwargs,
)
text = completion.choices[0].message.content
return text
def get_questions(self, prompt: str) -> str:
"""
Get the chat completion for the given prompt
"""
response = self.chat_completion(prompt)
result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]
return questions
class Gemini(BaseLLM):
pass

View File

@@ -10,15 +10,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from typing import List, Union
import numpy as np
import logging
from lancedb.embeddings.fine_tuner import QADataset
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .fine_tuner.basetuner import BaseEmbeddingTuner
from .registry import register
from .utils import weak_lru
@@ -83,151 +80,3 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
def finetune(self, trainset: QADataset, *args, **kwargs):
"""
Finetune the Sentence Transformers model
Parameters
----------
dataset: QADataset
The dataset to use for finetuning
"""
tuner = SentenceTransformersTuner(
model=self.embedding_model,
trainset=trainset,
**kwargs,
)
tuner.finetune()
class SentenceTransformersTuner(BaseEmbeddingTuner):
"""Sentence Transformers Embedding Finetuning Engine."""
def __init__(
self,
model: Any,
trainset: QADataset,
valset: Optional[QADataset] = None,
path: Optional[str] = "~/.lancedb/embeddings/models",
batch_size: int = 8,
epochs: int = 1,
show_progress: bool = True,
eval_steps: int = 50,
max_input_per_doc: int = -1,
loss: Optional[Any] = None,
evaluator: Optional[Any] = None,
run_name: Optional[str] = None,
log_wandb: bool = False,
) -> None:
"""
Parameters
----------
model: str
The model to use for finetuning.
trainset: QADataset
The training dataset.
valset: Optional[QADataset]
The validation dataset.
path: Optional[str]
The path to save the model.
batch_size: int, default=8
The batch size.
epochs: int, default=1
The number of epochs.
show_progress: bool, default=True
Whether to show progress.
eval_steps: int, default=50
The number of steps to evaluate.
max_input_per_doc: int, default=-1
The number of input per document.
if -1, use all documents.
"""
from sentence_transformers import InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
self.model = model
self.trainset = trainset
self.valset = valset
self.path = path
self.batch_size = batch_size
self.epochs = epochs
self.show_progress = show_progress
self.eval_steps = eval_steps
self.max_input_per_doc = max_input_per_doc
self.evaluator = None
self.epochs = epochs
self.show_progress = show_progress
self.eval_steps = eval_steps
self.run_name = run_name
self.log_wandb = log_wandb
if self.max_input_per_doc < -1:
raise ValueError("max_input_per_doc must be -1 or greater than 0.")
examples: Any = []
for query_id, query in self.trainset.queries.items():
if max_input_per_doc == -1:
for node_id in self.trainset.relevant_docs[query_id]:
text = self.trainset.corpus[node_id]
example = InputExample(texts=[query, text])
examples.append(example)
else:
node_id = self.trainset.relevant_docs[query_id][
min(max_input_per_doc, len(self.trainset.relevant_docs[query_id]))
]
text = self.trainset.corpus[node_id]
example = InputExample(texts=[query, text])
examples.append(example)
self.examples = examples
self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
if self.valset is not None:
eval_engine = evaluator or InformationRetrievalEvaluator
self.evaluator = eval_engine(
valset.queries, valset.corpus, valset.relevant_docs
)
self.evaluator = evaluator
# define loss
self.loss = loss or losses.MultipleNegativesRankingLoss(self.model)
self.warmup_steps = int(len(self.loader) * epochs * 0.1)
def finetune(self) -> None:
"""Finetune the Sentence Transformers model."""
self.model.fit(
train_objectives=[(self.loader, self.loss)],
epochs=self.epochs,
warmup_steps=self.warmup_steps,
output_path=self.path,
show_progress_bar=self.show_progress,
evaluator=self.evaluator,
evaluation_steps=self.eval_steps,
callback=self._wandb_callback if self.log_wandb else None,
)
self.helper()
def helper(self) -> None:
"""A helper method."""
logging.info("Finetuning complete.")
logging.info(f"Model saved to {self.path}.") # noqa
logging.info("You can now use the model as follows:")
logging.info(
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa
)
def _wandb_callback(self, score, epoch, steps):
try:
import wandb
except ImportError:
raise ImportError(
"wandb is not installed. Please install it using `pip install wandb`"
)
run = wandb.run or wandb.init(
project="sbert_lancedb_finetune", name=self.run_name
)
run.log({"epoch": epoch, "steps": steps, "score": score})

View File

@@ -94,7 +94,7 @@ class RemoteDBConnection(DBConnection):
yield item
@override
def open_table(self, name: str) -> Table:
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
"""Open a Lance Table in the database.
Parameters
@@ -110,6 +110,12 @@ class RemoteDBConnection(DBConnection):
self._client.mount_retry_adapter_for_table(name)
if index_cache_size is not None:
logging.info(
"index_cache_size is ignored in LanceDb Cloud"
" (there is no local cache to configure)"
)
# check if table exists
if self._table_cache.get(name) is None:
self._client.post(f"/v1/table/{name}/describe/")
@@ -281,6 +287,24 @@ class RemoteDBConnection(DBConnection):
)
self._table_cache.pop(name)
@override
def rename_table(self, cur_name: str, new_name: str):
"""Rename a table in the database.
Parameters
----------
cur_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
self._client.post(
f"/v1/table/{cur_name}/rename/",
json={"new_table_name": new_name},
)
self._table_cache.pop(cur_name)
self._table_cache[new_name] = True
async def close(self):
"""Close the connection to the database."""
self._client.close()

View File

@@ -72,7 +72,7 @@ class RemoteTable(Table):
return resp
def index_stats(self, index_uuid: str):
"""List all the indices on the table"""
"""List all the stats of a specified index"""
resp = self._conn._client.post(
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
)

View File

@@ -806,6 +806,7 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
"""Reference to the latest version of a LanceDataset."""
uri: str
index_cache_size: Optional[int] = None
read_consistency_interval: Optional[timedelta] = None
last_consistency_check: Optional[float] = None
_dataset: Optional[LanceDataset] = None
@@ -813,7 +814,9 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri)
self._dataset = lance.dataset(
self.uri, index_cache_size=self.index_cache_size
)
self.last_consistency_check = time.monotonic()
elif self.read_consistency_interval is not None:
now = time.monotonic()
@@ -842,12 +845,15 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
class _LanceTimeTravelRef(_LanceDatasetRef):
uri: str
version: int
index_cache_size: Optional[int] = None
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri, version=self.version)
self._dataset = lance.dataset(
self.uri, version=self.version, index_cache_size=self.index_cache_size
)
return self._dataset
@dataset.setter
@@ -884,6 +890,8 @@ class LanceTable(Table):
connection: "LanceDBConnection",
name: str,
version: Optional[int] = None,
*,
index_cache_size: Optional[int] = None,
):
self._conn = connection
self.name = name
@@ -892,11 +900,13 @@ class LanceTable(Table):
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
index_cache_size=index_cache_size,
)
else:
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=connection.read_consistency_interval,
index_cache_size=index_cache_size,
)
@classmethod

View File

@@ -368,6 +368,15 @@ async def test_create_exist_ok_async(tmp_path):
# await db.create_table("test", schema=bad_schema, exist_ok=True)
def test_open_table_sync(tmp_path):
db = lancedb.connect(tmp_path)
db.create_table("test", data=[{"id": 0}])
assert db.open_table("test").count_rows() == 1
assert db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(FileNotFoundError, match="does not exist"):
db.open_table("does_not_exist")
@pytest.mark.asyncio
async def test_open_table(tmp_path):
db = await lancedb.connect_async(tmp_path)
@@ -397,6 +406,10 @@ async def test_open_table(tmp_path):
}
)
# No way to verify this yet, but at least make sure we
# can pass the parameter
await db.open_table("test", index_cache_size=0)
with pytest.raises(ValueError, match="was not found"):
await db.open_table("does_not_exist")

View File

@@ -1,45 +0,0 @@
import uuid
import pytest
from lancedb.embeddings import get_registry
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
from tqdm import tqdm
@pytest.mark.slow
def test_finetuning_sentence_transformers(tmp_path):
queries = {}
relevant_docs = {}
chunks = [
"This is a chunk related to legal docs",
"This is another chunk related financial docs",
"This is a chunk related to sports docs",
"This is another chunk related to fashion docs",
]
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
for chunk in tqdm(text_chunks):
questions = [
"What is this chunk about?",
"What is the main topic of this chunk?",
]
for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [chunk.id]
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
assert len(ds.queries) == 8
assert len(ds.corpus) == 4
model = get_registry().get("sentence-transformers").create()
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
model = (
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
)
res = model.evaluate(ds)
assert res is not None
def test_text_chunk():
# TODO
pass

View File

@@ -134,17 +134,21 @@ impl Connection {
})
}
#[pyo3(signature = (name, storage_options = None))]
#[pyo3(signature = (name, storage_options = None, index_cache_size = None))]
pub fn open_table(
self_: PyRef<'_, Self>,
name: String,
storage_options: Option<HashMap<String, String>>,
index_cache_size: Option<u32>,
) -> PyResult<&PyAny> {
let inner = self_.get_inner()?.clone();
let mut builder = inner.open_table(name);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
if let Some(index_cache_size) = index_cache_size {
builder = builder.index_cache_size(index_cache_size);
}
future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))

View File

@@ -52,7 +52,7 @@ aws-sdk-kms = { version = "1.0" }
aws-config = { version = "1.0" }
[features]
default = ["remote"]
default = []
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
s3-test = []

View File

@@ -33,6 +33,9 @@ use crate::table::{NativeTable, WriteOptions};
use crate::utils::validate_table_name;
use crate::Table;
#[cfg(feature = "remote")]
use log::warn;
pub const LANCE_FILE_EXTENSION: &str = "lance";
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
@@ -579,6 +582,7 @@ impl ConnectBuilder {
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
warn!("The rust implementation of the remote client is not yet ready for use.");
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.uri,
&api_key,
@@ -909,12 +913,23 @@ impl ConnectionInternal for Database {
}
}
// Some ReadParams are exposed in the OpenTableBuilder, but we also
// let the user provide their own ReadParams.
//
// If we have a user provided ReadParams use that
// If we don't then start with the default ReadParams and customize it with
// the options from the OpenTableBuilder
let read_params = options.lance_read_params.unwrap_or_else(|| ReadParams {
index_cache_size: options.index_cache_size as usize,
..Default::default()
});
let native_table = Arc::new(
NativeTable::open_with_params(
&table_uri,
&options.name,
self.store_wrapper.clone(),
options.lance_read_params,
Some(read_params),
self.read_consistency_interval,
)
.await?,
@@ -1032,7 +1047,6 @@ mod tests {
}
#[tokio::test]
#[ignore = "this can't pass due to https://github.com/lancedb/lancedb/issues/1019, enable it after the bug fixed"]
async fn test_open_table() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();

View File

@@ -46,10 +46,18 @@ impl VectorIndex {
}
}
#[derive(Debug, Deserialize)]
pub struct VectorIndexMetadata {
pub metric_type: String,
pub index_type: String,
}
#[derive(Debug, Deserialize)]
pub struct VectorIndexStatistics {
pub num_indexed_rows: usize,
pub num_unindexed_rows: usize,
pub index_type: String,
pub indices: Vec<VectorIndexMetadata>,
}
/// Builder for an IVF PQ index.

View File

@@ -350,8 +350,16 @@ mod test {
#[tokio::test]
async fn test_e2e() {
let dir1 = tempfile::tempdir().unwrap().into_path();
let dir2 = tempfile::tempdir().unwrap().into_path();
let dir1 = tempfile::tempdir()
.unwrap()
.into_path()
.canonicalize()
.unwrap();
let dir2 = tempfile::tempdir()
.unwrap()
.into_path()
.canonicalize()
.unwrap();
let secondary_store = LocalFileSystem::new_with_prefix(dir2.to_str().unwrap()).unwrap();
let object_store_wrapper = Arc::new(MirroringObjectStoreWrapper {

View File

@@ -34,6 +34,16 @@
//! cargo install lancedb
//! ```
//!
//! ## Crate Features
//!
//! ### Experimental Features
//!
//! These features are not enabled by default. They are experimental or in-development features that
//! are not yet ready to be released.
//!
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
//! and should not be enabled.
//!
//! ### Quick Start
//!
//! #### Connect to a database.

View File

@@ -1061,6 +1061,26 @@ impl NativeTable {
}
}
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(stats.index_type)),
None => Ok(None),
}
}
pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> {
match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(
stats
.indices
.iter()
.map(|i| i.metric_type.clone())
.collect(),
)),
None => Ok(None),
}
}
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
let dataset = self.dataset.get().await?;
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;