Compare commits

...

23 Commits

Author SHA1 Message Date
Chang She
86c9bc0d2d stuff 2024-03-12 19:03:26 -07:00
Chang She
c1dfad675a see if we make EncodedImage work 2024-03-12 19:03:26 -07:00
Chang She
2e1838a62a ruff 2024-03-12 19:03:26 -07:00
Chang She
4d39f63cf6 add import guidance 2024-03-12 19:03:26 -07:00
Chang She
3c4f2a7020 fix 2024-03-12 19:03:26 -07:00
Chang She
48a4202748 fix 2024-03-12 19:03:26 -07:00
Chang She
2084fbcff4 working? 2024-03-12 19:03:26 -07:00
Chang She
408988abce just keep EncodedImage for now 2024-03-12 19:03:25 -07:00
Chang She
e68fbf65cc foo 2024-03-12 18:45:32 -07:00
Rok Mihevc
63399dc0ee unused imports 2024-03-12 18:45:32 -07:00
Rok Mihevc
0b0f4e9d1c __get_pydantic_core_schema__ 2024-03-12 18:45:32 -07:00
Rok Mihevc
2ec0e79303 Minor change 2024-03-12 18:45:32 -07:00
Rok Mihevc
d86dd2c60d test automatic reading of uris 2024-03-12 18:45:32 -07:00
Rok Mihevc
67b38d6115 changes 2024-03-12 18:45:32 -07:00
Rok Mihevc
c112dea28b work 2024-03-12 18:45:32 -07:00
Rok Mihevc
d662b9744e black 2024-03-12 18:45:32 -07:00
Rok Mihevc
ac955a5a7e initial commit 2024-03-12 18:45:32 -07:00
Weston Pace
4dc7497547 feat: add list_indices to the async api (#1074) 2024-03-12 14:41:21 -07:00
Weston Pace
d744972f2f feat: add update to the async API (#1093) 2024-03-12 14:11:37 -07:00
Will Jones
9bc320874a fix: handle uri in object (#1091)
Fixes #1078
2024-03-12 13:25:56 -07:00
Weston Pace
510d449167 feat: add time travel operations to the async API (#1070) 2024-03-12 09:20:23 -07:00
Weston Pace
356e89a800 feat: add create_index to the async python API (#1052)
This also refactors the rust lancedb index builder API (and,
correspondingly, the nodejs API)
2024-03-12 05:17:05 -07:00
Will Jones
ae1cf4441d fix: propagate filter validation errors (#1092)
In Rust and Node, we have been swallowing filter validation errors. If
there was an error in parsing the filter, then the filter was silently
ignored, returning unfiltered results.

Fixes #1081
2024-03-11 14:11:39 -07:00
49 changed files with 2645 additions and 995 deletions

View File

@@ -28,13 +28,14 @@ arrow-schema = "50.0"
arrow-arith = "50.0"
arrow-cast = "50.0"
async-trait = "0"
chrono = "0.4.23"
chrono = "0.4.35"
half = { "version" = "=2.3.1", default-features = false, features = [
"num-traits",
] }
futures = "0"
log = "0.4"
object_store = "0.9.0"
pin-project = "1.0.7"
snafu = "0.7.4"
url = "2"
num-traits = "0.2"

View File

@@ -176,6 +176,10 @@ export async function connect (
opts = { uri: arg }
} else {
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
const keys = Object.keys(arg)
if (keys.length === 1 && keys[0] === 'uri' && typeof arg.uri === 'string') {
opts = { uri: arg.uri }
} else {
opts = Object.assign(
{
uri: '',
@@ -187,6 +191,7 @@ export async function connect (
arg
)
}
}
if (opts.uri.startsWith('db://')) {
// Remote connection

View File

@@ -128,6 +128,11 @@ describe('LanceDB client', function () {
assertResults(results)
results = await table.where('id % 2 = 0').execute()
assertResults(results)
// Should reject a bad filter
await expect(table.filter('id % 2 = 0 AND').execute()).to.be.rejectedWith(
/.*sql parser error: Expected an expression:, found: EOF.*/
)
})
it('uses a filter / where clause', async function () {
@@ -283,7 +288,8 @@ describe('LanceDB client', function () {
it('create a table from an Arrow Table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
// Also test the connect function with an object
const con = await lancedb.connect({ uri: dir })
const i32s = new Int32Array(new Array<number>(10))
const i32 = makeVector(i32s)
@@ -745,11 +751,11 @@ describe('LanceDB client', function () {
num_sub_vectors: 2
})
await expect(createIndex).to.be.rejectedWith(
/VectorIndex requires the column data type to be fixed size list of float32s/
"index cannot be created on the column `name` which has data type Utf8"
)
})
it('it should fail when the column is not a vector', async function () {
it('it should fail when num_partitions is invalid', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')

View File

@@ -14,12 +14,10 @@ crate-type = ["cdylib"]
[dependencies]
arrow-ipc.workspace = true
futures.workspace = true
lance-linalg.workspace = true
lance.workspace = true
lancedb = { path = "../rust/lancedb" }
napi = { version = "2.15", default-features = false, features = [
"napi7",
"async"
"async",
] }
napi-derive = "2"

View File

@@ -27,6 +27,7 @@ import {
Float64,
} from "apache-arrow";
import { makeArrowTable } from "../dist/arrow";
import { Index } from "../dist/indices";
describe("Given a table", () => {
let tmpDir: tmp.DirResult;
@@ -65,21 +66,36 @@ describe("Given a table", () => {
expect(table.isOpen()).toBe(false);
expect(table.countRows()).rejects.toThrow("Table some_table is closed");
});
it("should let me update values", async () => {
await table.add([{ id: 1 }]);
expect(await table.countRows("id == 1")).toBe(1);
expect(await table.countRows("id == 7")).toBe(0);
await table.update({ id: "7" });
expect(await table.countRows("id == 1")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
await table.add([{ id: 2 }]);
// Test Map as input
await table.update(new Map(Object.entries({ id: "10" })), {
where: "id % 2 == 0",
});
expect(await table.countRows("id == 2")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1);
});
});
describe("Test creating index", () => {
describe("When creating an index", () => {
let tmpDir: tmp.DirResult;
const schema = new Schema([
new Field("id", new Int32(), true),
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
]);
let tbl: Table;
let queryVec: number[];
beforeEach(() => {
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
test("create vector index with no column", async () => {
const db = await connect(tmpDir.name);
const data = makeArrowTable(
Array(300)
@@ -94,47 +110,66 @@ describe("Test creating index", () => {
schema,
},
);
const tbl = await db.createTable("test", data);
await tbl.createIndex().build();
queryVec = data.toArray()[5].vec.toJSON();
tbl = await db.createTable("test", data);
});
afterEach(() => tmpDir.removeCallback());
it("should create a vector index on vector columns", async () => {
await tbl.createIndex("vec");
// check index directory
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
// TODO: check index type.
const indices = await tbl.listIndices();
expect(indices.length).toBe(1);
expect(indices[0]).toEqual({
indexType: "IvfPq",
columns: ["vec"],
});
// Search without specifying the column
const queryVector = data.toArray()[5].vec.toJSON();
const rst = await tbl.query().nearestTo(queryVector).limit(2).toArrow();
const rst = await tbl.query().nearestTo(queryVec).limit(2).toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
const rst2 = await tbl.search(queryVector, "vec").limit(2).toArrow();
const rst2 = await tbl.search(queryVec, "vec").limit(2).toArrow();
expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString());
});
test("no vector column available", async () => {
const db = await connect(tmpDir.name);
const tbl = await db.createTable(
"no_vec",
makeArrowTable([
{ id: 1, val: 2 },
{ id: 2, val: 3 },
]),
);
await expect(tbl.createIndex().build()).rejects.toThrow(
"No vector column found",
);
it("should allow parameters to be specified", async () => {
await tbl.createIndex("vec", {
config: Index.ivfPq({
numPartitions: 10,
}),
});
await tbl.createIndex("val").build();
const indexDir = path.join(tmpDir.name, "no_vec.lance", "_indices");
// TODO: Verify parameters when we can load index config as part of list indices
});
it("should allow me to replace (or not) an existing index", async () => {
await tbl.createIndex("id");
// Default is replace=true
await tbl.createIndex("id");
await expect(tbl.createIndex("id", { replace: false })).rejects.toThrow(
"already exists",
);
await tbl.createIndex("id", { replace: true });
});
test("should create a scalar index on scalar columns", async () => {
await tbl.createIndex("id");
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
expect(r.numRows).toBe(1);
expect(r.numRows).toBe(298);
}
});
// TODO: Move this test to the query API test (making sure we can reject queries
// when the dimension is incorrect)
test("two columns with different dimensions", async () => {
const db = await connect(tmpDir.name);
const schema = new Schema([
@@ -164,14 +199,9 @@ describe("Test creating index", () => {
);
// Only build index over v1
await expect(tbl.createIndex().build()).rejects.toThrow(
/.*More than one vector columns found.*/,
);
tbl
.createIndex("vec")
// eslint-disable-next-line @typescript-eslint/naming-convention
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
.build();
await tbl.createIndex("vec", {
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
});
const rst = await tbl
.query()
@@ -205,30 +235,6 @@ describe("Test creating index", () => {
expect(rst64Query.toString()).toEqual(rst64Search.toString());
expect(rst64Query.numRows).toBe(2);
});
test("create scalar index", async () => {
const db = await connect(tmpDir.name);
const data = makeArrowTable(
Array(300)
.fill(1)
.map((_, i) => ({
id: i,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
})),
{
schema,
},
);
const tbl = await db.createTable("test", data);
await tbl.createIndex("id").build();
// check index directory
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
// TODO: check index type.
});
});
describe("Read consistency interval", () => {
@@ -348,3 +354,48 @@ describe("schema evolution", function () {
expect(await table.schema()).toEqual(expectedSchema);
});
});
describe("when dealing with versioning", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
});
it("can travel in time", async () => {
// Setup
const con = await connect(tmpDir.name);
const table = await con.createTable("vectors", [
{ id: 1n, vector: [0.1, 0.2] },
]);
const version = await table.version();
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(2);
// Make sure we can rewind
await table.checkout(version);
expect(await table.countRows()).toBe(1);
// Can't add data in time travel mode
await expect(table.add([{ id: 3n, vector: [0.1, 0.2] }])).rejects.toThrow(
"table cannot be modified when a specific version is checked out",
);
// Can go back to normal mode
await table.checkoutLatest();
expect(await table.countRows()).toBe(2);
// Should be able to add data again
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(3);
// Now checkout and restore
await table.checkout(version);
await table.restore();
expect(await table.countRows()).toBe(1);
// Should be able to add data
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(2);
// Can't use restore if not checked out
await expect(table.restore()).rejects.toThrow(
"checkout before running restore",
);
});
});

View File

@@ -18,15 +18,9 @@ import {
ConnectionOptions,
} from "./native.js";
export {
ConnectionOptions,
WriteOptions,
Query,
MetricType,
} from "./native.js";
export { Connection } from "./connection";
export { Table } from "./table";
export { IvfPQOptions, IndexBuilder } from "./indexer";
export { ConnectionOptions, WriteOptions, Query } from "./native.js";
export { Connection, CreateTableOptions } from "./connection";
export { Table, AddDataOptions } from "./table";
/**
* Connect to a LanceDB instance at the given URI.

View File

@@ -1,105 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO: Re-enable this as part of https://github.com/lancedb/lancedb/pull/1052
/* eslint-disable @typescript-eslint/naming-convention */
import {
MetricType,
IndexBuilder as NativeBuilder,
Table as NativeTable,
} from "./native";
/** Options to create `IVF_PQ` index */
export interface IvfPQOptions {
/** Number of IVF partitions. */
num_partitions?: number;
/** Number of sub-vectors in PQ coding. */
num_sub_vectors?: number;
/** Number of bits used for each PQ code.
*/
num_bits?: number;
/** Metric type to calculate the distance between vectors.
*
* Supported metrics: `L2`, `Cosine` and `Dot`.
*/
metric_type?: MetricType;
/** Number of iterations to train K-means.
*
* Default is 50. The more iterations it usually yield better results,
* but it takes longer to train.
*/
max_iterations?: number;
sample_rate?: number;
}
/**
* Building an index on LanceDB {@link Table}
*
* @see {@link Table.createIndex} for detailed usage.
*/
export class IndexBuilder {
private inner: NativeBuilder;
constructor(tbl: NativeTable) {
this.inner = tbl.createIndex();
}
/** Instruct the builder to build an `IVF_PQ` index */
ivf_pq(options?: IvfPQOptions): IndexBuilder {
this.inner.ivfPq(
options?.metric_type,
options?.num_partitions,
options?.num_sub_vectors,
options?.num_bits,
options?.max_iterations,
options?.sample_rate,
);
return this;
}
/** Instruct the builder to build a Scalar index. */
scalar(): IndexBuilder {
this.scalar();
return this;
}
/** Set the column(s) to create index on top of. */
column(col: string): IndexBuilder {
this.inner.column(col);
return this;
}
/** Set to true to replace existing index. */
replace(val: boolean): IndexBuilder {
this.inner.replace(val);
return this;
}
/** Specify the name of the index. Optional */
name(n: string): IndexBuilder {
this.inner.name(n);
return this;
}
/** Building the index. */
async build() {
await this.inner.build();
}
}

195
nodejs/lancedb/indices.ts Normal file
View File

@@ -0,0 +1,195 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Index as LanceDbIndex } from "./native";
/**
* Options to create an `IVF_PQ` index
*/
export interface IvfPqOptions {
/** The number of IVF partitions to create.
*
* This value should generally scale with the number of rows in the dataset.
* By default the number of partitions is the square root of the number of
* rows.
*
* If this value is too large then the first part of the search (picking the
* right partition) will be slow. If this value is too small then the second
* part of the search (searching within a partition) will be slow.
*/
numPartitions?: number;
/** Number of sub-vectors of PQ.
*
* This value controls how much the vector is compressed during the quantization step.
* The more sub vectors there are the less the vector is compressed. The default is
* the dimension of the vector divided by 16. If the dimension is not evenly divisible
* by 16 we use the dimension divded by 8.
*
* The above two cases are highly preferred. Having 8 or 16 values per subvector allows
* us to use efficient SIMD instructions.
*
* If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
* will likely result in poor performance.
*/
numSubVectors?: number;
/** [DistanceType] to use to build the index.
*
* Default value is [DistanceType::L2].
*
* This is used when training the index to calculate the IVF partitions
* (vectors are grouped in partitions with similar vectors according to this
* distance type) and to calculate a subvector's code during quantization.
*
* The distance type used to train an index MUST match the distance type used
* to search the index. Failure to do so will yield inaccurate results.
*
* The following distance types are available:
*
* "l2" - Euclidean distance. This is a very common distance metric that
* accounts for both magnitude and direction when determining the distance
* between vectors. L2 distance has a range of [0, ∞).
*
* "cosine" - Cosine distance. Cosine distance is a distance metric
* calculated from the cosine similarity between two vectors. Cosine
* similarity is a measure of similarity between two non-zero vectors of an
* inner product space. It is defined to equal the cosine of the angle
* between them. Unlike L2, the cosine distance is not affected by the
* magnitude of the vectors. Cosine distance has a range of [0, 2].
*
* Note: the cosine distance is undefined when one (or both) of the vectors
* are all zeros (there is no direction). These vectors are invalid and may
* never be returned from a vector search.
*
* "dot" - Dot product. Dot distance is the dot product of two vectors. Dot
* distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
* L2 norm is 1), then dot distance is equivalent to the cosine distance.
*/
distanceType?: "l2" | "cosine" | "dot";
/** Max iteration to train IVF kmeans.
*
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* Increasing this might improve the quality of the index but in most cases these extra
* iterations have diminishing returns.
*
* The default value is 50.
*/
maxIterations?: number;
/** The number of vectors, per partition, to sample when training IVF kmeans.
*
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
* The default value is 256.
*/
sampleRate?: number;
}
export class Index {
private readonly inner: LanceDbIndex;
private constructor(inner: LanceDbIndex) {
this.inner = inner;
}
/**
* Create an IvfPq index
*
* This index stores a compressed (quantized) copy of every vector. These vectors
* are grouped into partitions of similar vectors. Each partition keeps track of
* a centroid which is the average value of all vectors in the group.
*
* During a query the centroids are compared with the query vector to find the closest
* partitions. The compressed vectors in these partitions are then searched to find
* the closest vectors.
*
* The compression scheme is called product quantization. Each vector is divided into
* subvectors and then each subvector is quantized into a small number of bits. the
* parameters `num_bits` and `num_subvectors` control this process, providing a tradeoff
* between index size (and thus search speed) and index accuracy.
*
* The partitioning process is called IVF and the `num_partitions` parameter controls how
* many groups to create.
*
* Note that training an IVF PQ index on a large dataset is a slow operation and
* currently is also a memory intensive operation.
*/
static ivfPq(options?: Partial<IvfPqOptions>) {
return new Index(
LanceDbIndex.ivfPq(
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.maxIterations,
options?.sampleRate,
),
);
}
/** Create a btree index
*
* A btree index is an index on a scalar columns. The index stores a copy of the column
* in sorted order. A header entry is created for each block of rows (currently the
* block size is fixed at 4096). These header entries are stored in a separate
* cacheable structure (a btree). To search for data the header is used to determine
* which blocks need to be read from disk.
*
* For example, a btree index in a table with 1Bi rows requires sizeof(Scalar) * 256Ki
* bytes of memory and will generally need to read sizeof(Scalar) * 4096 bytes to find
* the correct row ids.
*
* This index is good for scalar columns with mostly distinct values and does best when
* the query is highly selective.
*
* The btree index does not currently have any parameters though parameters such as the
* block size may be added in the future.
*/
static btree() {
return new Index(LanceDbIndex.btree());
}
}
export interface IndexOptions {
/** Advanced index configuration
*
* This option allows you to specify a specfic index to create and also
* allows you to pass in configuration for training the index.
*
* See the static methods on Index for details on the various index types.
*
* If this is not supplied then column data type(s) and column statistics
* will be used to determine the most useful kind of index to create.
*/
config?: Index;
/** Whether to replace the existing index
*
* If this is false, and another index already exists on the same columns
* and the same name, then an error will be returned. This is true even if
* that index is out of date.
*
* The default is true
*/
replace?: boolean;
}

View File

@@ -3,14 +3,17 @@
/* auto-generated by NAPI-RS */
export const enum IndexType {
Scalar = 0,
IvfPq = 1
}
export const enum MetricType {
L2 = 0,
Cosine = 1,
Dot = 2
/** A description of an index currently configured on a column */
export interface IndexConfig {
/** The type of the index */
indexType: string
/**
* The columns in the index
*
* Currently this is always an array of size 1. In the future there may
* be more columns to represent composite indices.
*/
columns: Array<string>
}
/**
* A definition of a column alteration. The alteration changes the column at
@@ -93,13 +96,9 @@ export class Connection {
/** Drop table with the name. Or raise an error if the table does not exist. */
dropTable(name: string): Promise<void>
}
export class IndexBuilder {
replace(v: boolean): void
column(c: string): void
name(name: string): void
ivfPq(metricType?: MetricType | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, numBits?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): void
scalar(): void
build(): Promise<void>
export class Index {
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index
static btree(): Index
}
/** Typescript-style Async Iterator over RecordBatches */
export class RecordBatchIterator {
@@ -125,9 +124,15 @@ export class Table {
add(buf: Buffer, mode: string): Promise<void>
countRows(filter?: string | undefined | null): Promise<number>
delete(predicate: string): Promise<void>
createIndex(): IndexBuilder
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
query(): Query
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
dropColumns(columns: Array<string>): Promise<void>
version(): Promise<number>
checkout(version: number): Promise<void>
checkoutLatest(): Promise<void>
restore(): Promise<void>
listIndices(): Promise<Array<IndexConfig>>
}

View File

@@ -295,12 +295,10 @@ if (!nativeBinding) {
throw new Error(`Failed to load native binding`)
}
const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
const { Connection, Index, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
module.exports.Connection = Connection
module.exports.IndexType = IndexType
module.exports.MetricType = MetricType
module.exports.IndexBuilder = IndexBuilder
module.exports.Index = Index
module.exports.RecordBatchIterator = RecordBatchIterator
module.exports.Query = Query
module.exports.Table = Table

View File

@@ -16,12 +16,14 @@ import { Schema, tableFromIPC } from "apache-arrow";
import {
AddColumnsSql,
ColumnAlteration,
IndexConfig,
Table as _NativeTable,
} from "./native";
import { Query } from "./query";
import { IndexBuilder } from "./indexer";
import { IndexOptions } from "./indices";
import { Data, fromDataToBuffer } from "./arrow";
export { IndexConfig } from "./native";
/**
* Options for adding data to a table.
*/
@@ -33,6 +35,20 @@ export interface AddDataOptions {
mode: "append" | "overwrite";
}
export interface UpdateOptions {
/**
* A filter that limits the scope of the update.
*
* This should be an SQL filter expression.
*
* Only rows that satisfy the expression will be updated.
*
* For example, this could be 'my_col == 0' to replace all instances
* of 0 in a column with some other default value.
*/
where: string;
}
/**
* A Table is a collection of Records in a LanceDB Database.
*
@@ -93,6 +109,45 @@ export class Table {
await this.inner.add(buffer, mode);
}
/**
* Update existing records in the Table
*
* An update operation can be used to adjust existing values. Use the
* returned builder to specify which columns to update. The new value
* can be a literal value (e.g. replacing nulls with some default value)
* or an expression applied to the old value (e.g. incrementing a value)
*
* An optional condition can be specified (e.g. "only update if the old
* value is 0")
*
* Note: if your condition is something like "some_id_column == 7" and
* you are updating many rows (with different ids) then you will get
* better performance with a single [`merge_insert`] call instead of
* repeatedly calilng this method.
*
* @param updates the columns to update
*
* Keys in the map should specify the name of the column to update.
* Values in the map provide the new value of the column. These can
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
* based on the row being updated (e.g. "my_col + 1")
*
* @param options additional options to control the update behavior
*/
async update(
updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>,
) {
const onlyIf = options?.where;
let columns: [string, string][];
if (updates instanceof Map) {
columns = Array.from(updates.entries());
} else {
columns = Object.entries(updates);
}
await this.inner.update(onlyIf, columns);
}
/** Count the total number of rows in the dataset. */
async countRows(filter?: string): Promise<number> {
return await this.inner.countRows(filter);
@@ -103,24 +158,28 @@ export class Table {
await this.inner.delete(predicate);
}
/** Create an index over the columns.
/** Create an index to speed up queries.
*
* @param {string} column The column to create the index on. If not specified,
* it will create an index on vector field.
* Indices can be created on vector columns or scalar columns.
* Indices on vector columns will speed up vector searches.
* Indices on scalar columns will speed up filtering (in both
* vector and non-vector searches)
*
* @example
*
* By default, it creates vector idnex on one vector column.
* If the column has a vector (fixed size list) data type then
* an IvfPq vector index will be created.
*
* ```typescript
* const table = await conn.openTable("my_table");
* await table.createIndex().build();
* await table.createIndex(["vector"]);
* ```
*
* You can specify `IVF_PQ` parameters via `ivf_pq({})` call.
* For advanced control over vector index creation you can specify
* the index type and options.
* ```typescript
* const table = await conn.openTable("my_table");
* await table.createIndex("my_vec_col")
* await table.createIndex(["vector"], I)
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
* .build();
* ```
@@ -131,12 +190,11 @@ export class Table {
* await table.createIndex("my_float_col").build();
* ```
*/
createIndex(column?: string): IndexBuilder {
let builder = new IndexBuilder(this.inner);
if (column !== undefined) {
builder = builder.column(column);
}
return builder;
async createIndex(column: string, options?: Partial<IndexOptions>) {
// Bit of a hack to get around the fact that TS has no package-scope.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const nativeIndex = (options?.config as any)?.inner;
await this.inner.createIndex(nativeIndex, column, options?.replace);
}
/**
@@ -232,4 +290,65 @@ export class Table {
async dropColumns(columnNames: string[]): Promise<void> {
await this.inner.dropColumns(columnNames);
}
/** Retrieve the version of the table
*
* LanceDb supports versioning. Every operation that modifies the table increases
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
* version to view the data at that point. In addition, you can `[Self::restore]` the
* version to replace the current table with a previous version.
*/
async version(): Promise<number> {
return await this.inner.version();
}
/** Checks out a specific version of the Table
*
* Any read operation on the table will now access the data at the checked out version.
* As a consequence, calling this method will disable any read consistency interval
* that was previously set.
*
* This is a read-only operation that turns the table into a sort of "view"
* or "detached head". Other table instances will not be affected. To make the change
* permanent you can use the `[Self::restore]` method.
*
* Any operation that modifies the table will fail while the table is in a checked
* out state.
*
* To return the table to a normal state use `[Self::checkout_latest]`
*/
async checkout(version: number): Promise<void> {
await this.inner.checkout(version);
}
/** Ensures the table is pointing at the latest version
*
* This can be used to manually update a table when the read_consistency_interval is None
* It can also be used to undo a `[Self::checkout]` operation
*/
async checkoutLatest(): Promise<void> {
await this.inner.checkoutLatest();
}
/** Restore the table to the currently checked out version
*
* This operation will fail if checkout has not been called previously
*
* This operation will overwrite the latest version of the table with a
* previous version. Any changes made since the checked out version will
* no longer be visible.
*
* Once the operation concludes the table will no longer be in a checked
* out state and the read_consistency_interval, if any, will apply.
*/
async restore(): Promise<void> {
await this.inner.restore();
}
/**
* List all indices that have been created with Self::create_index
*/
async listIndices(): Promise<IndexConfig[]> {
return await this.inner.listIndices();
}
}

12
nodejs/src/error.rs Normal file
View File

@@ -0,0 +1,12 @@
pub type Result<T> = napi::Result<T>;
pub trait NapiErrorExt<T> {
/// Convert to a napi error using from_reason(err.to_string())
fn default_error(self) -> Result<T>;
}
impl<T> NapiErrorExt<T> for std::result::Result<T, lancedb::Error> {
fn default_error(self) -> Result<T> {
self.map_err(|err| napi::Error::from_reason(err.to_string()))
}
}

View File

@@ -14,126 +14,73 @@
use std::sync::Mutex;
use lance_linalg::distance::MetricType as LanceMetricType;
use lancedb::index::IndexBuilder as LanceDbIndexBuilder;
use lancedb::Table as LanceDbTable;
use lancedb::index::scalar::BTreeIndexBuilder;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index as LanceDbIndex;
use lancedb::DistanceType;
use napi_derive::napi;
#[napi]
pub enum IndexType {
Scalar,
IvfPq,
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
}
#[napi]
pub enum MetricType {
L2,
Cosine,
Dot,
}
impl From<MetricType> for LanceMetricType {
fn from(metric: MetricType) -> Self {
match metric {
MetricType::L2 => Self::L2,
MetricType::Cosine => Self::Cosine,
MetricType::Dot => Self::Dot,
}
impl Index {
pub fn consume(&self) -> napi::Result<LanceDbIndex> {
self.inner
.lock()
.unwrap()
.take()
.ok_or(napi::Error::from_reason(
"attempt to use an index more than once",
))
}
}
#[napi]
pub struct IndexBuilder {
inner: Mutex<Option<LanceDbIndexBuilder>>,
}
impl IndexBuilder {
fn modify(
&self,
mod_fn: impl Fn(LanceDbIndexBuilder) -> LanceDbIndexBuilder,
) -> napi::Result<()> {
let mut inner = self.inner.lock().unwrap();
let inner_builder = inner.take().ok_or_else(|| {
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
})?;
let inner_builder = mod_fn(inner_builder);
inner.replace(inner_builder);
Ok(())
}
}
#[napi]
impl IndexBuilder {
pub fn new(tbl: &LanceDbTable) -> Self {
let inner = tbl.create_index(&[]);
Self {
inner: Mutex::new(Some(inner)),
}
}
#[napi]
pub fn replace(&self, v: bool) -> napi::Result<()> {
self.modify(|b| b.replace(v))
}
#[napi]
pub fn column(&self, c: String) -> napi::Result<()> {
self.modify(|b| b.columns(&[c.as_str()]))
}
#[napi]
pub fn name(&self, name: String) -> napi::Result<()> {
self.modify(|b| b.name(name.as_str()))
}
#[napi]
impl Index {
#[napi(factory)]
pub fn ivf_pq(
&self,
metric_type: Option<MetricType>,
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> napi::Result<()> {
self.modify(|b| {
let mut b = b.ivf_pq();
if let Some(metric_type) = metric_type {
b = b.metric_type(metric_type.into());
) -> napi::Result<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
b = b.num_partitions(num_partitions);
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
b = b.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
b = b.num_bits(num_bits);
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
b = b.max_iterations(max_iterations);
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
b = b.sample_rate(sample_rate);
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
b
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
}
#[napi]
pub fn scalar(&self) -> napi::Result<()> {
self.modify(|b| b.scalar())
#[napi(factory)]
pub fn btree() -> Self {
Self {
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
}
#[napi]
pub async fn build(&self) -> napi::Result<()> {
let inner = self.inner.lock().unwrap().take().ok_or_else(|| {
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
})?;
inner
.build()
.await
.map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?;
Ok(())
}
}

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use futures::StreamExt;
use lance::io::RecordBatchStream;
use lancedb::arrow::SendableRecordBatchStream;
use lancedb::ipc::batches_to_ipc_file;
use napi::bindgen_prelude::*;
use napi_derive::napi;
@@ -21,12 +21,12 @@ use napi_derive::napi;
/** Typescript-style Async Iterator over RecordBatches */
#[napi]
pub struct RecordBatchIterator {
inner: Box<dyn RecordBatchStream + Unpin>,
inner: SendableRecordBatchStream,
}
#[napi]
impl RecordBatchIterator {
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
pub(crate) fn new(inner: SendableRecordBatchStream) -> Self {
Self { inner }
}

View File

@@ -16,6 +16,7 @@ use connection::Connection;
use napi_derive::*;
mod connection;
mod error;
mod index;
mod iterator;
mod query;

View File

@@ -74,6 +74,6 @@ impl Query {
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
Ok(RecordBatchIterator::new(inner_stream))
}
}

View File

@@ -13,13 +13,16 @@
// limitations under the License.
use arrow_ipc::writer::FileWriter;
use lance::dataset::ColumnAlteration as LanceColumnAlteration;
use lancedb::ipc::ipc_file_to_batches;
use lancedb::table::{AddDataMode, Table as LanceDbTable};
use lancedb::table::{
AddDataMode, ColumnAlteration as LanceColumnAlteration, NewColumnTransform,
Table as LanceDbTable,
};
use napi::bindgen_prelude::*;
use napi_derive::napi;
use crate::index::IndexBuilder;
use crate::error::NapiErrorExt;
use crate::index::Index;
use crate::query::Query;
#[napi]
@@ -129,8 +132,38 @@ impl Table {
}
#[napi]
pub fn create_index(&self) -> napi::Result<IndexBuilder> {
Ok(IndexBuilder::new(self.inner_ref()?))
pub async fn create_index(
&self,
index: Option<&Index>,
column: String,
replace: Option<bool>,
) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let mut builder = self.inner_ref()?.create_index(&[column], lancedb_index);
if let Some(replace) = replace {
builder = builder.replace(replace);
}
builder.execute().await.default_error()
}
#[napi]
pub async fn update(
&self,
only_if: Option<String>,
columns: Vec<(String, String)>,
) -> napi::Result<()> {
let mut op = self.inner_ref()?.update();
if let Some(only_if) = only_if {
op = op.only_if(only_if);
}
for (column_name, value) in columns {
op = op.column(column_name, value);
}
op.execute().await.default_error()
}
#[napi]
@@ -144,7 +177,7 @@ impl Table {
.into_iter()
.map(|sql| (sql.name, sql.value_sql))
.collect::<Vec<_>>();
let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms);
let transforms = NewColumnTransform::SqlExpressions(transforms);
self.inner_ref()?
.add_columns(transforms, None)
.await
@@ -197,6 +230,67 @@ impl Table {
})?;
Ok(())
}
#[napi]
pub async fn version(&self) -> napi::Result<i64> {
self.inner_ref()?
.version()
.await
.map(|val| val as i64)
.default_error()
}
#[napi]
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
self.inner_ref()?
.checkout(version as u64)
.await
.default_error()
}
#[napi]
pub async fn checkout_latest(&self) -> napi::Result<()> {
self.inner_ref()?.checkout_latest().await.default_error()
}
#[napi]
pub async fn restore(&self) -> napi::Result<()> {
self.inner_ref()?.restore().await.default_error()
}
#[napi]
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
Ok(self
.inner_ref()?
.list_indices()
.await
.default_error()?
.into_iter()
.map(IndexConfig::from)
.collect::<Vec<_>>())
}
}
#[napi(object)]
/// A description of an index currently configured on a column
pub struct IndexConfig {
/// The type of the index
pub index_type: String,
/// The columns in the index
///
/// Currently this is always an array of size 1. In the future there may
/// be more columns to represent composite indices.
pub columns: Vec<String>,
}
impl From<lancedb::index::IndexConfig> for IndexConfig {
fn from(value: lancedb::index::IndexConfig) -> Self {
let index_type = format!("{:?}", value.index_type);
Self {
index_type,
columns: value.columns,
}
}
}
/// A definition of a column alteration. The alteration changes the column at

View File

@@ -57,6 +57,7 @@ tests = [
"duckdb",
"pytz",
"polars>=0.19",
"pillow",
]
dev = ["ruff", "pre-commit"]
docs = [

View File

@@ -23,8 +23,9 @@ from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector # noqa: F401
from .utils import sentry_log # noqa: F401
from .schema import vector
from .table import AsyncTable
from .utils import sentry_log
def connect(
@@ -188,3 +189,19 @@ async def connect_async(
read_consistency_interval_secs,
)
)
__all__ = [
"connect",
"connect_async",
"AsyncConnection",
"AsyncTable",
"URI",
"sanitize_uri",
"sentry_log",
"vector",
"DBConnection",
"LanceDBConnection",
"RemoteDBConnection",
"__version__",
]

View File

@@ -1,7 +1,19 @@
from typing import Optional
from typing import Dict, List, Optional
import pyarrow as pa
class Index:
@staticmethod
def ivf_pq(
distance_type: Optional[str],
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
max_iterations: Optional[int],
sample_rate: Optional[int],
) -> Index: ...
@staticmethod
def btree() -> Index: ...
class Connection(object):
async def table_names(
self, start_after: Optional[str], limit: Optional[int]
@@ -13,10 +25,25 @@ class Connection(object):
self, name: str, mode: str, schema: pa.Schema
) -> Table: ...
class Table(object):
class Table:
def name(self) -> str: ...
def __repr__(self) -> str: ...
async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(
self, column: str, config: Optional[Index], replace: Optional[bool]
): ...
async def version(self) -> int: ...
async def checkout(self, version): ...
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
class IndexConfig:
index_type: str
columns: List[str]
async def connect(
uri: str,

View File

@@ -529,7 +529,7 @@ class AsyncConnection(object):
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
) -> AsyncTable:
"""Create a [Table][lancedb.table.Table] in the database.
Parameters

View File

@@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction):
"""
Issue concurrent requests to retrieve the image data
"""
return [
self.generate_image_embedding(image) for image in tqdm(images)
]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)

View File

@@ -0,0 +1,163 @@
from typing import Optional
from ._lancedb import (
Index as LanceDbIndex,
)
from ._lancedb import (
IndexConfig,
)
class BTree(object):
"""Describes a btree index configuration
A btree index is an index on scalar columns. The index stores a copy of the
column in sorted order. A header entry is created for each block of rows
(currently the block size is fixed at 4096). These header entries are stored
in a separate cacheable structure (a btree). To search for data the header is
used to determine which blocks need to be read from disk.
For example, a btree index in a table with 1Bi rows requires
sizeof(Scalar) * 256Ki bytes of memory and will generally need to read
sizeof(Scalar) * 4096 bytes to find the correct row ids.
This index is good for scalar columns with mostly distinct values and does best
when the query is highly selective.
The btree index does not currently have any parameters though parameters such as
the block size may be added in the future.
"""
def __init__(self):
self._inner = LanceDbIndex.btree()
class IvfPq(object):
"""Describes an IVF PQ Index
This index stores a compressed (quantized) copy of every vector. These vectors
are grouped into partitions of similar vectors. Each partition keeps track of
a centroid which is the average value of all vectors in the group.
During a query the centroids are compared with the query vector to find the
closest partitions. The compressed vectors in these partitions are then
searched to find the closest vectors.
The compression scheme is called product quantization. Each vector is divide
into subvectors and then each subvector is quantized into a small number of
bits. the parameters `num_bits` and `num_subvectors` control this process,
providing a tradeoff between index size (and thus search speed) and index
accuracy.
The partitioning process is called IVF and the `num_partitions` parameter
controls how many groups to create.
Note that training an IVF PQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
):
"""
Create an IVF PQ index config
Parameters
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
num_sub_vectors: int, default is vector dimension / 16
Number of sub-vectors of PQ.
This value controls how much the vector is compressed during the
quantization step. The more sub vectors there are the less the vector is
compressed. The default is the dimension of the vector divided by 16. If
the dimension is not evenly divisible by 16 we use the dimension divded by
8.
The above two cases are highly preferred. Having 8 or 16 values per
subvector allows us to use efficient SIMD instructions.
If the dimension is not visible by 8 then we use 1 subvector. This is not
ideal and will likely result in poor performance.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
max_iterations=max_iterations,
sample_rate=sample_rate,
)
__all__ = ["BTree", "IvfPq", "IndexConfig"]

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import inspect
import io
import sys
import types
from abc import ABC, abstractmethod
@@ -26,7 +27,9 @@ from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Tuple,
Type,
Union,
_GenericAlias,
@@ -36,19 +39,30 @@ import numpy as np
import pyarrow as pa
import pydantic
import semver
from lance.arrow import (
EncodedImageType,
)
from lance.util import _check_huggingface
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
from .util import attempt_import_or_raise
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionConfig
try:
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
class FixedSizeListMixin(ABC):
@staticmethod
@@ -123,7 +137,7 @@ def Vector(
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
) -> "CoreSchema":
return core_schema.no_info_after_validator_function(
cls,
core_schema.list_schema(
@@ -181,24 +195,117 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child, field))
elif _safe_is_huggingface_image():
import datasets
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
)
if PYDANTIC_VERSION.major < 2:
class ImageMixin(ABC):
@staticmethod
@abstractmethod
def value_arrow_type() -> pa.DataType:
raise NotImplementedError
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
def EncodedImage():
attempt_import_or_raise("PIL", "pillow or pip install lancedb[embeddings]")
import PIL.Image
class EncodedImage(bytes, ImageMixin):
"""Pydantic type for inlined images.
!!! warning
Experimental feature.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import EncodedImage
...
>>> class MyModel(pydantic.BaseModel):
... image: EncodedImage()
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("image", pa.binary(), False)
... ])
"""
def __repr__(self):
return "EncodedImage()"
@staticmethod
def value_arrow_type() -> pa.DataType:
return EncodedImageType()
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> "CoreSchema":
from_bytes_schema = core_schema.bytes_schema()
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(PIL.Image.Image),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: cls.validate(instance)
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
) -> "JsonSchemaValue":
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, bytes):
return v
if isinstance(v, PIL.Image.Image):
with io.BytesIO() as output:
v.save(output, format=v.format)
return output.getvalue()
raise TypeError(
"EncodedImage can take bytes or PIL.Image.Image "
f"as input but got {type(v)}"
)
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
return EncodedImage
if PYDANTIC_VERSION.major < 2:
def _safe_get_fields(model: pydantic.BaseModel):
return model.__fields__
else:
def _safe_get_fields(model: pydantic.BaseModel):
return model.model_fields
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field)
for name, field in model.model_fields.items()
for name, field in _safe_get_fields(model).items()
]
@@ -230,6 +337,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
elif issubclass(field.annotation, ImageMixin):
return field.annotation.value_arrow_type()
return _py_type_to_arrow_type(field.annotation, field)
@@ -335,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
"""
Get the field names of this model.
"""
return list(cls.safe_get_fields().keys())
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
return list(_safe_get_fields(cls).keys())
@classmethod
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
@@ -351,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
from .embeddings import EmbeddingFunctionConfig
vec_and_function = []
for name, field_info in cls.safe_get_fields().items():
func = get_extras(field_info, "vector_column_for")
def get_vector_column(name, field_info):
fun = get_extras(field_info, "vector_column_for")
if func is not None:
vec_and_function.append([name, func])
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
configs = []
# find the source columns for each one
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
def get_source_column(source, field_info):
src_func = get_extras(field_info, "source_column_for")
if src_func is func:
# note we can't use == here since the function is a pydantic
@@ -371,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
source_column=source, vector_column=vec, function=func
)
)
visit_fields(_safe_get_fields(cls).items(), get_source_column)
return configs
def get_extras(field_info: FieldInfo, key: str) -> Any:
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
visitor: Callable[[str, FieldInfo], Any]):
"""
Get the extra metadata from a Pydantic FieldInfo.
Visit all the leaf fields in a Pydantic model.
Parameters
----------
fields : Iterable[Tuple(str, FieldInfo)]
The fields to visit.
visitor : Callable[[str, FieldInfo], Any]
The visitor function.
"""
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
for name, field_info in fields:
# if the field is a pydantic model then
# visit all subfields
if (isinstance(getattr(field_info, "annotation"), type)
and issubclass(field_info.annotation, pydantic.BaseModel)):
visit_fields(_safe_get_fields(field_info.annotation).items(),
_add_prefix(visitor, name))
else:
visitor(name, field_info)
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
def prefixed_visitor(name: str, field: FieldInfo):
return visitor(f"{prefix}.{name}", field)
return prefixed_visitor
if PYDANTIC_VERSION.major < 2:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
@@ -393,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
else:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.json_schema_extra or {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.

View File

@@ -37,6 +37,7 @@ import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs as pa_fs
from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
from ._lancedb import Table as LanceDBTable
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq
pd = safe_import_pandas()
@@ -73,7 +75,16 @@ def _sanitize_data(
on_bad_vectors: str,
fill_value: Any,
):
if isinstance(data, list):
import pdb; pdb.set_trace()
if _check_for_hugging_face(data):
# Huggingface datasets
import datasets
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = data.data.to_batches()
elif isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
@@ -135,12 +146,11 @@ def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
):
for batch in data:
if not isinstance(batch, pa.RecordBatch):
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches():
yield batch
else:
yield batch
class Table(ABC):
@@ -1917,112 +1927,48 @@ class AsyncTable:
raise NotImplementedError
async def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int, default 256
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int, default 96
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
- If True, replace the existing index if it exists.
- If False, raise an error if duplicate index exists.
accelerator: str, default None
If set, use the given accelerator to create the index.
Only support "cuda" for now.
index_cache_size : int, optional
The size of the index cache in number of entries. Default value is 256.
"""
raise NotImplementedError
async def create_scalar_index(
self,
column: str,
*,
replace: bool = True,
replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree]] = None,
):
"""Create a scalar index on a column.
"""Create an index to speed up queries
Scalar indices, like vector indices, can be used to speed up scans. A scalar
index can speed up scans that contain filter expressions on the indexed column.
For example, the following scan will be faster if the column ``my_col`` has
a scalar index:
.. code-block:: python
import lancedb
db = lancedb.connect("/data/lance")
img_table = db.open_table("images")
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
Scalar indices can also speed up scans containing a vector search and a
prefilter:
.. code-block::python
import lancedb
db = lancedb.connect("/data/lance")
img_table = db.open_table("images")
img_table.search([1, 2, 3, 4], vector_column_name="vector")
.where("my_col != 7", prefilter=True)
.to_pandas()
Scalar indices can only speed up scans for basic filters using
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
membership (e.g. `my_col IN (0, 1, 2)`)
Scalar indices can be used if the filter contains multiple indexed columns and
the filter criteria are AND'd or OR'd together
(e.g. ``my_col < 0 AND other_col> 100``)
Scalar indices may be used if the filter contains non-indexed columns but,
depending on the structure of the filter, they may not be usable. For example,
if the column ``not_indexed`` does not have a scalar index then the filter
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
``my_col``.
**Experimental API**
Indices can be created on vector columns or scalar columns.
Indices on vector columns will speed up vector searches.
Indices on scalar columns will speed up filtering (in both
vector and non-vector searches)
Parameters
----------
column : str
The column to be indexed. Must be a boolean, integer, float,
or string column.
replace : bool, default True
Replace the existing index if it exists.
index: Index
The index to create.
Examples
--------
LanceDb supports multiple types of indices. See the static methods on
the Index class for more details.
column: str, default None
The column to index.
.. code-block:: python
When building a scalar index this must be set.
import lance
When building a vector index, this is optional. The default will look
for any columns of type fixed-size-list with floating point values. If
there is only one column of this type then it will be used. Otherwise
an error will be returned.
replace: bool, default True
Whether to replace the existing index
dataset = lance.dataset("./images.lance")
dataset.create_scalar_index("category")
If this is false, and another index already exists on the same columns
and the same name, then an error will be returned. This is true even if
that index is out of date.
The default is True
"""
raise NotImplementedError
index = None
if config is not None:
index = config._inner
await self._inner.create_index(column, index=index, replace=replace)
async def add(
self,
@@ -2066,6 +2012,8 @@ class AsyncTable:
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
await self._inner.add(data, mode)
register_event("add")
@@ -2275,58 +2223,57 @@ class AsyncTable:
async def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
updates: Optional[Dict[str, Any]] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
where: Optional[str] = None,
updates_sql: Optional[Dict[str, str]] = None,
):
"""
This can be used to update zero to all rows depending on how many
rows match the where clause. If no where clause is provided, then
all rows will be updated.
This can be used to update zero to all rows in the table.
Either `values` or `values_sql` must be provided. You cannot provide
both.
If a filter is provided with `where` then only rows matching the
filter will be updated. Otherwise all rows will be updated.
Parameters
----------
updates: dict, optional
The updates to apply. The keys should be the name of the column to
update. The values should be the new values to assign. This is
required unless updates_sql is supplied.
where: str, optional
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict, optional
The values to update. The keys are the column names and the values
are the values to set.
values_sql: dict, optional
The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
An SQL filter that controls which rows are updated. For example, 'x = 2'
or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated.
updates_sql: dict, optional
The updates to apply, expressed as SQL expression strings. The keys should
be column names. The values should be SQL expressions. These can be SQL
literals (e.g. "7" or "'foo'") or they can be expressions based on the
previous value of the row (e.g. "x + 1" to increment the x column by 1)
Examples
--------
>>> import asyncio
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
>>> table.update(values_sql={"x": "x + 1"})
>>> table.to_pandas()
x vector
0 2 [1.0, 2.0]
1 4 [5.0, 6.0]
2 3 [10.0, 10.0]
>>> async def demo_update():
... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]})
... db = await lancedb.connect_async("./.lancedb")
... table = await db.create_table("my_table", data)
... # x is [1, 2], vector is [[1, 2], [3, 4]]
... await table.update({"vector": [10, 10]}, where="x = 2")
... # x is [1, 2], vector is [[1, 2], [10, 10]]
... await table.update(updates_sql={"x": "x + 1"})
... # x is [2, 3], vector is [[1, 2], [10, 10]]
>>> asyncio.run(demo_update())
"""
raise NotImplementedError
if updates is not None and updates_sql is not None:
raise ValueError("Only one of updates or updates_sql can be provided")
if updates is None and updates_sql is None:
raise ValueError("Either updates or updates_sql must be provided")
if updates is not None:
updates_sql = {k: value_to_sql(v) for k, v in updates.items()}
return await self._inner.update(updates_sql, where)
async def cleanup_old_versions(
self,
@@ -2423,3 +2370,65 @@ class AsyncTable:
The names of the columns to drop.
"""
raise NotImplementedError
async def version(self) -> int:
"""
Retrieve the version of the table
LanceDb supports versioning. Every operation that modifies the table increases
version. As long as a version hasn't been deleted you can `[Self::checkout]`
that version to view the data at that point. In addition, you can
`[Self::restore]` the version to replace the current table with a previous
version.
"""
return await self._inner.version()
async def checkout(self, version):
"""
Checks out a specific version of the Table
Any read operation on the table will now access the data at the checked out
version. As a consequence, calling this method will disable any read consistency
interval that was previously set.
This is a read-only operation that turns the table into a sort of "view"
or "detached head". Other table instances will not be affected. To make the
change permanent you can use the `[Self::restore]` method.
Any operation that modifies the table will fail while the table is in a checked
out state.
To return the table to a normal state use `[Self::checkout_latest]`
"""
await self._inner.checkout(version)
async def checkout_latest(self):
"""
Ensures the table is pointing at the latest version
This can be used to manually update a table when the read_consistency_interval
is None
It can also be used to undo a `[Self::checkout]` operation
"""
await self._inner.checkout_latest()
async def restore(self):
"""
Restore the table to the currently checked out version
This operation will fail if checkout has not been called previously
This operation will overwrite the latest version of the table with a
previous version. Any changes made since the checked out version will
no longer be visible.
Once the operation concludes the table will no longer be in a checked
out state and the read_consistency_interval, if any, will apply.
"""
await self._inner.restore()
async def list_indices(self) -> IndexConfig:
"""
List all indices that have been created with Self::create_index
"""
return await self._inner.list_indices()

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 B

View File

@@ -0,0 +1,69 @@
from datetime import timedelta
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq
@pytest_asyncio.fixture
async def db_async(tmp_path) -> AsyncConnection:
return await connect_async(tmp_path, read_consistency_interval=timedelta(seconds=0))
def sample_fixed_size_list_array(nrows, dim):
vector_data = pa.array([float(i) for i in range(dim * nrows)], pa.float32())
return pa.FixedSizeListArray.from_arrays(vector_data, dim)
DIM = 8
NROWS = 256
@pytest_asyncio.fixture
async def some_table(db_async):
data = pa.Table.from_pydict(
{
"id": list(range(256)),
"vector": sample_fixed_size_list_array(NROWS, DIM),
}
)
return await db_async.create_table(
"some_table",
data,
)
@pytest.mark.asyncio
async def test_create_scalar_index(some_table: AsyncTable):
# Can create
await some_table.create_index("id")
# Can recreate if replace=True
await some_table.create_index("id", replace=True)
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"]
# Can't recreate if replace=False
with pytest.raises(RuntimeError, match="already exists"):
await some_table.create_index("id", replace=False)
# can also specify index type
await some_table.create_index("id", config=BTree())
@pytest.mark.asyncio
async def test_create_vector_index(some_table: AsyncTable):
# Can create
await some_table.create_index("vector")
# Can recreate if replace=True
await some_table.create_index("vector", replace=True)
# Can't recreate if replace=False
with pytest.raises(RuntimeError, match="already exists"):
await some_table.create_index("vector", replace=False)
# Can also specify index type
await some_table.create_index("vector", config=IvfPq(num_partitions=100))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfPq"
assert indices[0].columns == ["vector"]

View File

@@ -12,17 +12,27 @@
# limitations under the License.
import io
import json
import os
import sys
from datetime import date, datetime
from pathlib import Path
from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from pydantic import Field
from lancedb.pydantic import (
PYDANTIC_VERSION,
EncodedImage,
LanceModel,
Vector,
pydantic_to_schema,
)
@pytest.mark.skipif(
sys.version_info < (3, 9),
@@ -243,3 +253,23 @@ def test_lance_model():
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_schema_with_images():
pytest.importorskip("PIL")
import PIL.Image
class TestModel(LanceModel):
img: EncodedImage()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
with open(img_path, "rb") as f:
img_bytes = f.read()
m1 = TestModel(img=PIL.Image.open(img_path))
m2 = TestModel(img=img_bytes)
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
assert tobytes(m1) == tobytes(m2)

View File

@@ -12,6 +12,8 @@
# limitations under the License.
import functools
import io
import os
from copy import copy
from datetime import date, datetime, timedelta
from pathlib import Path
@@ -20,19 +22,21 @@ from typing import List
from unittest.mock import PropertyMock, patch
import lance
import lancedb
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
import pytest_asyncio
from lance.arrow import EncodedImageType
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import AsyncConnection, LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import EncodedImage, LanceModel, Vector
from lancedb.table import LanceTable
from pydantic import BaseModel
class MockDB:
@@ -85,12 +89,30 @@ async def test_close(db_async: AsyncConnection):
assert str(table) == "ClosedTable(some_table)"
@pytest.mark.asyncio
async def test_update_async(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
assert await table.count_rows("id == 0") == 1
assert await table.count_rows("id == 7") == 0
await table.update({"id": 7})
assert await table.count_rows("id == 7") == 1
assert await table.count_rows("id == 0") == 0
await table.add([{"id": 2}])
await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
assert await table.count_rows("id == 7") == 1
assert await table.count_rows("id == 2") == 0
assert await table.count_rows("id == 5") == 1
await table.update({"id": 10}, where="id == 5")
assert await table.count_rows("id == 10") == 1
def test_create_table(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
pa.field("encoded_image", EncodedImageType()),
]
)
expected = pa.Table.from_arrays(
@@ -98,13 +120,26 @@ def test_create_table(db):
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
pa.ExtensionArray.from_storage(
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
),
],
schema=schema,
)
data = [
[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
{
"vector": [3.1, 4.1],
"item": "foo",
"price": 10.0,
"encoded_image": b"foo",
},
{
"vector": [5.9, 26.5],
"item": "bar",
"price": 20.0,
"encoded_image": b"bar",
},
]
]
df = pd.DataFrame(data[0])
@@ -974,3 +1009,56 @@ def test_drop_columns(tmp_path):
table = LanceTable.create(db, "my_table", data=data)
table.drop_columns(["category"])
assert table.to_arrow().column_names == ["id"]
@pytest.mark.asyncio
async def test_time_travel(db_async: AsyncConnection):
# Setup
table = await db_async.create_table("some_table", data=[{"id": 0}])
version = await table.version()
await table.add([{"id": 1}])
assert await table.count_rows() == 2
# Make sure we can rewind
await table.checkout(version)
assert await table.count_rows() == 1
# Can't add data in time travel mode
with pytest.raises(
ValueError,
match="table cannot be modified when a specific version is checked out",
):
await table.add([{"id": 2}])
# Can go back to normal mode
await table.checkout_latest()
assert await table.count_rows() == 2
# Should be able to add data again
await table.add([{"id": 3}])
assert await table.count_rows() == 3
# Now checkout and restore
await table.checkout(version)
await table.restore()
assert await table.count_rows() == 1
# Should be able to add data
await table.add([{"id": 4}])
assert await table.count_rows() == 2
# Can't use restore if not checked out
with pytest.raises(ValueError, match="checkout before running restore"):
await table.restore()
def test_add_image(tmp_path):
pytest.importorskip("PIL")
import PIL.Image
db = lancedb.connect(tmp_path)
class TestModel(LanceModel):
img: EncodedImage()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
m1 = TestModel(img=PIL.Image.open(img_path))
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
table = LanceTable.create(db, "my_table", schema=TestModel)
table.add([m1])

109
python/src/index.rs Normal file
View File

@@ -0,0 +1,109 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Mutex;
use lancedb::{
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
DistanceType,
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyResult,
};
#[pyclass]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
}
impl Index {
pub fn consume(&self) -> PyResult<LanceDbIndex> {
self.inner
.lock()
.unwrap()
.take()
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
}
}
#[pymethods]
impl Index {
#[staticmethod]
pub fn ivf_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> PyResult<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
}
#[staticmethod]
pub fn btree() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
})
}
}
#[pyclass(get_all)]
/// A description of an index currently configured on a column
pub struct IndexConfig {
/// The type of the index
pub index_type: String,
/// The columns in the index
///
/// Currently this is always a list of size 1. In the future there may
/// be more columns to represent composite indices.
pub columns: Vec<String>,
}
impl From<lancedb::index::IndexConfig> for IndexConfig {
fn from(value: lancedb::index::IndexConfig) -> Self {
let index_type = format!("{:?}", value.index_type);
Self {
index_type,
columns: value.columns,
}
}
}

View File

@@ -14,11 +14,15 @@
use connection::{connect, Connection};
use env_logger::Env;
use index::{Index, IndexConfig};
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
use table::Table;
pub mod connection;
pub mod error;
pub mod index;
pub mod table;
pub mod util;
#[pymodule]
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
@@ -27,6 +31,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
.write_style("LANCEDB_LOG_STYLE");
env_logger::init_from_env(env);
m.add_class::<Connection>()?;
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_class::<IndexConfig>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

View File

@@ -5,11 +5,16 @@ use arrow::{
use lancedb::table::{AddDataMode, Table as LanceDbTable};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyAny, PyRef, PyResult, Python,
pyclass, pymethods,
types::{PyDict, PyString},
PyAny, PyRef, PyResult, Python,
};
use pyo3_asyncio::tokio::future_into_py;
use crate::error::PythonErrorExt;
use crate::{
error::PythonErrorExt,
index::{Index, IndexConfig},
};
#[pyclass]
pub struct Table {
@@ -74,6 +79,28 @@ impl Table {
})
}
pub fn update<'a>(
self_: PyRef<'a, Self>,
updates: &PyDict,
r#where: Option<String>,
) -> PyResult<&'a PyAny> {
let mut op = self_.inner_ref()?.update();
if let Some(only_if) = r#where {
op = op.only_if(only_if);
}
for (column_name, value) in updates.into_iter() {
let column_name: &PyString = column_name.downcast()?;
let column_name = column_name.to_str()?.to_string();
let value: &PyString = value.downcast()?;
let value = value.to_str()?.to_string();
op = op.column(column_name, value);
}
future_into_py(self_.py(), async move {
op.execute().await.infer_error()?;
Ok(())
})
}
pub fn count_rows(self_: PyRef<'_, Self>, filter: Option<String>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
@@ -81,10 +108,75 @@ impl Table {
})
}
pub fn create_index<'a>(
self_: PyRef<'a, Self>,
column: String,
index: Option<&Index>,
replace: Option<bool>,
) -> PyResult<&'a PyAny> {
let index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let mut op = self_.inner_ref()?.create_index(&[column], index);
if let Some(replace) = replace {
op = op.replace(replace);
}
future_into_py(self_.py(), async move {
op.execute().await.infer_error()?;
Ok(())
})
}
pub fn list_indices(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
Ok(inner
.list_indices()
.await
.infer_error()?
.into_iter()
.map(IndexConfig::from)
.collect::<Vec<_>>())
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),
Some(inner) => inner.to_string(),
}
}
pub fn version(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(
self_.py(),
async move { inner.version().await.infer_error() },
)
}
pub fn checkout(self_: PyRef<'_, Self>, version: u64) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.checkout(version).await.infer_error()
})
}
pub fn checkout_latest(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.checkout_latest().await.infer_error()
})
}
pub fn restore(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(
self_.py(),
async move { inner.restore().await.infer_error() },
)
}
}

35
python/src/util.rs Normal file
View File

@@ -0,0 +1,35 @@
use std::sync::Mutex;
use pyo3::{exceptions::PyRuntimeError, PyResult};
/// A wrapper around a rust builder
///
/// Rust builders are often implemented so that the builder methods
/// consume the builder and return a new one. This is not compatible
/// with the pyo3, which, being garbage collected, cannot easily obtain
/// ownership of an object.
///
/// This wrapper converts the compile-time safety of rust into runtime
/// errors if any attempt to use the builder happens after it is consumed.
pub struct BuilderWrapper<T> {
name: String,
inner: Mutex<Option<T>>,
}
impl<T> BuilderWrapper<T> {
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
Self {
name: name.as_ref().to_string(),
inner: Mutex::new(Some(inner)),
}
}
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
let mut inner = self.inner.lock().unwrap();
let inner_builder = inner.take().ok_or_else(|| {
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
})?;
let result = mod_fn(inner_builder);
Ok(result)
}
}

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lancedb::index::{scalar::BTreeIndexBuilder, Index};
use neon::{
context::{Context, FunctionContext},
result::JsResult,
@@ -33,9 +34,9 @@ pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise>
rt.spawn(async move {
let idx_result = table
.create_index(&[&column])
.create_index(&[column], Index::BTree(BTreeIndexBuilder::default()))
.replace(replace)
.build()
.execute()
.await;
deferred.settle_with(&channel, move |mut cx| {

View File

@@ -13,12 +13,12 @@
// limitations under the License.
use lance_linalg::distance::MetricType;
use lancedb::index::IndexBuilder;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index;
use neon::context::FunctionContext;
use neon::prelude::*;
use std::convert::TryFrom;
use crate::error::Error::InvalidIndexType;
use crate::error::ResultExt;
use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::runtime;
@@ -39,13 +39,20 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise>
.map(|s| s.value(&mut cx))
.unwrap_or("vector".to_string()); // Backward compatibility
let replace = index_params
.get_opt::<JsBoolean, _, _>(&mut cx, "replace")?
.map(|r| r.value(&mut cx));
let tbl = table.clone();
let index_builder = tbl.create_index(&[&column_name]);
let index_builder =
get_index_params_builder(&mut cx, index_params, index_builder).or_throw(&mut cx)?;
let ivf_pq_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?;
let mut index_builder = tbl.create_index(&[column_name], Index::IvfPq(ivf_pq_builder));
if let Some(replace) = replace {
index_builder = index_builder.replace(replace);
}
rt.spawn(async move {
let idx_result = index_builder.build().await;
let idx_result = index_builder.execute().await;
deferred.settle_with(&channel, move |mut cx| {
idx_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
@@ -57,26 +64,17 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise>
fn get_index_params_builder(
cx: &mut FunctionContext,
obj: Handle<JsObject>,
builder: IndexBuilder,
) -> crate::error::Result<IndexBuilder> {
let mut builder = match obj.get::<JsString, _, _>(cx, "type")?.value(cx).as_str() {
"ivf_pq" => builder.ivf_pq(),
_ => {
return Err(InvalidIndexType {
index_type: "".into(),
})
) -> crate::error::Result<IvfPqIndexBuilder> {
if obj.get_opt::<JsString, _, _>(cx, "index_name")?.is_some() {
return Err(crate::error::Error::LanceDB {
message: "Setting the index_name is no longer supported".to_string(),
});
}
};
if let Some(index_name) = obj.get_opt::<JsString, _, _>(cx, "index_name")? {
builder = builder.name(index_name.value(cx).as_str());
}
let mut builder = IvfPqIndexBuilder::default();
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
builder = builder.metric_type(metric_type);
builder = builder.distance_type(metric_type);
}
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
builder = builder.num_partitions(np);
}
@@ -86,11 +84,5 @@ fn get_index_params_builder(
if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? {
builder = builder.max_iterations(max_iters);
}
if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? {
builder = builder.num_bits(num_bits);
}
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
builder = builder.replace(replace.value(cx));
}
Ok(builder)
}

View File

@@ -297,11 +297,14 @@ impl JsTable {
let predicate = predicate.as_deref();
let update_result = table
.as_native()
.unwrap()
.update(predicate, updates_arg)
.await;
let mut update_op = table.update();
if let Some(predicate) = predicate {
update_op = update_op.only_if(predicate);
}
for (column, value) in updates_arg {
update_op = update_op.column(column, value);
}
let update_result = update_op.execute().await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table)))

View File

@@ -26,6 +26,7 @@ lance = { workspace = true }
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log.workspace = true
async-trait = "0"

View File

@@ -20,6 +20,7 @@ use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::{connect, Result, Table as LanceDbTable};
#[tokio::main]
@@ -142,23 +143,18 @@ async fn create_empty_table(db: &Connection) -> Result<LanceDbTable> {
async fn create_index(table: &LanceDbTable) -> Result<()> {
// --8<-- [start:create_index]
table
.create_index(&["vector"])
.ivf_pq()
.num_partitions(8)
.build()
.await
table.create_index(&["vector"], Index::Auto).execute().await
// --8<-- [end:create_index]
}
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
// --8<-- [start:search]
Ok(table
table
.search(&[1.0; 128])
.limit(2)
.execute_stream()
.await?
.try_collect::<Vec<_>>()
.await?)
.await
// --8<-- [end:search]
}

View File

@@ -12,4 +12,92 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub use lance::arrow::*;
use std::{pin::Pin, sync::Arc};
pub use arrow_array;
pub use arrow_schema;
use futures::{Stream, StreamExt};
use crate::error::Result;
/// An iterator of batches that also has a schema
pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> {
/// Returns the schema of this `RecordBatchReader`.
///
/// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
/// reader should have the same schema as returned from this method.
fn schema(&self) -> Arc<arrow_schema::Schema>;
}
/// A simple RecordBatchReader formed from the two parts (iterator + schema)
pub struct SimpleRecordBatchReader<I: Iterator<Item = Result<arrow_array::RecordBatch>>> {
pub schema: Arc<arrow_schema::Schema>,
pub batches: I,
}
impl<I: Iterator<Item = Result<arrow_array::RecordBatch>>> Iterator for SimpleRecordBatchReader<I> {
type Item = Result<arrow_array::RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.batches.next()
}
}
impl<I: Iterator<Item = Result<arrow_array::RecordBatch>>> RecordBatchReader
for SimpleRecordBatchReader<I>
{
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.schema.clone()
}
}
/// A stream of batches that also has a schema
pub trait RecordBatchStream: Stream<Item = Result<arrow_array::RecordBatch>> {
/// Returns the schema of this `RecordBatchStream`.
///
/// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
/// stream should have the same schema as returned from this method.
fn schema(&self) -> Arc<arrow_schema::Schema>;
}
/// A boxed RecordBatchStream that is also Send
pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
impl<I: lance::io::RecordBatchStream + 'static> From<I> for SendableRecordBatchStream {
fn from(stream: I) -> Self {
let schema = stream.schema();
let mapped_stream = Box::pin(stream.map(|r| r.map_err(Into::into)));
Box::pin(SimpleRecordBatchStream {
schema,
stream: mapped_stream,
})
}
}
/// A simple RecordBatchStream formed from the two parts (stream + schema)
#[pin_project::pin_project]
pub struct SimpleRecordBatchStream<S: Stream<Item = Result<arrow_array::RecordBatch>>> {
pub schema: Arc<arrow_schema::Schema>,
#[pin]
pub stream: S,
}
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> Stream for SimpleRecordBatchStream<S> {
type Item = Result<arrow_array::RecordBatch>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
this.stream.poll_next(cx)
}
}
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> RecordBatchStream
for SimpleRecordBatchStream<S>
{
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.schema.clone()
}
}

View File

@@ -356,6 +356,15 @@ pub struct ConnectBuilder {
aws_creds: Option<AwsCredential>,
/// The interval at which to check for updates from other processes.
///
/// If None, then consistency is not checked. For performance
/// reasons, this is the default. For strong consistency, set this to
/// zero seconds. Then every read will check for updates from other
/// processes. As a compromise, you can set this to a non-zero timedelta
/// for eventual consistency. If more than that interval has passed since
/// the last check, then the table will be checked for updates. Note: this
/// consistency only applies to read operations. Write operations are
/// always consistent.
read_consistency_interval: Option<std::time::Duration>,
}

View File

@@ -12,181 +12,69 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::max, sync::Arc};
use lance_index::IndexType;
pub use lance_linalg::distance::MetricType;
pub mod vector;
use std::sync::Arc;
use crate::{table::TableInternal, Result};
/// Index Parameters.
pub enum IndexParams {
Scalar {
replace: bool,
},
IvfPq {
replace: bool,
metric_type: MetricType,
num_partitions: u64,
num_sub_vectors: u32,
num_bits: u32,
sample_rate: u32,
max_iterations: u32,
},
use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder};
pub mod scalar;
pub mod vector;
pub enum Index {
Auto,
BTree(BTreeIndexBuilder),
IvfPq(IvfPqIndexBuilder),
}
/// Builder for Index Parameters.
/// Builder for the create_index operation
///
/// The methods on this builder are used to specify options common to all indices.
pub struct IndexBuilder {
parent: Arc<dyn TableInternal>,
pub(crate) index: Index,
pub(crate) columns: Vec<String>,
// General parameters
/// Index name.
pub(crate) name: Option<String>,
/// Replace the existing index.
pub(crate) replace: bool,
pub(crate) index_type: IndexType,
// Scalar index parameters
// Nothing to set here.
// IVF_PQ parameters
pub(crate) metric_type: MetricType,
pub(crate) num_partitions: Option<u32>,
// PQ related
pub(crate) num_sub_vectors: Option<u32>,
pub(crate) num_bits: u32,
/// The rate to find samples to train kmeans.
pub(crate) sample_rate: u32,
/// Max iteration to train kmeans.
pub(crate) max_iterations: u32,
}
impl IndexBuilder {
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: &[&str]) -> Self {
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
Self {
parent,
columns: columns.iter().map(|c| c.to_string()).collect(),
name: None,
index,
columns,
replace: true,
index_type: IndexType::Scalar,
metric_type: MetricType::L2,
num_partitions: None,
num_sub_vectors: None,
num_bits: 8,
sample_rate: 256,
max_iterations: 50,
}
}
/// Build a Scalar Index.
/// Whether to replace the existing index, the default is `true`.
///
/// Accepted parameters:
/// - `replace`: Replace the existing index.
/// - `name`: Index name. Default: `None`
pub fn scalar(mut self) -> Self {
self.index_type = IndexType::Scalar;
self
}
/// Build an IVF PQ index.
///
/// Accepted parameters:
/// - `replace`: Replace the existing index.
/// - `name`: Index name. Default: `None`
/// - `metric_type`: [MetricType] to use to build Vector Index.
/// - `num_partitions`: Number of IVF partitions.
/// - `num_sub_vectors`: Number of sub-vectors of PQ.
/// - `num_bits`: Number of bits used for PQ centroids.
/// - `sample_rate`: The rate to find samples to train kmeans.
/// - `max_iterations`: Max iteration to train kmeans.
pub fn ivf_pq(mut self) -> Self {
self.index_type = IndexType::Vector;
self
}
/// The columns to build index on.
pub fn columns(mut self, cols: &[&str]) -> Self {
self.columns = cols.iter().map(|s| s.to_string()).collect();
self
}
/// Whether to replace the existing index, default is `true`.
/// If this is false, and another index already exists on the same columns
/// and the same name, then an error will be returned. This is true even if
/// that index is out of date.
pub fn replace(mut self, v: bool) -> Self {
self.replace = v;
self
}
/// Set the index name.
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
pub async fn execute(self) -> Result<()> {
self.parent.clone().create_index(self).await
}
}
/// [MetricType] to use to build Vector Index.
#[derive(Debug, Clone, PartialEq)]
pub enum IndexType {
IvfPq,
BTree,
}
/// A description of an index currently configured on a column
pub struct IndexConfig {
/// The type of the index
pub index_type: IndexType,
/// The columns in the index
///
/// Default value is [MetricType::L2].
pub fn metric_type(mut self, metric_type: MetricType) -> Self {
self.metric_type = metric_type;
self
}
/// Number of IVF partitions.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// Number of sub-vectors of PQ.
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
/// Number of bits used for PQ centroids.
pub fn num_bits(mut self, num_bits: u32) -> Self {
self.num_bits = num_bits;
self
}
/// The rate to find samples to train kmeans.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iteration to train kmeans.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
/// Build the parameters.
pub async fn build(self) -> Result<()> {
self.parent.clone().do_create_index(self).await
}
}
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
let num_partitions = (rows as f64).sqrt() as u32;
max(1, num_partitions)
}
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
// Should be more aggressive than this default.
dim / 16
} else if dim % 8 == 0 {
dim / 8
} else {
log::warn!(
"The dimension of the vector is not divisible by 8 or 16, \
which may cause performance degradation in PQ"
);
1
}
/// Currently this is always a Vec of size 1. In the future there may
/// be more columns to represent composite indices.
pub columns: Vec<String>,
}

