feat(nodejs): make tbl.search chainable (#1421)

so this was annoying me when writing the docs. 

for a `search` query, one needed to chain `async` calls.

```ts
const res = await (await tbl.search("greetings")).toArray()
```

now the promise will be deferred until the query is collected, leading
to a more functional API

```ts
const res = await tbl.search("greetings").toArray()
```
This commit is contained in:
Cory Grinstead
2024-07-02 14:31:57 -05:00
committed by GitHub
parent 46c6ff889d
commit b8ccea9f71
6 changed files with 112 additions and 57 deletions

View File

@@ -706,10 +706,10 @@ describe("table.search", () => {
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });
const results = await table.search("greetings").then((r) => r.toArray());
const results = await table.search("greetings").toArray();
expect(results[0].text).toBe(data[0].text);
const results2 = await table.search("farewell").then((r) => r.toArray());
const results2 = await table.search("farewell").toArray();
expect(results2[0].text).toBe(data[1].text);
});
@@ -721,7 +721,7 @@ describe("table.search", () => {
];
const table = await db.createTable("test", data);
expect(table.search("hello")).rejects.toThrow(
expect(table.search("hello").toArray()).rejects.toThrow(
"No embedding functions are defined in the table",
);
});

View File

@@ -97,7 +97,11 @@ export type TableLike =
| ArrowTable
| { schema: SchemaLike; batches: RecordBatchLike[] };
export type IntoVector = Float32Array | Float64Array | number[];
export type IntoVector =
| Float32Array
| Float64Array
| number[]
| Promise<Float32Array | Float64Array | number[]>;
export function isArrowTable(value: object): value is TableLike {
if (value instanceof ArrowTable) return true;

View File

@@ -181,7 +181,7 @@ export abstract class EmbeddingFunction<
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
async computeQueryEmbeddings(data: T): Promise<Awaited<IntoVector>> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);

View File

