diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 47e33de0..aa0552a2 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -31,6 +31,7 @@ import { Schema, makeArrowTable, } from "../lancedb/arrow"; +import { EmbeddingFunction, LanceSchema, register } from "../lancedb/embedding"; import { Index } from "../lancedb/indices"; // biome-ignore lint/suspicious/noExplicitAny: @@ -493,3 +494,99 @@ describe("when optimizing a dataset", () => { expect(stats.prune.oldVersionsRemoved).toBe(3); }); }); + +describe("table.search", () => { + let tmpDir: tmp.DirResult; + beforeEach(() => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + }); + afterEach(() => tmpDir.removeCallback()); + + test("can search using a string", async () => { + @register() + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 1; + } + embeddingDataType(): arrow.Float { + return new Float32(); + } + + // Hardcoded embeddings for the sake of testing + async computeQueryEmbeddings(_data: string) { + switch (_data) { + case "greetings": + return [0.1]; + case "farewell": + return [0.2]; + default: + return null as never; + } + } + + // Hardcoded embeddings for the sake of testing + async computeSourceEmbeddings(data: string[]) { + return data.map((s) => { + switch (s) { + case "hello world": + return [0.1]; + case "goodbye world": + return [0.2]; + default: + return null as never; + } + }); + } + } + + const func = new MockEmbeddingFunction(); + const schema = LanceSchema({ + text: func.sourceField(new arrow.Utf8()), + vector: func.vectorField(), + }); + const db = await connect(tmpDir.name); + const data = [{ text: "hello world" }, { text: "goodbye world" }]; + const table = await db.createTable("test", data, { schema }); + + const results = await table.search("greetings").then((r) => r.toArray()); + expect(results[0].text).toBe(data[0].text); + + const results2 = await table.search("farewell").then((r) => r.toArray()); + expect(results2[0].text).toBe(data[1].text); + }); + + test("rejects if no embedding function provided", async () => { + const db = await connect(tmpDir.name); + const data = [ + { text: "hello world", vector: [0.1, 0.2, 0.3] }, + { text: "goodbye world", vector: [0.4, 0.5, 0.6] }, + ]; + const table = await db.createTable("test", data); + + expect(table.search("hello")).rejects.toThrow( + "No embedding functions are defined in the table", + ); + }); + + test.each([ + [0.4, 0.5, 0.599], // number[] + Float32Array.of(0.4, 0.5, 0.599), // Float32Array + Float64Array.of(0.4, 0.5, 0.599), // Float64Array + ])("can search using vectorlike datatypes", async (vectorlike) => { + const db = await connect(tmpDir.name); + const data = [ + { text: "hello world", vector: [0.1, 0.2, 0.3] }, + { text: "goodbye world", vector: [0.4, 0.5, 0.6] }, + ]; + const table = await db.createTable("test", data); + + // biome-ignore lint/suspicious/noExplicitAny: test + const results: any[] = await table.search(vectorlike).toArray(); + + expect(results.length).toBe(2); + expect(results[0].text).toBe(data[1].text); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 1836dc30..8309c161 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize"; export * from "apache-arrow"; +export type IntoVector = Float32Array | Float64Array | number[]; + export function isArrowTable(value: object): value is ArrowTable { if (value instanceof ArrowTable) return true; return "schema" in value && "batches" in value; diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 8e752a8f..e2e098a3 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -19,6 +19,7 @@ import { FixedSizeList, Float, Float32, + type IntoVector, isDataType, isFixedSizeList, isFloat, @@ -169,9 +170,7 @@ export abstract class EmbeddingFunction< /** Compute the embeddings for a single query */ - async computeQueryEmbeddings( - data: T, - ): Promise { + async computeQueryEmbeddings(data: T): Promise { return this.computeSourceEmbeddings([data]).then( (embeddings) => embeddings[0], ); diff --git a/nodejs/lancedb/embedding/registry.ts b/nodejs/lancedb/embedding/registry.ts index f79e3f9d..47e52917 100644 --- a/nodejs/lancedb/embedding/registry.ts +++ b/nodejs/lancedb/embedding/registry.ts @@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry { * Register an embedding function * @param name The name of the function * @param func The function to register + * @throws Error if the function is already registered */ register( this: EmbeddingFunctionRegistry, @@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry { this.#functions.clear(); } + /** + * @ignore + */ parseFunctions( this: EmbeddingFunctionRegistry, metadata: Map, diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 0ac40378..dea06155 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow"; +import { + Table as ArrowTable, + type IntoVector, + RecordBatch, + tableFromIPC, +} from "./arrow"; import { type IvfPqOptions } from "./indices"; import { RecordBatchIterator as NativeBatchIterator, @@ -108,9 +113,12 @@ export class QueryBase< * object insertion order is easy to get wrong and `Map` is more foolproof. */ select( - columns: string[] | Map | Record, + columns: string[] | Map | Record | string, ): QueryType { let columnTuples: [string, string][]; + if (typeof columns === "string") { + columns = [columns]; + } if (Array.isArray(columns)) { columnTuples = columns.map((c) => [c, c]); } else if (columns instanceof Map) { @@ -370,9 +378,8 @@ export class Query extends QueryBase { * Vector searches always have a `limit`. If `limit` has not been called then * a default `limit` of 10 will be used. @see {@link Query#limit} */ - nearestTo(vector: unknown): VectorQuery { - // biome-ignore lint/suspicious/noExplicitAny: skip - const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any)); + nearestTo(vector: IntoVector): VectorQuery { + const vectorQuery = this.inner.nearestTo(Float32Array.from(vector)); return new VectorQuery(vectorQuery); } } diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 9d0f8adf..9b1363aa 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -11,15 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + import { Table as ArrowTable, Data, + IntoVector, Schema, fromDataToBuffer, tableFromIPC, } from "./arrow"; -import { getRegistry } from "./embedding/registry"; +import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { IndexOptions } from "./indices"; import { AddColumnsSql, @@ -115,6 +117,14 @@ export class Table { return this.inner.display(); } + async #getEmbeddingFunctions(): Promise< + Map + > { + const schema = await this.schema(); + const registry = getRegistry(); + return registry.parseFunctions(schema.metadata); + } + /** Get the schema of the table. */ async schema(): Promise { const schemaBuf = await this.inner.schema(); @@ -276,6 +286,40 @@ export class Table { return new Query(this.inner); } + /** + * Create a search query to find the nearest neighbors + * of the given query vector + * @param {string} query - the query. This will be converted to a vector using the table's provided embedding function + * @rejects {Error} If no embedding functions are defined in the table + */ + search(query: string): Promise; + /** + * Create a search query to find the nearest neighbors + * of the given query vector + * @param {IntoVector} query - the query vector + */ + search(query: IntoVector): VectorQuery; + search(query: string | IntoVector): Promise | VectorQuery { + if (typeof query !== "string") { + return this.vectorSearch(query); + } else { + return this.#getEmbeddingFunctions().then(async (functions) => { + // TODO: Support multiple embedding functions + const embeddingFunc: EmbeddingFunctionConfig | undefined = functions + .values() + .next().value; + if (!embeddingFunc) { + return Promise.reject( + new Error("No embedding functions are defined in the table"), + ); + } + const embeddings = + await embeddingFunc.function.computeQueryEmbeddings(query); + return this.query().nearestTo(embeddings); + }); + } + } + /** * Search the table with a given query vector. * @@ -283,7 +327,7 @@ export class Table { * is the same thing as calling `nearestTo` on the builder returned * by `query`. @see {@link Query#nearestTo} for more details. */ - vectorSearch(vector: unknown): VectorQuery { + vectorSearch(vector: IntoVector): VectorQuery { return this.query().nearestTo(vector); }