Compare commits

...

15 Commits

Author SHA1 Message Date
Lance Release
11fcdb1194 Bump version: 0.8.2-beta.0 → 0.8.2 2024-06-05 13:47:16 +00:00
Lance Release
95a5a0d713 Bump version: 0.8.1 → 0.8.2-beta.0 2024-06-05 13:47:16 +00:00
Weston Pace
c3043a54c6 feat: bump lance dependency to 0.12.1 (#1357) 2024-06-05 06:07:11 -07:00
Weston Pace
d5586c9c32 feat: make it possible to opt in to using the v2 format (#1352)
This also exposed the max_batch_length configuration option in
python/node (it was needed to verify if we are actually in v2 mode or
not)
2024-06-04 21:52:14 -07:00
Rob Meng
d39e7d23f4 feat: fast path for checkout_latest (#1355)
similar to https://github.com/lancedb/lancedb/pull/1354
do locked IO less frequently
2024-06-04 23:01:28 -04:00
Rob Meng
ddceda4ff7 feat: add fast path to dataset reload (#1354)
most of the time we don't need to reload. Locking the write lock and
performing IO is not an ideal pattern.

This PR tries to make the critical section of `.write()` happen less
frequently.

This isn't the most ideal solution. The most ideal solution should not
lock until the new dataset has been loaded. But that would require too
much refactoring.
2024-06-04 19:03:53 -04:00
Cory Grinstead
70f92f19a6 feat(nodejs): table.search functionality (#1341)
closes https://github.com/lancedb/lancedb/issues/1256
2024-06-04 14:04:03 -05:00
Cory Grinstead
d9fb6457e1 fix(nodejs): better support for f16 and f64 (#1343)
closes https://github.com/lancedb/lancedb/issues/1292
closes https://github.com/lancedb/lancedb/issues/1293
2024-06-04 13:41:21 -05:00
Lei Xu
56b4fd2bd9 feat(rust): allow to create execution plan on queries (#1350) 2024-05-31 17:33:58 -07:00
paul n walsh
7c133ec416 feat(nodejs): table.toArrow function (#1282)
Addresses https://github.com/lancedb/lancedb/issues/1254.

---------

Co-authored-by: universalmind303 <cory.grinstead@gmail.com>
2024-05-31 13:24:21 -05:00
QianZhu
1dbb4cd1e2 fix: error msg when query vector dim is wrong (#1339)
- changed the error msg for table.search with wrong query vector dim 
- added missing fields for listIndices and indexStats to be consistent
with Python API - will make changes in node integ test
2024-05-31 10:18:06 -07:00
Paul Rinaldi
af65417d19 fix: update broken blog link on readme (#1310) 2024-05-31 10:04:56 -07:00
Cory Grinstead
01dd6c5e75 feat(rust): openai embedding function (#1275)
part of https://github.com/lancedb/lancedb/issues/994. 

Adds the ability to use the openai embedding functions.


the example can be run by the following

```sh
> EXPORT OPENAI_API_KEY="sk-..."
> cargo run --example openai --features=openai
```

which should output
```
Closest match: Winter Parka
```
2024-05-30 15:55:55 -05:00
Weston Pace
1e85b57c82 ci: don't update package locks if we are not releasing node (#1323)
This doesn't actually block a python-only release since this step runs
after the version bump has been pushed but it still would be nice for
the git job to finish successfully.
2024-05-30 04:42:06 -07:00
Ayush Chaurasia
16eff254ea feat: add support for new cohere models in cohere and bedrock embedding functions (#1335)
Fixes #1329

Will update docs on https://github.com/lancedb/lancedb/pull/1326
2024-05-30 10:20:03 +05:30
40 changed files with 1555 additions and 358 deletions

View File

@@ -94,6 +94,6 @@ jobs:
branch: ${{ github.ref }}
tags: true
- uses: ./.github/workflows/update_package_lock
if: ${{ inputs.dry_run }} == "false"
if: ${{ !inputs.dry_run && inputs.other }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,5 +1,11 @@
[workspace]
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"]
members = [
"rust/ffi/node",
"rust/lancedb",
"nodejs",
"python",
"java/core/lancedb-jni",
]
# Python package needs to be built by maturin.
exclude = ["python"]
resolver = "2"
@@ -14,10 +20,11 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.11.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.11.1" }
lance-linalg = { "version" = "=0.11.1" }
lance-testing = { "version" = "=0.11.1" }
lance = { "version" = "=0.12.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.12.1" }
lance-linalg = { "version" = "=0.12.1" }
lance-testing = { "version" = "=0.12.1" }
lance-datafusion = { "version" = "=0.12.1" }
# Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"
@@ -29,6 +36,7 @@ arrow-arith = "51.0"
arrow-cast = "51.0"
async-trait = "0"
chrono = "0.4.35"
datafusion-physical-plan = "37.1"
half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits",
] }

View File

@@ -83,5 +83,5 @@ result = table.search([100, 100]).limit(2).to_pandas()
```
## Blogs, Tutorials & Videos
* 📈 <a href="https://blog.eto.ai/benchmarking-random-access-in-lance-ed690757a826">2000x better performance with Lance over Parquet</a>
* 📈 <a href="https://blog.lancedb.com/benchmarking-random-access-in-lance/">2000x better performance with Lance over Parquet</a>
* 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/docs/src/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a>

View File

@@ -704,6 +704,9 @@ export interface VectorIndex {
export interface IndexStats {
numIndexedRows: number | null
numUnindexedRows: number | null
index_type: string | null
distance_type: string | null
completed_at: string | null
}
/**

View File

@@ -509,7 +509,8 @@ export class RemoteTable<T = number[]> implements Table<T> {
return (await results.body()).indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
uuid: index.index_uuid
uuid: index.index_uuid,
status: index.status
}))
}
@@ -520,7 +521,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
const body = await results.body()
return {
numIndexedRows: body?.num_indexed_rows,
numUnindexedRows: body?.num_unindexed_rows
numUnindexedRows: body?.num_unindexed_rows,
index_type: body?.index_type,
distance_type: body?.distance_type,
completed_at: body?.completed_at
}
}

View File

@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Field, Float64, Schema } from "apache-arrow";
import * as tmp from "tmp";
import { Connection, connect } from "../lancedb";
import { Connection, Table, connect } from "../lancedb";
describe("when connecting", () => {
let tmpDir: tmp.DirResult;
@@ -86,4 +87,39 @@ describe("given a connection", () => {
tables = await db.tableNames({ startAfter: "a" });
expect(tables).toEqual(["b", "c"]);
});
it("should create tables in v2 mode", async () => {
const db = await connect(tmpDir.name);
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
// Create in v1 mode
let table = await db.createTable("test", data);
const isV2 = async (table: Table) => {
const data = await table.query().toArrow({ maxBatchLength: 100000 });
console.log(data.batches.length);
return data.batches.length < 5;
};
await expect(isV2(table)).resolves.toBe(false);
// Create in v2 mode
table = await db.createTable("test_v2", data, { useLegacyFormat: false });
await expect(isV2(table)).resolves.toBe(true);
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
// Create empty in v2 mode
const schema = new Schema([new Field("id", new Float64(), true)]);
table = await db.createEmptyTable("test_v2_empty", schema, {
useLegacyFormat: false,
});
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
});
});

View File

@@ -0,0 +1,314 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as tmp from "tmp";
import { connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float,
Float16,
Float32,
Float64,
Schema,
Utf8,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide manual embeddings with multiple float datatype",
async (floatType) => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const data = [{ text: "hello" }, { text: "hello world" }];
const schema = new Schema([
new Field("vector", new FixedSizeList(3, new Field("item", floatType))),
new Field("text", new Utf8()),
]);
const func = new MockEmbeddingFunction();
const name = "test";
const db = await connect(tmpDir.name);
const table = await db.createTable(name, data, {
schema,
embeddingFunction: {
sourceColumn: "text",
function: func,
},
});
const res = await table.query().toArray();
expect([...res[0].vector]).toEqual([1, 2, 3]);
},
);
test.only.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide auto embeddings with multiple float datatypes",
async (floatType) => {
@register("test1")
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
@register("test")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("test")!.create();
const func2 = getRegistry()
.get<MockEmbeddingFunctionWithoutNDims>("test1")!
.create();
const schema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(floatType),
});
const schema2 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func2.vectorField({ datatype: floatType, dims: 3 }),
});
const schema3 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func.vectorField({
datatype: new FixedSizeList(3, new Field("item", floatType, true)),
dims: 3,
}),
});
const expectedSchema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", floatType, true)),
true,
),
]);
const stringSchema = JSON.stringify(schema, null, 2);
const stringSchema2 = JSON.stringify(schema2, null, 2);
const stringSchema3 = JSON.stringify(schema3, null, 2);
const stringExpectedSchema = JSON.stringify(expectedSchema, null, 2);
expect(stringSchema).toEqual(stringExpectedSchema);
expect(stringSchema2).toEqual(stringExpectedSchema);
expect(stringSchema3).toEqual(stringExpectedSchema);
},
);
});

View File

@@ -21,19 +21,17 @@ import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb";
import {
Table as ArrowTable,
Field,
FixedSizeList,
Float,
Float32,
Float64,
Int32,
Int64,
Schema,
Utf8,
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { EmbeddingFunction, LanceSchema, register } from "../lancedb/embedding";
import { Index } from "../lancedb/indices";
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -44,6 +42,7 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
@@ -94,6 +93,43 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1);
});
// https://github.com/lancedb/lancedb/issues/1293
test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])(
"can create empty table with non default float type: %s",
async (floatType) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello", vector: Array(512).fill(1.0) },
{ text: "hello world", vector: Array(512).fill(1.0) },
];
const f64Schema = new arrow.Schema([
new arrow.Field("text", new arrow.Utf8(), true),
new arrow.Field(
"vector",
new arrow.FixedSizeList(512, new arrow.Field("item", floatType)),
true,
),
]);
const f64Table = await db.createEmptyTable("f64", f64Schema, {
mode: "overwrite",
});
try {
await f64Table.add(data);
const res = await f64Table.query().toArray();
expect(res.length).toBe(2);
} catch (e) {
expect(e).toBeUndefined();
}
},
);
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
});
});
describe("When creating an index", () => {
@@ -431,161 +467,6 @@ describe("when dealing with versioning", () => {
});
});
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
});
describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult;
let table: Table;
@@ -613,3 +494,99 @@ describe("when optimizing a dataset", () => {
expect(stats.prune.oldVersionsRemoved).toBe(3);
});
});
describe("table.search", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
test("can search using a string", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 1;
}
embeddingDataType(): arrow.Float {
return new Float32();
}
// Hardcoded embeddings for the sake of testing
async computeQueryEmbeddings(_data: string) {
switch (_data) {
case "greetings":
return [0.1];
case "farewell":
return [0.2];
default:
return null as never;
}
}
// Hardcoded embeddings for the sake of testing
async computeSourceEmbeddings(data: string[]) {
return data.map((s) => {
switch (s) {
case "hello world":
return [0.1];
case "goodbye world":
return [0.2];
default:
return null as never;
}
});
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });
const results = await table.search("greetings").then((r) => r.toArray());
expect(results[0].text).toBe(data[0].text);
const results2 = await table.search("farewell").then((r) => r.toArray());
expect(results2[0].text).toBe(data[1].text);
});
test("rejects if no embedding function provided", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
expect(table.search("hello")).rejects.toThrow(
"No embedding functions are defined in the table",
);
});
test.each([
[0.4, 0.5, 0.599], // number[]
Float32Array.of(0.4, 0.5, 0.599), // Float32Array
Float64Array.of(0.4, 0.5, 0.599), // Float64Array
])("can search using vectorlike datatypes", async (vectorlike) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
// biome-ignore lint/suspicious/noExplicitAny: test
const results: any[] = await table.search(vectorlike).toArray();
expect(results.length).toBe(2);
expect(results[0].text).toBe(data[1].text);
});
});

View File

@@ -31,7 +31,7 @@ import {
Schema,
Struct,
Utf8,
type Vector,
Vector,
makeBuilder,
makeData,
type makeTable,
@@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export type IntoVector = Float32Array | Float64Array | number[];
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;
@@ -182,6 +184,7 @@ export class MakeArrowTableOptions {
vector: new VectorColumnOptions(),
};
embeddings?: EmbeddingFunction<unknown>;
embeddingFunction?: EmbeddingFunctionConfig;
/**
* If true then string columns will be encoded with dictionary encoding
@@ -306,7 +309,11 @@ export function makeArrowTable(
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings);
opt.schema = validateSchemaEmbeddings(
opt.schema,
data,
options?.embeddingFunction,
);
}
const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns
@@ -545,7 +552,6 @@ async function applyEmbeddingsFromMetadata(
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
@@ -835,7 +841,7 @@ export function createEmptyTable(schema: Schema): ArrowTable {
function validateSchemaEmbeddings(
schema: Schema,
data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunction<unknown> | undefined,
embeddings: EmbeddingFunctionConfig | undefined,
) {
const fields = [];
const missingEmbeddingFields = [];

View File

@@ -71,6 +71,12 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
/**
* If true then data files will be written with the legacy format
*
* The default is true while the new format is in beta
*/
useLegacyFormat?: boolean;
schema?: Schema;
embeddingFunction?: EmbeddingFunctionConfig;
}
@@ -221,6 +227,7 @@ export class Connection {
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
);
return new Table(innerTable);
@@ -256,6 +263,7 @@ export class Connection {
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
);
return new Table(innerTable);
}

View File

@@ -19,6 +19,7 @@ import {
FixedSizeList,
Float,
Float32,
type IntoVector,
isDataType,
isFixedSizeList,
isFloat,
@@ -100,33 +101,55 @@ export abstract class EmbeddingFunction<
* @see {@link lancedb.LanceSchema}
*/
vectorField(
options?: Partial<FieldOptions>,
optionsOrDatatype?: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType;
const dims = this.ndims() ?? options?.dims;
if (!options?.datatype) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
let dtype: DataType | undefined;
let vectorType: DataType;
let dims: number | undefined = this.ndims();
// `func.vectorField(new Float32())`
if (isDataType(optionsOrDatatype)) {
dtype = optionsOrDatatype;
} else {
if (isFixedSizeList(options.datatype)) {
dtype = options.datatype;
} else if (isFloat(options.datatype)) {
// `func.vectorField({
// datatype: new Float32(),
// dims: 10
// })`
dims = dims ?? optionsOrDatatype?.dims;
dtype = optionsOrDatatype?.datatype;
}
if (dtype !== undefined) {
// `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))`
// or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})`
if (isFixedSizeList(dtype)) {
vectorType = dtype;
// `func.vectorField(new Float32())`
// or `func.vectorField({datatype: new Float32()})`
} else if (isFloat(dtype)) {
// No `ndims` impl and no `{dims: n}` provided;
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = newVectorType(dims, options.datatype);
vectorType = newVectorType(dims, dtype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
} else {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = new FixedSizeList(
dims,
new Field("item", new Float32(), true),
);
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
return [dtype, metadata];
return [vectorType, metadata];
}
/** The number of dimensions of the embeddings */
@@ -147,9 +170,7 @@ export abstract class EmbeddingFunction<
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(
data: T,
): Promise<number[] | Float32Array | Float64Array> {
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);

View File

@@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry {
* Register an embedding function
* @param name The name of the function
* @param func The function to register
* @throws Error if the function is already registered
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
@@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry {
this.#functions.clear();
}
/**
* @ignore
*/
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
tableFromIPC,
} from "./arrow";
import { type IvfPqOptions } from "./indices";
import {
RecordBatchIterator as NativeBatchIterator,
@@ -50,6 +55,39 @@ export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
}
/* eslint-enable */
class RecordBatchIterable<
NativeQueryType extends NativeQuery | NativeVectorQuery,
> implements AsyncIterable<RecordBatch>
{
private inner: NativeQueryType;
private options?: QueryExecutionOptions;
constructor(inner: NativeQueryType, options?: QueryExecutionOptions) {
this.inner = inner;
this.options = options;
}
// biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength),
);
}
}
/**
* Options that control the behavior of a particular query execution
*/
export interface QueryExecutionOptions {
/**
* The maximum number of rows to return in a single batch
*
* Batches may have fewer rows if the underlying data is stored
* in smaller chunks.
*/
maxBatchLength?: number;
}
/** Common methods supported by all query types */
export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery,
@@ -108,9 +146,12 @@ export class QueryBase<
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/
select(
columns: string[] | Map<string, string> | Record<string, string>,
columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
}
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
@@ -133,8 +174,10 @@ export class QueryBase<
return this as unknown as QueryType;
}
protected nativeExecute(): Promise<NativeBatchIterator> {
return this.inner.execute();
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
return this.inner.execute(options?.maxBatchLength);
}
/**
@@ -148,8 +191,10 @@ export class QueryBase<
* single query)
*
*/
protected execute(): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute());
protected execute(
options?: Partial<QueryExecutionOptions>,
): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute(options));
}
// biome-ignore lint/suspicious/noExplicitAny: skip
@@ -159,19 +204,18 @@ export class QueryBase<
}
/** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(): Promise<ArrowTable> {
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
const batches = [];
for await (const batch of this) {
for await (const batch of new RecordBatchIterable(this.inner, options)) {
batches.push(batch);
}
return new ArrowTable(batches);
}
/** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
// biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[]
async toArray(options?: Partial<QueryExecutionOptions>): Promise<any[]> {
const tbl = await this.toArrow(options);
return tbl.toArray();
}
}
@@ -370,9 +414,8 @@ export class Query extends QueryBase<NativeQuery, Query> {
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: unknown): VectorQuery {
// biome-ignore lint/suspicious/noExplicitAny: skip
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
}
}

View File

@@ -12,9 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow";
import {
Table as ArrowTable,
Data,
IntoVector,
Schema,
fromDataToBuffer,
tableFromIPC,
} from "./arrow";
import { getRegistry } from "./embedding/registry";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices";
import {
AddColumnsSql,
@@ -24,8 +31,8 @@ import {
Table as _NativeTable,
} from "./native";
import { Query, VectorQuery } from "./query";
export { IndexConfig } from "./native";
/**
* Options for adding data to a table.
*/
@@ -110,6 +117,14 @@ export class Table {
return this.inner.display();
}
async #getEmbeddingFunctions(): Promise<
Map<string, EmbeddingFunctionConfig>
> {
const schema = await this.schema();
const registry = getRegistry();
return registry.parseFunctions(schema.metadata);
}
/** Get the schema of the table. */
async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema();
@@ -130,6 +145,7 @@ export class Table {
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
await this.inner.add(buffer, mode);
}
@@ -270,6 +286,40 @@ export class Table {
return new Query(this.inner);
}
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
*/
search(query: string): Promise<VectorQuery>;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {IntoVector} query - the query vector
*/
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.#getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
}
}
/**
* Search the table with a given query vector.
*
@@ -277,7 +327,7 @@ export class Table {
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
vectorSearch(vector: unknown): VectorQuery {
vectorSearch(vector: IntoVector): VectorQuery {
return this.query().nearestTo(vector);
}
@@ -423,4 +473,9 @@ export class Table {
async listIndices(): Promise<IndexConfig[]> {
return await this.inner.listIndices();
}
/** Return the table as an arrow table */
async toArrow(): Promise<ArrowTable> {
return await this.query().toArrow();
}
}

View File

@@ -126,6 +126,7 @@ impl Connection {
buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
@@ -136,6 +137,9 @@ impl Connection {
builder = builder.storage_option(key, value);
}
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder
.execute()
.await
@@ -150,6 +154,7 @@ impl Connection {
schema_buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
@@ -164,6 +169,9 @@ impl Connection {
builder = builder.storage_option(key, value);
}
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder
.execute()
.await

View File

@@ -56,6 +56,7 @@ pub enum WriteMode {
/// Write options when creating a Table.
#[napi(object)]
pub struct WriteOptions {
/// Write mode for writing to a table.
pub mode: Option<WriteMode>,
}

View File

@@ -15,6 +15,7 @@
use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::Select;
use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*;
@@ -62,10 +63,21 @@ impl Query {
}
#[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
pub async fn execute(
&self,
max_batch_length: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}
@@ -125,10 +137,21 @@ impl VectorQuery {
}
#[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
pub async fn execute(
&self,
max_batch_length: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.8.1"
current_version = "0.8.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.8.1"
version = "0.8.2"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -3,7 +3,7 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"pylance==0.11.1",
"pylance==0.12.1",
"ratelimiter~=1.0",
"requests>=2.31.0",
"retry>=0.9.2",

View File

@@ -24,6 +24,7 @@ class Connection(object):
mode: str,
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
async def create_empty_table(
self,
@@ -31,6 +32,7 @@ class Connection(object):
mode: str,
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
class Table:
@@ -72,7 +74,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...

View File

@@ -558,6 +558,8 @@ class AsyncConnection(object):
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
use_legacy_format: Optional[bool] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -600,6 +602,9 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
use_legacy_format: bool, optional, default True
If True, use the legacy format for the table. If False, use the new format.
The default is True while the new format is in beta.
Returns
@@ -761,7 +766,11 @@ class AsyncConnection(object):
if data is None:
new_table = await self._inner.create_empty_table(
name, mode, schema, storage_options=storage_options
name,
mode,
schema,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
)
else:
data = data_to_reader(data, schema)
@@ -770,6 +779,7 @@ class AsyncConnection(object):
mode,
data,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
)
return AsyncTable(new_table)

View File

@@ -153,7 +153,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Generate the embeddings for the given texts

View File

@@ -73,6 +73,8 @@ class BedRockText(TextEmbeddingFunction):
assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
@@ -87,21 +89,29 @@ class BedRockText(TextEmbeddingFunction):
# TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1":
return 1536
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
elif self.name in [
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
# TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims
return 1024
else:
raise ValueError(f"Unknown model name: {self.name}")
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> List[List[float]]:
return self.compute_source_embeddings(query)
return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> List[List[float]]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
# assume source input type if not passed by `compute_query_embeddings`
kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type
return self.generate_embeddings(texts, **kwargs)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
@@ -121,11 +131,11 @@ class BedRockText(TextEmbeddingFunction):
"""
results = []
for text in texts:
response = self._generate_embedding(text)
response = self._generate_embedding(text, *args, **kwargs)
results.append(response)
return results
def _generate_embedding(self, text: str) -> List[float]:
def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]:
"""
Get the embeddings for the given texts
@@ -141,14 +151,12 @@ class BedRockText(TextEmbeddingFunction):
"""
# format input body for provider
provider = self.name.split(".")[0]
_model_kwargs = {}
input_body = {**_model_kwargs}
input_body = {**kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body.pop("input_type", None)
input_body["inputText"] = text
body = json.dumps(input_body)

View File

@@ -19,7 +19,7 @@ import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
from .utils import api_key_not_found_help, TEXT
@register("cohere")
@@ -32,8 +32,36 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for
a list of available models.
The name of the model to use. List of acceptable models:
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
source_input_type: str, default "search_document"
The input type for the source column in the database
query_input_type: str, default "search_query"
The input type for the query column in the database
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
Examples
--------
@@ -61,14 +89,39 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
"""
name: str = "embed-multilingual-v2.0"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
client: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
return 768
if self.name in [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v2.0",
]:
return 1024
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
return 384
elif self.name == "embed-english-v2.0":
return 4096
elif self.name == "embed-multilingual-v2.0":
return 768
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or self.source_input_type
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given texts
@@ -78,9 +131,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
rs = CohereEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
return [emb for emb in rs.embeddings]

View File

@@ -1113,11 +1113,22 @@ class AsyncQueryBase(object):
self._inner.limit(limit)
return self
async def to_batches(self) -> AsyncRecordBatchReader:
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
Parameters
----------
max_batch_length: Optional[int]
The maximum number of selected records in a single RecordBatch object.
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
"""
return AsyncRecordBatchReader(await self._inner.execute())
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
async def to_arrow(self) -> pa.Table:
"""

View File

@@ -507,6 +507,52 @@ def test_empty_or_nonexistent_table(tmp_path):
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
def make_table():
return pa.table([pa.array([x for x in range(10 * 1024)])], names=["x"])
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table("test", data=make_data(), schema=schema)
async def is_in_v2_mode(tbl):
batches = await tbl.query().to_batches(max_batch_length=1024 * 10)
num_batches = 0
async for batch in batches:
num_batches += 1
return num_batches < 10
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
assert await is_in_v2_mode(tbl)
# Add data (should remain in v2 mode)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(

View File

@@ -91,6 +91,7 @@ impl Connection {
mode: &str,
data: &PyAny,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -103,6 +104,10 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
@@ -115,6 +120,7 @@ impl Connection {
mode: &str,
schema: &PyAny,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -128,6 +134,10 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))

View File

@@ -15,6 +15,7 @@
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
@@ -61,10 +62,14 @@ impl Query {
Ok(VectorQuery { inner })
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
@@ -115,10 +120,14 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}

View File

@@ -19,11 +19,13 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
chrono = { workspace = true }
datafusion-physical-plan.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-datafusion.workspace = true
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
@@ -38,11 +40,12 @@ url.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true}
polars = { version = ">=0.37,<0.40.0", optional = true }
[dev-dependencies]
tempfile = "3.5.0"
@@ -62,4 +65,10 @@ default = []
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]
polars = ["dep:polars-arrow", "dep:polars"]
[[example]]
name = "openai"
required-features = ["openai"]

View File

@@ -0,0 +1,82 @@
use std::{iter::once, sync::Arc};
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
api_key,
"text-embedding-3-large",
)?);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("openai", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"text",
"openai",
Some("embeddings"),
))?
.execute()
.await?;
// there is no equivalent to '.search(<query>)' yet
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Closest match: {}", text);
Ok(())
}
fn make_data() -> impl IntoArrow {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("text", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
]);
let id = Int32Array::from(vec![1, 2, 3, 4]);
let text = StringArray::from_iter_values(vec![
"Black T-Shirt",
"Leather Jacket",
"Winter Parka",
"Hooded Sweatshirt",
]);
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
let schema = Arc::new(schema);
let rb = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
)
.unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
}

View File

@@ -140,6 +140,7 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
pub(crate) write_options: WriteOptions,
pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
pub(crate) use_legacy_format: bool,
}
// Builder methods that only apply when we have initial data
@@ -153,6 +154,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
write_options: WriteOptions::default(),
table_definition: None,
embeddings: Vec::new(),
use_legacy_format: true,
}
}
@@ -184,6 +186,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
mode: self.mode,
write_options: self.write_options,
embeddings: self.embeddings,
use_legacy_format: self.use_legacy_format,
};
Ok((data, builder))
}
@@ -217,6 +220,7 @@ impl CreateTableBuilder<false, NoData> {
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
embeddings: Vec::new(),
use_legacy_format: false,
}
}
@@ -278,6 +282,20 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
}
self
}
/// Set to true to use the v1 format for data files
///
/// This is currently defaulted to true and can be set to false to opt-in
/// to the new format. This should only be used for experimentation and
/// evaluation. The new format is still in beta and may change in ways that
/// are not backwards compatible.
///
/// Once the new format is stable, the default will change to `false` for
/// several releases and then eventually this option will be removed.
pub fn use_legacy_format(mut self, use_legacy_format: bool) -> Self {
self.use_legacy_format = use_legacy_format;
self
}
}
#[derive(Clone, Debug)]
@@ -943,6 +961,7 @@ impl ConnectionInternal for Database {
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}
write_params.use_legacy_format = options.use_legacy_format;
match NativeTable::create(
&table_uri,
@@ -1040,8 +1059,12 @@ impl ConnectionInternal for Database {
#[cfg(test)]
mod tests {
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use tempfile::tempdir;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
use super::*;
#[tokio::test]
@@ -1146,6 +1169,58 @@ mod tests {
assert_eq!(tables, vec!["table1".to_owned()]);
}
fn make_data() -> impl RecordBatchReader + Send + 'static {
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(id).batches(10, 2000)
}
#[tokio::test]
async fn test_create_table_v2() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("v1_test", make_data())
.execute()
.await
.unwrap();
// In v1 the row group size will trump max_batch_length
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 20);
let tbl = db
.create_table("v2_test", make_data())
.use_legacy_format(false)
.execute()
.await
.unwrap();
// In v2 the page size is much bigger than 50k so we should get a single batch
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
}
#[tokio::test]
async fn drop_table() {
let tmp_dir = tempdir().unwrap();

View File

@@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "openai")]
pub mod openai;
use lance::arrow::RecordBatchExt;
use std::{
@@ -51,8 +53,10 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
/// The type of the output data
/// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>;
/// Embed the input
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for the source column in the database
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for a given user query
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
}
/// Defines an embedding from input data into a lower-dimensional space
@@ -266,7 +270,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.embed(src_column.clone()) {
let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(

View File

@@ -0,0 +1,257 @@
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
Client,
};
use tokio::{runtime::Handle, task};
use crate::{Error, Result};
use super::EmbeddingFunction;
#[derive(Debug)]
pub enum EmbeddingModel {
TextEmbeddingAda002,
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
fn ndims(&self) -> usize {
match self {
Self::TextEmbeddingAda002 => 1536,
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
}
}
}
impl FromStr for EmbeddingModel {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
_ => Err(Error::InvalidInput {
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
}),
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
}
}
}
impl TryFrom<&str> for EmbeddingModel {
type Error = Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
value.parse()
}
}
pub struct OpenAIEmbeddingFunction {
model: EmbeddingModel,
api_key: String,
api_base: Option<String>,
org_id: Option<String>,
}
impl std::fmt::Debug for OpenAIEmbeddingFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
// let's be safe and not print the full API key
let creds_display = if self.api_key.len() > 6 {
format!(
"{}***{}",
&self.api_key[0..2],
&self.api_key[self.api_key.len() - 4..]
)
} else {
"[INVALID]".to_string()
};
f.debug_struct("OpenAI")
.field("model", &self.model)
.field("api_key", &creds_display)
.field("api_base", &self.api_base)
.field("org_id", &self.org_id)
.finish()
}
}
impl OpenAIEmbeddingFunction {
/// Create a new OpenAIEmbeddingFunction
pub fn new<A: Into<String>>(api_key: A) -> Self {
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
}
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
api_key: A,
model: M,
) -> crate::Result<Self>
where
M::Error: Into<crate::Error>,
{
Ok(Self::new_impl(
api_key.into(),
model.try_into().map_err(|e| e.into())?,
))
}
/// concrete implementation to reduce monomorphization
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
Self {
model,
api_key,
api_base: None,
org_id: None,
}
}
/// To use a API base url different from default "https://api.openai.com/v1"
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = Some(api_base.into());
self
}
/// To use a different OpenAI organization id other than default
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = Some(org_id.into());
self
}
}
impl EmbeddingFunction for OpenAIEmbeddingFunction {
fn name(&self) -> &str {
"openai"
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
let n_dims = self.model.ndims();
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
let len = source.len();
let n_dims = self.model.ndims();
let inner = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let arr = self.compute_inner(input)?;
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
// OpenAI only supports non-nullable string arrays
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
// OpenAI only supports string arrays
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
};
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
if let Some(api_base) = &self.api_base {
creds = creds.with_api_base(api_base.clone());
}
if let Some(org_id) = &self.org_id {
creds = creds.with_org_id(org_id.clone());
}
let input = match source.data_type() {
DataType::Utf8 => {
let array = source
.as_string::<i32>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
DataType::LargeUtf8 => {
let array = source
.as_string::<i64>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
_ => unreachable!("This should not happen. We already checked the data type."),
};
let client = Client::with_config(creds);
let embed = client.embeddings();
let req = CreateEmbeddingRequest {
model: self.model.to_string(),
input,
encoding_format: Some(EncodingFormat::Float),
user: None,
dimensions: None,
};
// TODO: request batching and retry logic
task::block_in_place(move || {
Handle::current().block_on(async {
let mut builder = Float32Builder::new();
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
message: format!("OpenAI embed request failed: {e}"),
})?;
for Embedding { embedding, .. } in res.data.iter() {
builder.append_slice(embedding);
}
Ok(builder.finish())
})
})
}
}

