feat(nodejs): table.search functionality (#1341)

closes https://github.com/lancedb/lancedb/issues/1256
This commit is contained in:
Cory Grinstead
2024-06-04 14:04:03 -05:00
committed by GitHub
parent d9fb6457e1
commit 70f92f19a6
6 changed files with 163 additions and 10 deletions

View File

@@ -31,6 +31,7 @@ import {
Schema,
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema, register } from "../lancedb/embedding";
import { Index } from "../lancedb/indices";
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -493,3 +494,99 @@ describe("when optimizing a dataset", () => {
expect(stats.prune.oldVersionsRemoved).toBe(3);
});
});
describe("table.search", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
test("can search using a string", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 1;
}
embeddingDataType(): arrow.Float {
return new Float32();
}
// Hardcoded embeddings for the sake of testing
async computeQueryEmbeddings(_data: string) {
switch (_data) {
case "greetings":
return [0.1];
case "farewell":
return [0.2];
default:
return null as never;
}
}
// Hardcoded embeddings for the sake of testing
async computeSourceEmbeddings(data: string[]) {
return data.map((s) => {
switch (s) {
case "hello world":
return [0.1];
case "goodbye world":
return [0.2];
default:
return null as never;
}
});
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });
const results = await table.search("greetings").then((r) => r.toArray());
expect(results[0].text).toBe(data[0].text);
const results2 = await table.search("farewell").then((r) => r.toArray());
expect(results2[0].text).toBe(data[1].text);
});
test("rejects if no embedding function provided", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
expect(table.search("hello")).rejects.toThrow(
"No embedding functions are defined in the table",
);
});
test.each([
[0.4, 0.5, 0.599], // number[]
Float32Array.of(0.4, 0.5, 0.599), // Float32Array
Float64Array.of(0.4, 0.5, 0.599), // Float64Array
])("can search using vectorlike datatypes", async (vectorlike) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
// biome-ignore lint/suspicious/noExplicitAny: test
const results: any[] = await table.search(vectorlike).toArray();
expect(results.length).toBe(2);
expect(results[0].text).toBe(data[1].text);
});
});

View File

@@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export type IntoVector = Float32Array | Float64Array | number[];
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;

View File

@@ -19,6 +19,7 @@ import {
FixedSizeList,
Float,
Float32,
type IntoVector,
isDataType,
isFixedSizeList,
isFloat,
@@ -169,9 +170,7 @@ export abstract class EmbeddingFunction<
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(
data: T,
): Promise<number[] | Float32Array | Float64Array> {
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);

View File

@@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry {
* Register an embedding function
* @param name The name of the function
* @param func The function to register
* @throws Error if the function is already registered
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
@@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry {
this.#functions.clear();
}
/**
* @ignore
*/
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
tableFromIPC,
} from "./arrow";
import { type IvfPqOptions } from "./indices";
import {
RecordBatchIterator as NativeBatchIterator,
@@ -108,9 +113,12 @@ export class QueryBase<
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/
select(
columns: string[] | Map<string, string> | Record<string, string>,
columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
}
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
@@ -370,9 +378,8 @@ export class Query extends QueryBase<NativeQuery, Query> {
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: unknown): VectorQuery {
// biome-ignore lint/suspicious/noExplicitAny: skip
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
}
}

View File

@@ -11,15 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Table as ArrowTable,
Data,
IntoVector,
Schema,
fromDataToBuffer,
tableFromIPC,
} from "./arrow";
import { getRegistry } from "./embedding/registry";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices";
import {
AddColumnsSql,
@@ -115,6 +117,14 @@ export class Table {
return this.inner.display();
}
async #getEmbeddingFunctions(): Promise<
Map<string, EmbeddingFunctionConfig>
> {
const schema = await this.schema();
const registry = getRegistry();
return registry.parseFunctions(schema.metadata);
}
/** Get the schema of the table. */
async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema();
@@ -276,6 +286,40 @@ export class Table {
return new Query(this.inner);
}
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
*/
search(query: string): Promise<VectorQuery>;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {IntoVector} query - the query vector
*/
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.#getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
}
}
/**
* Search the table with a given query vector.
*
@@ -283,7 +327,7 @@ export class Table {
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
vectorSearch(vector: unknown): VectorQuery {
vectorSearch(vector: IntoVector): VectorQuery {
return this.query().nearestTo(vector);
}