Compare commits

..

4 Commits

Author SHA1 Message Date
Lei Xu
ba3ad35b87 [maven-release-plugin] prepare for next development iteration 2024-05-29 14:55:11 -07:00
Lei Xu
28a0fea1d0 [maven-release-plugin] prepare release lancedb-parent-0.0.2 2024-05-29 14:55:08 -07:00
Rong Rong
b0e6c20be2 also add javadoc and source plugin 2024-05-28 21:13:34 -07:00
Rong Rong
d9965476c5 prepare for Java release 2024-05-28 20:57:20 -07:00
61 changed files with 8323 additions and 10101 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.5.1"
current_version = "0.5.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -94,6 +94,6 @@ jobs:
branch: ${{ github.ref }}
tags: true
- uses: ./.github/workflows/update_package_lock
if: ${{ !inputs.dry_run && inputs.other }}
if: ${{ inputs.dry_run }} == "false"
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -14,7 +14,7 @@ repos:
hooks:
- id: local-biome-check
name: biome check
entry: npx @biomejs/biome check --config-path nodejs/biome.json nodejs/
entry: npx biome check
language: system
types: [text]
files: "nodejs/.*"

View File

@@ -1,11 +1,5 @@
[workspace]
members = [
"rust/ffi/node",
"rust/lancedb",
"nodejs",
"python",
"java/core/lancedb-jni",
]
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"]
# Python package needs to be built by maturin.
exclude = ["python"]
resolver = "2"
@@ -20,11 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.12.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.12.1" }
lance-linalg = { "version" = "=0.12.1" }
lance-testing = { "version" = "=0.12.1" }
lance-datafusion = { "version" = "=0.12.1" }
lance = { "version" = "=0.11.0", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.11.0" }
lance-linalg = { "version" = "=0.11.0" }
lance-testing = { "version" = "=0.11.0" }
# Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"
@@ -36,7 +29,6 @@ arrow-arith = "51.0"
arrow-cast = "51.0"
async-trait = "0"
chrono = "0.4.35"
datafusion-physical-plan = "37.1"
half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits",
] }

View File

