diff --git a/node/src/sanitize.ts b/node/src/sanitize.ts index 0788f41e..5a10b5d7 100644 --- a/node/src/sanitize.ts +++ b/node/src/sanitize.ts @@ -79,7 +79,7 @@ import { import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type"; function sanitizeMetadata( - metadataLike?: unknown + metadataLike?: unknown, ): Map | undefined { if (metadataLike === undefined || metadataLike === null) { return undefined; @@ -90,7 +90,7 @@ function sanitizeMetadata( for (const item of metadataLike) { if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) { throw Error( - "Expected metadata, if present, to be a Map but it had non-string keys or values" + "Expected metadata, if present, to be a Map but it had non-string keys or values", ); } } @@ -105,7 +105,7 @@ function sanitizeInt(typeLike: object) { typeof typeLike.isSigned !== "boolean" ) { throw Error( - "Expected an Int Type to have a `bitWidth` and `isSigned` property" + "Expected an Int Type to have a `bitWidth` and `isSigned` property", ); } return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth); @@ -128,7 +128,7 @@ function sanitizeDecimal(typeLike: object) { typeof typeLike.bitWidth !== "number" ) { throw Error( - "Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties" + "Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties", ); } return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); @@ -149,7 +149,7 @@ function sanitizeTime(typeLike: object) { typeof typeLike.bitWidth !== "number" ) { throw Error( - "Expected a Time type to have `unit` and `bitWidth` properties" + "Expected a Time type to have `unit` and `bitWidth` properties", ); } return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth); @@ -172,7 +172,7 @@ function sanitizeTypedTimestamp( | typeof TimestampNanosecond | typeof TimestampMicrosecond | typeof TimestampMillisecond - | typeof TimestampSecond + | typeof TimestampSecond, ) { let timezone = null; if ("timezone" in typeLike && typeof typeLike.timezone === "string") { @@ -191,7 +191,7 @@ function sanitizeInterval(typeLike: object) { function sanitizeList(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a List type to have an array-like `children` property" + "Expected a List type to have an array-like `children` property", ); } if (typeLike.children.length !== 1) { @@ -203,7 +203,7 @@ function sanitizeList(typeLike: object) { function sanitizeStruct(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Struct type to have an array-like `children` property" + "Expected a Struct type to have an array-like `children` property", ); } return new Struct(typeLike.children.map((child) => sanitizeField(child))); @@ -216,47 +216,47 @@ function sanitizeUnion(typeLike: object) { typeof typeLike.mode !== "number" ) { throw Error( - "Expected a Union type to have `typeIds` and `mode` properties" + "Expected a Union type to have `typeIds` and `mode` properties", ); } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Union type to have an array-like `children` property" + "Expected a Union type to have an array-like `children` property", ); } return new Union( typeLike.mode, typeLike.typeIds as any, - typeLike.children.map((child) => sanitizeField(child)) + typeLike.children.map((child) => sanitizeField(child)), ); } function sanitizeTypedUnion( typeLike: object, - UnionType: typeof DenseUnion | typeof SparseUnion + UnionType: typeof DenseUnion | typeof SparseUnion, ) { if (!("typeIds" in typeLike)) { throw Error( - "Expected a DenseUnion/SparseUnion type to have a `typeIds` property" + "Expected a DenseUnion/SparseUnion type to have a `typeIds` property", ); } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a DenseUnion/SparseUnion type to have an array-like `children` property" + "Expected a DenseUnion/SparseUnion type to have an array-like `children` property", ); } return new UnionType( typeLike.typeIds as any, - typeLike.children.map((child) => sanitizeField(child)) + typeLike.children.map((child) => sanitizeField(child)), ); } function sanitizeFixedSizeBinary(typeLike: object) { if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { throw Error( - "Expected a FixedSizeBinary type to have a `byteWidth` property" + "Expected a FixedSizeBinary type to have a `byteWidth` property", ); } return new FixedSizeBinary(typeLike.byteWidth); @@ -268,7 +268,7 @@ function sanitizeFixedSizeList(typeLike: object) { } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a FixedSizeList type to have an array-like `children` property" + "Expected a FixedSizeList type to have an array-like `children` property", ); } if (typeLike.children.length !== 1) { @@ -276,14 +276,14 @@ function sanitizeFixedSizeList(typeLike: object) { } return new FixedSizeList( typeLike.listSize, - sanitizeField(typeLike.children[0]) + sanitizeField(typeLike.children[0]), ); } function sanitizeMap(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Map type to have an array-like `children` property" + "Expected a Map type to have an array-like `children` property", ); } if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") { @@ -291,7 +291,7 @@ function sanitizeMap(typeLike: object) { } return new Map_( typeLike.children.map((field) => sanitizeField(field)) as any, - typeLike.keysSorted + typeLike.keysSorted, ); } @@ -319,7 +319,7 @@ function sanitizeDictionary(typeLike: object) { sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices) as any, typeLike.id, - typeLike.isOrdered + typeLike.isOrdered, ); } @@ -454,7 +454,7 @@ function sanitizeField(fieldLike: unknown): Field { !("nullable" in fieldLike) ) { throw Error( - "The field passed in is missing a `type`/`name`/`nullable` property" + "The field passed in is missing a `type`/`name`/`nullable` property", ); } const type = sanitizeType(fieldLike.type); @@ -473,6 +473,13 @@ function sanitizeField(fieldLike: unknown): Field { return new Field(name, type, nullable, metadata); } +/** + * Convert something schemaLike into a Schema instance + * + * This method is often needed even when the caller is using a Schema + * instance because they might be using a different instance of apache-arrow + * than lancedb is using. + */ export function sanitizeSchema(schemaLike: unknown): Schema { if (schemaLike instanceof Schema) { return schemaLike; @@ -482,7 +489,7 @@ export function sanitizeSchema(schemaLike: unknown): Schema { } if (!("fields" in schemaLike)) { throw Error( - "The schema passed in does not appear to be a schema (no 'fields' property)" + "The schema passed in does not appear to be a schema (no 'fields' property)", ); } let metadata; @@ -491,11 +498,11 @@ export function sanitizeSchema(schemaLike: unknown): Schema { } if (!Array.isArray(schemaLike.fields)) { throw Error( - "The schema passed in had a 'fields' property but it was not an array" + "The schema passed in had a 'fields' property but it was not an array", ); } const sanitizedFields = schemaLike.fields.map((field) => - sanitizeField(field) + sanitizeField(field), ); return new Schema(sanitizedFields, metadata); } diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index d4ce0c6d..3a673675 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -129,11 +129,25 @@ describe("When creating an index", () => { }); // Search without specifying the column - const rst = await tbl.query().nearestTo(queryVec).limit(2).toArrow(); + let rst = await tbl + .query() + .limit(2) + .nearestTo(queryVec) + .distanceType("DoT") + .toArrow(); + expect(rst.numRows).toBe(2); + + // Search using `vectorSearch` + rst = await tbl.vectorSearch(queryVec).limit(2).toArrow(); expect(rst.numRows).toBe(2); // Search with specifying the column - const rst2 = await tbl.search(queryVec, "vec").limit(2).toArrow(); + const rst2 = await tbl + .query() + .limit(2) + .nearestTo(queryVec) + .column("vec") + .toArrow(); expect(rst2.numRows).toBe(2); expect(rst.toString()).toEqual(rst2.toString()); }); @@ -163,7 +177,7 @@ describe("When creating an index", () => { const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); - for await (const r of tbl.query().filter("id > 1").select(["id"])) { + for await (const r of tbl.query().where("id > 1").select(["id"])) { expect(r.numRows).toBe(298); } }); @@ -205,33 +219,39 @@ describe("When creating an index", () => { const rst = await tbl .query() + .limit(2) .nearestTo( Array(32) .fill(1) .map(() => Math.random()), ) - .limit(2) .toArrow(); expect(rst.numRows).toBe(2); // Search with specifying the column await expect( tbl - .search( + .query() + .limit(2) + .nearestTo( Array(64) .fill(1) .map(() => Math.random()), - "vec", ) - .limit(2) + .column("vec") .toArrow(), ).rejects.toThrow(/.*does not match the dimension.*/); const query64 = Array(64) .fill(1) .map(() => Math.random()); - const rst64Query = await tbl.query().nearestTo(query64).limit(2).toArrow(); - const rst64Search = await tbl.search(query64, "vec2").limit(2).toArrow(); + const rst64Query = await tbl.query().limit(2).nearestTo(query64).toArrow(); + const rst64Search = await tbl + .query() + .limit(2) + .nearestTo(query64) + .column("vec2") + .toArrow(); expect(rst64Query.toString()).toEqual(rst64Search.toString()); expect(rst64Query.numRows).toBe(2); }); diff --git a/nodejs/eslint.config.js b/nodejs/eslint.config.js index 73afdbdc..2fc93379 100644 --- a/nodejs/eslint.config.js +++ b/nodejs/eslint.config.js @@ -4,14 +4,25 @@ const eslint = require("@eslint/js"); const tseslint = require("typescript-eslint"); const eslintConfigPrettier = require("eslint-config-prettier"); +const jsdoc = require("eslint-plugin-jsdoc"); module.exports = tseslint.config( eslint.configs.recommended, + jsdoc.configs["flat/recommended"], eslintConfigPrettier, ...tseslint.configs.recommended, { rules: { "@typescript-eslint/naming-convention": "error", + "jsdoc/require-returns": "off", + "jsdoc/require-param": "off", + "jsdoc/require-jsdoc": [ + "error", + { + publicOnly: true, + }, + ], }, + plugins: jsdoc, }, ); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 21c84b53..df4129b0 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -31,6 +31,7 @@ import { DataType, Binary, Float32, + type makeTable, } from "apache-arrow"; import { type EmbeddingFunction } from "./embedding/embedding_function"; import { sanitizeSchema } from "./sanitize"; @@ -128,14 +129,7 @@ export class MakeArrowTableOptions { * - Buffer => Binary * - Record => Struct * - Array => List - * - * @param data input data - * @param options options to control the makeArrowTable call. - * * @example - * - * ```ts - * * import { fromTableToBuffer, makeArrowTable } from "../arrow"; * import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow"; * @@ -307,7 +301,9 @@ export function makeEmptyTable(schema: Schema): ArrowTable { return makeArrowTable([], { schema }); } -// Helper function to convert Array> to a variable sized list array +/** + * Helper function to convert Array> to a variable sized list array + */ // @ts-expect-error (Vector is not assignable to Vector) function makeListVector(lists: unknown[][]): Vector { if (lists.length === 0 || lists[0].length === 0) { @@ -333,7 +329,7 @@ function makeListVector(lists: unknown[][]): Vector { return listBuilder.finish().toVector(); } -// Helper function to convert an Array of JS values to an Arrow Vector +/** Helper function to convert an Array of JS values to an Arrow Vector */ function makeVector( values: unknown[], type?: DataType, @@ -374,6 +370,7 @@ function makeVector( } } +/** Helper function to apply embeddings to an input table */ async function applyEmbeddings( table: ArrowTable, embeddings?: EmbeddingFunction, @@ -466,7 +463,7 @@ async function applyEmbeddings( return newTable; } -/* +/** * Convert an Array of records into an Arrow Table, optionally applying an * embeddings function to it. * @@ -493,7 +490,7 @@ export async function convertToTable( return await applyEmbeddings(table, embeddings, makeTableOptions?.schema); } -// Creates the Arrow Type for a Vector column with dimension `dim` +/** Creates the Arrow Type for a Vector column with dimension `dim` */ function newVectorType( dim: number, innerType: T, @@ -565,6 +562,14 @@ export async function fromTableToBuffer( return Buffer.from(await writer.toUint8Array()); } +/** + * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization + * + * This function will apply `embeddings` to the table in a manner similar to + * `convertToTable`. + * + * `schema` is required if the table is empty + */ export async function fromDataToBuffer( data: Data, embeddings?: EmbeddingFunction, @@ -599,6 +604,9 @@ export async function fromTableToStreamBuffer( return Buffer.from(await writer.toUint8Array()); } +/** + * Reorder the columns in `batch` so that they agree with the field order in `schema` + */ function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch { const alignedChildren = []; for (const field of schema.fields) { @@ -621,6 +629,9 @@ function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch { return new RecordBatch(schema, newData); } +/** + * Reorder the columns in `table` so that they agree with the field order in `schema` + */ function alignTable(table: ArrowTable, schema: Schema): ArrowTable { const alignedBatches = table.batches.map((batch) => alignBatch(batch, schema), @@ -628,7 +639,9 @@ function alignTable(table: ArrowTable, schema: Schema): ArrowTable { return new ArrowTable(schema, alignedBatches); } -// Creates an empty Arrow Table +/** + * Create an empty table with the given schema + */ export function createEmptyTable(schema: Schema): ArrowTable { return new ArrowTable(sanitizeSchema(schema)); } diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index b42bc3ba..5a8f6f32 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -78,7 +78,8 @@ export class Connection { return this.inner.isOpen(); } - /** Close the connection, releasing any underlying resources. + /** + * Close the connection, releasing any underlying resources. * * It is safe to call this method multiple times. * @@ -93,11 +94,12 @@ export class Connection { return this.inner.display(); } - /** List all the table names in this database. + /** + * List all the table names in this database. * * Tables will be returned in lexicographical order. - * - * @param options Optional parameters to control the listing. + * @param {Partial} options - options to control the + * paging / start point */ async tableNames(options?: Partial): Promise { return this.inner.tableNames(options?.startAfter, options?.limit); @@ -105,9 +107,7 @@ export class Connection { /** * Open a table in the database. - * - * @param name The name of the table. - * @param embeddings An embedding function to use on this table + * @param {string} name - The name of the table */ async openTable(name: string): Promise { const innerTable = await this.inner.openTable(name); @@ -116,9 +116,9 @@ export class Connection { /** * Creates a new Table and initialize it with new data. - * * @param {string} name - The name of the table. - * @param data - Non-empty Array of Records to be inserted into the table + * @param {Record[] | ArrowTable} data - Non-empty Array of Records + * to be inserted into the table */ async createTable( name: string, @@ -145,9 +145,8 @@ export class Connection { /** * Creates a new empty Table - * * @param {string} name - The name of the table. - * @param schema - The schema of the table + * @param {Schema} schema - The schema of the table */ async createEmptyTable( name: string, @@ -169,7 +168,7 @@ export class Connection { /** * Drop an existing table. - * @param name The name of the table to drop. + * @param {string} name The name of the table to drop. */ async dropTable(name: string): Promise { return this.inner.dropTable(name); diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 4e046562..4e102453 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -62,6 +62,7 @@ export interface EmbeddingFunction { embed: (data: T[]) => Promise; } +/** Test if the input seems to be an embedding function */ export function isEmbeddingFunction( value: unknown, ): value is EmbeddingFunction { diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 9ec464df..4877a6c2 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -30,9 +30,8 @@ export { Table, AddDataOptions } from "./table"; * - `/path/to/database` - local database * - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage * - `db://host:port` - remote database (LanceDB cloud) - * - * @param uri The uri of the database. If the database uri starts with `db://` then it connects to a remote database. - * + * @param {string} uri - The uri of the database. If the database uri starts + * with `db://` then it connects to a remote database. * @see {@link ConnectionOptions} for more details on the URI format. */ export async function connect( diff --git a/nodejs/lancedb/indices.ts b/nodejs/lancedb/indices.ts index c14d335e..cf5fab63 100644 --- a/nodejs/lancedb/indices.ts +++ b/nodejs/lancedb/indices.ts @@ -18,7 +18,8 @@ import { Index as LanceDbIndex } from "./native"; * Options to create an `IVF_PQ` index */ export interface IvfPqOptions { - /** The number of IVF partitions to create. + /** + * The number of IVF partitions to create. * * This value should generally scale with the number of rows in the dataset. * By default the number of partitions is the square root of the number of @@ -30,7 +31,8 @@ export interface IvfPqOptions { */ numPartitions?: number; - /** Number of sub-vectors of PQ. + /** + * Number of sub-vectors of PQ. * * This value controls how much the vector is compressed during the quantization step. * The more sub vectors there are the less the vector is compressed. The default is @@ -45,9 +47,10 @@ export interface IvfPqOptions { */ numSubVectors?: number; - /** [DistanceType] to use to build the index. + /** + * Distance type to use to build the index. * - * Default value is [DistanceType::L2]. + * Default value is "l2". * * This is used when training the index to calculate the IVF partitions * (vectors are grouped in partitions with similar vectors according to this @@ -79,7 +82,8 @@ export interface IvfPqOptions { */ distanceType?: "l2" | "cosine" | "dot"; - /** Max iteration to train IVF kmeans. + /** + * Max iteration to train IVF kmeans. * * When training an IVF PQ index we use kmeans to calculate the partitions. This parameter * controls how many iterations of kmeans to run. @@ -91,7 +95,8 @@ export interface IvfPqOptions { */ maxIterations?: number; - /** The number of vectors, per partition, to sample when training IVF kmeans. + /** + * The number of vectors, per partition, to sample when training IVF kmeans. * * When an IVF PQ index is trained, we need to calculate partitions. These are groups * of vectors that are similar to each other. To do this we use an algorithm called kmeans. @@ -148,7 +153,8 @@ export class Index { ); } - /** Create a btree index + /** + * Create a btree index * * A btree index is an index on a scalar columns. The index stores a copy of the column * in sorted order. A header entry is created for each block of rows (currently the @@ -172,7 +178,8 @@ export class Index { } export interface IndexOptions { - /** Advanced index configuration + /** + * Advanced index configuration * * This option allows you to specify a specfic index to create and also * allows you to pass in configuration for training the index. @@ -183,7 +190,8 @@ export interface IndexOptions { * will be used to determine the most useful kind of index to create. */ config?: Index; - /** Whether to replace the existing index + /** + * Whether to replace the existing index * * If this is false, and another index already exists on the same columns * and the same name, then an error will be returned. This is true even if diff --git a/nodejs/lancedb/native.d.ts b/nodejs/lancedb/native.d.ts index 9c639856..f025484b 100644 --- a/nodejs/lancedb/native.d.ts +++ b/nodejs/lancedb/native.d.ts @@ -105,15 +105,23 @@ export class RecordBatchIterator { next(): Promise } export class Query { - column(column: string): void - filter(filter: string): void - select(columns: Array): void + onlyIf(predicate: string): void + select(columns: Array<[string, string]>): void limit(limit: number): void - prefilter(prefilter: boolean): void - nearestTo(vector: Float32Array): void + nearestTo(vector: Float32Array): VectorQuery + execute(): Promise +} +export class VectorQuery { + column(column: string): void + distanceType(distanceType: string): void + postfilter(): void refineFactor(refineFactor: number): void nprobes(nprobe: number): void - executeStream(): Promise + bypassVectorIndex(): void + onlyIf(predicate: string): void + select(columns: Array<[string, string]>): void + limit(limit: number): void + execute(): Promise } export class Table { display(): string @@ -127,6 +135,7 @@ export class Table { createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise query(): Query + vectorSearch(vector: Float32Array): VectorQuery addColumns(transforms: Array): Promise alterColumns(alterations: Array): Promise dropColumns(columns: Array): Promise diff --git a/nodejs/lancedb/native.js b/nodejs/lancedb/native.js index 55eb6df5..e8ea9295 100644 --- a/nodejs/lancedb/native.js +++ b/nodejs/lancedb/native.js @@ -5,302 +5,325 @@ /* auto-generated by NAPI-RS */ const { existsSync, readFileSync } = require('fs') -const { join } = require('path') +const { join } = require("path"); -const { platform, arch } = process +const { platform, arch } = process; -let nativeBinding = null -let localFileExisted = false -let loadError = null +let nativeBinding = null; +let localFileExisted = false; +let loadError = null; function isMusl() { // For Node 10 - if (!process.report || typeof process.report.getReport !== 'function') { + if (!process.report || typeof process.report.getReport !== "function") { try { - const lddPath = require('child_process').execSync('which ldd').toString().trim() - return readFileSync(lddPath, 'utf8').includes('musl') + const lddPath = require("child_process") + .execSync("which ldd") + .toString() + .trim(); + return readFileSync(lddPath, "utf8").includes("musl"); } catch (e) { - return true + return true; } } else { - const { glibcVersionRuntime } = process.report.getReport().header - return !glibcVersionRuntime + const { glibcVersionRuntime } = process.report.getReport().header; + return !glibcVersionRuntime; } } switch (platform) { - case 'android': + case "android": switch (arch) { - case 'arm64': - localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm64.node')) + case "arm64": + localFileExisted = existsSync( + join(__dirname, "lancedb-nodejs.android-arm64.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.android-arm64.node') + nativeBinding = require("./lancedb-nodejs.android-arm64.node"); } else { - nativeBinding = require('lancedb-android-arm64') + nativeBinding = require("lancedb-android-arm64"); } } catch (e) { - loadError = e + loadError = e; } - break - case 'arm': - localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm-eabi.node')) + break; + case "arm": + localFileExisted = existsSync( + join(__dirname, "lancedb-nodejs.android-arm-eabi.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.android-arm-eabi.node') + nativeBinding = require("./lancedb-nodejs.android-arm-eabi.node"); } else { - nativeBinding = require('lancedb-android-arm-eabi') + nativeBinding = require("lancedb-android-arm-eabi"); } } catch (e) { - loadError = e + loadError = e; } - break + break; default: - throw new Error(`Unsupported architecture on Android ${arch}`) + throw new Error(`Unsupported architecture on Android ${arch}`); } - break - case 'win32': + break; + case "win32": switch (arch) { - case 'x64': + case "x64": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.win32-x64-msvc.node') - ) + join(__dirname, "lancedb-nodejs.win32-x64-msvc.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.win32-x64-msvc.node') + nativeBinding = require("./lancedb-nodejs.win32-x64-msvc.node"); } else { - nativeBinding = require('lancedb-win32-x64-msvc') + nativeBinding = require("lancedb-win32-x64-msvc"); } } catch (e) { - loadError = e + loadError = e; } - break - case 'ia32': + break; + case "ia32": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.win32-ia32-msvc.node') - ) + join(__dirname, "lancedb-nodejs.win32-ia32-msvc.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.win32-ia32-msvc.node') + nativeBinding = require("./lancedb-nodejs.win32-ia32-msvc.node"); } else { - nativeBinding = require('lancedb-win32-ia32-msvc') + nativeBinding = require("lancedb-win32-ia32-msvc"); } } catch (e) { - loadError = e + loadError = e; } - break - case 'arm64': + break; + case "arm64": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.win32-arm64-msvc.node') - ) + join(__dirname, "lancedb-nodejs.win32-arm64-msvc.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.win32-arm64-msvc.node') + nativeBinding = require("./lancedb-nodejs.win32-arm64-msvc.node"); } else { - nativeBinding = require('lancedb-win32-arm64-msvc') + nativeBinding = require("lancedb-win32-arm64-msvc"); } } catch (e) { - loadError = e + loadError = e; } - break + break; default: - throw new Error(`Unsupported architecture on Windows: ${arch}`) + throw new Error(`Unsupported architecture on Windows: ${arch}`); } - break - case 'darwin': - localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-universal.node')) + break; + case "darwin": + localFileExisted = existsSync( + join(__dirname, "lancedb-nodejs.darwin-universal.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.darwin-universal.node') + nativeBinding = require("./lancedb-nodejs.darwin-universal.node"); } else { - nativeBinding = require('lancedb-darwin-universal') + nativeBinding = require("lancedb-darwin-universal"); } - break + break; } catch {} switch (arch) { - case 'x64': - localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-x64.node')) - try { - if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.darwin-x64.node') - } else { - nativeBinding = require('lancedb-darwin-x64') - } - } catch (e) { - loadError = e - } - break - case 'arm64': + case "x64": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.darwin-arm64.node') - ) + join(__dirname, "lancedb-nodejs.darwin-x64.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.darwin-arm64.node') + nativeBinding = require("./lancedb-nodejs.darwin-x64.node"); } else { - nativeBinding = require('lancedb-darwin-arm64') + nativeBinding = require("lancedb-darwin-x64"); } } catch (e) { - loadError = e + loadError = e; } - break + break; + case "arm64": + localFileExisted = existsSync( + join(__dirname, "lancedb-nodejs.darwin-arm64.node"), + ); + try { + if (localFileExisted) { + nativeBinding = require("./lancedb-nodejs.darwin-arm64.node"); + } else { + nativeBinding = require("lancedb-darwin-arm64"); + } + } catch (e) { + loadError = e; + } + break; default: - throw new Error(`Unsupported architecture on macOS: ${arch}`) + throw new Error(`Unsupported architecture on macOS: ${arch}`); } - break - case 'freebsd': - if (arch !== 'x64') { - throw new Error(`Unsupported architecture on FreeBSD: ${arch}`) + break; + case "freebsd": + if (arch !== "x64") { + throw new Error(`Unsupported architecture on FreeBSD: ${arch}`); } - localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.freebsd-x64.node')) + localFileExisted = existsSync( + join(__dirname, "lancedb-nodejs.freebsd-x64.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.freebsd-x64.node') + nativeBinding = require("./lancedb-nodejs.freebsd-x64.node"); } else { - nativeBinding = require('lancedb-freebsd-x64') + nativeBinding = require("lancedb-freebsd-x64"); } } catch (e) { - loadError = e + loadError = e; } - break - case 'linux': + break; + case "linux": switch (arch) { - case 'x64': + case "x64": if (isMusl()) { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-x64-musl.node') - ) + join(__dirname, "lancedb-nodejs.linux-x64-musl.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-x64-musl.node') + nativeBinding = require("./lancedb-nodejs.linux-x64-musl.node"); } else { - nativeBinding = require('lancedb-linux-x64-musl') + nativeBinding = require("lancedb-linux-x64-musl"); } } catch (e) { - loadError = e + loadError = e; } } else { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-x64-gnu.node') - ) + join(__dirname, "lancedb-nodejs.linux-x64-gnu.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-x64-gnu.node') + nativeBinding = require("./lancedb-nodejs.linux-x64-gnu.node"); } else { - nativeBinding = require('lancedb-linux-x64-gnu') + nativeBinding = require("lancedb-linux-x64-gnu"); } } catch (e) { - loadError = e + loadError = e; } } - break - case 'arm64': + break; + case "arm64": if (isMusl()) { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-arm64-musl.node') - ) + join(__dirname, "lancedb-nodejs.linux-arm64-musl.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-arm64-musl.node') + nativeBinding = require("./lancedb-nodejs.linux-arm64-musl.node"); } else { - nativeBinding = require('lancedb-linux-arm64-musl') + nativeBinding = require("lancedb-linux-arm64-musl"); } } catch (e) { - loadError = e + loadError = e; } } else { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-arm64-gnu.node') - ) + join(__dirname, "lancedb-nodejs.linux-arm64-gnu.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-arm64-gnu.node') + nativeBinding = require("./lancedb-nodejs.linux-arm64-gnu.node"); } else { - nativeBinding = require('lancedb-linux-arm64-gnu') + nativeBinding = require("lancedb-linux-arm64-gnu"); } } catch (e) { - loadError = e + loadError = e; } } - break - case 'arm': + break; + case "arm": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-arm-gnueabihf.node') - ) + join(__dirname, "lancedb-nodejs.linux-arm-gnueabihf.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-arm-gnueabihf.node') + nativeBinding = require("./lancedb-nodejs.linux-arm-gnueabihf.node"); } else { - nativeBinding = require('lancedb-linux-arm-gnueabihf') + nativeBinding = require("lancedb-linux-arm-gnueabihf"); } } catch (e) { - loadError = e + loadError = e; } - break - case 'riscv64': + break; + case "riscv64": if (isMusl()) { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-riscv64-musl.node') - ) + join(__dirname, "lancedb-nodejs.linux-riscv64-musl.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-riscv64-musl.node') + nativeBinding = require("./lancedb-nodejs.linux-riscv64-musl.node"); } else { - nativeBinding = require('lancedb-linux-riscv64-musl') + nativeBinding = require("lancedb-linux-riscv64-musl"); } } catch (e) { - loadError = e + loadError = e; } } else { localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-riscv64-gnu.node') - ) + join(__dirname, "lancedb-nodejs.linux-riscv64-gnu.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-riscv64-gnu.node') + nativeBinding = require("./lancedb-nodejs.linux-riscv64-gnu.node"); } else { - nativeBinding = require('lancedb-linux-riscv64-gnu') + nativeBinding = require("lancedb-linux-riscv64-gnu"); } } catch (e) { - loadError = e + loadError = e; } } - break - case 's390x': + break; + case "s390x": localFileExisted = existsSync( - join(__dirname, 'lancedb-nodejs.linux-s390x-gnu.node') - ) + join(__dirname, "lancedb-nodejs.linux-s390x-gnu.node"), + ); try { if (localFileExisted) { - nativeBinding = require('./lancedb-nodejs.linux-s390x-gnu.node') + nativeBinding = require("./lancedb-nodejs.linux-s390x-gnu.node"); } else { - nativeBinding = require('lancedb-linux-s390x-gnu') + nativeBinding = require("lancedb-linux-s390x-gnu"); } } catch (e) { - loadError = e + loadError = e; } - break + break; default: - throw new Error(`Unsupported architecture on Linux: ${arch}`) + throw new Error(`Unsupported architecture on Linux: ${arch}`); } - break + break; default: - throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`) + throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`); } if (!nativeBinding) { if (loadError) { - throw loadError + throw loadError; } - throw new Error(`Failed to load native binding`) + throw new Error(`Failed to load native binding`); } -const { Connection, Index, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding +const { + Connection, + Index, + RecordBatchIterator, + Query, + VectorQuery, + Table, + WriteMode, + connect, +} = nativeBinding; -module.exports.Connection = Connection -module.exports.Index = Index -module.exports.RecordBatchIterator = RecordBatchIterator -module.exports.Query = Query -module.exports.Table = Table -module.exports.WriteMode = WriteMode -module.exports.connect = connect +module.exports.Connection = Connection; +module.exports.Index = Index; +module.exports.RecordBatchIterator = RecordBatchIterator; +module.exports.Query = Query; +module.exports.VectorQuery = VectorQuery; +module.exports.Table = Table; +module.exports.WriteMode = WriteMode; +module.exports.connect = connect; diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index cd86a310..990f6370 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -17,18 +17,15 @@ import { RecordBatchIterator as NativeBatchIterator, Query as NativeQuery, Table as NativeTable, + VectorQuery as NativeVectorQuery, } from "./native"; - +import { type IvfPqOptions } from "./indices"; class RecordBatchIterator implements AsyncIterator { private promisedInner?: Promise; private inner?: NativeBatchIterator; - constructor( - inner?: NativeBatchIterator, - promise?: Promise, - ) { + constructor(promise?: Promise) { // TODO: check promise reliably so we dont need to pass two arguments. - this.inner = inner; this.promisedInner = promise; } @@ -53,82 +50,113 @@ class RecordBatchIterator implements AsyncIterator { } /* eslint-enable */ -/** Query executor */ -export class Query implements AsyncIterable { - private readonly inner: NativeQuery; +/** Common methods supported by all query types */ +export class QueryBase< + NativeQueryType extends NativeQuery | NativeVectorQuery, + QueryType, +> implements AsyncIterable +{ + protected constructor(protected inner: NativeQueryType) {} - constructor(tbl: NativeTable) { - this.inner = tbl.query(); + /** + * A filter statement to be applied to this query. + * + * The filter should be supplied as an SQL query string. For example: + * @example + * x > 10 + * y > 0 AND y < 100 + * x > 5 OR y = 'test' + * + * Filtering performance can often be improved by creating a scalar index + * on the filter column(s). + */ + where(predicate: string): QueryType { + this.inner.onlyIf(predicate); + return this as unknown as QueryType; } - /** Set the column to run query. */ - column(column: string): Query { - this.inner.column(column); - return this; + /** + * Return only the specified columns. + * + * By default a query will return all columns from the table. However, this can have + * a very significant impact on latency. LanceDb stores data in a columnar fashion. This + * means we can finely tune our I/O to select exactly the columns we need. + * + * As a best practice you should always limit queries to the columns that you need. If you + * pass in an array of column names then only those columns will be returned. + * + * You can also use this method to create new "dynamic" columns based on your existing columns. + * For example, you may not care about "a" or "b" but instead simply want "a + b". This is often + * seen in the SELECT clause of an SQL query (e.g. `SELECT a+b FROM my_table`). + * + * To create dynamic columns you can pass in a Map. A column will be returned + * for each entry in the map. The key provides the name of the column. The value is + * an SQL string used to specify how the column is calculated. + * + * For example, an SQL query might state `SELECT a + b AS combined, c`. The equivalent + * input to this method would be: + * @example + * new Map([["combined", "a + b"], ["c", "c"]]) + * + * Columns will always be returned in the order given, even if that order is different than + * the order used when adding the data. + * + * Note that you can pass in a `Record` (e.g. an object literal). This method + * uses `Object.entries` which should preserve the insertion order of the object. However, + * object insertion order is easy to get wrong and `Map` is more foolproof. + */ + select( + columns: string[] | Map | Record, + ): QueryType { + let columnTuples: [string, string][]; + if (Array.isArray(columns)) { + columnTuples = columns.map((c) => [c, c]); + } else if (columns instanceof Map) { + columnTuples = Array.from(columns.entries()); + } else { + columnTuples = Object.entries(columns); + } + this.inner.select(columnTuples); + return this as unknown as QueryType; } - /** Set the filter predicate, only returns the results that satisfy the filter. + /** + * Set the maximum number of results to return. + * + * By default, a plain search has no limit. If this method is not + * called then every valid row from the table will be returned. + */ + limit(limit: number): QueryType { + this.inner.limit(limit); + return this as unknown as QueryType; + } + + protected nativeExecute(): Promise { + return this.inner.execute(); + } + + /** + * Execute the query and return the results as an @see {@link AsyncIterator} + * of @see {@link RecordBatch}. + * + * By default, LanceDb will use many threads to calculate results and, when + * the result set is large, multiple batches will be processed at one time. + * This readahead is limited however and backpressure will be applied if this + * stream is consumed slowly (this constrains the maximum memory used by a + * single query) * */ - filter(predicate: string): Query { - this.inner.filter(predicate); - return this; + protected execute(): RecordBatchIterator { + return new RecordBatchIterator(this.nativeExecute()); } - /** - * Select the columns to return. If not set, all columns are returned. - */ - select(columns: string[]): Query { - this.inner.select(columns); - return this; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [Symbol.asyncIterator](): AsyncIterator> { + const promise = this.nativeExecute(); + return new RecordBatchIterator(promise); } - /** - * Set the limit of rows to return. - */ - limit(limit: number): Query { - this.inner.limit(limit); - return this; - } - - prefilter(prefilter: boolean): Query { - this.inner.prefilter(prefilter); - return this; - } - - /** - * Set the query vector. - */ - nearestTo(vector: number[]): Query { - this.inner.nearestTo(Float32Array.from(vector)); - return this; - } - - /** - * Set the number of IVF partitions to use for the query. - */ - nprobes(nprobes: number): Query { - this.inner.nprobes(nprobes); - return this; - } - - /** - * Set the refine factor for the query. - */ - refineFactor(refineFactor: number): Query { - this.inner.refineFactor(refineFactor); - return this; - } - - /** - * Execute the query and return the results as an AsyncIterator. - */ - async executeStream(): Promise { - const inner = await this.inner.executeStream(); - return new RecordBatchIterator(inner); - } - - /** Collect the results as an Arrow Table. */ + /** Collect the results as an Arrow @see {@link ArrowTable}. */ async toArrow(): Promise { const batches = []; for await (const batch of this) { @@ -137,18 +165,211 @@ export class Query implements AsyncIterable { return new ArrowTable(batches); } - /** Returns a JSON Array of All results. - * - */ + /** Collect the results as an array of objects. */ async toArray(): Promise { const tbl = await this.toArrow(); // eslint-disable-next-line @typescript-eslint/no-unsafe-return return tbl.toArray(); } +} - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [Symbol.asyncIterator](): AsyncIterator> { - const promise = this.inner.executeStream(); - return new RecordBatchIterator(undefined, promise); +/** + * An interface for a query that can be executed + * + * Supported by all query types + */ +export interface ExecutableQuery {} + +/** + * A builder used to construct a vector search + * + * This builder can be reused to execute the query many times. + */ +export class VectorQuery extends QueryBase { + constructor(inner: NativeVectorQuery) { + super(inner); + } + + /** + * Set the number of partitions to search (probe) + * + * This argument is only used when the vector column has an IVF PQ index. + * If there is no index then this value is ignored. + * + * The IVF stage of IVF PQ divides the input into partitions (clusters) of + * related values. + * + * The partition whose centroids are closest to the query vector will be + * exhaustiely searched to find matches. This parameter controls how many + * partitions should be searched. + * + * Increasing this value will increase the recall of your query but will + * also increase the latency of your query. The default value is 20. This + * default is good for many cases but the best value to use will depend on + * your data and the recall that you need to achieve. + * + * For best results we recommend tuning this parameter with a benchmark against + * your actual data to find the smallest possible value that will still give + * you the desired recall. + */ + nprobes(nprobes: number): VectorQuery { + this.inner.nprobes(nprobes); + return this; + } + + /** + * Set the vector column to query + * + * This controls which column is compared to the query vector supplied in + * the call to @see {@link Query#nearestTo} + * + * This parameter must be specified if the table has more than one column + * whose data type is a fixed-size-list of floats. + */ + column(column: string): VectorQuery { + this.inner.column(column); + return this; + } + + /** + * Set the distance metric to use + * + * When performing a vector search we try and find the "nearest" vectors according + * to some kind of distance metric. This parameter controls which distance metric to + * use. See @see {@link IvfPqOptions.distanceType} for more details on the different + * distance metrics available. + * + * Note: if there is a vector index then the distance type used MUST match the distance + * type used to train the vector index. If this is not done then the results will be + * invalid. + * + * By default "l2" is used. + */ + distanceType(distanceType: string): VectorQuery { + this.inner.distanceType(distanceType); + return this; + } + + /** + * A multiplier to control how many additional rows are taken during the refine step + * + * This argument is only used when the vector column has an IVF PQ index. + * If there is no index then this value is ignored. + * + * An IVF PQ index stores compressed (quantized) values. They query vector is compared + * against these values and, since they are compressed, the comparison is inaccurate. + * + * This parameter can be used to refine the results. It can improve both improve recall + * and correct the ordering of the nearest results. + * + * To refine results LanceDb will first perform an ANN search to find the nearest + * `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and + * `limit` is the default (10) then the first 30 results will be selected. LanceDb + * then fetches the full, uncompressed, values for these 30 results. The results are + * then reordered by the true distance and only the nearest 10 are kept. + * + * Note: there is a difference between calling this method with a value of 1 and never + * calling this method at all. Calling this method with any value will have an impact + * on your search latency. When you call this method with a `refine_factor` of 1 then + * LanceDb still needs to fetch the full, uncompressed, values so that it can potentially + * reorder the results. + * + * Note: if this method is NOT called then the distances returned in the _distance column + * will be approximate distances based on the comparison of the quantized query vector + * and the quantized result vectors. This can be considerably different than the true + * distance between the query vector and the actual uncompressed vector. + */ + refineFactor(refineFactor: number): VectorQuery { + this.inner.refineFactor(refineFactor); + return this; + } + + /** + * If this is called then filtering will happen after the vector search instead of + * before. + * + * By default filtering will be performed before the vector search. This is how + * filtering is typically understood to work. This prefilter step does add some + * additional latency. Creating a scalar index on the filter column(s) can + * often improve this latency. However, sometimes a filter is too complex or scalar + * indices cannot be applied to the column. In these cases postfiltering can be + * used instead of prefiltering to improve latency. + * + * Post filtering applies the filter to the results of the vector search. This means + * we only run the filter on a much smaller set of data. However, it can cause the + * query to return fewer than `limit` results (or even no results) if none of the nearest + * results match the filter. + * + * Post filtering happens during the "refine stage" (described in more detail in + * @see {@link VectorQuery#refineFactor}). This means that setting a higher refine + * factor can often help restore some of the results lost by post filtering. + */ + postfilter(): VectorQuery { + this.inner.postfilter(); + return this; + } + + /** + * If this is called then any vector index is skipped + * + * An exhaustive (flat) search will be performed. The query vector will + * be compared to every vector in the table. At high scales this can be + * expensive. However, this is often still useful. For example, skipping + * the vector index can give you ground truth results which you can use to + * calculate your recall to select an appropriate value for nprobes. + */ + bypassVectorIndex(): VectorQuery { + this.inner.bypassVectorIndex(); + return this; + } +} + +/** A builder for LanceDB queries. */ +export class Query extends QueryBase { + constructor(tbl: NativeTable) { + super(tbl.query()); + } + + /** + * Find the nearest vectors to the given query vector. + * + * This converts the query from a plain query to a vector query. + * + * This method will attempt to convert the input to the query vector + * expected by the embedding model. If the input cannot be converted + * then an error will be thrown. + * + * By default, there is no embedding model, and the input should be + * an array-like object of numbers (something that can be used as input + * to Float32Array.from) + * + * If there is only one vector column (a column whose data type is a + * fixed size list of floats) then the column does not need to be specified. + * If there is more than one vector column you must use + * @see {@link VectorQuery#column} to specify which column you would like + * to compare with. + * + * If no index has been created on the vector column then a vector query + * will perform a distance comparison between the query vector and every + * vector in the database and then sort the results. This is sometimes + * called a "flat search" + * + * For small databases, with a few hundred thousand vectors or less, this can + * be reasonably fast. In larger databases you should create a vector index + * on the column. If there is a vector index then an "approximate" nearest + * neighbor search (frequently called an ANN search) will be performed. This + * search is much faster, but the results will be approximate. + * + * The query can be further parameterized using the returned builder. There + * are various ANN search parameters that will let you fine tune your recall + * accuracy vs search latency. + * + * Vector searches always have a `limit`. If `limit` has not been called then + * a default `limit` of 10 will be used. @see {@link Query#limit} + */ + nearestTo(vector: unknown): VectorQuery { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any)); + return new VectorQuery(vectorQuery); } } diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index 92ec01e6..9afebe54 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -481,6 +481,13 @@ function sanitizeField(fieldLike: unknown): Field { return new Field(name, type, nullable, metadata); } +/** + * Convert something schemaLike into a Schema instance + * + * This method is often needed even when the caller is using a Schema + * instance because they might be using a different instance of apache-arrow + * than lancedb is using. + */ export function sanitizeSchema(schemaLike: unknown): Schema { if (schemaLike instanceof Schema) { return schemaLike; diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 4fad4f77..0ff8a13f 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -19,7 +19,7 @@ import { IndexConfig, Table as _NativeTable, } from "./native"; -import { Query } from "./query"; +import { Query, VectorQuery } from "./query"; import { IndexOptions } from "./indices"; import { Data, fromDataToBuffer } from "./arrow"; @@ -28,7 +28,8 @@ export { IndexConfig } from "./native"; * Options for adding data to a table. */ export interface AddDataOptions { - /** If "append" (the default) then the new data will be added to the table + /** + * If "append" (the default) then the new data will be added to the table * * If "overwrite" then the new data will replace the existing data in the table. */ @@ -74,7 +75,8 @@ export class Table { return this.inner.isOpen(); } - /** Close the table, releasing any underlying resources. + /** + * Close the table, releasing any underlying resources. * * It is safe to call this method multiple times. * @@ -98,9 +100,7 @@ export class Table { /** * Insert records into this Table. - * * @param {Data} data Records to be inserted into the Table - * @return The number of rows added to the table */ async add(data: Data, options?: Partial): Promise { const mode = options?.mode ?? "append"; @@ -124,15 +124,15 @@ export class Table { * you are updating many rows (with different ids) then you will get * better performance with a single [`merge_insert`] call instead of * repeatedly calilng this method. - * - * @param updates the columns to update + * @param {Map | Record} updates - the + * columns to update * * Keys in the map should specify the name of the column to update. * Values in the map provide the new value of the column. These can * be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions * based on the row being updated (e.g. "my_col + 1") - * - * @param options additional options to control the update behavior + * @param {Partial} options - additional options to control + * the update behavior */ async update( updates: Map | Record, @@ -158,37 +158,28 @@ export class Table { await this.inner.delete(predicate); } - /** Create an index to speed up queries. + /** + * Create an index to speed up queries. * * Indices can be created on vector columns or scalar columns. * Indices on vector columns will speed up vector searches. * Indices on scalar columns will speed up filtering (in both * vector and non-vector searches) - * * @example - * - * If the column has a vector (fixed size list) data type then - * an IvfPq vector index will be created. - * - * ```typescript + * // If the column has a vector (fixed size list) data type then + * // an IvfPq vector index will be created. * const table = await conn.openTable("my_table"); * await table.createIndex(["vector"]); - * ``` - * - * For advanced control over vector index creation you can specify - * the index type and options. - * ```typescript + * @example + * // For advanced control over vector index creation you can specify + * // the index type and options. * const table = await conn.openTable("my_table"); * await table.createIndex(["vector"], I) * .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 }) * .build(); - * ``` - * - * Or create a Scalar index - * - * ```typescript + * @example + * // Or create a Scalar index * await table.createIndex("my_float_col").build(); - * ``` */ async createIndex(column: string, options?: Partial) { // Bit of a hack to get around the fact that TS has no package-scope. @@ -198,69 +189,74 @@ export class Table { } /** - * Create a generic {@link Query} Builder. + * Create a {@link Query} Builder. + * + * Queries allow you to search your existing data. By default the query will + * return all the data in the table in no particular order. The builder + * returned by this method can be used to control the query using filtering, + * vector similarity, sorting, and more. + * + * Note: By default, all columns are returned. For best performance, you should + * only fetch the columns you need. See [`Query::select_with_projection`] for + * more details. * * When appropriate, various indices and statistics based pruning will be used to * accelerate the query. - * * @example - * - * ### Run a SQL-style query - * ```typescript + * // SQL-style filtering + * // + * // This query will return up to 1000 rows whose value in the `id` column + * // is greater than 5. LanceDb supports a broad set of filtering functions. * for await (const batch of table.query() * .filter("id > 1").select(["id"]).limit(20)) { * console.log(batch); * } - * ``` - * - * ### Run Top-10 vector similarity search - * ```typescript + * @example + * // Vector Similarity Search + * // + * // This example will find the 10 rows whose value in the "vector" column are + * // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created + * // on the "vector" column then this will perform an ANN search. + * // + * // The `refine_factor` and `nprobes` methods are used to control the recall / + * // latency tradeoff of the search. * for await (const batch of table.query() * .nearestTo([1, 2, 3]) * .refineFactor(5).nprobe(10) * .limit(10)) { * console.log(batch); * } - *``` - * - * ### Scan the full dataset - * ```typescript + * @example + * // Scan the full dataset + * // + * // This query will return everything in the table in no particular order. * for await (const batch of table.query()) { * console.log(batch); * } - * - * ### Return the full dataset as Arrow Table - * ```typescript - * let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow(); - * ``` - * - * @returns {@link Query} + * @returns {Query} A builder that can be used to parameterize the query */ query(): Query { return new Query(this.inner); } - /** Search the table with a given query vector. + /** + * Search the table with a given query vector. * - * This is a convenience method for preparing an ANN {@link Query}. + * This is a convenience method for preparing a vector query and + * is the same thing as calling `nearestTo` on the builder returned + * by `query`. @see {@link Query#nearestTo} for more details. */ - search(vector: number[], column?: string): Query { - const q = this.query(); - q.nearestTo(vector); - if (column !== undefined) { - q.column(column); - } - return q; + vectorSearch(vector: unknown): VectorQuery { + return this.query().nearestTo(vector); } // TODO: Support BatchUDF /** * Add new columns with defined values. - * - * @param newColumnTransforms pairs of column names and the SQL expression to use - * to calculate the value of the new column. These - * expressions will be evaluated for each row in the - * table, and can reference existing columns in the table. + * @param {AddColumnsSql[]} newColumnTransforms pairs of column names and + * the SQL expression to use to calculate the value of the new column. These + * expressions will be evaluated for each row in the table, and can + * reference existing columns in the table. */ async addColumns(newColumnTransforms: AddColumnsSql[]): Promise { await this.inner.addColumns(newColumnTransforms); @@ -268,8 +264,8 @@ export class Table { /** * Alter the name or nullability of columns. - * - * @param columnAlterations One or more alterations to apply to columns. + * @param {ColumnAlteration[]} columnAlterations One or more alterations to + * apply to columns. */ async alterColumns(columnAlterations: ColumnAlteration[]): Promise { await this.inner.alterColumns(columnAlterations); @@ -282,16 +278,16 @@ export class Table { * underlying storage. In order to remove the data, you must subsequently * call ``compact_files`` to rewrite the data without the removed columns and * then call ``cleanup_files`` to remove the old files. - * - * @param columnNames The names of the columns to drop. These can be nested - * column references (e.g. "a.b.c") or top-level column - * names (e.g. "a"). + * @param {string[]} columnNames The names of the columns to drop. These can + * be nested column references (e.g. "a.b.c") or top-level column names + * (e.g. "a"). */ async dropColumns(columnNames: string[]): Promise { await this.inner.dropColumns(columnNames); } - /** Retrieve the version of the table + /** + * Retrieve the version of the table * * LanceDb supports versioning. Every operation that modifies the table increases * version. As long as a version hasn't been deleted you can `[Self::checkout]` that @@ -302,7 +298,8 @@ export class Table { return await this.inner.version(); } - /** Checks out a specific version of the Table + /** + * Checks out a specific version of the Table * * Any read operation on the table will now access the data at the checked out version. * As a consequence, calling this method will disable any read consistency interval @@ -321,7 +318,8 @@ export class Table { await this.inner.checkout(version); } - /** Ensures the table is pointing at the latest version + /** + * Ensures the table is pointing at the latest version * * This can be used to manually update a table when the read_consistency_interval is None * It can also be used to undo a `[Self::checkout]` operation @@ -330,7 +328,8 @@ export class Table { await this.inner.checkoutLatest(); } - /** Restore the table to the currently checked out version + /** + * Restore the table to the currently checked out version * * This operation will fail if checkout has not been called previously * diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 6f3938cd..b8d05a7e 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -26,6 +26,7 @@ "apache-arrow-old": "npm:apache-arrow@13.0.0", "eslint": "^8.57.0", "eslint-config-prettier": "^9.1.0", + "eslint-plugin-jsdoc": "^48.2.1", "jest": "^29.7.0", "prettier": "^3.1.0", "tmp": "^0.2.3", @@ -755,6 +756,20 @@ "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==", "dev": true }, + "node_modules/@es-joy/jsdoccomment": { + "version": "0.42.0", + "resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.42.0.tgz", + "integrity": "sha512-R1w57YlVA6+YE01wch3GPYn6bCsrOV3YW/5oGGE2tmX6JcL9Nr+b5IikrjMPF+v9CV3ay+obImEdsDhovhJrzw==", + "dev": true, + "dependencies": { + "comment-parser": "1.4.1", + "esquery": "^1.5.0", + "jsdoc-type-pratt-parser": "~4.0.0" + }, + "engines": { + "node": ">=16" + } + }, "node_modules/@eslint-community/eslint-utils": { "version": "4.4.0", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", @@ -1948,6 +1963,15 @@ "integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==", "dev": true }, + "node_modules/are-docs-informative": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/are-docs-informative/-/are-docs-informative-0.0.2.tgz", + "integrity": "sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==", + "dev": true, + "engines": { + "node": ">=14" + } + }, "node_modules/argparse": { "version": "1.0.10", "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", @@ -2189,6 +2213,18 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "node_modules/builtin-modules": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz", + "integrity": "sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw==", + "dev": true, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/camelcase": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz", @@ -2373,6 +2409,15 @@ "node": ">=12.17" } }, + "node_modules/comment-parser": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.4.1.tgz", + "integrity": "sha512-buhp5kePrmda3vhc5B9t7pUQXAb2Tnd0qgpkIhPhkHXxJpiPJ11H0ZEU0oBpJ2QztSbzG/ZxMj/CHsYJqRHmyg==", + "dev": true, + "engines": { + "node": ">= 12.0.0" + } + }, "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", @@ -2660,6 +2705,29 @@ "eslint": ">=7.0.0" } }, + "node_modules/eslint-plugin-jsdoc": { + "version": "48.2.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-48.2.1.tgz", + "integrity": "sha512-iUvbcyDZSO/9xSuRv2HQBw++8VkV/pt3UWtX9cpPH0l7GKPq78QC/6+PmyQHHvNZaTjAce6QVciEbnc6J/zH5g==", + "dev": true, + "dependencies": { + "@es-joy/jsdoccomment": "~0.42.0", + "are-docs-informative": "^0.0.2", + "comment-parser": "1.4.1", + "debug": "^4.3.4", + "escape-string-regexp": "^4.0.0", + "esquery": "^1.5.0", + "is-builtin-module": "^3.2.1", + "semver": "^7.6.0", + "spdx-expression-parse": "^4.0.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "eslint": "^7.0.0 || ^8.0.0 || ^9.0.0" + } + }, "node_modules/eslint-scope": { "version": "7.2.2", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", @@ -3299,6 +3367,21 @@ "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==", "optional": true }, + "node_modules/is-builtin-module": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/is-builtin-module/-/is-builtin-module-3.2.1.tgz", + "integrity": "sha512-BSLE3HnV2syZ0FK0iMA/yUGplUeMmNz4AW5fnTunbCIqZi4vG3WjJT9FHMy5D69xmAYBHXQhJdALdpwVxV501A==", + "dev": true, + "dependencies": { + "builtin-modules": "^3.3.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-core-module": { "version": "2.13.1", "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz", @@ -4172,6 +4255,15 @@ "js-yaml": "bin/js-yaml.js" } }, + "node_modules/jsdoc-type-pratt-parser": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-4.0.0.tgz", + "integrity": "sha512-YtOli5Cmzy3q4dP26GraSOeAhqecewG04hoO8DY56CH4KJ9Fvv5qKWUCCo3HZob7esJQHCv6/+bnTy72xZZaVQ==", + "dev": true, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/jsesc": { "version": "2.5.2", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz", @@ -5018,9 +5110,9 @@ } }, "node_modules/semver": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", - "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.6.0.tgz", + "integrity": "sha512-EnwXhrlwXMk9gKu5/flx5sv/an57AkRplG3hTK68W7FRDN+k+OWBj65M7719OkA82XLBxrcX0KSHj+X5COhOVg==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -5105,6 +5197,28 @@ "source-map": "^0.6.0" } }, + "node_modules/spdx-exceptions": { + "version": "2.5.0", + "resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.5.0.tgz", + "integrity": "sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==", + "dev": true + }, + "node_modules/spdx-expression-parse": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/spdx-expression-parse/-/spdx-expression-parse-4.0.0.tgz", + "integrity": "sha512-Clya5JIij/7C6bRR22+tnGXbc4VKlibKSVj2iHvVeX5iMW7s1SIQlqu699JkODJJIhh/pUu8L0/VLh8xflD+LQ==", + "dev": true, + "dependencies": { + "spdx-exceptions": "^2.1.0", + "spdx-license-ids": "^3.0.0" + } + }, + "node_modules/spdx-license-ids": { + "version": "3.0.17", + "resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.17.tgz", + "integrity": "sha512-sh8PWc/ftMqAAdFiBu6Fy6JUOYjqDJBJvIhpfDMyHrr0Rbp5liZqd4TjtQ/RgfLjKFZb+LMx5hpml5qOWy0qvg==", + "dev": true + }, "node_modules/sprintf-js": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", diff --git a/nodejs/package.json b/nodejs/package.json index e023e356..d7d7580f 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -25,6 +25,7 @@ "apache-arrow-old": "npm:apache-arrow@13.0.0", "eslint": "^8.57.0", "eslint-config-prettier": "^9.1.0", + "eslint-plugin-jsdoc": "^48.2.1", "jest": "^29.7.0", "prettier": "^3.1.0", "tmp": "^0.2.3", diff --git a/nodejs/src/index.rs b/nodejs/src/index.rs index bba79b14..cef43505 100644 --- a/nodejs/src/index.rs +++ b/nodejs/src/index.rs @@ -17,9 +17,10 @@ use std::sync::Mutex; use lancedb::index::scalar::BTreeIndexBuilder; use lancedb::index::vector::IvfPqIndexBuilder; use lancedb::index::Index as LanceDbIndex; -use lancedb::DistanceType; use napi_derive::napi; +use crate::util::parse_distance_type; + #[napi] pub struct Index { inner: Mutex>, @@ -49,15 +50,7 @@ impl Index { ) -> napi::Result { let mut ivf_pq_builder = IvfPqIndexBuilder::default(); if let Some(distance_type) = distance_type { - let distance_type = match distance_type.as_str() { - "l2" => Ok(DistanceType::L2), - "cosine" => Ok(DistanceType::Cosine), - "dot" => Ok(DistanceType::Dot), - _ => Err(napi::Error::from_reason(format!( - "Invalid distance type '{}'. Must be one of l2, cosine, or dot", - distance_type - ))), - }?; + let distance_type = parse_distance_type(distance_type)?; ivf_pq_builder = ivf_pq_builder.distance_type(distance_type); } if let Some(num_partitions) = num_partitions { diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index 08a0045c..37de4885 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -21,6 +21,7 @@ mod index; mod iterator; mod query; mod table; +mod util; #[napi(object)] #[derive(Debug)] diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 07017d18..f74faa86 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -12,36 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -use lancedb::query::Query as LanceDBQuery; +use lancedb::query::ExecutableQuery; +use lancedb::query::Query as LanceDbQuery; +use lancedb::query::QueryBase; +use lancedb::query::Select; +use lancedb::query::VectorQuery as LanceDbVectorQuery; use napi::bindgen_prelude::*; use napi_derive::napi; +use crate::error::NapiErrorExt; use crate::iterator::RecordBatchIterator; +use crate::util::parse_distance_type; #[napi] pub struct Query { - inner: LanceDBQuery, + inner: LanceDbQuery, } #[napi] impl Query { - pub fn new(query: LanceDBQuery) -> Self { + pub fn new(query: LanceDbQuery) -> Self { Self { inner: query } } + // We cannot call this r#where because NAPI gets confused by the r# #[napi] - pub fn column(&mut self, column: String) { - self.inner = self.inner.clone().column(&column); + pub fn only_if(&mut self, predicate: String) { + self.inner = self.inner.clone().only_if(predicate); } #[napi] - pub fn filter(&mut self, filter: String) { - self.inner = self.inner.clone().filter(filter); - } - - #[napi] - pub fn select(&mut self, columns: Vec) { - self.inner = self.inner.clone().select(&columns); + pub fn select(&mut self, columns: Vec<(String, String)>) { + self.inner = self.inner.clone().select(Select::dynamic(&columns)); } #[napi] @@ -50,13 +52,46 @@ impl Query { } #[napi] - pub fn prefilter(&mut self, prefilter: bool) { - self.inner = self.inner.clone().prefilter(prefilter); + pub fn nearest_to(&mut self, vector: Float32Array) -> Result { + let inner = self + .inner + .clone() + .nearest_to(vector.as_ref()) + .default_error()?; + Ok(VectorQuery { inner }) } #[napi] - pub fn nearest_to(&mut self, vector: Float32Array) { - self.inner = self.inner.clone().nearest_to(&vector); + pub async fn execute(&self) -> napi::Result { + let inner_stream = self.inner.execute().await.map_err(|e| { + napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) + })?; + Ok(RecordBatchIterator::new(inner_stream)) + } +} + +#[napi] +pub struct VectorQuery { + inner: LanceDbVectorQuery, +} + +#[napi] +impl VectorQuery { + #[napi] + pub fn column(&mut self, column: String) { + self.inner = self.inner.clone().column(&column); + } + + #[napi] + pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> { + let distance_type = parse_distance_type(distance_type)?; + self.inner = self.inner.clone().distance_type(distance_type); + Ok(()) + } + + #[napi] + pub fn postfilter(&mut self) { + self.inner = self.inner.clone().postfilter(); } #[napi] @@ -70,8 +105,28 @@ impl Query { } #[napi] - pub async fn execute_stream(&self) -> napi::Result { - let inner_stream = self.inner.execute_stream().await.map_err(|e| { + pub fn bypass_vector_index(&mut self) { + self.inner = self.inner.clone().bypass_vector_index() + } + + #[napi] + pub fn only_if(&mut self, predicate: String) { + self.inner = self.inner.clone().only_if(predicate); + } + + #[napi] + pub fn select(&mut self, columns: Vec<(String, String)>) { + self.inner = self.inner.clone().select(Select::dynamic(&columns)); + } + + #[napi] + pub fn limit(&mut self, limit: u32) { + self.inner = self.inner.clone().limit(limit as usize); + } + + #[napi] + pub async fn execute(&self) -> napi::Result { + let inner_stream = self.inner.execute().await.map_err(|e| { napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) })?; Ok(RecordBatchIterator::new(inner_stream)) diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index f1a0ee58..4c3e4d33 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -23,7 +23,7 @@ use napi_derive::napi; use crate::error::NapiErrorExt; use crate::index::Index; -use crate::query::Query; +use crate::query::{Query, VectorQuery}; #[napi] pub struct Table { @@ -171,6 +171,11 @@ impl Table { Ok(Query::new(self.inner_ref()?.query())) } + #[napi] + pub fn vector_search(&self, vector: Float32Array) -> napi::Result { + self.query()?.nearest_to(vector) + } + #[napi] pub async fn add_columns(&self, transforms: Vec) -> napi::Result<()> { let transforms = transforms diff --git a/nodejs/src/util.rs b/nodejs/src/util.rs new file mode 100644 index 00000000..7cca8752 --- /dev/null +++ b/nodejs/src/util.rs @@ -0,0 +1,13 @@ +use lancedb::DistanceType; + +pub fn parse_distance_type(distance_type: impl AsRef) -> napi::Result { + match distance_type.as_ref().to_lowercase().as_str() { + "l2" => Ok(DistanceType::L2), + "cosine" => Ok(DistanceType::Cosine), + "dot" => Ok(DistanceType::Dot), + _ => Err(napi::Error::from_reason(format!( + "Invalid distance type '{}'. Must be one of l2, cosine, or dot", + distance_type.as_ref() + ))), + } +} diff --git a/python/Cargo.toml b/python/Cargo.toml index 58c8fe75..0811b692 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -22,6 +22,9 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } # Prevent dynamic linking of lzma, which comes from datafusion lzma-sys = { version = "*", features = ["static"] } +pin-project = "1.1.5" +futures.workspace = true +tokio = { version = "1.36.0", features = ["sync"] } [build-dependencies] pyo3-build-config = { version = "0.20.3", features = [ diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 1591d252..d16f8a1a 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import pyarrow as pa @@ -40,6 +40,8 @@ class Table: async def checkout_latest(self): ... async def restore(self): ... async def list_indices(self) -> List[IndexConfig]: ... + def query(self) -> Query: ... + def vector_search(self) -> VectorQuery: ... class IndexConfig: index_type: str @@ -52,3 +54,27 @@ async def connect( host_override: Optional[str], read_consistency_interval: Optional[float], ) -> Connection: ... + +class RecordBatchStream: + def schema(self) -> pa.Schema: ... + async def next(self) -> Optional[pa.RecordBatch]: ... + +class Query: + def where(self, filter: str): ... + def select(self, columns: Tuple[str, str]): ... + def limit(self, limit: int): ... + def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... + async def execute(self) -> RecordBatchStream: ... + +class VectorQuery: + async def execute(self) -> RecordBatchStream: ... + def where(self, filter: str): ... + def select(self, columns: List[str]): ... + def select_with_projection(self, columns: Tuple[str, str]): ... + def limit(self, limit: int): ... + def column(self, column: str): ... + def distance_type(self, distance_type: str): ... + def postfilter(self): ... + def refine_factor(self, refine_factor: int): ... + def nprobes(self, nprobes: int): ... + def bypass_vector_index(self): ... diff --git a/python/python/lancedb/arrow.py b/python/python/lancedb/arrow.py new file mode 100644 index 00000000..06393e66 --- /dev/null +++ b/python/python/lancedb/arrow.py @@ -0,0 +1,44 @@ +from typing import List + +import pyarrow as pa + +from ._lancedb import RecordBatchStream + + +class AsyncRecordBatchReader: + """ + An async iterator over a stream of RecordBatches. + + Also allows access to the schema of the stream + """ + + def __init__(self, inner: RecordBatchStream): + self.inner_ = inner + + @property + def schema(self) -> pa.Schema: + """ + Get the schema of the batches produced by the stream + + Accessing the schema does not consume any data from the stream + """ + return self.inner_.schema() + + async def read_all(self) -> List[pa.RecordBatch]: + """ + Read all the record batches from the stream + + This consumes the entire stream and returns a list of record batches + + If there are a lot of results this may consume a lot of memory + """ + return [batch async for batch in self] + + def __aiter__(self): + return self + + async def __anext__(self) -> pa.RecordBatch: + next = await self.inner_.next() + if next is None: + raise StopAsyncIteration + return next diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 2c2ef71b..2addcb24 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -24,6 +24,7 @@ import pyarrow as pa import pydantic from . import __version__ +from .arrow import AsyncRecordBatchReader from .common import VEC from .rerankers.base import Reranker from .rerankers.linear_combination import LinearCombinationReranker @@ -33,6 +34,8 @@ if TYPE_CHECKING: import PIL import polars as pl + from ._lancedb import Query as LanceQuery + from ._lancedb import VectorQuery as LanceVectorQuery from .pydantic import LanceModel from .table import Table @@ -921,3 +924,334 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): """ self._vector_query.refine_factor(refine_factor) return self + + +class AsyncQueryBase(object): + def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]): + """ + Construct an AsyncQueryBase + + This method is not intended to be called directly. Instead, use the + [Table.query][] method to create a query. + """ + self._inner = inner + + def where(self, predicate: str) -> AsyncQuery: + """ + Only return rows matching the given predicate + + The predicate should be supplied as an SQL query string. For example: + + >>> predicate = "x > 10" + >>> predicate = "y > 0 AND y < 100" + >>> predicate = "x > 5 OR y = 'test'" + + Filtering performance can often be improved by creating a scalar index + on the filter column(s). + """ + self._inner.where(predicate) + return self + + def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery: + """ + Return only the specified columns. + + By default a query will return all columns from the table. However, this can + have a very significant impact on latency. LanceDb stores data in a columnar + fashion. This + means we can finely tune our I/O to select exactly the columns we need. + + As a best practice you should always limit queries to the columns that you need. + If you pass in a list of column names then only those columns will be + returned. + + You can also use this method to create new "dynamic" columns based on your + existing columns. For example, you may not care about "a" or "b" but instead + simply want "a + b". This is often seen in the SELECT clause of an SQL query + (e.g. `SELECT a+b FROM my_table`). + + To create dynamic columns you can pass in a dict[str, str]. A column will be + returned for each entry in the map. The key provides the name of the column. + The value is an SQL string used to specify how the column is calculated. + + For example, an SQL query might state `SELECT a + b AS combined, c`. The + equivalent input to this method would be `{"combined": "a + b", "c": "c"}`. + + Columns will always be returned in the order given, even if that order is + different than the order used when adding the data. + """ + if isinstance(columns, dict): + column_tuples = list(columns.items()) + else: + try: + column_tuples = [(c, c) for c in columns] + except TypeError: + raise TypeError("columns must be a list of column names or a dict") + self._inner.select(column_tuples) + return self + + def limit(self, limit: int) -> AsyncQuery: + """ + Set the maximum number of results to return. + + By default, a plain search has no limit. If this method is not + called then every valid row from the table will be returned. + """ + self._inner.limit(limit) + return self + + async def to_batches(self) -> AsyncRecordBatchReader: + """ + Execute the query and return the results as an Apache Arrow RecordBatchReader. + """ + return AsyncRecordBatchReader(await self._inner.execute()) + + async def to_arrow(self) -> pa.Table: + """ + Execute the query and collect the results into an Apache Arrow Table. + + This method will collect all results into memory before returning. If + you expect a large number of results, you may want to use [to_batches][] + """ + batch_iter = await self.to_batches() + return pa.Table.from_batches( + await batch_iter.read_all(), schema=batch_iter.schema + ) + + async def to_pandas(self) -> "pd.DataFrame": + """ + Execute the query and collect the results into a pandas DataFrame. + + This method will collect all results into memory before returning. If + you expect a large number of results, you may want to use [to_batches][] + and convert each batch to pandas separately. + + Example + ------- + + >>> import asyncio + >>> from lancedb import connect_async + >>> async def doctest_example(): + ... conn = await connect_async("./.lancedb") + ... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}]) + ... async for batch in await table.query().to_batches(): + ... batch_df = batch.to_pandas() + >>> asyncio.run(doctest_example()) + """ + return (await self.to_arrow()).to_pandas() + + +class AsyncQuery(AsyncQueryBase): + def __init__(self, inner: LanceQuery): + """ + Construct an AsyncQuery + + This method is not intended to be called directly. Instead, use the + [Table.query][] method to create a query. + """ + super().__init__(inner) + self._inner = inner + + @classmethod + def _query_vec_to_array(self, vec: Union[VEC, Tuple]): + if isinstance(vec, list): + return pa.array(vec) + if isinstance(vec, np.ndarray): + return pa.array(vec) + if isinstance(vec, pa.Array): + return vec + if isinstance(vec, pa.ChunkedArray): + return vec.combine_chunks() + if isinstance(vec, tuple): + return pa.array(vec) + # We've checked everything we formally support in our typings + # but, as a fallback, let pyarrow try and convert it anyway. + # This can allow for some more exotic things like iterables + return pa.array(vec) + + def nearest_to( + self, query_vector: Optional[Union[VEC, Tuple]] = None + ) -> AsyncVectorQuery: + """ + Find the nearest vectors to the given query vector. + + This converts the query from a plain query to a vector query. + + This method will attempt to convert the input to the query vector + expected by the embedding model. If the input cannot be converted + then an error will be thrown. + + By default, there is no embedding model, and the input should be + something that can be converted to a pyarrow array of floats. This + includes lists, numpy arrays, and tuples. + + If there is only one vector column (a column whose data type is a + fixed size list of floats) then the column does not need to be specified. + If there is more than one vector column you must use + [AsyncVectorQuery::column][] to specify which column you would like to + compare with. + + If no index has been created on the vector column then a vector query + will perform a distance comparison between the query vector and every + vector in the database and then sort the results. This is sometimes + called a "flat search" + + For small databases, with tens of thousands of vectors or less, this can + be reasonably fast. In larger databases you should create a vector index + on the column. If there is a vector index then an "approximate" nearest + neighbor search (frequently called an ANN search) will be performed. This + search is much faster, but the results will be approximate. + + The query can be further parameterized using the returned builder. There + are various ANN search parameters that will let you fine tune your recall + accuracy vs search latency. + + Vector searches always have a [limit][]. If `limit` has not been called then + a default `limit` of 10 will be used. + """ + return AsyncVectorQuery( + self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) + ) + + +class AsyncVectorQuery(AsyncQueryBase): + def __init__(self, inner: LanceVectorQuery): + """ + Construct an AsyncVectorQuery + + This method is not intended to be called directly. Instead, create + a query first with [Table.query][] and then use [AsyncQuery.nearest_to][] + to convert to a vector query. + """ + super().__init__(inner) + self._inner = inner + + def column(self, column: str) -> AsyncVectorQuery: + """ + Set the vector column to query + + This controls which column is compared to the query vector supplied in + the call to [Query.nearest_to][]. + + This parameter must be specified if the table has more than one column + whose data type is a fixed-size-list of floats. + """ + self._inner.column(column) + return self + + def nprobes(self, nprobes: int) -> AsyncVectorQuery: + """ + Set the number of partitions to search (probe) + + This argument is only used when the vector column has an IVF PQ index. + If there is no index then this value is ignored. + + The IVF stage of IVF PQ divides the input into partitions (clusters) of + related values. + + The partition whose centroids are closest to the query vector will be + exhaustiely searched to find matches. This parameter controls how many + partitions should be searched. + + Increasing this value will increase the recall of your query but will + also increase the latency of your query. The default value is 20. This + default is good for many cases but the best value to use will depend on + your data and the recall that you need to achieve. + + For best results we recommend tuning this parameter with a benchmark against + your actual data to find the smallest possible value that will still give + you the desired recall. + """ + self._inner.nprobes(nprobes) + return self + + def refine_factor(self, refine_factor: int) -> AsyncVectorQuery: + """ + A multiplier to control how many additional rows are taken during the refine + step + + This argument is only used when the vector column has an IVF PQ index. + If there is no index then this value is ignored. + + An IVF PQ index stores compressed (quantized) values. They query vector is + compared against these values and, since they are compressed, the comparison is + inaccurate. + + This parameter can be used to refine the results. It can improve both improve + recall and correct the ordering of the nearest results. + + To refine results LanceDb will first perform an ANN search to find the nearest + `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and + `limit` is the default (10) then the first 30 results will be selected. LanceDb + then fetches the full, uncompressed, values for these 30 results. The results + are then reordered by the true distance and only the nearest 10 are kept. + + Note: there is a difference between calling this method with a value of 1 and + never calling this method at all. Calling this method with any value will have + an impact on your search latency. When you call this method with a + `refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed, + values so that it can potentially reorder the results. + + Note: if this method is NOT called then the distances returned in the _distance + column will be approximate distances based on the comparison of the quantized + query vector and the quantized result vectors. This can be considerably + different than the true distance between the query vector and the actual + uncompressed vector. + """ + self._inner.refine_factor(refine_factor) + return self + + def distance_type(self, distance_type: str) -> AsyncVectorQuery: + """ + Set the distance metric to use + + When performing a vector search we try and find the "nearest" vectors according + to some kind of distance metric. This parameter controls which distance metric + to use. See @see {@link IvfPqOptions.distanceType} for more details on the + different distance metrics available. + + Note: if there is a vector index then the distance type used MUST match the + distance type used to train the vector index. If this is not done then the + results will be invalid. + + By default "l2" is used. + """ + self._inner.distance_type(distance_type) + return self + + def postfilter(self) -> AsyncVectorQuery: + """ + If this is called then filtering will happen after the vector search instead of + before. + + By default filtering will be performed before the vector search. This is how + filtering is typically understood to work. This prefilter step does add some + additional latency. Creating a scalar index on the filter column(s) can + often improve this latency. However, sometimes a filter is too complex or + scalar indices cannot be applied to the column. In these cases postfiltering + can be used instead of prefiltering to improve latency. + + Post filtering applies the filter to the results of the vector search. This + means we only run the filter on a much smaller set of data. However, it can + cause the query to return fewer than `limit` results (or even no results) if + none of the nearest results match the filter. + + Post filtering happens during the "refine stage" (described in more detail in + @see {@link VectorQuery#refineFactor}). This means that setting a higher refine + factor can often help restore some of the results lost by post filtering. + """ + self._inner.postfilter() + return self + + def bypass_vector_index(self) -> AsyncVectorQuery: + """ + If this is called then any vector index is skipped + + An exhaustive (flat) search will be performed. The query vector will + be compared to every vector in the table. At high scales this can be + expensive. However, this is often still useful. For example, skipping + the vector index can give you ground truth results which you can use to + calculate your recall to select an appropriate value for nprobes. + """ + self._inner.bypass_vector_index() + return self diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index efd3f4d2..3501ae60 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -43,7 +43,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict -from .query import LanceQueryBuilder, Query +from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query from .util import ( fs_from_uri, inf_vector_column_query, @@ -1899,6 +1899,9 @@ class AsyncTable: """ return await self._inner.count_rows(filter) + def query(self) -> AsyncQuery: + return AsyncQuery(self._inner.query()) + async def to_pandas(self) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -1906,7 +1909,7 @@ class AsyncTable: ------- pd.DataFrame """ - return self.to_arrow().to_pandas() + return (await self.to_arrow()).to_pandas() async def to_arrow(self) -> pa.Table: """Return the table as a pyarrow Table. @@ -1915,7 +1918,7 @@ class AsyncTable: ------- pa.Table """ - raise NotImplementedError + return await self.query().to_arrow() async def create_index( self, @@ -2068,90 +2071,18 @@ class AsyncTable: return LanceMergeInsertBuilder(self, on) - async def search( + def vector_search( self, - query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, - vector_column_name: Optional[str] = None, - query_type: str = "auto", - ) -> LanceQueryBuilder: - """Create a search query to find the nearest neighbors - of the given query vector. We currently support [vector search][search] - and [full-text search][experimental-full-text-search]. - - All query options are defined in [Query][lancedb.query.Query]. - - Examples - -------- - >>> import lancedb - >>> db = lancedb.connect("./.lancedb") - >>> data = [ - ... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]}, - ... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]}, - ... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]} - ... ] - >>> table = db.create_table("my_table", data) - >>> query = [0.4, 1.4, 2.4] - >>> (table.search(query) - ... .where("original_width > 1000", prefilter=True) - ... .select(["caption", "original_width", "vector"]) - ... .limit(2) - ... .to_pandas()) - caption original_width vector _distance - 0 foo 2000 [0.5, 3.4, 1.3] 5.220000 - 1 test 3000 [0.3, 6.2, 2.6] 23.089996 - - Parameters - ---------- - query: list/np.ndarray/str/PIL.Image.Image, default None - The targetted vector to search for. - - - *default None*. - Acceptable types are: list, np.ndarray, PIL.Image.Image - - - If None then the select/where/limit clauses are applied to filter - the table - vector_column_name: str, optional - The name of the vector column to search. - - The vector column needs to be a pyarrow fixed size list type - - - If not specified then the vector column is inferred from - the table schema - - - If the table has multiple vector columns then the *vector_column_name* - needs to be specified. Otherwise, an error is raised. - query_type: str - *default "auto"*. - Acceptable types are: "vector", "fts", "hybrid", or "auto" - - - If "auto" then the query type is inferred from the query; - - - If `query` is a list/np.ndarray then the query type is - "vector"; - - - If `query` is a PIL.Image.Image then either do vector search, - or raise an error if no corresponding embedding function is found. - - - If `query` is a string, then the query type is "vector" if the - table has embedding functions else the query type is "fts" - - Returns - ------- - LanceQueryBuilder - A query builder object representing the query. - Once executed, the query returns - - - selected columns - - - the vector - - - and also the "_distance" column which is the distance between the query - vector and the returned vector. + query_vector: Optional[Union[VEC, Tuple]] = None, + ) -> AsyncVectorQuery: """ - raise NotImplementedError + Search the table with a given query vector. - async def _execute_query(self, query: Query) -> pa.Table: - pass + This is a convenience method for preparing a vector query and + is the same thing as calling `nearestTo` on the builder returned + by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details. + """ + return self.query().nearest_to(query_vector) async def _do_merge( self, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index d1a08666..49d207b8 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -12,16 +12,19 @@ # limitations under the License. import unittest.mock as mock +from datetime import timedelta import lance +import lancedb import numpy as np import pandas.testing as tm import pyarrow as pa import pytest +import pytest_asyncio from lancedb.db import LanceDBConnection from lancedb.pydantic import LanceModel, Vector -from lancedb.query import LanceVectorQueryBuilder, Query -from lancedb.table import LanceTable +from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query +from lancedb.table import AsyncTable, LanceTable class MockTable: @@ -65,6 +68,24 @@ def table(tmp_path) -> MockTable: return MockTable(tmp_path) +@pytest_asyncio.fixture +async def table_async(tmp_path) -> AsyncTable: + conn = await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) + data = pa.table( + { + "vector": pa.array( + [[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2) + ), + "id": pa.array([1, 2]), + "str_field": pa.array(["a", "b"]), + "float_field": pa.array([1.0, 2.0]), + } + ) + return await conn.create_table("test", data) + + def test_cast(table): class TestModel(LanceModel): vector: Vector(2) @@ -184,3 +205,109 @@ def test_query_builder_with_different_vector_column(): def cosine_distance(vec1, vec2): return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) + + +async def check_query( + query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None +): + num_rows = 0 + results = await query.to_batches() + async for batch in results: + if expected_columns is not None: + assert batch.schema.names == expected_columns + num_rows += batch.num_rows + if expected_num_rows is not None: + assert num_rows == expected_num_rows + + +@pytest.mark.asyncio +async def test_query_async(table_async: AsyncTable): + await check_query( + table_async.query(), + expected_num_rows=2, + expected_columns=["vector", "id", "str_field", "float_field"], + ) + await check_query(table_async.query().where("id = 2"), expected_num_rows=1) + await check_query( + table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"] + ) + await check_query( + table_async.query().select({"foo": "id", "bar": "id + 1"}), + expected_columns=["foo", "bar"], + ) + await check_query(table_async.query().limit(1), expected_num_rows=1) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2 + ) + # Support different types of inputs for the vector query + for vector_query in [ + [1, 2], + [1.0, 2.0], + np.array([1, 2]), + (1, 2), + ]: + await check_query( + table_async.query().nearest_to(vector_query), expected_num_rows=2 + ) + + # No easy way to check these vector query parameters are doing what they say. We + # just check that they don't raise exceptions and assume this is tested at a lower + # level. + await check_query( + table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(), + expected_num_rows=1, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1), + expected_num_rows=2, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).nprobes(10), + expected_num_rows=2, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(), + expected_num_rows=2, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"), + expected_num_rows=2, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"), + expected_num_rows=2, + ) + + # Make sure we can use a vector query as a base query (e.g. call limit on it) + # Also make sure `vector_search` works + await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1) + + # Also check an empty query + await check_query(table_async.query().where("id < 0"), expected_num_rows=0) + + +@pytest.mark.asyncio +async def test_query_to_arrow_async(table_async: AsyncTable): + table = await table_async.to_arrow() + assert table.num_rows == 2 + assert table.num_columns == 4 + + table = await table_async.query().to_arrow() + assert table.num_rows == 2 + assert table.num_columns == 4 + + table = await table_async.query().where("id < 0").to_arrow() + assert table.num_rows == 0 + assert table.num_columns == 4 + + +@pytest.mark.asyncio +async def test_query_to_pandas_async(table_async: AsyncTable): + df = await table_async.to_pandas() + assert df.shape == (2, 4) + + df = await table_async.query().to_pandas() + assert df.shape == (2, 4) + + df = await table_async.query().where("id < 0").to_pandas() + assert df.shape == (0, 4) diff --git a/python/src/arrow.rs b/python/src/arrow.rs new file mode 100644 index 00000000..81bdaa03 --- /dev/null +++ b/python/src/arrow.rs @@ -0,0 +1,51 @@ +// use arrow::datatypes::SchemaRef; +// use lancedb::arrow::SendableRecordBatchStream; + +use std::sync::Arc; + +use arrow::{ + datatypes::SchemaRef, + pyarrow::{IntoPyArrow, ToPyArrow}, +}; +use futures::stream::StreamExt; +use lancedb::arrow::SendableRecordBatchStream; +use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python}; +use pyo3_asyncio::tokio::future_into_py; + +use crate::error::PythonErrorExt; + +#[pyclass] +pub struct RecordBatchStream { + schema: SchemaRef, + inner: Arc>, +} + +impl RecordBatchStream { + pub fn new(inner: SendableRecordBatchStream) -> Self { + let schema = inner.schema().clone(); + Self { + schema, + inner: Arc::new(tokio::sync::Mutex::new(inner)), + } + } +} + +#[pymethods] +impl RecordBatchStream { + pub fn schema(&self, py: Python) -> PyResult { + (*self.schema).clone().into_pyarrow(py) + } + + pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let inner_next = inner.lock().await.next().await; + inner_next + .map(|item| { + let item = item.infer_error()?; + Python::with_gil(|py| item.to_pyarrow(py)) + }) + .transpose() + }) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index bf9006fc..558668cb 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -12,15 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow::RecordBatchStream; use connection::{connect, Connection}; use env_logger::Env; use index::{Index, IndexConfig}; use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python}; +use query::{Query, VectorQuery}; use table::Table; +pub mod arrow; pub mod connection; pub mod error; pub mod index; +pub mod query; pub mod table; pub mod util; @@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::
()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) diff --git a/python/src/query.rs b/python/src/query.rs new file mode 100644 index 00000000..cdcac654 --- /dev/null +++ b/python/src/query.rs @@ -0,0 +1,125 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::array::make_array; +use arrow::array::ArrayData; +use arrow::pyarrow::FromPyArrow; +use lancedb::query::{ + ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, +}; +use pyo3::pyclass; +use pyo3::pymethods; +use pyo3::PyAny; +use pyo3::PyRef; +use pyo3::PyResult; +use pyo3_asyncio::tokio::future_into_py; + +use crate::arrow::RecordBatchStream; +use crate::error::PythonErrorExt; +use crate::util::parse_distance_type; + +#[pyclass] +pub struct Query { + inner: LanceDbQuery, +} + +impl Query { + pub fn new(query: LanceDbQuery) -> Self { + Self { inner: query } + } +} + +#[pymethods] +impl Query { + pub fn r#where(&mut self, predicate: String) { + self.inner = self.inner.clone().only_if(predicate); + } + + pub fn select(&mut self, columns: Vec<(String, String)>) { + self.inner = self.inner.clone().select(Select::dynamic(&columns)); + } + + pub fn limit(&mut self, limit: u32) { + self.inner = self.inner.clone().limit(limit as usize); + } + + pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult { + let data: ArrayData = ArrayData::from_pyarrow(vector)?; + let array = make_array(data); + let inner = self.inner.clone().nearest_to(array).infer_error()?; + Ok(VectorQuery { inner }) + } + + pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let inner_stream = inner.execute().await.infer_error()?; + Ok(RecordBatchStream::new(inner_stream)) + }) + } +} + +#[pyclass] +pub struct VectorQuery { + inner: LanceDbVectorQuery, +} + +#[pymethods] +impl VectorQuery { + pub fn r#where(&mut self, predicate: String) { + self.inner = self.inner.clone().only_if(predicate); + } + + pub fn select(&mut self, columns: Vec<(String, String)>) { + self.inner = self.inner.clone().select(Select::dynamic(&columns)); + } + + pub fn limit(&mut self, limit: u32) { + self.inner = self.inner.clone().limit(limit as usize); + } + + pub fn column(&mut self, column: String) { + self.inner = self.inner.clone().column(&column); + } + + pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> { + let distance_type = parse_distance_type(distance_type)?; + self.inner = self.inner.clone().distance_type(distance_type); + Ok(()) + } + + pub fn postfilter(&mut self) { + self.inner = self.inner.clone().postfilter(); + } + + pub fn refine_factor(&mut self, refine_factor: u32) { + self.inner = self.inner.clone().refine_factor(refine_factor); + } + + pub fn nprobes(&mut self, nprobe: u32) { + self.inner = self.inner.clone().nprobes(nprobe as usize); + } + + pub fn bypass_vector_index(&mut self) { + self.inner = self.inner.clone().bypass_vector_index() + } + + pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let inner_stream = inner.execute().await.infer_error()?; + Ok(RecordBatchStream::new(inner_stream)) + }) + } +} diff --git a/python/src/table.rs b/python/src/table.rs index f58b0d0c..0f4ed73d 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py; use crate::{ error::PythonErrorExt, index::{Index, IndexConfig}, + query::Query, }; #[pyclass] @@ -179,4 +180,8 @@ impl Table { async move { inner.restore().await.infer_error() }, ) } + + pub fn query(&self) -> Query { + Query::new(self.inner_ref().unwrap().query()) + } } diff --git a/python/src/util.rs b/python/src/util.rs index df9ab5d0..893e8089 100644 --- a/python/src/util.rs +++ b/python/src/util.rs @@ -1,6 +1,10 @@ use std::sync::Mutex; -use pyo3::{exceptions::PyRuntimeError, PyResult}; +use lancedb::DistanceType; +use pyo3::{ + exceptions::{PyRuntimeError, PyValueError}, + PyResult, +}; /// A wrapper around a rust builder /// @@ -33,3 +37,15 @@ impl BuilderWrapper { Ok(result) } } + +pub fn parse_distance_type(distance_type: impl AsRef) -> PyResult { + match distance_type.as_ref().to_lowercase().as_str() { + "l2" => Ok(DistanceType::L2), + "cosine" => Ok(DistanceType::Cosine), + "dot" => Ok(DistanceType::Dot), + _ => Err(PyValueError::new_err(format!( + "Invalid distance type '{}'. Must be one of l2, cosine, or dot", + distance_type.as_ref() + ))), + } +} diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index 6b63593f..7e00ac21 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use futures::{TryFutureExt, TryStreamExt}; use lance_linalg::distance::MetricType; +use lancedb::query::{ExecutableQuery, QueryBase, Select}; use neon::context::FunctionContext; use neon::handle::Handle; use neon::prelude::*; @@ -56,53 +57,72 @@ impl JsQuery { let channel = cx.channel(); let table = js_table.table.clone(); - let query_vector = query_obj.get_opt::(&mut cx, "_queryVector")?; let mut builder = table.query(); - if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) { - builder = builder.nearest_to(&query); - if let Some(metric_type) = query_obj - .get_opt::(&mut cx, "_metricType")? - .map(|s| s.value(&mut cx)) - .map(|s| MetricType::try_from(s.as_str()).unwrap()) - { - builder = builder.metric_type(metric_type); - } - - let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; - builder = builder.nprobes(nprobes); - }; - if let Some(filter) = query_obj .get_opt::(&mut cx, "_filter")? .map(|s| s.value(&mut cx)) { - builder = builder.filter(filter); + builder = builder.only_if(filter); } if let Some(select) = select { - builder = builder.select(select.as_slice()); + builder = builder.select(Select::columns(select.as_slice())); } if let Some(limit) = limit { builder = builder.limit(limit as usize); }; - builder = builder.prefilter(prefilter); + let query_vector = query_obj.get_opt::(&mut cx, "_queryVector")?; + if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) { + let mut vector_builder = builder.nearest_to(query).unwrap(); + if let Some(metric_type) = query_obj + .get_opt::(&mut cx, "_metricType")? + .map(|s| s.value(&mut cx)) + .map(|s| MetricType::try_from(s.as_str()).unwrap()) + { + vector_builder = vector_builder.distance_type(metric_type); + } - rt.spawn(async move { - let record_batch_stream = builder.execute_stream(); - let results = record_batch_stream - .and_then(|stream| { - stream - .try_collect::>() - .map_err(lancedb::error::Error::from) - }) - .await; + let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; + vector_builder = vector_builder.nprobes(nprobes); - deferred.settle_with(&channel, move |mut cx| { - let results = results.or_throw(&mut cx)?; - let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; - convert::new_js_buffer(buffer, &mut cx, is_electron) + if !prefilter { + vector_builder = vector_builder.postfilter(); + } + rt.spawn(async move { + let results = vector_builder + .execute() + .and_then(|stream| { + stream + .try_collect::>() + .map_err(lancedb::error::Error::from) + }) + .await; + + deferred.settle_with(&channel, move |mut cx| { + let results = results.or_throw(&mut cx)?; + let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; + convert::new_js_buffer(buffer, &mut cx, is_electron) + }); }); - }); + } else { + rt.spawn(async move { + let results = builder + .execute() + .and_then(|stream| { + stream + .try_collect::>() + .map_err(lancedb::error::Error::from) + }) + .await; + + deferred.settle_with(&channel, move |mut cx| { + let results = results.or_throw(&mut cx)?; + let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; + convert::new_js_buffer(buffer, &mut cx, is_electron) + }); + }); + }; + Ok(promise) } } diff --git a/rust/lancedb/examples/simple.rs b/rust/lancedb/examples/simple.rs index 79dd7012..e148f483 100644 --- a/rust/lancedb/examples/simple.rs +++ b/rust/lancedb/examples/simple.rs @@ -21,6 +21,7 @@ use futures::TryStreamExt; use lancedb::connection::Connection; use lancedb::index::Index; +use lancedb::query::{ExecutableQuery, QueryBase}; use lancedb::{connect, Result, Table as LanceDbTable}; #[tokio::main] @@ -150,9 +151,10 @@ async fn create_index(table: &LanceDbTable) -> Result<()> { async fn search(table: &LanceDbTable) -> Result> { // --8<-- [start:search] table - .search(&[1.0; 128]) + .query() .limit(2) - .execute_stream() + .nearest_to(&[1.0; 128])? + .execute() .await? .try_collect::>() .await diff --git a/rust/lancedb/src/io/object_store.rs b/rust/lancedb/src/io/object_store.rs index e7dc3d78..4e052d65 100644 --- a/rust/lancedb/src/io/object_store.rs +++ b/rust/lancedb/src/io/object_store.rs @@ -342,7 +342,11 @@ mod test { use object_store::local::LocalFileSystem; use tempfile; - use crate::{connect, table::WriteOptions}; + use crate::{ + connect, + query::{ExecutableQuery, QueryBase}, + table::WriteOptions, + }; #[tokio::test] async fn test_e2e() { @@ -381,9 +385,11 @@ mod test { assert_eq!(t.count_rows(None).await.unwrap(), 100); let q = t - .search(&[0.1, 0.1, 0.1, 0.1]) + .query() .limit(10) - .execute_stream() + .nearest_to(&[0.1, 0.1, 0.1, 0.1]) + .unwrap() + .execute() .await .unwrap(); diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index a467a636..689262d9 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -150,6 +150,7 @@ //! # use arrow_schema::{DataType, Schema, Field}; //! # use arrow_array::{RecordBatch, RecordBatchIterator}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; +//! # use lancedb::query::{ExecutableQuery, QueryBase}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); //! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); @@ -170,8 +171,10 @@ //! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap(); //! # let table = db.open_table("my_table").execute().await.unwrap(); //! let results = table -//! .search(&[1.0; 128]) -//! .execute_stream() +//! .query() +//! .nearest_to(&[1.0; 128]) +//! .unwrap() +//! .execute() //! .await //! .unwrap() //! .try_collect::>() diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 137015a7..bf3f8668 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -12,193 +12,725 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; use std::sync::Arc; -use arrow_array::Float32Array; -use lance_linalg::distance::MetricType; +use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; +use arrow_schema::DataType; +use half::f16; use crate::arrow::SendableRecordBatchStream; -use crate::error::Result; +use crate::error::{Error, Result}; use crate::table::TableInternal; +use crate::DistanceType; pub(crate) const DEFAULT_TOP_K: usize = 10; +/// Which columns should be retrieved from the database #[derive(Debug, Clone)] pub enum Select { + /// Select all columns + /// + /// Warning: This will always be slower than selecting only the columns you need. All, - Simple(Vec), - Projection(Vec<(String, String)>), + /// Select the provided columns + Columns(Vec), + /// Advanced selection which allows for dynamic column calculations + /// + /// The first item in each tuple is a name to assign to the output column. + /// The second item in each tuple is an SQL expression to evaluate the result. + /// + /// See [`Query::select`] for more details and examples + Dynamic(Vec<(String, String)>), } -/// A builder for nearest neighbor queries for LanceDB. -#[derive(Clone)] +impl Select { + /// Create a simple selection that only selects the given columns + /// + /// This method is a convenience method for creating a [`Select::Columns`] variant + /// from either Vec<&str> or Vec + pub fn columns(columns: &[impl AsRef]) -> Self { + Self::Columns(columns.iter().map(|c| c.as_ref().to_string()).collect()) + } + /// Create a dynamic selection that allows for advanced column selection + /// + /// This method is a convenience method for creating a [`Select::Dynamic`] variant + /// from either &str or String tuples + pub fn dynamic(columns: &[(impl AsRef, impl AsRef)]) -> Self { + Self::Dynamic( + columns + .iter() + .map(|(name, value)| (name.as_ref().to_string(), value.as_ref().to_string())) + .collect(), + ) + } +} + +/// A trait for converting a type to a query vector +/// +/// This is primarily intended to allow rust users that are unfamiliar with Arrow +/// a chance to use native types such as Vec instead of arrow arrays. It also +/// serves as an integration point for other rust libraries such as polars. +/// +/// By accepting the query vector as an array we are potentially allowing any data +/// type to be used as the query vector. In the future, custom embedding models +/// may be installed. These models may accept something other than f32. For example, +/// sentence transformers typically expect the query to be a string. This means that +/// any kind of conversion library should expect to convert more than just f32. +pub trait ToQueryVector { + /// Convert the user's query vector input to a query vector + /// + /// This trait exists to allow users to provide many different types as + /// input to the [`crate::query::QueryBuilder::nearest_to`] method. + /// + /// By default, there is no embedding model registered, and the input should + /// be the vector that the user wants to search with. LanceDb expects a + /// fixed-size-list of floats. This means the input will need to be something + /// that can be converted to a fixed-size-list of floats (e.g. a Vec) + /// + /// This crate provides a variety of default impls for common types. + /// + /// On the other hand, if an embedding model is registered, then the embedding + /// model will determine the input type. For example, sentence transformers expect + /// the input to be strings. The input should be converted to an array with + /// a single string value. + /// + /// Trait impls should try and convert the source data to the requested data + /// type if they can and fail with a meaningful error if they cannot. An + /// embedding model label is provided to help provide useful error messages. For + /// example, "failed to create query vector, the sentence transformer model + /// expects strings but the input was a list of integers". + /// + /// Note that the output is an array but, in most cases, this will be an array of + /// length one. The query vector is considered a single "item" and arrays of + /// length one are how arrow represents scalars. + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result>; +} + +// TODO: perhaps support some casts like f32->f64 and maybe even f64->f32? +impl ToQueryVector for Arc { + fn to_query_vector( + self, + data_type: &DataType, + _embedding_model_label: &str, + ) -> Result> { + if data_type != self.data_type() { + match data_type { + // If the embedding wants floating point data we can try and cast + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + arrow_cast::cast(&self, data_type).map_err(|e| { + Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was {:?} but the expected data type was {:?}. Attempt to cast yielded: {}", + self.data_type(), + data_type, + e + ), + } + }) + }, + // TODO: Should we try and cast even if the embedding wants non-numeric data? + _ => Err(Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was {:?} but the expected data type was {:?}", + self.data_type(), + data_type + )}) + } + } else { + Ok(self.clone()) + } + } +} + +impl ToQueryVector for &dyn Array { + fn to_query_vector( + self, + data_type: &DataType, + _embedding_model_label: &str, + ) -> Result> { + if data_type != self.data_type() { + Err(Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was {:?} but the expected data type was {:?}", + self.data_type(), + data_type + )}) + } else { + let data = self.to_data(); + Ok(make_array(data)) + } + } +} + +impl ToQueryVector for &[f16] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + match data_type { + DataType::Float16 => { + let arr: Vec = self.to_vec(); + Ok(Arc::new(Float16Array::from(arr))) + } + DataType::Float32 => { + let arr: Vec = self.iter().map(|x| f32::from(*x)).collect(); + Ok(Arc::new(Float32Array::from(arr))) + }, + DataType::Float64 => { + let arr: Vec = self.iter().map(|x| f64::from(*x)).collect(); + Ok(Arc::new(Float64Array::from(arr))) + } + _ => Err(Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was &[f16] but the embedding model \"{}\" expected data type {:?}", + embedding_model_label, + data_type + ), + }), + } + } +} + +impl ToQueryVector for &[f32] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + match data_type { + DataType::Float16 => { + let arr: Vec = self.iter().map(|x| f16::from_f32(*x)).collect(); + Ok(Arc::new(Float16Array::from(arr))) + } + DataType::Float32 => { + let arr: Vec = self.to_vec(); + Ok(Arc::new(Float32Array::from(arr))) + }, + DataType::Float64 => { + let arr: Vec = self.iter().map(|x| *x as f64).collect(); + Ok(Arc::new(Float64Array::from(arr))) + } + _ => Err(Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was &[f32] but the embedding model \"{}\" expected data type {:?}", + embedding_model_label, + data_type + ), + }), + } + } +} + +impl ToQueryVector for &[f64] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + match data_type { + DataType::Float16 => { + let arr: Vec = self.iter().map(|x| f16::from_f64(*x)).collect(); + Ok(Arc::new(Float16Array::from(arr))) + } + DataType::Float32 => { + let arr: Vec = self.iter().map(|x| *x as f32).collect(); + Ok(Arc::new(Float32Array::from(arr))) + }, + DataType::Float64 => { + let arr: Vec = self.to_vec(); + Ok(Arc::new(Float64Array::from(arr))) + } + _ => Err(Error::InvalidInput { + message: format!( + "failed to create query vector, the input data type was &[f64] but the embedding model \"{}\" expected data type {:?}", + embedding_model_label, + data_type + ), + }), + } + } +} + +impl ToQueryVector for &[f16; N] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +impl ToQueryVector for &[f32; N] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +impl ToQueryVector for &[f64; N] { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +impl ToQueryVector for Vec { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +impl ToQueryVector for Vec { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +impl ToQueryVector for Vec { + fn to_query_vector( + self, + data_type: &DataType, + embedding_model_label: &str, + ) -> Result> { + self.as_slice() + .to_query_vector(data_type, embedding_model_label) + } +} + +/// Common parameters that can be applied to scans and vector queries +pub trait QueryBase { + /// Set the maximum number of results to return. + /// + /// By default, a plain search has no limit. If this method is not + /// called then every valid row from the table will be returned. + /// + /// A vector search always has a limit. If this is not called then + /// it will default to 10. + fn limit(self, limit: usize) -> Self; + + /// Only return rows which match the filter. + /// + /// The filter should be supplied as an SQL query string. For example: + /// + /// ```ignore + /// x > 10 + /// y > 0 AND y < 100 + /// x > 5 OR y = 'test' + /// ``` + /// + /// Filtering performance can often be improved by creating a scalar index + /// on the filter column(s). + fn only_if(self, filter: impl AsRef) -> Self; + + /// Return only the specified columns. + /// + /// By default a query will return all columns from the table. However, this can have + /// a very significant impact on latency. LanceDb stores data in a columnar fashion. This + /// means we can finely tune our I/O to select exactly the columns we need. + /// + /// As a best practice you should always limit queries to the columns that you need. + /// + /// You can also use this method to create new "dynamic" columns based on your existing columns. + /// For example, you may not care about "a" or "b" but instead simply want "a + b". This is often + /// seen in the SELECT clause of an SQL query (e.g. `SELECT a+b FROM my_table`). + /// + /// To create dynamic columns use [`Select::Dynamic`] (it might be easier to create this with the + /// helper method [`Select::dynamic`]). A column will be returned for each tuple provided. The + /// first value in that tuple provides the name of the column. The second value in the tuple is + /// an SQL string used to specify how the column is calculated. + /// + /// For example, an SQL query might state `SELECT a + b AS combined, c`. The equivalent + /// input to [`Select::dynamic`] would be `&[("combined", "a + b"), ("c", "c")]`. + /// + /// Columns will always be returned in the order given, even if that order is different than + /// the order used when adding the data. + fn select(self, selection: Select) -> Self; +} + +pub trait HasQuery { + fn mut_query(&mut self) -> &mut Query; +} + +impl QueryBase for T { + fn limit(mut self, limit: usize) -> Self { + self.mut_query().limit = Some(limit); + self + } + + fn only_if(mut self, filter: impl AsRef) -> Self { + self.mut_query().filter = Some(filter.as_ref().to_string()); + self + } + + fn select(mut self, select: Select) -> Self { + self.mut_query().select = select; + self + } +} + +/// Options for controlling the execution of a query +#[non_exhaustive] +pub struct QueryExecutionOptions { + /// The maximum number of rows that will be contained in a single + /// `RecordBatch` delivered by the query. + /// + /// Note: This is a maximum only. The query may return smaller + /// batches, even in the middle of a query, to avoid forcing + /// memory copies due to concatenation. + /// + /// Note: Slicing an Arrow RecordBatch is a zero-copy operation + /// and so the performance penalty of reading smaller batches + /// is typically very small. + /// + /// By default, this is 1024 + pub max_batch_length: u32, +} + +impl Default for QueryExecutionOptions { + fn default() -> Self { + Self { + max_batch_length: 1024, + } + } +} + +/// A trait for a query object that can be executed to get results +/// +/// There are various kinds of queries but they all return results +/// in the same way. +pub trait ExecutableQuery { + /// Execute the query with default options and return results + /// + /// See [`ExecutableQuery::execute_with_options`] for more details. + fn execute(&self) -> impl Future> + Send { + self.execute_with_options(QueryExecutionOptions::default()) + } + + /// Execute the query and return results + /// + /// The query results are returned as a [`SendableRecordBatchStream`]. This is + /// an Stream of Arrow [`arrow_array::RecordBatch`] (and you can also independently + /// access the [`arrow_schema::Schema`] without polling the stream). + /// + /// Note: The size of the returned batches and the order of individual rows is + /// not deterministic. + /// + /// LanceDb will use many threads to calculate results and, when + /// the result set is large, multiple batches will be processed at one time. + /// This readahead is limited however and backpressure will be applied if this + /// stream is consumed slowly (this constrains the maximum memory used by a + /// single query. + /// + /// For simpler access or row-based access we recommend creating extension traits + /// to convert Arrow data into your internal data model. + fn execute_with_options( + &self, + options: QueryExecutionOptions, + ) -> impl Future> + Send; +} + +/// A builder for LanceDB queries. +/// +/// See [`crate::Table::query`] for more details on queries +/// +/// See [`QueryBase`] for methods that can be used to parameterize +/// the query. +/// +/// See [`ExecutableQuery`] for methods that can be used to execute +/// the query and retrieve results. +/// +/// This query object can be reused to issue the same query multiple +/// times. +#[derive(Debug, Clone)] pub struct Query { parent: Arc, - // The column to run the query on. If not specified, we will attempt to guess - // the column based on the dataset's schema. - pub(crate) column: Option, - - // IVF PQ - ANN search. - pub(crate) query_vector: Option, - pub(crate) nprobes: usize, - pub(crate) refine_factor: Option, - pub(crate) metric_type: Option, - /// limit the number of rows to return. pub(crate) limit: Option, /// Apply filter to the returned rows. pub(crate) filter: Option, /// Select column projection. pub(crate) select: Select, +} +impl Query { + pub(crate) fn new(parent: Arc) -> Self { + Self { + parent, + limit: None, + filter: None, + select: Select::All, + } + } + + /// Helper method to convert the query to a VectorQuery with a `query_vector` + /// of None. This retrofits to some existing inner paths that work with a + /// single query object for both vector and plain queries. + pub(crate) fn into_vector(self) -> VectorQuery { + VectorQuery::new(self) + } + + /// Find the nearest vectors to the given query vector. + /// + /// This converts the query from a plain query to a vector query. + /// + /// This method will attempt to convert the input to the query vector + /// expected by the embedding model. If the input cannot be converted + /// then an error will be returned. + /// + /// By default, there is no embedding model, and the input should be + /// vector/slice of floats. + /// + /// If there is only one vector column (a column whose data type is a + /// fixed size list of floats) then the column does not need to be specified. + /// If there is more than one vector column you must use [`Query::column`] + /// to specify which column you would like to compare with. + /// + /// If no index has been created on the vector column then a vector query + /// will perform a distance comparison between the query vector and every + /// vector in the database and then sort the results. This is sometimes + /// called a "flat search" + /// + /// For small databases, with a few hundred thousand vectors or less, this can + /// be reasonably fast. In larger databases you should create a vector index + /// on the column. If there is a vector index then an "approximate" nearest + /// neighbor search (frequently called an ANN search) will be performed. This + /// search is much faster, but the results will be approximate. + /// + /// The query can be further parameterized using the returned builder. There + /// are various search parameters that will let you fine tune your recall + /// accuracy vs search latency. + /// + /// # Arguments + /// + /// * `vector` - The vector that will be used for search. + pub fn nearest_to(self, vector: impl ToQueryVector) -> Result { + let mut vector_query = self.into_vector(); + let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; + vector_query.query_vector = Some(query_vector); + Ok(vector_query) + } +} + +impl HasQuery for Query { + fn mut_query(&mut self) -> &mut Query { + self + } +} + +impl ExecutableQuery for Query { + async fn execute_with_options( + &self, + options: QueryExecutionOptions, + ) -> Result { + Ok(SendableRecordBatchStream::from( + self.parent.clone().plain_query(self, options).await?, + )) + } +} + +/// A builder for vector searches +/// +/// This builder contains methods specific to vector searches. +/// +/// /// See [`QueryBase`] for additional methods that can be used to +/// parameterize the query. +/// +/// See [`ExecutableQuery`] for methods that can be used to execute +/// the query and retrieve results. +#[derive(Debug, Clone)] +pub struct VectorQuery { + pub(crate) base: Query, + // The column to run the query on. If not specified, we will attempt to guess + // the column based on the dataset's schema. + pub(crate) column: Option, + // IVF PQ - ANN search. + pub(crate) query_vector: Option>, + pub(crate) nprobes: usize, + pub(crate) refine_factor: Option, + pub(crate) distance_type: Option, /// Default is true. Set to false to enforce a brute force search. pub(crate) use_index: bool, /// Apply filter before ANN search/ pub(crate) prefilter: bool, } -impl Query { - /// Creates a new Query object - /// - /// # Arguments - /// - /// * `parent` - the table to run the query on. - /// - pub(crate) fn new(parent: Arc) -> Self { +impl VectorQuery { + fn new(base: Query) -> Self { Self { - parent, - query_vector: None, + base, column: None, - limit: None, + query_vector: None, nprobes: 20, refine_factor: None, - metric_type: None, + distance_type: None, use_index: true, - filter: None, - select: Select::All, - prefilter: false, + prefilter: true, } } - /// Convert the query plan to a [`SendableRecordBatchStream`] + /// Set the vector column to query /// - /// # Returns + /// This controls which column is compared to the query vector supplied in + /// the call to [`Query::nearest_to`] /// - /// * A [SendableRecordBatchStream] with the query's results. - pub async fn execute_stream(&self) -> Result { - Ok(SendableRecordBatchStream::from( - self.parent.clone().query(self).await?, - )) - } - - /// Set the column to query - /// - /// # Arguments - /// - /// * `column` - The column name + /// This parameter must be specified if the table has more than one column + /// whose data type is a fixed-size-list of floats. pub fn column(mut self, column: &str) -> Self { self.column = Some(column.to_string()); self } - /// Set the maximum number of results to return. + /// Set the number of partitions to search (probe) /// - /// # Arguments + /// This argument is only used when the vector column has an IVF PQ index. + /// If there is no index then this value is ignored. /// - /// * `limit` - The maximum number of results to return. - pub fn limit(mut self, limit: usize) -> Self { - self.limit = Some(limit); - self - } - - /// Find the nearest vectors to the given query vector. + /// The IVF stage of IVF PQ divides the input into partitions (clusters) of + /// related values. /// - /// # Arguments + /// The partition whose centroids are closest to the query vector will be + /// exhaustiely searched to find matches. This parameter controls how many + /// partitions should be searched. /// - /// * `vector` - The vector that will be used for search. - pub fn nearest_to(mut self, vector: &[f32]) -> Self { - self.query_vector = Some(Float32Array::from(vector.to_vec())); - self - } - - /// Set the number of probes to use. + /// Increasing this value will increase the recall of your query but will + /// also increase the latency of your query. The default value is 20. This + /// default is good for many cases but the best value to use will depend on + /// your data and the recall that you need to achieve. /// - /// # Arguments - /// - /// * `nprobes` - The number of probes to use. + /// For best results we recommend tuning this parameter with a benchmark against + /// your actual data to find the smallest possible value that will still give + /// you the desired recall. pub fn nprobes(mut self, nprobes: usize) -> Self { self.nprobes = nprobes; self } - /// Set the refine factor to use. + /// A multiplier to control how many additional rows are taken during the refine step /// - /// # Arguments + /// This argument is only used when the vector column has an IVF PQ index. + /// If there is no index then this value is ignored. /// - /// * `refine_factor` - The refine factor to use. + /// An IVF PQ index stores compressed (quantized) values. They query vector is compared + /// against these values and, since they are compressed, the comparison is inaccurate. + /// + /// This parameter can be used to refine the results. It can improve both improve recall + /// and correct the ordering of the nearest results. + /// + /// To refine results LanceDb will first perform an ANN search to find the nearest + /// `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and + /// `limit` is the default (10) then the first 30 results will be selected. LanceDb + /// then fetches the full, uncompressed, values for these 30 results. The results are + /// then reordered by the true distance and only the nearest 10 are kept. + /// + /// Note: there is a difference between calling this method with a value of 1 and never + /// calling this method at all. Calling this method with any value will have an impact + /// on your search latency. When you call this method with a `refine_factor` of 1 then + /// LanceDb still needs to fetch the full, uncompressed, values so that it can potentially + /// reorder the results. + /// + /// Note: if this method is NOT called then the distances returned in the _distance column + /// will be approximate distances based on the comparison of the quantized query vector + /// and the quantized result vectors. This can be considerably different than the true + /// distance between the query vector and the actual uncompressed vector. pub fn refine_factor(mut self, refine_factor: u32) -> Self { self.refine_factor = Some(refine_factor); self } - /// Set the distance metric to use. + /// Set the distance metric to use /// - /// # Arguments + /// When performing a vector search we try and find the "nearest" vectors according + /// to some kind of distance metric. This parameter controls which distance metric to + /// use. See [`DistanceType`] for more details on the different distance metrics + /// available. /// - /// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used. - pub fn metric_type(mut self, metric_type: MetricType) -> Self { - self.metric_type = Some(metric_type); + /// Note: if there is a vector index then the distance type used MUST match the distance + /// type used to train the vector index. If this is not done then the results will be + /// invalid. + /// + /// By default [`DistanceType::L2`] is used. + pub fn distance_type(mut self, distance_type: DistanceType) -> Self { + self.distance_type = Some(distance_type); self } - /// Whether to use an ANN index if available + /// If this is called then filtering will happen after the vector search instead of + /// before. /// - /// # Arguments + /// By default filtering will be performed before the vector search. This is how + /// filtering is typically understood to work. This prefilter step does add some + /// additional latency. Creating a scalar index on the filter column(s) can + /// often improve this latency. However, sometimes a filter is too complex or scalar + /// indices cannot be applied to the column. In these cases postfiltering can be + /// used instead of prefiltering to improve latency. /// - /// * `use_index` - Sets Whether to use an ANN index if available - pub fn use_index(mut self, use_index: bool) -> Self { - self.use_index = use_index; + /// Post filtering applies the filter to the results of the vector search. This means + /// we only run the filter on a much smaller set of data. However, it can cause the + /// query to return fewer than `limit` results (or even no results) if none of the nearest + /// results match the filter. + /// + /// Post filtering happens during the "refine stage" (described in more detail in + /// [`Self::refine_factor`]). This means that setting a higher refine factor can often + /// help restore some of the results lost by post filtering. + pub fn postfilter(mut self) -> Self { + self.prefilter = false; self } - /// A filter statement to be applied to this query. + /// If this is called then any vector index is skipped /// - /// # Arguments - /// - /// * `filter` - SQL filter - pub fn filter(mut self, filter: impl AsRef) -> Self { - self.filter = Some(filter.as_ref().to_string()); + /// An exhaustive (flat) search will be performed. The query vector will + /// be compared to every vector in the table. At high scales this can be + /// expensive. However, this is often still useful. For example, skipping + /// the vector index can give you ground truth results which you can use to + /// calculate your recall to select an appropriate value for nprobes. + pub fn bypass_vector_index(mut self) -> Self { + self.use_index = false; self } +} - /// Return only the specified columns. - /// - /// Only select the specified columns. If not specified, all columns will be returned. - pub fn select(mut self, columns: &[impl AsRef]) -> Self { - self.select = Select::Simple(columns.iter().map(|c| c.as_ref().to_string()).collect()); - self +impl ExecutableQuery for VectorQuery { + async fn execute_with_options( + &self, + options: QueryExecutionOptions, + ) -> Result { + Ok(SendableRecordBatchStream::from( + self.base.parent.clone().vector_query(self, options).await?, + )) } +} - /// Return only the specified columns. - /// - /// Only select the specified columns. If not specified, all columns will be returned. - pub fn select_with_projection( - mut self, - columns: &[(impl AsRef, impl AsRef)], - ) -> Self { - self.select = Select::Projection( - columns - .iter() - .map(|(c, t)| (c.as_ref().to_string(), t.as_ref().to_string())) - .collect(), - ); - self - } - - pub fn prefilter(mut self, prefilter: bool) -> Self { - self.prefilter = prefilter; - self +impl HasQuery for VectorQuery { + fn mut_query(&mut self) -> &mut Query { + &mut self.base } } @@ -216,7 +748,7 @@ mod tests { use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use tempfile::tempdir; - use crate::connect; + use crate::{connect, Table}; #[tokio::test] async fn test_setters_getters() { @@ -234,25 +766,30 @@ mod tests { .await .unwrap(); - let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); - let query = table.query().nearest_to(&[0.1, 0.2]); - assert_eq!(query.query_vector, vector); + let vector = Float32Array::from_iter_values([0.1, 0.2]); + let query = table.query().nearest_to(&[0.1, 0.2]).unwrap(); + assert_eq!(*query.query_vector.unwrap().as_ref().as_primitive(), vector); let new_vector = Float32Array::from_iter_values([9.8, 8.7]); - let query = query - .nearest_to(&[9.8, 8.7]) + let query = table + .query() .limit(100) + .nearest_to(&[9.8, 8.7]) + .unwrap() .nprobes(1000) - .use_index(true) - .metric_type(MetricType::Cosine) + .postfilter() + .distance_type(DistanceType::Cosine) .refine_factor(999); - assert_eq!(query.query_vector.unwrap(), new_vector); - assert_eq!(query.limit.unwrap(), 100); + assert_eq!( + *query.query_vector.unwrap().as_ref().as_primitive(), + new_vector + ); + assert_eq!(query.base.limit.unwrap(), 100); assert_eq!(query.nprobes, 1000); assert!(query.use_index); - assert_eq!(query.metric_type, Some(MetricType::Cosine)); + assert_eq!(query.distance_type, Some(DistanceType::Cosine)); assert_eq!(query.refine_factor, Some(999)); } @@ -272,8 +809,14 @@ mod tests { .await .unwrap(); - let query = table.query().nearest_to(&[0.1; 4]); - let result = query.limit(10).filter("id % 2 == 0").execute_stream().await; + let query = table + .query() + .limit(10) + .only_if("id % 2 == 0") + .nearest_to(&[0.1; 4]) + .unwrap() + .postfilter(); + let result = query.execute().await; let mut stream = result.expect("should have result"); // should only have one batch while let Some(batch) = stream.next().await { @@ -281,13 +824,13 @@ mod tests { assert!(batch.expect("should be Ok").num_rows() < 10); } - let query = table.query().nearest_to(&[0.1; 4]); - let result = query + let query = table + .query() .limit(10) - .filter(String::from("id % 2 == 0")) // Work with String too - .prefilter(true) - .execute_stream() - .await; + .only_if(String::from("id % 2 == 0")) + .nearest_to(&[0.1; 4]) + .unwrap(); + let result = query.execute().await; let mut stream = result.expect("should have result"); // should only have one batch while let Some(batch) = stream.next().await { @@ -315,8 +858,8 @@ mod tests { let query = table .query() .limit(10) - .select_with_projection(&[("id2", "id * 2"), ("id", "id")]); - let result = query.execute_stream().await; + .select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")])); + let result = query.execute().await; let mut batches = result .expect("should have result") .try_collect::>() @@ -356,7 +899,7 @@ mod tests { .unwrap(); let query = table.query(); - let result = query.filter("id % 2 == 0").execute_stream().await; + let result = query.only_if("id % 2 == 0").execute().await; let mut stream = result.expect("should have result"); // should only have one batch while let Some(batch) = stream.next().await { @@ -367,7 +910,7 @@ mod tests { } // Reject bad filter - let result = table.query().filter("id = 0 AND").execute_stream().await; + let result = table.query().only_if("id = 0 AND").execute().await; assert!(result.is_err()); } @@ -399,21 +942,52 @@ mod tests { ) } - #[tokio::test] - async fn test_search() { - let tmp_dir = tempdir().unwrap(); + async fn make_test_table(tmp_dir: &tempfile::TempDir) -> Table { let dataset_path = tmp_dir.path().join("test.lance"); let uri = dataset_path.to_str().unwrap(); - let batches = make_test_batches(); + let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); - let table = conn - .create_table("my_table", Box::new(batches)) + conn.create_table("my_table", Box::new(batches)) + .execute() + .await + .unwrap() + } + + #[tokio::test] + async fn test_execute_with_options() { + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + + let mut results = table + .query() + .execute_with_options(QueryExecutionOptions { + max_batch_length: 10, + }) + .await + .unwrap(); + + while let Some(batch) = results.next().await { + assert!(batch.unwrap().num_rows() <= 10); + } + } + + #[tokio::test] + async fn query_base_methods_on_vector_query() { + // Make sure VectorQuery can be used as a QueryBase + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + + let mut results = table + .vector_search(&[1.0, 2.0, 3.0, 4.0]) + .unwrap() + .limit(1) .execute() .await .unwrap(); - let query = table.search(&[0.1, 0.2]); - assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values()); + let first_batch = results.next().await.unwrap().unwrap(); + assert_eq!(first_batch.num_rows(), 1); + assert!(results.next().await.is_none()); } } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 6160ff50..433ef08e 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -6,7 +6,7 @@ use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewCol use crate::{ error::Result, index::{IndexBuilder, IndexConfig}, - query::Query, + query::{Query, QueryExecutionOptions, VectorQuery}, table::{ merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats, TableInternal, UpdateBuilder, @@ -66,7 +66,18 @@ impl TableInternal for RemoteTable { async fn add(&self, _add: AddDataBuilder) -> Result<()> { todo!() } - async fn query(&self, _query: &Query) -> Result { + async fn plain_query( + &self, + _query: &Query, + _options: QueryExecutionOptions, + ) -> Result { + todo!() + } + async fn vector_query( + &self, + _query: &VectorQuery, + _options: QueryExecutionOptions, + ) -> Result { todo!() } async fn update(&self, _update: UpdateBuilder) -> Result<()> { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 520d1685..a9eb1146 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -17,6 +17,8 @@ use std::path::Path; use std::sync::Arc; +use arrow::array::AsArray; +use arrow::datatypes::Float32Type; use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -47,7 +49,9 @@ use crate::index::{ vector::{suggested_num_partitions, suggested_num_sub_vectors}, Index, IndexBuilder, }; -use crate::query::{Query, Select, DEFAULT_TOP_K}; +use crate::query::{ + Query, QueryExecutionOptions, Select, ToQueryVector, VectorQuery, DEFAULT_TOP_K, +}; use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam}; use self::dataset::DatasetConsistencyWrapper; @@ -230,7 +234,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn /// Count the number of rows in this table. async fn count_rows(&self, filter: Option) -> Result; async fn add(&self, add: AddDataBuilder) -> Result<()>; - async fn query(&self, query: &Query) -> Result; + async fn plain_query( + &self, + query: &Query, + options: QueryExecutionOptions, + ) -> Result; + async fn vector_query( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result; async fn delete(&self, predicate: &str) -> Result<()>; async fn update(&self, update: UpdateBuilder) -> Result<()>; async fn create_index(&self, index: IndexBuilder) -> Result<()>; @@ -528,21 +541,30 @@ impl Table { ) } - /// Search the table with a given query vector. + /// Create a [`Query`] Builder. /// - /// This is a convenience method for preparing an ANN query. - pub fn search(&self, query: &[f32]) -> Query { - self.query().nearest_to(query) - } - - /// Create a generic [`Query`] Builder. + /// Queries allow you to search your existing data. By default the query will + /// return all the data in the table in no particular order. The builder + /// returned by this method can be used to control the query using filtering, + /// vector similarity, sorting, and more. /// - /// When appropriate, various indices and statistics based pruning will be used to - /// accelerate the query. + /// Note: By default, all columns are returned. For best performance, you should + /// only fetch the columns you need. See [`Query::select_with_projection`] for + /// more details. + /// + /// When appropriate, various indices and statistics will be used to accelerate + /// the query. /// /// # Examples /// - /// ## Run a vector search (ANN) query. + /// ## Vector search + /// + /// This example will find the 10 rows whose value in the "vector" column are + /// closest to the query vector [1.0, 2.0, 3.0]. If an index has been created + /// on the "vector" column then this will perform an ANN search. + /// + /// The [`Query::refine_factor`] and [`Query::nprobes`] methods are used to + /// control the recall / latency tradeoff of the search. /// /// ```no_run /// # use arrow_array::RecordBatch; @@ -551,19 +573,25 @@ impl Table { /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; + /// use crate::lancedb::query::ExecutableQuery; /// let stream = tbl /// .query() /// .nearest_to(&[1.0, 2.0, 3.0]) + /// .unwrap() /// .refine_factor(5) /// .nprobes(10) - /// .execute_stream() + /// .execute() /// .await /// .unwrap(); /// let batches: Vec = stream.try_collect().await.unwrap(); /// # }); /// ``` /// - /// ## Run a SQL-style filter + /// ## SQL-style filter + /// + /// This query will return up to 1000 rows whose value in the `id` column + /// is greater than 5. LanceDb supports a broad set of filtering functions. + /// /// ```no_run /// # use arrow_array::RecordBatch; /// # use futures::TryStreamExt; @@ -571,18 +599,23 @@ impl Table { /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; + /// use crate::lancedb::query::{ExecutableQuery, QueryBase}; /// let stream = tbl /// .query() - /// .filter("id > 5") + /// .only_if("id > 5") /// .limit(1000) - /// .execute_stream() + /// .execute() /// .await /// .unwrap(); /// let batches: Vec = stream.try_collect().await.unwrap(); /// # }); /// ``` /// - /// ## Run a full scan query. + /// ## Full scan + /// + /// This query will return everything in the table in no particular + /// order. + /// /// ```no_run /// # use arrow_array::RecordBatch; /// # use futures::TryStreamExt; @@ -590,7 +623,8 @@ impl Table { /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; - /// let stream = tbl.query().execute_stream().await.unwrap(); + /// use crate::lancedb::query::ExecutableQuery; + /// let stream = tbl.query().execute().await.unwrap(); /// let batches: Vec = stream.try_collect().await.unwrap(); /// # }); /// ``` @@ -598,6 +632,15 @@ impl Table { Query::new(self.inner.clone()) } + /// Search the table with a given query vector. + /// + /// This is a convenience method for preparing a vector query and + /// is the same thing as calling `nearest_to` on the builder returned + /// by `query`. See [`Query::nearest_to`] for more details. + pub fn vector_search(&self, query: impl ToQueryVector) -> Result { + self.query().nearest_to(query) + } + /// Optimize the on-disk data and indices for better performance. /// ///
Experimental API
@@ -1107,6 +1150,75 @@ impl NativeTable { .await?; Ok(()) } + + async fn generic_query( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result { + let ds_ref = self.dataset.get().await?; + let mut scanner: Scanner = ds_ref.scan(); + + if let Some(query_vector) = query.query_vector.as_ref() { + // If there is a vector query, default to limit=10 if unspecified + let column = if let Some(col) = query.column.as_ref() { + col.clone() + } else { + // Infer a vector column with the same dimension of the query vector. + let arrow_schema = Schema::from(ds_ref.schema()); + default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? + }; + let field = ds_ref.schema().field(&column).ok_or(Error::Schema { + message: format!("Column {} not found in dataset schema", column), + })?; + if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32) + { + return Err(Error::Schema { + message: format!( + "Vector column '{}' does not match the dimension of the query vector: dim={}", + column, + query_vector.len(), + ), + }); + } + let query_vector = query_vector.as_primitive::(); + scanner.nearest( + &column, + query_vector, + query.base.limit.unwrap_or(DEFAULT_TOP_K), + )?; + } else { + // If there is no vector query, it's ok to not have a limit + scanner.limit(query.base.limit.map(|limit| limit as i64), None)?; + } + scanner.nprobs(query.nprobes); + scanner.use_index(query.use_index); + scanner.prefilter(query.prefilter); + scanner.batch_size(options.max_batch_length as usize); + + match &query.base.select { + Select::Columns(select) => { + scanner.project(select.as_slice())?; + } + Select::Dynamic(select_with_transform) => { + scanner.project_with_transform(select_with_transform.as_slice())?; + } + Select::All => { /* Do nothing */ } + } + + if let Some(filter) = &query.base.filter { + scanner.filter(filter)?; + } + + if let Some(refine_factor) = query.refine_factor { + scanner.refine(refine_factor); + } + + if let Some(distance_type) = query.distance_type { + scanner.distance_metric(distance_type); + } + Ok(scanner.try_into_stream().await?) + } } #[async_trait::async_trait] @@ -1232,63 +1344,21 @@ impl TableInternal for NativeTable { Ok(()) } - async fn query(&self, query: &Query) -> Result { - let ds_ref = self.dataset.get().await?; - let mut scanner: Scanner = ds_ref.scan(); + async fn plain_query( + &self, + query: &Query, + options: QueryExecutionOptions, + ) -> Result { + self.generic_query(&query.clone().into_vector(), options) + .await + } - if let Some(query_vector) = query.query_vector.as_ref() { - // If there is a vector query, default to limit=10 if unspecified - let column = if let Some(col) = query.column.as_ref() { - col.clone() - } else { - // Infer a vector column with the same dimension of the query vector. - let arrow_schema = Schema::from(ds_ref.schema()); - default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? - }; - let field = ds_ref.schema().field(&column).ok_or(Error::Schema { - message: format!("Column {} not found in dataset schema", column), - })?; - if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32) - { - return Err(Error::Schema { - message: format!( - "Vector column '{}' does not match the dimension of the query vector: dim={}", - column, - query_vector.len(), - ), - }); - } - scanner.nearest(&column, query_vector, query.limit.unwrap_or(DEFAULT_TOP_K))?; - } else { - // If there is no vector query, it's ok to not have a limit - scanner.limit(query.limit.map(|limit| limit as i64), None)?; - } - scanner.nprobs(query.nprobes); - scanner.use_index(query.use_index); - scanner.prefilter(query.prefilter); - - match &query.select { - Select::Simple(select) => { - scanner.project(select.as_slice())?; - } - Select::Projection(select_with_transform) => { - scanner.project_with_transform(select_with_transform.as_slice())?; - } - Select::All => { /* Do nothing */ } - } - - if let Some(filter) = &query.filter { - scanner.filter(filter)?; - } - - if let Some(refine_factor) = query.refine_factor { - scanner.refine(refine_factor); - } - - if let Some(metric_type) = query.metric_type { - scanner.distance_metric(metric_type); - } - Ok(scanner.try_into_stream().await?) + async fn vector_query( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result { + self.generic_query(query, options).await } async fn merge_insert( @@ -1450,6 +1520,7 @@ mod tests { use crate::connect; use crate::connection::ConnectBuilder; use crate::index::scalar::BTreeIndexBuilder; + use crate::query::{ExecutableQuery, QueryBase}; use super::*; @@ -1689,8 +1760,8 @@ mod tests { let mut batches = table .query() - .select(&["id", "name"]) - .execute_stream() + .select(Select::columns(&["id", "name"])) + .execute() .await .unwrap() .try_collect::>() @@ -1841,7 +1912,7 @@ mod tests { let mut batches = table .query() - .select(&[ + .select(Select::columns(&[ "string", "large_string", "int32", @@ -1855,8 +1926,8 @@ mod tests { "timestamp_ms", "vec_f32", "vec_f64", - ]) - .execute_stream() + ])) + .execute() .await .unwrap() .try_collect::>()