Compare commits

...

10 Commits

Author SHA1 Message Date
Lance Release
a6544c2a31 Bump version: 0.1.5 → 0.1.6 2023-06-15 16:16:03 +00:00
Leon Yee
39ed70896a [rust] added rust.yml for /rust directory (#193) 2023-06-14 11:46:08 -07:00
gsilvestrin
ae672df1b7 feat(rust): add action to publish release to crates.io (#192) 2023-06-14 11:01:22 -07:00
gsilvestrin
15c3f42387 feat(node): add action to tag node / rust releases (#186) 2023-06-14 11:01:02 -07:00
gsilvestrin
f65d85efcc feat(node): add where method to query builder (#183)
Closes #181
2023-06-14 10:54:43 -07:00
Utkarsh Gautam
6b5c046c3b [Python] Updated to_df implementation in Contextualizer class (#174)
Changes include:
- Contexts of sizes less than window param to be included as well
- Added optional threshold parameter to to_df in Contextualizer 
This should close #165 
- If maintainers are satisfied with the implementation will add more
examples and test cases and update the documentations as well.

---------

Co-authored-by: Nithin PS <47279496+Nithinps021@users.noreply.github.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
2023-06-14 09:22:32 -07:00
Lei Xu
d00f4e51d0 Fix node ffi build (#191) 2023-06-13 19:31:29 -07:00
Benjamin Manns
fbc44d4243 Fix small typo in ann_indexes.md (#190) 2023-06-13 17:43:18 -07:00
Lei Xu
b53eee42ce Upgrade to lance 0.4.21 (#187) 2023-06-13 15:39:44 -07:00
Utkarsh Gautam
7e0d6088ca [docs] Fixed langchain example broken link in index.md (#184) 2023-06-13 12:40:39 -07:00
20 changed files with 374 additions and 40 deletions

12
.bumpversion.cfg Normal file
View File

@@ -0,0 +1,12 @@
[bumpversion]
current_version = 0.1.6
commit = True
message = Bump version: {current_version} → {new_version}
tag = True
tag_name = v{new_version}
[bumpversion:file:node/package.json]
[bumpversion:file:rust/ffi/node/Cargo.toml]
[bumpversion:file:rust/vectordb/Cargo.toml]

29
.github/workflows/cargo-publish.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Cargo Publish
on:
release:
types: [ published ]
env:
# This env var is used by Swatinem/rust-cache@v2 for the cache
# key, so we set it to make sure it is always consistent.
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-22.04
timeout-minutes: 30
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v3
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Publish the package
run: |
cargo publish -p vectordb --all-features --token ${{ secrets.CARGO_REGISTRY_TOKEN }}

View File

@@ -0,0 +1,55 @@
name: Create release commit
on:
workflow_dispatch:
inputs:
dry_run:
description: 'Dry run (create the local commit/tags but do not push it)'
required: true
default: "false"
type: choice
options:
- "true"
- "false"
part:
description: 'What kind of release is this?'
required: true
default: 'patch'
type: choice
options:
- patch
- minor
- major
jobs:
bump-version:
runs-on: ubuntu-latest
steps:
- name: Check out main
uses: actions/checkout@v3
with:
ref: main
persist-credentials: false
fetch-depth: 0
lfs: true
- name: Set git configs for bumpversion
shell: bash
run: |
git config user.name 'Lance Release'
git config user.email 'lance-dev@lancedb.com'
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Bump version, create tag and commit
run: |
pip install bump2version
bumpversion --verbose ${{ inputs.part }}
- name: Push new version and tag
if: ${{ inputs.dry_run }} == "false"
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
branch: main
tags: true

View File

@@ -3,12 +3,12 @@ name: PyPI Publish
on: on:
release: release:
types: [ published ] types: [ published ]
tags:
- 'python-v*' # Push events that matches the python-make-release action
jobs: jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Only runs on tags that matches the python-make-release action
if: startsWith(github.ref, 'refs/tags/python-v')
defaults: defaults:
run: run:
shell: bash shell: bash

67
.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
name: Rust
on:
push:
branches:
- main
pull_request:
paths:
- rust/**
- .github/workflows/rust.yml
env:
# This env var is used by Swatinem/rust-cache@v2 for the cache
# key, so we set it to make sure it is always consistent.
CARGO_TERM_COLOR: always
# Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks.
RUSTFLAGS: "-C debuginfo=1"
RUST_BACKTRACE: "1"
jobs:
linux:
timeout-minutes: 30
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: rust
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: cargo build --all-features
- name: Run tests
run: cargo test --all-features
macos:
runs-on: macos-12
timeout-minutes: 30
defaults:
run:
shell: bash
working-directory: rust
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- name: CPU features
run: sysctl -a | grep cpu
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: brew install protobuf
- name: Build
run: cargo build --all-features
- name: Run tests
run: cargo test --all-features

32
Cargo.lock generated
View File

@@ -190,6 +190,7 @@ dependencies = [
"arrow-data", "arrow-data",
"arrow-schema", "arrow-schema",
"flatbuffers", "flatbuffers",
"zstd",
] ]
[[package]] [[package]]
@@ -654,6 +655,12 @@ version = "3.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8" checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8"
[[package]]
name = "bytemuck"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
version = "1.4.3" version = "1.4.3"
@@ -1646,9 +1653,9 @@ dependencies = [
[[package]] [[package]]
name = "lance" name = "lance"
version = "0.4.17" version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86dda8185bd1ffae7b910c1f68035af23be9b717c52e9cc4de176cd30b47f772" checksum = "3d6c2e7bcfc71c7167ec70cd06c6d55c644a148f6580218c5a0b66e13ac5b5cc"
dependencies = [ dependencies = [
"accelerate-src", "accelerate-src",
"arrow", "arrow",
@@ -1657,7 +1664,9 @@ dependencies = [
"arrow-buffer", "arrow-buffer",
"arrow-cast", "arrow-cast",
"arrow-data", "arrow-data",
"arrow-ipc",
"arrow-ord", "arrow-ord",
"arrow-row",
"arrow-schema", "arrow-schema",
"arrow-select", "arrow-select",
"async-recursion", "async-recursion",
@@ -1668,6 +1677,7 @@ dependencies = [
"bytes", "bytes",
"cblas", "cblas",
"chrono", "chrono",
"dashmap",
"datafusion", "datafusion",
"futures", "futures",
"lapack", "lapack",
@@ -1684,6 +1694,7 @@ dependencies = [
"prost-types", "prost-types",
"rand", "rand",
"reqwest", "reqwest",
"roaring",
"shellexpand", "shellexpand",
"snafu", "snafu",
"sqlparser-lance", "sqlparser-lance",
@@ -2598,6 +2609,12 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "retain_mut"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c31b5c4033f8fdde8700e4657be2c497e7288f01515be52168c631e2e4d4086"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.16.20" version = "0.16.20"
@@ -2613,6 +2630,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "roaring"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef0fb5e826a8bde011ecae6a8539dd333884335c57ff0f003fbe27c25bbe8f71"
dependencies = [
"bytemuck",
"byteorder",
"retain_mut",
]
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.4.0" version = "0.4.0"

View File

@@ -67,7 +67,7 @@ There are a couple of parameters that can be used to fine-tune the search:
e.g., for 1M vectors divided up into 256 partitions, nprobes should be set to ~20-40.<br/> e.g., for 1M vectors divided up into 256 partitions, nprobes should be set to ~20-40.<br/>
Note: nprobes is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored. Note: nprobes is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
- **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory.<br/> - **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory.<br/>
A higher number makes search more accurate but also slower. If you find the recall is less than idea, try refine_factor=10 to start.<br/> A higher number makes search more accurate but also slower. If you find the recall is less than ideal, try refine_factor=10 to start.<br/>
e.g., for 1M vectors divided into 256 partitions, if you're looking for top 20, then refine_factor=200 reranks the whole partition.<br/> e.g., for 1M vectors divided into 256 partitions, if you're looking for top 20, then refine_factor=200 reranks the whole partition.<br/>
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored. Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.

View File

@@ -14,7 +14,7 @@ The key features of LanceDB include:
* Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure. * Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure.
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way. * Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads. LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.1.5", "version": "0.1.6",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",

View File

@@ -293,6 +293,8 @@ export class Query<T = number[]> {
return this return this
} }
where = this.filter
/** Return only the specified columns. /** Return only the specified columns.
* *
* @param value Only select the specified columns. If not specified, all columns will be returned. * @param value Only select the specified columns. If not specified, all columns will be returned.

View File

@@ -64,13 +64,20 @@ describe('LanceDB client', function () {
assert.equal(results[0].id, 1) assert.equal(results[0].id, 1)
}) })
it('uses a filter', async function () { it('uses a filter / where clause', async function () {
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
const assertResults = (results: Array<Record<string, unknown>>) => {
assert.equal(results.length, 1)
assert.equal(results[0].id, 2)
}
const uri = await createTestDB() const uri = await createTestDB()
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)
const table = await con.openTable('vectors') const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.1]).filter('id == 2').execute() let results = await table.search([0.1, 0.1]).filter('id == 2').execute()
assert.equal(results.length, 1) assertResults(results)
assert.equal(results[0].id, 2) results = await table.search([0.1, 0.1]).where('id == 2').execute()
assertResults(results)
}) })
it('select only a subset of columns', async function () { it('select only a subset of columns', async function () {

View File

@@ -42,34 +42,38 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
paragraphs, messages, etc. paragraphs, messages, etc.
>>> contextualize(data).window(3).stride(1).text_col('token').to_df() >>> contextualize(data).window(3).stride(1).text_col('token').to_df()
token document_id token document_id
0 The quick brown 1 0 The quick brown 1
1 quick brown fox 1 1 quick brown fox 1
2 brown fox jumped 1 2 brown fox jumped 1
3 fox jumped over 1 3 fox jumped over 1
4 jumped over the 1 4 jumped over the 1
5 over the lazy 1 5 over the lazy 1
6 the lazy dog 1 6 the lazy dog 1
7 lazy dog I 1 7 lazy dog I 1
8 dog I love 1 8 dog I love 1
>>> contextualize(data).window(7).stride(1).text_col('token').to_df() 9 I love sandwiches 2
10 love sandwiches 2
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_df()
token document_id token document_id
0 The quick brown fox jumped over the 1 0 The quick brown fox jumped over the 1
1 quick brown fox jumped over the lazy 1 1 quick brown fox jumped over the lazy 1
2 brown fox jumped over the lazy dog 1 2 brown fox jumped over the lazy dog 1
3 fox jumped over the lazy dog I 1 3 fox jumped over the lazy dog I 1
4 jumped over the lazy dog I love 1 4 jumped over the lazy dog I love 1
5 over the lazy dog I love sandwiches 1
``stride`` determines how many rows to skip between each window start. This can ``stride`` determines how many rows to skip between each window start. This can
be used to reduce the total number of windows generated. be used to reduce the total number of windows generated.
>>> contextualize(data).window(4).stride(2).text_col('token').to_df() >>> contextualize(data).window(4).stride(2).text_col('token').to_df()
token document_id token document_id
0 The quick brown fox 1 0 The quick brown fox 1
2 brown fox jumped over 1 2 brown fox jumped over 1
4 jumped over the lazy 1 4 jumped over the lazy 1
6 the lazy dog I 1 6 the lazy dog I 1
8 dog I love sandwiches 1
10 love sandwiches 2
``groupby`` determines how to group the rows. For example, we would like to have ``groupby`` determines how to group the rows. For example, we would like to have
context windows that don't cross document boundaries. In this case, we can context windows that don't cross document boundaries. In this case, we can
@@ -80,6 +84,25 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
0 The quick brown fox 1 0 The quick brown fox 1
2 brown fox jumped over 1 2 brown fox jumped over 1
4 jumped over the lazy 1 4 jumped over the lazy 1
6 the lazy dog 1
9 I love sandwiches 2
``min_window_size`` determines the minimum size of the context windows that are generated
This can be used to trim the last few context windows which have size less than
``min_window_size``. By default context windows of size 1 are skipped.
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_df()
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
6 the lazy dog 1
9 I love sandwiches 2
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_df()
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
""" """
return Contextualizer(raw_df) return Contextualizer(raw_df)
@@ -92,6 +115,7 @@ class Contextualizer:
self._groupby = None self._groupby = None
self._stride = None self._stride = None
self._window = None self._window = None
self._min_window_size = 2
self._raw_df = raw_df self._raw_df = raw_df
def window(self, window: int) -> Contextualizer: def window(self, window: int) -> Contextualizer:
@@ -139,6 +163,17 @@ class Contextualizer:
self._text_col = text_col self._text_col = text_col
return self return self
def min_window_size(self, min_window_size: int) -> Contextualizer:
"""Set the (optional) min_window_size size for the context window.
Parameters
----------
min_window_size: int
The min_window_size.
"""
self._min_window_size = min_window_size
return self
def to_df(self) -> pd.DataFrame: def to_df(self) -> pd.DataFrame:
"""Create the context windows and return a DataFrame.""" """Create the context windows and return a DataFrame."""
@@ -159,12 +194,19 @@ class Contextualizer:
def process_group(grp): def process_group(grp):
# For each group, create the text rolling window # For each group, create the text rolling window
# with values of size >= min_window_size
text = grp[self._text_col].values text = grp[self._text_col].values
contexts = grp.iloc[: -self._window : self._stride, :].copy() contexts = grp.iloc[:: self._stride, :].copy()
contexts[self._text_col] = [ windows = [
" ".join(text[start_i : start_i + self._window]) " ".join(text[start_i : min(start_i + self._window, len(grp))])
for start_i in range(0, len(grp) - self._window, self._stride) for start_i in range(0, len(grp), self._stride)
if start_i + self._window <= len(grp)
or len(grp) - start_i >= self._min_window_size
] ]
# if last few rows dropped
if len(windows) < len(contexts):
contexts = contexts.iloc[: len(windows)]
contexts[self._text_col] = windows
return contexts return contexts
if self._groupby is None: if self._groupby is None:

View File

@@ -0,0 +1,77 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
import pytest
from lancedb.context import contextualize
@pytest.fixture
def raw_df() -> pd.DataFrame:
return pd.DataFrame(
{
"token": [
"The",
"quick",
"brown",
"fox",
"jumped",
"over",
"the",
"lazy",
"dog",
"I",
"love",
"sandwiches",
],
"document_id": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2],
}
)
def test_contextualizer(raw_df: pd.DataFrame):
result = (
contextualize(raw_df)
.window(6)
.stride(3)
.text_col("token")
.groupby("document_id")
.to_df()["token"]
.to_list()
)
assert result == [
"The quick brown fox jumped over",
"fox jumped over the lazy dog",
"the lazy dog",
"I love sandwiches",
]
def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
result = (
contextualize(raw_df)
.window(6)
.stride(3)
.text_col("token")
.groupby("document_id")
.min_window_size(4)
.to_df()["token"]
.to_list()
)
assert result == [
"The quick brown fox jumped over",
"fox jumped over the lazy dog",
]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.1.0" version = "0.1.6"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
edition = "2018" edition = "2018"

View File

@@ -97,6 +97,7 @@ fn get_index_params_builder(
let ivf_params = IvfBuildParams { let ivf_params = IvfBuildParams {
num_partitions: np, num_partitions: np,
max_iters, max_iters,
centroids: None,
}; };
index_builder.ivf_params(ivf_params) index_builder.ivf_params(ivf_params)
}); });

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.0.1" version = "0.1.6"
edition = "2021" edition = "2021"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
@@ -14,7 +14,7 @@ arrow-data = "37.0"
arrow-schema = "37.0" arrow-schema = "37.0"
object_store = "0.5.6" object_store = "0.5.6"
snafu = "0.7.4" snafu = "0.7.4"
lance = "0.4.17" lance = "0.4.21"
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
[dev-dependencies] [dev-dependencies]

View File

@@ -42,7 +42,7 @@ impl Database {
/// ///
/// * A [Database] object. /// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> { pub async fn connect(uri: &str) -> Result<Database> {
let object_store = ObjectStore::new(uri).await?; let (object_store, _) = ObjectStore::from_uri(uri).await?;
if object_store.is_local() { if object_store.is_local() {
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?; Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
} }
@@ -69,7 +69,7 @@ impl Database {
pub async fn table_names(&self) -> Result<Vec<String>> { pub async fn table_names(&self) -> Result<Vec<String>> {
let f = self let f = self
.object_store .object_store
.read_dir("/") .read_dir(self.uri.as_str())
.await? .await?
.iter() .iter()
.map(|fname| Path::new(fname)) .map(|fname| Path::new(fname))

View File

@@ -20,6 +20,8 @@ pub trait VectorIndexBuilder {
fn get_column(&self) -> Option<String>; fn get_column(&self) -> Option<String>;
fn get_index_name(&self) -> Option<String>; fn get_index_name(&self) -> Option<String>;
fn build(&self) -> VectorIndexParams; fn build(&self) -> VectorIndexParams;
fn get_replace(&self) -> bool;
} }
pub struct IvfPQIndexBuilder { pub struct IvfPQIndexBuilder {
@@ -28,6 +30,7 @@ pub struct IvfPQIndexBuilder {
metric_type: Option<MetricType>, metric_type: Option<MetricType>,
ivf_params: Option<IvfBuildParams>, ivf_params: Option<IvfBuildParams>,
pq_params: Option<PQBuildParams>, pq_params: Option<PQBuildParams>,
replace: bool,
} }
impl IvfPQIndexBuilder { impl IvfPQIndexBuilder {
@@ -38,6 +41,7 @@ impl IvfPQIndexBuilder {
metric_type: None, metric_type: None,
ivf_params: None, ivf_params: None,
pq_params: None, pq_params: None,
replace: true,
} }
} }
} }
@@ -67,6 +71,11 @@ impl IvfPQIndexBuilder {
self.pq_params = Some(pq_params); self.pq_params = Some(pq_params);
self self
} }
pub fn replace(&mut self, replace: bool) -> &mut IvfPQIndexBuilder {
self.replace = replace;
self
}
} }
impl VectorIndexBuilder for IvfPQIndexBuilder { impl VectorIndexBuilder for IvfPQIndexBuilder {
@@ -84,6 +93,10 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params) VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params)
} }
fn get_replace(&self) -> bool {
self.replace
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -177,7 +177,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_setters_getters() { async fn test_setters_getters() {
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let ds = Dataset::write(&mut batches, ":memory:", None) let ds = Dataset::write(&mut batches, "memory://foo", None)
.await .await
.unwrap(); .unwrap();
@@ -206,7 +206,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_execute() { async fn test_execute() {
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let ds = Dataset::write(&mut batches, ":memory:", None) let ds = Dataset::write(&mut batches, "memory://foo", None)
.await .await
.unwrap(); .unwrap();

View File

@@ -130,6 +130,7 @@ impl Table {
IndexType::Vector, IndexType::Vector,
index_builder.get_index_name(), index_builder.get_index_name(),
&index_builder.build(), &index_builder.build(),
index_builder.get_replace(),
) )
.await?; .await?;
self.dataset = Arc::new(dataset); self.dataset = Arc::new(dataset);
@@ -233,7 +234,7 @@ mod tests {
let uri = tmp_dir.path().to_str().unwrap(); let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone(); let _ = batches.schema().clone();
Table::create(&uri, "test", batches).await.unwrap(); Table::create(&uri, "test", batches).await.unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());