@@ -83,5 +83,5 @@ result = table.search([100, 100]).limit(2).to_pandas()
```
## Blogs, Tutorials & Videos
* 📈 <a href="https://blog.lancedb.com/benchmarking-random-access-in-lance/">2000x better performance with Lance over Parquet</a>
* 📈 <a href="https://blog.eto.ai/benchmarking-random-access-in-lance-ed690757a826">2000x better performance with Lance over Parquet</a>
* 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/docs/src/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a>

View File

@@ -1,14 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.1-SNAPSHOT</version>
<version>0.0.3-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -1,15 +1,34 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.1-SNAPSHOT</version>
<version>0.0.3-SNAPSHOT</version>
<packaging>pom</packaging>
<name>Lance Parent</name>
<description>LanceDB Java API</description>
<url>http://lancedb.com/</url>
<developers>
<developer>
<name>Lance DB Dev Group</name>
<email>dev@lancedb.com</email>
</developer>
</developers>
<licenses>
<license>
<name>The Apache Software License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
</license>
</licenses>
<scm>
<developerConnection>scm:git:git@github.com:lancedb/lancedb.git</developerConnection>
<tag>HEAD</tag>
<url>scm:git:git@github.com:lancedb/lancedb.git</url>
</scm>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@@ -64,6 +83,32 @@
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
@@ -111,7 +156,7 @@
<version>3.2.5</version>
<configuration>
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory" />
<useSystemClassLoader>false</useSystemClassLoader>
</configuration>
</plugin>
@@ -126,4 +171,49 @@
</plugins>
</pluginManagement>
</build>
<profiles>
<profile>
<id>deploy-to-ossrh</id>
<build>
<plugins>
<plugin>
<groupId>org.sonatype.central</groupId>
<artifactId>central-publishing-maven-plugin</artifactId>
<version>0.4.0</version>
<extensions>true</extensions>
<configuration>
<publishingServerId>ossrh</publishingServerId>
<tokenAuth>true</tokenAuth>
</configuration>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<version>1.6.13</version>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.5</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.5.1",
"version": "0.5.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.5.1",
"version": "0.5.0",
"cpu": [
"x64",
"arm64"

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.5.1",
"version": "0.5.0",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",

View File

@@ -704,9 +704,6 @@ export interface VectorIndex {
export interface IndexStats {
numIndexedRows: number | null
numUnindexedRows: number | null
index_type: string | null
distance_type: string | null
completed_at: string | null
}
/**

View File

@@ -509,8 +509,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
return (await results.body()).indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
uuid: index.index_uuid,
status: index.status
uuid: index.index_uuid
}))
}
@@ -521,10 +520,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
const body = await results.body()
return {
numIndexedRows: body?.num_indexed_rows,
numUnindexedRows: body?.num_unindexed_rows,
index_type: body?.index_type,
distance_type: body?.distance_type,
completed_at: body?.completed_at
numUnindexedRows: body?.num_unindexed_rows
}
}

View File

@@ -31,7 +31,6 @@ import {
Schema,
Struct,
type Table,
Type,
Utf8,
tableFromIPC,
} from "apache-arrow";
@@ -52,12 +51,7 @@ import {
makeArrowTable,
makeEmptyTable,
} from "../lancedb/arrow";
import {
EmbeddingFunction,
FieldOptions,
FunctionOptions,
} from "../lancedb/embedding/embedding_function";
import { EmbeddingFunctionConfig } from "../lancedb/embedding/registry";
import { type EmbeddingFunction } from "../lancedb/embedding/embedding_function";
// biome-ignore lint/suspicious/noExplicitAny: skip
function sampleRecords(): Array<Record<string, any>> {
@@ -286,46 +280,23 @@ describe("The function makeArrowTable", function () {
});
});
class DummyEmbedding extends EmbeddingFunction<string> {
toJSON(): Partial<FunctionOptions> {
return {};
}
class DummyEmbedding implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
public readonly embeddingDimension = 2;
public readonly embeddingDataType = new Float16();
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
ndims(): number {
return 2;
}
embeddingDataType() {
return new Float16();
}
}
class DummyEmbeddingWithNoDimension extends EmbeddingFunction<string> {
toJSON(): Partial<FunctionOptions> {
return {};
}
embeddingDataType(): Float {
return new Float16();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
async embed(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
}
const dummyEmbeddingConfig: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbedding(),
};
const dummyEmbeddingConfigWithNoDimension: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbeddingWithNoDimension(),
};
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
async embed(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
}
describe("convertToTable", function () {
it("will infer data types correctly", async function () {
@@ -360,7 +331,7 @@ describe("convertToTable", function () {
it("will apply embeddings", async function () {
const records = sampleRecords();
const table = await convertToTable(records, dummyEmbeddingConfig);
const table = await convertToTable(records, new DummyEmbedding());
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
@@ -369,7 +340,7 @@ describe("convertToTable", function () {
it("will fail if missing the embedding source column", async function () {
await expect(
convertToTable([{ id: 1 }], dummyEmbeddingConfig),
convertToTable([{ id: 1 }], new DummyEmbedding()),
).rejects.toThrow("'string' was not present");
});
@@ -380,7 +351,7 @@ describe("convertToTable", function () {
const table = makeEmptyTable(schema);
// If the embedding specifies the dimension we are fine
await fromTableToBuffer(table, dummyEmbeddingConfig);
await fromTableToBuffer(table, new DummyEmbedding());
// We can also supply a schema and should be ok
const schemaWithEmbedding = new Schema([
@@ -393,13 +364,13 @@ describe("convertToTable", function () {
]);
await fromTableToBuffer(
table,
dummyEmbeddingConfigWithNoDimension,
new DummyEmbeddingWithNoDimension(),
schemaWithEmbedding,
);
// Otherwise we will get an error
await expect(
fromTableToBuffer(table, dummyEmbeddingConfigWithNoDimension),
fromTableToBuffer(table, new DummyEmbeddingWithNoDimension()),
).rejects.toThrow("does not specify `embeddingDimension`");
});
@@ -412,7 +383,7 @@ describe("convertToTable", function () {
false,
),
]);
const table = await convertToTable([], dummyEmbeddingConfig, { schema });
const table = await convertToTable([], new DummyEmbedding(), { schema });
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
@@ -422,17 +393,16 @@ describe("convertToTable", function () {
it("will complain if embeddings present but schema missing embedding column", async function () {
const schema = new Schema([new Field("string", new Utf8(), false)]);
await expect(
convertToTable([], dummyEmbeddingConfig, { schema }),
convertToTable([], new DummyEmbedding(), { schema }),
).rejects.toThrow("column vector was missing");
});
it("will provide a nice error if run twice", async function () {
const records = sampleRecords();
const table = await convertToTable(records, dummyEmbeddingConfig);
const table = await convertToTable(records, new DummyEmbedding());
// fromTableToBuffer will try and apply the embeddings again
await expect(
fromTableToBuffer(table, dummyEmbeddingConfig),
fromTableToBuffer(table, new DummyEmbedding()),
).rejects.toThrow("already existed");
});
});

View File

@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Field, Float64, Schema } from "apache-arrow";
import * as tmp from "tmp";
import { Connection, Table, connect } from "../lancedb";
import { Connection, connect } from "../lancedb";
describe("when connecting", () => {
let tmpDir: tmp.DirResult;
@@ -87,39 +87,4 @@ describe("given a connection", () => {
tables = await db.tableNames({ startAfter: "a" });
expect(tables).toEqual(["b", "c"]);
});
it("should create tables in v2 mode", async () => {
const db = await connect(tmpDir.name);
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
// Create in v1 mode
let table = await db.createTable("test", data);
const isV2 = async (table: Table) => {
const data = await table.query().toArrow({ maxBatchLength: 100000 });
console.log(data.batches.length);
return data.batches.length < 5;
};
await expect(isV2(table)).resolves.toBe(false);
// Create in v2 mode
table = await db.createTable("test_v2", data, { useLegacyFormat: false });
await expect(isV2(table)).resolves.toBe(true);
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
// Create empty in v2 mode
const schema = new Schema([new Field("id", new Float64(), true)]);
table = await db.createEmptyTable("test_v2_empty", schema, {
useLegacyFormat: false,
});
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
});
});

View File

@@ -1,314 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as tmp from "tmp";
import { connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float,
Float16,
Float32,
Float64,
Schema,
Utf8,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide manual embeddings with multiple float datatype",
async (floatType) => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const data = [{ text: "hello" }, { text: "hello world" }];
const schema = new Schema([
new Field("vector", new FixedSizeList(3, new Field("item", floatType))),
new Field("text", new Utf8()),
]);
const func = new MockEmbeddingFunction();
const name = "test";
const db = await connect(tmpDir.name);
const table = await db.createTable(name, data, {
schema,
embeddingFunction: {
sourceColumn: "text",
function: func,
},
});
const res = await table.query().toArray();
expect([...res[0].vector]).toEqual([1, 2, 3]);
},
);
test.only.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide auto embeddings with multiple float datatypes",
async (floatType) => {
@register("test1")
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
@register("test")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("test")!.create();
const func2 = getRegistry()
.get<MockEmbeddingFunctionWithoutNDims>("test1")!
.create();
const schema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(floatType),
});
const schema2 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func2.vectorField({ datatype: floatType, dims: 3 }),
});
const schema3 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func.vectorField({
datatype: new FixedSizeList(3, new Field("item", floatType, true)),
dims: 3,
}),
});
const expectedSchema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", floatType, true)),
true,
),
]);
const stringSchema = JSON.stringify(schema, null, 2);
const stringSchema2 = JSON.stringify(schema2, null, 2);
const stringSchema3 = JSON.stringify(schema3, null, 2);
const stringExpectedSchema = JSON.stringify(expectedSchema, null, 2);
expect(stringSchema).toEqual(stringExpectedSchema);
expect(stringSchema2).toEqual(stringExpectedSchema);
expect(stringSchema3).toEqual(stringExpectedSchema);
},
);
});

View File

@@ -1,169 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import * as tmp from "tmp";
import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe.each([arrow, arrowOld])("LanceSchema", (arrow) => {
test("should preserve input order", async () => {
const schema = LanceSchema({
id: new arrow.Int32(),
text: new arrow.Utf8(),
vector: new arrow.Float32(),
});
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
});
});
describe("Registry", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should register a new item to the registry", async () => {
@register("mock-embedding")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = getRegistry()
.get<MockEmbeddingFunction>("mock-embedding")!
.create();
const schema = LanceSchema({
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{ schema },
);
const expected = [
[1, 2, 3],
[1, 2, 3],
];
const actual = await table.query().toArrow();
const vectors = actual
.getChild("vector")
?.toArray()
.map((x: unknown) => {
if (x instanceof arrow.Vector) {
return [...x];
} else {
return x;
}
});
expect(vectors).toEqual(expected);
});
test("should error if registering with the same name", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
register("mock-embedding")(MockEmbeddingFunction);
expect(() => register("mock-embedding")(MockEmbeddingFunction)).toThrow(
'Embedding function with alias "mock-embedding" already exists',
);
});
test("schema should contain correct metadata", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const expectedMetadata = new Map<string, string>([
[
"embedding_functions",
JSON.stringify([
{
sourceColumn: "text",
vectorColumn: "vector",
name: "MockEmbeddingFunction",
model: { someText: "hello" },
},
]),
],
]);
expect(schema.metadata).toEqual(expectedMetadata);
});
});

View File

@@ -16,12 +16,7 @@ import * as fs from "fs";
import * as path from "path";
import * as tmp from "tmp";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb";
import {
Table as ArrowTable,
Field,
FixedSizeList,
Float32,
@@ -29,20 +24,15 @@ import {
Int32,
Int64,
Schema,
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema, register } from "../lancedb/embedding";
} from "apache-arrow";
import { Table, connect } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
import { Index } from "../lancedb/indices";
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
describe("Given a table", () => {
let tmpDir: tmp.DirResult;
let table: Table;
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
const schema = new Schema([new Field("id", new Float64(), true)]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
@@ -93,43 +83,6 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1);
});
// https://github.com/lancedb/lancedb/issues/1293
test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])(
"can create empty table with non default float type: %s",
async (floatType) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello", vector: Array(512).fill(1.0) },
{ text: "hello world", vector: Array(512).fill(1.0) },
];
const f64Schema = new arrow.Schema([
new arrow.Field("text", new arrow.Utf8(), true),
new arrow.Field(
"vector",
new arrow.FixedSizeList(512, new arrow.Field("item", floatType)),
true,
),
]);
const f64Table = await db.createEmptyTable("f64", f64Schema, {
mode: "overwrite",
});
try {
await f64Table.add(data);
const res = await f64Table.query().toArray();
expect(res.length).toBe(2);
} catch (e) {
expect(e).toBeUndefined();
}
},
);
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
});
});
describe("When creating an index", () => {
@@ -494,99 +447,3 @@ describe("when optimizing a dataset", () => {
expect(stats.prune.oldVersionsRemoved).toBe(3);
});
});
describe("table.search", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
test("can search using a string", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 1;
}
embeddingDataType(): arrow.Float {
return new Float32();
}
// Hardcoded embeddings for the sake of testing
async computeQueryEmbeddings(_data: string) {
switch (_data) {
case "greetings":
return [0.1];
case "farewell":
return [0.2];
default:
return null as never;
}
}
// Hardcoded embeddings for the sake of testing
async computeSourceEmbeddings(data: string[]) {
return data.map((s) => {
switch (s) {
case "hello world":
return [0.1];
case "goodbye world":
return [0.2];
default:
return null as never;
}
});
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });
const results = await table.search("greetings").then((r) => r.toArray());
expect(results[0].text).toBe(data[0].text);
const results2 = await table.search("farewell").then((r) => r.toArray());
expect(results2[0].text).toBe(data[1].text);
});
test("rejects if no embedding function provided", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
expect(table.search("hello")).rejects.toThrow(
"No embedding functions are defined in the table",
);
});
test.each([
[0.4, 0.5, 0.599], // number[]
Float32Array.of(0.4, 0.5, 0.599), // Float32Array
Float64Array.of(0.4, 0.5, 0.599), // Float64Array
])("can search using vectorlike datatypes", async (vectorlike) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
// biome-ignore lint/suspicious/noExplicitAny: test
const results: any[] = await table.search(vectorlike).toArray();
expect(results.length).toBe(2);
expect(results[0].text).toBe(data[1].text);
});
});

View File

@@ -48,7 +48,7 @@
"noUnsafeFinally": "error",
"noUnsafeOptionalChaining": "error",
"noUnusedLabels": "error",
"noUnusedVariables": "warn",
"noUnusedVariables": "error",
"useIsNan": "error",
"useValidForDirection": "error",
"useYield": "error"
@@ -101,13 +101,7 @@
},
"overrides": [
{
"include": [
"**/*.ts",
"**/*.tsx",
"**/*.mts",
"**/*.cts",
"__test__/*.test.ts"
],
"include": ["**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"],
"linter": {
"rules": {
"correctness": {

View File

@@ -17,122 +17,24 @@ import {
Binary,
DataType,
Field,
FixedSizeBinary,
FixedSizeList,
Float,
type Float,
Float32,
Int,
LargeBinary,
List,
Null,
RecordBatch,
RecordBatchFileWriter,
RecordBatchStreamWriter,
Schema,
Struct,
Utf8,
Vector,
type Vector,
makeBuilder,
makeData,
type makeTable,
vectorFromArray,
} from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export type IntoVector = Float32Array | Float64Array | number[];
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;
}
export function isDataType(value: unknown): value is DataType {
return (
value instanceof DataType ||
DataType.isNull(value) ||
DataType.isInt(value) ||
DataType.isFloat(value) ||
DataType.isBinary(value) ||
DataType.isLargeBinary(value) ||
DataType.isUtf8(value) ||
DataType.isLargeUtf8(value) ||
DataType.isBool(value) ||
DataType.isDecimal(value) ||
DataType.isDate(value) ||
DataType.isTime(value) ||
DataType.isTimestamp(value) ||
DataType.isInterval(value) ||
DataType.isDuration(value) ||
DataType.isList(value) ||
DataType.isStruct(value) ||
DataType.isUnion(value) ||
DataType.isFixedSizeBinary(value) ||
DataType.isFixedSizeList(value) ||
DataType.isMap(value) ||
DataType.isDictionary(value)
);
}
export function isNull(value: unknown): value is Null {
return value instanceof Null || DataType.isNull(value);
}
export function isInt(value: unknown): value is Int {
return value instanceof Int || DataType.isInt(value);
}
export function isFloat(value: unknown): value is Float {
return value instanceof Float || DataType.isFloat(value);
}
export function isBinary(value: unknown): value is Binary {
return value instanceof Binary || DataType.isBinary(value);
}
export function isLargeBinary(value: unknown): value is LargeBinary {
return value instanceof LargeBinary || DataType.isLargeBinary(value);
}
export function isUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isUtf8(value);
}
export function isLargeUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isLargeUtf8(value);
}
export function isBool(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isBool(value);
}
export function isDecimal(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDecimal(value);
}
export function isDate(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDate(value);
}
export function isTime(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTime(value);
}
export function isTimestamp(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTimestamp(value);
}
export function isInterval(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isInterval(value);
}
export function isDuration(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDuration(value);
}
export function isList(value: unknown): value is List {
return value instanceof List || DataType.isList(value);
}
export function isStruct(value: unknown): value is Struct {
return value instanceof Struct || DataType.isStruct(value);
}
export function isUnion(value: unknown): value is Struct {
return value instanceof Struct || DataType.isUnion(value);
}
export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary {
return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value);
}
export function isFixedSizeList(value: unknown): value is FixedSizeList {
return value instanceof FixedSizeList || DataType.isFixedSizeList(value);
}
import { sanitizeSchema } from "./sanitize";
/** Data type accepted by NodeJS SDK */
export type Data = Record<string, unknown>[] | ArrowTable;
@@ -184,7 +86,6 @@ export class MakeArrowTableOptions {
vector: new VectorColumnOptions(),
};
embeddings?: EmbeddingFunction<unknown>;
embeddingFunction?: EmbeddingFunctionConfig;
/**
* If true then string columns will be encoded with dictionary encoding
@@ -297,7 +198,6 @@ export class MakeArrowTableOptions {
export function makeArrowTable(
data: Array<Record<string, unknown>>,
options?: Partial<MakeArrowTableOptions>,
metadata?: Map<string, string>,
): ArrowTable {
if (
data.length === 0 &&
@@ -309,11 +209,7 @@ export function makeArrowTable(
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(
opt.schema,
data,
options?.embeddingFunction,
);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings);
}
const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns
@@ -394,41 +290,20 @@ export function makeArrowTable(
// `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns);
const batchesFixed = firstTable.batches.map(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
(batch) => new RecordBatch(opt.schema!, batch.data),
);
let schema: Schema;
if (metadata !== undefined) {
let schemaMetadata = opt.schema.metadata;
if (schemaMetadata.size === 0) {
schemaMetadata = metadata;
} else {
for (const [key, entry] of schemaMetadata.entries()) {
schemaMetadata.set(key, entry);
}
}
schema = new Schema(opt.schema.fields, schemaMetadata);
} else {
schema = opt.schema;
}
return new ArrowTable(schema, batchesFixed);
return new ArrowTable(opt.schema, batchesFixed);
} else {
return new ArrowTable(columns);
}
const tbl = new ArrowTable(columns);
if (metadata !== undefined) {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
(<any>tbl.schema).metadata = metadata;
}
return tbl;
}
/**
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable(
schema: Schema,
metadata?: Map<string, string>,
): ArrowTable {
return makeArrowTable([], { schema }, metadata);
export function makeEmptyTable(schema: Schema): ArrowTable {
return makeArrowTable([], { schema });
}
/**
@@ -500,74 +375,13 @@ function makeVector(
}
}
/** Helper function to apply embeddings from metadata to an input table */
async function applyEmbeddingsFromMetadata(
table: ArrowTable,
schema: Schema,
): Promise<ArrowTable> {
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const columns = Object.fromEntries(
table.schema.fields.map((field) => [
field.name,
table.getChild(field.name)!,
]),
);
for (const functionEntry of functions.values()) {
const sourceColumn = columns[functionEntry.sourceColumn];
const destColumn = functionEntry.vectorColumn ?? "vector";
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`,
);
}
if (columns[destColumn] !== undefined) {
throw new Error(
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
);
}
if (table.batches.length > 1) {
throw new Error(
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
);
}
const values = sourceColumn.toArray();
const vectors =
await functionEntry.function.computeSourceEmbeddings(values);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
);
}
let destType: DataType;
const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
if (isFixedSizeList(dtype)) {
destType = sanitizeType(dtype);
} else {
throw new Error(
"Expected FixedSizeList as datatype for vector field, instead got: " +
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
const newTable = new ArrowTable(columns);
return alignTable(newTable, schema);
}
/** Helper function to apply embeddings to an input table */
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<ArrowTable> {
if (schema?.metadata.has("embedding_functions")) {
return applyEmbeddingsFromMetadata(table, schema!);
} else if (embeddings == null || embeddings === undefined) {
if (embeddings == null) {
return table;
}
@@ -585,9 +399,8 @@ async function applyEmbeddings<T>(
const newColumns = Object.fromEntries(colEntries);
const sourceColumn = newColumns[embeddings.sourceColumn];
const destColumn = embeddings.vectorColumn ?? "vector";
const innerDestType =
embeddings.function.embeddingDataType() ?? new Float32();
const destColumn = embeddings.destColumn ?? "vector";
const innerDestType = embeddings.embeddingDataType ?? new Float32();
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
@@ -601,9 +414,11 @@ async function applyEmbeddings<T>(
// if we call convertToTable with 0 records and a schema that includes the embedding
return table;
}
const dimensions = embeddings.function.ndims();
if (dimensions !== undefined) {
const destType = newVectorType(dimensions, innerDestType);
if (embeddings.embeddingDimension !== undefined) {
const destType = newVectorType(
embeddings.embeddingDimension,
innerDestType,
);
newColumns[destColumn] = makeVector([], destType);
} else if (schema != null) {
const destField = schema.fields.find((f) => f.name === destColumn);
@@ -631,9 +446,7 @@ async function applyEmbeddings<T>(
);
}
const values = sourceColumn.toArray();
const vectors = await embeddings.function.computeSourceEmbeddings(
values as T[],
);
const vectors = await embeddings.embed(values as T[]);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
@@ -673,9 +486,9 @@ async function applyEmbeddings<T>(
* embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns.
*/
export async function convertToTable(
export async function convertToTable<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
makeTableOptions?: Partial<MakeArrowTableOptions>,
): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions);
@@ -683,13 +496,13 @@ export async function convertToTable(
}
/** Creates the Arrow Type for a Vector column with dimension `dim` */
export function newVectorType<T extends Float>(
function newVectorType<T extends Float>(
dim: number,
innerType: T,
): FixedSizeList<T> {
// in Lance we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field("item", <T>sanitizeType(innerType), true);
const children = new Field<T>("item", innerType, true);
return new FixedSizeList(dim, children);
}
@@ -700,9 +513,9 @@ export function newVectorType<T extends Float>(
*
* `schema` is required if data is empty
*/
export async function fromRecordsToBuffer(
export async function fromRecordsToBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -720,9 +533,9 @@ export async function fromRecordsToBuffer(
*
* `schema` is required if data is empty
*/
export async function fromRecordsToStreamBuffer(
export async function fromRecordsToStreamBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -741,9 +554,9 @@ export async function fromRecordsToStreamBuffer(
*
* `schema` is required if the table is empty
*/
export async function fromTableToBuffer(
export async function fromTableToBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -762,19 +575,19 @@ export async function fromTableToBuffer(
*
* `schema` is required if the table is empty
*/
export async function fromDataToBuffer(
export async function fromDataToBuffer<T>(
data: Data,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
if (isArrowTable(data)) {
if (data instanceof ArrowTable) {
return fromTableToBuffer(data, embeddings, schema);
} else {
const table = await convertToTable(data, embeddings, { schema });
return fromTableToBuffer(table);
const table = await convertToTable(data);
return fromTableToBuffer(table, embeddings, schema);
}
}
@@ -786,9 +599,9 @@ export async function fromDataToBuffer(
*
* `schema` is required if the table is empty
*/
export async function fromTableToStreamBuffer(
export async function fromTableToStreamBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
@@ -841,7 +654,7 @@ export function createEmptyTable(schema: Schema): ArrowTable {
function validateSchemaEmbeddings(
schema: Schema,
data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunctionConfig | undefined,
embeddings: EmbeddingFunction<unknown> | undefined,
) {
const fields = [];
const missingEmbeddingFields = [];
@@ -851,25 +664,10 @@ function validateSchemaEmbeddings(
// if it does not, we add it to the list of missing embedding fields
// Finally, we check if those missing embedding fields are `this._embeddings`
// if they are not, we throw an error
for (let field of schema.fields) {
if (isFixedSizeList(field.type)) {
field = sanitizeField(field);
for (const field of schema.fields) {
if (field.type instanceof FixedSizeList) {
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(
schema.metadata.get("embedding_functions")!,
);
if (
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
undefined
) {
missingEmbeddingFields.push(field);
}
} else {
missingEmbeddingFields.push(field);
}
missingEmbeddingFields.push(field);
} else {
fields.push(field);
}

View File

@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, Schema } from "./arrow";
import {
fromTableToBuffer,
isArrowTable,
makeArrowTable,
makeEmptyTable,
} from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { Table as ArrowTable, Schema } from "apache-arrow";
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
import { Table } from "./table";
@@ -71,14 +65,6 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
/**
* If true then data files will be written with the legacy format
*
* The default is true while the new format is in beta
*/
useLegacyFormat?: boolean;
schema?: Schema;
embeddingFunction?: EmbeddingFunctionConfig;
}
export interface OpenTableOptions {
@@ -188,7 +174,6 @@ export class Connection {
cleanseStorageOptions(options?.storageOptions),
options?.indexCacheSize,
);
return new Table(innerTable);
}
@@ -211,25 +196,18 @@ export class Connection {
}
let table: ArrowTable;
if (isArrowTable(data)) {
if (data instanceof ArrowTable) {
table = data;
} else {
table = makeArrowTable(data, options);
table = makeArrowTable(data);
}
const buf = await fromTableToBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createTable(
name,
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
);
return new Table(innerTable);
}
@@ -249,21 +227,14 @@ export class Connection {
if (mode === "create" && existOk) {
mode = "exist_ok";
}
let metadata: Map<string, string> | undefined = undefined;
if (options?.embeddingFunction !== undefined) {
const embeddingFunction = options.embeddingFunction;
const registry = getRegistry();
metadata = registry.getTableMetadata([embeddingFunction]);
}
const table = makeEmptyTable(schema, metadata);
const table = makeEmptyTable(schema);
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createEmptyTable(
name,
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
);
return new Table(innerTable);
}

View File

@@ -1,4 +1,4 @@
// Copyright 2024 Lance Developers.
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,172 +12,67 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import "reflect-metadata";
import {
DataType,
Field,
FixedSizeList,
Float,
Float32,
type IntoVector,
isDataType,
isFixedSizeList,
isFloat,
newVectorType,
} from "../arrow";
import { sanitizeType } from "../sanitize";
/**
* Options for a given embedding function
*/
export interface FunctionOptions {
// biome-ignore lint/suspicious/noExplicitAny: options can be anything
[key: string]: any;
}
import { type Float } from "apache-arrow";
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export abstract class EmbeddingFunction<
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
T = any,
M extends FunctionOptions = FunctionOptions,
> {
export interface EmbeddingFunction<T> {
/**
* Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema
* It's important that any object returned by this method contains all the necessary
* information to recreate the embedding function
*
* It should return the same object that was passed to the constructor
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
*
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* constructor(options: {model: string, timeout: number}) {
* super();
* this.model = options.model;
* this.timeout = options.timeout;
* }
* toJSON() {
* return {
* model: this.model,
* timeout: this.timeout,
* };
* }
* ```
* The name of the column that will be used as input for the Embedding Function.
*/
abstract toJSON(): Partial<M>;
sourceColumn: string;
/**
* sourceField is used in combination with `LanceSchema` to provide a declarative data model
* The data type of the embedding
*
* @param optionsOrDatatype - The options for the field or the datatype
*
* @see {@link lancedb.LanceSchema}
* The embedding function should return `number`. This will be converted into
* an Arrow float array. By default this will be Float32 but this property can
* be used to control the conversion.
*/
sourceField(
optionsOrDatatype: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let datatype = isDataType(optionsOrDatatype)
? optionsOrDatatype
: optionsOrDatatype?.datatype;
if (!datatype) {
throw new Error("Datatype is required");
}
datatype = sanitizeType(datatype);
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("source_column_for", this);
return [datatype, metadata];
}
embeddingDataType?: Float;
/**
* vectorField is used in combination with `LanceSchema` to provide a declarative data model
* The dimension of the embedding
*
* @param options - The options for the field
*
* @see {@link lancedb.LanceSchema}
* This is optional, normally this can be determined by looking at the results of
* `embed`. If this is not specified, and there is an attempt to apply the embedding
* to an empty table, then that process will fail.
*/
vectorField(
optionsOrDatatype?: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType | undefined;
let vectorType: DataType;
let dims: number | undefined = this.ndims();
embeddingDimension?: number;
// `func.vectorField(new Float32())`
if (isDataType(optionsOrDatatype)) {
dtype = optionsOrDatatype;
} else {
// `func.vectorField({
// datatype: new Float32(),
// dims: 10
// })`
dims = dims ?? optionsOrDatatype?.dims;
dtype = optionsOrDatatype?.datatype;
}
/**
* The name of the column that will contain the embedding
*
* By default this is "vector"
*/
destColumn?: string;
if (dtype !== undefined) {
// `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))`
// or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})`
if (isFixedSizeList(dtype)) {
vectorType = dtype;
// `func.vectorField(new Float32())`
// or `func.vectorField({datatype: new Float32()})`
} else if (isFloat(dtype)) {
// No `ndims` impl and no `{dims: n}` provided;
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = newVectorType(dims, dtype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
} else {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = new FixedSizeList(
dims,
new Field("item", new Float32(), true),
);
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
return [vectorType, metadata];
}
/** The number of dimensions of the embeddings */
ndims(): number | undefined {
return undefined;
}
/** The datatype of the embeddings */
abstract embeddingDataType(): Float;
/**
* Should the source column be excluded from the resulting table
*
* By default the source column is included. Set this to true and
* only the embedding will be stored.
*/
excludeSource?: boolean;
/**
* Creates a vector representation for the given values.
*/
abstract computeSourceEmbeddings(
data: T[],
): Promise<number[][] | Float32Array[] | Float64Array[]>;
embed: (data: T[]) => Promise<number[][]>;
}
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);
/** Test if the input seems to be an embedding function */
export function isEmbeddingFunction<T>(
value: unknown,
): value is EmbeddingFunction<T> {
if (typeof value !== "object" || value === null) {
return false;
}
}
export interface FieldOptions<T extends DataType = DataType> {
datatype: T;
dims?: number;
if (!("sourceColumn" in value) || !("embed" in value)) {
return false;
}
return (
typeof value.sourceColumn === "string" && typeof value.embed === "function"
);
}

View File

@@ -1,113 +1,2 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { DataType, Field, Schema } from "../arrow";
import { isDataType } from "../arrow";
import { sanitizeType } from "../sanitize";
import { EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
export { EmbeddingFunction } from "./embedding_function";
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
export * from "./openai";
export * from "./registry";
/**
* Create a schema with embedding functions.
*
* @param fields
* @returns Schema
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* // ...
* }
* const func = new MyEmbeddingFunction();
* const schema = LanceSchema({
* id: new Int32(),
* text: func.sourceField(new Utf8()),
* vector: func.vectorField(),
* // optional: specify the datatype and/or dimensions
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}),
* });
*
* const table = await db.createTable("my_table", data, { schema });
* ```
*/
export function LanceSchema(
fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>,
): Schema {
const arrowFields: Field[] = [];
const embeddingFunctions = new Map<
EmbeddingFunction,
Partial<EmbeddingFunctionConfig>
>();
Object.entries(fields).forEach(([key, value]) => {
if (isDataType(value)) {
arrowFields.push(new Field(key, sanitizeType(value), true));
} else {
const [dtype, metadata] = value as [
object,
Map<string, EmbeddingFunction>,
];
arrowFields.push(new Field(key, sanitizeType(dtype), true));
parseEmbeddingFunctions(embeddingFunctions, key, metadata);
}
});
const registry = getRegistry();
const metadata = registry.getTableMetadata(
Array.from(embeddingFunctions.values()) as EmbeddingFunctionConfig[],
);
const schema = new Schema(arrowFields, metadata);
return schema;
}
function parseEmbeddingFunctions(
embeddingFunctions: Map<EmbeddingFunction, Partial<EmbeddingFunctionConfig>>,
key: string,
metadata: Map<string, EmbeddingFunction>,
): void {
if (metadata.has("source_column_for")) {
const embedFunction = metadata.get("source_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
sourceColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
sourceColumn: key,
function: embedFunction,
});
}
} else if (metadata.has("vector_column_for")) {
const embedFunction = metadata.get("vector_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
vectorColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
vectorColumn: key,
function: embedFunction,
});
}
}
}
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function";
export { OpenAIEmbeddingFunction } from "./openai";