View File

@@ -0,0 +1,30 @@
//! Scalar indices are exact indices that are used to quickly satisfy a variety of filters
//! against a column of scalar values.
//!
//! Scalar indices are currently supported on numeric, string, boolean, and temporal columns.
//!
//! A scalar index will help with queries with filters like `x > 10`, `x < 10`, `x = 10`,
//! etc. Scalar indices can also speed up prefiltering for vector searches. A single
//! vector search with prefiltering can use both a scalar index and a vector index.
/// Builder for a btree index
///
/// A btree index is an index on scalar columns. The index stores a copy of the column
/// in sorted order. A header entry is created for each block of rows (currently the
/// block size is fixed at 4096). These header entries are stored in a separate
/// cacheable structure (a btree). To search for data the header is used to determine
/// which blocks need to be read from disk.
///
/// For example, a btree index in a table with 1Bi rows requires sizeof(Scalar) * 256Ki
/// bytes of memory and will generally need to read sizeof(Scalar) * 4096 bytes to find
/// the correct row ids.
///
/// This index is good for scalar columns with mostly distinct values and does best when
/// the query is highly selective.
///
/// The btree index does not currently have any parameters though parameters such as the
/// block size may be added in the future.
#[derive(Default, Debug, Clone)]
pub struct BTreeIndexBuilder {}
impl BTreeIndexBuilder {}