@@ -89,15 +89,26 @@ export interface QueryExecutionOptions {
}
/** Common methods supported by all query types */
export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery,
QueryType,
> implements AsyncIterable<RecordBatch>
export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
implements AsyncIterable<RecordBatch>
{
protected constructor(protected inner: NativeQueryType) {
protected constructor(
protected inner: NativeQueryType | Promise<NativeQueryType>,
) {
// intentionally empty
}
// call a function on the inner (either a promise or the actual object)
protected doCall(fn: (inner: NativeQueryType) => void) {
if (this.inner instanceof Promise) {
this.inner = this.inner.then((inner) => {
fn(inner);
return inner;
});
} else {
fn(this.inner);
}
}
/**
* A filter statement to be applied to this query.
*
@@ -110,16 +121,16 @@ export class QueryBase<
* Filtering performance can often be improved by creating a scalar index
* on the filter column(s).
*/
where(predicate: string): QueryType {
this.inner.onlyIf(predicate);
return this as unknown as QueryType;
where(predicate: string): this {
this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate));
return this;
}
/**
* A filter statement to be applied to this query.
* @alias where
* @deprecated Use `where` instead
*/
filter(predicate: string): QueryType {
filter(predicate: string): this {
return this.where(predicate);
}
@@ -155,7 +166,7 @@ export class QueryBase<
*/
select(
columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType {
): this {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
@@ -167,8 +178,10 @@ export class QueryBase<
} else {
columnTuples = Object.entries(columns);
}
this.inner.select(columnTuples);
return this as unknown as QueryType;
this.doCall((inner: NativeQueryType) => {
inner.select(columnTuples);
});
return this;
}
/**
@@ -177,15 +190,19 @@ export class QueryBase<
* By default, a plain search has no limit. If this method is not
* called then every valid row from the table will be returned.
*/
limit(limit: number): QueryType {
this.inner.limit(limit);
return this as unknown as QueryType;
limit(limit: number): this {
this.doCall((inner: NativeQueryType) => inner.limit(limit));
return this;
}
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
return this.inner.execute(options?.maxBatchLength);
if (this.inner instanceof Promise) {
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
} else {
return this.inner.execute(options?.maxBatchLength);
}
}
/**
@@ -214,7 +231,13 @@ export class QueryBase<
/** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
const batches = [];
for await (const batch of new RecordBatchIterable(this.inner, options)) {
let inner;
if (this.inner instanceof Promise) {
inner = await this.inner;
} else {
inner = this.inner;
}
for await (const batch of new RecordBatchIterable(inner, options)) {
batches.push(batch);
}
return new ArrowTable(batches);
@@ -258,8 +281,8 @@ export interface ExecutableQuery {}
*
* This builder can be reused to execute the query many times.
*/
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
constructor(inner: NativeVectorQuery) {
export class VectorQuery extends QueryBase<NativeVectorQuery> {
constructor(inner: NativeVectorQuery | Promise<NativeVectorQuery>) {
super(inner);
}
@@ -286,7 +309,8 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* you the desired recall.
*/
nprobes(nprobes: number): VectorQuery {
this.inner.nprobes(nprobes);
super.doCall((inner) => inner.nprobes(nprobes));
return this;
}
@@ -300,7 +324,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* whose data type is a fixed-size-list of floats.
*/
column(column: string): VectorQuery {
this.inner.column(column);
super.doCall((inner) => inner.column(column));
return this;
}
@@ -321,7 +345,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
distanceType(
distanceType: Required<IvfPqOptions>["distanceType"],
): VectorQuery {
this.inner.distanceType(distanceType);
super.doCall((inner) => inner.distanceType(distanceType));
return this;
}
@@ -355,7 +379,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* distance between the query vector and the actual uncompressed vector.
*/
refineFactor(refineFactor: number): VectorQuery {
this.inner.refineFactor(refineFactor);
super.doCall((inner) => inner.refineFactor(refineFactor));
return this;
}
@@ -380,7 +404,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* factor can often help restore some of the results lost by post filtering.
*/
postfilter(): VectorQuery {
this.inner.postfilter();
super.doCall((inner) => inner.postfilter());
return this;
}
@@ -394,13 +418,13 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* calculate your recall to select an appropriate value for nprobes.
*/
bypassVectorIndex(): VectorQuery {
this.inner.bypassVectorIndex();
super.doCall((inner) => inner.bypassVectorIndex());
return this;
}
}
/** A builder for LanceDB queries. */
export class Query extends QueryBase<NativeQuery, Query> {
export class Query extends QueryBase<NativeQuery> {
constructor(tbl: NativeTable) {
super(tbl.query());
}
@@ -443,7 +467,37 @@ export class Query extends QueryBase<NativeQuery, Query> {
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
if (this.inner instanceof Promise) {
const nativeQuery = this.inner.then(async (inner) => {
if (vector instanceof Promise) {
const arr = await vector.then((v) => Float32Array.from(v));
return inner.nearestTo(arr);
} else {
return inner.nearestTo(Float32Array.from(vector));
}
});
return new VectorQuery(nativeQuery);
}
if (vector instanceof Promise) {
const res = (async () => {
try {
const v = await vector;
const arr = Float32Array.from(v);
//
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
const value: any = this.nearestTo(arr);
const inner = value.inner as
| NativeVectorQuery
| Promise<NativeVectorQuery>;
return inner;
} catch (e) {
return Promise.reject(e);
}
})();
return new VectorQuery(res);
} else {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
}
}
}

View File

@@ -122,9 +122,8 @@ export class RemoteTable extends Table {
query(): import("..").Query {
throw new Error("query() is not yet supported on the LanceDB cloud");
}
search(query: IntoVector): VectorQuery;
search(query: string): Promise<VectorQuery>;
search(_query: string | IntoVector): VectorQuery | Promise<VectorQuery> {
search(_query: string | IntoVector): VectorQuery {
throw new Error("search() is not yet supported on the LanceDB cloud");
}
vectorSearch(_vector: unknown): import("..").VectorQuery {

View File

@@ -244,9 +244,9 @@ export abstract class Table {
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
* @note If no embedding functions are defined in the table, this will error when collecting the results.
*/
abstract search(query: string): Promise<VectorQuery>;
abstract search(query: string): VectorQuery;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
@@ -502,28 +502,26 @@ export class LocalTable extends Table {
query(): Query {
return new Query(this.inner);
}
search(query: string): Promise<VectorQuery>;
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
search(query: string | IntoVector): VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
const queryPromise = this.getEmbeddingFunctions().then(
async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
return await embeddingFunc.function.computeQueryEmbeddings(query);
},
);
return this.query().nearestTo(queryPromise);
}
}