mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat(napi): Provide a new createIndex API in the napi SDK. (#857)
This commit is contained in:
@@ -18,5 +18,5 @@ module.exports = {
|
||||
"@typescript-eslint/method-signature-style": "off",
|
||||
"@typescript-eslint/no-explicit-any": "off",
|
||||
},
|
||||
ignorePatterns: ["node_modules/", "dist/", "build/"],
|
||||
ignorePatterns: ["node_modules/", "dist/", "build/", "vectordb/native.*"],
|
||||
};
|
||||
|
||||
@@ -17,6 +17,7 @@ napi = { version = "2.14", default-features = false, features = [
|
||||
napi-derive = "2.14"
|
||||
vectordb = { path = "../rust/vectordb" }
|
||||
lance.workspace = true
|
||||
lance-linalg.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2.1"
|
||||
|
||||
@@ -1,3 +1,24 @@
|
||||
# (New) LanceDB NodeJS SDK
|
||||
|
||||
It will replace the NodeJS SDK when it is ready.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
```sh
|
||||
npm run build
|
||||
npm t
|
||||
```
|
||||
|
||||
Generating docs
|
||||
|
||||
```
|
||||
npm run docs
|
||||
|
||||
cd ../docs
|
||||
# Asssume the virtual environment was created
|
||||
# python3 -m venv venv
|
||||
# pip install -r requirements.txt
|
||||
. ./venv/bin/activate
|
||||
mkdocs build
|
||||
```
|
||||
|
||||
99
nodejs/__test__/table.test.ts
Normal file
99
nodejs/__test__/table.test.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import * as os from "os";
|
||||
import * as path from "path";
|
||||
import * as fs from "fs";
|
||||
|
||||
import { connect } from "../dist";
|
||||
import { Schema, Field, Float32, Int32, FixedSizeList } from "apache-arrow";
|
||||
import { makeArrowTable } from "../dist/arrow";
|
||||
|
||||
describe("Test creating index", () => {
|
||||
let tmpDir: string;
|
||||
const schema = new Schema([
|
||||
new Field("id", new Int32(), true),
|
||||
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
|
||||
]);
|
||||
|
||||
beforeEach(() => {
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "index-"));
|
||||
});
|
||||
|
||||
test("create vector index with no column", async () => {
|
||||
const db = await connect(tmpDir);
|
||||
const data = makeArrowTable(
|
||||
Array(300)
|
||||
.fill(1)
|
||||
.map((_, i) => ({
|
||||
id: i,
|
||||
vec: Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
})),
|
||||
{
|
||||
schema,
|
||||
}
|
||||
);
|
||||
const tbl = await db.createTable("test", data);
|
||||
await tbl.createIndex().build();
|
||||
|
||||
// check index directory
|
||||
const indexDir = path.join(tmpDir, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
// TODO: check index type.
|
||||
});
|
||||
|
||||
test("no vector column available", async () => {
|
||||
const db = await connect(tmpDir);
|
||||
const tbl = await db.createTable(
|
||||
"no_vec",
|
||||
makeArrowTable([
|
||||
{ id: 1, val: 2 },
|
||||
{ id: 2, val: 3 },
|
||||
])
|
||||
);
|
||||
await expect(tbl.createIndex().build()).rejects.toThrow(
|
||||
"No vector column found"
|
||||
);
|
||||
|
||||
await tbl.createIndex("val").build();
|
||||
const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("create scalar index", async () => {
|
||||
const db = await connect(tmpDir);
|
||||
const data = makeArrowTable(
|
||||
Array(300)
|
||||
.fill(1)
|
||||
.map((_, i) => ({
|
||||
id: i,
|
||||
vec: Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
})),
|
||||
{
|
||||
schema,
|
||||
}
|
||||
);
|
||||
const tbl = await db.createTable("test", data);
|
||||
await tbl.createIndex("id").build();
|
||||
|
||||
// check index directory
|
||||
const indexDir = path.join(tmpDir, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
// TODO: check index type.
|
||||
});
|
||||
});
|
||||
101
nodejs/src/index.rs
Normal file
101
nodejs/src/index.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance_linalg::distance::MetricType as LanceMetricType;
|
||||
use napi_derive::napi;
|
||||
|
||||
#[napi]
|
||||
pub enum IndexType {
|
||||
Scalar,
|
||||
IvfPq,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub enum MetricType {
|
||||
L2,
|
||||
Cosine,
|
||||
Dot,
|
||||
}
|
||||
|
||||
impl From<MetricType> for LanceMetricType {
|
||||
fn from(metric: MetricType) -> Self {
|
||||
match metric {
|
||||
MetricType::L2 => Self::L2,
|
||||
MetricType::Cosine => Self::Cosine,
|
||||
MetricType::Dot => Self::Dot,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct IndexBuilder {
|
||||
inner: vectordb::index::IndexBuilder,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl IndexBuilder {
|
||||
pub fn new(tbl: &dyn vectordb::Table) -> Self {
|
||||
let inner = tbl.create_index(&[]);
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub unsafe fn replace(&mut self, v: bool) {
|
||||
self.inner.replace(v);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub unsafe fn column(&mut self, c: String) {
|
||||
self.inner.columns(&[c.as_str()]);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub unsafe fn name(&mut self, name: String) {
|
||||
self.inner.name(name.as_str());
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub unsafe fn ivf_pq(
|
||||
&mut self,
|
||||
metric_type: Option<MetricType>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) {
|
||||
self.inner.ivf_pq();
|
||||
metric_type.map(|m| self.inner.metric_type(m.into()));
|
||||
num_partitions.map(|p| self.inner.num_partitions(p));
|
||||
num_sub_vectors.map(|s| self.inner.num_sub_vectors(s));
|
||||
num_bits.map(|b| self.inner.num_bits(b));
|
||||
max_iterations.map(|i| self.inner.max_iterations(i));
|
||||
sample_rate.map(|s| self.inner.sample_rate(s));
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub unsafe fn scalar(&mut self) {
|
||||
self.inner.scalar();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn build(&self) -> napi::Result<()> {
|
||||
println!("nodejs::index.rs : build");
|
||||
self.inner
|
||||
.build()
|
||||
.await
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ use connection::Connection;
|
||||
use napi_derive::*;
|
||||
|
||||
mod connection;
|
||||
mod index;
|
||||
mod query;
|
||||
mod table;
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::query::Query;
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
|
||||
|
||||
use crate::index::IndexBuilder;
|
||||
use crate::query::Query;
|
||||
|
||||
#[napi]
|
||||
pub struct Table {
|
||||
pub(crate) table: TableRef,
|
||||
@@ -43,7 +45,7 @@ impl Table {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async unsafe fn add(&mut self, buf: Buffer) -> napi::Result<()> {
|
||||
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
self.table.add(Box::new(batches), None).await.map_err(|e| {
|
||||
@@ -65,7 +67,7 @@ impl Table {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async unsafe fn delete(&mut self, predicate: String) -> napi::Result<()> {
|
||||
pub async fn delete(&self, predicate: String) -> napi::Result<()> {
|
||||
self.table.delete(&predicate).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to delete rows in table {}: predicate={}",
|
||||
@@ -74,6 +76,11 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn create_index(&self) -> IndexBuilder {
|
||||
IndexBuilder::new(self.table.as_ref())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn query(&self) -> Query {
|
||||
Query::new(self)
|
||||
|
||||
@@ -179,5 +179,5 @@ export function toBuffer(data: Data, schema?: Schema): Buffer {
|
||||
} else {
|
||||
tbl = makeArrowTable(data, { schema });
|
||||
}
|
||||
return Buffer.from(tableToIPC(tbl, "file"));
|
||||
return Buffer.from(tableToIPC(tbl));
|
||||
}
|
||||
|
||||
@@ -15,10 +15,16 @@
|
||||
import { Connection } from "./connection";
|
||||
import { Connection as NativeConnection, ConnectionOptions } from "./native.js";
|
||||
|
||||
export { ConnectionOptions, WriteOptions, Query } from "./native.js";
|
||||
export {
|
||||
ConnectionOptions,
|
||||
WriteOptions,
|
||||
Query,
|
||||
MetricType,
|
||||
} from "./native.js";
|
||||
export { Connection } from "./connection";
|
||||
export { Table } from "./table";
|
||||
export { Data } from "./arrow";
|
||||
export { IvfPQOptions, IndexBuilder } from "./indexer";
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
|
||||
102
nodejs/vectordb/indexer.ts
Normal file
102
nodejs/vectordb/indexer.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
MetricType,
|
||||
IndexBuilder as NativeBuilder,
|
||||
Table as NativeTable,
|
||||
} from "./native";
|
||||
|
||||
/** Options to create `IVF_PQ` index */
|
||||
export interface IvfPQOptions {
|
||||
/** Number of IVF partitions. */
|
||||
num_partitions?: number;
|
||||
|
||||
/** Number of sub-vectors in PQ coding. */
|
||||
num_sub_vectors?: number;
|
||||
|
||||
/** Number of bits used for each PQ code.
|
||||
*/
|
||||
num_bits?: number;
|
||||
|
||||
/** Metric type to calculate the distance between vectors.
|
||||
*
|
||||
* Supported metrics: `L2`, `Cosine` and `Dot`.
|
||||
*/
|
||||
metric_type?: MetricType;
|
||||
|
||||
/** Number of iterations to train K-means.
|
||||
*
|
||||
* Default is 50. The more iterations it usually yield better results,
|
||||
* but it takes longer to train.
|
||||
*/
|
||||
max_iterations?: number;
|
||||
|
||||
sample_rate?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Building an index on LanceDB {@link Table}
|
||||
*
|
||||
* @see {@link Table.createIndex} for detailed usage.
|
||||
*/
|
||||
export class IndexBuilder {
|
||||
private inner: NativeBuilder;
|
||||
|
||||
constructor(tbl: NativeTable) {
|
||||
this.inner = tbl.createIndex();
|
||||
}
|
||||
|
||||
/** Instruct the builder to build an `IVF_PQ` index */
|
||||
ivf_pq(options?: IvfPQOptions): IndexBuilder {
|
||||
this.inner.ivfPq(
|
||||
options?.metric_type,
|
||||
options?.num_partitions,
|
||||
options?.num_sub_vectors,
|
||||
options?.num_bits,
|
||||
options?.max_iterations,
|
||||
options?.sample_rate
|
||||
);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Instruct the builder to build a Scalar index. */
|
||||
scalar(): IndexBuilder {
|
||||
this.scalar();
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Set the column(s) to create index on top of. */
|
||||
column(col: string): IndexBuilder {
|
||||
this.inner.column(col);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Set to true to replace existing index. */
|
||||
replace(val: boolean): IndexBuilder {
|
||||
this.inner.replace(val);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Specify the name of the index. Optional */
|
||||
name(n: string): IndexBuilder {
|
||||
this.inner.name(n);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Building the index. */
|
||||
async build() {
|
||||
await this.inner.build();
|
||||
}
|
||||
}
|
||||
19
nodejs/vectordb/native.d.ts
vendored
19
nodejs/vectordb/native.d.ts
vendored
@@ -1,7 +1,17 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/* auto-generated by NAPI-RS */
|
||||
|
||||
export const enum IndexType {
|
||||
Scalar = 0,
|
||||
IvfPq = 1
|
||||
}
|
||||
export const enum MetricType {
|
||||
L2 = 0,
|
||||
Cosine = 1,
|
||||
Dot = 2
|
||||
}
|
||||
export interface ConnectionOptions {
|
||||
uri: string
|
||||
apiKey?: string
|
||||
@@ -36,6 +46,14 @@ export class Connection {
|
||||
/** Drop table with the name. Or raise an error if the table does not exist. */
|
||||
dropTable(name: string): Promise<void>
|
||||
}
|
||||
export class IndexBuilder {
|
||||
replace(v: boolean): void
|
||||
column(c: string): void
|
||||
name(name: string): void
|
||||
ivfPq(metricType?: MetricType | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, numBits?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): void
|
||||
scalar(): void
|
||||
build(): Promise<void>
|
||||
}
|
||||
export class Query {
|
||||
vector(vector: Float32Array): void
|
||||
toArrow(): void
|
||||
@@ -46,5 +64,6 @@ export class Table {
|
||||
add(buf: Buffer): Promise<void>
|
||||
countRows(): Promise<bigint>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(): IndexBuilder
|
||||
query(): Query
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
/* prettier-ignore */
|
||||
|
||||
@@ -294,9 +295,12 @@ if (!nativeBinding) {
|
||||
throw new Error(`Failed to load native binding`)
|
||||
}
|
||||
|
||||
const { Connection, Query, Table, WriteMode, connect } = nativeBinding
|
||||
const { Connection, IndexType, MetricType, IndexBuilder, Query, Table, WriteMode, connect } = nativeBinding
|
||||
|
||||
module.exports.Connection = Connection
|
||||
module.exports.IndexType = IndexType
|
||||
module.exports.MetricType = MetricType
|
||||
module.exports.IndexBuilder = IndexBuilder
|
||||
module.exports.Query = Query
|
||||
module.exports.Table = Table
|
||||
module.exports.WriteMode = WriteMode
|
||||
|
||||
@@ -16,6 +16,7 @@ import { Schema, tableFromIPC } from "apache-arrow";
|
||||
import { Table as _NativeTable } from "./native";
|
||||
import { toBuffer, Data } from "./arrow";
|
||||
import { Query } from "./query";
|
||||
import { IndexBuilder } from "./indexer";
|
||||
|
||||
/**
|
||||
* A LanceDB Table is the collection of Records.
|
||||
@@ -58,6 +59,42 @@ export class Table {
|
||||
await this.inner.delete(predicate);
|
||||
}
|
||||
|
||||
/** Create an index over the columns.
|
||||
*
|
||||
* @param {string} column The column to create the index on. If not specified,
|
||||
* it will create an index on vector field.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
* By default, it creates vector idnex on one vector column.
|
||||
*
|
||||
* ```typescript
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex().build();
|
||||
* ```
|
||||
*
|
||||
* You can specify `IVF_PQ` parameters via `ivf_pq({})` call.
|
||||
* ```typescript
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex("my_vec_col")
|
||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
||||
* .build();
|
||||
* ```
|
||||
*
|
||||
* Or create a Scalar index
|
||||
*
|
||||
* ```typescript
|
||||
* await table.createIndex("my_float_col").build();
|
||||
* ```
|
||||
*/
|
||||
createIndex(column?: string): IndexBuilder {
|
||||
let builder = new IndexBuilder(this.inner);
|
||||
if (column !== undefined) {
|
||||
builder = builder.column(column);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
search(vector?: number[]): Query {
|
||||
const q = new Query(this);
|
||||
if (vector !== undefined) {
|
||||
|
||||
@@ -75,8 +75,8 @@ fn get_index_params_builder(
|
||||
builder.metric_type(metric_type);
|
||||
}
|
||||
|
||||
if let Some(np) = obj.get_opt_usize(cx, "num_partitions")? {
|
||||
builder.num_partitions(np as u64);
|
||||
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
||||
builder.num_partitions(np);
|
||||
}
|
||||
if let Some(ns) = obj.get_opt_u32(cx, "num_sub_vectors")? {
|
||||
builder.num_sub_vectors(ns);
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::{cmp::max, sync::Arc};
|
||||
|
||||
use arrow_schema::Schema;
|
||||
use lance_index::{DatasetIndexExt, IndexType};
|
||||
pub use lance_linalg::distance::MetricType;
|
||||
|
||||
@@ -55,7 +56,7 @@ pub struct IndexBuilder {
|
||||
|
||||
// IVF_PQ parameters
|
||||
metric_type: MetricType,
|
||||
num_partitions: Option<u64>,
|
||||
num_partitions: Option<u32>,
|
||||
// PQ related
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
@@ -109,6 +110,11 @@ impl IndexBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn columns(&mut self, cols: &[&str]) -> &mut Self {
|
||||
self.columns = cols.iter().map(|s| s.to_string()).collect();
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to replace the existing index, default is `true`.
|
||||
pub fn replace(&mut self, v: bool) -> &mut Self {
|
||||
self.replace = v;
|
||||
@@ -130,7 +136,7 @@ impl IndexBuilder {
|
||||
}
|
||||
|
||||
/// Number of IVF partitions.
|
||||
pub fn num_partitions(&mut self, num_partitions: u64) -> &mut Self {
|
||||
pub fn num_partitions(&mut self, num_partitions: u32) -> &mut Self {
|
||||
self.num_partitions = Some(num_partitions);
|
||||
self
|
||||
}
|
||||
@@ -161,16 +167,28 @@ impl IndexBuilder {
|
||||
|
||||
/// Build the parameters.
|
||||
pub async fn build(&self) -> Result<()> {
|
||||
if self.columns.len() != 1 {
|
||||
let schema = self.table.schema();
|
||||
|
||||
// TODO: simplify this after GH lance#1864.
|
||||
let mut index_type = &self.index_type;
|
||||
let columns = if self.columns.is_empty() {
|
||||
// By default we create vector index.
|
||||
index_type = &IndexType::Vector;
|
||||
vec![default_column_for_index(&schema)?]
|
||||
} else {
|
||||
self.columns.clone()
|
||||
};
|
||||
|
||||
if columns.len() != 1 {
|
||||
return Err(Error::Schema {
|
||||
message: "Only one column is supported for index".to_string(),
|
||||
});
|
||||
}
|
||||
let column = &self.columns[0];
|
||||
let schema = self.table.schema();
|
||||
let column = &columns[0];
|
||||
|
||||
let field = schema.field_with_name(column)?;
|
||||
|
||||
let params = match self.index_type {
|
||||
let params = match index_type {
|
||||
IndexType::Scalar => IndexParams::Scalar {
|
||||
replace: self.replace,
|
||||
},
|
||||
@@ -198,7 +216,7 @@ impl IndexBuilder {
|
||||
IndexParams::IvfPq {
|
||||
replace: self.replace,
|
||||
metric_type: self.metric_type,
|
||||
num_partitions,
|
||||
num_partitions: num_partitions as u64,
|
||||
num_sub_vectors,
|
||||
num_bits: self.num_bits,
|
||||
sample_rate: self.sample_rate,
|
||||
@@ -253,8 +271,8 @@ impl IndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_num_partitions(rows: usize) -> u64 {
|
||||
let num_partitions = (rows as f64).sqrt() as u64;
|
||||
fn suggested_num_partitions(rows: usize) -> u32 {
|
||||
let num_partitions = (rows as f64).sqrt() as u32;
|
||||
max(1, num_partitions)
|
||||
}
|
||||
|
||||
@@ -272,3 +290,83 @@ fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Find one default column to create index.
|
||||
fn default_column_for_index(schema: &Schema) -> Result<String> {
|
||||
// Try to find one fixed size list array column.
|
||||
let candidates = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.filter_map(|field| match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(f, _) if f.data_type().is_floating() => {
|
||||
Some(field.name())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if candidates.is_empty() {
|
||||
Err(Error::Store {
|
||||
message: "No vector column found to create index".to_string(),
|
||||
})
|
||||
} else if candidates.len() != 1 {
|
||||
Err(Error::Store {
|
||||
message: format!(
|
||||
"More than one vector columns found, \
|
||||
please specify which column to create index: {:?}",
|
||||
candidates
|
||||
),
|
||||
})
|
||||
} else {
|
||||
Ok(candidates[0].to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_schema::{DataType, Field};
|
||||
|
||||
#[test]
|
||||
fn test_guess_default_column() {
|
||||
let schema_no_vector = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new("tag", DataType::Utf8, false),
|
||||
]);
|
||||
assert!(default_column_for_index(&schema_no_vector)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("No vector column"));
|
||||
|
||||
let schema_with_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert_eq!(
|
||||
default_column_for_index(&schema_with_vec_col).unwrap(),
|
||||
"vec"
|
||||
);
|
||||
|
||||
let multi_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec2",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 50),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert!(default_column_for_index(&multi_vec_col)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("More than one"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@
|
||||
use std::io::Cursor;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// Convert a Arrow IPC file to a batch reader
|
||||
pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> {
|
||||
let buf_reader = Cursor::new(buf);
|
||||
let reader = FileReader::try_new(buf_reader, None)?;
|
||||
let reader = StreamReader::try_new(buf_reader, None)?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, Int64Array, RecordBatch};
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use arrow_ipc::writer::StreamWriter;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -55,7 +55,7 @@ mod tests {
|
||||
fn test_ipc_file_to_batches() -> Result<()> {
|
||||
let batch = create_record_batch()?;
|
||||
|
||||
let mut writer = FileWriter::try_new(vec![], &batch.schema())?;
|
||||
let mut writer = StreamWriter::try_new(vec![], &batch.schema())?;
|
||||
writer.write(&batch)?;
|
||||
writer.finish()?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user