View File

@@ -12,10 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Vector indices are approximate indices that are used to find rows similar to
//! a query vector. Vector indices speed up vector searches.
//!
//! Vector indices are only supported on fixed-size-list (tensor) columns of floating point
//! values
use std::cmp::max;
use serde::Deserialize;
use lance::table::format::{Index, Manifest};
use crate::DistanceType;
pub struct VectorIndex {
pub columns: Vec<String>,
pub index_name: String,
@@ -42,3 +51,145 @@ pub struct VectorIndexStatistics {
pub num_indexed_rows: usize,
pub num_unindexed_rows: usize,
}
/// Builder for an IVF PQ index.
///
/// This index stores a compressed (quantized) copy of every vector. These vectors
/// are grouped into partitions of similar vectors. Each partition keeps track of
/// a centroid which is the average value of all vectors in the group.
///
/// During a query the centroids are compared with the query vector to find the closest
/// partitions. The compressed vectors in these partitions are then searched to find
/// the closest vectors.
///
/// The compression scheme is called product quantization. Each vector is divided into
/// subvectors and then each subvector is quantized into a small number of bits. the
/// parameters `num_bits` and `num_subvectors` control this process, providing a tradeoff
/// between index size (and thus search speed) and index accuracy.
///
/// The partitioning process is called IVF and the `num_partitions` parameter controls how
/// many groups to create.
///
/// Note that training an IVF PQ index on a large dataset is a slow operation and
/// currently is also a memory intensive operation.
#[derive(Debug, Clone)]
pub struct IvfPqIndexBuilder {
pub(crate) distance_type: DistanceType,
pub(crate) num_partitions: Option<u32>,
pub(crate) num_sub_vectors: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
}
impl Default for IvfPqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
num_sub_vectors: None,
sample_rate: 256,
max_iterations: 50,
}
}
}
impl IvfPqIndexBuilder {
/// [DistanceType] to use to build the index.
///
/// Default value is [DistanceType::L2].
///
/// This is used when training the index to calculate the IVF partitions (vectors are
/// grouped in partitions with similar vectors according to this distance type) and to
/// calculate a subvector's code during quantization.
///
/// The metric type used to train an index MUST match the metric type used to search the
/// index. Failure to do so will yield inaccurate results.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
/// The number of IVF partitions to create.
///
/// This value should generally scale with the number of rows in the dataset. By default
/// the number of partitions is the square root of the number of rows.
///
/// If this value is too large then the first part of the search (picking the right partition)
/// will be slow. If this value is too small then the second part of the search (searching
/// within a partition) will be slow.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// Number of sub-vectors of PQ.
///
/// This value controls how much the vector is compressed during the quantization step.
/// The more sub vectors there are the less the vector is compressed. The default is
/// the dimension of the vector divided by 16. If the dimension is not evenly divisible
/// by 16 we use the dimension divded by 8.
///
/// The above two cases are highly preferred. Having 8 or 16 values per subvector allows
/// us to use efficient SIMD instructions.
///
/// If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
/// will likely result in poor performance.
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
/// The rate used to calculate the number of training vectors for kmeans.
///
/// When an IVF PQ index is trained, we need to calculate partitions. These are groups
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
///
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
/// random sample of the data. This parameter controls the size of the sample. The total
/// number of vectors used to train the index is `sample_rate * num_partitions`.
///
/// Increasing this value might improve the quality of the index but in most cases the
/// default should be sufficient.
///
/// The default value is 256.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iterations to train kmeans.
///
/// When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
/// controls how many iterations of kmeans to run.
///
/// Increasing this might improve the quality of the index but in most cases the parameter
/// is unused because kmeans will converge with fewer iterations. The parameter is only
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
/// that setting this larger will lead to the index converging anyways.
///
/// The default value is 50.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
}
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
let num_partitions = (rows as f64).sqrt() as u32;
max(1, num_partitions)
}
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
// Should be more aggressive than this default.
dim / 16
} else if dim % 8 == 0 {
dim / 8
} else {
log::warn!(
"The dimension of the vector is not divisible by 8 or 16, \
which may cause performance degradation in PQ"
);
1
}
}