View File

@@ -13,31 +13,17 @@
// limitations under the License.
import type OpenAI from "openai";
import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
import { type EmbeddingFunction } from "./embedding_function";
export type OpenAIOptions = {
apiKey?: string;
model?: string;
};
@register("openai")
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
string,
OpenAIOptions
> {
#openai: OpenAI;
#modelName: string;
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
super();
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) {
throw new Error("OpenAI API key is required");
}
const modelName = options?.model ?? "text-embedding-ada-002";
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
private readonly _openai: OpenAI;
private readonly _modelName: string;
constructor(
sourceColumn: string,
openAIKey: string,
modelName: string = "text-embedding-ada-002",
) {
/**
* @type {import("openai").default}
*/
@@ -50,40 +36,18 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
throw new Error("please install openai@^4.24.1 using npm install openai");
}
this.sourceColumn = sourceColumn;
const configuration = {
apiKey: openAIKey,
};
this.#openai = new Openai(configuration);
this.#modelName = modelName;
this._openai = new Openai(configuration);
this._modelName = modelName;
}
toJSON() {
return {
model: this.#modelName,
};
}
ndims(): number {
switch (this.#modelName) {
case "text-embedding-ada-002":
return 1536;
case "text-embedding-3-large":
return 3072;
case "text-embedding-3-small":
return 1536;
default:
return null as never;
}
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
const response = await this.#openai.embeddings.create({
model: this.#modelName,
async embed(data: string[]): Promise<number[][]> {
const response = await this._openai.embeddings.create({
model: this._modelName,
input: data,
});
@@ -94,15 +58,5 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
return embeddings;
}
async computeQueryEmbeddings(data: string): Promise<number[]> {
if (typeof data !== "string") {
throw new Error("Data must be a string");
}
const response = await this.#openai.embeddings.create({
model: this.#modelName,
input: data,
});
return response.data[0].embedding;
}
sourceColumn: string;
}