View File

@@ -17,7 +17,10 @@ use std::sync::Arc;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan;
use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance_datafusion::exec::execute_plan;
use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result};
@@ -425,6 +428,15 @@ impl Default for QueryExecutionOptions {
/// There are various kinds of queries but they all return results
/// in the same way.
pub trait ExecutableQuery {
/// Return the Datafusion [ExecutionPlan].
///
/// The caller can further optimize the plan or execute it.
///
fn create_plan(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<Arc<dyn ExecutionPlan>>> + Send;
/// Execute the query with default options and return results
///
/// See [`ExecutableQuery::execute_with_options`] for more details.
@@ -545,6 +557,13 @@ impl HasQuery for Query {
}
impl ExecutableQuery for Query {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.parent
.clone()
.create_plan(&self.clone().into_vector(), options)
.await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
@@ -718,12 +737,19 @@ impl VectorQuery {
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.base.parent.clone().create_plan(self, options).await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
Ok(SendableRecordBatchStream::from(
self.base.parent.clone().vector_query(self, options).await?,
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
))
}
}
@@ -972,6 +998,30 @@ mod tests {
}
}
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
if plan.name() == name {
return true;
}
plan.children()
.iter()
.any(|child| assert_plan_exists(child, name))
}
#[tokio::test]
async fn test_create_execute_plan() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let plan = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
.unwrap()
.create_plan(QueryExecutionOptions::default())
.await
.unwrap();
assert_plan_exists(&plan, "KNNFlatSearch");
assert_plan_exists(&plan, "ProjectionExec");
}
#[tokio::test]
async fn query_base_methods_on_vector_query() {
// Make sure VectorQuery can be used as a QueryBase
@@ -989,5 +1039,18 @@ mod tests {
let first_batch = results.next().await.unwrap().unwrap();
assert_eq!(first_batch.num_rows(), 1);
assert!(results.next().await.is_none());
// query with wrong vector dimension
let error_result = table
.vector_search(&[1.0, 2.0, 3.0])
.unwrap()
.limit(1)
.execute()
.await;
assert!(error_result
.err()
.unwrap()
.to_string()
.contains("No vector column found to match with the query vector dimension: 3"));
}
}