View File

@@ -130,14 +130,13 @@
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
//! # RecordBatchIterator, Int32Array};
//! # use arrow_schema::{Schema, Field, DataType};
//! use lancedb::index::Index;
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
//! tbl.create_index(&["vector"])
//! .ivf_pq()
//! .num_partitions(256)
//! .build()
//! tbl.create_index(&["vector"], Index::Auto)
//! .execute()
//! .await
//! .unwrap();
//! # });
@@ -181,6 +180,7 @@
//! # });
//! ```
pub mod arrow;
pub mod connection;
pub mod data;
pub mod error;
@@ -194,6 +194,7 @@ pub mod table;
pub mod utils;
pub use error::{Error, Result};
pub use lance_linalg::distance::DistanceType;
pub use table::Table;
/// Connect to a database

View File

@@ -15,9 +15,9 @@
use std::sync::Arc;
use arrow_array::Float32Array;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance_linalg::distance::MetricType;
use crate::arrow::SendableRecordBatchStream;
use crate::error::Result;
use crate::table::TableInternal;
@@ -81,13 +81,15 @@ impl Query {
}
}
/// Convert the query plan to a [`DatasetRecordBatchStream`]
/// Convert the query plan to a [`SendableRecordBatchStream`]
///
/// # Returns
///
/// * A [DatasetRecordBatchStream] with the query's results.
pub async fn execute_stream(&self) -> Result<DatasetRecordBatchStream> {
self.parent.clone().do_query(self).await
/// * A [SendableRecordBatchStream] with the query's results.
pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
Ok(SendableRecordBatchStream::from(
self.parent.clone().query(self).await?,
))
}
/// Set the column to query
@@ -363,6 +365,10 @@ mod tests {
let arr: &Int32Array = b["id"].as_primitive();
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
}
// Reject bad filter
let result = table.query().filter("id = 0 AND").execute_stream().await;
assert!(result.is_err());
}
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {

View File

@@ -5,11 +5,11 @@ use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewCol
use crate::{
error::Result,
index::IndexBuilder,
index::{IndexBuilder, IndexConfig},
query::Query,
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableInternal,
TableInternal, UpdateBuilder,
},
};
@@ -45,25 +45,40 @@ impl TableInternal for RemoteTable {
fn name(&self) -> &str {
&self.name
}
async fn version(&self) -> Result<u64> {
todo!()
}
async fn checkout(&self, _version: u64) -> Result<()> {
todo!()
}
async fn checkout_latest(&self) -> Result<()> {
todo!()
}
async fn restore(&self) -> Result<()> {
todo!()
}
async fn schema(&self) -> Result<SchemaRef> {
todo!()
}
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
todo!()
}
async fn do_add(&self, _add: AddDataBuilder) -> Result<()> {
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
todo!()
}
async fn do_query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
}
async fn delete(&self, _predicate: &str) -> Result<()> {
todo!()
}
async fn do_create_index(&self, _index: IndexBuilder) -> Result<()> {
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
todo!()
}
async fn do_merge_insert(
async fn merge_insert(
&self,
_params: MergeInsertBuilder,
_new_data: Box<dyn RecordBatchReader + Send>,
@@ -86,4 +101,7 @@ impl TableInternal for RemoteTable {
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
todo!()
}
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
todo!()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -83,6 +83,33 @@ impl DatasetRef {
}
}
async fn as_time_travel(&mut self, target_version: u64) -> Result<()> {
match self {
Self::Latest { dataset, .. } => {
*self = Self::TimeTravel {
dataset: dataset.checkout_version(target_version).await?,
version: target_version,
};
}
Self::TimeTravel { dataset, version } => {
if *version != target_version {
*self = Self::TimeTravel {
dataset: dataset.checkout_version(target_version).await?,
version: target_version,
};
}
}
}
Ok(())
}
fn time_travel_version(&self) -> Option<u64> {
match self {
Self::Latest { .. } => None,
Self::TimeTravel { version, .. } => Some(*version),
}
}
fn set_latest(&mut self, dataset: Dataset) {
match self {
Self::Latest {
@@ -106,23 +133,6 @@ impl DatasetConsistencyWrapper {
})))
}
/// Create a new wrapper in the time travel mode.
pub fn new_time_travel(dataset: Dataset, version: u64) -> Self {
Self(Arc::new(RwLock::new(DatasetRef::TimeTravel {
dataset,
version,
})))
}
/// Create an independent copy of self.
///
/// Unlike Clone, this will track versions independently of the original wrapper and
/// will be tied to a different RwLock.
pub async fn duplicate(&self) -> Self {
let ds_ref = self.0.read().await;
Self(Arc::new(RwLock::new((*ds_ref).clone())))
}
/// Get an immutable reference to the dataset.
pub async fn get(&self) -> Result<DatasetReadGuard<'_>> {
self.ensure_up_to_date().await?;
@@ -132,7 +142,19 @@ impl DatasetConsistencyWrapper {
}
/// Get a mutable reference to the dataset.
///
/// If the dataset is in time travel mode this will fail
pub async fn get_mut(&self) -> Result<DatasetWriteGuard<'_>> {
self.ensure_mutable().await?;
self.ensure_up_to_date().await?;
Ok(DatasetWriteGuard {
guard: self.0.write().await,
})
}
/// Get a mutable reference to the dataset without requiring the
/// dataset to be in a Latest mode.
pub async fn get_mut_unchecked(&self) -> Result<DatasetWriteGuard<'_>> {
self.ensure_up_to_date().await?;
Ok(DatasetWriteGuard {
guard: self.0.write().await,
@@ -140,7 +162,7 @@ impl DatasetConsistencyWrapper {
}
/// Convert into a wrapper in latest version mode
pub async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
self.0
.write()
.await
@@ -148,6 +170,10 @@ impl DatasetConsistencyWrapper {
.await
}
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
self.0.write().await.as_time_travel(target_version).await
}
/// Provide a known latest version of the dataset.
///
/// This is usually done after some write operation, which inherently will
@@ -160,6 +186,22 @@ impl DatasetConsistencyWrapper {
self.0.write().await.reload().await
}
/// Returns the version, if in time travel mode, or None otherwise
pub async fn time_travel_version(&self) -> Option<u64> {
self.0.read().await.time_travel_version()
}
pub async fn ensure_mutable(&self) -> Result<()> {
let dataset_ref = self.0.read().await;
match &*dataset_ref {
DatasetRef::Latest { .. } => Ok(()),
DatasetRef::TimeTravel { .. } => Err(crate::Error::InvalidInput {
message: "table cannot be modified when a specific version is checked out"
.to_string(),
}),
}
}
async fn is_up_to_date(&self) -> Result<bool> {
let dataset_ref = self.0.read().await;
match &*dataset_ref {

View File

@@ -98,6 +98,6 @@ impl MergeInsertBuilder {
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
self.table.clone().do_merge_insert(self, new_data).await
self.table.clone().merge_insert(self, new_data).await
}
}