mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-25 16:00:49 +00:00
feat(nodejs): table.search functionality (#1341)
closes https://github.com/lancedb/lancedb/issues/1256
This commit is contained in:
@@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
|
||||
export * from "apache-arrow";
|
||||
|
||||
export type IntoVector = Float32Array | Float64Array | number[];
|
||||
|
||||
export function isArrowTable(value: object): value is ArrowTable {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
return "schema" in value && "batches" in value;
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
FixedSizeList,
|
||||
Float,
|
||||
Float32,
|
||||
type IntoVector,
|
||||
isDataType,
|
||||
isFixedSizeList,
|
||||
isFloat,
|
||||
@@ -169,9 +170,7 @@ export abstract class EmbeddingFunction<
|
||||
/**
|
||||
Compute the embeddings for a single query
|
||||
*/
|
||||
async computeQueryEmbeddings(
|
||||
data: T,
|
||||
): Promise<number[] | Float32Array | Float64Array> {
|
||||
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
|
||||
return this.computeSourceEmbeddings([data]).then(
|
||||
(embeddings) => embeddings[0],
|
||||
);
|
||||
|
||||
@@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry {
|
||||
* Register an embedding function
|
||||
* @param name The name of the function
|
||||
* @param func The function to register
|
||||
* @throws Error if the function is already registered
|
||||
*/
|
||||
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
|
||||
this: EmbeddingFunctionRegistry,
|
||||
@@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry {
|
||||
this.#functions.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* @ignore
|
||||
*/
|
||||
parseFunctions(
|
||||
this: EmbeddingFunctionRegistry,
|
||||
metadata: Map<string, string>,
|
||||
|
||||
@@ -12,7 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
|
||||
import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
tableFromIPC,
|
||||
} from "./arrow";
|
||||
import { type IvfPqOptions } from "./indices";
|
||||
import {
|
||||
RecordBatchIterator as NativeBatchIterator,
|
||||
@@ -108,9 +113,12 @@ export class QueryBase<
|
||||
* object insertion order is easy to get wrong and `Map` is more foolproof.
|
||||
*/
|
||||
select(
|
||||
columns: string[] | Map<string, string> | Record<string, string>,
|
||||
columns: string[] | Map<string, string> | Record<string, string> | string,
|
||||
): QueryType {
|
||||
let columnTuples: [string, string][];
|
||||
if (typeof columns === "string") {
|
||||
columns = [columns];
|
||||
}
|
||||
if (Array.isArray(columns)) {
|
||||
columnTuples = columns.map((c) => [c, c]);
|
||||
} else if (columns instanceof Map) {
|
||||
@@ -370,9 +378,8 @@ export class Query extends QueryBase<NativeQuery, Query> {
|
||||
* Vector searches always have a `limit`. If `limit` has not been called then
|
||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||
*/
|
||||
nearestTo(vector: unknown): VectorQuery {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
|
||||
nearestTo(vector: IntoVector): VectorQuery {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
return new VectorQuery(vectorQuery);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,15 +11,17 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
Table as ArrowTable,
|
||||
Data,
|
||||
IntoVector,
|
||||
Schema,
|
||||
fromDataToBuffer,
|
||||
tableFromIPC,
|
||||
} from "./arrow";
|
||||
|
||||
import { getRegistry } from "./embedding/registry";
|
||||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||
import { IndexOptions } from "./indices";
|
||||
import {
|
||||
AddColumnsSql,
|
||||
@@ -115,6 +117,14 @@ export class Table {
|
||||
return this.inner.display();
|
||||
}
|
||||
|
||||
async #getEmbeddingFunctions(): Promise<
|
||||
Map<string, EmbeddingFunctionConfig>
|
||||
> {
|
||||
const schema = await this.schema();
|
||||
const registry = getRegistry();
|
||||
return registry.parseFunctions(schema.metadata);
|
||||
}
|
||||
|
||||
/** Get the schema of the table. */
|
||||
async schema(): Promise<Schema> {
|
||||
const schemaBuf = await this.inner.schema();
|
||||
@@ -276,6 +286,40 @@ export class Table {
|
||||
return new Query(this.inner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
|
||||
* @rejects {Error} If no embedding functions are defined in the table
|
||||
*/
|
||||
search(query: string): Promise<VectorQuery>;
|
||||
/**
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
* @param {IntoVector} query - the query vector
|
||||
*/
|
||||
search(query: IntoVector): VectorQuery;
|
||||
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
|
||||
if (typeof query !== "string") {
|
||||
return this.vectorSearch(query);
|
||||
} else {
|
||||
return this.#getEmbeddingFunctions().then(async (functions) => {
|
||||
// TODO: Support multiple embedding functions
|
||||
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
|
||||
.values()
|
||||
.next().value;
|
||||
if (!embeddingFunc) {
|
||||
return Promise.reject(
|
||||
new Error("No embedding functions are defined in the table"),
|
||||
);
|
||||
}
|
||||
const embeddings =
|
||||
await embeddingFunc.function.computeQueryEmbeddings(query);
|
||||
return this.query().nearestTo(embeddings);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Search the table with a given query vector.
|
||||
*
|
||||
@@ -283,7 +327,7 @@ export class Table {
|
||||
* is the same thing as calling `nearestTo` on the builder returned
|
||||
* by `query`. @see {@link Query#nearestTo} for more details.
|
||||
*/
|
||||
vectorSearch(vector: unknown): VectorQuery {
|
||||
vectorSearch(vector: IntoVector): VectorQuery {
|
||||
return this.query().nearestTo(vector);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user