View File

@@ -1,176 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import type { EmbeddingFunction } from "./embedding_function";
import "reflect-metadata";
export interface EmbeddingFunctionOptions {
[key: string]: unknown;
}
export interface EmbeddingFunctionFactory<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: EmbeddingFunctionOptions): T;
}
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: EmbeddingFunctionOptions): T;
}
/**
* This is a singleton class used to register embedding functions
* and fetch them by name. It also handles serializing and deserializing.
* You can implement your own embedding function by subclassing EmbeddingFunction
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
/**
* Register an embedding function
* @param name The name of the function
* @param func The function to register
* @throws Error if the function is already registered
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
): (ctor: T) => any {
const self = this;
return function (ctor: T) {
if (!alias) {
alias = ctor.name;
}
if (self.#functions.has(alias)) {
throw new Error(
`Embedding function with alias "${alias}" already exists`,
);
}
self.#functions.set(alias, ctor);
Reflect.defineMetadata("lancedb::embedding::name", alias, ctor);
return ctor;
};
}
/**
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
name: string,
): EmbeddingFunctionCreate<T> | undefined {
const factory = this.#functions.get(name);
if (!factory) {
return undefined;
}
return {
create: function (options: EmbeddingFunctionOptions) {
return new factory(options) as unknown as T;
},
};
}
/**
* reset the registry to the initial state
*/
reset(this: EmbeddingFunctionRegistry) {
this.#functions.clear();
}
/**
* @ignore
*/
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,
): Map<string, EmbeddingFunctionConfig> {
if (!metadata.has("embedding_functions")) {
return new Map();
} else {
type FunctionConfig = {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunctionOptions;
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)
);
return new Map(
functions.map((f) => {
const fn = this.get(f.name);
if (!fn) {
throw new Error(`Function "${f.name}" not found in registry`);
}
return [
f.name,
{
sourceColumn: f.sourceColumn,
vectorColumn: f.vectorColumn,
function: this.get(f.name)!.create(f.model),
},
];
}),
);
}
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
functionToMetadata(conf: EmbeddingFunctionConfig): Record<string, any> {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const metadata: Record<string, any> = {};
const name = Reflect.getMetadata(
"lancedb::embedding::name",
conf.function.constructor,
);
metadata["sourceColumn"] = conf.sourceColumn;
metadata["vectorColumn"] = conf.vectorColumn ?? "vector";
metadata["name"] = name ?? conf.function.constructor.name;
metadata["model"] = conf.function.toJSON();
return metadata;
}
getTableMetadata(functions: EmbeddingFunctionConfig[]): Map<string, string> {
const metadata = new Map<string, string>();
const jsonData = functions.map((conf) => this.functionToMetadata(conf));
metadata.set("embedding_functions", JSON.stringify(jsonData));
return metadata;
}
}
const _REGISTRY = new EmbeddingFunctionRegistry();
export function register(name?: string) {
return _REGISTRY.register(name);
}
/**
* Utility function to get the global instance of the registry
* @returns `EmbeddingFunctionRegistry` The global instance of the registry
* @example
* ```ts
* const registry = getRegistry();
* const openai = registry.get("openai").create();
*/
export function getRegistry(): EmbeddingFunctionRegistry {
return _REGISTRY;
}
export interface EmbeddingFunctionConfig {
sourceColumn: string;
vectorColumn?: string;
function: EmbeddingFunction;
}

