feat(nodejs): table.search functionality (#1341)

closes https://github.com/lancedb/lancedb/issues/1256
This commit is contained in:
Cory Grinstead
2024-06-04 14:04:03 -05:00
committed by GitHub
parent d9fb6457e1
commit 70f92f19a6
6 changed files with 163 additions and 10 deletions

View File

@@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export type IntoVector = Float32Array | Float64Array | number[];
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;

View File

@@ -19,6 +19,7 @@ import {
FixedSizeList,
Float,
Float32,
type IntoVector,
isDataType,
isFixedSizeList,
isFloat,
@@ -169,9 +170,7 @@ export abstract class EmbeddingFunction<
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(
data: T,
): Promise<number[] | Float32Array | Float64Array> {
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);

View File

@@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry {
* Register an embedding function
* @param name The name of the function
* @param func The function to register
* @throws Error if the function is already registered
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
@@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry {
this.#functions.clear();
}
/**
* @ignore
*/
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
tableFromIPC,
} from "./arrow";
import { type IvfPqOptions } from "./indices";
import {
RecordBatchIterator as NativeBatchIterator,
@@ -108,9 +113,12 @@ export class QueryBase<
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/
select(
columns: string[] | Map<string, string> | Record<string, string>,
columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
}
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
@@ -370,9 +378,8 @@ export class Query extends QueryBase<NativeQuery, Query> {
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: unknown): VectorQuery {
// biome-ignore lint/suspicious/noExplicitAny: skip
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
}
}

View File

@@ -11,15 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Table as ArrowTable,
Data,
IntoVector,
Schema,
fromDataToBuffer,
tableFromIPC,
} from "./arrow";
import { getRegistry } from "./embedding/registry";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices";
import {
AddColumnsSql,
@@ -115,6 +117,14 @@ export class Table {
return this.inner.display();
}
async #getEmbeddingFunctions(): Promise<
Map<string, EmbeddingFunctionConfig>
> {
const schema = await this.schema();
const registry = getRegistry();
return registry.parseFunctions(schema.metadata);
}
/** Get the schema of the table. */
async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema();
@@ -276,6 +286,40 @@ export class Table {
return new Query(this.inner);
}
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
*/
search(query: string): Promise<VectorQuery>;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {IntoVector} query - the query vector
*/
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.#getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
}
}
/**
* Search the table with a given query vector.
*
@@ -283,7 +327,7 @@ export class Table {
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
vectorSearch(vector: unknown): VectorQuery {
vectorSearch(vector: IntoVector): VectorQuery {
return this.query().nearestTo(vector);
}