mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
Compare commits
6 Commits
release-0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
291ed41c3e | ||
|
|
fdda7b1a76 | ||
|
|
eb2cbedf19 | ||
|
|
bc139000bd | ||
|
|
dbea3a7544 | ||
|
|
3bb7c546d7 |
@@ -14,7 +14,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: local-biome-check
|
- id: local-biome-check
|
||||||
name: biome check
|
name: biome check
|
||||||
entry: npx biome check
|
entry: npx @biomejs/biome check --config-path nodejs/biome.json nodejs/
|
||||||
language: system
|
language: system
|
||||||
types: [text]
|
types: [text]
|
||||||
files: "nodejs/.*"
|
files: "nodejs/.*"
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
|||||||
categories = ["database-implementations"]
|
categories = ["database-implementations"]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.11.0", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.11.1", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.11.0" }
|
lance-index = { "version" = "=0.11.1" }
|
||||||
lance-linalg = { "version" = "=0.11.0" }
|
lance-linalg = { "version" = "=0.11.1" }
|
||||||
lance-testing = { "version" = "=0.11.0" }
|
lance-testing = { "version" = "=0.11.1" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "51.0", optional = false }
|
arrow = { version = "51.0", optional = false }
|
||||||
arrow-array = "51.0"
|
arrow-array = "51.0"
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.0.3-SNAPSHOT</version>
|
<version>0.1-SNAPSHOT</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
100
java/pom.xml
100
java/pom.xml
@@ -1,34 +1,15 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.0.3-SNAPSHOT</version>
|
<version>0.1-SNAPSHOT</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>Lance Parent</name>
|
<name>Lance Parent</name>
|
||||||
<description>LanceDB Java API</description>
|
|
||||||
<url>http://lancedb.com/</url>
|
|
||||||
|
|
||||||
<developers>
|
|
||||||
<developer>
|
|
||||||
<name>Lance DB Dev Group</name>
|
|
||||||
<email>dev@lancedb.com</email>
|
|
||||||
</developer>
|
|
||||||
</developers>
|
|
||||||
<licenses>
|
|
||||||
<license>
|
|
||||||
<name>The Apache Software License, Version 2.0</name>
|
|
||||||
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
|
||||||
</license>
|
|
||||||
</licenses>
|
|
||||||
|
|
||||||
<scm>
|
|
||||||
<developerConnection>scm:git:git@github.com:lancedb/lancedb.git</developerConnection>
|
|
||||||
<tag>HEAD</tag>
|
|
||||||
<url>scm:git:git@github.com:lancedb/lancedb.git</url>
|
|
||||||
</scm>
|
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
@@ -83,32 +64,6 @@
|
|||||||
|
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-source-plugin</artifactId>
|
|
||||||
<version>2.2.1</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>attach-sources</id>
|
|
||||||
<goals>
|
|
||||||
<goal>jar-no-fork</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-javadoc-plugin</artifactId>
|
|
||||||
<version>2.9.1</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>attach-javadocs</id>
|
|
||||||
<goals>
|
|
||||||
<goal>jar</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-checkstyle-plugin</artifactId>
|
<artifactId>maven-checkstyle-plugin</artifactId>
|
||||||
@@ -156,7 +111,7 @@
|
|||||||
<version>3.2.5</version>
|
<version>3.2.5</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
|
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
|
||||||
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory" />
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<useSystemClassLoader>false</useSystemClassLoader>
|
<useSystemClassLoader>false</useSystemClassLoader>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
@@ -171,49 +126,4 @@
|
|||||||
</plugins>
|
</plugins>
|
||||||
</pluginManagement>
|
</pluginManagement>
|
||||||
</build>
|
</build>
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>deploy-to-ossrh</id>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.sonatype.central</groupId>
|
|
||||||
<artifactId>central-publishing-maven-plugin</artifactId>
|
|
||||||
<version>0.4.0</version>
|
|
||||||
<extensions>true</extensions>
|
|
||||||
<configuration>
|
|
||||||
<publishingServerId>ossrh</publishingServerId>
|
|
||||||
<tokenAuth>true</tokenAuth>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.sonatype.plugins</groupId>
|
|
||||||
<artifactId>nexus-staging-maven-plugin</artifactId>
|
|
||||||
<version>1.6.13</version>
|
|
||||||
<extensions>true</extensions>
|
|
||||||
<configuration>
|
|
||||||
<serverId>ossrh</serverId>
|
|
||||||
<nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
|
|
||||||
<autoReleaseAfterClose>true</autoReleaseAfterClose>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-gpg-plugin</artifactId>
|
|
||||||
<version>1.5</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>sign-artifacts</id>
|
|
||||||
<phase>verify</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>sign</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import {
|
|||||||
Schema,
|
Schema,
|
||||||
Struct,
|
Struct,
|
||||||
type Table,
|
type Table,
|
||||||
|
Type,
|
||||||
Utf8,
|
Utf8,
|
||||||
tableFromIPC,
|
tableFromIPC,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
@@ -51,7 +52,12 @@ import {
|
|||||||
makeArrowTable,
|
makeArrowTable,
|
||||||
makeEmptyTable,
|
makeEmptyTable,
|
||||||
} from "../lancedb/arrow";
|
} from "../lancedb/arrow";
|
||||||
import { type EmbeddingFunction } from "../lancedb/embedding/embedding_function";
|
import {
|
||||||
|
EmbeddingFunction,
|
||||||
|
FieldOptions,
|
||||||
|
FunctionOptions,
|
||||||
|
} from "../lancedb/embedding/embedding_function";
|
||||||
|
import { EmbeddingFunctionConfig } from "../lancedb/embedding/registry";
|
||||||
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||||
function sampleRecords(): Array<Record<string, any>> {
|
function sampleRecords(): Array<Record<string, any>> {
|
||||||
@@ -280,23 +286,46 @@ describe("The function makeArrowTable", function () {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
class DummyEmbedding implements EmbeddingFunction<string> {
|
class DummyEmbedding extends EmbeddingFunction<string> {
|
||||||
public readonly sourceColumn = "string";
|
toJSON(): Partial<FunctionOptions> {
|
||||||
public readonly embeddingDimension = 2;
|
return {};
|
||||||
public readonly embeddingDataType = new Float16();
|
}
|
||||||
|
|
||||||
async embed(data: string[]): Promise<number[][]> {
|
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
|
||||||
return data.map(() => [0.0, 0.0]);
|
return data.map(() => [0.0, 0.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ndims(): number {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingDataType() {
|
||||||
|
return new Float16();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
|
class DummyEmbeddingWithNoDimension extends EmbeddingFunction<string> {
|
||||||
public readonly sourceColumn = "string";
|
toJSON(): Partial<FunctionOptions> {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
async embed(data: string[]): Promise<number[][]> {
|
embeddingDataType(): Float {
|
||||||
|
return new Float16();
|
||||||
|
}
|
||||||
|
|
||||||
|
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
|
||||||
return data.map(() => [0.0, 0.0]);
|
return data.map(() => [0.0, 0.0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const dummyEmbeddingConfig: EmbeddingFunctionConfig = {
|
||||||
|
sourceColumn: "string",
|
||||||
|
function: new DummyEmbedding(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const dummyEmbeddingConfigWithNoDimension: EmbeddingFunctionConfig = {
|
||||||
|
sourceColumn: "string",
|
||||||
|
function: new DummyEmbeddingWithNoDimension(),
|
||||||
|
};
|
||||||
|
|
||||||
describe("convertToTable", function () {
|
describe("convertToTable", function () {
|
||||||
it("will infer data types correctly", async function () {
|
it("will infer data types correctly", async function () {
|
||||||
@@ -331,7 +360,7 @@ describe("convertToTable", function () {
|
|||||||
|
|
||||||
it("will apply embeddings", async function () {
|
it("will apply embeddings", async function () {
|
||||||
const records = sampleRecords();
|
const records = sampleRecords();
|
||||||
const table = await convertToTable(records, new DummyEmbedding());
|
const table = await convertToTable(records, dummyEmbeddingConfig);
|
||||||
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
|
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
|
||||||
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
|
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
|
||||||
new Float16().toString(),
|
new Float16().toString(),
|
||||||
@@ -340,7 +369,7 @@ describe("convertToTable", function () {
|
|||||||
|
|
||||||
it("will fail if missing the embedding source column", async function () {
|
it("will fail if missing the embedding source column", async function () {
|
||||||
await expect(
|
await expect(
|
||||||
convertToTable([{ id: 1 }], new DummyEmbedding()),
|
convertToTable([{ id: 1 }], dummyEmbeddingConfig),
|
||||||
).rejects.toThrow("'string' was not present");
|
).rejects.toThrow("'string' was not present");
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -351,7 +380,7 @@ describe("convertToTable", function () {
|
|||||||
const table = makeEmptyTable(schema);
|
const table = makeEmptyTable(schema);
|
||||||
|
|
||||||
// If the embedding specifies the dimension we are fine
|
// If the embedding specifies the dimension we are fine
|
||||||
await fromTableToBuffer(table, new DummyEmbedding());
|
await fromTableToBuffer(table, dummyEmbeddingConfig);
|
||||||
|
|
||||||
// We can also supply a schema and should be ok
|
// We can also supply a schema and should be ok
|
||||||
const schemaWithEmbedding = new Schema([
|
const schemaWithEmbedding = new Schema([
|
||||||
@@ -364,13 +393,13 @@ describe("convertToTable", function () {
|
|||||||
]);
|
]);
|
||||||
await fromTableToBuffer(
|
await fromTableToBuffer(
|
||||||
table,
|
table,
|
||||||
new DummyEmbeddingWithNoDimension(),
|
dummyEmbeddingConfigWithNoDimension,
|
||||||
schemaWithEmbedding,
|
schemaWithEmbedding,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Otherwise we will get an error
|
// Otherwise we will get an error
|
||||||
await expect(
|
await expect(
|
||||||
fromTableToBuffer(table, new DummyEmbeddingWithNoDimension()),
|
fromTableToBuffer(table, dummyEmbeddingConfigWithNoDimension),
|
||||||
).rejects.toThrow("does not specify `embeddingDimension`");
|
).rejects.toThrow("does not specify `embeddingDimension`");
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -383,7 +412,7 @@ describe("convertToTable", function () {
|
|||||||
false,
|
false,
|
||||||
),
|
),
|
||||||
]);
|
]);
|
||||||
const table = await convertToTable([], new DummyEmbedding(), { schema });
|
const table = await convertToTable([], dummyEmbeddingConfig, { schema });
|
||||||
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
|
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
|
||||||
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
|
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
|
||||||
new Float16().toString(),
|
new Float16().toString(),
|
||||||
@@ -393,16 +422,17 @@ describe("convertToTable", function () {
|
|||||||
it("will complain if embeddings present but schema missing embedding column", async function () {
|
it("will complain if embeddings present but schema missing embedding column", async function () {
|
||||||
const schema = new Schema([new Field("string", new Utf8(), false)]);
|
const schema = new Schema([new Field("string", new Utf8(), false)]);
|
||||||
await expect(
|
await expect(
|
||||||
convertToTable([], new DummyEmbedding(), { schema }),
|
convertToTable([], dummyEmbeddingConfig, { schema }),
|
||||||
).rejects.toThrow("column vector was missing");
|
).rejects.toThrow("column vector was missing");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("will provide a nice error if run twice", async function () {
|
it("will provide a nice error if run twice", async function () {
|
||||||
const records = sampleRecords();
|
const records = sampleRecords();
|
||||||
const table = await convertToTable(records, new DummyEmbedding());
|
const table = await convertToTable(records, dummyEmbeddingConfig);
|
||||||
|
|
||||||
// fromTableToBuffer will try and apply the embeddings again
|
// fromTableToBuffer will try and apply the embeddings again
|
||||||
await expect(
|
await expect(
|
||||||
fromTableToBuffer(table, new DummyEmbedding()),
|
fromTableToBuffer(table, dummyEmbeddingConfig),
|
||||||
).rejects.toThrow("already existed");
|
).rejects.toThrow("already existed");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import * as tmp from "tmp";
|
import * as tmp from "tmp";
|
||||||
|
|
||||||
import { Connection, connect } from "../lancedb";
|
import { Connection, connect } from "../lancedb";
|
||||||
|
|
||||||
describe("when connecting", () => {
|
describe("when connecting", () => {
|
||||||
|
|||||||
169
nodejs/__test__/registry.test.ts
Normal file
169
nodejs/__test__/registry.test.ts
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
import * as arrow from "apache-arrow";
|
||||||
|
import * as arrowOld from "apache-arrow-old";
|
||||||
|
|
||||||
|
import * as tmp from "tmp";
|
||||||
|
|
||||||
|
import { connect } from "../lancedb";
|
||||||
|
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
||||||
|
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||||
|
|
||||||
|
describe.each([arrow, arrowOld])("LanceSchema", (arrow) => {
|
||||||
|
test("should preserve input order", async () => {
|
||||||
|
const schema = LanceSchema({
|
||||||
|
id: new arrow.Int32(),
|
||||||
|
text: new arrow.Utf8(),
|
||||||
|
vector: new arrow.Float32(),
|
||||||
|
});
|
||||||
|
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Registry", () => {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
beforeEach(() => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
tmpDir.removeCallback();
|
||||||
|
getRegistry().reset();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should register a new item to the registry", async () => {
|
||||||
|
@register("mock-embedding")
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {
|
||||||
|
someText: "hello",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): arrow.Float {
|
||||||
|
return new arrow.Float32();
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return data.map(() => [1, 2, 3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const func = getRegistry()
|
||||||
|
.get<MockEmbeddingFunction>("mock-embedding")!
|
||||||
|
.create();
|
||||||
|
|
||||||
|
const schema = LanceSchema({
|
||||||
|
id: new arrow.Int32(),
|
||||||
|
text: func.sourceField(new arrow.Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const table = await db.createTable(
|
||||||
|
"test",
|
||||||
|
[
|
||||||
|
{ id: 1, text: "hello" },
|
||||||
|
{ id: 2, text: "world" },
|
||||||
|
],
|
||||||
|
{ schema },
|
||||||
|
);
|
||||||
|
const expected = [
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3],
|
||||||
|
];
|
||||||
|
const actual = await table.query().toArrow();
|
||||||
|
const vectors = actual
|
||||||
|
.getChild("vector")
|
||||||
|
?.toArray()
|
||||||
|
.map((x: unknown) => {
|
||||||
|
if (x instanceof arrow.Vector) {
|
||||||
|
return [...x];
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
expect(vectors).toEqual(expected);
|
||||||
|
});
|
||||||
|
test("should error if registering with the same name", async () => {
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {
|
||||||
|
someText: "hello",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): arrow.Float {
|
||||||
|
return new arrow.Float32();
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return data.map(() => [1, 2, 3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
register("mock-embedding")(MockEmbeddingFunction);
|
||||||
|
expect(() => register("mock-embedding")(MockEmbeddingFunction)).toThrow(
|
||||||
|
'Embedding function with alias "mock-embedding" already exists',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
test("schema should contain correct metadata", async () => {
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {
|
||||||
|
someText: "hello",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): arrow.Float {
|
||||||
|
return new arrow.Float32();
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return data.map(() => [1, 2, 3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const func = new MockEmbeddingFunction();
|
||||||
|
|
||||||
|
const schema = LanceSchema({
|
||||||
|
id: new arrow.Int32(),
|
||||||
|
text: func.sourceField(new arrow.Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
const expectedMetadata = new Map<string, string>([
|
||||||
|
[
|
||||||
|
"embedding_functions",
|
||||||
|
JSON.stringify([
|
||||||
|
{
|
||||||
|
sourceColumn: "text",
|
||||||
|
vectorColumn: "vector",
|
||||||
|
name: "MockEmbeddingFunction",
|
||||||
|
model: { someText: "hello" },
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
]);
|
||||||
|
expect(schema.metadata).toEqual(expectedMetadata);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -16,23 +16,34 @@ import * as fs from "fs";
|
|||||||
import * as path from "path";
|
import * as path from "path";
|
||||||
import * as tmp from "tmp";
|
import * as tmp from "tmp";
|
||||||
|
|
||||||
|
import * as arrow from "apache-arrow";
|
||||||
|
import * as arrowOld from "apache-arrow-old";
|
||||||
|
|
||||||
|
import { Table, connect } from "../lancedb";
|
||||||
import {
|
import {
|
||||||
Field,
|
Field,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
|
Float,
|
||||||
Float32,
|
Float32,
|
||||||
Float64,
|
Float64,
|
||||||
Int32,
|
Int32,
|
||||||
Int64,
|
Int64,
|
||||||
Schema,
|
Schema,
|
||||||
} from "apache-arrow";
|
Utf8,
|
||||||
import { Table, connect } from "../lancedb";
|
makeArrowTable,
|
||||||
import { makeArrowTable } from "../lancedb/arrow";
|
} from "../lancedb/arrow";
|
||||||
|
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
||||||
|
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||||
import { Index } from "../lancedb/indices";
|
import { Index } from "../lancedb/indices";
|
||||||
|
|
||||||
describe("Given a table", () => {
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
|
||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
let table: Table;
|
let table: Table;
|
||||||
const schema = new Schema([new Field("id", new Float64(), true)]);
|
|
||||||
|
const schema = new arrow.Schema([
|
||||||
|
new arrow.Field("id", new arrow.Float64(), true),
|
||||||
|
]);
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
const conn = await connect(tmpDir.name);
|
const conn = await connect(tmpDir.name);
|
||||||
@@ -420,6 +431,161 @@ describe("when dealing with versioning", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("embedding functions", () => {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
beforeEach(() => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
});
|
||||||
|
afterEach(() => tmpDir.removeCallback());
|
||||||
|
|
||||||
|
it("should be able to create a table with an embedding function", async () => {
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): Float {
|
||||||
|
return new Float32();
|
||||||
|
}
|
||||||
|
async computeQueryEmbeddings(_data: string) {
|
||||||
|
return [1, 2, 3];
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return Array.from({ length: data.length }).fill([
|
||||||
|
1, 2, 3,
|
||||||
|
]) as number[][];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const func = new MockEmbeddingFunction();
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const table = await db.createTable(
|
||||||
|
"test",
|
||||||
|
[
|
||||||
|
{ id: 1, text: "hello" },
|
||||||
|
{ id: 2, text: "world" },
|
||||||
|
],
|
||||||
|
{
|
||||||
|
embeddingFunction: {
|
||||||
|
function: func,
|
||||||
|
sourceColumn: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: test
|
||||||
|
const arr = (await table.query().toArray()) as any;
|
||||||
|
expect(arr[0].vector).toBeDefined();
|
||||||
|
|
||||||
|
// we round trip through JSON to make sure the vector properly gets converted to an array
|
||||||
|
// otherwise it'll be a TypedArray or Vector
|
||||||
|
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
|
||||||
|
expect(vector0).toEqual([1, 2, 3]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should be able to create an empty table with an embedding function", async () => {
|
||||||
|
@register()
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): Float {
|
||||||
|
return new Float32();
|
||||||
|
}
|
||||||
|
async computeQueryEmbeddings(_data: string) {
|
||||||
|
return [1, 2, 3];
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return Array.from({ length: data.length }).fill([
|
||||||
|
1, 2, 3,
|
||||||
|
]) as number[][];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("text", new Utf8(), true),
|
||||||
|
new Field(
|
||||||
|
"vector",
|
||||||
|
new FixedSizeList(3, new Field("item", new Float32(), true)),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const func = new MockEmbeddingFunction();
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const table = await db.createEmptyTable("test", schema, {
|
||||||
|
embeddingFunction: {
|
||||||
|
function: func,
|
||||||
|
sourceColumn: "text",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const outSchema = await table.schema();
|
||||||
|
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
|
||||||
|
await table.add([{ text: "hello world" }]);
|
||||||
|
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: test
|
||||||
|
const arr = (await table.query().toArray()) as any;
|
||||||
|
expect(arr[0].vector).toBeDefined();
|
||||||
|
|
||||||
|
// we round trip through JSON to make sure the vector properly gets converted to an array
|
||||||
|
// otherwise it'll be a TypedArray or Vector
|
||||||
|
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
|
||||||
|
expect(vector0).toEqual([1, 2, 3]);
|
||||||
|
});
|
||||||
|
it("should error when appending to a table with an unregistered embedding function", async () => {
|
||||||
|
@register("mock")
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): Float {
|
||||||
|
return new Float32();
|
||||||
|
}
|
||||||
|
async computeQueryEmbeddings(_data: string) {
|
||||||
|
return [1, 2, 3];
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return Array.from({ length: data.length }).fill([
|
||||||
|
1, 2, 3,
|
||||||
|
]) as number[][];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
|
||||||
|
|
||||||
|
const schema = LanceSchema({
|
||||||
|
id: new arrow.Float64(),
|
||||||
|
text: func.sourceField(new Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
await db.createTable(
|
||||||
|
"test",
|
||||||
|
[
|
||||||
|
{ id: 1, text: "hello" },
|
||||||
|
{ id: 2, text: "world" },
|
||||||
|
],
|
||||||
|
{
|
||||||
|
schema,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
getRegistry().reset();
|
||||||
|
const db2 = await connect(tmpDir.name);
|
||||||
|
|
||||||
|
const tbl = await db2.openTable("test");
|
||||||
|
|
||||||
|
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
|
||||||
|
`Function "mock" not found in registry`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe("when optimizing a dataset", () => {
|
describe("when optimizing a dataset", () => {
|
||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
let table: Table;
|
let table: Table;
|
||||||
|
|||||||
@@ -48,7 +48,7 @@
|
|||||||
"noUnsafeFinally": "error",
|
"noUnsafeFinally": "error",
|
||||||
"noUnsafeOptionalChaining": "error",
|
"noUnsafeOptionalChaining": "error",
|
||||||
"noUnusedLabels": "error",
|
"noUnusedLabels": "error",
|
||||||
"noUnusedVariables": "error",
|
"noUnusedVariables": "warn",
|
||||||
"useIsNan": "error",
|
"useIsNan": "error",
|
||||||
"useValidForDirection": "error",
|
"useValidForDirection": "error",
|
||||||
"useYield": "error"
|
"useYield": "error"
|
||||||
@@ -101,7 +101,13 @@
|
|||||||
},
|
},
|
||||||
"overrides": [
|
"overrides": [
|
||||||
{
|
{
|
||||||
"include": ["**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"],
|
"include": [
|
||||||
|
"**/*.ts",
|
||||||
|
"**/*.tsx",
|
||||||
|
"**/*.mts",
|
||||||
|
"**/*.cts",
|
||||||
|
"__test__/*.test.ts"
|
||||||
|
],
|
||||||
"linter": {
|
"linter": {
|
||||||
"rules": {
|
"rules": {
|
||||||
"correctness": {
|
"correctness": {
|
||||||
|
|||||||
@@ -17,10 +17,14 @@ import {
|
|||||||
Binary,
|
Binary,
|
||||||
DataType,
|
DataType,
|
||||||
Field,
|
Field,
|
||||||
|
FixedSizeBinary,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
type Float,
|
Float,
|
||||||
Float32,
|
Float32,
|
||||||
|
Int,
|
||||||
|
LargeBinary,
|
||||||
List,
|
List,
|
||||||
|
Null,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
RecordBatchStreamWriter,
|
RecordBatchStreamWriter,
|
||||||
@@ -34,7 +38,99 @@ import {
|
|||||||
vectorFromArray,
|
vectorFromArray,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
||||||
import { sanitizeSchema } from "./sanitize";
|
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||||
|
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
|
||||||
|
export * from "apache-arrow";
|
||||||
|
|
||||||
|
export function isArrowTable(value: object): value is ArrowTable {
|
||||||
|
if (value instanceof ArrowTable) return true;
|
||||||
|
return "schema" in value && "batches" in value;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isDataType(value: unknown): value is DataType {
|
||||||
|
return (
|
||||||
|
value instanceof DataType ||
|
||||||
|
DataType.isNull(value) ||
|
||||||
|
DataType.isInt(value) ||
|
||||||
|
DataType.isFloat(value) ||
|
||||||
|
DataType.isBinary(value) ||
|
||||||
|
DataType.isLargeBinary(value) ||
|
||||||
|
DataType.isUtf8(value) ||
|
||||||
|
DataType.isLargeUtf8(value) ||
|
||||||
|
DataType.isBool(value) ||
|
||||||
|
DataType.isDecimal(value) ||
|
||||||
|
DataType.isDate(value) ||
|
||||||
|
DataType.isTime(value) ||
|
||||||
|
DataType.isTimestamp(value) ||
|
||||||
|
DataType.isInterval(value) ||
|
||||||
|
DataType.isDuration(value) ||
|
||||||
|
DataType.isList(value) ||
|
||||||
|
DataType.isStruct(value) ||
|
||||||
|
DataType.isUnion(value) ||
|
||||||
|
DataType.isFixedSizeBinary(value) ||
|
||||||
|
DataType.isFixedSizeList(value) ||
|
||||||
|
DataType.isMap(value) ||
|
||||||
|
DataType.isDictionary(value)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
export function isNull(value: unknown): value is Null {
|
||||||
|
return value instanceof Null || DataType.isNull(value);
|
||||||
|
}
|
||||||
|
export function isInt(value: unknown): value is Int {
|
||||||
|
return value instanceof Int || DataType.isInt(value);
|
||||||
|
}
|
||||||
|
export function isFloat(value: unknown): value is Float {
|
||||||
|
return value instanceof Float || DataType.isFloat(value);
|
||||||
|
}
|
||||||
|
export function isBinary(value: unknown): value is Binary {
|
||||||
|
return value instanceof Binary || DataType.isBinary(value);
|
||||||
|
}
|
||||||
|
export function isLargeBinary(value: unknown): value is LargeBinary {
|
||||||
|
return value instanceof LargeBinary || DataType.isLargeBinary(value);
|
||||||
|
}
|
||||||
|
export function isUtf8(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isUtf8(value);
|
||||||
|
}
|
||||||
|
export function isLargeUtf8(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isLargeUtf8(value);
|
||||||
|
}
|
||||||
|
export function isBool(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isBool(value);
|
||||||
|
}
|
||||||
|
export function isDecimal(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isDecimal(value);
|
||||||
|
}
|
||||||
|
export function isDate(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isDate(value);
|
||||||
|
}
|
||||||
|
export function isTime(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isTime(value);
|
||||||
|
}
|
||||||
|
export function isTimestamp(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isTimestamp(value);
|
||||||
|
}
|
||||||
|
export function isInterval(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isInterval(value);
|
||||||
|
}
|
||||||
|
export function isDuration(value: unknown): value is Utf8 {
|
||||||
|
return value instanceof Utf8 || DataType.isDuration(value);
|
||||||
|
}
|
||||||
|
export function isList(value: unknown): value is List {
|
||||||
|
return value instanceof List || DataType.isList(value);
|
||||||
|
}
|
||||||
|
export function isStruct(value: unknown): value is Struct {
|
||||||
|
return value instanceof Struct || DataType.isStruct(value);
|
||||||
|
}
|
||||||
|
export function isUnion(value: unknown): value is Struct {
|
||||||
|
return value instanceof Struct || DataType.isUnion(value);
|
||||||
|
}
|
||||||
|
export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary {
|
||||||
|
return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isFixedSizeList(value: unknown): value is FixedSizeList {
|
||||||
|
return value instanceof FixedSizeList || DataType.isFixedSizeList(value);
|
||||||
|
}
|
||||||
|
|
||||||
/** Data type accepted by NodeJS SDK */
|
/** Data type accepted by NodeJS SDK */
|
||||||
export type Data = Record<string, unknown>[] | ArrowTable;
|
export type Data = Record<string, unknown>[] | ArrowTable;
|
||||||
@@ -198,6 +294,7 @@ export class MakeArrowTableOptions {
|
|||||||
export function makeArrowTable(
|
export function makeArrowTable(
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
options?: Partial<MakeArrowTableOptions>,
|
options?: Partial<MakeArrowTableOptions>,
|
||||||
|
metadata?: Map<string, string>,
|
||||||
): ArrowTable {
|
): ArrowTable {
|
||||||
if (
|
if (
|
||||||
data.length === 0 &&
|
data.length === 0 &&
|
||||||
@@ -290,20 +387,41 @@ export function makeArrowTable(
|
|||||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||||
const firstTable = new ArrowTable(columns);
|
const firstTable = new ArrowTable(columns);
|
||||||
const batchesFixed = firstTable.batches.map(
|
const batchesFixed = firstTable.batches.map(
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
||||||
(batch) => new RecordBatch(opt.schema!, batch.data),
|
(batch) => new RecordBatch(opt.schema!, batch.data),
|
||||||
);
|
);
|
||||||
return new ArrowTable(opt.schema, batchesFixed);
|
let schema: Schema;
|
||||||
} else {
|
if (metadata !== undefined) {
|
||||||
return new ArrowTable(columns);
|
let schemaMetadata = opt.schema.metadata;
|
||||||
|
if (schemaMetadata.size === 0) {
|
||||||
|
schemaMetadata = metadata;
|
||||||
|
} else {
|
||||||
|
for (const [key, entry] of schemaMetadata.entries()) {
|
||||||
|
schemaMetadata.set(key, entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = new Schema(opt.schema.fields, schemaMetadata);
|
||||||
|
} else {
|
||||||
|
schema = opt.schema;
|
||||||
|
}
|
||||||
|
return new ArrowTable(schema, batchesFixed);
|
||||||
}
|
}
|
||||||
|
const tbl = new ArrowTable(columns);
|
||||||
|
if (metadata !== undefined) {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
(<any>tbl.schema).metadata = metadata;
|
||||||
|
}
|
||||||
|
return tbl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create an empty Arrow table with the provided schema
|
* Create an empty Arrow table with the provided schema
|
||||||
*/
|
*/
|
||||||
export function makeEmptyTable(schema: Schema): ArrowTable {
|
export function makeEmptyTable(
|
||||||
return makeArrowTable([], { schema });
|
schema: Schema,
|
||||||
|
metadata?: Map<string, string>,
|
||||||
|
): ArrowTable {
|
||||||
|
return makeArrowTable([], { schema }, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -375,13 +493,75 @@ function makeVector(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Helper function to apply embeddings from metadata to an input table */
|
||||||
|
async function applyEmbeddingsFromMetadata(
|
||||||
|
table: ArrowTable,
|
||||||
|
schema: Schema,
|
||||||
|
): Promise<ArrowTable> {
|
||||||
|
const registry = getRegistry();
|
||||||
|
const functions = registry.parseFunctions(schema.metadata);
|
||||||
|
|
||||||
|
const columns = Object.fromEntries(
|
||||||
|
table.schema.fields.map((field) => [
|
||||||
|
field.name,
|
||||||
|
table.getChild(field.name)!,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
for (const functionEntry of functions.values()) {
|
||||||
|
const sourceColumn = columns[functionEntry.sourceColumn];
|
||||||
|
const destColumn = functionEntry.vectorColumn ?? "vector";
|
||||||
|
if (sourceColumn === undefined) {
|
||||||
|
throw new Error(
|
||||||
|
`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (columns[destColumn] !== undefined) {
|
||||||
|
throw new Error(
|
||||||
|
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (table.batches.length > 1) {
|
||||||
|
throw new Error(
|
||||||
|
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const values = sourceColumn.toArray();
|
||||||
|
|
||||||
|
const vectors =
|
||||||
|
await functionEntry.function.computeSourceEmbeddings(values);
|
||||||
|
if (vectors.length !== values.length) {
|
||||||
|
throw new Error(
|
||||||
|
"Embedding function did not return an embedding for each input element",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let destType: DataType;
|
||||||
|
const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
|
||||||
|
if (isFixedSizeList(dtype)) {
|
||||||
|
destType = sanitizeType(dtype);
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
"Expected FixedSizeList as datatype for vector field, instead got: " +
|
||||||
|
dtype,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const vector = makeVector(vectors, destType);
|
||||||
|
columns[destColumn] = vector;
|
||||||
|
}
|
||||||
|
const newTable = new ArrowTable(columns);
|
||||||
|
return alignTable(newTable, schema);
|
||||||
|
}
|
||||||
|
|
||||||
/** Helper function to apply embeddings to an input table */
|
/** Helper function to apply embeddings to an input table */
|
||||||
async function applyEmbeddings<T>(
|
async function applyEmbeddings<T>(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<ArrowTable> {
|
): Promise<ArrowTable> {
|
||||||
if (embeddings == null) {
|
if (schema?.metadata.has("embedding_functions")) {
|
||||||
|
return applyEmbeddingsFromMetadata(table, schema!);
|
||||||
|
} else if (embeddings == null || embeddings === undefined) {
|
||||||
return table;
|
return table;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,8 +579,9 @@ async function applyEmbeddings<T>(
|
|||||||
const newColumns = Object.fromEntries(colEntries);
|
const newColumns = Object.fromEntries(colEntries);
|
||||||
|
|
||||||
const sourceColumn = newColumns[embeddings.sourceColumn];
|
const sourceColumn = newColumns[embeddings.sourceColumn];
|
||||||
const destColumn = embeddings.destColumn ?? "vector";
|
const destColumn = embeddings.vectorColumn ?? "vector";
|
||||||
const innerDestType = embeddings.embeddingDataType ?? new Float32();
|
const innerDestType =
|
||||||
|
embeddings.function.embeddingDataType() ?? new Float32();
|
||||||
if (sourceColumn === undefined) {
|
if (sourceColumn === undefined) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
|
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
|
||||||
@@ -414,11 +595,9 @@ async function applyEmbeddings<T>(
|
|||||||
// if we call convertToTable with 0 records and a schema that includes the embedding
|
// if we call convertToTable with 0 records and a schema that includes the embedding
|
||||||
return table;
|
return table;
|
||||||
}
|
}
|
||||||
if (embeddings.embeddingDimension !== undefined) {
|
const dimensions = embeddings.function.ndims();
|
||||||
const destType = newVectorType(
|
if (dimensions !== undefined) {
|
||||||
embeddings.embeddingDimension,
|
const destType = newVectorType(dimensions, innerDestType);
|
||||||
innerDestType,
|
|
||||||
);
|
|
||||||
newColumns[destColumn] = makeVector([], destType);
|
newColumns[destColumn] = makeVector([], destType);
|
||||||
} else if (schema != null) {
|
} else if (schema != null) {
|
||||||
const destField = schema.fields.find((f) => f.name === destColumn);
|
const destField = schema.fields.find((f) => f.name === destColumn);
|
||||||
@@ -446,7 +625,9 @@ async function applyEmbeddings<T>(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
const values = sourceColumn.toArray();
|
const values = sourceColumn.toArray();
|
||||||
const vectors = await embeddings.embed(values as T[]);
|
const vectors = await embeddings.function.computeSourceEmbeddings(
|
||||||
|
values as T[],
|
||||||
|
);
|
||||||
if (vectors.length !== values.length) {
|
if (vectors.length !== values.length) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Embedding function did not return an embedding for each input element",
|
"Embedding function did not return an embedding for each input element",
|
||||||
@@ -486,9 +667,9 @@ async function applyEmbeddings<T>(
|
|||||||
* embedding columns. If no schema is provded then embedding columns will
|
* embedding columns. If no schema is provded then embedding columns will
|
||||||
* be placed at the end of the table, after all of the input columns.
|
* be placed at the end of the table, after all of the input columns.
|
||||||
*/
|
*/
|
||||||
export async function convertToTable<T>(
|
export async function convertToTable(
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
makeTableOptions?: Partial<MakeArrowTableOptions>,
|
makeTableOptions?: Partial<MakeArrowTableOptions>,
|
||||||
): Promise<ArrowTable> {
|
): Promise<ArrowTable> {
|
||||||
const table = makeArrowTable(data, makeTableOptions);
|
const table = makeArrowTable(data, makeTableOptions);
|
||||||
@@ -496,13 +677,13 @@ export async function convertToTable<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Creates the Arrow Type for a Vector column with dimension `dim` */
|
/** Creates the Arrow Type for a Vector column with dimension `dim` */
|
||||||
function newVectorType<T extends Float>(
|
export function newVectorType<T extends Float>(
|
||||||
dim: number,
|
dim: number,
|
||||||
innerType: T,
|
innerType: T,
|
||||||
): FixedSizeList<T> {
|
): FixedSizeList<T> {
|
||||||
// in Lance we always default to have the elements nullable, so we need to set it to true
|
// in Lance we always default to have the elements nullable, so we need to set it to true
|
||||||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||||
const children = new Field<T>("item", innerType, true);
|
const children = new Field("item", <T>sanitizeType(innerType), true);
|
||||||
return new FixedSizeList(dim, children);
|
return new FixedSizeList(dim, children);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,9 +694,9 @@ function newVectorType<T extends Float>(
|
|||||||
*
|
*
|
||||||
* `schema` is required if data is empty
|
* `schema` is required if data is empty
|
||||||
*/
|
*/
|
||||||
export async function fromRecordsToBuffer<T>(
|
export async function fromRecordsToBuffer(
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
@@ -533,9 +714,9 @@ export async function fromRecordsToBuffer<T>(
|
|||||||
*
|
*
|
||||||
* `schema` is required if data is empty
|
* `schema` is required if data is empty
|
||||||
*/
|
*/
|
||||||
export async function fromRecordsToStreamBuffer<T>(
|
export async function fromRecordsToStreamBuffer(
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
@@ -554,9 +735,9 @@ export async function fromRecordsToStreamBuffer<T>(
|
|||||||
*
|
*
|
||||||
* `schema` is required if the table is empty
|
* `schema` is required if the table is empty
|
||||||
*/
|
*/
|
||||||
export async function fromTableToBuffer<T>(
|
export async function fromTableToBuffer(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
@@ -575,19 +756,19 @@ export async function fromTableToBuffer<T>(
|
|||||||
*
|
*
|
||||||
* `schema` is required if the table is empty
|
* `schema` is required if the table is empty
|
||||||
*/
|
*/
|
||||||
export async function fromDataToBuffer<T>(
|
export async function fromDataToBuffer(
|
||||||
data: Data,
|
data: Data,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
schema = sanitizeSchema(schema);
|
schema = sanitizeSchema(schema);
|
||||||
}
|
}
|
||||||
if (data instanceof ArrowTable) {
|
if (isArrowTable(data)) {
|
||||||
return fromTableToBuffer(data, embeddings, schema);
|
return fromTableToBuffer(data, embeddings, schema);
|
||||||
} else {
|
} else {
|
||||||
const table = await convertToTable(data);
|
const table = await convertToTable(data, embeddings, { schema });
|
||||||
return fromTableToBuffer(table, embeddings, schema);
|
return fromTableToBuffer(table);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,9 +780,9 @@ export async function fromDataToBuffer<T>(
|
|||||||
*
|
*
|
||||||
* `schema` is required if the table is empty
|
* `schema` is required if the table is empty
|
||||||
*/
|
*/
|
||||||
export async function fromTableToStreamBuffer<T>(
|
export async function fromTableToStreamBuffer(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: Schema,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
|
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
|
||||||
@@ -664,10 +845,25 @@ function validateSchemaEmbeddings(
|
|||||||
// if it does not, we add it to the list of missing embedding fields
|
// if it does not, we add it to the list of missing embedding fields
|
||||||
// Finally, we check if those missing embedding fields are `this._embeddings`
|
// Finally, we check if those missing embedding fields are `this._embeddings`
|
||||||
// if they are not, we throw an error
|
// if they are not, we throw an error
|
||||||
for (const field of schema.fields) {
|
for (let field of schema.fields) {
|
||||||
if (field.type instanceof FixedSizeList) {
|
if (isFixedSizeList(field.type)) {
|
||||||
|
field = sanitizeField(field);
|
||||||
|
|
||||||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
|
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
|
||||||
missingEmbeddingFields.push(field);
|
if (schema.metadata.has("embedding_functions")) {
|
||||||
|
const embeddings = JSON.parse(
|
||||||
|
schema.metadata.get("embedding_functions")!,
|
||||||
|
);
|
||||||
|
if (
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
|
||||||
|
embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
|
||||||
|
undefined
|
||||||
|
) {
|
||||||
|
missingEmbeddingFields.push(field);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
missingEmbeddingFields.push(field);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
fields.push(field);
|
fields.push(field);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,14 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { Table as ArrowTable, Schema } from "apache-arrow";
|
import { Table as ArrowTable, Schema } from "./arrow";
|
||||||
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
|
import {
|
||||||
|
fromTableToBuffer,
|
||||||
|
isArrowTable,
|
||||||
|
makeArrowTable,
|
||||||
|
makeEmptyTable,
|
||||||
|
} from "./arrow";
|
||||||
|
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||||
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
|
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
|
||||||
import { Table } from "./table";
|
import { Table } from "./table";
|
||||||
|
|
||||||
@@ -65,6 +71,8 @@ export interface CreateTableOptions {
|
|||||||
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||||
*/
|
*/
|
||||||
storageOptions?: Record<string, string>;
|
storageOptions?: Record<string, string>;
|
||||||
|
schema?: Schema;
|
||||||
|
embeddingFunction?: EmbeddingFunctionConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface OpenTableOptions {
|
export interface OpenTableOptions {
|
||||||
@@ -174,6 +182,7 @@ export class Connection {
|
|||||||
cleanseStorageOptions(options?.storageOptions),
|
cleanseStorageOptions(options?.storageOptions),
|
||||||
options?.indexCacheSize,
|
options?.indexCacheSize,
|
||||||
);
|
);
|
||||||
|
|
||||||
return new Table(innerTable);
|
return new Table(innerTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,18 +205,24 @@ export class Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let table: ArrowTable;
|
let table: ArrowTable;
|
||||||
if (data instanceof ArrowTable) {
|
if (isArrowTable(data)) {
|
||||||
table = data;
|
table = data;
|
||||||
} else {
|
} else {
|
||||||
table = makeArrowTable(data);
|
table = makeArrowTable(data, options);
|
||||||
}
|
}
|
||||||
const buf = await fromTableToBuffer(table);
|
|
||||||
|
const buf = await fromTableToBuffer(
|
||||||
|
table,
|
||||||
|
options?.embeddingFunction,
|
||||||
|
options?.schema,
|
||||||
|
);
|
||||||
const innerTable = await this.inner.createTable(
|
const innerTable = await this.inner.createTable(
|
||||||
name,
|
name,
|
||||||
buf,
|
buf,
|
||||||
mode,
|
mode,
|
||||||
cleanseStorageOptions(options?.storageOptions),
|
cleanseStorageOptions(options?.storageOptions),
|
||||||
);
|
);
|
||||||
|
|
||||||
return new Table(innerTable);
|
return new Table(innerTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +242,14 @@ export class Connection {
|
|||||||
if (mode === "create" && existOk) {
|
if (mode === "create" && existOk) {
|
||||||
mode = "exist_ok";
|
mode = "exist_ok";
|
||||||
}
|
}
|
||||||
|
let metadata: Map<string, string> | undefined = undefined;
|
||||||
|
if (options?.embeddingFunction !== undefined) {
|
||||||
|
const embeddingFunction = options.embeddingFunction;
|
||||||
|
const registry = getRegistry();
|
||||||
|
metadata = registry.getTableMetadata([embeddingFunction]);
|
||||||
|
}
|
||||||
|
|
||||||
const table = makeEmptyTable(schema);
|
const table = makeEmptyTable(schema, metadata);
|
||||||
const buf = await fromTableToBuffer(table);
|
const buf = await fromTableToBuffer(table);
|
||||||
const innerTable = await this.inner.createEmptyTable(
|
const innerTable = await this.inner.createEmptyTable(
|
||||||
name,
|
name,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2023 Lance Developers.
|
// Copyright 2024 Lance Developers.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -12,67 +12,151 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { type Float } from "apache-arrow";
|
import "reflect-metadata";
|
||||||
|
import {
|
||||||
|
DataType,
|
||||||
|
Field,
|
||||||
|
FixedSizeList,
|
||||||
|
Float,
|
||||||
|
Float32,
|
||||||
|
isDataType,
|
||||||
|
isFixedSizeList,
|
||||||
|
isFloat,
|
||||||
|
newVectorType,
|
||||||
|
} from "../arrow";
|
||||||
|
import { sanitizeType } from "../sanitize";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for a given embedding function
|
||||||
|
*/
|
||||||
|
export interface FunctionOptions {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: options can be anything
|
||||||
|
[key: string]: any;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An embedding function that automatically creates vector representation for a given column.
|
* An embedding function that automatically creates vector representation for a given column.
|
||||||
*/
|
*/
|
||||||
export interface EmbeddingFunction<T> {
|
export abstract class EmbeddingFunction<
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
||||||
|
T = any,
|
||||||
|
M extends FunctionOptions = FunctionOptions,
|
||||||
|
> {
|
||||||
/**
|
/**
|
||||||
* The name of the column that will be used as input for the Embedding Function.
|
* Convert the embedding function to a JSON object
|
||||||
|
* It is used to serialize the embedding function to the schema
|
||||||
|
* It's important that any object returned by this method contains all the necessary
|
||||||
|
* information to recreate the embedding function
|
||||||
|
*
|
||||||
|
* It should return the same object that was passed to the constructor
|
||||||
|
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||||
|
* constructor(options: {model: string, timeout: number}) {
|
||||||
|
* super();
|
||||||
|
* this.model = options.model;
|
||||||
|
* this.timeout = options.timeout;
|
||||||
|
* }
|
||||||
|
* toJSON() {
|
||||||
|
* return {
|
||||||
|
* model: this.model,
|
||||||
|
* timeout: this.timeout,
|
||||||
|
* };
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
*/
|
*/
|
||||||
sourceColumn: string;
|
abstract toJSON(): Partial<M>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The data type of the embedding
|
* sourceField is used in combination with `LanceSchema` to provide a declarative data model
|
||||||
*
|
*
|
||||||
* The embedding function should return `number`. This will be converted into
|
* @param optionsOrDatatype - The options for the field or the datatype
|
||||||
* an Arrow float array. By default this will be Float32 but this property can
|
*
|
||||||
* be used to control the conversion.
|
* @see {@link lancedb.LanceSchema}
|
||||||
*/
|
*/
|
||||||
embeddingDataType?: Float;
|
sourceField(
|
||||||
|
optionsOrDatatype: Partial<FieldOptions> | DataType,
|
||||||
|
): [DataType, Map<string, EmbeddingFunction>] {
|
||||||
|
let datatype = isDataType(optionsOrDatatype)
|
||||||
|
? optionsOrDatatype
|
||||||
|
: optionsOrDatatype?.datatype;
|
||||||
|
if (!datatype) {
|
||||||
|
throw new Error("Datatype is required");
|
||||||
|
}
|
||||||
|
datatype = sanitizeType(datatype);
|
||||||
|
const metadata = new Map<string, EmbeddingFunction>();
|
||||||
|
metadata.set("source_column_for", this);
|
||||||
|
|
||||||
|
return [datatype, metadata];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The dimension of the embedding
|
* vectorField is used in combination with `LanceSchema` to provide a declarative data model
|
||||||
*
|
*
|
||||||
* This is optional, normally this can be determined by looking at the results of
|
* @param options - The options for the field
|
||||||
* `embed`. If this is not specified, and there is an attempt to apply the embedding
|
*
|
||||||
* to an empty table, then that process will fail.
|
* @see {@link lancedb.LanceSchema}
|
||||||
*/
|
*/
|
||||||
embeddingDimension?: number;
|
vectorField(
|
||||||
|
options?: Partial<FieldOptions>,
|
||||||
|
): [DataType, Map<string, EmbeddingFunction>] {
|
||||||
|
let dtype: DataType;
|
||||||
|
const dims = this.ndims() ?? options?.dims;
|
||||||
|
if (!options?.datatype) {
|
||||||
|
if (dims === undefined) {
|
||||||
|
throw new Error("ndims is required for vector field");
|
||||||
|
}
|
||||||
|
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
|
||||||
|
} else {
|
||||||
|
if (isFixedSizeList(options.datatype)) {
|
||||||
|
dtype = options.datatype;
|
||||||
|
} else if (isFloat(options.datatype)) {
|
||||||
|
if (dims === undefined) {
|
||||||
|
throw new Error("ndims is required for vector field");
|
||||||
|
}
|
||||||
|
dtype = newVectorType(dims, options.datatype);
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
"Expected FixedSizeList or Float as datatype for vector field",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const metadata = new Map<string, EmbeddingFunction>();
|
||||||
|
metadata.set("vector_column_for", this);
|
||||||
|
|
||||||
/**
|
return [dtype, metadata];
|
||||||
* The name of the column that will contain the embedding
|
}
|
||||||
*
|
|
||||||
* By default this is "vector"
|
|
||||||
*/
|
|
||||||
destColumn?: string;
|
|
||||||
|
|
||||||
/**
|
/** The number of dimensions of the embeddings */
|
||||||
* Should the source column be excluded from the resulting table
|
ndims(): number | undefined {
|
||||||
*
|
return undefined;
|
||||||
* By default the source column is included. Set this to true and
|
}
|
||||||
* only the embedding will be stored.
|
|
||||||
*/
|
/** The datatype of the embeddings */
|
||||||
excludeSource?: boolean;
|
abstract embeddingDataType(): Float;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a vector representation for the given values.
|
* Creates a vector representation for the given values.
|
||||||
*/
|
*/
|
||||||
embed: (data: T[]) => Promise<number[][]>;
|
abstract computeSourceEmbeddings(
|
||||||
|
data: T[],
|
||||||
|
): Promise<number[][] | Float32Array[] | Float64Array[]>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
Compute the embeddings for a single query
|
||||||
|
*/
|
||||||
|
async computeQueryEmbeddings(
|
||||||
|
data: T,
|
||||||
|
): Promise<number[] | Float32Array | Float64Array> {
|
||||||
|
return this.computeSourceEmbeddings([data]).then(
|
||||||
|
(embeddings) => embeddings[0],
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Test if the input seems to be an embedding function */
|
export interface FieldOptions<T extends DataType = DataType> {
|
||||||
export function isEmbeddingFunction<T>(
|
datatype: T;
|
||||||
value: unknown,
|
dims?: number;
|
||||||
): value is EmbeddingFunction<T> {
|
|
||||||
if (typeof value !== "object" || value === null) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!("sourceColumn" in value) || !("embed" in value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
typeof value.sourceColumn === "string" && typeof value.embed === "function"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,2 +1,113 @@
|
|||||||
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function";
|
// Copyright 2023 Lance Developers.
|
||||||
export { OpenAIEmbeddingFunction } from "./openai";
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import { DataType, Field, Schema } from "../arrow";
|
||||||
|
import { isDataType } from "../arrow";
|
||||||
|
import { sanitizeType } from "../sanitize";
|
||||||
|
import { EmbeddingFunction } from "./embedding_function";
|
||||||
|
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
|
||||||
|
|
||||||
|
export { EmbeddingFunction } from "./embedding_function";
|
||||||
|
|
||||||
|
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
|
||||||
|
export * from "./openai";
|
||||||
|
export * from "./registry";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a schema with embedding functions.
|
||||||
|
*
|
||||||
|
* @param fields
|
||||||
|
* @returns Schema
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||||
|
* // ...
|
||||||
|
* }
|
||||||
|
* const func = new MyEmbeddingFunction();
|
||||||
|
* const schema = LanceSchema({
|
||||||
|
* id: new Int32(),
|
||||||
|
* text: func.sourceField(new Utf8()),
|
||||||
|
* vector: func.vectorField(),
|
||||||
|
* // optional: specify the datatype and/or dimensions
|
||||||
|
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}),
|
||||||
|
* });
|
||||||
|
*
|
||||||
|
* const table = await db.createTable("my_table", data, { schema });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function LanceSchema(
|
||||||
|
fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>,
|
||||||
|
): Schema {
|
||||||
|
const arrowFields: Field[] = [];
|
||||||
|
|
||||||
|
const embeddingFunctions = new Map<
|
||||||
|
EmbeddingFunction,
|
||||||
|
Partial<EmbeddingFunctionConfig>
|
||||||
|
>();
|
||||||
|
Object.entries(fields).forEach(([key, value]) => {
|
||||||
|
if (isDataType(value)) {
|
||||||
|
arrowFields.push(new Field(key, sanitizeType(value), true));
|
||||||
|
} else {
|
||||||
|
const [dtype, metadata] = value as [
|
||||||
|
object,
|
||||||
|
Map<string, EmbeddingFunction>,
|
||||||
|
];
|
||||||
|
arrowFields.push(new Field(key, sanitizeType(dtype), true));
|
||||||
|
parseEmbeddingFunctions(embeddingFunctions, key, metadata);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const registry = getRegistry();
|
||||||
|
const metadata = registry.getTableMetadata(
|
||||||
|
Array.from(embeddingFunctions.values()) as EmbeddingFunctionConfig[],
|
||||||
|
);
|
||||||
|
const schema = new Schema(arrowFields, metadata);
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseEmbeddingFunctions(
|
||||||
|
embeddingFunctions: Map<EmbeddingFunction, Partial<EmbeddingFunctionConfig>>,
|
||||||
|
key: string,
|
||||||
|
metadata: Map<string, EmbeddingFunction>,
|
||||||
|
): void {
|
||||||
|
if (metadata.has("source_column_for")) {
|
||||||
|
const embedFunction = metadata.get("source_column_for")!;
|
||||||
|
const current = embeddingFunctions.get(embedFunction);
|
||||||
|
if (current !== undefined) {
|
||||||
|
embeddingFunctions.set(embedFunction, {
|
||||||
|
...current,
|
||||||
|
sourceColumn: key,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
embeddingFunctions.set(embedFunction, {
|
||||||
|
sourceColumn: key,
|
||||||
|
function: embedFunction,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (metadata.has("vector_column_for")) {
|
||||||
|
const embedFunction = metadata.get("vector_column_for")!;
|
||||||
|
|
||||||
|
const current = embeddingFunctions.get(embedFunction);
|
||||||
|
if (current !== undefined) {
|
||||||
|
embeddingFunctions.set(embedFunction, {
|
||||||
|
...current,
|
||||||
|
vectorColumn: key,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
embeddingFunctions.set(embedFunction, {
|
||||||
|
vectorColumn: key,
|
||||||
|
function: embedFunction,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,17 +13,31 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import type OpenAI from "openai";
|
import type OpenAI from "openai";
|
||||||
import { type EmbeddingFunction } from "./embedding_function";
|
import { Float, Float32 } from "../arrow";
|
||||||
|
import { EmbeddingFunction } from "./embedding_function";
|
||||||
|
import { register } from "./registry";
|
||||||
|
|
||||||
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
|
export type OpenAIOptions = {
|
||||||
private readonly _openai: OpenAI;
|
apiKey?: string;
|
||||||
private readonly _modelName: string;
|
model?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
@register("openai")
|
||||||
|
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||||
|
string,
|
||||||
|
OpenAIOptions
|
||||||
|
> {
|
||||||
|
#openai: OpenAI;
|
||||||
|
#modelName: string;
|
||||||
|
|
||||||
|
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
|
||||||
|
super();
|
||||||
|
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||||
|
if (!openAIKey) {
|
||||||
|
throw new Error("OpenAI API key is required");
|
||||||
|
}
|
||||||
|
const modelName = options?.model ?? "text-embedding-ada-002";
|
||||||
|
|
||||||
constructor(
|
|
||||||
sourceColumn: string,
|
|
||||||
openAIKey: string,
|
|
||||||
modelName: string = "text-embedding-ada-002",
|
|
||||||
) {
|
|
||||||
/**
|
/**
|
||||||
* @type {import("openai").default}
|
* @type {import("openai").default}
|
||||||
*/
|
*/
|
||||||
@@ -36,18 +50,40 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
|
|||||||
throw new Error("please install openai@^4.24.1 using npm install openai");
|
throw new Error("please install openai@^4.24.1 using npm install openai");
|
||||||
}
|
}
|
||||||
|
|
||||||
this.sourceColumn = sourceColumn;
|
|
||||||
const configuration = {
|
const configuration = {
|
||||||
apiKey: openAIKey,
|
apiKey: openAIKey,
|
||||||
};
|
};
|
||||||
|
|
||||||
this._openai = new Openai(configuration);
|
this.#openai = new Openai(configuration);
|
||||||
this._modelName = modelName;
|
this.#modelName = modelName;
|
||||||
}
|
}
|
||||||
|
|
||||||
async embed(data: string[]): Promise<number[][]> {
|
toJSON() {
|
||||||
const response = await this._openai.embeddings.create({
|
return {
|
||||||
model: this._modelName,
|
model: this.#modelName,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
ndims(): number {
|
||||||
|
switch (this.#modelName) {
|
||||||
|
case "text-embedding-ada-002":
|
||||||
|
return 1536;
|
||||||
|
case "text-embedding-3-large":
|
||||||
|
return 3072;
|
||||||
|
case "text-embedding-3-small":
|
||||||
|
return 1536;
|
||||||
|
default:
|
||||||
|
return null as never;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingDataType(): Float {
|
||||||
|
return new Float32();
|
||||||
|
}
|
||||||
|
|
||||||
|
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
|
||||||
|
const response = await this.#openai.embeddings.create({
|
||||||
|
model: this.#modelName,
|
||||||
input: data,
|
input: data,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -58,5 +94,15 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
|
|||||||
return embeddings;
|
return embeddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceColumn: string;
|
async computeQueryEmbeddings(data: string): Promise<number[]> {
|
||||||
|
if (typeof data !== "string") {
|
||||||
|
throw new Error("Data must be a string");
|
||||||
|
}
|
||||||
|
const response = await this.#openai.embeddings.create({
|
||||||
|
model: this.#modelName,
|
||||||
|
input: data,
|
||||||
|
});
|
||||||
|
|
||||||
|
return response.data[0].embedding;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
172
nodejs/lancedb/embedding/registry.ts
Normal file
172
nodejs/lancedb/embedding/registry.ts
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import type { EmbeddingFunction } from "./embedding_function";
|
||||||
|
import "reflect-metadata";
|
||||||
|
|
||||||
|
export interface EmbeddingFunctionOptions {
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EmbeddingFunctionFactory<
|
||||||
|
T extends EmbeddingFunction = EmbeddingFunction,
|
||||||
|
> {
|
||||||
|
new (modelOptions?: EmbeddingFunctionOptions): T;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||||
|
create(options?: EmbeddingFunctionOptions): T;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a singleton class used to register embedding functions
|
||||||
|
* and fetch them by name. It also handles serializing and deserializing.
|
||||||
|
* You can implement your own embedding function by subclassing EmbeddingFunction
|
||||||
|
* or TextEmbeddingFunction and registering it with the registry
|
||||||
|
*/
|
||||||
|
export class EmbeddingFunctionRegistry {
|
||||||
|
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register an embedding function
|
||||||
|
* @param name The name of the function
|
||||||
|
* @param func The function to register
|
||||||
|
*/
|
||||||
|
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
|
||||||
|
this: EmbeddingFunctionRegistry,
|
||||||
|
alias?: string,
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
): (ctor: T) => any {
|
||||||
|
const self = this;
|
||||||
|
return function (ctor: T) {
|
||||||
|
if (!alias) {
|
||||||
|
alias = ctor.name;
|
||||||
|
}
|
||||||
|
if (self.#functions.has(alias)) {
|
||||||
|
throw new Error(
|
||||||
|
`Embedding function with alias "${alias}" already exists`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
self.#functions.set(alias, ctor);
|
||||||
|
Reflect.defineMetadata("lancedb::embedding::name", alias, ctor);
|
||||||
|
return ctor;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch an embedding function by name
|
||||||
|
* @param name The name of the function
|
||||||
|
*/
|
||||||
|
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
|
||||||
|
name: string,
|
||||||
|
): EmbeddingFunctionCreate<T> | undefined {
|
||||||
|
const factory = this.#functions.get(name);
|
||||||
|
if (!factory) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
create: function (options: EmbeddingFunctionOptions) {
|
||||||
|
return new factory(options) as unknown as T;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* reset the registry to the initial state
|
||||||
|
*/
|
||||||
|
reset(this: EmbeddingFunctionRegistry) {
|
||||||
|
this.#functions.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
parseFunctions(
|
||||||
|
this: EmbeddingFunctionRegistry,
|
||||||
|
metadata: Map<string, string>,
|
||||||
|
): Map<string, EmbeddingFunctionConfig> {
|
||||||
|
if (!metadata.has("embedding_functions")) {
|
||||||
|
return new Map();
|
||||||
|
} else {
|
||||||
|
type FunctionConfig = {
|
||||||
|
name: string;
|
||||||
|
sourceColumn: string;
|
||||||
|
vectorColumn: string;
|
||||||
|
model: EmbeddingFunctionOptions;
|
||||||
|
};
|
||||||
|
const functions = <FunctionConfig[]>(
|
||||||
|
JSON.parse(metadata.get("embedding_functions")!)
|
||||||
|
);
|
||||||
|
return new Map(
|
||||||
|
functions.map((f) => {
|
||||||
|
const fn = this.get(f.name);
|
||||||
|
if (!fn) {
|
||||||
|
throw new Error(`Function "${f.name}" not found in registry`);
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
f.name,
|
||||||
|
{
|
||||||
|
sourceColumn: f.sourceColumn,
|
||||||
|
vectorColumn: f.vectorColumn,
|
||||||
|
function: this.get(f.name)!.create(f.model),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
functionToMetadata(conf: EmbeddingFunctionConfig): Record<string, any> {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
const metadata: Record<string, any> = {};
|
||||||
|
const name = Reflect.getMetadata(
|
||||||
|
"lancedb::embedding::name",
|
||||||
|
conf.function.constructor,
|
||||||
|
);
|
||||||
|
metadata["sourceColumn"] = conf.sourceColumn;
|
||||||
|
metadata["vectorColumn"] = conf.vectorColumn ?? "vector";
|
||||||
|
metadata["name"] = name ?? conf.function.constructor.name;
|
||||||
|
metadata["model"] = conf.function.toJSON();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
getTableMetadata(functions: EmbeddingFunctionConfig[]): Map<string, string> {
|
||||||
|
const metadata = new Map<string, string>();
|
||||||
|
const jsonData = functions.map((conf) => this.functionToMetadata(conf));
|
||||||
|
metadata.set("embedding_functions", JSON.stringify(jsonData));
|
||||||
|
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const _REGISTRY = new EmbeddingFunctionRegistry();
|
||||||
|
|
||||||
|
export function register(name?: string) {
|
||||||
|
return _REGISTRY.register(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to get the global instance of the registry
|
||||||
|
* @returns `EmbeddingFunctionRegistry` The global instance of the registry
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const registry = getRegistry();
|
||||||
|
* const openai = registry.get("openai").create();
|
||||||
|
*/
|
||||||
|
export function getRegistry(): EmbeddingFunctionRegistry {
|
||||||
|
return _REGISTRY;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EmbeddingFunctionConfig {
|
||||||
|
sourceColumn: string;
|
||||||
|
vectorColumn?: string;
|
||||||
|
function: EmbeddingFunction;
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow";
|
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
|
||||||
import { type IvfPqOptions } from "./indices";
|
import { type IvfPqOptions } from "./indices";
|
||||||
import {
|
import {
|
||||||
RecordBatchIterator as NativeBatchIterator,
|
RecordBatchIterator as NativeBatchIterator,
|
||||||
@@ -170,6 +170,7 @@ export class QueryBase<
|
|||||||
/** Collect the results as an array of objects. */
|
/** Collect the results as an array of objects. */
|
||||||
async toArray(): Promise<unknown[]> {
|
async toArray(): Promise<unknown[]> {
|
||||||
const tbl = await this.toArrow();
|
const tbl = await this.toArrow();
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||||
return tbl.toArray();
|
return tbl.toArray();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@
|
|||||||
// comes from the exact same library instance. This is not always the case
|
// comes from the exact same library instance. This is not always the case
|
||||||
// and so we must sanitize the input to ensure that it is compatible.
|
// and so we must sanitize the input to ensure that it is compatible.
|
||||||
|
|
||||||
|
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
|
||||||
import {
|
import {
|
||||||
Binary,
|
Binary,
|
||||||
Bool,
|
Bool,
|
||||||
@@ -75,10 +76,9 @@ import {
|
|||||||
Uint64,
|
Uint64,
|
||||||
Union,
|
Union,
|
||||||
Utf8,
|
Utf8,
|
||||||
} from "apache-arrow";
|
} from "./arrow";
|
||||||
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
|
|
||||||
|
|
||||||
function sanitizeMetadata(
|
export function sanitizeMetadata(
|
||||||
metadataLike?: unknown,
|
metadataLike?: unknown,
|
||||||
): Map<string, string> | undefined {
|
): Map<string, string> | undefined {
|
||||||
if (metadataLike === undefined || metadataLike === null) {
|
if (metadataLike === undefined || metadataLike === null) {
|
||||||
@@ -97,7 +97,7 @@ function sanitizeMetadata(
|
|||||||
return metadataLike as Map<string, string>;
|
return metadataLike as Map<string, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeInt(typeLike: object) {
|
export function sanitizeInt(typeLike: object) {
|
||||||
if (
|
if (
|
||||||
!("bitWidth" in typeLike) ||
|
!("bitWidth" in typeLike) ||
|
||||||
typeof typeLike.bitWidth !== "number" ||
|
typeof typeLike.bitWidth !== "number" ||
|
||||||
@@ -111,14 +111,14 @@ function sanitizeInt(typeLike: object) {
|
|||||||
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
|
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeFloat(typeLike: object) {
|
export function sanitizeFloat(typeLike: object) {
|
||||||
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
|
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
|
||||||
throw Error("Expected a Float Type to have a `precision` property");
|
throw Error("Expected a Float Type to have a `precision` property");
|
||||||
}
|
}
|
||||||
return new Float(typeLike.precision as Precision);
|
return new Float(typeLike.precision as Precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeDecimal(typeLike: object) {
|
export function sanitizeDecimal(typeLike: object) {
|
||||||
if (
|
if (
|
||||||
!("scale" in typeLike) ||
|
!("scale" in typeLike) ||
|
||||||
typeof typeLike.scale !== "number" ||
|
typeof typeLike.scale !== "number" ||
|
||||||
@@ -134,14 +134,14 @@ function sanitizeDecimal(typeLike: object) {
|
|||||||
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
|
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeDate(typeLike: object) {
|
export function sanitizeDate(typeLike: object) {
|
||||||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
||||||
throw Error("Expected a Date type to have a `unit` property");
|
throw Error("Expected a Date type to have a `unit` property");
|
||||||
}
|
}
|
||||||
return new Date_(typeLike.unit as DateUnit);
|
return new Date_(typeLike.unit as DateUnit);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeTime(typeLike: object) {
|
export function sanitizeTime(typeLike: object) {
|
||||||
if (
|
if (
|
||||||
!("unit" in typeLike) ||
|
!("unit" in typeLike) ||
|
||||||
typeof typeLike.unit !== "number" ||
|
typeof typeLike.unit !== "number" ||
|
||||||
@@ -155,7 +155,7 @@ function sanitizeTime(typeLike: object) {
|
|||||||
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
|
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeTimestamp(typeLike: object) {
|
export function sanitizeTimestamp(typeLike: object) {
|
||||||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
||||||
throw Error("Expected a Timestamp type to have a `unit` property");
|
throw Error("Expected a Timestamp type to have a `unit` property");
|
||||||
}
|
}
|
||||||
@@ -166,7 +166,7 @@ function sanitizeTimestamp(typeLike: object) {
|
|||||||
return new Timestamp(typeLike.unit, timezone);
|
return new Timestamp(typeLike.unit, timezone);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeTypedTimestamp(
|
export function sanitizeTypedTimestamp(
|
||||||
typeLike: object,
|
typeLike: object,
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||||
Datatype:
|
Datatype:
|
||||||
@@ -182,14 +182,14 @@ function sanitizeTypedTimestamp(
|
|||||||
return new Datatype(timezone);
|
return new Datatype(timezone);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeInterval(typeLike: object) {
|
export function sanitizeInterval(typeLike: object) {
|
||||||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
||||||
throw Error("Expected an Interval type to have a `unit` property");
|
throw Error("Expected an Interval type to have a `unit` property");
|
||||||
}
|
}
|
||||||
return new Interval(typeLike.unit);
|
return new Interval(typeLike.unit);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeList(typeLike: object) {
|
export function sanitizeList(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a List type to have an array-like `children` property",
|
"Expected a List type to have an array-like `children` property",
|
||||||
@@ -201,7 +201,7 @@ function sanitizeList(typeLike: object) {
|
|||||||
return new List(sanitizeField(typeLike.children[0]));
|
return new List(sanitizeField(typeLike.children[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeStruct(typeLike: object) {
|
export function sanitizeStruct(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Struct type to have an array-like `children` property",
|
"Expected a Struct type to have an array-like `children` property",
|
||||||
@@ -210,7 +210,7 @@ function sanitizeStruct(typeLike: object) {
|
|||||||
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
|
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeUnion(typeLike: object) {
|
export function sanitizeUnion(typeLike: object) {
|
||||||
if (
|
if (
|
||||||
!("typeIds" in typeLike) ||
|
!("typeIds" in typeLike) ||
|
||||||
!("mode" in typeLike) ||
|
!("mode" in typeLike) ||
|
||||||
@@ -234,7 +234,7 @@ function sanitizeUnion(typeLike: object) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeTypedUnion(
|
export function sanitizeTypedUnion(
|
||||||
typeLike: object,
|
typeLike: object,
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||||
UnionType: typeof DenseUnion | typeof SparseUnion,
|
UnionType: typeof DenseUnion | typeof SparseUnion,
|
||||||
@@ -256,7 +256,7 @@ function sanitizeTypedUnion(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeFixedSizeBinary(typeLike: object) {
|
export function sanitizeFixedSizeBinary(typeLike: object) {
|
||||||
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
|
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a FixedSizeBinary type to have a `byteWidth` property",
|
"Expected a FixedSizeBinary type to have a `byteWidth` property",
|
||||||
@@ -265,7 +265,7 @@ function sanitizeFixedSizeBinary(typeLike: object) {
|
|||||||
return new FixedSizeBinary(typeLike.byteWidth);
|
return new FixedSizeBinary(typeLike.byteWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeFixedSizeList(typeLike: object) {
|
export function sanitizeFixedSizeList(typeLike: object) {
|
||||||
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
|
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
|
||||||
throw Error("Expected a FixedSizeList type to have a `listSize` property");
|
throw Error("Expected a FixedSizeList type to have a `listSize` property");
|
||||||
}
|
}
|
||||||
@@ -283,7 +283,7 @@ function sanitizeFixedSizeList(typeLike: object) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeMap(typeLike: object) {
|
export function sanitizeMap(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Map type to have an array-like `children` property",
|
"Expected a Map type to have an array-like `children` property",
|
||||||
@@ -300,14 +300,14 @@ function sanitizeMap(typeLike: object) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeDuration(typeLike: object) {
|
export function sanitizeDuration(typeLike: object) {
|
||||||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
|
||||||
throw Error("Expected a Duration type to have a `unit` property");
|
throw Error("Expected a Duration type to have a `unit` property");
|
||||||
}
|
}
|
||||||
return new Duration(typeLike.unit);
|
return new Duration(typeLike.unit);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeDictionary(typeLike: object) {
|
export function sanitizeDictionary(typeLike: object) {
|
||||||
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
|
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
|
||||||
throw Error("Expected a Dictionary type to have an `id` property");
|
throw Error("Expected a Dictionary type to have an `id` property");
|
||||||
}
|
}
|
||||||
@@ -329,7 +329,7 @@ function sanitizeDictionary(typeLike: object) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||||
function sanitizeType(typeLike: unknown): DataType<any> {
|
export function sanitizeType(typeLike: unknown): DataType<any> {
|
||||||
if (typeof typeLike !== "object" || typeLike === null) {
|
if (typeof typeLike !== "object" || typeLike === null) {
|
||||||
throw Error("Expected a Type but object was null/undefined");
|
throw Error("Expected a Type but object was null/undefined");
|
||||||
}
|
}
|
||||||
@@ -449,7 +449,7 @@ function sanitizeType(typeLike: unknown): DataType<any> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeField(fieldLike: unknown): Field {
|
export function sanitizeField(fieldLike: unknown): Field {
|
||||||
if (fieldLike instanceof Field) {
|
if (fieldLike instanceof Field) {
|
||||||
return fieldLike;
|
return fieldLike;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { Schema, tableFromIPC } from "apache-arrow";
|
import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow";
|
||||||
import { Data, fromDataToBuffer } from "./arrow";
|
|
||||||
|
import { getRegistry } from "./embedding/registry";
|
||||||
import { IndexOptions } from "./indices";
|
import { IndexOptions } from "./indices";
|
||||||
import {
|
import {
|
||||||
AddColumnsSql,
|
AddColumnsSql,
|
||||||
@@ -122,8 +123,14 @@ export class Table {
|
|||||||
*/
|
*/
|
||||||
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
||||||
const mode = options?.mode ?? "append";
|
const mode = options?.mode ?? "append";
|
||||||
|
const schema = await this.schema();
|
||||||
|
const registry = getRegistry();
|
||||||
|
const functions = registry.parseFunctions(schema.metadata);
|
||||||
|
|
||||||
const buffer = await fromDataToBuffer(data);
|
const buffer = await fromDataToBuffer(
|
||||||
|
data,
|
||||||
|
functions.values().next().value,
|
||||||
|
);
|
||||||
await this.inner.add(buffer, mode);
|
await this.inner.add(buffer, mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
15383
nodejs/package-lock.json
generated
15383
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.5.0",
|
"version": "0.5.0",
|
||||||
"main": "./dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "./dist/index.d.ts",
|
"exports": {
|
||||||
|
".": "./dist/index.js",
|
||||||
|
"./embedding": "./dist/embedding/index.js"
|
||||||
|
},
|
||||||
|
"types": "dist/index.d.ts",
|
||||||
"napi": {
|
"napi": {
|
||||||
"name": "lancedb",
|
"name": "lancedb",
|
||||||
"triples": {
|
"triples": {
|
||||||
@@ -62,6 +66,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"apache-arrow": "^15.0.0",
|
"apache-arrow": "^15.0.0",
|
||||||
"openai": "^4.29.2"
|
"openai": "^4.29.2",
|
||||||
|
"reflect-metadata": "^0.2.2"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,9 @@
|
|||||||
"outDir": "./dist",
|
"outDir": "./dist",
|
||||||
"strict": true,
|
"strict": true,
|
||||||
"allowJs": true,
|
"allowJs": true,
|
||||||
"resolveJsonModule": true
|
"resolveJsonModule": true,
|
||||||
|
"emitDecoratorMetadata": true,
|
||||||
|
"experimentalDecorators": true
|
||||||
},
|
},
|
||||||
"exclude": ["./dist/*"],
|
"exclude": ["./dist/*"],
|
||||||
"typedocOptions": {
|
"typedocOptions": {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.8.0"
|
current_version = "0.8.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.8.0"
|
version = "0.8.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
|||||||
# version in Cargo.toml
|
# version in Cargo.toml
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.11.0",
|
"pylance==0.11.1",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"requests>=2.31.0",
|
"requests>=2.31.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
|
|||||||
@@ -509,7 +509,7 @@ class AsyncConnection(object):
|
|||||||
return self._inner.__repr__()
|
return self._inner.__repr__()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *_):
|
def __exit__(self, *_):
|
||||||
self.close()
|
self.close()
|
||||||
@@ -779,7 +779,7 @@ class AsyncConnection(object):
|
|||||||
name: str,
|
name: str,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
) -> Table:
|
) -> AsyncTable:
|
||||||
"""Open a Lance Table in the database.
|
"""Open a Lance Table in the database.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -296,6 +296,13 @@ async def test_close(tmp_path):
|
|||||||
await db.table_names()
|
await db.table_names()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manager(tmp_path):
|
||||||
|
with await lancedb.connect_async(tmp_path) as db:
|
||||||
|
assert db.is_open()
|
||||||
|
assert not db.is_open()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_mode_async(tmp_path):
|
async def test_create_mode_async(tmp_path):
|
||||||
db = await lancedb.connect_async(tmp_path)
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
|||||||
@@ -1751,7 +1751,7 @@ impl TableInternal for NativeTable {
|
|||||||
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
|
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
|
||||||
}
|
}
|
||||||
let job = builder.try_build()?;
|
let job = builder.try_build()?;
|
||||||
let new_dataset = job.execute_reader(new_data).await?;
|
let (new_dataset, _stats) = job.execute_reader(new_data).await?;
|
||||||
self.dataset.set_latest(new_dataset.as_ref().clone()).await;
|
self.dataset.set_latest(new_dataset.as_ref().clone()).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user