View File

@@ -12,12 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
tableFromIPC,
} from "./arrow";
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow";
import { type IvfPqOptions } from "./indices";
import {
RecordBatchIterator as NativeBatchIterator,
@@ -55,39 +50,6 @@ export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
}
/* eslint-enable */
class RecordBatchIterable<
NativeQueryType extends NativeQuery | NativeVectorQuery,
> implements AsyncIterable<RecordBatch>
{
private inner: NativeQueryType;
private options?: QueryExecutionOptions;
constructor(inner: NativeQueryType, options?: QueryExecutionOptions) {
this.inner = inner;
this.options = options;
}
// biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength),
);
}
}
/**
* Options that control the behavior of a particular query execution
*/
export interface QueryExecutionOptions {
/**
* The maximum number of rows to return in a single batch
*
* Batches may have fewer rows if the underlying data is stored
* in smaller chunks.
*/
maxBatchLength?: number;
}
/** Common methods supported by all query types */
export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery,
@@ -146,12 +108,9 @@ export class QueryBase<
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/
select(
columns: string[] | Map<string, string> | Record<string, string> | string,
columns: string[] | Map<string, string> | Record<string, string>,
): QueryType {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
}
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
@@ -174,10 +133,8 @@ export class QueryBase<
return this as unknown as QueryType;
}
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
return this.inner.execute(options?.maxBatchLength);
protected nativeExecute(): Promise<NativeBatchIterator> {
return this.inner.execute();
}
/**
@@ -191,10 +148,8 @@ export class QueryBase<
* single query)
*
*/
protected execute(
options?: Partial<QueryExecutionOptions>,
): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute(options));
protected execute(): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute());
}
// biome-ignore lint/suspicious/noExplicitAny: skip
@@ -204,18 +159,18 @@ export class QueryBase<
}
/** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
async toArrow(): Promise<ArrowTable> {
const batches = [];
for await (const batch of new RecordBatchIterable(this.inner, options)) {
for await (const batch of this) {
batches.push(batch);
}
return new ArrowTable(batches);
}
/** Collect the results as an array of objects. */
// biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[]
async toArray(options?: Partial<QueryExecutionOptions>): Promise<any[]> {
const tbl = await this.toArrow(options);
async toArray(): Promise<unknown[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}
}
@@ -414,8 +369,9 @@ export class Query extends QueryBase<NativeQuery, Query> {
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
nearestTo(vector: unknown): VectorQuery {
// biome-ignore lint/suspicious/noExplicitAny: skip
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
return new VectorQuery(vectorQuery);
}
}

View File