View File

@@ -1,6 +1,9 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
use crate::{
@@ -71,6 +74,13 @@ impl TableInternal for RemoteTable {
) -> Result<()> {
todo!()
}
async fn create_plan(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn plain_query(
&self,
_query: &Query,
@@ -78,13 +88,6 @@ impl TableInternal for RemoteTable {
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
}

View File

@@ -23,6 +23,7 @@ use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
@@ -35,6 +36,7 @@ use lance::dataset::{
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_datafusion::exec::execute_plan;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
@@ -231,7 +233,8 @@ pub struct WriteOptions {
// pub on_bad_vectors: BadVectorHandling,
/// Advanced parameters that can be used to customize table creation
///
/// If set, these will take precedence over any overlapping `OpenTableBuilder` options
/// Overlapping `OpenTableBuilder` options (e.g. [AddDataBuilder::mode]) will take
/// precedence over their counterparts in `WriteOptions` (e.g. [WriteParams::mode]).
pub lance_write_params: Option<WriteParams>,
}
@@ -366,16 +369,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>>;
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn add(
&self,
add: AddDataBuilder<NoData>,
@@ -1479,79 +1482,11 @@ impl NativeTable {
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}':
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.try_into_stream().await?)
let plan = self.create_plan(query, options).await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
}
}
@@ -1703,6 +1638,86 @@ impl TableInternal for NativeTable {
Ok(())
}
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.create_plan().await?)
}
async fn plain_query(
&self,
query: &Query,
@@ -1712,14 +1727,6 @@ impl TableInternal for NativeTable {
.await
}
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,
@@ -2550,8 +2557,7 @@ mod tests {
.unwrap()
.get_index_type(index_uuid)
.await
.unwrap()
.map(|index_type| index_type.to_string()),
.unwrap(),
Some("IVF".to_string())
);
assert_eq!(

View File

@@ -66,6 +66,19 @@ impl DatasetRef {
Ok(())
}
fn is_latest(&self) -> bool {
matches!(self, Self::Latest { .. })
}
async fn need_reload(&self) -> Result<bool> {
Ok(match self {
Self::Latest { dataset, .. } => {
dataset.latest_version_id().await? != dataset.version().version
}
Self::TimeTravel { dataset, version } => dataset.version().version != *version,
})
}
async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
match self {
Self::Latest { .. } => Ok(()),
@@ -129,7 +142,7 @@ impl DatasetConsistencyWrapper {
Self(Arc::new(RwLock::new(DatasetRef::Latest {
dataset,
read_consistency_interval,
last_consistency_check: None,
last_consistency_check: Some(Instant::now()),
})))
}
@@ -163,11 +176,16 @@ impl DatasetConsistencyWrapper {
/// Convert into a wrapper in latest version mode
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
self.0
.write()
.await
.as_latest(read_consistency_interval)
.await
if self.0.read().await.is_latest() {
return Ok(());
}
let mut write_guard = self.0.write().await;
if write_guard.is_latest() {
return Ok(());
}
write_guard.as_latest(read_consistency_interval).await
}
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
@@ -183,7 +201,18 @@ impl DatasetConsistencyWrapper {
}
pub async fn reload(&self) -> Result<()> {
self.0.write().await.reload().await
if !self.0.read().await.need_reload().await? {
return Ok(());
}
let mut write_guard = self.0.write().await;
// on lock escalation -- check if someone else has already reloaded
if !write_guard.need_reload().await? {
return Ok(());
}
// actually need reloading
write_guard.reload().await
}
/// Returns the version, if in time travel mode, or None otherwise

View File

@@ -101,7 +101,7 @@ pub fn validate_table_name(name: &str) -> Result<()> {
Ok(())
}
/// Find one default column to create index.
/// Find one default column to create index or perform vector query.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find one fixed size list array column.
let candidates = schema
@@ -118,14 +118,17 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
})
.collect::<Vec<_>>();
if candidates.is_empty() {
Err(Error::Schema {
message: "No vector column found to create index".to_string(),
Err(Error::InvalidInput {
message: format!(
"No vector column found to match with the query vector dimension: {}",
dim.unwrap_or_default()
),
})
} else if candidates.len() != 1 {
Err(Error::Schema {
message: format!(
"More than one vector columns found, \
please specify which column to create index: {:?}",
please specify which column to create index or query: {:?}",
candidates
),
})

View File

@@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed {
fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let len = source.len();
@@ -317,4 +317,9 @@ impl EmbeddingFunction for MockEmbed {
Ok(Arc::new(arr))
}
#[allow(unused_variables)]
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}