mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
25 Commits
add-python
...
600
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86c9bc0d2d | ||
|
|
c1dfad675a | ||
|
|
2e1838a62a | ||
|
|
4d39f63cf6 | ||
|
|
3c4f2a7020 | ||
|
|
48a4202748 | ||
|
|
2084fbcff4 | ||
|
|
408988abce | ||
|
|
e68fbf65cc | ||
|
|
63399dc0ee | ||
|
|
0b0f4e9d1c | ||
|
|
2ec0e79303 | ||
|
|
d86dd2c60d | ||
|
|
67b38d6115 | ||
|
|
c112dea28b | ||
|
|
d662b9744e | ||
|
|
ac955a5a7e | ||
|
|
4dc7497547 | ||
|
|
d744972f2f | ||
|
|
9bc320874a | ||
|
|
510d449167 | ||
|
|
356e89a800 | ||
|
|
ae1cf4441d | ||
|
|
1ae08fe31d | ||
|
|
a517629c65 |
@@ -28,13 +28,14 @@ arrow-schema = "50.0"
|
|||||||
arrow-arith = "50.0"
|
arrow-arith = "50.0"
|
||||||
arrow-cast = "50.0"
|
arrow-cast = "50.0"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.35"
|
||||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
] }
|
] }
|
||||||
futures = "0"
|
futures = "0"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
object_store = "0.9.0"
|
object_store = "0.9.0"
|
||||||
|
pin-project = "1.0.7"
|
||||||
snafu = "0.7.4"
|
snafu = "0.7.4"
|
||||||
url = "2"
|
url = "2"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
|
|||||||
@@ -176,6 +176,10 @@ export async function connect (
|
|||||||
opts = { uri: arg }
|
opts = { uri: arg }
|
||||||
} else {
|
} else {
|
||||||
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
|
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
|
||||||
|
const keys = Object.keys(arg)
|
||||||
|
if (keys.length === 1 && keys[0] === 'uri' && typeof arg.uri === 'string') {
|
||||||
|
opts = { uri: arg.uri }
|
||||||
|
} else {
|
||||||
opts = Object.assign(
|
opts = Object.assign(
|
||||||
{
|
{
|
||||||
uri: '',
|
uri: '',
|
||||||
@@ -187,6 +191,7 @@ export async function connect (
|
|||||||
arg
|
arg
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (opts.uri.startsWith('db://')) {
|
if (opts.uri.startsWith('db://')) {
|
||||||
// Remote connection
|
// Remote connection
|
||||||
|
|||||||
@@ -128,6 +128,11 @@ describe('LanceDB client', function () {
|
|||||||
assertResults(results)
|
assertResults(results)
|
||||||
results = await table.where('id % 2 = 0').execute()
|
results = await table.where('id % 2 = 0').execute()
|
||||||
assertResults(results)
|
assertResults(results)
|
||||||
|
|
||||||
|
// Should reject a bad filter
|
||||||
|
await expect(table.filter('id % 2 = 0 AND').execute()).to.be.rejectedWith(
|
||||||
|
/.*sql parser error: Expected an expression:, found: EOF.*/
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('uses a filter / where clause', async function () {
|
it('uses a filter / where clause', async function () {
|
||||||
@@ -283,7 +288,8 @@ describe('LanceDB client', function () {
|
|||||||
|
|
||||||
it('create a table from an Arrow Table', async function () {
|
it('create a table from an Arrow Table', async function () {
|
||||||
const dir = await track().mkdir('lancejs')
|
const dir = await track().mkdir('lancejs')
|
||||||
const con = await lancedb.connect(dir)
|
// Also test the connect function with an object
|
||||||
|
const con = await lancedb.connect({ uri: dir })
|
||||||
|
|
||||||
const i32s = new Int32Array(new Array<number>(10))
|
const i32s = new Int32Array(new Array<number>(10))
|
||||||
const i32 = makeVector(i32s)
|
const i32 = makeVector(i32s)
|
||||||
@@ -745,11 +751,11 @@ describe('LanceDB client', function () {
|
|||||||
num_sub_vectors: 2
|
num_sub_vectors: 2
|
||||||
})
|
})
|
||||||
await expect(createIndex).to.be.rejectedWith(
|
await expect(createIndex).to.be.rejectedWith(
|
||||||
/VectorIndex requires the column data type to be fixed size list of float32s/
|
"index cannot be created on the column `name` which has data type Utf8"
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('it should fail when the column is not a vector', async function () {
|
it('it should fail when num_partitions is invalid', async function () {
|
||||||
const uri = await createTestDB(32, 300)
|
const uri = await createTestDB(32, 300)
|
||||||
const con = await lancedb.connect(uri)
|
const con = await lancedb.connect(uri)
|
||||||
const table = await con.openTable('vectors')
|
const table = await con.openTable('vectors')
|
||||||
|
|||||||
@@ -14,12 +14,10 @@ crate-type = ["cdylib"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
lance-linalg.workspace = true
|
|
||||||
lance.workspace = true
|
|
||||||
lancedb = { path = "../rust/lancedb" }
|
lancedb = { path = "../rust/lancedb" }
|
||||||
napi = { version = "2.15", default-features = false, features = [
|
napi = { version = "2.15", default-features = false, features = [
|
||||||
"napi7",
|
"napi7",
|
||||||
"async"
|
"async",
|
||||||
] }
|
] }
|
||||||
napi-derive = "2"
|
napi-derive = "2"
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import {
|
|||||||
Float64,
|
Float64,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
import { makeArrowTable } from "../dist/arrow";
|
import { makeArrowTable } from "../dist/arrow";
|
||||||
|
import { Index } from "../dist/indices";
|
||||||
|
|
||||||
describe("Given a table", () => {
|
describe("Given a table", () => {
|
||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
@@ -65,21 +66,36 @@ describe("Given a table", () => {
|
|||||||
expect(table.isOpen()).toBe(false);
|
expect(table.isOpen()).toBe(false);
|
||||||
expect(table.countRows()).rejects.toThrow("Table some_table is closed");
|
expect(table.countRows()).rejects.toThrow("Table some_table is closed");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should let me update values", async () => {
|
||||||
|
await table.add([{ id: 1 }]);
|
||||||
|
expect(await table.countRows("id == 1")).toBe(1);
|
||||||
|
expect(await table.countRows("id == 7")).toBe(0);
|
||||||
|
await table.update({ id: "7" });
|
||||||
|
expect(await table.countRows("id == 1")).toBe(0);
|
||||||
|
expect(await table.countRows("id == 7")).toBe(1);
|
||||||
|
await table.add([{ id: 2 }]);
|
||||||
|
// Test Map as input
|
||||||
|
await table.update(new Map(Object.entries({ id: "10" })), {
|
||||||
|
where: "id % 2 == 0",
|
||||||
|
});
|
||||||
|
expect(await table.countRows("id == 2")).toBe(0);
|
||||||
|
expect(await table.countRows("id == 7")).toBe(1);
|
||||||
|
expect(await table.countRows("id == 10")).toBe(1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("Test creating index", () => {
|
describe("When creating an index", () => {
|
||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field("id", new Int32(), true),
|
new Field("id", new Int32(), true),
|
||||||
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
|
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
|
||||||
]);
|
]);
|
||||||
|
let tbl: Table;
|
||||||
|
let queryVec: number[];
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(async () => {
|
||||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
});
|
|
||||||
afterEach(() => tmpDir.removeCallback());
|
|
||||||
|
|
||||||
test("create vector index with no column", async () => {
|
|
||||||
const db = await connect(tmpDir.name);
|
const db = await connect(tmpDir.name);
|
||||||
const data = makeArrowTable(
|
const data = makeArrowTable(
|
||||||
Array(300)
|
Array(300)
|
||||||
@@ -94,47 +110,66 @@ describe("Test creating index", () => {
|
|||||||
schema,
|
schema,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
const tbl = await db.createTable("test", data);
|
queryVec = data.toArray()[5].vec.toJSON();
|
||||||
await tbl.createIndex().build();
|
tbl = await db.createTable("test", data);
|
||||||
|
});
|
||||||
|
afterEach(() => tmpDir.removeCallback());
|
||||||
|
|
||||||
|
it("should create a vector index on vector columns", async () => {
|
||||||
|
await tbl.createIndex("vec");
|
||||||
|
|
||||||
// check index directory
|
// check index directory
|
||||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||||
// TODO: check index type.
|
const indices = await tbl.listIndices();
|
||||||
|
expect(indices.length).toBe(1);
|
||||||
|
expect(indices[0]).toEqual({
|
||||||
|
indexType: "IvfPq",
|
||||||
|
columns: ["vec"],
|
||||||
|
});
|
||||||
|
|
||||||
// Search without specifying the column
|
// Search without specifying the column
|
||||||
const queryVector = data.toArray()[5].vec.toJSON();
|
const rst = await tbl.query().nearestTo(queryVec).limit(2).toArrow();
|
||||||
const rst = await tbl.query().nearestTo(queryVector).limit(2).toArrow();
|
|
||||||
expect(rst.numRows).toBe(2);
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
// Search with specifying the column
|
// Search with specifying the column
|
||||||
const rst2 = await tbl.search(queryVector, "vec").limit(2).toArrow();
|
const rst2 = await tbl.search(queryVec, "vec").limit(2).toArrow();
|
||||||
expect(rst2.numRows).toBe(2);
|
expect(rst2.numRows).toBe(2);
|
||||||
expect(rst.toString()).toEqual(rst2.toString());
|
expect(rst.toString()).toEqual(rst2.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
test("no vector column available", async () => {
|
it("should allow parameters to be specified", async () => {
|
||||||
const db = await connect(tmpDir.name);
|
await tbl.createIndex("vec", {
|
||||||
const tbl = await db.createTable(
|
config: Index.ivfPq({
|
||||||
"no_vec",
|
numPartitions: 10,
|
||||||
makeArrowTable([
|
}),
|
||||||
{ id: 1, val: 2 },
|
});
|
||||||
{ id: 2, val: 3 },
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
await expect(tbl.createIndex().build()).rejects.toThrow(
|
|
||||||
"No vector column found",
|
|
||||||
);
|
|
||||||
|
|
||||||
await tbl.createIndex("val").build();
|
// TODO: Verify parameters when we can load index config as part of list indices
|
||||||
const indexDir = path.join(tmpDir.name, "no_vec.lance", "_indices");
|
});
|
||||||
|
|
||||||
|
it("should allow me to replace (or not) an existing index", async () => {
|
||||||
|
await tbl.createIndex("id");
|
||||||
|
// Default is replace=true
|
||||||
|
await tbl.createIndex("id");
|
||||||
|
await expect(tbl.createIndex("id", { replace: false })).rejects.toThrow(
|
||||||
|
"already exists",
|
||||||
|
);
|
||||||
|
await tbl.createIndex("id", { replace: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create a scalar index on scalar columns", async () => {
|
||||||
|
await tbl.createIndex("id");
|
||||||
|
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||||
|
|
||||||
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
||||||
expect(r.numRows).toBe(1);
|
expect(r.numRows).toBe(298);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// TODO: Move this test to the query API test (making sure we can reject queries
|
||||||
|
// when the dimension is incorrect)
|
||||||
test("two columns with different dimensions", async () => {
|
test("two columns with different dimensions", async () => {
|
||||||
const db = await connect(tmpDir.name);
|
const db = await connect(tmpDir.name);
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
@@ -164,14 +199,9 @@ describe("Test creating index", () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Only build index over v1
|
// Only build index over v1
|
||||||
await expect(tbl.createIndex().build()).rejects.toThrow(
|
await tbl.createIndex("vec", {
|
||||||
/.*More than one vector columns found.*/,
|
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
|
||||||
);
|
});
|
||||||
tbl
|
|
||||||
.createIndex("vec")
|
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
|
||||||
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
|
|
||||||
.build();
|
|
||||||
|
|
||||||
const rst = await tbl
|
const rst = await tbl
|
||||||
.query()
|
.query()
|
||||||
@@ -205,30 +235,6 @@ describe("Test creating index", () => {
|
|||||||
expect(rst64Query.toString()).toEqual(rst64Search.toString());
|
expect(rst64Query.toString()).toEqual(rst64Search.toString());
|
||||||
expect(rst64Query.numRows).toBe(2);
|
expect(rst64Query.numRows).toBe(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("create scalar index", async () => {
|
|
||||||
const db = await connect(tmpDir.name);
|
|
||||||
const data = makeArrowTable(
|
|
||||||
Array(300)
|
|
||||||
.fill(1)
|
|
||||||
.map((_, i) => ({
|
|
||||||
id: i,
|
|
||||||
vec: Array(32)
|
|
||||||
.fill(1)
|
|
||||||
.map(() => Math.random()),
|
|
||||||
})),
|
|
||||||
{
|
|
||||||
schema,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
const tbl = await db.createTable("test", data);
|
|
||||||
await tbl.createIndex("id").build();
|
|
||||||
|
|
||||||
// check index directory
|
|
||||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
|
||||||
// TODO: check index type.
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("Read consistency interval", () => {
|
describe("Read consistency interval", () => {
|
||||||
@@ -348,3 +354,48 @@ describe("schema evolution", function () {
|
|||||||
expect(await table.schema()).toEqual(expectedSchema);
|
expect(await table.schema()).toEqual(expectedSchema);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("when dealing with versioning", () => {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
beforeEach(() => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
});
|
||||||
|
afterEach(() => {
|
||||||
|
tmpDir.removeCallback();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("can travel in time", async () => {
|
||||||
|
// Setup
|
||||||
|
const con = await connect(tmpDir.name);
|
||||||
|
const table = await con.createTable("vectors", [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] },
|
||||||
|
]);
|
||||||
|
const version = await table.version();
|
||||||
|
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
|
||||||
|
expect(await table.countRows()).toBe(2);
|
||||||
|
// Make sure we can rewind
|
||||||
|
await table.checkout(version);
|
||||||
|
expect(await table.countRows()).toBe(1);
|
||||||
|
// Can't add data in time travel mode
|
||||||
|
await expect(table.add([{ id: 3n, vector: [0.1, 0.2] }])).rejects.toThrow(
|
||||||
|
"table cannot be modified when a specific version is checked out",
|
||||||
|
);
|
||||||
|
// Can go back to normal mode
|
||||||
|
await table.checkoutLatest();
|
||||||
|
expect(await table.countRows()).toBe(2);
|
||||||
|
// Should be able to add data again
|
||||||
|
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
|
||||||
|
expect(await table.countRows()).toBe(3);
|
||||||
|
// Now checkout and restore
|
||||||
|
await table.checkout(version);
|
||||||
|
await table.restore();
|
||||||
|
expect(await table.countRows()).toBe(1);
|
||||||
|
// Should be able to add data
|
||||||
|
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
|
||||||
|
expect(await table.countRows()).toBe(2);
|
||||||
|
// Can't use restore if not checked out
|
||||||
|
await expect(table.restore()).rejects.toThrow(
|
||||||
|
"checkout before running restore",
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -18,15 +18,9 @@ import {
|
|||||||
ConnectionOptions,
|
ConnectionOptions,
|
||||||
} from "./native.js";
|
} from "./native.js";
|
||||||
|
|
||||||
export {
|
export { ConnectionOptions, WriteOptions, Query } from "./native.js";
|
||||||
ConnectionOptions,
|
export { Connection, CreateTableOptions } from "./connection";
|
||||||
WriteOptions,
|
export { Table, AddDataOptions } from "./table";
|
||||||
Query,
|
|
||||||
MetricType,
|
|
||||||
} from "./native.js";
|
|
||||||
export { Connection } from "./connection";
|
|
||||||
export { Table } from "./table";
|
|
||||||
export { IvfPQOptions, IndexBuilder } from "./indexer";
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connect to a LanceDB instance at the given URI.
|
* Connect to a LanceDB instance at the given URI.
|
||||||
|
|||||||
@@ -1,105 +0,0 @@
|
|||||||
// Copyright 2024 Lance Developers.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// TODO: Re-enable this as part of https://github.com/lancedb/lancedb/pull/1052
|
|
||||||
/* eslint-disable @typescript-eslint/naming-convention */
|
|
||||||
|
|
||||||
import {
|
|
||||||
MetricType,
|
|
||||||
IndexBuilder as NativeBuilder,
|
|
||||||
Table as NativeTable,
|
|
||||||
} from "./native";
|
|
||||||
|
|
||||||
/** Options to create `IVF_PQ` index */
|
|
||||||
export interface IvfPQOptions {
|
|
||||||
/** Number of IVF partitions. */
|
|
||||||
num_partitions?: number;
|
|
||||||
|
|
||||||
/** Number of sub-vectors in PQ coding. */
|
|
||||||
num_sub_vectors?: number;
|
|
||||||
|
|
||||||
/** Number of bits used for each PQ code.
|
|
||||||
*/
|
|
||||||
num_bits?: number;
|
|
||||||
|
|
||||||
/** Metric type to calculate the distance between vectors.
|
|
||||||
*
|
|
||||||
* Supported metrics: `L2`, `Cosine` and `Dot`.
|
|
||||||
*/
|
|
||||||
metric_type?: MetricType;
|
|
||||||
|
|
||||||
/** Number of iterations to train K-means.
|
|
||||||
*
|
|
||||||
* Default is 50. The more iterations it usually yield better results,
|
|
||||||
* but it takes longer to train.
|
|
||||||
*/
|
|
||||||
max_iterations?: number;
|
|
||||||
|
|
||||||
sample_rate?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Building an index on LanceDB {@link Table}
|
|
||||||
*
|
|
||||||
* @see {@link Table.createIndex} for detailed usage.
|
|
||||||
*/
|
|
||||||
export class IndexBuilder {
|
|
||||||
private inner: NativeBuilder;
|
|
||||||
|
|
||||||
constructor(tbl: NativeTable) {
|
|
||||||
this.inner = tbl.createIndex();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Instruct the builder to build an `IVF_PQ` index */
|
|
||||||
ivf_pq(options?: IvfPQOptions): IndexBuilder {
|
|
||||||
this.inner.ivfPq(
|
|
||||||
options?.metric_type,
|
|
||||||
options?.num_partitions,
|
|
||||||
options?.num_sub_vectors,
|
|
||||||
options?.num_bits,
|
|
||||||
options?.max_iterations,
|
|
||||||
options?.sample_rate,
|
|
||||||
);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Instruct the builder to build a Scalar index. */
|
|
||||||
scalar(): IndexBuilder {
|
|
||||||
this.scalar();
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Set the column(s) to create index on top of. */
|
|
||||||
column(col: string): IndexBuilder {
|
|
||||||
this.inner.column(col);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Set to true to replace existing index. */
|
|
||||||
replace(val: boolean): IndexBuilder {
|
|
||||||
this.inner.replace(val);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Specify the name of the index. Optional */
|
|
||||||
name(n: string): IndexBuilder {
|
|
||||||
this.inner.name(n);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Building the index. */
|
|
||||||
async build() {
|
|
||||||
await this.inner.build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
195
nodejs/lancedb/indices.ts
Normal file
195
nodejs/lancedb/indices.ts
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import { Index as LanceDbIndex } from "./native";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options to create an `IVF_PQ` index
|
||||||
|
*/
|
||||||
|
export interface IvfPqOptions {
|
||||||
|
/** The number of IVF partitions to create.
|
||||||
|
*
|
||||||
|
* This value should generally scale with the number of rows in the dataset.
|
||||||
|
* By default the number of partitions is the square root of the number of
|
||||||
|
* rows.
|
||||||
|
*
|
||||||
|
* If this value is too large then the first part of the search (picking the
|
||||||
|
* right partition) will be slow. If this value is too small then the second
|
||||||
|
* part of the search (searching within a partition) will be slow.
|
||||||
|
*/
|
||||||
|
numPartitions?: number;
|
||||||
|
|
||||||
|
/** Number of sub-vectors of PQ.
|
||||||
|
*
|
||||||
|
* This value controls how much the vector is compressed during the quantization step.
|
||||||
|
* The more sub vectors there are the less the vector is compressed. The default is
|
||||||
|
* the dimension of the vector divided by 16. If the dimension is not evenly divisible
|
||||||
|
* by 16 we use the dimension divded by 8.
|
||||||
|
*
|
||||||
|
* The above two cases are highly preferred. Having 8 or 16 values per subvector allows
|
||||||
|
* us to use efficient SIMD instructions.
|
||||||
|
*
|
||||||
|
* If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
|
||||||
|
* will likely result in poor performance.
|
||||||
|
*/
|
||||||
|
numSubVectors?: number;
|
||||||
|
|
||||||
|
/** [DistanceType] to use to build the index.
|
||||||
|
*
|
||||||
|
* Default value is [DistanceType::L2].
|
||||||
|
*
|
||||||
|
* This is used when training the index to calculate the IVF partitions
|
||||||
|
* (vectors are grouped in partitions with similar vectors according to this
|
||||||
|
* distance type) and to calculate a subvector's code during quantization.
|
||||||
|
*
|
||||||
|
* The distance type used to train an index MUST match the distance type used
|
||||||
|
* to search the index. Failure to do so will yield inaccurate results.
|
||||||
|
*
|
||||||
|
* The following distance types are available:
|
||||||
|
*
|
||||||
|
* "l2" - Euclidean distance. This is a very common distance metric that
|
||||||
|
* accounts for both magnitude and direction when determining the distance
|
||||||
|
* between vectors. L2 distance has a range of [0, ∞).
|
||||||
|
*
|
||||||
|
* "cosine" - Cosine distance. Cosine distance is a distance metric
|
||||||
|
* calculated from the cosine similarity between two vectors. Cosine
|
||||||
|
* similarity is a measure of similarity between two non-zero vectors of an
|
||||||
|
* inner product space. It is defined to equal the cosine of the angle
|
||||||
|
* between them. Unlike L2, the cosine distance is not affected by the
|
||||||
|
* magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||||
|
*
|
||||||
|
* Note: the cosine distance is undefined when one (or both) of the vectors
|
||||||
|
* are all zeros (there is no direction). These vectors are invalid and may
|
||||||
|
* never be returned from a vector search.
|
||||||
|
*
|
||||||
|
* "dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||||
|
* distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||||
|
* L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||||
|
*/
|
||||||
|
distanceType?: "l2" | "cosine" | "dot";
|
||||||
|
|
||||||
|
/** Max iteration to train IVF kmeans.
|
||||||
|
*
|
||||||
|
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
|
||||||
|
* controls how many iterations of kmeans to run.
|
||||||
|
*
|
||||||
|
* Increasing this might improve the quality of the index but in most cases these extra
|
||||||
|
* iterations have diminishing returns.
|
||||||
|
*
|
||||||
|
* The default value is 50.
|
||||||
|
*/
|
||||||
|
maxIterations?: number;
|
||||||
|
|
||||||
|
/** The number of vectors, per partition, to sample when training IVF kmeans.
|
||||||
|
*
|
||||||
|
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
|
||||||
|
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
|
*
|
||||||
|
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||||
|
* random sample of the data. This parameter controls the size of the sample. The total
|
||||||
|
* number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||||
|
*
|
||||||
|
* Increasing this value might improve the quality of the index but in most cases the
|
||||||
|
* default should be sufficient.
|
||||||
|
*
|
||||||
|
* The default value is 256.
|
||||||
|
*/
|
||||||
|
sampleRate?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Index {
|
||||||
|
private readonly inner: LanceDbIndex;
|
||||||
|
private constructor(inner: LanceDbIndex) {
|
||||||
|
this.inner = inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an IvfPq index
|
||||||
|
*
|
||||||
|
* This index stores a compressed (quantized) copy of every vector. These vectors
|
||||||
|
* are grouped into partitions of similar vectors. Each partition keeps track of
|
||||||
|
* a centroid which is the average value of all vectors in the group.
|
||||||
|
*
|
||||||
|
* During a query the centroids are compared with the query vector to find the closest
|
||||||
|
* partitions. The compressed vectors in these partitions are then searched to find
|
||||||
|
* the closest vectors.
|
||||||
|
*
|
||||||
|
* The compression scheme is called product quantization. Each vector is divided into
|
||||||
|
* subvectors and then each subvector is quantized into a small number of bits. the
|
||||||
|
* parameters `num_bits` and `num_subvectors` control this process, providing a tradeoff
|
||||||
|
* between index size (and thus search speed) and index accuracy.
|
||||||
|
*
|
||||||
|
* The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||||
|
* many groups to create.
|
||||||
|
*
|
||||||
|
* Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||||
|
* currently is also a memory intensive operation.
|
||||||
|
*/
|
||||||
|
static ivfPq(options?: Partial<IvfPqOptions>) {
|
||||||
|
return new Index(
|
||||||
|
LanceDbIndex.ivfPq(
|
||||||
|
options?.distanceType,
|
||||||
|
options?.numPartitions,
|
||||||
|
options?.numSubVectors,
|
||||||
|
options?.maxIterations,
|
||||||
|
options?.sampleRate,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Create a btree index
|
||||||
|
*
|
||||||
|
* A btree index is an index on a scalar columns. The index stores a copy of the column
|
||||||
|
* in sorted order. A header entry is created for each block of rows (currently the
|
||||||
|
* block size is fixed at 4096). These header entries are stored in a separate
|
||||||
|
* cacheable structure (a btree). To search for data the header is used to determine
|
||||||
|
* which blocks need to be read from disk.
|
||||||
|
*
|
||||||
|
* For example, a btree index in a table with 1Bi rows requires sizeof(Scalar) * 256Ki
|
||||||
|
* bytes of memory and will generally need to read sizeof(Scalar) * 4096 bytes to find
|
||||||
|
* the correct row ids.
|
||||||
|
*
|
||||||
|
* This index is good for scalar columns with mostly distinct values and does best when
|
||||||
|
* the query is highly selective.
|
||||||
|
*
|
||||||
|
* The btree index does not currently have any parameters though parameters such as the
|
||||||
|
* block size may be added in the future.
|
||||||
|
*/
|
||||||
|
static btree() {
|
||||||
|
return new Index(LanceDbIndex.btree());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IndexOptions {
|
||||||
|
/** Advanced index configuration
|
||||||
|
*
|
||||||
|
* This option allows you to specify a specfic index to create and also
|
||||||
|
* allows you to pass in configuration for training the index.
|
||||||
|
*
|
||||||
|
* See the static methods on Index for details on the various index types.
|
||||||
|
*
|
||||||
|
* If this is not supplied then column data type(s) and column statistics
|
||||||
|
* will be used to determine the most useful kind of index to create.
|
||||||
|
*/
|
||||||
|
config?: Index;
|
||||||
|
/** Whether to replace the existing index
|
||||||
|
*
|
||||||
|
* If this is false, and another index already exists on the same columns
|
||||||
|
* and the same name, then an error will be returned. This is true even if
|
||||||
|
* that index is out of date.
|
||||||
|
*
|
||||||
|
* The default is true
|
||||||
|
*/
|
||||||
|
replace?: boolean;
|
||||||
|
}
|
||||||
37
nodejs/lancedb/native.d.ts
vendored
37
nodejs/lancedb/native.d.ts
vendored
@@ -3,14 +3,17 @@
|
|||||||
|
|
||||||
/* auto-generated by NAPI-RS */
|
/* auto-generated by NAPI-RS */
|
||||||
|
|
||||||
export const enum IndexType {
|
/** A description of an index currently configured on a column */
|
||||||
Scalar = 0,
|
export interface IndexConfig {
|
||||||
IvfPq = 1
|
/** The type of the index */
|
||||||
}
|
indexType: string
|
||||||
export const enum MetricType {
|
/**
|
||||||
L2 = 0,
|
* The columns in the index
|
||||||
Cosine = 1,
|
*
|
||||||
Dot = 2
|
* Currently this is always an array of size 1. In the future there may
|
||||||
|
* be more columns to represent composite indices.
|
||||||
|
*/
|
||||||
|
columns: Array<string>
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* A definition of a column alteration. The alteration changes the column at
|
* A definition of a column alteration. The alteration changes the column at
|
||||||
@@ -93,13 +96,9 @@ export class Connection {
|
|||||||
/** Drop table with the name. Or raise an error if the table does not exist. */
|
/** Drop table with the name. Or raise an error if the table does not exist. */
|
||||||
dropTable(name: string): Promise<void>
|
dropTable(name: string): Promise<void>
|
||||||
}
|
}
|
||||||
export class IndexBuilder {
|
export class Index {
|
||||||
replace(v: boolean): void
|
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index
|
||||||
column(c: string): void
|
static btree(): Index
|
||||||
name(name: string): void
|
|
||||||
ivfPq(metricType?: MetricType | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, numBits?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): void
|
|
||||||
scalar(): void
|
|
||||||
build(): Promise<void>
|
|
||||||
}
|
}
|
||||||
/** Typescript-style Async Iterator over RecordBatches */
|
/** Typescript-style Async Iterator over RecordBatches */
|
||||||
export class RecordBatchIterator {
|
export class RecordBatchIterator {
|
||||||
@@ -125,9 +124,15 @@ export class Table {
|
|||||||
add(buf: Buffer, mode: string): Promise<void>
|
add(buf: Buffer, mode: string): Promise<void>
|
||||||
countRows(filter?: string | undefined | null): Promise<number>
|
countRows(filter?: string | undefined | null): Promise<number>
|
||||||
delete(predicate: string): Promise<void>
|
delete(predicate: string): Promise<void>
|
||||||
createIndex(): IndexBuilder
|
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
|
||||||
|
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
|
||||||
query(): Query
|
query(): Query
|
||||||
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
||||||
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
||||||
dropColumns(columns: Array<string>): Promise<void>
|
dropColumns(columns: Array<string>): Promise<void>
|
||||||
|
version(): Promise<number>
|
||||||
|
checkout(version: number): Promise<void>
|
||||||
|
checkoutLatest(): Promise<void>
|
||||||
|
restore(): Promise<void>
|
||||||
|
listIndices(): Promise<Array<IndexConfig>>
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -295,12 +295,10 @@ if (!nativeBinding) {
|
|||||||
throw new Error(`Failed to load native binding`)
|
throw new Error(`Failed to load native binding`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
|
const { Connection, Index, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
|
||||||
|
|
||||||
module.exports.Connection = Connection
|
module.exports.Connection = Connection
|
||||||
module.exports.IndexType = IndexType
|
module.exports.Index = Index
|
||||||
module.exports.MetricType = MetricType
|
|
||||||
module.exports.IndexBuilder = IndexBuilder
|
|
||||||
module.exports.RecordBatchIterator = RecordBatchIterator
|
module.exports.RecordBatchIterator = RecordBatchIterator
|
||||||
module.exports.Query = Query
|
module.exports.Query = Query
|
||||||
module.exports.Table = Table
|
module.exports.Table = Table
|
||||||
|
|||||||
@@ -16,12 +16,14 @@ import { Schema, tableFromIPC } from "apache-arrow";
|
|||||||
import {
|
import {
|
||||||
AddColumnsSql,
|
AddColumnsSql,
|
||||||
ColumnAlteration,
|
ColumnAlteration,
|
||||||
|
IndexConfig,
|
||||||
Table as _NativeTable,
|
Table as _NativeTable,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
import { Query } from "./query";
|
import { Query } from "./query";
|
||||||
import { IndexBuilder } from "./indexer";
|
import { IndexOptions } from "./indices";
|
||||||
import { Data, fromDataToBuffer } from "./arrow";
|
import { Data, fromDataToBuffer } from "./arrow";
|
||||||
|
|
||||||
|
export { IndexConfig } from "./native";
|
||||||
/**
|
/**
|
||||||
* Options for adding data to a table.
|
* Options for adding data to a table.
|
||||||
*/
|
*/
|
||||||
@@ -33,6 +35,20 @@ export interface AddDataOptions {
|
|||||||
mode: "append" | "overwrite";
|
mode: "append" | "overwrite";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface UpdateOptions {
|
||||||
|
/**
|
||||||
|
* A filter that limits the scope of the update.
|
||||||
|
*
|
||||||
|
* This should be an SQL filter expression.
|
||||||
|
*
|
||||||
|
* Only rows that satisfy the expression will be updated.
|
||||||
|
*
|
||||||
|
* For example, this could be 'my_col == 0' to replace all instances
|
||||||
|
* of 0 in a column with some other default value.
|
||||||
|
*/
|
||||||
|
where: string;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Table is a collection of Records in a LanceDB Database.
|
* A Table is a collection of Records in a LanceDB Database.
|
||||||
*
|
*
|
||||||
@@ -93,6 +109,45 @@ export class Table {
|
|||||||
await this.inner.add(buffer, mode);
|
await this.inner.add(buffer, mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update existing records in the Table
|
||||||
|
*
|
||||||
|
* An update operation can be used to adjust existing values. Use the
|
||||||
|
* returned builder to specify which columns to update. The new value
|
||||||
|
* can be a literal value (e.g. replacing nulls with some default value)
|
||||||
|
* or an expression applied to the old value (e.g. incrementing a value)
|
||||||
|
*
|
||||||
|
* An optional condition can be specified (e.g. "only update if the old
|
||||||
|
* value is 0")
|
||||||
|
*
|
||||||
|
* Note: if your condition is something like "some_id_column == 7" and
|
||||||
|
* you are updating many rows (with different ids) then you will get
|
||||||
|
* better performance with a single [`merge_insert`] call instead of
|
||||||
|
* repeatedly calilng this method.
|
||||||
|
*
|
||||||
|
* @param updates the columns to update
|
||||||
|
*
|
||||||
|
* Keys in the map should specify the name of the column to update.
|
||||||
|
* Values in the map provide the new value of the column. These can
|
||||||
|
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
|
||||||
|
* based on the row being updated (e.g. "my_col + 1")
|
||||||
|
*
|
||||||
|
* @param options additional options to control the update behavior
|
||||||
|
*/
|
||||||
|
async update(
|
||||||
|
updates: Map<string, string> | Record<string, string>,
|
||||||
|
options?: Partial<UpdateOptions>,
|
||||||
|
) {
|
||||||
|
const onlyIf = options?.where;
|
||||||
|
let columns: [string, string][];
|
||||||
|
if (updates instanceof Map) {
|
||||||
|
columns = Array.from(updates.entries());
|
||||||
|
} else {
|
||||||
|
columns = Object.entries(updates);
|
||||||
|
}
|
||||||
|
await this.inner.update(onlyIf, columns);
|
||||||
|
}
|
||||||
|
|
||||||
/** Count the total number of rows in the dataset. */
|
/** Count the total number of rows in the dataset. */
|
||||||
async countRows(filter?: string): Promise<number> {
|
async countRows(filter?: string): Promise<number> {
|
||||||
return await this.inner.countRows(filter);
|
return await this.inner.countRows(filter);
|
||||||
@@ -103,24 +158,28 @@ export class Table {
|
|||||||
await this.inner.delete(predicate);
|
await this.inner.delete(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create an index over the columns.
|
/** Create an index to speed up queries.
|
||||||
*
|
*
|
||||||
* @param {string} column The column to create the index on. If not specified,
|
* Indices can be created on vector columns or scalar columns.
|
||||||
* it will create an index on vector field.
|
* Indices on vector columns will speed up vector searches.
|
||||||
|
* Indices on scalar columns will speed up filtering (in both
|
||||||
|
* vector and non-vector searches)
|
||||||
*
|
*
|
||||||
* @example
|
* @example
|
||||||
*
|
*
|
||||||
* By default, it creates vector idnex on one vector column.
|
* If the column has a vector (fixed size list) data type then
|
||||||
|
* an IvfPq vector index will be created.
|
||||||
*
|
*
|
||||||
* ```typescript
|
* ```typescript
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex().build();
|
* await table.createIndex(["vector"]);
|
||||||
* ```
|
* ```
|
||||||
*
|
*
|
||||||
* You can specify `IVF_PQ` parameters via `ivf_pq({})` call.
|
* For advanced control over vector index creation you can specify
|
||||||
|
* the index type and options.
|
||||||
* ```typescript
|
* ```typescript
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex("my_vec_col")
|
* await table.createIndex(["vector"], I)
|
||||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
||||||
* .build();
|
* .build();
|
||||||
* ```
|
* ```
|
||||||
@@ -131,12 +190,11 @@ export class Table {
|
|||||||
* await table.createIndex("my_float_col").build();
|
* await table.createIndex("my_float_col").build();
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
createIndex(column?: string): IndexBuilder {
|
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
||||||
let builder = new IndexBuilder(this.inner);
|
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||||
if (column !== undefined) {
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
builder = builder.column(column);
|
const nativeIndex = (options?.config as any)?.inner;
|
||||||
}
|
await this.inner.createIndex(nativeIndex, column, options?.replace);
|
||||||
return builder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -232,4 +290,65 @@ export class Table {
|
|||||||
async dropColumns(columnNames: string[]): Promise<void> {
|
async dropColumns(columnNames: string[]): Promise<void> {
|
||||||
await this.inner.dropColumns(columnNames);
|
await this.inner.dropColumns(columnNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Retrieve the version of the table
|
||||||
|
*
|
||||||
|
* LanceDb supports versioning. Every operation that modifies the table increases
|
||||||
|
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
||||||
|
* version to view the data at that point. In addition, you can `[Self::restore]` the
|
||||||
|
* version to replace the current table with a previous version.
|
||||||
|
*/
|
||||||
|
async version(): Promise<number> {
|
||||||
|
return await this.inner.version();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Checks out a specific version of the Table
|
||||||
|
*
|
||||||
|
* Any read operation on the table will now access the data at the checked out version.
|
||||||
|
* As a consequence, calling this method will disable any read consistency interval
|
||||||
|
* that was previously set.
|
||||||
|
*
|
||||||
|
* This is a read-only operation that turns the table into a sort of "view"
|
||||||
|
* or "detached head". Other table instances will not be affected. To make the change
|
||||||
|
* permanent you can use the `[Self::restore]` method.
|
||||||
|
*
|
||||||
|
* Any operation that modifies the table will fail while the table is in a checked
|
||||||
|
* out state.
|
||||||
|
*
|
||||||
|
* To return the table to a normal state use `[Self::checkout_latest]`
|
||||||
|
*/
|
||||||
|
async checkout(version: number): Promise<void> {
|
||||||
|
await this.inner.checkout(version);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Ensures the table is pointing at the latest version
|
||||||
|
*
|
||||||
|
* This can be used to manually update a table when the read_consistency_interval is None
|
||||||
|
* It can also be used to undo a `[Self::checkout]` operation
|
||||||
|
*/
|
||||||
|
async checkoutLatest(): Promise<void> {
|
||||||
|
await this.inner.checkoutLatest();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Restore the table to the currently checked out version
|
||||||
|
*
|
||||||
|
* This operation will fail if checkout has not been called previously
|
||||||
|
*
|
||||||
|
* This operation will overwrite the latest version of the table with a
|
||||||
|
* previous version. Any changes made since the checked out version will
|
||||||
|
* no longer be visible.
|
||||||
|
*
|
||||||
|
* Once the operation concludes the table will no longer be in a checked
|
||||||
|
* out state and the read_consistency_interval, if any, will apply.
|
||||||
|
*/
|
||||||
|
async restore(): Promise<void> {
|
||||||
|
await this.inner.restore();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List all indices that have been created with Self::create_index
|
||||||
|
*/
|
||||||
|
async listIndices(): Promise<IndexConfig[]> {
|
||||||
|
return await this.inner.listIndices();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
12
nodejs/src/error.rs
Normal file
12
nodejs/src/error.rs
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
pub type Result<T> = napi::Result<T>;
|
||||||
|
|
||||||
|
pub trait NapiErrorExt<T> {
|
||||||
|
/// Convert to a napi error using from_reason(err.to_string())
|
||||||
|
fn default_error(self) -> Result<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> NapiErrorExt<T> for std::result::Result<T, lancedb::Error> {
|
||||||
|
fn default_error(self) -> Result<T> {
|
||||||
|
self.map_err(|err| napi::Error::from_reason(err.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,126 +14,73 @@
|
|||||||
|
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use lance_linalg::distance::MetricType as LanceMetricType;
|
use lancedb::index::scalar::BTreeIndexBuilder;
|
||||||
use lancedb::index::IndexBuilder as LanceDbIndexBuilder;
|
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||||
use lancedb::Table as LanceDbTable;
|
use lancedb::index::Index as LanceDbIndex;
|
||||||
|
use lancedb::DistanceType;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub enum IndexType {
|
pub struct Index {
|
||||||
Scalar,
|
inner: Mutex<Option<LanceDbIndex>>,
|
||||||
IvfPq,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
impl Index {
|
||||||
pub enum MetricType {
|
pub fn consume(&self) -> napi::Result<LanceDbIndex> {
|
||||||
L2,
|
self.inner
|
||||||
Cosine,
|
.lock()
|
||||||
Dot,
|
.unwrap()
|
||||||
}
|
.take()
|
||||||
|
.ok_or(napi::Error::from_reason(
|
||||||
impl From<MetricType> for LanceMetricType {
|
"attempt to use an index more than once",
|
||||||
fn from(metric: MetricType) -> Self {
|
))
|
||||||
match metric {
|
|
||||||
MetricType::L2 => Self::L2,
|
|
||||||
MetricType::Cosine => Self::Cosine,
|
|
||||||
MetricType::Dot => Self::Dot,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct IndexBuilder {
|
impl Index {
|
||||||
inner: Mutex<Option<LanceDbIndexBuilder>>,
|
#[napi(factory)]
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexBuilder {
|
|
||||||
fn modify(
|
|
||||||
&self,
|
|
||||||
mod_fn: impl Fn(LanceDbIndexBuilder) -> LanceDbIndexBuilder,
|
|
||||||
) -> napi::Result<()> {
|
|
||||||
let mut inner = self.inner.lock().unwrap();
|
|
||||||
let inner_builder = inner.take().ok_or_else(|| {
|
|
||||||
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
|
|
||||||
})?;
|
|
||||||
let inner_builder = mod_fn(inner_builder);
|
|
||||||
inner.replace(inner_builder);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
impl IndexBuilder {
|
|
||||||
pub fn new(tbl: &LanceDbTable) -> Self {
|
|
||||||
let inner = tbl.create_index(&[]);
|
|
||||||
Self {
|
|
||||||
inner: Mutex::new(Some(inner)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn replace(&self, v: bool) -> napi::Result<()> {
|
|
||||||
self.modify(|b| b.replace(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn column(&self, c: String) -> napi::Result<()> {
|
|
||||||
self.modify(|b| b.columns(&[c.as_str()]))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn name(&self, name: String) -> napi::Result<()> {
|
|
||||||
self.modify(|b| b.name(name.as_str()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn ivf_pq(
|
pub fn ivf_pq(
|
||||||
&self,
|
distance_type: Option<String>,
|
||||||
metric_type: Option<MetricType>,
|
|
||||||
num_partitions: Option<u32>,
|
num_partitions: Option<u32>,
|
||||||
num_sub_vectors: Option<u32>,
|
num_sub_vectors: Option<u32>,
|
||||||
num_bits: Option<u32>,
|
|
||||||
max_iterations: Option<u32>,
|
max_iterations: Option<u32>,
|
||||||
sample_rate: Option<u32>,
|
sample_rate: Option<u32>,
|
||||||
) -> napi::Result<()> {
|
) -> napi::Result<Self> {
|
||||||
self.modify(|b| {
|
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||||
let mut b = b.ivf_pq();
|
if let Some(distance_type) = distance_type {
|
||||||
if let Some(metric_type) = metric_type {
|
let distance_type = match distance_type.as_str() {
|
||||||
b = b.metric_type(metric_type.into());
|
"l2" => Ok(DistanceType::L2),
|
||||||
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
_ => Err(napi::Error::from_reason(format!(
|
||||||
|
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||||
|
distance_type
|
||||||
|
))),
|
||||||
|
}?;
|
||||||
|
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||||
}
|
}
|
||||||
if let Some(num_partitions) = num_partitions {
|
if let Some(num_partitions) = num_partitions {
|
||||||
b = b.num_partitions(num_partitions);
|
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||||
}
|
}
|
||||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||||
b = b.num_sub_vectors(num_sub_vectors);
|
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||||
}
|
|
||||||
if let Some(num_bits) = num_bits {
|
|
||||||
b = b.num_bits(num_bits);
|
|
||||||
}
|
}
|
||||||
if let Some(max_iterations) = max_iterations {
|
if let Some(max_iterations) = max_iterations {
|
||||||
b = b.max_iterations(max_iterations);
|
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||||
}
|
}
|
||||||
if let Some(sample_rate) = sample_rate {
|
if let Some(sample_rate) = sample_rate {
|
||||||
b = b.sample_rate(sample_rate);
|
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||||
}
|
}
|
||||||
b
|
Ok(Self {
|
||||||
|
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(factory)]
|
||||||
pub fn scalar(&self) -> napi::Result<()> {
|
pub fn btree() -> Self {
|
||||||
self.modify(|b| b.scalar())
|
Self {
|
||||||
}
|
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||||
|
}
|
||||||
#[napi]
|
|
||||||
pub async fn build(&self) -> napi::Result<()> {
|
|
||||||
let inner = self.inner.lock().unwrap().take().ok_or_else(|| {
|
|
||||||
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
|
|
||||||
})?;
|
|
||||||
inner
|
|
||||||
.build()
|
|
||||||
.await
|
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use lance::io::RecordBatchStream;
|
use lancedb::arrow::SendableRecordBatchStream;
|
||||||
use lancedb::ipc::batches_to_ipc_file;
|
use lancedb::ipc::batches_to_ipc_file;
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
@@ -21,12 +21,12 @@ use napi_derive::napi;
|
|||||||
/** Typescript-style Async Iterator over RecordBatches */
|
/** Typescript-style Async Iterator over RecordBatches */
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct RecordBatchIterator {
|
pub struct RecordBatchIterator {
|
||||||
inner: Box<dyn RecordBatchStream + Unpin>,
|
inner: SendableRecordBatchStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl RecordBatchIterator {
|
impl RecordBatchIterator {
|
||||||
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
|
pub(crate) fn new(inner: SendableRecordBatchStream) -> Self {
|
||||||
Self { inner }
|
Self { inner }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ use connection::Connection;
|
|||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
mod connection;
|
mod connection;
|
||||||
|
mod error;
|
||||||
mod index;
|
mod index;
|
||||||
mod iterator;
|
mod iterator;
|
||||||
mod query;
|
mod query;
|
||||||
|
|||||||
@@ -74,6 +74,6 @@ impl Query {
|
|||||||
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
|
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
||||||
})?;
|
})?;
|
||||||
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
|
Ok(RecordBatchIterator::new(inner_stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,13 +13,16 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use arrow_ipc::writer::FileWriter;
|
use arrow_ipc::writer::FileWriter;
|
||||||
use lance::dataset::ColumnAlteration as LanceColumnAlteration;
|
|
||||||
use lancedb::ipc::ipc_file_to_batches;
|
use lancedb::ipc::ipc_file_to_batches;
|
||||||
use lancedb::table::{AddDataMode, Table as LanceDbTable};
|
use lancedb::table::{
|
||||||
|
AddDataMode, ColumnAlteration as LanceColumnAlteration, NewColumnTransform,
|
||||||
|
Table as LanceDbTable,
|
||||||
|
};
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
|
|
||||||
use crate::index::IndexBuilder;
|
use crate::error::NapiErrorExt;
|
||||||
|
use crate::index::Index;
|
||||||
use crate::query::Query;
|
use crate::query::Query;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -129,8 +132,38 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn create_index(&self) -> napi::Result<IndexBuilder> {
|
pub async fn create_index(
|
||||||
Ok(IndexBuilder::new(self.inner_ref()?))
|
&self,
|
||||||
|
index: Option<&Index>,
|
||||||
|
column: String,
|
||||||
|
replace: Option<bool>,
|
||||||
|
) -> napi::Result<()> {
|
||||||
|
let lancedb_index = if let Some(index) = index {
|
||||||
|
index.consume()?
|
||||||
|
} else {
|
||||||
|
lancedb::index::Index::Auto
|
||||||
|
};
|
||||||
|
let mut builder = self.inner_ref()?.create_index(&[column], lancedb_index);
|
||||||
|
if let Some(replace) = replace {
|
||||||
|
builder = builder.replace(replace);
|
||||||
|
}
|
||||||
|
builder.execute().await.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn update(
|
||||||
|
&self,
|
||||||
|
only_if: Option<String>,
|
||||||
|
columns: Vec<(String, String)>,
|
||||||
|
) -> napi::Result<()> {
|
||||||
|
let mut op = self.inner_ref()?.update();
|
||||||
|
if let Some(only_if) = only_if {
|
||||||
|
op = op.only_if(only_if);
|
||||||
|
}
|
||||||
|
for (column_name, value) in columns {
|
||||||
|
op = op.column(column_name, value);
|
||||||
|
}
|
||||||
|
op.execute().await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -144,7 +177,7 @@ impl Table {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|sql| (sql.name, sql.value_sql))
|
.map(|sql| (sql.name, sql.value_sql))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms);
|
let transforms = NewColumnTransform::SqlExpressions(transforms);
|
||||||
self.inner_ref()?
|
self.inner_ref()?
|
||||||
.add_columns(transforms, None)
|
.add_columns(transforms, None)
|
||||||
.await
|
.await
|
||||||
@@ -197,6 +230,67 @@ impl Table {
|
|||||||
})?;
|
})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn version(&self) -> napi::Result<i64> {
|
||||||
|
self.inner_ref()?
|
||||||
|
.version()
|
||||||
|
.await
|
||||||
|
.map(|val| val as i64)
|
||||||
|
.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
|
||||||
|
self.inner_ref()?
|
||||||
|
.checkout(version as u64)
|
||||||
|
.await
|
||||||
|
.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn checkout_latest(&self) -> napi::Result<()> {
|
||||||
|
self.inner_ref()?.checkout_latest().await.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn restore(&self) -> napi::Result<()> {
|
||||||
|
self.inner_ref()?.restore().await.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
|
||||||
|
Ok(self
|
||||||
|
.inner_ref()?
|
||||||
|
.list_indices()
|
||||||
|
.await
|
||||||
|
.default_error()?
|
||||||
|
.into_iter()
|
||||||
|
.map(IndexConfig::from)
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
/// A description of an index currently configured on a column
|
||||||
|
pub struct IndexConfig {
|
||||||
|
/// The type of the index
|
||||||
|
pub index_type: String,
|
||||||
|
/// The columns in the index
|
||||||
|
///
|
||||||
|
/// Currently this is always an array of size 1. In the future there may
|
||||||
|
/// be more columns to represent composite indices.
|
||||||
|
pub columns: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<lancedb::index::IndexConfig> for IndexConfig {
|
||||||
|
fn from(value: lancedb::index::IndexConfig) -> Self {
|
||||||
|
let index_type = format!("{:?}", value.index_type);
|
||||||
|
Self {
|
||||||
|
index_type,
|
||||||
|
columns: value.columns,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A definition of a column alteration. The alteration changes the column at
|
/// A definition of a column alteration. The alteration changes the column at
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.2
|
current_version = 0.6.3
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.6.2"
|
version = "0.6.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.10.2",
|
"pylance==0.10.2",
|
||||||
@@ -57,6 +57,7 @@ tests = [
|
|||||||
"duckdb",
|
"duckdb",
|
||||||
"pytz",
|
"pytz",
|
||||||
"polars>=0.19",
|
"polars>=0.19",
|
||||||
|
"pillow",
|
||||||
]
|
]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit"]
|
||||||
docs = [
|
docs = [
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ from ._lancedb import connect as lancedb_connect
|
|||||||
from .common import URI, sanitize_uri
|
from .common import URI, sanitize_uri
|
||||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
from .schema import vector # noqa: F401
|
from .schema import vector
|
||||||
from .utils import sentry_log # noqa: F401
|
from .table import AsyncTable
|
||||||
|
from .utils import sentry_log
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
@@ -35,6 +36,7 @@ def connect(
|
|||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
read_consistency_interval: Optional[timedelta] = None,
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
@@ -99,7 +101,12 @@ def connect(
|
|||||||
if isinstance(request_thread_pool, int):
|
if isinstance(request_thread_pool, int):
|
||||||
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||||
return RemoteDBConnection(
|
return RemoteDBConnection(
|
||||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
uri,
|
||||||
|
api_key,
|
||||||
|
region,
|
||||||
|
host_override,
|
||||||
|
request_thread_pool=request_thread_pool,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
|
|
||||||
@@ -182,3 +189,19 @@ async def connect_async(
|
|||||||
read_consistency_interval_secs,
|
read_consistency_interval_secs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"connect",
|
||||||
|
"connect_async",
|
||||||
|
"AsyncConnection",
|
||||||
|
"AsyncTable",
|
||||||
|
"URI",
|
||||||
|
"sanitize_uri",
|
||||||
|
"sentry_log",
|
||||||
|
"vector",
|
||||||
|
"DBConnection",
|
||||||
|
"LanceDBConnection",
|
||||||
|
"RemoteDBConnection",
|
||||||
|
"__version__",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
from typing import Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
|
class Index:
|
||||||
|
@staticmethod
|
||||||
|
def ivf_pq(
|
||||||
|
distance_type: Optional[str],
|
||||||
|
num_partitions: Optional[int],
|
||||||
|
num_sub_vectors: Optional[int],
|
||||||
|
max_iterations: Optional[int],
|
||||||
|
sample_rate: Optional[int],
|
||||||
|
) -> Index: ...
|
||||||
|
@staticmethod
|
||||||
|
def btree() -> Index: ...
|
||||||
|
|
||||||
class Connection(object):
|
class Connection(object):
|
||||||
async def table_names(
|
async def table_names(
|
||||||
self, start_after: Optional[str], limit: Optional[int]
|
self, start_after: Optional[str], limit: Optional[int]
|
||||||
@@ -13,10 +25,25 @@ class Connection(object):
|
|||||||
self, name: str, mode: str, schema: pa.Schema
|
self, name: str, mode: str, schema: pa.Schema
|
||||||
) -> Table: ...
|
) -> Table: ...
|
||||||
|
|
||||||
class Table(object):
|
class Table:
|
||||||
def name(self) -> str: ...
|
def name(self) -> str: ...
|
||||||
def __repr__(self) -> str: ...
|
def __repr__(self) -> str: ...
|
||||||
async def schema(self) -> pa.Schema: ...
|
async def schema(self) -> pa.Schema: ...
|
||||||
|
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||||
|
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||||
|
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||||
|
async def create_index(
|
||||||
|
self, column: str, config: Optional[Index], replace: Optional[bool]
|
||||||
|
): ...
|
||||||
|
async def version(self) -> int: ...
|
||||||
|
async def checkout(self, version): ...
|
||||||
|
async def checkout_latest(self): ...
|
||||||
|
async def restore(self): ...
|
||||||
|
async def list_indices(self) -> List[IndexConfig]: ...
|
||||||
|
|
||||||
|
class IndexConfig:
|
||||||
|
index_type: str
|
||||||
|
columns: List[str]
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
uri: str,
|
uri: str,
|
||||||
|
|||||||
@@ -529,7 +529,7 @@ class AsyncConnection(object):
|
|||||||
on_bad_vectors: Optional[str] = None,
|
on_bad_vectors: Optional[str] = None,
|
||||||
fill_value: Optional[float] = None,
|
fill_value: Optional[float] = None,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> AsyncTable:
|
||||||
"""Create a [Table][lancedb.table.Table] in the database.
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
Issue concurrent requests to retrieve the image data
|
Issue concurrent requests to retrieve the image data
|
||||||
"""
|
"""
|
||||||
|
return [
|
||||||
|
self.generate_image_embedding(image) for image in tqdm(images)
|
||||||
|
]
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(self.generate_image_embedding, image)
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
|||||||
163
python/python/lancedb/index.py
Normal file
163
python/python/lancedb/index.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ._lancedb import (
|
||||||
|
Index as LanceDbIndex,
|
||||||
|
)
|
||||||
|
from ._lancedb import (
|
||||||
|
IndexConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BTree(object):
|
||||||
|
"""Describes a btree index configuration
|
||||||
|
|
||||||
|
A btree index is an index on scalar columns. The index stores a copy of the
|
||||||
|
column in sorted order. A header entry is created for each block of rows
|
||||||
|
(currently the block size is fixed at 4096). These header entries are stored
|
||||||
|
in a separate cacheable structure (a btree). To search for data the header is
|
||||||
|
used to determine which blocks need to be read from disk.
|
||||||
|
|
||||||
|
For example, a btree index in a table with 1Bi rows requires
|
||||||
|
sizeof(Scalar) * 256Ki bytes of memory and will generally need to read
|
||||||
|
sizeof(Scalar) * 4096 bytes to find the correct row ids.
|
||||||
|
|
||||||
|
This index is good for scalar columns with mostly distinct values and does best
|
||||||
|
when the query is highly selective.
|
||||||
|
|
||||||
|
The btree index does not currently have any parameters though parameters such as
|
||||||
|
the block size may be added in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._inner = LanceDbIndex.btree()
|
||||||
|
|
||||||
|
|
||||||
|
class IvfPq(object):
|
||||||
|
"""Describes an IVF PQ Index
|
||||||
|
|
||||||
|
This index stores a compressed (quantized) copy of every vector. These vectors
|
||||||
|
are grouped into partitions of similar vectors. Each partition keeps track of
|
||||||
|
a centroid which is the average value of all vectors in the group.
|
||||||
|
|
||||||
|
During a query the centroids are compared with the query vector to find the
|
||||||
|
closest partitions. The compressed vectors in these partitions are then
|
||||||
|
searched to find the closest vectors.
|
||||||
|
|
||||||
|
The compression scheme is called product quantization. Each vector is divide
|
||||||
|
into subvectors and then each subvector is quantized into a small number of
|
||||||
|
bits. the parameters `num_bits` and `num_subvectors` control this process,
|
||||||
|
providing a tradeoff between index size (and thus search speed) and index
|
||||||
|
accuracy.
|
||||||
|
|
||||||
|
The partitioning process is called IVF and the `num_partitions` parameter
|
||||||
|
controls how many groups to create.
|
||||||
|
|
||||||
|
Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||||
|
currently is also a memory intensive operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
distance_type: Optional[str] = None,
|
||||||
|
num_partitions: Optional[int] = None,
|
||||||
|
num_sub_vectors: Optional[int] = None,
|
||||||
|
max_iterations: Optional[int] = None,
|
||||||
|
sample_rate: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an IVF PQ index config
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
distance_type: str, default "L2"
|
||||||
|
The distance metric used to train the index
|
||||||
|
|
||||||
|
This is used when training the index to calculate the IVF partitions
|
||||||
|
(vectors are grouped in partitions with similar vectors according to this
|
||||||
|
distance type) and to calculate a subvector's code during quantization.
|
||||||
|
|
||||||
|
The distance type used to train an index MUST match the distance type used
|
||||||
|
to search the index. Failure to do so will yield inaccurate results.
|
||||||
|
|
||||||
|
The following distance types are available:
|
||||||
|
|
||||||
|
"l2" - Euclidean distance. This is a very common distance metric that
|
||||||
|
accounts for both magnitude and direction when determining the distance
|
||||||
|
between vectors. L2 distance has a range of [0, ∞).
|
||||||
|
|
||||||
|
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||||
|
calculated from the cosine similarity between two vectors. Cosine
|
||||||
|
similarity is a measure of similarity between two non-zero vectors of an
|
||||||
|
inner product space. It is defined to equal the cosine of the angle
|
||||||
|
between them. Unlike L2, the cosine distance is not affected by the
|
||||||
|
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||||
|
|
||||||
|
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||||
|
are all zeros (there is no direction). These vectors are invalid and may
|
||||||
|
never be returned from a vector search.
|
||||||
|
|
||||||
|
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||||
|
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||||
|
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||||
|
num_partitions: int, default sqrt(num_rows)
|
||||||
|
The number of IVF partitions to create.
|
||||||
|
|
||||||
|
This value should generally scale with the number of rows in the dataset.
|
||||||
|
By default the number of partitions is the square root of the number of
|
||||||
|
rows.
|
||||||
|
|
||||||
|
If this value is too large then the first part of the search (picking the
|
||||||
|
right partition) will be slow. If this value is too small then the second
|
||||||
|
part of the search (searching within a partition) will be slow.
|
||||||
|
num_sub_vectors: int, default is vector dimension / 16
|
||||||
|
Number of sub-vectors of PQ.
|
||||||
|
|
||||||
|
This value controls how much the vector is compressed during the
|
||||||
|
quantization step. The more sub vectors there are the less the vector is
|
||||||
|
compressed. The default is the dimension of the vector divided by 16. If
|
||||||
|
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||||
|
8.
|
||||||
|
|
||||||
|
The above two cases are highly preferred. Having 8 or 16 values per
|
||||||
|
subvector allows us to use efficient SIMD instructions.
|
||||||
|
|
||||||
|
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||||
|
ideal and will likely result in poor performance.
|
||||||
|
max_iterations: int, default 50
|
||||||
|
Max iteration to train kmeans.
|
||||||
|
|
||||||
|
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||||
|
This parameter controls how many iterations of kmeans to run.
|
||||||
|
|
||||||
|
Increasing this might improve the quality of the index but in most cases
|
||||||
|
these extra iterations have diminishing returns.
|
||||||
|
|
||||||
|
The default value is 50.
|
||||||
|
sample_rate: int, default 256
|
||||||
|
The rate used to calculate the number of training vectors for kmeans.
|
||||||
|
|
||||||
|
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||||
|
are groups of vectors that are similar to each other. To do this we use an
|
||||||
|
algorithm called kmeans.
|
||||||
|
|
||||||
|
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||||
|
kmeans on a random sample of the data. This parameter controls the size of
|
||||||
|
the sample. The total number of vectors used to train the index is
|
||||||
|
`sample_rate * num_partitions`.
|
||||||
|
|
||||||
|
Increasing this value might improve the quality of the index but in most
|
||||||
|
cases the default should be sufficient.
|
||||||
|
|
||||||
|
The default value is 256.
|
||||||
|
"""
|
||||||
|
self._inner = LanceDbIndex.ivf_pq(
|
||||||
|
distance_type=distance_type,
|
||||||
|
num_partitions=num_partitions,
|
||||||
|
num_sub_vectors=num_sub_vectors,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BTree", "IvfPq", "IndexConfig"]
|
||||||
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -26,7 +27,9 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
_GenericAlias,
|
_GenericAlias,
|
||||||
@@ -36,19 +39,30 @@ import numpy as np
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
import semver
|
import semver
|
||||||
|
from lance.arrow import (
|
||||||
|
EncodedImageType,
|
||||||
|
)
|
||||||
|
from lance.util import _check_huggingface
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
from .util import attempt_import_or_raise
|
||||||
|
|
||||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||||
try:
|
|
||||||
from pydantic_core import CoreSchema, core_schema
|
|
||||||
except ImportError:
|
|
||||||
if PYDANTIC_VERSION >= (2,):
|
|
||||||
raise
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic import GetJsonSchemaHandler
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
from pydantic_core import CoreSchema
|
||||||
|
except ImportError:
|
||||||
|
if PYDANTIC_VERSION >= (2,):
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeListMixin(ABC):
|
class FixedSizeListMixin(ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -123,7 +137,7 @@ def Vector(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_core_schema__(
|
def __get_pydantic_core_schema__(
|
||||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||||
) -> CoreSchema:
|
) -> "CoreSchema":
|
||||||
return core_schema.no_info_after_validator_function(
|
return core_schema.no_info_after_validator_function(
|
||||||
cls,
|
cls,
|
||||||
core_schema.list_schema(
|
core_schema.list_schema(
|
||||||
@@ -181,24 +195,117 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
|||||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||||
child = py_type.__args__[0]
|
child = py_type.__args__[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||||
|
elif _safe_is_huggingface_image():
|
||||||
|
import datasets
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
class ImageMixin(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def value_arrow_type() -> pa.DataType:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
|
||||||
return [
|
def EncodedImage():
|
||||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
attempt_import_or_raise("PIL", "pillow or pip install lancedb[embeddings]")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
class EncodedImage(bytes, ImageMixin):
|
||||||
|
"""Pydantic type for inlined images.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Experimental feature.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
>>> import pydantic
|
||||||
|
>>> from lancedb.pydantic import EncodedImage
|
||||||
|
...
|
||||||
|
>>> class MyModel(pydantic.BaseModel):
|
||||||
|
... image: EncodedImage()
|
||||||
|
>>> schema = pydantic_to_schema(MyModel)
|
||||||
|
>>> assert schema == pa.schema([
|
||||||
|
... pa.field("image", pa.binary(), False)
|
||||||
|
... ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "EncodedImage()"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_arrow_type() -> pa.DataType:
|
||||||
|
return EncodedImageType()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||||
|
) -> "CoreSchema":
|
||||||
|
from_bytes_schema = core_schema.bytes_schema()
|
||||||
|
|
||||||
|
return core_schema.json_or_python_schema(
|
||||||
|
json_schema=from_bytes_schema,
|
||||||
|
python_schema=core_schema.union_schema(
|
||||||
|
[
|
||||||
|
core_schema.is_instance_schema(PIL.Image.Image),
|
||||||
|
from_bytes_schema,
|
||||||
]
|
]
|
||||||
|
),
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda instance: cls.validate(instance)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
|
||||||
|
) -> "JsonSchemaValue":
|
||||||
|
return handler(core_schema.bytes_schema())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
# For pydantic v2
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v):
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
return v
|
||||||
|
if isinstance(v, PIL.Image.Image):
|
||||||
|
with io.BytesIO() as output:
|
||||||
|
v.save(output, format=v.format)
|
||||||
|
return output.getvalue()
|
||||||
|
raise TypeError(
|
||||||
|
"EncodedImage can take bytes or PIL.Image.Image "
|
||||||
|
f"as input but got {type(v)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION < (2, 0):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||||
|
field_schema["type"] = "string"
|
||||||
|
field_schema["format"] = "binary"
|
||||||
|
|
||||||
|
return EncodedImage
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
def _safe_get_fields(model: pydantic.BaseModel):
|
||||||
|
return model.__fields__
|
||||||
else:
|
else:
|
||||||
|
def _safe_get_fields(model: pydantic.BaseModel):
|
||||||
|
return model.model_fields
|
||||||
|
|
||||||
|
|
||||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||||
return [
|
return [
|
||||||
_pydantic_to_field(name, field)
|
_pydantic_to_field(name, field)
|
||||||
for name, field in model.model_fields.items()
|
for name, field in _safe_get_fields(model).items()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -230,6 +337,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
|||||||
return pa.struct(fields)
|
return pa.struct(fields)
|
||||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||||
|
elif issubclass(field.annotation, ImageMixin):
|
||||||
|
return field.annotation.value_arrow_type()
|
||||||
|
|
||||||
return _py_type_to_arrow_type(field.annotation, field)
|
return _py_type_to_arrow_type(field.annotation, field)
|
||||||
|
|
||||||
|
|
||||||
@@ -335,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
"""
|
"""
|
||||||
Get the field names of this model.
|
Get the field names of this model.
|
||||||
"""
|
"""
|
||||||
return list(cls.safe_get_fields().keys())
|
return list(_safe_get_fields(cls).keys())
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safe_get_fields(cls):
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
|
||||||
return cls.__fields__
|
|
||||||
return cls.model_fields
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||||
@@ -351,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
|
||||||
vec_and_function = []
|
vec_and_function = []
|
||||||
for name, field_info in cls.safe_get_fields().items():
|
def get_vector_column(name, field_info):
|
||||||
func = get_extras(field_info, "vector_column_for")
|
fun = get_extras(field_info, "vector_column_for")
|
||||||
if func is not None:
|
if func is not None:
|
||||||
vec_and_function.append([name, func])
|
vec_and_function.append([name, func])
|
||||||
|
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
|
# find the source columns for each one
|
||||||
for vec, func in vec_and_function:
|
for vec, func in vec_and_function:
|
||||||
for source, field_info in cls.safe_get_fields().items():
|
def get_source_column(source, field_info):
|
||||||
src_func = get_extras(field_info, "source_column_for")
|
src_func = get_extras(field_info, "source_column_for")
|
||||||
if src_func is func:
|
if src_func is func:
|
||||||
# note we can't use == here since the function is a pydantic
|
# note we can't use == here since the function is a pydantic
|
||||||
@@ -371,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
source_column=source, vector_column=vec, function=func
|
source_column=source, vector_column=vec, function=func
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
visit_fields(_safe_get_fields(cls).items(), get_source_column)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
|
||||||
|
visitor: Callable[[str, FieldInfo], Any]):
|
||||||
|
"""
|
||||||
|
Visit all the leaf fields in a Pydantic model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fields : Iterable[Tuple(str, FieldInfo)]
|
||||||
|
The fields to visit.
|
||||||
|
visitor : Callable[[str, FieldInfo], Any]
|
||||||
|
The visitor function.
|
||||||
|
"""
|
||||||
|
for name, field_info in fields:
|
||||||
|
# if the field is a pydantic model then
|
||||||
|
# visit all subfields
|
||||||
|
if (isinstance(getattr(field_info, "annotation"), type)
|
||||||
|
and issubclass(field_info.annotation, pydantic.BaseModel)):
|
||||||
|
visit_fields(_safe_get_fields(field_info.annotation).items(),
|
||||||
|
_add_prefix(visitor, name))
|
||||||
|
else:
|
||||||
|
visitor(name, field_info)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
|
||||||
|
def prefixed_visitor(name: str, field: FieldInfo):
|
||||||
|
return visitor(f"{prefix}.{name}", field)
|
||||||
|
return prefixed_visitor
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
|
||||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||||
"""
|
"""
|
||||||
Get the extra metadata from a Pydantic FieldInfo.
|
Get the extra metadata from a Pydantic FieldInfo.
|
||||||
"""
|
"""
|
||||||
if PYDANTIC_VERSION.major >= 2:
|
|
||||||
return (field_info.json_schema_extra or {}).get(key)
|
|
||||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
|
||||||
|
|
||||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a Pydantic model to a dictionary.
|
Convert a Pydantic model to a dictionary.
|
||||||
@@ -393,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the extra metadata from a Pydantic FieldInfo.
|
||||||
|
"""
|
||||||
|
return (field_info.json_schema_extra or {}).get(key)
|
||||||
|
|
||||||
|
|
||||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a Pydantic model to a dictionary.
|
Convert a Pydantic model to a dictionary.
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ class RestfulLanceDBClient:
|
|||||||
|
|
||||||
closed: bool = attrs.field(default=False, init=False)
|
closed: bool = attrs.field(default=False, init=False)
|
||||||
|
|
||||||
|
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||||
|
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> requests.Session:
|
def session(self) -> requests.Session:
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
@@ -117,7 +120,7 @@ class RestfulLanceDBClient:
|
|||||||
urljoin(self.url, uri),
|
urljoin(self.url, uri),
|
||||||
params=params,
|
params=params,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=(120.0, 300.0),
|
timeout=(self.connection_timeout, self.read_timeout),
|
||||||
) as resp:
|
) as resp:
|
||||||
self._check_status(resp)
|
self._check_status(resp)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
@@ -159,7 +162,7 @@ class RestfulLanceDBClient:
|
|||||||
urljoin(self.url, uri),
|
urljoin(self.url, uri),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
timeout=(120.0, 300.0),
|
timeout=(self.connection_timeout, self.read_timeout),
|
||||||
**req_kwargs,
|
**req_kwargs,
|
||||||
) as resp:
|
) as resp:
|
||||||
self._check_status(resp)
|
self._check_status(resp)
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
region: str,
|
region: str,
|
||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||||
|
connection_timeout: float = 120.0,
|
||||||
|
read_timeout: float = 300.0,
|
||||||
):
|
):
|
||||||
"""Connect to a remote LanceDB database."""
|
"""Connect to a remote LanceDB database."""
|
||||||
parsed = urlparse(db_url)
|
parsed = urlparse(db_url)
|
||||||
@@ -49,7 +51,12 @@ class RemoteDBConnection(DBConnection):
|
|||||||
self.db_name = parsed.netloc
|
self.db_name = parsed.netloc
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self._client = RestfulLanceDBClient(
|
self._client = RestfulLanceDBClient(
|
||||||
self.db_name, region, api_key, host_override
|
self.db_name,
|
||||||
|
region,
|
||||||
|
api_key,
|
||||||
|
host_override,
|
||||||
|
connection_timeout=connection_timeout,
|
||||||
|
read_timeout=read_timeout,
|
||||||
)
|
)
|
||||||
self._request_thread_pool = request_thread_pool
|
self._request_thread_pool = request_thread_pool
|
||||||
|
|
||||||
|
|||||||
@@ -68,14 +68,10 @@ class RemoteTable(Table):
|
|||||||
|
|
||||||
def list_indices(self):
|
def list_indices(self):
|
||||||
"""List all the indices on the table"""
|
"""List all the indices on the table"""
|
||||||
|
print(self._name)
|
||||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def index_stats(self, index_uuid: str):
|
|
||||||
"""List all the indices on the table"""
|
|
||||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/{index_uuid}/stats/")
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def create_scalar_index(
|
def create_scalar_index(
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import pyarrow as pa
|
|||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
|
from lance.dependencies import _check_for_hugging_face
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from ._lancedb import Table as LanceDBTable
|
from ._lancedb import Table as LanceDBTable
|
||||||
from .db import LanceDBConnection
|
from .db import LanceDBConnection
|
||||||
|
from .index import BTree, IndexConfig, IvfPq
|
||||||
|
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
@@ -73,7 +75,16 @@ def _sanitize_data(
|
|||||||
on_bad_vectors: str,
|
on_bad_vectors: str,
|
||||||
fill_value: Any,
|
fill_value: Any,
|
||||||
):
|
):
|
||||||
if isinstance(data, list):
|
import pdb; pdb.set_trace()
|
||||||
|
if _check_for_hugging_face(data):
|
||||||
|
# Huggingface datasets
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
if isinstance(data, datasets.Dataset):
|
||||||
|
if schema is None:
|
||||||
|
schema = data.features.arrow_schema
|
||||||
|
data = data.data.to_batches()
|
||||||
|
elif isinstance(data, list):
|
||||||
# convert to list of dict if data is a bunch of LanceModels
|
# convert to list of dict if data is a bunch of LanceModels
|
||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
schema = data[0].__class__.to_arrow_schema()
|
schema = data[0].__class__.to_arrow_schema()
|
||||||
@@ -135,12 +146,11 @@ def _to_record_batch_generator(
|
|||||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||||
):
|
):
|
||||||
for batch in data:
|
for batch in data:
|
||||||
if not isinstance(batch, pa.RecordBatch):
|
if isinstance(batch, pa.RecordBatch):
|
||||||
|
batch = pa.Table.from_batches([batch])
|
||||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||||
for batch in table.to_batches():
|
for batch in table.to_batches():
|
||||||
yield batch
|
yield batch
|
||||||
else:
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
|
|
||||||
class Table(ABC):
|
class Table(ABC):
|
||||||
@@ -1917,112 +1927,48 @@ class AsyncTable:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create_index(
|
async def create_index(
|
||||||
self,
|
|
||||||
metric="L2",
|
|
||||||
num_partitions=256,
|
|
||||||
num_sub_vectors=96,
|
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
|
||||||
replace: bool = True,
|
|
||||||
accelerator: Optional[str] = None,
|
|
||||||
index_cache_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""Create an index on the table.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
metric: str, default "L2"
|
|
||||||
The distance metric to use when creating the index.
|
|
||||||
Valid values are "L2", "cosine", or "dot".
|
|
||||||
L2 is euclidean distance.
|
|
||||||
num_partitions: int, default 256
|
|
||||||
The number of IVF partitions to use when creating the index.
|
|
||||||
Default is 256.
|
|
||||||
num_sub_vectors: int, default 96
|
|
||||||
The number of PQ sub-vectors to use when creating the index.
|
|
||||||
Default is 96.
|
|
||||||
vector_column_name: str, default "vector"
|
|
||||||
The vector column name to create the index.
|
|
||||||
replace: bool, default True
|
|
||||||
- If True, replace the existing index if it exists.
|
|
||||||
|
|
||||||
- If False, raise an error if duplicate index exists.
|
|
||||||
accelerator: str, default None
|
|
||||||
If set, use the given accelerator to create the index.
|
|
||||||
Only support "cuda" for now.
|
|
||||||
index_cache_size : int, optional
|
|
||||||
The size of the index cache in number of entries. Default value is 256.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def create_scalar_index(
|
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
*,
|
*,
|
||||||
replace: bool = True,
|
replace: Optional[bool] = None,
|
||||||
|
config: Optional[Union[IvfPq, BTree]] = None,
|
||||||
):
|
):
|
||||||
"""Create a scalar index on a column.
|
"""Create an index to speed up queries
|
||||||
|
|
||||||
Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
Indices can be created on vector columns or scalar columns.
|
||||||
index can speed up scans that contain filter expressions on the indexed column.
|
Indices on vector columns will speed up vector searches.
|
||||||
For example, the following scan will be faster if the column ``my_col`` has
|
Indices on scalar columns will speed up filtering (in both
|
||||||
a scalar index:
|
vector and non-vector searches)
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
db = lancedb.connect("/data/lance")
|
|
||||||
img_table = db.open_table("images")
|
|
||||||
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
|
|
||||||
|
|
||||||
Scalar indices can also speed up scans containing a vector search and a
|
|
||||||
prefilter:
|
|
||||||
|
|
||||||
.. code-block::python
|
|
||||||
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
db = lancedb.connect("/data/lance")
|
|
||||||
img_table = db.open_table("images")
|
|
||||||
img_table.search([1, 2, 3, 4], vector_column_name="vector")
|
|
||||||
.where("my_col != 7", prefilter=True)
|
|
||||||
.to_pandas()
|
|
||||||
|
|
||||||
Scalar indices can only speed up scans for basic filters using
|
|
||||||
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
|
|
||||||
membership (e.g. `my_col IN (0, 1, 2)`)
|
|
||||||
|
|
||||||
Scalar indices can be used if the filter contains multiple indexed columns and
|
|
||||||
the filter criteria are AND'd or OR'd together
|
|
||||||
(e.g. ``my_col < 0 AND other_col> 100``)
|
|
||||||
|
|
||||||
Scalar indices may be used if the filter contains non-indexed columns but,
|
|
||||||
depending on the structure of the filter, they may not be usable. For example,
|
|
||||||
if the column ``not_indexed`` does not have a scalar index then the filter
|
|
||||||
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
|
|
||||||
``my_col``.
|
|
||||||
|
|
||||||
**Experimental API**
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
column : str
|
index: Index
|
||||||
The column to be indexed. Must be a boolean, integer, float,
|
The index to create.
|
||||||
or string column.
|
|
||||||
|
LanceDb supports multiple types of indices. See the static methods on
|
||||||
|
the Index class for more details.
|
||||||
|
column: str, default None
|
||||||
|
The column to index.
|
||||||
|
|
||||||
|
When building a scalar index this must be set.
|
||||||
|
|
||||||
|
When building a vector index, this is optional. The default will look
|
||||||
|
for any columns of type fixed-size-list with floating point values. If
|
||||||
|
there is only one column of this type then it will be used. Otherwise
|
||||||
|
an error will be returned.
|
||||||
replace: bool, default True
|
replace: bool, default True
|
||||||
Replace the existing index if it exists.
|
Whether to replace the existing index
|
||||||
|
|
||||||
Examples
|
If this is false, and another index already exists on the same columns
|
||||||
--------
|
and the same name, then an error will be returned. This is true even if
|
||||||
|
that index is out of date.
|
||||||
|
|
||||||
.. code-block:: python
|
The default is True
|
||||||
|
|
||||||
import lance
|
|
||||||
|
|
||||||
dataset = lance.dataset("./images.lance")
|
|
||||||
dataset.create_scalar_index("category")
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
index = None
|
||||||
|
if config is not None:
|
||||||
|
index = config._inner
|
||||||
|
await self._inner.create_index(column, index=index, replace=replace)
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
@@ -2066,6 +2012,8 @@ class AsyncTable:
|
|||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
|
if isinstance(data, pa.Table):
|
||||||
|
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||||
await self._inner.add(data, mode)
|
await self._inner.add(data, mode)
|
||||||
register_event("add")
|
register_event("add")
|
||||||
|
|
||||||
@@ -2275,58 +2223,57 @@ class AsyncTable:
|
|||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
where: Optional[str] = None,
|
updates: Optional[Dict[str, Any]] = None,
|
||||||
values: Optional[dict] = None,
|
|
||||||
*,
|
*,
|
||||||
values_sql: Optional[Dict[str, str]] = None,
|
where: Optional[str] = None,
|
||||||
|
updates_sql: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This can be used to update zero to all rows depending on how many
|
This can be used to update zero to all rows in the table.
|
||||||
rows match the where clause. If no where clause is provided, then
|
|
||||||
all rows will be updated.
|
|
||||||
|
|
||||||
Either `values` or `values_sql` must be provided. You cannot provide
|
If a filter is provided with `where` then only rows matching the
|
||||||
both.
|
filter will be updated. Otherwise all rows will be updated.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
updates: dict, optional
|
||||||
|
The updates to apply. The keys should be the name of the column to
|
||||||
|
update. The values should be the new values to assign. This is
|
||||||
|
required unless updates_sql is supplied.
|
||||||
where: str, optional
|
where: str, optional
|
||||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
An SQL filter that controls which rows are updated. For example, 'x = 2'
|
||||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated.
|
||||||
values: dict, optional
|
updates_sql: dict, optional
|
||||||
The values to update. The keys are the column names and the values
|
The updates to apply, expressed as SQL expression strings. The keys should
|
||||||
are the values to set.
|
be column names. The values should be SQL expressions. These can be SQL
|
||||||
values_sql: dict, optional
|
literals (e.g. "7" or "'foo'") or they can be expressions based on the
|
||||||
The values to update, expressed as SQL expression strings. These can
|
previous value of the row (e.g. "x + 1" to increment the x column by 1)
|
||||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
|
||||||
the x column by 1.
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
>>> import asyncio
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> import pandas as pd
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
>>> async def demo_update():
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]})
|
||||||
>>> table = db.create_table("my_table", data)
|
... db = await lancedb.connect_async("./.lancedb")
|
||||||
>>> table.to_pandas()
|
... table = await db.create_table("my_table", data)
|
||||||
x vector
|
... # x is [1, 2], vector is [[1, 2], [3, 4]]
|
||||||
0 1 [1.0, 2.0]
|
... await table.update({"vector": [10, 10]}, where="x = 2")
|
||||||
1 2 [3.0, 4.0]
|
... # x is [1, 2], vector is [[1, 2], [10, 10]]
|
||||||
2 3 [5.0, 6.0]
|
... await table.update(updates_sql={"x": "x + 1"})
|
||||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
... # x is [2, 3], vector is [[1, 2], [10, 10]]
|
||||||
>>> table.to_pandas()
|
>>> asyncio.run(demo_update())
|
||||||
x vector
|
|
||||||
0 1 [1.0, 2.0]
|
|
||||||
1 3 [5.0, 6.0]
|
|
||||||
2 2 [10.0, 10.0]
|
|
||||||
>>> table.update(values_sql={"x": "x + 1"})
|
|
||||||
>>> table.to_pandas()
|
|
||||||
x vector
|
|
||||||
0 2 [1.0, 2.0]
|
|
||||||
1 4 [5.0, 6.0]
|
|
||||||
2 3 [10.0, 10.0]
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if updates is not None and updates_sql is not None:
|
||||||
|
raise ValueError("Only one of updates or updates_sql can be provided")
|
||||||
|
if updates is None and updates_sql is None:
|
||||||
|
raise ValueError("Either updates or updates_sql must be provided")
|
||||||
|
|
||||||
|
if updates is not None:
|
||||||
|
updates_sql = {k: value_to_sql(v) for k, v in updates.items()}
|
||||||
|
|
||||||
|
return await self._inner.update(updates_sql, where)
|
||||||
|
|
||||||
async def cleanup_old_versions(
|
async def cleanup_old_versions(
|
||||||
self,
|
self,
|
||||||
@@ -2423,3 +2370,65 @@ class AsyncTable:
|
|||||||
The names of the columns to drop.
|
The names of the columns to drop.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def version(self) -> int:
|
||||||
|
"""
|
||||||
|
Retrieve the version of the table
|
||||||
|
|
||||||
|
LanceDb supports versioning. Every operation that modifies the table increases
|
||||||
|
version. As long as a version hasn't been deleted you can `[Self::checkout]`
|
||||||
|
that version to view the data at that point. In addition, you can
|
||||||
|
`[Self::restore]` the version to replace the current table with a previous
|
||||||
|
version.
|
||||||
|
"""
|
||||||
|
return await self._inner.version()
|
||||||
|
|
||||||
|
async def checkout(self, version):
|
||||||
|
"""
|
||||||
|
Checks out a specific version of the Table
|
||||||
|
|
||||||
|
Any read operation on the table will now access the data at the checked out
|
||||||
|
version. As a consequence, calling this method will disable any read consistency
|
||||||
|
interval that was previously set.
|
||||||
|
|
||||||
|
This is a read-only operation that turns the table into a sort of "view"
|
||||||
|
or "detached head". Other table instances will not be affected. To make the
|
||||||
|
change permanent you can use the `[Self::restore]` method.
|
||||||
|
|
||||||
|
Any operation that modifies the table will fail while the table is in a checked
|
||||||
|
out state.
|
||||||
|
|
||||||
|
To return the table to a normal state use `[Self::checkout_latest]`
|
||||||
|
"""
|
||||||
|
await self._inner.checkout(version)
|
||||||
|
|
||||||
|
async def checkout_latest(self):
|
||||||
|
"""
|
||||||
|
Ensures the table is pointing at the latest version
|
||||||
|
|
||||||
|
This can be used to manually update a table when the read_consistency_interval
|
||||||
|
is None
|
||||||
|
It can also be used to undo a `[Self::checkout]` operation
|
||||||
|
"""
|
||||||
|
await self._inner.checkout_latest()
|
||||||
|
|
||||||
|
async def restore(self):
|
||||||
|
"""
|
||||||
|
Restore the table to the currently checked out version
|
||||||
|
|
||||||
|
This operation will fail if checkout has not been called previously
|
||||||
|
|
||||||
|
This operation will overwrite the latest version of the table with a
|
||||||
|
previous version. Any changes made since the checked out version will
|
||||||
|
no longer be visible.
|
||||||
|
|
||||||
|
Once the operation concludes the table will no longer be in a checked
|
||||||
|
out state and the read_consistency_interval, if any, will apply.
|
||||||
|
"""
|
||||||
|
await self._inner.restore()
|
||||||
|
|
||||||
|
async def list_indices(self) -> IndexConfig:
|
||||||
|
"""
|
||||||
|
List all indices that have been created with Self::create_index
|
||||||
|
"""
|
||||||
|
return await self._inner.list_indices()
|
||||||
|
|||||||
BIN
python/python/tests/images/1.png
Normal file
BIN
python/python/tests/images/1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 B |
69
python/python/tests/test_index.py
Normal file
69
python/python/tests/test_index.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||||
|
from lancedb.index import BTree, IvfPq
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_async(tmp_path) -> AsyncConnection:
|
||||||
|
return await connect_async(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||||
|
|
||||||
|
|
||||||
|
def sample_fixed_size_list_array(nrows, dim):
|
||||||
|
vector_data = pa.array([float(i) for i in range(dim * nrows)], pa.float32())
|
||||||
|
return pa.FixedSizeListArray.from_arrays(vector_data, dim)
|
||||||
|
|
||||||
|
|
||||||
|
DIM = 8
|
||||||
|
NROWS = 256
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def some_table(db_async):
|
||||||
|
data = pa.Table.from_pydict(
|
||||||
|
{
|
||||||
|
"id": list(range(256)),
|
||||||
|
"vector": sample_fixed_size_list_array(NROWS, DIM),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return await db_async.create_table(
|
||||||
|
"some_table",
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_scalar_index(some_table: AsyncTable):
|
||||||
|
# Can create
|
||||||
|
await some_table.create_index("id")
|
||||||
|
# Can recreate if replace=True
|
||||||
|
await some_table.create_index("id", replace=True)
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type == "BTree"
|
||||||
|
assert indices[0].columns == ["id"]
|
||||||
|
# Can't recreate if replace=False
|
||||||
|
with pytest.raises(RuntimeError, match="already exists"):
|
||||||
|
await some_table.create_index("id", replace=False)
|
||||||
|
# can also specify index type
|
||||||
|
await some_table.create_index("id", config=BTree())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_vector_index(some_table: AsyncTable):
|
||||||
|
# Can create
|
||||||
|
await some_table.create_index("vector")
|
||||||
|
# Can recreate if replace=True
|
||||||
|
await some_table.create_index("vector", replace=True)
|
||||||
|
# Can't recreate if replace=False
|
||||||
|
with pytest.raises(RuntimeError, match="already exists"):
|
||||||
|
await some_table.create_index("vector", replace=False)
|
||||||
|
# Can also specify index type
|
||||||
|
await some_table.create_index("vector", config=IvfPq(num_partitions=100))
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type == "IvfPq"
|
||||||
|
assert indices[0].columns == ["vector"]
|
||||||
@@ -12,17 +12,27 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from lancedb.pydantic import (
|
||||||
|
PYDANTIC_VERSION,
|
||||||
|
EncodedImage,
|
||||||
|
LanceModel,
|
||||||
|
Vector,
|
||||||
|
pydantic_to_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 9),
|
sys.version_info < (3, 9),
|
||||||
@@ -243,3 +253,23 @@ def test_lance_model():
|
|||||||
|
|
||||||
t = TestModel()
|
t = TestModel()
|
||||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_with_images():
|
||||||
|
pytest.importorskip("PIL")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
img: EncodedImage()
|
||||||
|
|
||||||
|
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
img_bytes = f.read()
|
||||||
|
|
||||||
|
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||||
|
m2 = TestModel(img=img_bytes)
|
||||||
|
|
||||||
|
def tobytes(m):
|
||||||
|
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||||
|
|
||||||
|
assert tobytes(m1) == tobytes(m2)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import io
|
||||||
|
import os
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -20,19 +22,21 @@ from typing import List
|
|||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import lancedb
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import polars as pl
|
import polars as pl
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
from lance.arrow import EncodedImageType
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import lancedb
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.db import AsyncConnection, LanceDBConnection
|
from lancedb.db import AsyncConnection, LanceDBConnection
|
||||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import EncodedImage, LanceModel, Vector
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MockDB:
|
class MockDB:
|
||||||
@@ -85,12 +89,30 @@ async def test_close(db_async: AsyncConnection):
|
|||||||
assert str(table) == "ClosedTable(some_table)"
|
assert str(table) == "ClosedTable(some_table)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_async(db_async: AsyncConnection):
|
||||||
|
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||||
|
assert await table.count_rows("id == 0") == 1
|
||||||
|
assert await table.count_rows("id == 7") == 0
|
||||||
|
await table.update({"id": 7})
|
||||||
|
assert await table.count_rows("id == 7") == 1
|
||||||
|
assert await table.count_rows("id == 0") == 0
|
||||||
|
await table.add([{"id": 2}])
|
||||||
|
await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
|
||||||
|
assert await table.count_rows("id == 7") == 1
|
||||||
|
assert await table.count_rows("id == 2") == 0
|
||||||
|
assert await table.count_rows("id == 5") == 1
|
||||||
|
await table.update({"id": 10}, where="id == 5")
|
||||||
|
assert await table.count_rows("id == 10") == 1
|
||||||
|
|
||||||
|
|
||||||
def test_create_table(db):
|
def test_create_table(db):
|
||||||
schema = pa.schema(
|
schema = pa.schema(
|
||||||
[
|
[
|
||||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
pa.field("item", pa.string()),
|
pa.field("item", pa.string()),
|
||||||
pa.field("price", pa.float32()),
|
pa.field("price", pa.float32()),
|
||||||
|
pa.field("encoded_image", EncodedImageType()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected = pa.Table.from_arrays(
|
expected = pa.Table.from_arrays(
|
||||||
@@ -98,13 +120,26 @@ def test_create_table(db):
|
|||||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||||
pa.array(["foo", "bar"]),
|
pa.array(["foo", "bar"]),
|
||||||
pa.array([10.0, 20.0]),
|
pa.array([10.0, 20.0]),
|
||||||
|
pa.ExtensionArray.from_storage(
|
||||||
|
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||||
|
),
|
||||||
],
|
],
|
||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
data = [
|
data = [
|
||||||
[
|
[
|
||||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
{
|
||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
"vector": [3.1, 4.1],
|
||||||
|
"item": "foo",
|
||||||
|
"price": 10.0,
|
||||||
|
"encoded_image": b"foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"vector": [5.9, 26.5],
|
||||||
|
"item": "bar",
|
||||||
|
"price": 20.0,
|
||||||
|
"encoded_image": b"bar",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
df = pd.DataFrame(data[0])
|
df = pd.DataFrame(data[0])
|
||||||
@@ -974,3 +1009,56 @@ def test_drop_columns(tmp_path):
|
|||||||
table = LanceTable.create(db, "my_table", data=data)
|
table = LanceTable.create(db, "my_table", data=data)
|
||||||
table.drop_columns(["category"])
|
table.drop_columns(["category"])
|
||||||
assert table.to_arrow().column_names == ["id"]
|
assert table.to_arrow().column_names == ["id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_time_travel(db_async: AsyncConnection):
|
||||||
|
# Setup
|
||||||
|
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||||
|
version = await table.version()
|
||||||
|
await table.add([{"id": 1}])
|
||||||
|
assert await table.count_rows() == 2
|
||||||
|
# Make sure we can rewind
|
||||||
|
await table.checkout(version)
|
||||||
|
assert await table.count_rows() == 1
|
||||||
|
# Can't add data in time travel mode
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="table cannot be modified when a specific version is checked out",
|
||||||
|
):
|
||||||
|
await table.add([{"id": 2}])
|
||||||
|
# Can go back to normal mode
|
||||||
|
await table.checkout_latest()
|
||||||
|
assert await table.count_rows() == 2
|
||||||
|
# Should be able to add data again
|
||||||
|
await table.add([{"id": 3}])
|
||||||
|
assert await table.count_rows() == 3
|
||||||
|
# Now checkout and restore
|
||||||
|
await table.checkout(version)
|
||||||
|
await table.restore()
|
||||||
|
assert await table.count_rows() == 1
|
||||||
|
# Should be able to add data
|
||||||
|
await table.add([{"id": 4}])
|
||||||
|
assert await table.count_rows() == 2
|
||||||
|
# Can't use restore if not checked out
|
||||||
|
with pytest.raises(ValueError, match="checkout before running restore"):
|
||||||
|
await table.restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_image(tmp_path):
|
||||||
|
pytest.importorskip("PIL")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
img: EncodedImage()
|
||||||
|
|
||||||
|
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||||
|
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||||
|
|
||||||
|
def tobytes(m):
|
||||||
|
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||||
|
|
||||||
|
table = LanceTable.create(db, "my_table", schema=TestModel)
|
||||||
|
table.add([m1])
|
||||||
|
|||||||
109
python/src/index.rs
Normal file
109
python/src/index.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use lancedb::{
|
||||||
|
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
|
||||||
|
DistanceType,
|
||||||
|
};
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
|
pyclass, pymethods, PyResult,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct Index {
|
||||||
|
inner: Mutex<Option<LanceDbIndex>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Index {
|
||||||
|
pub fn consume(&self) -> PyResult<LanceDbIndex> {
|
||||||
|
self.inner
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl Index {
|
||||||
|
#[staticmethod]
|
||||||
|
pub fn ivf_pq(
|
||||||
|
distance_type: Option<String>,
|
||||||
|
num_partitions: Option<u32>,
|
||||||
|
num_sub_vectors: Option<u32>,
|
||||||
|
max_iterations: Option<u32>,
|
||||||
|
sample_rate: Option<u32>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||||
|
if let Some(distance_type) = distance_type {
|
||||||
|
let distance_type = match distance_type.as_str() {
|
||||||
|
"l2" => Ok(DistanceType::L2),
|
||||||
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
_ => Err(PyValueError::new_err(format!(
|
||||||
|
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||||
|
distance_type
|
||||||
|
))),
|
||||||
|
}?;
|
||||||
|
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||||
|
}
|
||||||
|
if let Some(num_partitions) = num_partitions {
|
||||||
|
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||||
|
}
|
||||||
|
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||||
|
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||||
|
}
|
||||||
|
if let Some(max_iterations) = max_iterations {
|
||||||
|
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||||
|
}
|
||||||
|
if let Some(sample_rate) = sample_rate {
|
||||||
|
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[staticmethod]
|
||||||
|
pub fn btree() -> PyResult<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass(get_all)]
|
||||||
|
/// A description of an index currently configured on a column
|
||||||
|
pub struct IndexConfig {
|
||||||
|
/// The type of the index
|
||||||
|
pub index_type: String,
|
||||||
|
/// The columns in the index
|
||||||
|
///
|
||||||
|
/// Currently this is always a list of size 1. In the future there may
|
||||||
|
/// be more columns to represent composite indices.
|
||||||
|
pub columns: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<lancedb::index::IndexConfig> for IndexConfig {
|
||||||
|
fn from(value: lancedb::index::IndexConfig) -> Self {
|
||||||
|
let index_type = format!("{:?}", value.index_type);
|
||||||
|
Self {
|
||||||
|
index_type,
|
||||||
|
columns: value.columns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,11 +14,15 @@
|
|||||||
|
|
||||||
use connection::{connect, Connection};
|
use connection::{connect, Connection};
|
||||||
use env_logger::Env;
|
use env_logger::Env;
|
||||||
|
use index::{Index, IndexConfig};
|
||||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||||
|
use table::Table;
|
||||||
|
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod index;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
@@ -27,6 +31,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
.write_style("LANCEDB_LOG_STYLE");
|
.write_style("LANCEDB_LOG_STYLE");
|
||||||
env_logger::init_from_env(env);
|
env_logger::init_from_env(env);
|
||||||
m.add_class::<Connection>()?;
|
m.add_class::<Connection>()?;
|
||||||
|
m.add_class::<Table>()?;
|
||||||
|
m.add_class::<Index>()?;
|
||||||
|
m.add_class::<IndexConfig>()?;
|
||||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -5,11 +5,16 @@ use arrow::{
|
|||||||
use lancedb::table::{AddDataMode, Table as LanceDbTable};
|
use lancedb::table::{AddDataMode, Table as LanceDbTable};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
pyclass, pymethods, PyAny, PyRef, PyResult, Python,
|
pyclass, pymethods,
|
||||||
|
types::{PyDict, PyString},
|
||||||
|
PyAny, PyRef, PyResult, Python,
|
||||||
};
|
};
|
||||||
use pyo3_asyncio::tokio::future_into_py;
|
use pyo3_asyncio::tokio::future_into_py;
|
||||||
|
|
||||||
use crate::error::PythonErrorExt;
|
use crate::{
|
||||||
|
error::PythonErrorExt,
|
||||||
|
index::{Index, IndexConfig},
|
||||||
|
};
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Table {
|
pub struct Table {
|
||||||
@@ -74,6 +79,28 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update<'a>(
|
||||||
|
self_: PyRef<'a, Self>,
|
||||||
|
updates: &PyDict,
|
||||||
|
r#where: Option<String>,
|
||||||
|
) -> PyResult<&'a PyAny> {
|
||||||
|
let mut op = self_.inner_ref()?.update();
|
||||||
|
if let Some(only_if) = r#where {
|
||||||
|
op = op.only_if(only_if);
|
||||||
|
}
|
||||||
|
for (column_name, value) in updates.into_iter() {
|
||||||
|
let column_name: &PyString = column_name.downcast()?;
|
||||||
|
let column_name = column_name.to_str()?.to_string();
|
||||||
|
let value: &PyString = value.downcast()?;
|
||||||
|
let value = value.to_str()?.to_string();
|
||||||
|
op = op.column(column_name, value);
|
||||||
|
}
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
op.execute().await.infer_error()?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn count_rows(self_: PyRef<'_, Self>, filter: Option<String>) -> PyResult<&PyAny> {
|
pub fn count_rows(self_: PyRef<'_, Self>, filter: Option<String>) -> PyResult<&PyAny> {
|
||||||
let inner = self_.inner_ref()?.clone();
|
let inner = self_.inner_ref()?.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
@@ -81,10 +108,75 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_index<'a>(
|
||||||
|
self_: PyRef<'a, Self>,
|
||||||
|
column: String,
|
||||||
|
index: Option<&Index>,
|
||||||
|
replace: Option<bool>,
|
||||||
|
) -> PyResult<&'a PyAny> {
|
||||||
|
let index = if let Some(index) = index {
|
||||||
|
index.consume()?
|
||||||
|
} else {
|
||||||
|
lancedb::index::Index::Auto
|
||||||
|
};
|
||||||
|
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||||
|
if let Some(replace) = replace {
|
||||||
|
op = op.replace(replace);
|
||||||
|
}
|
||||||
|
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
op.execute().await.infer_error()?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_indices(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner_ref()?.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
Ok(inner
|
||||||
|
.list_indices()
|
||||||
|
.await
|
||||||
|
.infer_error()?
|
||||||
|
.into_iter()
|
||||||
|
.map(IndexConfig::from)
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn __repr__(&self) -> String {
|
pub fn __repr__(&self) -> String {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
None => format!("ClosedTable({})", self.name),
|
None => format!("ClosedTable({})", self.name),
|
||||||
Some(inner) => inner.to_string(),
|
Some(inner) => inner.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn version(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner_ref()?.clone();
|
||||||
|
future_into_py(
|
||||||
|
self_.py(),
|
||||||
|
async move { inner.version().await.infer_error() },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn checkout(self_: PyRef<'_, Self>, version: u64) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner_ref()?.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
inner.checkout(version).await.infer_error()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn checkout_latest(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner_ref()?.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
inner.checkout_latest().await.infer_error()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn restore(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner_ref()?.clone();
|
||||||
|
future_into_py(
|
||||||
|
self_.py(),
|
||||||
|
async move { inner.restore().await.infer_error() },
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
35
python/src/util.rs
Normal file
35
python/src/util.rs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
||||||
|
|
||||||
|
/// A wrapper around a rust builder
|
||||||
|
///
|
||||||
|
/// Rust builders are often implemented so that the builder methods
|
||||||
|
/// consume the builder and return a new one. This is not compatible
|
||||||
|
/// with the pyo3, which, being garbage collected, cannot easily obtain
|
||||||
|
/// ownership of an object.
|
||||||
|
///
|
||||||
|
/// This wrapper converts the compile-time safety of rust into runtime
|
||||||
|
/// errors if any attempt to use the builder happens after it is consumed.
|
||||||
|
pub struct BuilderWrapper<T> {
|
||||||
|
name: String,
|
||||||
|
inner: Mutex<Option<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> BuilderWrapper<T> {
|
||||||
|
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.as_ref().to_string(),
|
||||||
|
inner: Mutex::new(Some(inner)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
|
||||||
|
let mut inner = self.inner.lock().unwrap();
|
||||||
|
let inner_builder = inner.take().ok_or_else(|| {
|
||||||
|
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
|
||||||
|
})?;
|
||||||
|
let result = mod_fn(inner_builder);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use lancedb::index::{scalar::BTreeIndexBuilder, Index};
|
||||||
use neon::{
|
use neon::{
|
||||||
context::{Context, FunctionContext},
|
context::{Context, FunctionContext},
|
||||||
result::JsResult,
|
result::JsResult,
|
||||||
@@ -33,9 +34,9 @@ pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise>
|
|||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let idx_result = table
|
let idx_result = table
|
||||||
.create_index(&[&column])
|
.create_index(&[column], Index::BTree(BTreeIndexBuilder::default()))
|
||||||
.replace(replace)
|
.replace(replace)
|
||||||
.build()
|
.execute()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
|||||||
@@ -13,12 +13,12 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use lance_linalg::distance::MetricType;
|
use lance_linalg::distance::MetricType;
|
||||||
use lancedb::index::IndexBuilder;
|
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||||
|
use lancedb::index::Index;
|
||||||
use neon::context::FunctionContext;
|
use neon::context::FunctionContext;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
use crate::error::Error::InvalidIndexType;
|
|
||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
use crate::neon_ext::js_object_ext::JsObjectExt;
|
use crate::neon_ext::js_object_ext::JsObjectExt;
|
||||||
use crate::runtime;
|
use crate::runtime;
|
||||||
@@ -39,13 +39,20 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise>
|
|||||||
.map(|s| s.value(&mut cx))
|
.map(|s| s.value(&mut cx))
|
||||||
.unwrap_or("vector".to_string()); // Backward compatibility
|
.unwrap_or("vector".to_string()); // Backward compatibility
|
||||||
|
|
||||||
|
let replace = index_params
|
||||||
|
.get_opt::<JsBoolean, _, _>(&mut cx, "replace")?
|
||||||
|
.map(|r| r.value(&mut cx));
|
||||||
|
|
||||||
let tbl = table.clone();
|
let tbl = table.clone();
|
||||||
let index_builder = tbl.create_index(&[&column_name]);
|
let ivf_pq_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?;
|
||||||
let index_builder =
|
|
||||||
get_index_params_builder(&mut cx, index_params, index_builder).or_throw(&mut cx)?;
|
let mut index_builder = tbl.create_index(&[column_name], Index::IvfPq(ivf_pq_builder));
|
||||||
|
if let Some(replace) = replace {
|
||||||
|
index_builder = index_builder.replace(replace);
|
||||||
|
}
|
||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let idx_result = index_builder.build().await;
|
let idx_result = index_builder.execute().await;
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
idx_result.or_throw(&mut cx)?;
|
idx_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(JsTable::from(table)))
|
||||||
@@ -57,26 +64,17 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise>
|
|||||||
fn get_index_params_builder(
|
fn get_index_params_builder(
|
||||||
cx: &mut FunctionContext,
|
cx: &mut FunctionContext,
|
||||||
obj: Handle<JsObject>,
|
obj: Handle<JsObject>,
|
||||||
builder: IndexBuilder,
|
) -> crate::error::Result<IvfPqIndexBuilder> {
|
||||||
) -> crate::error::Result<IndexBuilder> {
|
if obj.get_opt::<JsString, _, _>(cx, "index_name")?.is_some() {
|
||||||
let mut builder = match obj.get::<JsString, _, _>(cx, "type")?.value(cx).as_str() {
|
return Err(crate::error::Error::LanceDB {
|
||||||
"ivf_pq" => builder.ivf_pq(),
|
message: "Setting the index_name is no longer supported".to_string(),
|
||||||
_ => {
|
});
|
||||||
return Err(InvalidIndexType {
|
|
||||||
index_type: "".into(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
};
|
let mut builder = IvfPqIndexBuilder::default();
|
||||||
|
|
||||||
if let Some(index_name) = obj.get_opt::<JsString, _, _>(cx, "index_name")? {
|
|
||||||
builder = builder.name(index_name.value(cx).as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
||||||
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
|
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
|
||||||
builder = builder.metric_type(metric_type);
|
builder = builder.distance_type(metric_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
||||||
builder = builder.num_partitions(np);
|
builder = builder.num_partitions(np);
|
||||||
}
|
}
|
||||||
@@ -86,11 +84,5 @@ fn get_index_params_builder(
|
|||||||
if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? {
|
if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? {
|
||||||
builder = builder.max_iterations(max_iters);
|
builder = builder.max_iterations(max_iters);
|
||||||
}
|
}
|
||||||
if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? {
|
|
||||||
builder = builder.num_bits(num_bits);
|
|
||||||
}
|
|
||||||
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
|
|
||||||
builder = builder.replace(replace.value(cx));
|
|
||||||
}
|
|
||||||
Ok(builder)
|
Ok(builder)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -297,11 +297,14 @@ impl JsTable {
|
|||||||
|
|
||||||
let predicate = predicate.as_deref();
|
let predicate = predicate.as_deref();
|
||||||
|
|
||||||
let update_result = table
|
let mut update_op = table.update();
|
||||||
.as_native()
|
if let Some(predicate) = predicate {
|
||||||
.unwrap()
|
update_op = update_op.only_if(predicate);
|
||||||
.update(predicate, updates_arg)
|
}
|
||||||
.await;
|
for (column, value) in updates_arg {
|
||||||
|
update_op = update_op.column(column, value);
|
||||||
|
}
|
||||||
|
let update_result = update_op.execute().await;
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
update_result.or_throw(&mut cx)?;
|
update_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(Self::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ lance = { workspace = true }
|
|||||||
lance-index = { workspace = true }
|
lance-index = { workspace = true }
|
||||||
lance-linalg = { workspace = true }
|
lance-linalg = { workspace = true }
|
||||||
lance-testing = { workspace = true }
|
lance-testing = { workspace = true }
|
||||||
|
pin-project = { workspace = true }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ use arrow_schema::{DataType, Field, Schema};
|
|||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
use lancedb::connection::Connection;
|
use lancedb::connection::Connection;
|
||||||
|
use lancedb::index::Index;
|
||||||
use lancedb::{connect, Result, Table as LanceDbTable};
|
use lancedb::{connect, Result, Table as LanceDbTable};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -142,23 +143,18 @@ async fn create_empty_table(db: &Connection) -> Result<LanceDbTable> {
|
|||||||
|
|
||||||
async fn create_index(table: &LanceDbTable) -> Result<()> {
|
async fn create_index(table: &LanceDbTable) -> Result<()> {
|
||||||
// --8<-- [start:create_index]
|
// --8<-- [start:create_index]
|
||||||
table
|
table.create_index(&["vector"], Index::Auto).execute().await
|
||||||
.create_index(&["vector"])
|
|
||||||
.ivf_pq()
|
|
||||||
.num_partitions(8)
|
|
||||||
.build()
|
|
||||||
.await
|
|
||||||
// --8<-- [end:create_index]
|
// --8<-- [end:create_index]
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
|
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
|
||||||
// --8<-- [start:search]
|
// --8<-- [start:search]
|
||||||
Ok(table
|
table
|
||||||
.search(&[1.0; 128])
|
.search(&[1.0; 128])
|
||||||
.limit(2)
|
.limit(2)
|
||||||
.execute_stream()
|
.execute_stream()
|
||||||
.await?
|
.await?
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
.await?)
|
.await
|
||||||
// --8<-- [end:search]
|
// --8<-- [end:search]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,4 +12,92 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
pub use lance::arrow::*;
|
use std::{pin::Pin, sync::Arc};
|
||||||
|
|
||||||
|
pub use arrow_array;
|
||||||
|
pub use arrow_schema;
|
||||||
|
use futures::{Stream, StreamExt};
|
||||||
|
|
||||||
|
use crate::error::Result;
|
||||||
|
|
||||||
|
/// An iterator of batches that also has a schema
|
||||||
|
pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> {
|
||||||
|
/// Returns the schema of this `RecordBatchReader`.
|
||||||
|
///
|
||||||
|
/// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
|
||||||
|
/// reader should have the same schema as returned from this method.
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A simple RecordBatchReader formed from the two parts (iterator + schema)
|
||||||
|
pub struct SimpleRecordBatchReader<I: Iterator<Item = Result<arrow_array::RecordBatch>>> {
|
||||||
|
pub schema: Arc<arrow_schema::Schema>,
|
||||||
|
pub batches: I,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = Result<arrow_array::RecordBatch>>> Iterator for SimpleRecordBatchReader<I> {
|
||||||
|
type Item = Result<arrow_array::RecordBatch>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.batches.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = Result<arrow_array::RecordBatch>>> RecordBatchReader
|
||||||
|
for SimpleRecordBatchReader<I>
|
||||||
|
{
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||||
|
self.schema.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A stream of batches that also has a schema
|
||||||
|
pub trait RecordBatchStream: Stream<Item = Result<arrow_array::RecordBatch>> {
|
||||||
|
/// Returns the schema of this `RecordBatchStream`.
|
||||||
|
///
|
||||||
|
/// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
|
||||||
|
/// stream should have the same schema as returned from this method.
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A boxed RecordBatchStream that is also Send
|
||||||
|
pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
|
||||||
|
|
||||||
|
impl<I: lance::io::RecordBatchStream + 'static> From<I> for SendableRecordBatchStream {
|
||||||
|
fn from(stream: I) -> Self {
|
||||||
|
let schema = stream.schema();
|
||||||
|
let mapped_stream = Box::pin(stream.map(|r| r.map_err(Into::into)));
|
||||||
|
Box::pin(SimpleRecordBatchStream {
|
||||||
|
schema,
|
||||||
|
stream: mapped_stream,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A simple RecordBatchStream formed from the two parts (stream + schema)
|
||||||
|
#[pin_project::pin_project]
|
||||||
|
pub struct SimpleRecordBatchStream<S: Stream<Item = Result<arrow_array::RecordBatch>>> {
|
||||||
|
pub schema: Arc<arrow_schema::Schema>,
|
||||||
|
#[pin]
|
||||||
|
pub stream: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> Stream for SimpleRecordBatchStream<S> {
|
||||||
|
type Item = Result<arrow_array::RecordBatch>;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Option<Self::Item>> {
|
||||||
|
let this = self.project();
|
||||||
|
this.stream.poll_next(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> RecordBatchStream
|
||||||
|
for SimpleRecordBatchStream<S>
|
||||||
|
{
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||||
|
self.schema.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -356,6 +356,15 @@ pub struct ConnectBuilder {
|
|||||||
aws_creds: Option<AwsCredential>,
|
aws_creds: Option<AwsCredential>,
|
||||||
|
|
||||||
/// The interval at which to check for updates from other processes.
|
/// The interval at which to check for updates from other processes.
|
||||||
|
///
|
||||||
|
/// If None, then consistency is not checked. For performance
|
||||||
|
/// reasons, this is the default. For strong consistency, set this to
|
||||||
|
/// zero seconds. Then every read will check for updates from other
|
||||||
|
/// processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
/// for eventual consistency. If more than that interval has passed since
|
||||||
|
/// the last check, then the table will be checked for updates. Note: this
|
||||||
|
/// consistency only applies to read operations. Write operations are
|
||||||
|
/// always consistent.
|
||||||
read_consistency_interval: Option<std::time::Duration>,
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,181 +12,69 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{cmp::max, sync::Arc};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use lance_index::IndexType;
|
|
||||||
pub use lance_linalg::distance::MetricType;
|
|
||||||
|
|
||||||
pub mod vector;
|
|
||||||
|
|
||||||
use crate::{table::TableInternal, Result};
|
use crate::{table::TableInternal, Result};
|
||||||
|
|
||||||
/// Index Parameters.
|
use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder};
|
||||||
pub enum IndexParams {
|
|
||||||
Scalar {
|
pub mod scalar;
|
||||||
replace: bool,
|
pub mod vector;
|
||||||
},
|
|
||||||
IvfPq {
|
pub enum Index {
|
||||||
replace: bool,
|
Auto,
|
||||||
metric_type: MetricType,
|
BTree(BTreeIndexBuilder),
|
||||||
num_partitions: u64,
|
IvfPq(IvfPqIndexBuilder),
|
||||||
num_sub_vectors: u32,
|
|
||||||
num_bits: u32,
|
|
||||||
sample_rate: u32,
|
|
||||||
max_iterations: u32,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builder for Index Parameters.
|
/// Builder for the create_index operation
|
||||||
|
///
|
||||||
|
/// The methods on this builder are used to specify options common to all indices.
|
||||||
pub struct IndexBuilder {
|
pub struct IndexBuilder {
|
||||||
parent: Arc<dyn TableInternal>,
|
parent: Arc<dyn TableInternal>,
|
||||||
|
pub(crate) index: Index,
|
||||||
pub(crate) columns: Vec<String>,
|
pub(crate) columns: Vec<String>,
|
||||||
// General parameters
|
|
||||||
/// Index name.
|
|
||||||
pub(crate) name: Option<String>,
|
|
||||||
/// Replace the existing index.
|
|
||||||
pub(crate) replace: bool,
|
pub(crate) replace: bool,
|
||||||
|
|
||||||
pub(crate) index_type: IndexType,
|
|
||||||
|
|
||||||
// Scalar index parameters
|
|
||||||
// Nothing to set here.
|
|
||||||
|
|
||||||
// IVF_PQ parameters
|
|
||||||
pub(crate) metric_type: MetricType,
|
|
||||||
pub(crate) num_partitions: Option<u32>,
|
|
||||||
// PQ related
|
|
||||||
pub(crate) num_sub_vectors: Option<u32>,
|
|
||||||
pub(crate) num_bits: u32,
|
|
||||||
|
|
||||||
/// The rate to find samples to train kmeans.
|
|
||||||
pub(crate) sample_rate: u32,
|
|
||||||
/// Max iteration to train kmeans.
|
|
||||||
pub(crate) max_iterations: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IndexBuilder {
|
impl IndexBuilder {
|
||||||
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: &[&str]) -> Self {
|
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
|
||||||
Self {
|
Self {
|
||||||
parent,
|
parent,
|
||||||
columns: columns.iter().map(|c| c.to_string()).collect(),
|
index,
|
||||||
name: None,
|
columns,
|
||||||
replace: true,
|
replace: true,
|
||||||
index_type: IndexType::Scalar,
|
|
||||||
metric_type: MetricType::L2,
|
|
||||||
num_partitions: None,
|
|
||||||
num_sub_vectors: None,
|
|
||||||
num_bits: 8,
|
|
||||||
sample_rate: 256,
|
|
||||||
max_iterations: 50,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a Scalar Index.
|
/// Whether to replace the existing index, the default is `true`.
|
||||||
///
|
///
|
||||||
/// Accepted parameters:
|
/// If this is false, and another index already exists on the same columns
|
||||||
/// - `replace`: Replace the existing index.
|
/// and the same name, then an error will be returned. This is true even if
|
||||||
/// - `name`: Index name. Default: `None`
|
/// that index is out of date.
|
||||||
pub fn scalar(mut self) -> Self {
|
|
||||||
self.index_type = IndexType::Scalar;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build an IVF PQ index.
|
|
||||||
///
|
|
||||||
/// Accepted parameters:
|
|
||||||
/// - `replace`: Replace the existing index.
|
|
||||||
/// - `name`: Index name. Default: `None`
|
|
||||||
/// - `metric_type`: [MetricType] to use to build Vector Index.
|
|
||||||
/// - `num_partitions`: Number of IVF partitions.
|
|
||||||
/// - `num_sub_vectors`: Number of sub-vectors of PQ.
|
|
||||||
/// - `num_bits`: Number of bits used for PQ centroids.
|
|
||||||
/// - `sample_rate`: The rate to find samples to train kmeans.
|
|
||||||
/// - `max_iterations`: Max iteration to train kmeans.
|
|
||||||
pub fn ivf_pq(mut self) -> Self {
|
|
||||||
self.index_type = IndexType::Vector;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The columns to build index on.
|
|
||||||
pub fn columns(mut self, cols: &[&str]) -> Self {
|
|
||||||
self.columns = cols.iter().map(|s| s.to_string()).collect();
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether to replace the existing index, default is `true`.
|
|
||||||
pub fn replace(mut self, v: bool) -> Self {
|
pub fn replace(mut self, v: bool) -> Self {
|
||||||
self.replace = v;
|
self.replace = v;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the index name.
|
pub async fn execute(self) -> Result<()> {
|
||||||
pub fn name(mut self, name: &str) -> Self {
|
self.parent.clone().create_index(self).await
|
||||||
self.name = Some(name.to_string());
|
}
|
||||||
self
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [MetricType] to use to build Vector Index.
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum IndexType {
|
||||||
|
IvfPq,
|
||||||
|
BTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A description of an index currently configured on a column
|
||||||
|
pub struct IndexConfig {
|
||||||
|
/// The type of the index
|
||||||
|
pub index_type: IndexType,
|
||||||
|
/// The columns in the index
|
||||||
///
|
///
|
||||||
/// Default value is [MetricType::L2].
|
/// Currently this is always a Vec of size 1. In the future there may
|
||||||
pub fn metric_type(mut self, metric_type: MetricType) -> Self {
|
/// be more columns to represent composite indices.
|
||||||
self.metric_type = metric_type;
|
pub columns: Vec<String>,
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of IVF partitions.
|
|
||||||
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
|
|
||||||
self.num_partitions = Some(num_partitions);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of sub-vectors of PQ.
|
|
||||||
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
|
|
||||||
self.num_sub_vectors = Some(num_sub_vectors);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of bits used for PQ centroids.
|
|
||||||
pub fn num_bits(mut self, num_bits: u32) -> Self {
|
|
||||||
self.num_bits = num_bits;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The rate to find samples to train kmeans.
|
|
||||||
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
|
|
||||||
self.sample_rate = sample_rate;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Max iteration to train kmeans.
|
|
||||||
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
|
|
||||||
self.max_iterations = max_iterations;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the parameters.
|
|
||||||
pub async fn build(self) -> Result<()> {
|
|
||||||
self.parent.clone().do_create_index(self).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
|
|
||||||
let num_partitions = (rows as f64).sqrt() as u32;
|
|
||||||
max(1, num_partitions)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
|
||||||
if dim % 16 == 0 {
|
|
||||||
// Should be more aggressive than this default.
|
|
||||||
dim / 16
|
|
||||||
} else if dim % 8 == 0 {
|
|
||||||
dim / 8
|
|
||||||
} else {
|
|
||||||
log::warn!(
|
|
||||||
"The dimension of the vector is not divisible by 8 or 16, \
|
|
||||||
which may cause performance degradation in PQ"
|
|
||||||
);
|
|
||||||
1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
30
rust/lancedb/src/index/scalar.rs
Normal file
30
rust/lancedb/src/index/scalar.rs
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
//! Scalar indices are exact indices that are used to quickly satisfy a variety of filters
|
||||||
|
//! against a column of scalar values.
|
||||||
|
//!
|
||||||
|
//! Scalar indices are currently supported on numeric, string, boolean, and temporal columns.
|
||||||
|
//!
|
||||||
|
//! A scalar index will help with queries with filters like `x > 10`, `x < 10`, `x = 10`,
|
||||||
|
//! etc. Scalar indices can also speed up prefiltering for vector searches. A single
|
||||||
|
//! vector search with prefiltering can use both a scalar index and a vector index.
|
||||||
|
|
||||||
|
/// Builder for a btree index
|
||||||
|
///
|
||||||
|
/// A btree index is an index on scalar columns. The index stores a copy of the column
|
||||||
|
/// in sorted order. A header entry is created for each block of rows (currently the
|
||||||
|
/// block size is fixed at 4096). These header entries are stored in a separate
|
||||||
|
/// cacheable structure (a btree). To search for data the header is used to determine
|
||||||
|
/// which blocks need to be read from disk.
|
||||||
|
///
|
||||||
|
/// For example, a btree index in a table with 1Bi rows requires sizeof(Scalar) * 256Ki
|
||||||
|
/// bytes of memory and will generally need to read sizeof(Scalar) * 4096 bytes to find
|
||||||
|
/// the correct row ids.
|
||||||
|
///
|
||||||
|
/// This index is good for scalar columns with mostly distinct values and does best when
|
||||||
|
/// the query is highly selective.
|
||||||
|
///
|
||||||
|
/// The btree index does not currently have any parameters though parameters such as the
|
||||||
|
/// block size may be added in the future.
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct BTreeIndexBuilder {}
|
||||||
|
|
||||||
|
impl BTreeIndexBuilder {}
|
||||||
@@ -12,10 +12,19 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Vector indices are approximate indices that are used to find rows similar to
|
||||||
|
//! a query vector. Vector indices speed up vector searches.
|
||||||
|
//!
|
||||||
|
//! Vector indices are only supported on fixed-size-list (tensor) columns of floating point
|
||||||
|
//! values
|
||||||
|
use std::cmp::max;
|
||||||
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use lance::table::format::{Index, Manifest};
|
use lance::table::format::{Index, Manifest};
|
||||||
|
|
||||||
|
use crate::DistanceType;
|
||||||
|
|
||||||
pub struct VectorIndex {
|
pub struct VectorIndex {
|
||||||
pub columns: Vec<String>,
|
pub columns: Vec<String>,
|
||||||
pub index_name: String,
|
pub index_name: String,
|
||||||
@@ -42,3 +51,145 @@ pub struct VectorIndexStatistics {
|
|||||||
pub num_indexed_rows: usize,
|
pub num_indexed_rows: usize,
|
||||||
pub num_unindexed_rows: usize,
|
pub num_unindexed_rows: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builder for an IVF PQ index.
|
||||||
|
///
|
||||||
|
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
||||||
|
/// are grouped into partitions of similar vectors. Each partition keeps track of
|
||||||
|
/// a centroid which is the average value of all vectors in the group.
|
||||||
|
///
|
||||||
|
/// During a query the centroids are compared with the query vector to find the closest
|
||||||
|
/// partitions. The compressed vectors in these partitions are then searched to find
|
||||||
|
/// the closest vectors.
|
||||||
|
///
|
||||||
|
/// The compression scheme is called product quantization. Each vector is divided into
|
||||||
|
/// subvectors and then each subvector is quantized into a small number of bits. the
|
||||||
|
/// parameters `num_bits` and `num_subvectors` control this process, providing a tradeoff
|
||||||
|
/// between index size (and thus search speed) and index accuracy.
|
||||||
|
///
|
||||||
|
/// The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||||
|
/// many groups to create.
|
||||||
|
///
|
||||||
|
/// Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||||
|
/// currently is also a memory intensive operation.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IvfPqIndexBuilder {
|
||||||
|
pub(crate) distance_type: DistanceType,
|
||||||
|
pub(crate) num_partitions: Option<u32>,
|
||||||
|
pub(crate) num_sub_vectors: Option<u32>,
|
||||||
|
pub(crate) sample_rate: u32,
|
||||||
|
pub(crate) max_iterations: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IvfPqIndexBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
distance_type: DistanceType::L2,
|
||||||
|
num_partitions: None,
|
||||||
|
num_sub_vectors: None,
|
||||||
|
sample_rate: 256,
|
||||||
|
max_iterations: 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IvfPqIndexBuilder {
|
||||||
|
/// [DistanceType] to use to build the index.
|
||||||
|
///
|
||||||
|
/// Default value is [DistanceType::L2].
|
||||||
|
///
|
||||||
|
/// This is used when training the index to calculate the IVF partitions (vectors are
|
||||||
|
/// grouped in partitions with similar vectors according to this distance type) and to
|
||||||
|
/// calculate a subvector's code during quantization.
|
||||||
|
///
|
||||||
|
/// The metric type used to train an index MUST match the metric type used to search the
|
||||||
|
/// index. Failure to do so will yield inaccurate results.
|
||||||
|
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
|
||||||
|
self.distance_type = distance_type;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of IVF partitions to create.
|
||||||
|
///
|
||||||
|
/// This value should generally scale with the number of rows in the dataset. By default
|
||||||
|
/// the number of partitions is the square root of the number of rows.
|
||||||
|
///
|
||||||
|
/// If this value is too large then the first part of the search (picking the right partition)
|
||||||
|
/// will be slow. If this value is too small then the second part of the search (searching
|
||||||
|
/// within a partition) will be slow.
|
||||||
|
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
|
||||||
|
self.num_partitions = Some(num_partitions);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of sub-vectors of PQ.
|
||||||
|
///
|
||||||
|
/// This value controls how much the vector is compressed during the quantization step.
|
||||||
|
/// The more sub vectors there are the less the vector is compressed. The default is
|
||||||
|
/// the dimension of the vector divided by 16. If the dimension is not evenly divisible
|
||||||
|
/// by 16 we use the dimension divded by 8.
|
||||||
|
///
|
||||||
|
/// The above two cases are highly preferred. Having 8 or 16 values per subvector allows
|
||||||
|
/// us to use efficient SIMD instructions.
|
||||||
|
///
|
||||||
|
/// If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
|
||||||
|
/// will likely result in poor performance.
|
||||||
|
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
|
||||||
|
self.num_sub_vectors = Some(num_sub_vectors);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The rate used to calculate the number of training vectors for kmeans.
|
||||||
|
///
|
||||||
|
/// When an IVF PQ index is trained, we need to calculate partitions. These are groups
|
||||||
|
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
|
///
|
||||||
|
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||||
|
/// random sample of the data. This parameter controls the size of the sample. The total
|
||||||
|
/// number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||||
|
///
|
||||||
|
/// Increasing this value might improve the quality of the index but in most cases the
|
||||||
|
/// default should be sufficient.
|
||||||
|
///
|
||||||
|
/// The default value is 256.
|
||||||
|
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
|
||||||
|
self.sample_rate = sample_rate;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Max iterations to train kmeans.
|
||||||
|
///
|
||||||
|
/// When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
|
||||||
|
/// controls how many iterations of kmeans to run.
|
||||||
|
///
|
||||||
|
/// Increasing this might improve the quality of the index but in most cases the parameter
|
||||||
|
/// is unused because kmeans will converge with fewer iterations. The parameter is only
|
||||||
|
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
|
||||||
|
/// that setting this larger will lead to the index converging anyways.
|
||||||
|
///
|
||||||
|
/// The default value is 50.
|
||||||
|
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
|
||||||
|
self.max_iterations = max_iterations;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
|
||||||
|
let num_partitions = (rows as f64).sqrt() as u32;
|
||||||
|
max(1, num_partitions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||||
|
if dim % 16 == 0 {
|
||||||
|
// Should be more aggressive than this default.
|
||||||
|
dim / 16
|
||||||
|
} else if dim % 8 == 0 {
|
||||||
|
dim / 8
|
||||||
|
} else {
|
||||||
|
log::warn!(
|
||||||
|
"The dimension of the vector is not divisible by 8 or 16, \
|
||||||
|
which may cause performance degradation in PQ"
|
||||||
|
);
|
||||||
|
1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -130,14 +130,13 @@
|
|||||||
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||||
//! # RecordBatchIterator, Int32Array};
|
//! # RecordBatchIterator, Int32Array};
|
||||||
//! # use arrow_schema::{Schema, Field, DataType};
|
//! # use arrow_schema::{Schema, Field, DataType};
|
||||||
|
//! use lancedb::index::Index;
|
||||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||||
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||||
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||||
//! tbl.create_index(&["vector"])
|
//! tbl.create_index(&["vector"], Index::Auto)
|
||||||
//! .ivf_pq()
|
//! .execute()
|
||||||
//! .num_partitions(256)
|
|
||||||
//! .build()
|
|
||||||
//! .await
|
//! .await
|
||||||
//! .unwrap();
|
//! .unwrap();
|
||||||
//! # });
|
//! # });
|
||||||
@@ -181,6 +180,7 @@
|
|||||||
//! # });
|
//! # });
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
pub mod arrow;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod data;
|
pub mod data;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
@@ -194,6 +194,7 @@ pub mod table;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
pub use lance_linalg::distance::DistanceType;
|
||||||
pub use table::Table;
|
pub use table::Table;
|
||||||
|
|
||||||
/// Connect to a database
|
/// Connect to a database
|
||||||
|
|||||||
@@ -15,9 +15,9 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::Float32Array;
|
use arrow_array::Float32Array;
|
||||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
|
||||||
use lance_linalg::distance::MetricType;
|
use lance_linalg::distance::MetricType;
|
||||||
|
|
||||||
|
use crate::arrow::SendableRecordBatchStream;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::table::TableInternal;
|
use crate::table::TableInternal;
|
||||||
|
|
||||||
@@ -81,13 +81,15 @@ impl Query {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the query plan to a [`DatasetRecordBatchStream`]
|
/// Convert the query plan to a [`SendableRecordBatchStream`]
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * A [DatasetRecordBatchStream] with the query's results.
|
/// * A [SendableRecordBatchStream] with the query's results.
|
||||||
pub async fn execute_stream(&self) -> Result<DatasetRecordBatchStream> {
|
pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
|
||||||
self.parent.clone().do_query(self).await
|
Ok(SendableRecordBatchStream::from(
|
||||||
|
self.parent.clone().query(self).await?,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the column to query
|
/// Set the column to query
|
||||||
@@ -363,6 +365,10 @@ mod tests {
|
|||||||
let arr: &Int32Array = b["id"].as_primitive();
|
let arr: &Int32Array = b["id"].as_primitive();
|
||||||
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reject bad filter
|
||||||
|
let result = table.query().filter("id = 0 AND").execute_stream().await;
|
||||||
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewCol
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::Result,
|
error::Result,
|
||||||
index::IndexBuilder,
|
index::{IndexBuilder, IndexConfig},
|
||||||
query::Query,
|
query::Query,
|
||||||
table::{
|
table::{
|
||||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||||
TableInternal,
|
TableInternal, UpdateBuilder,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -45,25 +45,40 @@ impl TableInternal for RemoteTable {
|
|||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
&self.name
|
&self.name
|
||||||
}
|
}
|
||||||
|
async fn version(&self) -> Result<u64> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn checkout(&self, _version: u64) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn checkout_latest(&self) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn restore(&self) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
async fn schema(&self) -> Result<SchemaRef> {
|
async fn schema(&self) -> Result<SchemaRef> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn do_add(&self, _add: AddDataBuilder) -> Result<()> {
|
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn do_query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
|
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn delete(&self, _predicate: &str) -> Result<()> {
|
async fn delete(&self, _predicate: &str) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn do_create_index(&self, _index: IndexBuilder) -> Result<()> {
|
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn do_merge_insert(
|
async fn merge_insert(
|
||||||
&self,
|
&self,
|
||||||
_params: MergeInsertBuilder,
|
_params: MergeInsertBuilder,
|
||||||
_new_data: Box<dyn RecordBatchReader + Send>,
|
_new_data: Box<dyn RecordBatchReader + Send>,
|
||||||
@@ -86,4 +101,7 @@ impl TableInternal for RemoteTable {
|
|||||||
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
|
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -83,6 +83,33 @@ impl DatasetRef {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn as_time_travel(&mut self, target_version: u64) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Latest { dataset, .. } => {
|
||||||
|
*self = Self::TimeTravel {
|
||||||
|
dataset: dataset.checkout_version(target_version).await?,
|
||||||
|
version: target_version,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Self::TimeTravel { dataset, version } => {
|
||||||
|
if *version != target_version {
|
||||||
|
*self = Self::TimeTravel {
|
||||||
|
dataset: dataset.checkout_version(target_version).await?,
|
||||||
|
version: target_version,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn time_travel_version(&self) -> Option<u64> {
|
||||||
|
match self {
|
||||||
|
Self::Latest { .. } => None,
|
||||||
|
Self::TimeTravel { version, .. } => Some(*version),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn set_latest(&mut self, dataset: Dataset) {
|
fn set_latest(&mut self, dataset: Dataset) {
|
||||||
match self {
|
match self {
|
||||||
Self::Latest {
|
Self::Latest {
|
||||||
@@ -106,23 +133,6 @@ impl DatasetConsistencyWrapper {
|
|||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new wrapper in the time travel mode.
|
|
||||||
pub fn new_time_travel(dataset: Dataset, version: u64) -> Self {
|
|
||||||
Self(Arc::new(RwLock::new(DatasetRef::TimeTravel {
|
|
||||||
dataset,
|
|
||||||
version,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an independent copy of self.
|
|
||||||
///
|
|
||||||
/// Unlike Clone, this will track versions independently of the original wrapper and
|
|
||||||
/// will be tied to a different RwLock.
|
|
||||||
pub async fn duplicate(&self) -> Self {
|
|
||||||
let ds_ref = self.0.read().await;
|
|
||||||
Self(Arc::new(RwLock::new((*ds_ref).clone())))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get an immutable reference to the dataset.
|
/// Get an immutable reference to the dataset.
|
||||||
pub async fn get(&self) -> Result<DatasetReadGuard<'_>> {
|
pub async fn get(&self) -> Result<DatasetReadGuard<'_>> {
|
||||||
self.ensure_up_to_date().await?;
|
self.ensure_up_to_date().await?;
|
||||||
@@ -132,7 +142,19 @@ impl DatasetConsistencyWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get a mutable reference to the dataset.
|
/// Get a mutable reference to the dataset.
|
||||||
|
///
|
||||||
|
/// If the dataset is in time travel mode this will fail
|
||||||
pub async fn get_mut(&self) -> Result<DatasetWriteGuard<'_>> {
|
pub async fn get_mut(&self) -> Result<DatasetWriteGuard<'_>> {
|
||||||
|
self.ensure_mutable().await?;
|
||||||
|
self.ensure_up_to_date().await?;
|
||||||
|
Ok(DatasetWriteGuard {
|
||||||
|
guard: self.0.write().await,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a mutable reference to the dataset without requiring the
|
||||||
|
/// dataset to be in a Latest mode.
|
||||||
|
pub async fn get_mut_unchecked(&self) -> Result<DatasetWriteGuard<'_>> {
|
||||||
self.ensure_up_to_date().await?;
|
self.ensure_up_to_date().await?;
|
||||||
Ok(DatasetWriteGuard {
|
Ok(DatasetWriteGuard {
|
||||||
guard: self.0.write().await,
|
guard: self.0.write().await,
|
||||||
@@ -140,7 +162,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a wrapper in latest version mode
|
/// Convert into a wrapper in latest version mode
|
||||||
pub async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
|
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
|
||||||
self.0
|
self.0
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
@@ -148,6 +170,10 @@ impl DatasetConsistencyWrapper {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
|
||||||
|
self.0.write().await.as_time_travel(target_version).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Provide a known latest version of the dataset.
|
/// Provide a known latest version of the dataset.
|
||||||
///
|
///
|
||||||
/// This is usually done after some write operation, which inherently will
|
/// This is usually done after some write operation, which inherently will
|
||||||
@@ -160,6 +186,22 @@ impl DatasetConsistencyWrapper {
|
|||||||
self.0.write().await.reload().await
|
self.0.write().await.reload().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the version, if in time travel mode, or None otherwise
|
||||||
|
pub async fn time_travel_version(&self) -> Option<u64> {
|
||||||
|
self.0.read().await.time_travel_version()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ensure_mutable(&self) -> Result<()> {
|
||||||
|
let dataset_ref = self.0.read().await;
|
||||||
|
match &*dataset_ref {
|
||||||
|
DatasetRef::Latest { .. } => Ok(()),
|
||||||
|
DatasetRef::TimeTravel { .. } => Err(crate::Error::InvalidInput {
|
||||||
|
message: "table cannot be modified when a specific version is checked out"
|
||||||
|
.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn is_up_to_date(&self) -> Result<bool> {
|
async fn is_up_to_date(&self) -> Result<bool> {
|
||||||
let dataset_ref = self.0.read().await;
|
let dataset_ref = self.0.read().await;
|
||||||
match &*dataset_ref {
|
match &*dataset_ref {
|
||||||
|
|||||||
@@ -98,6 +98,6 @@ impl MergeInsertBuilder {
|
|||||||
///
|
///
|
||||||
/// Nothing is returned but the [`super::Table`] is updated
|
/// Nothing is returned but the [`super::Table`] is updated
|
||||||
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
|
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
|
||||||
self.table.clone().do_merge_insert(self, new_data).await
|
self.table.clone().merge_insert(self, new_data).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user