@@ -20,7 +20,6 @@
// comes from the exact same library instance. This is not always the case
// and so we must sanitize the input to ensure that it is compatible.
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
import {
Binary,
Bool,
@@ -76,9 +75,10 @@ import {
Uint64,
Union,
Utf8,
} from "./arrow";
} from "apache-arrow";
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
export function sanitizeMetadata(
function sanitizeMetadata(
metadataLike?: unknown,
): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) {
@@ -97,7 +97,7 @@ export function sanitizeMetadata(
return metadataLike as Map<string, string>;
}
export function sanitizeInt(typeLike: object) {
function sanitizeInt(typeLike: object) {
if (
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number" ||
@@ -111,14 +111,14 @@ export function sanitizeInt(typeLike: object) {
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
}
export function sanitizeFloat(typeLike: object) {
function sanitizeFloat(typeLike: object) {
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
throw Error("Expected a Float Type to have a `precision` property");
}
return new Float(typeLike.precision as Precision);
}
export function sanitizeDecimal(typeLike: object) {
function sanitizeDecimal(typeLike: object) {
if (
!("scale" in typeLike) ||
typeof typeLike.scale !== "number" ||
@@ -134,14 +134,14 @@ export function sanitizeDecimal(typeLike: object) {
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
}
export function sanitizeDate(typeLike: object) {
function sanitizeDate(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Date type to have a `unit` property");
}
return new Date_(typeLike.unit as DateUnit);
}
export function sanitizeTime(typeLike: object) {
function sanitizeTime(typeLike: object) {
if (
!("unit" in typeLike) ||
typeof typeLike.unit !== "number" ||
@@ -155,7 +155,7 @@ export function sanitizeTime(typeLike: object) {
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
}
export function sanitizeTimestamp(typeLike: object) {
function sanitizeTimestamp(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Timestamp type to have a `unit` property");
}
@@ -166,7 +166,7 @@ export function sanitizeTimestamp(typeLike: object) {
return new Timestamp(typeLike.unit, timezone);
}
export function sanitizeTypedTimestamp(
function sanitizeTypedTimestamp(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
Datatype:
@@ -182,14 +182,14 @@ export function sanitizeTypedTimestamp(
return new Datatype(timezone);
}
export function sanitizeInterval(typeLike: object) {
function sanitizeInterval(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected an Interval type to have a `unit` property");
}
return new Interval(typeLike.unit);
}
export function sanitizeList(typeLike: object) {
function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a List type to have an array-like `children` property",
@@ -201,7 +201,7 @@ export function sanitizeList(typeLike: object) {
return new List(sanitizeField(typeLike.children[0]));
}
export function sanitizeStruct(typeLike: object) {
function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Struct type to have an array-like `children` property",
@@ -210,7 +210,7 @@ export function sanitizeStruct(typeLike: object) {
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
}
export function sanitizeUnion(typeLike: object) {
function sanitizeUnion(typeLike: object) {
if (
!("typeIds" in typeLike) ||
!("mode" in typeLike) ||
@@ -234,7 +234,7 @@ export function sanitizeUnion(typeLike: object) {
);
}
export function sanitizeTypedUnion(
function sanitizeTypedUnion(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
UnionType: typeof DenseUnion | typeof SparseUnion,
@@ -256,7 +256,7 @@ export function sanitizeTypedUnion(
);
}
export function sanitizeFixedSizeBinary(typeLike: object) {
function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property",
@@ -265,7 +265,7 @@ export function sanitizeFixedSizeBinary(typeLike: object) {
return new FixedSizeBinary(typeLike.byteWidth);
}
export function sanitizeFixedSizeList(typeLike: object) {
function sanitizeFixedSizeList(typeLike: object) {
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
throw Error("Expected a FixedSizeList type to have a `listSize` property");
}
@@ -283,7 +283,7 @@ export function sanitizeFixedSizeList(typeLike: object) {
);
}
export function sanitizeMap(typeLike: object) {
function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Map type to have an array-like `children` property",
@@ -300,14 +300,14 @@ export function sanitizeMap(typeLike: object) {
);
}
export function sanitizeDuration(typeLike: object) {
function sanitizeDuration(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Duration type to have a `unit` property");
}
return new Duration(typeLike.unit);
}
export function sanitizeDictionary(typeLike: object) {
function sanitizeDictionary(typeLike: object) {
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
throw Error("Expected a Dictionary type to have an `id` property");
}
@@ -329,7 +329,7 @@ export function sanitizeDictionary(typeLike: object) {
}
// biome-ignore lint/suspicious/noExplicitAny: skip
export function sanitizeType(typeLike: unknown): DataType<any> {
function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
@@ -449,7 +449,7 @@ export function sanitizeType(typeLike: unknown): DataType<any> {
}
}
export function sanitizeField(fieldLike: unknown): Field {
function sanitizeField(fieldLike: unknown): Field {
if (fieldLike instanceof Field) {
return fieldLike;
}

View File

@@ -12,16 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Table as ArrowTable,
Data,
IntoVector,
Schema,
fromDataToBuffer,
tableFromIPC,
} from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { Schema, tableFromIPC } from "apache-arrow";
import { Data, fromDataToBuffer } from "./arrow";
import { IndexOptions } from "./indices";
import {
AddColumnsSql,
@@ -31,8 +23,8 @@ import {
Table as _NativeTable,
} from "./native";
import { Query, VectorQuery } from "./query";
export { IndexConfig } from "./native";
export { IndexConfig } from "./native";
/**
* Options for adding data to a table.
*/
@@ -117,14 +109,6 @@ export class Table {
return this.inner.display();
}
async #getEmbeddingFunctions(): Promise<
Map<string, EmbeddingFunctionConfig>
> {
const schema = await this.schema();
const registry = getRegistry();
return registry.parseFunctions(schema.metadata);
}
/** Get the schema of the table. */
async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema();
@@ -138,15 +122,8 @@ export class Table {
*/
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
const buffer = await fromDataToBuffer(data);
await this.inner.add(buffer, mode);
}
@@ -286,40 +263,6 @@ export class Table {
return new Query(this.inner);
}
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
*/
search(query: string): Promise<VectorQuery>;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {IntoVector} query - the query vector
*/
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.#getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
}
}
/**
* Search the table with a given query vector.
*
@@ -327,7 +270,7 @@ export class Table {
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
vectorSearch(vector: IntoVector): VectorQuery {
vectorSearch(vector: unknown): VectorQuery {
return this.query().nearestTo(vector);
}
@@ -473,9 +416,4 @@ export class Table {
async listIndices(): Promise<IndexConfig[]> {
return await this.inner.listIndices();
}
/** Return the table as an arrow table */
async toArrow(): Promise<ArrowTable> {
return await this.query().toArrow();
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.5.1",
"version": "0.5.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.5.1",
"version": "0.5.0",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.5.1",
"version": "0.5.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.5.1",
"version": "0.5.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.5.1",
"version": "0.5.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

15383
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,8 @@
{
"name": "@lancedb/lancedb",
"version": "0.5.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
"./embedding": "./dist/embedding/index.js"
},
"types": "dist/index.d.ts",
"version": "0.5.0",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"napi": {
"name": "lancedb",
"triples": {
@@ -66,7 +62,6 @@
},
"dependencies": {
"apache-arrow": "^15.0.0",
"openai": "^4.29.2",
"reflect-metadata": "^0.2.2"
"openai": "^4.29.2"
}
}

View File

@@ -126,7 +126,6 @@ impl Connection {
buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
@@ -137,9 +136,6 @@ impl Connection {
builder = builder.storage_option(key, value);
}
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder
.execute()
.await
@@ -154,7 +150,6 @@ impl Connection {
schema_buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
@@ -169,9 +164,6 @@ impl Connection {
builder = builder.storage_option(key, value);
}
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder
.execute()
.await

View File

@@ -56,7 +56,6 @@ pub enum WriteMode {
/// Write options when creating a Table.
#[napi(object)]
pub struct WriteOptions {
/// Write mode for writing to a table.
pub mode: Option<WriteMode>,
}

View File

@@ -15,7 +15,6 @@
use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::Select;
use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*;
@@ -63,21 +62,10 @@ impl Query {
}
#[napi]
pub async fn execute(
&self,
max_batch_length: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}
@@ -137,21 +125,10 @@ impl VectorQuery {
}
#[napi]
pub async fn execute(
&self,
max_batch_length: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}

View File

@@ -7,9 +7,7 @@
"outDir": "./dist",
"strict": true,
"allowJs": true,
"resolveJsonModule": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true
"resolveJsonModule": true
},
"exclude": ["./dist/*"],
"typedocOptions": {

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.8.2"
current_version = "0.8.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.8.2"
version = "0.8.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -3,7 +3,7 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"pylance==0.12.1",
"pylance==0.11.0",
"ratelimiter~=1.0",
"requests>=2.31.0",
"retry>=0.9.2",

View File

@@ -24,7 +24,6 @@ class Connection(object):
mode: str,
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
async def create_empty_table(
self,
@@ -32,7 +31,6 @@ class Connection(object):
mode: str,
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
class Table:
@@ -74,7 +72,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
async def execute(self) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...

View File

@@ -509,7 +509,7 @@ class AsyncConnection(object):
return self._inner.__repr__()
def __enter__(self):
return self
self
def __exit__(self, *_):
self.close()
@@ -558,8 +558,6 @@ class AsyncConnection(object):
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
use_legacy_format: Optional[bool] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -602,9 +600,6 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
use_legacy_format: bool, optional, default True
If True, use the legacy format for the table. If False, use the new format.
The default is True while the new format is in beta.
Returns
@@ -766,11 +761,7 @@ class AsyncConnection(object):
if data is None:
new_table = await self._inner.create_empty_table(
name,
mode,
schema,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
name, mode, schema, storage_options=storage_options
)
else:
data = data_to_reader(data, schema)
@@ -779,7 +770,6 @@ class AsyncConnection(object):
mode,
data,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
)
return AsyncTable(new_table)
@@ -789,7 +779,7 @@ class AsyncConnection(object):
name: str,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> AsyncTable:
) -> Table:
"""Open a Lance Table in the database.
Parameters

View File

@@ -153,7 +153,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts

View File

@@ -73,8 +73,6 @@ class BedRockText(TextEmbeddingFunction):
assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
@@ -89,29 +87,21 @@ class BedRockText(TextEmbeddingFunction):
# TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1":
return 1536
elif self.name in [
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
# TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
raise ValueError(f"Unknown model name: {self.name}")
def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> List[List[float]]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)
return self.compute_source_embeddings(query)
def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> List[List[float]]:
texts = self.sanitize_input(texts)
# assume source input type if not passed by `compute_query_embeddings`
kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type
return self.generate_embeddings(texts, **kwargs)
return self.generate_embeddings(texts)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
@@ -131,11 +121,11 @@ class BedRockText(TextEmbeddingFunction):
"""
results = []
for text in texts:
response = self._generate_embedding(text, *args, **kwargs)
response = self._generate_embedding(text)
results.append(response)
return results
def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]:
def _generate_embedding(self, text: str) -> List[float]:
"""
Get the embeddings for the given texts
@@ -151,12 +141,14 @@ class BedRockText(TextEmbeddingFunction):
"""
# format input body for provider
provider = self.name.split(".")[0]
input_body = {**kwargs}
_model_kwargs = {}
input_body = {**_model_kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body.pop("input_type", None)
input_body["inputText"] = text
body = json.dumps(input_body)

View File

@@ -19,7 +19,7 @@ import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT
from .utils import api_key_not_found_help
@register("cohere")
@@ -32,36 +32,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. List of acceptable models:
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
source_input_type: str, default "search_document"
The input type for the source column in the database
query_input_type: str, default "search_query"
The input type for the query column in the database
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
The name of the model to use. See the Cohere documentation for
a list of available models.
Examples
--------
@@ -89,39 +61,14 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
"""
name: str = "embed-multilingual-v2.0"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
client: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
if self.name in [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v2.0",
]:
return 1024
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
return 384
elif self.name == "embed-english-v2.0":
return 4096
elif self.name == "embed-multilingual-v2.0":
return 768
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or self.source_input_type
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
return 768
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
@@ -131,10 +78,9 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
return [emb for emb in rs.embeddings]

View File

@@ -1113,22 +1113,11 @@ class AsyncQueryBase(object):
self._inner.limit(limit)
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
async def to_batches(self) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
Parameters
----------
max_batch_length: Optional[int]
The maximum number of selected records in a single RecordBatch object.
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
"""
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
return AsyncRecordBatchReader(await self._inner.execute())
async def to_arrow(self) -> pa.Table:
"""

View File

@@ -296,13 +296,6 @@ async def test_close(tmp_path):
await db.table_names()
@pytest.mark.asyncio
async def test_context_manager(tmp_path):
with await lancedb.connect_async(tmp_path) as db:
assert db.is_open()
assert not db.is_open()
@pytest.mark.asyncio
async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
@@ -507,52 +500,6 @@ def test_empty_or_nonexistent_table(tmp_path):
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
def make_table():
return pa.table([pa.array([x for x in range(10 * 1024)])], names=["x"])
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table("test", data=make_data(), schema=schema)
async def is_in_v2_mode(tbl):
batches = await tbl.query().to_batches(max_batch_length=1024 * 10)
num_batches = 0
async for batch in batches:
num_batches += 1
return num_batches < 10
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
assert await is_in_v2_mode(tbl)
# Add data (should remain in v2 mode)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(

View File

@@ -91,7 +91,6 @@ impl Connection {
mode: &str,
data: &PyAny,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -104,10 +103,6 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
@@ -120,7 +115,6 @@ impl Connection {
mode: &str,
schema: &PyAny,
storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -134,10 +128,6 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))

View File

@@ -15,7 +15,6 @@
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
@@ -62,14 +61,10 @@ impl Query {
Ok(VectorQuery { inner })
}
pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
@@ -120,14 +115,10 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.5.1"
version = "0.5.0"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.5.1"
version = "0.5.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -19,13 +19,11 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
chrono = { workspace = true }
datafusion-physical-plan.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-datafusion.workspace = true
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
@@ -40,12 +38,11 @@ url.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true}
[dev-dependencies]
tempfile = "3.5.0"
@@ -65,10 +62,4 @@ default = []
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]
polars = ["dep:polars-arrow", "dep:polars"]
[[example]]
name = "openai"
required-features = ["openai"]

View File

@@ -1,82 +0,0 @@
use std::{iter::once, sync::Arc};
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
api_key,
"text-embedding-3-large",
)?);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("openai", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"text",
"openai",
Some("embeddings"),
))?
.execute()
.await?;
// there is no equivalent to '.search(<query>)' yet
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Closest match: {}", text);
Ok(())
}
fn make_data() -> impl IntoArrow {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("text", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
]);
let id = Int32Array::from(vec![1, 2, 3, 4]);
let text = StringArray::from_iter_values(vec![
"Black T-Shirt",
"Leather Jacket",
"Winter Parka",
"Hooded Sweatshirt",
]);
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
let schema = Arc::new(schema);
let rb = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
)
.unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
}

View File

@@ -140,7 +140,6 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
pub(crate) write_options: WriteOptions,
pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
pub(crate) use_legacy_format: bool,
}
// Builder methods that only apply when we have initial data
@@ -154,7 +153,6 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
write_options: WriteOptions::default(),
table_definition: None,
embeddings: Vec::new(),
use_legacy_format: true,
}
}
@@ -186,7 +184,6 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
mode: self.mode,
write_options: self.write_options,
embeddings: self.embeddings,
use_legacy_format: self.use_legacy_format,
};
Ok((data, builder))
}
@@ -220,7 +217,6 @@ impl CreateTableBuilder<false, NoData> {
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
embeddings: Vec::new(),
use_legacy_format: false,
}
}
@@ -282,20 +278,6 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
}
self
}
/// Set to true to use the v1 format for data files
///
/// This is currently defaulted to true and can be set to false to opt-in
/// to the new format. This should only be used for experimentation and
/// evaluation. The new format is still in beta and may change in ways that
/// are not backwards compatible.
///
/// Once the new format is stable, the default will change to `false` for
/// several releases and then eventually this option will be removed.
pub fn use_legacy_format(mut self, use_legacy_format: bool) -> Self {
self.use_legacy_format = use_legacy_format;
self
}
}
#[derive(Clone, Debug)]
@@ -961,7 +943,6 @@ impl ConnectionInternal for Database {
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}
write_params.use_legacy_format = options.use_legacy_format;
match NativeTable::create(
&table_uri,
@@ -1059,12 +1040,8 @@ impl ConnectionInternal for Database {
#[cfg(test)]
mod tests {
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use tempfile::tempdir;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
use super::*;
#[tokio::test]
@@ -1169,58 +1146,6 @@ mod tests {
assert_eq!(tables, vec!["table1".to_owned()]);
}
fn make_data() -> impl RecordBatchReader + Send + 'static {
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(id).batches(10, 2000)
}
#[tokio::test]
async fn test_create_table_v2() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("v1_test", make_data())
.execute()
.await
.unwrap();
// In v1 the row group size will trump max_batch_length
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 20);
let tbl = db
.create_table("v2_test", make_data())
.use_legacy_format(false)
.execute()
.await
.unwrap();
// In v2 the page size is much bigger than 50k so we should get a single batch
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
}
#[tokio::test]
async fn drop_table() {
let tmp_dir = tempdir().unwrap();

View File

@@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "openai")]
pub mod openai;
use lance::arrow::RecordBatchExt;
use std::{
@@ -53,10 +51,8 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
/// The type of the output data
/// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>;
/// Compute the embeddings for the source column in the database
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for a given user query
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Embed the input
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
}
/// Defines an embedding from input data into a lower-dimensional space
@@ -270,7 +266,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.compute_source_embeddings(src_column.clone()) {
let embedding = match func.embed(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(

View File

@@ -1,257 +0,0 @@
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
Client,
};
use tokio::{runtime::Handle, task};
use crate::{Error, Result};
use super::EmbeddingFunction;
#[derive(Debug)]
pub enum EmbeddingModel {
TextEmbeddingAda002,
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
fn ndims(&self) -> usize {
match self {
Self::TextEmbeddingAda002 => 1536,
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
}
}
}
impl FromStr for EmbeddingModel {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
_ => Err(Error::InvalidInput {
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
}),
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
}
}
}
impl TryFrom<&str> for EmbeddingModel {
type Error = Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
value.parse()
}
}
pub struct OpenAIEmbeddingFunction {
model: EmbeddingModel,
api_key: String,
api_base: Option<String>,
org_id: Option<String>,
}
impl std::fmt::Debug for OpenAIEmbeddingFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
// let's be safe and not print the full API key
let creds_display = if self.api_key.len() > 6 {
format!(
"{}***{}",
&self.api_key[0..2],
&self.api_key[self.api_key.len() - 4..]
)
} else {
"[INVALID]".to_string()
};
f.debug_struct("OpenAI")
.field("model", &self.model)
.field("api_key", &creds_display)
.field("api_base", &self.api_base)
.field("org_id", &self.org_id)
.finish()
}
}
impl OpenAIEmbeddingFunction {
/// Create a new OpenAIEmbeddingFunction
pub fn new<A: Into<String>>(api_key: A) -> Self {
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
}
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
api_key: A,
model: M,
) -> crate::Result<Self>
where
M::Error: Into<crate::Error>,
{
Ok(Self::new_impl(
api_key.into(),
model.try_into().map_err(|e| e.into())?,
))
}
/// concrete implementation to reduce monomorphization
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
Self {
model,
api_key,
api_base: None,
org_id: None,
}
}
/// To use a API base url different from default "https://api.openai.com/v1"
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = Some(api_base.into());
self
}
/// To use a different OpenAI organization id other than default
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = Some(org_id.into());
self
}
}
impl EmbeddingFunction for OpenAIEmbeddingFunction {
fn name(&self) -> &str {
"openai"
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
let n_dims = self.model.ndims();
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
let len = source.len();
let n_dims = self.model.ndims();
let inner = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let arr = self.compute_inner(input)?;
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
// OpenAI only supports non-nullable string arrays
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
// OpenAI only supports string arrays
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
};
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
if let Some(api_base) = &self.api_base {
creds = creds.with_api_base(api_base.clone());
}
if let Some(org_id) = &self.org_id {
creds = creds.with_org_id(org_id.clone());
}
let input = match source.data_type() {
DataType::Utf8 => {
let array = source
.as_string::<i32>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
DataType::LargeUtf8 => {
let array = source
.as_string::<i64>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
_ => unreachable!("This should not happen. We already checked the data type."),
};
let client = Client::with_config(creds);
let embed = client.embeddings();
let req = CreateEmbeddingRequest {
model: self.model.to_string(),
input,
encoding_format: Some(EncodingFormat::Float),
user: None,
dimensions: None,
};
// TODO: request batching and retry logic
task::block_in_place(move || {
Handle::current().block_on(async {
let mut builder = Float32Builder::new();
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
message: format!("OpenAI embed request failed: {e}"),
})?;
for Embedding { embedding, .. } in res.data.iter() {
builder.append_slice(embedding);
}
Ok(builder.finish())
})
})
}
}

View File

@@ -17,10 +17,7 @@ use std::sync::Arc;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan;
use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance_datafusion::exec::execute_plan;
use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result};
@@ -428,15 +425,6 @@ impl Default for QueryExecutionOptions {
/// There are various kinds of queries but they all return results
/// in the same way.
pub trait ExecutableQuery {
/// Return the Datafusion [ExecutionPlan].
///
/// The caller can further optimize the plan or execute it.
///
fn create_plan(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<Arc<dyn ExecutionPlan>>> + Send;
/// Execute the query with default options and return results
///
/// See [`ExecutableQuery::execute_with_options`] for more details.
@@ -557,13 +545,6 @@ impl HasQuery for Query {
}
impl ExecutableQuery for Query {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.parent
.clone()
.create_plan(&self.clone().into_vector(), options)
.await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
@@ -737,19 +718,12 @@ impl VectorQuery {
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.base.parent.clone().create_plan(self, options).await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
self.base.parent.clone().vector_query(self, options).await?,
))
}
}
@@ -998,30 +972,6 @@ mod tests {
}
}
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
if plan.name() == name {
return true;
}
plan.children()
.iter()
.any(|child| assert_plan_exists(child, name))
}
#[tokio::test]
async fn test_create_execute_plan() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let plan = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
.unwrap()
.create_plan(QueryExecutionOptions::default())
.await
.unwrap();
assert_plan_exists(&plan, "KNNFlatSearch");
assert_plan_exists(&plan, "ProjectionExec");
}
#[tokio::test]
async fn query_base_methods_on_vector_query() {
// Make sure VectorQuery can be used as a QueryBase
@@ -1039,18 +989,5 @@ mod tests {
let first_batch = results.next().await.unwrap().unwrap();
assert_eq!(first_batch.num_rows(), 1);
assert!(results.next().await.is_none());
// query with wrong vector dimension
let error_result = table
.vector_search(&[1.0, 2.0, 3.0])
.unwrap()
.limit(1)
.execute()
.await;
assert!(error_result
.err()
.unwrap()
.to_string()
.contains("No vector column found to match with the query vector dimension: 3"));
}
}

View File

@@ -1,9 +1,6 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
use crate::{
@@ -74,13 +71,6 @@ impl TableInternal for RemoteTable {
) -> Result<()> {
todo!()
}
async fn create_plan(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn plain_query(
&self,
_query: &Query,
@@ -88,6 +78,13 @@ impl TableInternal for RemoteTable {
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
}

View File

@@ -23,7 +23,6 @@ use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
@@ -36,7 +35,6 @@ use lance::dataset::{
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_datafusion::exec::execute_plan;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
@@ -233,8 +231,7 @@ pub struct WriteOptions {
// pub on_bad_vectors: BadVectorHandling,
/// Advanced parameters that can be used to customize table creation
///
/// Overlapping `OpenTableBuilder` options (e.g. [AddDataBuilder::mode]) will take
/// precedence over their counterparts in `WriteOptions` (e.g. [WriteParams::mode]).
/// If set, these will take precedence over any overlapping `OpenTableBuilder` options
pub lance_write_params: Option<WriteParams>,
}
@@ -369,16 +366,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>>;
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn add(
&self,
add: AddDataBuilder<NoData>,
@@ -1482,11 +1479,79 @@ impl NativeTable {
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let plan = self.create_plan(query, options).await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}':
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.try_into_stream().await?)
}
}
@@ -1638,86 +1703,6 @@ impl TableInternal for NativeTable {
Ok(())
}
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.create_plan().await?)
}
async fn plain_query(
&self,
query: &Query,
@@ -1727,6 +1712,14 @@ impl TableInternal for NativeTable {
.await
}
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,
@@ -1758,7 +1751,7 @@ impl TableInternal for NativeTable {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let (new_dataset, _stats) = job.execute_reader(new_data).await?;
let new_dataset = job.execute_reader(new_data).await?;
self.dataset.set_latest(new_dataset.as_ref().clone()).await;
Ok(())
}
@@ -2557,7 +2550,8 @@ mod tests {
.unwrap()
.get_index_type(index_uuid)
.await
.unwrap(),
.unwrap()
.map(|index_type| index_type.to_string()),
Some("IVF".to_string())
);
assert_eq!(

View File

@@ -66,19 +66,6 @@ impl DatasetRef {
Ok(())
}
fn is_latest(&self) -> bool {
matches!(self, Self::Latest { .. })
}
async fn need_reload(&self) -> Result<bool> {
Ok(match self {
Self::Latest { dataset, .. } => {
dataset.latest_version_id().await? != dataset.version().version
}
Self::TimeTravel { dataset, version } => dataset.version().version != *version,
})
}
async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
match self {
Self::Latest { .. } => Ok(()),
@@ -142,7 +129,7 @@ impl DatasetConsistencyWrapper {
Self(Arc::new(RwLock::new(DatasetRef::Latest {
dataset,
read_consistency_interval,
last_consistency_check: Some(Instant::now()),
last_consistency_check: None,
})))
}
@@ -176,16 +163,11 @@ impl DatasetConsistencyWrapper {
/// Convert into a wrapper in latest version mode
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
if self.0.read().await.is_latest() {
return Ok(());
}
let mut write_guard = self.0.write().await;
if write_guard.is_latest() {
return Ok(());
}
write_guard.as_latest(read_consistency_interval).await
self.0
.write()
.await
.as_latest(read_consistency_interval)
.await
}
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
@@ -201,18 +183,7 @@ impl DatasetConsistencyWrapper {
}
pub async fn reload(&self) -> Result<()> {
if !self.0.read().await.need_reload().await? {
return Ok(());
}
let mut write_guard = self.0.write().await;
// on lock escalation -- check if someone else has already reloaded
if !write_guard.need_reload().await? {
return Ok(());
}
// actually need reloading
write_guard.reload().await
self.0.write().await.reload().await
}
/// Returns the version, if in time travel mode, or None otherwise

View File

@@ -101,7 +101,7 @@ pub fn validate_table_name(name: &str) -> Result<()> {
Ok(())
}
/// Find one default column to create index or perform vector query.
/// Find one default column to create index.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find one fixed size list array column.
let candidates = schema
@@ -118,17 +118,14 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
})
.collect::<Vec<_>>();
if candidates.is_empty() {
Err(Error::InvalidInput {
message: format!(
"No vector column found to match with the query vector dimension: {}",
dim.unwrap_or_default()
),
Err(Error::Schema {
message: "No vector column found to create index".to_string(),
})
} else if candidates.len() != 1 {
Err(Error::Schema {
message: format!(
"More than one vector columns found, \
please specify which column to create index or query: {:?}",
please specify which column to create index: {:?}",
candidates
),
})

View File

@@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed {
fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let len = source.len();
@@ -317,9 +317,4 @@ impl EmbeddingFunction for MockEmbed {
Ok(Arc::new(arr))
}
#[allow(unused_variables)]
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}