mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
feat(nodejs): make tbl.search chainable (#1421)
so this was annoying me when writing the docs.
for a `search` query, one needed to chain `async` calls.
```ts
const res = await (await tbl.search("greetings")).toArray()
```
now the promise will be deferred until the query is collected, leading
to a more functional API
```ts
const res = await tbl.search("greetings").toArray()
```
This commit is contained in:
@@ -706,10 +706,10 @@ describe("table.search", () => {
|
||||
const data = [{ text: "hello world" }, { text: "goodbye world" }];
|
||||
const table = await db.createTable("test", data, { schema });
|
||||
|
||||
const results = await table.search("greetings").then((r) => r.toArray());
|
||||
const results = await table.search("greetings").toArray();
|
||||
expect(results[0].text).toBe(data[0].text);
|
||||
|
||||
const results2 = await table.search("farewell").then((r) => r.toArray());
|
||||
const results2 = await table.search("farewell").toArray();
|
||||
expect(results2[0].text).toBe(data[1].text);
|
||||
});
|
||||
|
||||
@@ -721,7 +721,7 @@ describe("table.search", () => {
|
||||
];
|
||||
const table = await db.createTable("test", data);
|
||||
|
||||
expect(table.search("hello")).rejects.toThrow(
|
||||
expect(table.search("hello").toArray()).rejects.toThrow(
|
||||
"No embedding functions are defined in the table",
|
||||
);
|
||||
});
|
||||
|
||||
@@ -97,7 +97,11 @@ export type TableLike =
|
||||
| ArrowTable
|
||||
| { schema: SchemaLike; batches: RecordBatchLike[] };
|
||||
|
||||
export type IntoVector = Float32Array | Float64Array | number[];
|
||||
export type IntoVector =
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | number[]>;
|
||||
|
||||
export function isArrowTable(value: object): value is TableLike {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
|
||||
@@ -181,7 +181,7 @@ export abstract class EmbeddingFunction<
|
||||
/**
|
||||
Compute the embeddings for a single query
|
||||
*/
|
||||
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
|
||||
async computeQueryEmbeddings(data: T): Promise<Awaited<IntoVector>> {
|
||||
return this.computeSourceEmbeddings([data]).then(
|
||||
(embeddings) => embeddings[0],
|
||||
);
|
||||
|
||||
@@ -89,15 +89,26 @@ export interface QueryExecutionOptions {
|
||||
}
|
||||
|
||||
/** Common methods supported by all query types */
|
||||
export class QueryBase<
|
||||
NativeQueryType extends NativeQuery | NativeVectorQuery,
|
||||
QueryType,
|
||||
> implements AsyncIterable<RecordBatch>
|
||||
export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
|
||||
implements AsyncIterable<RecordBatch>
|
||||
{
|
||||
protected constructor(protected inner: NativeQueryType) {
|
||||
protected constructor(
|
||||
protected inner: NativeQueryType | Promise<NativeQueryType>,
|
||||
) {
|
||||
// intentionally empty
|
||||
}
|
||||
|
||||
// call a function on the inner (either a promise or the actual object)
|
||||
protected doCall(fn: (inner: NativeQueryType) => void) {
|
||||
if (this.inner instanceof Promise) {
|
||||
this.inner = this.inner.then((inner) => {
|
||||
fn(inner);
|
||||
return inner;
|
||||
});
|
||||
} else {
|
||||
fn(this.inner);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* A filter statement to be applied to this query.
|
||||
*
|
||||
@@ -110,16 +121,16 @@ export class QueryBase<
|
||||
* Filtering performance can often be improved by creating a scalar index
|
||||
* on the filter column(s).
|
||||
*/
|
||||
where(predicate: string): QueryType {
|
||||
this.inner.onlyIf(predicate);
|
||||
return this as unknown as QueryType;
|
||||
where(predicate: string): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate));
|
||||
return this;
|
||||
}
|
||||
/**
|
||||
* A filter statement to be applied to this query.
|
||||
* @alias where
|
||||
* @deprecated Use `where` instead
|
||||
*/
|
||||
filter(predicate: string): QueryType {
|
||||
filter(predicate: string): this {
|
||||
return this.where(predicate);
|
||||
}
|
||||
|
||||
@@ -155,7 +166,7 @@ export class QueryBase<
|
||||
*/
|
||||
select(
|
||||
columns: string[] | Map<string, string> | Record<string, string> | string,
|
||||
): QueryType {
|
||||
): this {
|
||||
let columnTuples: [string, string][];
|
||||
if (typeof columns === "string") {
|
||||
columns = [columns];
|
||||
@@ -167,8 +178,10 @@ export class QueryBase<
|
||||
} else {
|
||||
columnTuples = Object.entries(columns);
|
||||
}
|
||||
this.inner.select(columnTuples);
|
||||
return this as unknown as QueryType;
|
||||
this.doCall((inner: NativeQueryType) => {
|
||||
inner.select(columnTuples);
|
||||
});
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -177,15 +190,19 @@ export class QueryBase<
|
||||
* By default, a plain search has no limit. If this method is not
|
||||
* called then every valid row from the table will be returned.
|
||||
*/
|
||||
limit(limit: number): QueryType {
|
||||
this.inner.limit(limit);
|
||||
return this as unknown as QueryType;
|
||||
limit(limit: number): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.limit(limit));
|
||||
return this;
|
||||
}
|
||||
|
||||
protected nativeExecute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): Promise<NativeBatchIterator> {
|
||||
return this.inner.execute(options?.maxBatchLength);
|
||||
if (this.inner instanceof Promise) {
|
||||
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
|
||||
} else {
|
||||
return this.inner.execute(options?.maxBatchLength);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -214,7 +231,13 @@ export class QueryBase<
|
||||
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
|
||||
const batches = [];
|
||||
for await (const batch of new RecordBatchIterable(this.inner, options)) {
|
||||
let inner;
|
||||
if (this.inner instanceof Promise) {
|
||||
inner = await this.inner;
|
||||
} else {
|
||||
inner = this.inner;
|
||||
}
|
||||
for await (const batch of new RecordBatchIterable(inner, options)) {
|
||||
batches.push(batch);
|
||||
}
|
||||
return new ArrowTable(batches);
|
||||
@@ -258,8 +281,8 @@ export interface ExecutableQuery {}
|
||||
*
|
||||
* This builder can be reused to execute the query many times.
|
||||
*/
|
||||
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
constructor(inner: NativeVectorQuery) {
|
||||
export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
constructor(inner: NativeVectorQuery | Promise<NativeVectorQuery>) {
|
||||
super(inner);
|
||||
}
|
||||
|
||||
@@ -286,7 +309,8 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* you the desired recall.
|
||||
*/
|
||||
nprobes(nprobes: number): VectorQuery {
|
||||
this.inner.nprobes(nprobes);
|
||||
super.doCall((inner) => inner.nprobes(nprobes));
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -300,7 +324,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* whose data type is a fixed-size-list of floats.
|
||||
*/
|
||||
column(column: string): VectorQuery {
|
||||
this.inner.column(column);
|
||||
super.doCall((inner) => inner.column(column));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -321,7 +345,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
distanceType(
|
||||
distanceType: Required<IvfPqOptions>["distanceType"],
|
||||
): VectorQuery {
|
||||
this.inner.distanceType(distanceType);
|
||||
super.doCall((inner) => inner.distanceType(distanceType));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -355,7 +379,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* distance between the query vector and the actual uncompressed vector.
|
||||
*/
|
||||
refineFactor(refineFactor: number): VectorQuery {
|
||||
this.inner.refineFactor(refineFactor);
|
||||
super.doCall((inner) => inner.refineFactor(refineFactor));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -380,7 +404,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* factor can often help restore some of the results lost by post filtering.
|
||||
*/
|
||||
postfilter(): VectorQuery {
|
||||
this.inner.postfilter();
|
||||
super.doCall((inner) => inner.postfilter());
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -394,13 +418,13 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* calculate your recall to select an appropriate value for nprobes.
|
||||
*/
|
||||
bypassVectorIndex(): VectorQuery {
|
||||
this.inner.bypassVectorIndex();
|
||||
super.doCall((inner) => inner.bypassVectorIndex());
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
export class Query extends QueryBase<NativeQuery, Query> {
|
||||
export class Query extends QueryBase<NativeQuery> {
|
||||
constructor(tbl: NativeTable) {
|
||||
super(tbl.query());
|
||||
}
|
||||
@@ -443,7 +467,37 @@ export class Query extends QueryBase<NativeQuery, Query> {
|
||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||
*/
|
||||
nearestTo(vector: IntoVector): VectorQuery {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
return new VectorQuery(vectorQuery);
|
||||
if (this.inner instanceof Promise) {
|
||||
const nativeQuery = this.inner.then(async (inner) => {
|
||||
if (vector instanceof Promise) {
|
||||
const arr = await vector.then((v) => Float32Array.from(v));
|
||||
return inner.nearestTo(arr);
|
||||
} else {
|
||||
return inner.nearestTo(Float32Array.from(vector));
|
||||
}
|
||||
});
|
||||
return new VectorQuery(nativeQuery);
|
||||
}
|
||||
if (vector instanceof Promise) {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.nearestTo(arr);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
return inner;
|
||||
} catch (e) {
|
||||
return Promise.reject(e);
|
||||
}
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
return new VectorQuery(vectorQuery);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,9 +122,8 @@ export class RemoteTable extends Table {
|
||||
query(): import("..").Query {
|
||||
throw new Error("query() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
search(query: IntoVector): VectorQuery;
|
||||
search(query: string): Promise<VectorQuery>;
|
||||
search(_query: string | IntoVector): VectorQuery | Promise<VectorQuery> {
|
||||
|
||||
search(_query: string | IntoVector): VectorQuery {
|
||||
throw new Error("search() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
vectorSearch(_vector: unknown): import("..").VectorQuery {
|
||||
|
||||
@@ -244,9 +244,9 @@ export abstract class Table {
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
|
||||
* @rejects {Error} If no embedding functions are defined in the table
|
||||
* @note If no embedding functions are defined in the table, this will error when collecting the results.
|
||||
*/
|
||||
abstract search(query: string): Promise<VectorQuery>;
|
||||
abstract search(query: string): VectorQuery;
|
||||
/**
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
@@ -502,28 +502,26 @@ export class LocalTable extends Table {
|
||||
query(): Query {
|
||||
return new Query(this.inner);
|
||||
}
|
||||
|
||||
search(query: string): Promise<VectorQuery>;
|
||||
|
||||
search(query: IntoVector): VectorQuery;
|
||||
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
|
||||
search(query: string | IntoVector): VectorQuery {
|
||||
if (typeof query !== "string") {
|
||||
return this.vectorSearch(query);
|
||||
} else {
|
||||
return this.getEmbeddingFunctions().then(async (functions) => {
|
||||
// TODO: Support multiple embedding functions
|
||||
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
|
||||
.values()
|
||||
.next().value;
|
||||
if (!embeddingFunc) {
|
||||
return Promise.reject(
|
||||
new Error("No embedding functions are defined in the table"),
|
||||
);
|
||||
}
|
||||
const embeddings =
|
||||
await embeddingFunc.function.computeQueryEmbeddings(query);
|
||||
return this.query().nearestTo(embeddings);
|
||||
});
|
||||
const queryPromise = this.getEmbeddingFunctions().then(
|
||||
async (functions) => {
|
||||
// TODO: Support multiple embedding functions
|
||||
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
|
||||
.values()
|
||||
.next().value;
|
||||
if (!embeddingFunc) {
|
||||
return Promise.reject(
|
||||
new Error("No embedding functions are defined in the table"),
|
||||
);
|
||||
}
|
||||
return await embeddingFunc.function.computeQueryEmbeddings(query);
|
||||
},
|
||||
);
|
||||
|
||||
return this.query().nearestTo(queryPromise);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user