feat: refactor the query API and add query support to the python async API (#1113)

In addition, there are also a number of changes in nodejs to the
docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
Weston Pace
2024-03-18 12:36:49 -07:00
parent 2db257ca29
commit 4180b44472
38 changed files with 2609 additions and 754 deletions

View File

@@ -79,7 +79,7 @@ import {
import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type";
function sanitizeMetadata(
metadataLike?: unknown
metadataLike?: unknown,
): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) {
return undefined;
@@ -90,7 +90,7 @@ function sanitizeMetadata(
for (const item of metadataLike) {
if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) {
throw Error(
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values"
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values",
);
}
}
@@ -105,7 +105,7 @@ function sanitizeInt(typeLike: object) {
typeof typeLike.isSigned !== "boolean"
) {
throw Error(
"Expected an Int Type to have a `bitWidth` and `isSigned` property"
"Expected an Int Type to have a `bitWidth` and `isSigned` property",
);
}
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
@@ -128,7 +128,7 @@ function sanitizeDecimal(typeLike: object) {
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties"
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties",
);
}
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
@@ -149,7 +149,7 @@ function sanitizeTime(typeLike: object) {
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Time type to have `unit` and `bitWidth` properties"
"Expected a Time type to have `unit` and `bitWidth` properties",
);
}
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
@@ -172,7 +172,7 @@ function sanitizeTypedTimestamp(
| typeof TimestampNanosecond
| typeof TimestampMicrosecond
| typeof TimestampMillisecond
| typeof TimestampSecond
| typeof TimestampSecond,
) {
let timezone = null;
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
@@ -191,7 +191,7 @@ function sanitizeInterval(typeLike: object) {
function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a List type to have an array-like `children` property"
"Expected a List type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
@@ -203,7 +203,7 @@ function sanitizeList(typeLike: object) {
function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Struct type to have an array-like `children` property"
"Expected a Struct type to have an array-like `children` property",
);
}
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
@@ -216,47 +216,47 @@ function sanitizeUnion(typeLike: object) {
typeof typeLike.mode !== "number"
) {
throw Error(
"Expected a Union type to have `typeIds` and `mode` properties"
"Expected a Union type to have `typeIds` and `mode` properties",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Union type to have an array-like `children` property"
"Expected a Union type to have an array-like `children` property",
);
}
return new Union(
typeLike.mode,
typeLike.typeIds as any,
typeLike.children.map((child) => sanitizeField(child))
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeTypedUnion(
typeLike: object,
UnionType: typeof DenseUnion | typeof SparseUnion
UnionType: typeof DenseUnion | typeof SparseUnion,
) {
if (!("typeIds" in typeLike)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property"
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property"
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property",
);
}
return new UnionType(
typeLike.typeIds as any,
typeLike.children.map((child) => sanitizeField(child))
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property"
"Expected a FixedSizeBinary type to have a `byteWidth` property",
);
}
return new FixedSizeBinary(typeLike.byteWidth);
@@ -268,7 +268,7 @@ function sanitizeFixedSizeList(typeLike: object) {
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a FixedSizeList type to have an array-like `children` property"
"Expected a FixedSizeList type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
@@ -276,14 +276,14 @@ function sanitizeFixedSizeList(typeLike: object) {
}
return new FixedSizeList(
typeLike.listSize,
sanitizeField(typeLike.children[0])
sanitizeField(typeLike.children[0]),
);
}
function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Map type to have an array-like `children` property"
"Expected a Map type to have an array-like `children` property",
);
}
if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") {
@@ -291,7 +291,7 @@ function sanitizeMap(typeLike: object) {
}
return new Map_(
typeLike.children.map((field) => sanitizeField(field)) as any,
typeLike.keysSorted
typeLike.keysSorted,
);
}
@@ -319,7 +319,7 @@ function sanitizeDictionary(typeLike: object) {
sanitizeType(typeLike.dictionary),
sanitizeType(typeLike.indices) as any,
typeLike.id,
typeLike.isOrdered
typeLike.isOrdered,
);
}
@@ -454,7 +454,7 @@ function sanitizeField(fieldLike: unknown): Field {
!("nullable" in fieldLike)
) {
throw Error(
"The field passed in is missing a `type`/`name`/`nullable` property"
"The field passed in is missing a `type`/`name`/`nullable` property",
);
}
const type = sanitizeType(fieldLike.type);
@@ -473,6 +473,13 @@ function sanitizeField(fieldLike: unknown): Field {
return new Field(name, type, nullable, metadata);
}
/**
* Convert something schemaLike into a Schema instance
*
* This method is often needed even when the caller is using a Schema
* instance because they might be using a different instance of apache-arrow
* than lancedb is using.
*/
export function sanitizeSchema(schemaLike: unknown): Schema {
if (schemaLike instanceof Schema) {
return schemaLike;
@@ -482,7 +489,7 @@ export function sanitizeSchema(schemaLike: unknown): Schema {
}
if (!("fields" in schemaLike)) {
throw Error(
"The schema passed in does not appear to be a schema (no 'fields' property)"
"The schema passed in does not appear to be a schema (no 'fields' property)",
);
}
let metadata;
@@ -491,11 +498,11 @@ export function sanitizeSchema(schemaLike: unknown): Schema {
}
if (!Array.isArray(schemaLike.fields)) {
throw Error(
"The schema passed in had a 'fields' property but it was not an array"
"The schema passed in had a 'fields' property but it was not an array",
);
}
const sanitizedFields = schemaLike.fields.map((field) =>
sanitizeField(field)
sanitizeField(field),
);
return new Schema(sanitizedFields, metadata);
}

View File

@@ -129,11 +129,25 @@ describe("When creating an index", () => {
});
// Search without specifying the column
const rst = await tbl.query().nearestTo(queryVec).limit(2).toArrow();
let rst = await tbl
.query()
.limit(2)
.nearestTo(queryVec)
.distanceType("DoT")
.toArrow();
expect(rst.numRows).toBe(2);
// Search using `vectorSearch`
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
const rst2 = await tbl.search(queryVec, "vec").limit(2).toArrow();
const rst2 = await tbl
.query()
.limit(2)
.nearestTo(queryVec)
.column("vec")
.toArrow();
expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString());
});
@@ -163,7 +177,7 @@ describe("When creating an index", () => {
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
for await (const r of tbl.query().where("id > 1").select(["id"])) {
expect(r.numRows).toBe(298);
}
});
@@ -205,33 +219,39 @@ describe("When creating an index", () => {
const rst = await tbl
.query()
.limit(2)
.nearestTo(
Array(32)
.fill(1)
.map(() => Math.random()),
)
.limit(2)
.toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
await expect(
tbl
.search(
.query()
.limit(2)
.nearestTo(
Array(64)
.fill(1)
.map(() => Math.random()),
"vec",
)
.limit(2)
.column("vec")
.toArrow(),
).rejects.toThrow(/.*does not match the dimension.*/);
const query64 = Array(64)
.fill(1)
.map(() => Math.random());
const rst64Query = await tbl.query().nearestTo(query64).limit(2).toArrow();
const rst64Search = await tbl.search(query64, "vec2").limit(2).toArrow();
const rst64Query = await tbl.query().limit(2).nearestTo(query64).toArrow();
const rst64Search = await tbl
.query()
.limit(2)
.nearestTo(query64)
.column("vec2")
.toArrow();
expect(rst64Query.toString()).toEqual(rst64Search.toString());
expect(rst64Query.numRows).toBe(2);
});

View File

@@ -4,14 +4,25 @@
const eslint = require("@eslint/js");
const tseslint = require("typescript-eslint");
const eslintConfigPrettier = require("eslint-config-prettier");
const jsdoc = require("eslint-plugin-jsdoc");
module.exports = tseslint.config(
eslint.configs.recommended,
jsdoc.configs["flat/recommended"],
eslintConfigPrettier,
...tseslint.configs.recommended,
{
rules: {
"@typescript-eslint/naming-convention": "error",
"jsdoc/require-returns": "off",
"jsdoc/require-param": "off",
"jsdoc/require-jsdoc": [
"error",
{
publicOnly: true,
},
],
},
plugins: jsdoc,
},
);

View File

@@ -31,6 +31,7 @@ import {
DataType,
Binary,
Float32,
type makeTable,
} from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { sanitizeSchema } from "./sanitize";
@@ -128,14 +129,7 @@ export class MakeArrowTableOptions {
* - Buffer => Binary
* - Record<String, any> => Struct
* - Array<any> => List
*
* @param data input data
* @param options options to control the makeArrowTable call.
*
* @example
*
* ```ts
*
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
*
@@ -307,7 +301,9 @@ export function makeEmptyTable(schema: Schema): ArrowTable {
return makeArrowTable([], { schema });
}
// Helper function to convert Array<Array<any>> to a variable sized list array
/**
* Helper function to convert Array<Array<any>> to a variable sized list array
*/
// @ts-expect-error (Vector<unknown> is not assignable to Vector<any>)
function makeListVector(lists: unknown[][]): Vector<unknown> {
if (lists.length === 0 || lists[0].length === 0) {
@@ -333,7 +329,7 @@ function makeListVector(lists: unknown[][]): Vector<unknown> {
return listBuilder.finish().toVector();
}
// Helper function to convert an Array of JS values to an Arrow Vector
/** Helper function to convert an Array of JS values to an Arrow Vector */
function makeVector(
values: unknown[],
type?: DataType,
@@ -374,6 +370,7 @@ function makeVector(
}
}
/** Helper function to apply embeddings to an input table */
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
@@ -466,7 +463,7 @@ async function applyEmbeddings<T>(
return newTable;
}
/*
/**
* Convert an Array of records into an Arrow Table, optionally applying an
* embeddings function to it.
*
@@ -493,7 +490,7 @@ export async function convertToTable<T>(
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
}
// Creates the Arrow Type for a Vector column with dimension `dim`
/** Creates the Arrow Type for a Vector column with dimension `dim` */
function newVectorType<T extends Float>(
dim: number,
innerType: T,
@@ -565,6 +562,14 @@ export async function fromTableToBuffer<T>(
return Buffer.from(await writer.toUint8Array());
}
/**
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromDataToBuffer<T>(
data: Data,
embeddings?: EmbeddingFunction<T>,
@@ -599,6 +604,9 @@ export async function fromTableToStreamBuffer<T>(
return Buffer.from(await writer.toUint8Array());
}
/**
* Reorder the columns in `batch` so that they agree with the field order in `schema`
*/
function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = [];
for (const field of schema.fields) {
@@ -621,6 +629,9 @@ function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
return new RecordBatch(schema, newData);
}
/**
* Reorder the columns in `table` so that they agree with the field order in `schema`
*/
function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map((batch) =>
alignBatch(batch, schema),
@@ -628,7 +639,9 @@ function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
return new ArrowTable(schema, alignedBatches);
}
// Creates an empty Arrow Table
/**
* Create an empty table with the given schema
*/
export function createEmptyTable(schema: Schema): ArrowTable {
return new ArrowTable(sanitizeSchema(schema));
}

View File

@@ -78,7 +78,8 @@ export class Connection {
return this.inner.isOpen();
}
/** Close the connection, releasing any underlying resources.
/**
* Close the connection, releasing any underlying resources.
*
* It is safe to call this method multiple times.
*
@@ -93,11 +94,12 @@ export class Connection {
return this.inner.display();
}
/** List all the table names in this database.
/**
* List all the table names in this database.
*
* Tables will be returned in lexicographical order.
*
* @param options Optional parameters to control the listing.
* @param {Partial<TableNamesOptions>} options - options to control the
* paging / start point
*/
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
return this.inner.tableNames(options?.startAfter, options?.limit);
@@ -105,9 +107,7 @@ export class Connection {
/**
* Open a table in the database.
*
* @param name The name of the table.
* @param embeddings An embedding function to use on this table
* @param {string} name - The name of the table
*/
async openTable(name: string): Promise<Table> {
const innerTable = await this.inner.openTable(name);
@@ -116,9 +116,9 @@ export class Connection {
/**
* Creates a new Table and initialize it with new data.
*
* @param {string} name - The name of the table.
* @param data - Non-empty Array of Records to be inserted into the table
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
* to be inserted into the table
*/
async createTable(
name: string,
@@ -145,9 +145,8 @@ export class Connection {
/**
* Creates a new empty Table
*
* @param {string} name - The name of the table.
* @param schema - The schema of the table
* @param {Schema} schema - The schema of the table
*/
async createEmptyTable(
name: string,
@@ -169,7 +168,7 @@ export class Connection {
/**
* Drop an existing table.
* @param name The name of the table to drop.
* @param {string} name The name of the table to drop.
*/
async dropTable(name: string): Promise<void> {
return this.inner.dropTable(name);

View File

@@ -62,6 +62,7 @@ export interface EmbeddingFunction<T> {
embed: (data: T[]) => Promise<number[][]>;
}
/** Test if the input seems to be an embedding function */
export function isEmbeddingFunction<T>(
value: unknown,
): value is EmbeddingFunction<T> {

View File

@@ -30,9 +30,8 @@ export { Table, AddDataOptions } from "./table";
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
*
* @param uri The uri of the database. If the database uri starts with `db://` then it connects to a remote database.
*
* @param {string} uri - The uri of the database. If the database uri starts
* with `db://` then it connects to a remote database.
* @see {@link ConnectionOptions} for more details on the URI format.
*/
export async function connect(

View File

@@ -18,7 +18,8 @@ import { Index as LanceDbIndex } from "./native";
* Options to create an `IVF_PQ` index
*/
export interface IvfPqOptions {
/** The number of IVF partitions to create.
/**
* The number of IVF partitions to create.
*
* This value should generally scale with the number of rows in the dataset.
* By default the number of partitions is the square root of the number of
@@ -30,7 +31,8 @@ export interface IvfPqOptions {
*/
numPartitions?: number;
/** Number of sub-vectors of PQ.
/**
* Number of sub-vectors of PQ.
*
* This value controls how much the vector is compressed during the quantization step.
* The more sub vectors there are the less the vector is compressed. The default is
@@ -45,9 +47,10 @@ export interface IvfPqOptions {
*/
numSubVectors?: number;
/** [DistanceType] to use to build the index.
/**
* Distance type to use to build the index.
*
* Default value is [DistanceType::L2].
* Default value is "l2".
*
* This is used when training the index to calculate the IVF partitions
* (vectors are grouped in partitions with similar vectors according to this
@@ -79,7 +82,8 @@ export interface IvfPqOptions {
*/
distanceType?: "l2" | "cosine" | "dot";
/** Max iteration to train IVF kmeans.
/**
* Max iteration to train IVF kmeans.
*
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
@@ -91,7 +95,8 @@ export interface IvfPqOptions {
*/
maxIterations?: number;
/** The number of vectors, per partition, to sample when training IVF kmeans.
/**
* The number of vectors, per partition, to sample when training IVF kmeans.
*
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
@@ -148,7 +153,8 @@ export class Index {
);
}
/** Create a btree index
/**
* Create a btree index
*
* A btree index is an index on a scalar columns. The index stores a copy of the column
* in sorted order. A header entry is created for each block of rows (currently the
@@ -172,7 +178,8 @@ export class Index {
}
export interface IndexOptions {
/** Advanced index configuration
/**
* Advanced index configuration
*
* This option allows you to specify a specfic index to create and also
* allows you to pass in configuration for training the index.
@@ -183,7 +190,8 @@ export interface IndexOptions {
* will be used to determine the most useful kind of index to create.
*/
config?: Index;
/** Whether to replace the existing index
/**
* Whether to replace the existing index
*
* If this is false, and another index already exists on the same columns
* and the same name, then an error will be returned. This is true even if

View File

@@ -105,15 +105,23 @@ export class RecordBatchIterator {
next(): Promise<Buffer | null>
}
export class Query {
column(column: string): void
filter(filter: string): void
select(columns: Array<string>): void
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
prefilter(prefilter: boolean): void
nearestTo(vector: Float32Array): void
nearestTo(vector: Float32Array): VectorQuery
execute(): Promise<RecordBatchIterator>
}
export class VectorQuery {
column(column: string): void
distanceType(distanceType: string): void
postfilter(): void
refineFactor(refineFactor: number): void
nprobes(nprobe: number): void
executeStream(): Promise<RecordBatchIterator>
bypassVectorIndex(): void
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
execute(): Promise<RecordBatchIterator>
}
export class Table {
display(): string
@@ -127,6 +135,7 @@ export class Table {
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
query(): Query
vectorSearch(vector: Float32Array): VectorQuery
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
dropColumns(columns: Array<string>): Promise<void>

View File

@@ -5,302 +5,325 @@
/* auto-generated by NAPI-RS */
const { existsSync, readFileSync } = require('fs')
const { join } = require('path')
const { join } = require("path");
const { platform, arch } = process
const { platform, arch } = process;
let nativeBinding = null
let localFileExisted = false
let loadError = null
let nativeBinding = null;
let localFileExisted = false;
let loadError = null;
function isMusl() {
// For Node 10
if (!process.report || typeof process.report.getReport !== 'function') {
if (!process.report || typeof process.report.getReport !== "function") {
try {
const lddPath = require('child_process').execSync('which ldd').toString().trim()
return readFileSync(lddPath, 'utf8').includes('musl')
const lddPath = require("child_process")
.execSync("which ldd")
.toString()
.trim();
return readFileSync(lddPath, "utf8").includes("musl");
} catch (e) {
return true
return true;
}
} else {
const { glibcVersionRuntime } = process.report.getReport().header
return !glibcVersionRuntime
const { glibcVersionRuntime } = process.report.getReport().header;
return !glibcVersionRuntime;
}
}
switch (platform) {
case 'android':
case "android":
switch (arch) {
case 'arm64':
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm64.node'))
case "arm64":
localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.android-arm64.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.android-arm64.node')
nativeBinding = require("./lancedb-nodejs.android-arm64.node");
} else {
nativeBinding = require('lancedb-android-arm64')
nativeBinding = require("lancedb-android-arm64");
}
} catch (e) {
loadError = e
loadError = e;
}
break
case 'arm':
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm-eabi.node'))
break;
case "arm":
localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.android-arm-eabi.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.android-arm-eabi.node')
nativeBinding = require("./lancedb-nodejs.android-arm-eabi.node");
} else {
nativeBinding = require('lancedb-android-arm-eabi')
nativeBinding = require("lancedb-android-arm-eabi");
}
} catch (e) {
loadError = e
loadError = e;
}
break
break;
default:
throw new Error(`Unsupported architecture on Android ${arch}`)
throw new Error(`Unsupported architecture on Android ${arch}`);
}
break
case 'win32':
break;
case "win32":
switch (arch) {
case 'x64':
case "x64":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.win32-x64-msvc.node')
)
join(__dirname, "lancedb-nodejs.win32-x64-msvc.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.win32-x64-msvc.node')
nativeBinding = require("./lancedb-nodejs.win32-x64-msvc.node");
} else {
nativeBinding = require('lancedb-win32-x64-msvc')
nativeBinding = require("lancedb-win32-x64-msvc");
}
} catch (e) {
loadError = e
loadError = e;
}
break
case 'ia32':
break;
case "ia32":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.win32-ia32-msvc.node')
)
join(__dirname, "lancedb-nodejs.win32-ia32-msvc.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.win32-ia32-msvc.node')
nativeBinding = require("./lancedb-nodejs.win32-ia32-msvc.node");
} else {
nativeBinding = require('lancedb-win32-ia32-msvc')
nativeBinding = require("lancedb-win32-ia32-msvc");
}
} catch (e) {
loadError = e
loadError = e;
}
break
case 'arm64':
break;
case "arm64":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.win32-arm64-msvc.node')
)
join(__dirname, "lancedb-nodejs.win32-arm64-msvc.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.win32-arm64-msvc.node')
nativeBinding = require("./lancedb-nodejs.win32-arm64-msvc.node");
} else {
nativeBinding = require('lancedb-win32-arm64-msvc')
nativeBinding = require("lancedb-win32-arm64-msvc");
}
} catch (e) {
loadError = e
loadError = e;
}
break
break;
default:
throw new Error(`Unsupported architecture on Windows: ${arch}`)
throw new Error(`Unsupported architecture on Windows: ${arch}`);
}
break
case 'darwin':
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-universal.node'))
break;
case "darwin":
localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.darwin-universal.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.darwin-universal.node')
nativeBinding = require("./lancedb-nodejs.darwin-universal.node");
} else {
nativeBinding = require('lancedb-darwin-universal')
nativeBinding = require("lancedb-darwin-universal");
}
break
break;
} catch {}
switch (arch) {
case 'x64':
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-x64.node'))
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.darwin-x64.node')
} else {
nativeBinding = require('lancedb-darwin-x64')
}
} catch (e) {
loadError = e
}
break
case 'arm64':
case "x64":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.darwin-arm64.node')
)
join(__dirname, "lancedb-nodejs.darwin-x64.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.darwin-arm64.node')
nativeBinding = require("./lancedb-nodejs.darwin-x64.node");
} else {
nativeBinding = require('lancedb-darwin-arm64')
nativeBinding = require("lancedb-darwin-x64");
}
} catch (e) {
loadError = e
loadError = e;
}
break
break;
case "arm64":
localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.darwin-arm64.node"),
);
try {
if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.darwin-arm64.node");
} else {
nativeBinding = require("lancedb-darwin-arm64");
}
} catch (e) {
loadError = e;
}
break;
default:
throw new Error(`Unsupported architecture on macOS: ${arch}`)
throw new Error(`Unsupported architecture on macOS: ${arch}`);
}
break
case 'freebsd':
if (arch !== 'x64') {
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
break;
case "freebsd":
if (arch !== "x64") {
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`);
}
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.freebsd-x64.node'))
localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.freebsd-x64.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.freebsd-x64.node')
nativeBinding = require("./lancedb-nodejs.freebsd-x64.node");
} else {
nativeBinding = require('lancedb-freebsd-x64')
nativeBinding = require("lancedb-freebsd-x64");
}
} catch (e) {
loadError = e
loadError = e;
}
break
case 'linux':
break;
case "linux":
switch (arch) {
case 'x64':
case "x64":
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-x64-musl.node')
)
join(__dirname, "lancedb-nodejs.linux-x64-musl.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-x64-musl.node')
nativeBinding = require("./lancedb-nodejs.linux-x64-musl.node");
} else {
nativeBinding = require('lancedb-linux-x64-musl')
nativeBinding = require("lancedb-linux-x64-musl");
}
} catch (e) {
loadError = e
loadError = e;
}
} else {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-x64-gnu.node')
)
join(__dirname, "lancedb-nodejs.linux-x64-gnu.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-x64-gnu.node')
nativeBinding = require("./lancedb-nodejs.linux-x64-gnu.node");
} else {
nativeBinding = require('lancedb-linux-x64-gnu')
nativeBinding = require("lancedb-linux-x64-gnu");
}
} catch (e) {
loadError = e
loadError = e;
}
}
break
case 'arm64':
break;
case "arm64":
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-arm64-musl.node')
)
join(__dirname, "lancedb-nodejs.linux-arm64-musl.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-arm64-musl.node')
nativeBinding = require("./lancedb-nodejs.linux-arm64-musl.node");
} else {
nativeBinding = require('lancedb-linux-arm64-musl')
nativeBinding = require("lancedb-linux-arm64-musl");
}
} catch (e) {
loadError = e
loadError = e;
}
} else {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-arm64-gnu.node')
)
join(__dirname, "lancedb-nodejs.linux-arm64-gnu.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-arm64-gnu.node')
nativeBinding = require("./lancedb-nodejs.linux-arm64-gnu.node");
} else {
nativeBinding = require('lancedb-linux-arm64-gnu')
nativeBinding = require("lancedb-linux-arm64-gnu");
}
} catch (e) {
loadError = e
loadError = e;
}
}
break
case 'arm':
break;
case "arm":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-arm-gnueabihf.node')
)
join(__dirname, "lancedb-nodejs.linux-arm-gnueabihf.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-arm-gnueabihf.node')
nativeBinding = require("./lancedb-nodejs.linux-arm-gnueabihf.node");
} else {
nativeBinding = require('lancedb-linux-arm-gnueabihf')
nativeBinding = require("lancedb-linux-arm-gnueabihf");
}
} catch (e) {
loadError = e
loadError = e;
}
break
case 'riscv64':
break;
case "riscv64":
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-riscv64-musl.node')
)
join(__dirname, "lancedb-nodejs.linux-riscv64-musl.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-riscv64-musl.node')
nativeBinding = require("./lancedb-nodejs.linux-riscv64-musl.node");
} else {
nativeBinding = require('lancedb-linux-riscv64-musl')
nativeBinding = require("lancedb-linux-riscv64-musl");
}
} catch (e) {
loadError = e
loadError = e;
}
} else {
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-riscv64-gnu.node')
)
join(__dirname, "lancedb-nodejs.linux-riscv64-gnu.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-riscv64-gnu.node')
nativeBinding = require("./lancedb-nodejs.linux-riscv64-gnu.node");
} else {
nativeBinding = require('lancedb-linux-riscv64-gnu')
nativeBinding = require("lancedb-linux-riscv64-gnu");
}
} catch (e) {
loadError = e
loadError = e;
}
}
break
case 's390x':
break;
case "s390x":
localFileExisted = existsSync(
join(__dirname, 'lancedb-nodejs.linux-s390x-gnu.node')
)
join(__dirname, "lancedb-nodejs.linux-s390x-gnu.node"),
);
try {
if (localFileExisted) {
nativeBinding = require('./lancedb-nodejs.linux-s390x-gnu.node')
nativeBinding = require("./lancedb-nodejs.linux-s390x-gnu.node");
} else {
nativeBinding = require('lancedb-linux-s390x-gnu')
nativeBinding = require("lancedb-linux-s390x-gnu");
}
} catch (e) {
loadError = e
loadError = e;
}
break
break;
default:
throw new Error(`Unsupported architecture on Linux: ${arch}`)
throw new Error(`Unsupported architecture on Linux: ${arch}`);
}
break
break;
default:
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`)
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`);
}
if (!nativeBinding) {
if (loadError) {
throw loadError
throw loadError;
}
throw new Error(`Failed to load native binding`)
throw new Error(`Failed to load native binding`);
}
const { Connection, Index, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
const {
Connection,
Index,
RecordBatchIterator,
Query,
VectorQuery,
Table,
WriteMode,
connect,
} = nativeBinding;
module.exports.Connection = Connection
module.exports.Index = Index
module.exports.RecordBatchIterator = RecordBatchIterator
module.exports.Query = Query
module.exports.Table = Table
module.exports.WriteMode = WriteMode
module.exports.connect = connect
module.exports.Connection = Connection;
module.exports.Index = Index;
module.exports.RecordBatchIterator = RecordBatchIterator;
module.exports.Query = Query;
module.exports.VectorQuery = VectorQuery;
module.exports.Table = Table;
module.exports.WriteMode = WriteMode;
module.exports.connect = connect;

View File

@@ -17,18 +17,15 @@ import {
RecordBatchIterator as NativeBatchIterator,
Query as NativeQuery,
Table as NativeTable,
VectorQuery as NativeVectorQuery,
} from "./native";
import { type IvfPqOptions } from "./indices";
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
private promisedInner?: Promise<NativeBatchIterator>;
private inner?: NativeBatchIterator;
constructor(
inner?: NativeBatchIterator,
promise?: Promise<NativeBatchIterator>,
) {
constructor(promise?: Promise<NativeBatchIterator>) {
// TODO: check promise reliably so we dont need to pass two arguments.
this.inner = inner;
this.promisedInner = promise;
}
@@ -53,82 +50,113 @@ class RecordBatchIterator implements AsyncIterator<RecordBatch> {
}
/* eslint-enable */
/** Query executor */
export class Query implements AsyncIterable<RecordBatch> {
private readonly inner: NativeQuery;
/** Common methods supported by all query types */
export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery,
QueryType,
> implements AsyncIterable<RecordBatch>
{
protected constructor(protected inner: NativeQueryType) {}
constructor(tbl: NativeTable) {
this.inner = tbl.query();
/**
* A filter statement to be applied to this query.
*
* The filter should be supplied as an SQL query string. For example:
* @example
* x > 10
* y > 0 AND y < 100
* x > 5 OR y = 'test'
*
* Filtering performance can often be improved by creating a scalar index
* on the filter column(s).
*/
where(predicate: string): QueryType {
this.inner.onlyIf(predicate);
return this as unknown as QueryType;
}
/** Set the column to run query. */
column(column: string): Query {
this.inner.column(column);
return this;
/**
* Return only the specified columns.
*
* By default a query will return all columns from the table. However, this can have
* a very significant impact on latency. LanceDb stores data in a columnar fashion. This
* means we can finely tune our I/O to select exactly the columns we need.
*
* As a best practice you should always limit queries to the columns that you need. If you
* pass in an array of column names then only those columns will be returned.
*
* You can also use this method to create new "dynamic" columns based on your existing columns.
* For example, you may not care about "a" or "b" but instead simply want "a + b". This is often
* seen in the SELECT clause of an SQL query (e.g. `SELECT a+b FROM my_table`).
*
* To create dynamic columns you can pass in a Map<string, string>. A column will be returned
* for each entry in the map. The key provides the name of the column. The value is
* an SQL string used to specify how the column is calculated.
*
* For example, an SQL query might state `SELECT a + b AS combined, c`. The equivalent
* input to this method would be:
* @example
* new Map([["combined", "a + b"], ["c", "c"]])
*
* Columns will always be returned in the order given, even if that order is different than
* the order used when adding the data.
*
* Note that you can pass in a `Record<string, string>` (e.g. an object literal). This method
* uses `Object.entries` which should preserve the insertion order of the object. However,
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/
select(
columns: string[] | Map<string, string> | Record<string, string>,
): QueryType {
let columnTuples: [string, string][];
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
columnTuples = Array.from(columns.entries());
} else {
columnTuples = Object.entries(columns);
}
this.inner.select(columnTuples);
return this as unknown as QueryType;
}
/** Set the filter predicate, only returns the results that satisfy the filter.
/**
* Set the maximum number of results to return.
*
* By default, a plain search has no limit. If this method is not
* called then every valid row from the table will be returned.
*/
limit(limit: number): QueryType {
this.inner.limit(limit);
return this as unknown as QueryType;
}
protected nativeExecute(): Promise<NativeBatchIterator> {
return this.inner.execute();
}
/**
* Execute the query and return the results as an @see {@link AsyncIterator}
* of @see {@link RecordBatch}.
*
* By default, LanceDb will use many threads to calculate results and, when
* the result set is large, multiple batches will be processed at one time.
* This readahead is limited however and backpressure will be applied if this
* stream is consumed slowly (this constrains the maximum memory used by a
* single query)
*
*/
filter(predicate: string): Query {
this.inner.filter(predicate);
return this;
protected execute(): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute());
}
/**
* Select the columns to return. If not set, all columns are returned.
*/
select(columns: string[]): Query {
this.inner.select(columns);
return this;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
const promise = this.nativeExecute();
return new RecordBatchIterator(promise);
}
/**
* Set the limit of rows to return.
*/
limit(limit: number): Query {
this.inner.limit(limit);
return this;
}
prefilter(prefilter: boolean): Query {
this.inner.prefilter(prefilter);
return this;
}
/**
* Set the query vector.
*/
nearestTo(vector: number[]): Query {
this.inner.nearestTo(Float32Array.from(vector));
return this;
}
/**
* Set the number of IVF partitions to use for the query.
*/
nprobes(nprobes: number): Query {
this.inner.nprobes(nprobes);
return this;
}
/**
* Set the refine factor for the query.
*/
refineFactor(refineFactor: number): Query {
this.inner.refineFactor(refineFactor);
return this;
}
/**
* Execute the query and return the results as an AsyncIterator.
*/
async executeStream(): Promise<RecordBatchIterator> {
const inner = await this.inner.executeStream();
return new RecordBatchIterator(inner);
}
/** Collect the results as an Arrow Table. */
/** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(): Promise<ArrowTable> {
const batches = [];
for await (const batch of this) {
@@ -137,18 +165,211 @@ export class Query implements AsyncIterable<RecordBatch> {
return new ArrowTable(batches);
}
/** Returns a JSON Array of All results.
*
*/
/** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
const promise = this.inner.executeStream();
return new RecordBatchIterator(undefined, promise);
/**
* An interface for a query that can be executed
*
* Supported by all query types
*/
export interface ExecutableQuery {}
/**
* A builder used to construct a vector search
*
* This builder can be reused to execute the query many times.
*/
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
constructor(inner: NativeVectorQuery) {
super(inner);
}
/**
* Set the number of partitions to search (probe)
*
* This argument is only used when the vector column has an IVF PQ index.
* If there is no index then this value is ignored.
*
* The IVF stage of IVF PQ divides the input into partitions (clusters) of
* related values.
*
* The partition whose centroids are closest to the query vector will be
* exhaustiely searched to find matches. This parameter controls how many
* partitions should be searched.
*
* Increasing this value will increase the recall of your query but will
* also increase the latency of your query. The default value is 20. This
* default is good for many cases but the best value to use will depend on
* your data and the recall that you need to achieve.
*
* For best results we recommend tuning this parameter with a benchmark against
* your actual data to find the smallest possible value that will still give
* you the desired recall.
*/
nprobes(nprobes: number): VectorQuery {
this.inner.nprobes(nprobes);
return this;
}
/**
* Set the vector column to query
*
* This controls which column is compared to the query vector supplied in
* the call to @see {@link Query#nearestTo}
*
* This parameter must be specified if the table has more than one column
* whose data type is a fixed-size-list of floats.
*/
column(column: string): VectorQuery {
this.inner.column(column);
return this;
}
/**
* Set the distance metric to use
*
* When performing a vector search we try and find the "nearest" vectors according
* to some kind of distance metric. This parameter controls which distance metric to
* use. See @see {@link IvfPqOptions.distanceType} for more details on the different
* distance metrics available.
*
* Note: if there is a vector index then the distance type used MUST match the distance
* type used to train the vector index. If this is not done then the results will be
* invalid.
*
* By default "l2" is used.
*/
distanceType(distanceType: string): VectorQuery {
this.inner.distanceType(distanceType);
return this;
}
/**
* A multiplier to control how many additional rows are taken during the refine step
*
* This argument is only used when the vector column has an IVF PQ index.
* If there is no index then this value is ignored.
*
* An IVF PQ index stores compressed (quantized) values. They query vector is compared
* against these values and, since they are compressed, the comparison is inaccurate.
*
* This parameter can be used to refine the results. It can improve both improve recall
* and correct the ordering of the nearest results.
*
* To refine results LanceDb will first perform an ANN search to find the nearest
* `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
* `limit` is the default (10) then the first 30 results will be selected. LanceDb
* then fetches the full, uncompressed, values for these 30 results. The results are
* then reordered by the true distance and only the nearest 10 are kept.
*
* Note: there is a difference between calling this method with a value of 1 and never
* calling this method at all. Calling this method with any value will have an impact
* on your search latency. When you call this method with a `refine_factor` of 1 then
* LanceDb still needs to fetch the full, uncompressed, values so that it can potentially
* reorder the results.
*
* Note: if this method is NOT called then the distances returned in the _distance column
* will be approximate distances based on the comparison of the quantized query vector
* and the quantized result vectors. This can be considerably different than the true
* distance between the query vector and the actual uncompressed vector.
*/
refineFactor(refineFactor: number): VectorQuery {
this.inner.refineFactor(refineFactor);
return this;
}
/**
* If this is called then filtering will happen after the vector search instead of
* before.
*
* By default filtering will be performed before the vector search. This is how
* filtering is typically understood to work. This prefilter step does add some
* additional latency. Creating a scalar index on the filter column(s) can
* often improve this latency. However, sometimes a filter is too complex or scalar
* indices cannot be applied to the column. In these cases postfiltering can be
* used instead of prefiltering to improve latency.
*
* Post filtering applies the filter to the results of the vector search. This means
* we only run the filter on a much smaller set of data. However, it can cause the
* query to return fewer than `limit` results (or even no results) if none of the nearest
* results match the filter.
*
* Post filtering happens during the "refine stage" (described in more detail in
* @see {@link VectorQuery#refineFactor}). This means that setting a higher refine
* factor can often help restore some of the results lost by post filtering.
*/
postfilter(): VectorQuery {
this.inner.postfilter();
return this;
}
/**
* If this is called then any vector index is skipped
*
* An exhaustive (flat) search will be performed. The query vector will
* be compared to every vector in the table. At high scales this can be
* expensive. However, this is often still useful. For example, skipping
* the vector index can give you ground truth results which you can use to
* calculate your recall to select an appropriate value for nprobes.
*/
bypassVectorIndex(): VectorQuery {
this.inner.bypassVectorIndex();
return this;
}
}
/** A builder for LanceDB queries. */
export class Query extends QueryBase<NativeQuery, Query> {
constructor(tbl: NativeTable) {
super(tbl.query());
}
/**
* Find the nearest vectors to the given query vector.
*
* This converts the query from a plain query to a vector query.
*
* This method will attempt to convert the input to the query vector
* expected by the embedding model. If the input cannot be converted
* then an error will be thrown.
*
* By default, there is no embedding model, and the input should be
* an array-like object of numbers (something that can be used as input
* to Float32Array.from)
*
* If there is only one vector column (a column whose data type is a
* fixed size list of floats) then the column does not need to be specified.
* If there is more than one vector column you must use
* @see {@link VectorQuery#column} to specify which column you would like
* to compare with.
*
* If no index has been created on the vector column then a vector query
* will perform a distance comparison between the query vector and every
* vector in the database and then sort the results. This is sometimes
* called a "flat search"
*
* For small databases, with a few hundred thousand vectors or less, this can
* be reasonably fast. In larger databases you should create a vector index
* on the column. If there is a vector index then an "approximate" nearest
* neighbor search (frequently called an ANN search) will be performed. This
* search is much faster, but the results will be approximate.
*
* The query can be further parameterized using the returned builder. There
* are various ANN search parameters that will let you fine tune your recall
* accuracy vs search latency.
*
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: unknown): VectorQuery {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
return new VectorQuery(vectorQuery);
}
}

View File

@@ -481,6 +481,13 @@ function sanitizeField(fieldLike: unknown): Field {
return new Field(name, type, nullable, metadata);
}
/**
* Convert something schemaLike into a Schema instance
*
* This method is often needed even when the caller is using a Schema
* instance because they might be using a different instance of apache-arrow
* than lancedb is using.
*/
export function sanitizeSchema(schemaLike: unknown): Schema {
if (schemaLike instanceof Schema) {
return schemaLike;

View File

@@ -19,7 +19,7 @@ import {
IndexConfig,
Table as _NativeTable,
} from "./native";
import { Query } from "./query";
import { Query, VectorQuery } from "./query";
import { IndexOptions } from "./indices";
import { Data, fromDataToBuffer } from "./arrow";
@@ -28,7 +28,8 @@ export { IndexConfig } from "./native";
* Options for adding data to a table.
*/
export interface AddDataOptions {
/** If "append" (the default) then the new data will be added to the table
/**
* If "append" (the default) then the new data will be added to the table
*
* If "overwrite" then the new data will replace the existing data in the table.
*/
@@ -74,7 +75,8 @@ export class Table {
return this.inner.isOpen();
}
/** Close the table, releasing any underlying resources.
/**
* Close the table, releasing any underlying resources.
*
* It is safe to call this method multiple times.
*
@@ -98,9 +100,7 @@ export class Table {
/**
* Insert records into this Table.
*
* @param {Data} data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append";
@@ -124,15 +124,15 @@ export class Table {
* you are updating many rows (with different ids) then you will get
* better performance with a single [`merge_insert`] call instead of
* repeatedly calilng this method.
*
* @param updates the columns to update
* @param {Map<string, string> | Record<string, string>} updates - the
* columns to update
*
* Keys in the map should specify the name of the column to update.
* Values in the map provide the new value of the column. These can
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
* based on the row being updated (e.g. "my_col + 1")
*
* @param options additional options to control the update behavior
* @param {Partial<UpdateOptions>} options - additional options to control
* the update behavior
*/
async update(
updates: Map<string, string> | Record<string, string>,
@@ -158,37 +158,28 @@ export class Table {
await this.inner.delete(predicate);
}
/** Create an index to speed up queries.
/**
* Create an index to speed up queries.
*
* Indices can be created on vector columns or scalar columns.
* Indices on vector columns will speed up vector searches.
* Indices on scalar columns will speed up filtering (in both
* vector and non-vector searches)
*
* @example
*
* If the column has a vector (fixed size list) data type then
* an IvfPq vector index will be created.
*
* ```typescript
* // If the column has a vector (fixed size list) data type then
* // an IvfPq vector index will be created.
* const table = await conn.openTable("my_table");
* await table.createIndex(["vector"]);
* ```
*
* For advanced control over vector index creation you can specify
* the index type and options.
* ```typescript
* @example
* // For advanced control over vector index creation you can specify
* // the index type and options.
* const table = await conn.openTable("my_table");
* await table.createIndex(["vector"], I)
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
* .build();
* ```
*
* Or create a Scalar index
*
* ```typescript
* @example
* // Or create a Scalar index
* await table.createIndex("my_float_col").build();
* ```
*/
async createIndex(column: string, options?: Partial<IndexOptions>) {
// Bit of a hack to get around the fact that TS has no package-scope.
@@ -198,69 +189,74 @@ export class Table {
}
/**
* Create a generic {@link Query} Builder.
* Create a {@link Query} Builder.
*
* Queries allow you to search your existing data. By default the query will
* return all the data in the table in no particular order. The builder
* returned by this method can be used to control the query using filtering,
* vector similarity, sorting, and more.
*
* Note: By default, all columns are returned. For best performance, you should
* only fetch the columns you need. See [`Query::select_with_projection`] for
* more details.
*
* When appropriate, various indices and statistics based pruning will be used to
* accelerate the query.
*
* @example
*
* ### Run a SQL-style query
* ```typescript
* // SQL-style filtering
* //
* // This query will return up to 1000 rows whose value in the `id` column
* // is greater than 5. LanceDb supports a broad set of filtering functions.
* for await (const batch of table.query()
* .filter("id > 1").select(["id"]).limit(20)) {
* console.log(batch);
* }
* ```
*
* ### Run Top-10 vector similarity search
* ```typescript
* @example
* // Vector Similarity Search
* //
* // This example will find the 10 rows whose value in the "vector" column are
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
* // on the "vector" column then this will perform an ANN search.
* //
* // The `refine_factor` and `nprobes` methods are used to control the recall /
* // latency tradeoff of the search.
* for await (const batch of table.query()
* .nearestTo([1, 2, 3])
* .refineFactor(5).nprobe(10)
* .limit(10)) {
* console.log(batch);
* }
*```
*
* ### Scan the full dataset
* ```typescript
* @example
* // Scan the full dataset
* //
* // This query will return everything in the table in no particular order.
* for await (const batch of table.query()) {
* console.log(batch);
* }
*
* ### Return the full dataset as Arrow Table
* ```typescript
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
* ```
*
* @returns {@link Query}
* @returns {Query} A builder that can be used to parameterize the query
*/
query(): Query {
return new Query(this.inner);
}
/** Search the table with a given query vector.
/**
* Search the table with a given query vector.
*
* This is a convenience method for preparing an ANN {@link Query}.
* This is a convenience method for preparing a vector query and
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
search(vector: number[], column?: string): Query {
const q = this.query();
q.nearestTo(vector);
if (column !== undefined) {
q.column(column);
}
return q;
vectorSearch(vector: unknown): VectorQuery {
return this.query().nearestTo(vector);
}
// TODO: Support BatchUDF
/**
* Add new columns with defined values.
*
* @param newColumnTransforms pairs of column names and the SQL expression to use
* to calculate the value of the new column. These
* expressions will be evaluated for each row in the
* table, and can reference existing columns in the table.
* @param {AddColumnsSql[]} newColumnTransforms pairs of column names and
* the SQL expression to use to calculate the value of the new column. These
* expressions will be evaluated for each row in the table, and can
* reference existing columns in the table.
*/
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
await this.inner.addColumns(newColumnTransforms);
@@ -268,8 +264,8 @@ export class Table {
/**
* Alter the name or nullability of columns.
*
* @param columnAlterations One or more alterations to apply to columns.
* @param {ColumnAlteration[]} columnAlterations One or more alterations to
* apply to columns.
*/
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
await this.inner.alterColumns(columnAlterations);
@@ -282,16 +278,16 @@ export class Table {
* underlying storage. In order to remove the data, you must subsequently
* call ``compact_files`` to rewrite the data without the removed columns and
* then call ``cleanup_files`` to remove the old files.
*
* @param columnNames The names of the columns to drop. These can be nested
* column references (e.g. "a.b.c") or top-level column
* names (e.g. "a").
* @param {string[]} columnNames The names of the columns to drop. These can
* be nested column references (e.g. "a.b.c") or top-level column names
* (e.g. "a").
*/
async dropColumns(columnNames: string[]): Promise<void> {
await this.inner.dropColumns(columnNames);
}
/** Retrieve the version of the table
/**
* Retrieve the version of the table
*
* LanceDb supports versioning. Every operation that modifies the table increases
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
@@ -302,7 +298,8 @@ export class Table {
return await this.inner.version();
}
/** Checks out a specific version of the Table
/**
* Checks out a specific version of the Table
*
* Any read operation on the table will now access the data at the checked out version.
* As a consequence, calling this method will disable any read consistency interval
@@ -321,7 +318,8 @@ export class Table {
await this.inner.checkout(version);
}
/** Ensures the table is pointing at the latest version
/**
* Ensures the table is pointing at the latest version
*
* This can be used to manually update a table when the read_consistency_interval is None
* It can also be used to undo a `[Self::checkout]` operation
@@ -330,7 +328,8 @@ export class Table {
await this.inner.checkoutLatest();
}
/** Restore the table to the currently checked out version
/**
* Restore the table to the currently checked out version
*
* This operation will fail if checkout has not been called previously
*

120
nodejs/package-lock.json generated
View File

@@ -26,6 +26,7 @@
"apache-arrow-old": "npm:apache-arrow@13.0.0",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-jsdoc": "^48.2.1",
"jest": "^29.7.0",
"prettier": "^3.1.0",
"tmp": "^0.2.3",
@@ -755,6 +756,20 @@
"integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==",
"dev": true
},
"node_modules/@es-joy/jsdoccomment": {
"version": "0.42.0",
"resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.42.0.tgz",
"integrity": "sha512-R1w57YlVA6+YE01wch3GPYn6bCsrOV3YW/5oGGE2tmX6JcL9Nr+b5IikrjMPF+v9CV3ay+obImEdsDhovhJrzw==",
"dev": true,
"dependencies": {
"comment-parser": "1.4.1",
"esquery": "^1.5.0",
"jsdoc-type-pratt-parser": "~4.0.0"
},
"engines": {
"node": ">=16"
}
},
"node_modules/@eslint-community/eslint-utils": {
"version": "4.4.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz",
@@ -1948,6 +1963,15 @@
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==",
"dev": true
},
"node_modules/are-docs-informative": {
"version": "0.0.2",
"resolved": "https://registry.npmjs.org/are-docs-informative/-/are-docs-informative-0.0.2.tgz",
"integrity": "sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==",
"dev": true,
"engines": {
"node": ">=14"
}
},
"node_modules/argparse": {
"version": "1.0.10",
"resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz",
@@ -2189,6 +2213,18 @@
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
"dev": true
},
"node_modules/builtin-modules": {
"version": "3.3.0",
"resolved": "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz",
"integrity": "sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw==",
"dev": true,
"engines": {
"node": ">=6"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/camelcase": {
"version": "5.3.1",
"resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz",
@@ -2373,6 +2409,15 @@
"node": ">=12.17"
}
},
"node_modules/comment-parser": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.4.1.tgz",
"integrity": "sha512-buhp5kePrmda3vhc5B9t7pUQXAb2Tnd0qgpkIhPhkHXxJpiPJ11H0ZEU0oBpJ2QztSbzG/ZxMj/CHsYJqRHmyg==",
"dev": true,
"engines": {
"node": ">= 12.0.0"
}
},
"node_modules/concat-map": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
@@ -2660,6 +2705,29 @@
"eslint": ">=7.0.0"
}
},
"node_modules/eslint-plugin-jsdoc": {
"version": "48.2.1",
"resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-48.2.1.tgz",
"integrity": "sha512-iUvbcyDZSO/9xSuRv2HQBw++8VkV/pt3UWtX9cpPH0l7GKPq78QC/6+PmyQHHvNZaTjAce6QVciEbnc6J/zH5g==",
"dev": true,
"dependencies": {
"@es-joy/jsdoccomment": "~0.42.0",
"are-docs-informative": "^0.0.2",
"comment-parser": "1.4.1",
"debug": "^4.3.4",
"escape-string-regexp": "^4.0.0",
"esquery": "^1.5.0",
"is-builtin-module": "^3.2.1",
"semver": "^7.6.0",
"spdx-expression-parse": "^4.0.0"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"eslint": "^7.0.0 || ^8.0.0 || ^9.0.0"
}
},
"node_modules/eslint-scope": {
"version": "7.2.2",
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz",
@@ -3299,6 +3367,21 @@
"integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==",
"optional": true
},
"node_modules/is-builtin-module": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/is-builtin-module/-/is-builtin-module-3.2.1.tgz",
"integrity": "sha512-BSLE3HnV2syZ0FK0iMA/yUGplUeMmNz4AW5fnTunbCIqZi4vG3WjJT9FHMy5D69xmAYBHXQhJdALdpwVxV501A==",
"dev": true,
"dependencies": {
"builtin-modules": "^3.3.0"
},
"engines": {
"node": ">=6"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/is-core-module": {
"version": "2.13.1",
"resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz",
@@ -4172,6 +4255,15 @@
"js-yaml": "bin/js-yaml.js"
}
},
"node_modules/jsdoc-type-pratt-parser": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-4.0.0.tgz",
"integrity": "sha512-YtOli5Cmzy3q4dP26GraSOeAhqecewG04hoO8DY56CH4KJ9Fvv5qKWUCCo3HZob7esJQHCv6/+bnTy72xZZaVQ==",
"dev": true,
"engines": {
"node": ">=12.0.0"
}
},
"node_modules/jsesc": {
"version": "2.5.2",
"resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz",
@@ -5018,9 +5110,9 @@
}
},
"node_modules/semver": {
"version": "7.5.4",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
"integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
"version": "7.6.0",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.6.0.tgz",
"integrity": "sha512-EnwXhrlwXMk9gKu5/flx5sv/an57AkRplG3hTK68W7FRDN+k+OWBj65M7719OkA82XLBxrcX0KSHj+X5COhOVg==",
"dev": true,
"dependencies": {
"lru-cache": "^6.0.0"
@@ -5105,6 +5197,28 @@
"source-map": "^0.6.0"
}
},
"node_modules/spdx-exceptions": {
"version": "2.5.0",
"resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.5.0.tgz",
"integrity": "sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==",
"dev": true
},
"node_modules/spdx-expression-parse": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/spdx-expression-parse/-/spdx-expression-parse-4.0.0.tgz",
"integrity": "sha512-Clya5JIij/7C6bRR22+tnGXbc4VKlibKSVj2iHvVeX5iMW7s1SIQlqu699JkODJJIhh/pUu8L0/VLh8xflD+LQ==",
"dev": true,
"dependencies": {
"spdx-exceptions": "^2.1.0",
"spdx-license-ids": "^3.0.0"
}
},
"node_modules/spdx-license-ids": {
"version": "3.0.17",
"resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.17.tgz",
"integrity": "sha512-sh8PWc/ftMqAAdFiBu6Fy6JUOYjqDJBJvIhpfDMyHrr0Rbp5liZqd4TjtQ/RgfLjKFZb+LMx5hpml5qOWy0qvg==",
"dev": true
},
"node_modules/sprintf-js": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz",

View File

@@ -25,6 +25,7 @@
"apache-arrow-old": "npm:apache-arrow@13.0.0",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-jsdoc": "^48.2.1",
"jest": "^29.7.0",
"prettier": "^3.1.0",
"tmp": "^0.2.3",

View File

@@ -17,9 +17,10 @@ use std::sync::Mutex;
use lancedb::index::scalar::BTreeIndexBuilder;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index as LanceDbIndex;
use lancedb::DistanceType;
use napi_derive::napi;
use crate::util::parse_distance_type;
#[napi]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
@@ -49,15 +50,7 @@ impl Index {
) -> napi::Result<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
let distance_type = parse_distance_type(distance_type)?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {

View File

@@ -21,6 +21,7 @@ mod index;
mod iterator;
mod query;
mod table;
mod util;
#[napi(object)]
#[derive(Debug)]

View File

@@ -12,36 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lancedb::query::Query as LanceDBQuery;
use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase;
use lancedb::query::Select;
use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator;
use crate::util::parse_distance_type;
#[napi]
pub struct Query {
inner: LanceDBQuery,
inner: LanceDbQuery,
}
#[napi]
impl Query {
pub fn new(query: LanceDBQuery) -> Self {
pub fn new(query: LanceDbQuery) -> Self {
Self { inner: query }
}
// We cannot call this r#where because NAPI gets confused by the r#
#[napi]
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
pub fn only_if(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
#[napi]
pub fn filter(&mut self, filter: String) {
self.inner = self.inner.clone().filter(filter);
}
#[napi]
pub fn select(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(&columns);
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
#[napi]
@@ -50,13 +52,46 @@ impl Query {
}
#[napi]
pub fn prefilter(&mut self, prefilter: bool) {
self.inner = self.inner.clone().prefilter(prefilter);
pub fn nearest_to(&mut self, vector: Float32Array) -> Result<VectorQuery> {
let inner = self
.inner
.clone()
.nearest_to(vector.as_ref())
.default_error()?;
Ok(VectorQuery { inner })
}
#[napi]
pub fn nearest_to(&mut self, vector: Float32Array) {
self.inner = self.inner.clone().nearest_to(&vector);
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}
#[napi]
pub struct VectorQuery {
inner: LanceDbVectorQuery,
}
#[napi]
impl VectorQuery {
#[napi]
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}
#[napi]
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
let distance_type = parse_distance_type(distance_type)?;
self.inner = self.inner.clone().distance_type(distance_type);
Ok(())
}
#[napi]
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
#[napi]
@@ -70,8 +105,28 @@ impl Query {
}
#[napi]
pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()
}
#[napi]
pub fn only_if(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
#[napi]
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
#[napi]
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))

View File

@@ -23,7 +23,7 @@ use napi_derive::napi;
use crate::error::NapiErrorExt;
use crate::index::Index;
use crate::query::Query;
use crate::query::{Query, VectorQuery};
#[napi]
pub struct Table {
@@ -171,6 +171,11 @@ impl Table {
Ok(Query::new(self.inner_ref()?.query()))
}
#[napi]
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
self.query()?.nearest_to(vector)
}
#[napi]
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
let transforms = transforms

13
nodejs/src/util.rs Normal file
View File

@@ -0,0 +1,13 @@
use lancedb::DistanceType;
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type.as_ref()
))),
}
}

View File

@@ -22,6 +22,9 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }
pin-project = "1.1.5"
futures.workspace = true
tokio = { version = "1.36.0", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.20.3", features = [

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import pyarrow as pa
@@ -40,6 +40,8 @@ class Table:
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...
class IndexConfig:
index_type: str
@@ -52,3 +54,27 @@ async def connect(
host_override: Optional[str],
read_consistency_interval: Optional[float],
) -> Connection: ...
class RecordBatchStream:
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def column(self, column: str): ...
def distance_type(self, distance_type: str): ...
def postfilter(self): ...
def refine_factor(self, refine_factor: int): ...
def nprobes(self, nprobes: int): ...
def bypass_vector_index(self): ...

View File

@@ -0,0 +1,44 @@
from typing import List
import pyarrow as pa
from ._lancedb import RecordBatchStream
class AsyncRecordBatchReader:
"""
An async iterator over a stream of RecordBatches.
Also allows access to the schema of the stream
"""
def __init__(self, inner: RecordBatchStream):
self.inner_ = inner
@property
def schema(self) -> pa.Schema:
"""
Get the schema of the batches produced by the stream
Accessing the schema does not consume any data from the stream
"""
return self.inner_.schema()
async def read_all(self) -> List[pa.RecordBatch]:
"""
Read all the record batches from the stream
This consumes the entire stream and returns a list of record batches
If there are a lot of results this may consume a lot of memory
"""
return [batch async for batch in self]
def __aiter__(self):
return self
async def __anext__(self) -> pa.RecordBatch:
next = await self.inner_.next()
if next is None:
raise StopAsyncIteration
return next

View File

@@ -24,6 +24,7 @@ import pyarrow as pa
import pydantic
from . import __version__
from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
@@ -33,6 +34,8 @@ if TYPE_CHECKING:
import PIL
import polars as pl
from ._lancedb import Query as LanceQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .pydantic import LanceModel
from .table import Table
@@ -921,3 +924,334 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"""
self._vector_query.refine_factor(refine_factor)
return self
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
"""
Construct an AsyncQueryBase
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
self._inner = inner
def where(self, predicate: str) -> AsyncQuery:
"""
Only return rows matching the given predicate
The predicate should be supplied as an SQL query string. For example:
>>> predicate = "x > 10"
>>> predicate = "y > 0 AND y < 100"
>>> predicate = "x > 5 OR y = 'test'"
Filtering performance can often be improved by creating a scalar index
on the filter column(s).
"""
self._inner.where(predicate)
return self
def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery:
"""
Return only the specified columns.
By default a query will return all columns from the table. However, this can
have a very significant impact on latency. LanceDb stores data in a columnar
fashion. This
means we can finely tune our I/O to select exactly the columns we need.
As a best practice you should always limit queries to the columns that you need.
If you pass in a list of column names then only those columns will be
returned.
You can also use this method to create new "dynamic" columns based on your
existing columns. For example, you may not care about "a" or "b" but instead
simply want "a + b". This is often seen in the SELECT clause of an SQL query
(e.g. `SELECT a+b FROM my_table`).
To create dynamic columns you can pass in a dict[str, str]. A column will be
returned for each entry in the map. The key provides the name of the column.
The value is an SQL string used to specify how the column is calculated.
For example, an SQL query might state `SELECT a + b AS combined, c`. The
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
Columns will always be returned in the order given, even if that order is
different than the order used when adding the data.
"""
if isinstance(columns, dict):
column_tuples = list(columns.items())
else:
try:
column_tuples = [(c, c) for c in columns]
except TypeError:
raise TypeError("columns must be a list of column names or a dict")
self._inner.select(column_tuples)
return self
def limit(self, limit: int) -> AsyncQuery:
"""
Set the maximum number of results to return.
By default, a plain search has no limit. If this method is not
called then every valid row from the table will be returned.
"""
self._inner.limit(limit)
return self
async def to_batches(self) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
"""
return AsyncRecordBatchReader(await self._inner.execute())
async def to_arrow(self) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
"""
batch_iter = await self.to_batches()
return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema
)
async def to_pandas(self) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
and convert each batch to pandas separately.
Example
-------
>>> import asyncio
>>> from lancedb import connect_async
>>> async def doctest_example():
... conn = await connect_async("./.lancedb")
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
... async for batch in await table.query().to_batches():
... batch_df = batch.to_pandas()
>>> asyncio.run(doctest_example())
"""
return (await self.to_arrow()).to_pandas()
class AsyncQuery(AsyncQueryBase):
def __init__(self, inner: LanceQuery):
"""
Construct an AsyncQuery
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
super().__init__(inner)
self._inner = inner
@classmethod
def _query_vec_to_array(self, vec: Union[VEC, Tuple]):
if isinstance(vec, list):
return pa.array(vec)
if isinstance(vec, np.ndarray):
return pa.array(vec)
if isinstance(vec, pa.Array):
return vec
if isinstance(vec, pa.ChunkedArray):
return vec.combine_chunks()
if isinstance(vec, tuple):
return pa.array(vec)
# We've checked everything we formally support in our typings
# but, as a fallback, let pyarrow try and convert it anyway.
# This can allow for some more exotic things like iterables
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple]] = None
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
This converts the query from a plain query to a vector query.
This method will attempt to convert the input to the query vector
expected by the embedding model. If the input cannot be converted
then an error will be thrown.
By default, there is no embedding model, and the input should be
something that can be converted to a pyarrow array of floats. This
includes lists, numpy arrays, and tuples.
If there is only one vector column (a column whose data type is a
fixed size list of floats) then the column does not need to be specified.
If there is more than one vector column you must use
[AsyncVectorQuery::column][] to specify which column you would like to
compare with.
If no index has been created on the vector column then a vector query
will perform a distance comparison between the query vector and every
vector in the database and then sort the results. This is sometimes
called a "flat search"
For small databases, with tens of thousands of vectors or less, this can
be reasonably fast. In larger databases you should create a vector index
on the column. If there is a vector index then an "approximate" nearest
neighbor search (frequently called an ANN search) will be performed. This
search is much faster, but the results will be approximate.
The query can be further parameterized using the returned builder. There
are various ANN search parameters that will let you fine tune your recall
accuracy vs search latency.
Vector searches always have a [limit][]. If `limit` has not been called then
a default `limit` of 10 will be used.
"""
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):
"""
Construct an AsyncVectorQuery
This method is not intended to be called directly. Instead, create
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
to convert to a vector query.
"""
super().__init__(inner)
self._inner = inner
def column(self, column: str) -> AsyncVectorQuery:
"""
Set the vector column to query
This controls which column is compared to the query vector supplied in
the call to [Query.nearest_to][].
This parameter must be specified if the table has more than one column
whose data type is a fixed-size-list of floats.
"""
self._inner.column(column)
return self
def nprobes(self, nprobes: int) -> AsyncVectorQuery:
"""
Set the number of partitions to search (probe)
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
The IVF stage of IVF PQ divides the input into partitions (clusters) of
related values.
The partition whose centroids are closest to the query vector will be
exhaustiely searched to find matches. This parameter controls how many
partitions should be searched.
Increasing this value will increase the recall of your query but will
also increase the latency of your query. The default value is 20. This
default is good for many cases but the best value to use will depend on
your data and the recall that you need to achieve.
For best results we recommend tuning this parameter with a benchmark against
your actual data to find the smallest possible value that will still give
you the desired recall.
"""
self._inner.nprobes(nprobes)
return self
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
"""
A multiplier to control how many additional rows are taken during the refine
step
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
An IVF PQ index stores compressed (quantized) values. They query vector is
compared against these values and, since they are compressed, the comparison is
inaccurate.
This parameter can be used to refine the results. It can improve both improve
recall and correct the ordering of the nearest results.
To refine results LanceDb will first perform an ANN search to find the nearest
`limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
`limit` is the default (10) then the first 30 results will be selected. LanceDb
then fetches the full, uncompressed, values for these 30 results. The results
are then reordered by the true distance and only the nearest 10 are kept.
Note: there is a difference between calling this method with a value of 1 and
never calling this method at all. Calling this method with any value will have
an impact on your search latency. When you call this method with a
`refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed,
values so that it can potentially reorder the results.
Note: if this method is NOT called then the distances returned in the _distance
column will be approximate distances based on the comparison of the quantized
query vector and the quantized result vectors. This can be considerably
different than the true distance between the query vector and the actual
uncompressed vector.
"""
self._inner.refine_factor(refine_factor)
return self
def distance_type(self, distance_type: str) -> AsyncVectorQuery:
"""
Set the distance metric to use
When performing a vector search we try and find the "nearest" vectors according
to some kind of distance metric. This parameter controls which distance metric
to use. See @see {@link IvfPqOptions.distanceType} for more details on the
different distance metrics available.
Note: if there is a vector index then the distance type used MUST match the
distance type used to train the vector index. If this is not done then the
results will be invalid.
By default "l2" is used.
"""
self._inner.distance_type(distance_type)
return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery:
"""
If this is called then any vector index is skipped
An exhaustive (flat) search will be performed. The query vector will
be compared to every vector in the table. At high scales this can be
expensive. However, this is often still useful. For example, skipping
the vector index can give you ground truth results which you can use to
calculate your recall to select an appropriate value for nprobes.
"""
self._inner.bypass_vector_index()
return self

View File

@@ -43,7 +43,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
from .util import (
fs_from_uri,
inf_vector_column_query,
@@ -1899,6 +1899,9 @@ class AsyncTable:
"""
return await self._inner.count_rows(filter)
def query(self) -> AsyncQuery:
return AsyncQuery(self._inner.query())
async def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -1906,7 +1909,7 @@ class AsyncTable:
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
return (await self.to_arrow()).to_pandas()
async def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
@@ -1915,7 +1918,7 @@ class AsyncTable:
-------
pa.Table
"""
raise NotImplementedError
return await self.query().to_arrow()
async def create_index(
self,
@@ -2068,90 +2071,18 @@ class AsyncTable:
return LanceMergeInsertBuilder(self, on)
async def search(
def vector_search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [Query][lancedb.query.Query].
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width", "vector"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters
----------
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns
- selected columns
- the vector
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
query_vector: Optional[Union[VEC, Tuple]] = None,
) -> AsyncVectorQuery:
"""
raise NotImplementedError
Search the table with a given query vector.
async def _execute_query(self, query: Query) -> pa.Table:
pass
This is a convenience method for preparing a vector query and
is the same thing as calling `nearestTo` on the builder returned
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
"""
return self.query().nearest_to(query_vector)
async def _do_merge(
self,

View File

@@ -12,16 +12,19 @@
# limitations under the License.
import unittest.mock as mock
from datetime import timedelta
import lance
import lancedb
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
from lancedb.table import AsyncTable, LanceTable
class MockTable:
@@ -65,6 +68,24 @@ def table(tmp_path) -> MockTable:
return MockTable(tmp_path)
@pytest_asyncio.fixture
async def table_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
data = pa.table(
{
"vector": pa.array(
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
}
)
return await conn.create_table("test", data)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
@@ -184,3 +205,109 @@ def test_query_builder_with_different_vector_column():
def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
async def check_query(
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
):
num_rows = 0
results = await query.to_batches()
async for batch in results:
if expected_columns is not None:
assert batch.schema.names == expected_columns
num_rows += batch.num_rows
if expected_num_rows is not None:
assert num_rows == expected_num_rows
@pytest.mark.asyncio
async def test_query_async(table_async: AsyncTable):
await check_query(
table_async.query(),
expected_num_rows=2,
expected_columns=["vector", "id", "str_field", "float_field"],
)
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
await check_query(
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
)
await check_query(
table_async.query().select({"foo": "id", "bar": "id + 1"}),
expected_columns=["foo", "bar"],
)
await check_query(table_async.query().limit(1), expected_num_rows=1)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
)
# Support different types of inputs for the vector query
for vector_query in [
[1, 2],
[1.0, 2.0],
np.array([1, 2]),
(1, 2),
]:
await check_query(
table_async.query().nearest_to(vector_query), expected_num_rows=2
)
# No easy way to check these vector query parameters are doing what they say. We
# just check that they don't raise exceptions and assume this is tested at a lower
# level.
await check_query(
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
expected_num_rows=1,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
expected_num_rows=2,
)
# Make sure we can use a vector query as a base query (e.g. call limit on it)
# Also make sure `vector_search` works
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
# Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
table = await table_async.to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().where("id < 0").to_arrow()
assert table.num_rows == 0
assert table.num_columns == 4
@pytest.mark.asyncio
async def test_query_to_pandas_async(table_async: AsyncTable):
df = await table_async.to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().where("id < 0").to_pandas()
assert df.shape == (0, 4)

51
python/src/arrow.rs Normal file
View File

@@ -0,0 +1,51 @@
// use arrow::datatypes::SchemaRef;
// use lancedb::arrow::SendableRecordBatchStream;
use std::sync::Arc;
use arrow::{
datatypes::SchemaRef,
pyarrow::{IntoPyArrow, ToPyArrow},
};
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
use pyo3_asyncio::tokio::future_into_py;
use crate::error::PythonErrorExt;
#[pyclass]
pub struct RecordBatchStream {
schema: SchemaRef,
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
}
impl RecordBatchStream {
pub fn new(inner: SendableRecordBatchStream) -> Self {
let schema = inner.schema().clone();
Self {
schema,
inner: Arc::new(tokio::sync::Mutex::new(inner)),
}
}
}
#[pymethods]
impl RecordBatchStream {
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema).clone().into_pyarrow(py)
}
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_next = inner.lock().await.next().await;
inner_next
.map(|item| {
let item = item.infer_error()?;
Python::with_gil(|py| item.to_pyarrow(py))
})
.transpose()
})
}
}

View File

@@ -12,15 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::{Index, IndexConfig};
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
use query::{Query, VectorQuery};
use table::Table;
pub mod arrow;
pub mod connection;
pub mod error;
pub mod index;
pub mod query;
pub mod table;
pub mod util;
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_class::<IndexConfig>()?;
m.add_class::<Query>()?;
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

125
python/src/query.rs Normal file
View File

@@ -0,0 +1,125 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
use pyo3::pyclass;
use pyo3::pymethods;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3_asyncio::tokio::future_into_py;
use crate::arrow::RecordBatchStream;
use crate::error::PythonErrorExt;
use crate::util::parse_distance_type;
#[pyclass]
pub struct Query {
inner: LanceDbQuery,
}
impl Query {
pub fn new(query: LanceDbQuery) -> Self {
Self { inner: query }
}
}
#[pymethods]
impl Query {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
let array = make_array(data);
let inner = self.inner.clone().nearest_to(array).infer_error()?;
Ok(VectorQuery { inner })
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}
#[pyclass]
pub struct VectorQuery {
inner: LanceDbVectorQuery,
}
#[pymethods]
impl VectorQuery {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
let distance_type = parse_distance_type(distance_type)?;
self.inner = self.inner.clone().distance_type(distance_type);
Ok(())
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn refine_factor(&mut self, refine_factor: u32) {
self.inner = self.inner.clone().refine_factor(refine_factor);
}
pub fn nprobes(&mut self, nprobe: u32) {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}

View File

@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
use crate::{
error::PythonErrorExt,
index::{Index, IndexConfig},
query::Query,
};
#[pyclass]
@@ -179,4 +180,8 @@ impl Table {
async move { inner.restore().await.infer_error() },
)
}
pub fn query(&self) -> Query {
Query::new(self.inner_ref().unwrap().query())
}
}

View File

@@ -1,6 +1,10 @@
use std::sync::Mutex;
use pyo3::{exceptions::PyRuntimeError, PyResult};
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
};
/// A wrapper around a rust builder
///
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
Ok(result)
}
}
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type.as_ref()
))),
}
}

View File

@@ -3,6 +3,7 @@ use std::ops::Deref;
use futures::{TryFutureExt, TryStreamExt};
use lance_linalg::distance::MetricType;
use lancedb::query::{ExecutableQuery, QueryBase, Select};
use neon::context::FunctionContext;
use neon::handle::Handle;
use neon::prelude::*;
@@ -56,53 +57,72 @@ impl JsQuery {
let channel = cx.channel();
let table = js_table.table.clone();
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
let mut builder = table.query();
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
builder = builder.nearest_to(&query);
if let Some(metric_type) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap())
{
builder = builder.metric_type(metric_type);
}
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
builder = builder.nprobes(nprobes);
};
if let Some(filter) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx))
{
builder = builder.filter(filter);
builder = builder.only_if(filter);
}
if let Some(select) = select {
builder = builder.select(select.as_slice());
builder = builder.select(Select::columns(select.as_slice()));
}
if let Some(limit) = limit {
builder = builder.limit(limit as usize);
};
builder = builder.prefilter(prefilter);
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
let mut vector_builder = builder.nearest_to(query).unwrap();
if let Some(metric_type) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap())
{
vector_builder = vector_builder.distance_type(metric_type);
}
rt.spawn(async move {
let record_batch_stream = builder.execute_stream();
let results = record_batch_stream
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
vector_builder = vector_builder.nprobes(nprobes);
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
if !prefilter {
vector_builder = vector_builder.postfilter();
}
rt.spawn(async move {
let results = vector_builder
.execute()
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
});
});
});
} else {
rt.spawn(async move {
let results = builder
.execute()
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
});
});
};
Ok(promise)
}
}

View File

@@ -21,6 +21,7 @@ use futures::TryStreamExt;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use lancedb::{connect, Result, Table as LanceDbTable};
#[tokio::main]
@@ -150,9 +151,10 @@ async fn create_index(table: &LanceDbTable) -> Result<()> {
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
// --8<-- [start:search]
table
.search(&[1.0; 128])
.query()
.limit(2)
.execute_stream()
.nearest_to(&[1.0; 128])?
.execute()
.await?
.try_collect::<Vec<_>>()
.await

View File

@@ -342,7 +342,11 @@ mod test {
use object_store::local::LocalFileSystem;
use tempfile;
use crate::{connect, table::WriteOptions};
use crate::{
connect,
query::{ExecutableQuery, QueryBase},
table::WriteOptions,
};
#[tokio::test]
async fn test_e2e() {
@@ -381,9 +385,11 @@ mod test {
assert_eq!(t.count_rows(None).await.unwrap(), 100);
let q = t
.search(&[0.1, 0.1, 0.1, 0.1])
.query()
.limit(10)
.execute_stream()
.nearest_to(&[0.1, 0.1, 0.1, 0.1])
.unwrap()
.execute()
.await
.unwrap();

View File

@@ -150,6 +150,7 @@
//! # use arrow_schema::{DataType, Schema, Field};
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use lancedb::query::{ExecutableQuery, QueryBase};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
@@ -170,8 +171,10 @@
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
//! # let table = db.open_table("my_table").execute().await.unwrap();
//! let results = table
//! .search(&[1.0; 128])
//! .execute_stream()
//! .query()
//! .nearest_to(&[1.0; 128])
//! .unwrap()
//! .execute()
//! .await
//! .unwrap()
//! .try_collect::<Vec<_>>()

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewCol
use crate::{
error::Result,
index::{IndexBuilder, IndexConfig},
query::Query,
query::{Query, QueryExecutionOptions, VectorQuery},
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableInternal, UpdateBuilder,
@@ -66,7 +66,18 @@ impl TableInternal for RemoteTable {
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
todo!()
}
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
async fn plain_query(
&self,
_query: &Query,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {

View File

@@ -17,6 +17,8 @@
use std::path::Path;
use std::sync::Arc;
use arrow::array::AsArray;
use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
@@ -47,7 +49,9 @@ use crate::index::{
vector::{suggested_num_partitions, suggested_num_sub_vectors},
Index, IndexBuilder,
};
use crate::query::{Query, Select, DEFAULT_TOP_K};
use crate::query::{
Query, QueryExecutionOptions, Select, ToQueryVector, VectorQuery, DEFAULT_TOP_K,
};
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
use self::dataset::DatasetConsistencyWrapper;
@@ -230,7 +234,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn add(&self, add: AddDataBuilder) -> Result<()>;
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream>;
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn delete(&self, predicate: &str) -> Result<()>;
async fn update(&self, update: UpdateBuilder) -> Result<()>;
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
@@ -528,21 +541,30 @@ impl Table {
)
}
/// Search the table with a given query vector.
/// Create a [`Query`] Builder.
///
/// This is a convenience method for preparing an ANN query.
pub fn search(&self, query: &[f32]) -> Query {
self.query().nearest_to(query)
}
/// Create a generic [`Query`] Builder.
/// Queries allow you to search your existing data. By default the query will
/// return all the data in the table in no particular order. The builder
/// returned by this method can be used to control the query using filtering,
/// vector similarity, sorting, and more.
///
/// When appropriate, various indices and statistics based pruning will be used to
/// accelerate the query.
/// Note: By default, all columns are returned. For best performance, you should
/// only fetch the columns you need. See [`Query::select_with_projection`] for
/// more details.
///
/// When appropriate, various indices and statistics will be used to accelerate
/// the query.
///
/// # Examples
///
/// ## Run a vector search (ANN) query.
/// ## Vector search
///
/// This example will find the 10 rows whose value in the "vector" column are
/// closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
/// on the "vector" column then this will perform an ANN search.
///
/// The [`Query::refine_factor`] and [`Query::nprobes`] methods are used to
/// control the recall / latency tradeoff of the search.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
@@ -551,19 +573,25 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// use crate::lancedb::query::ExecutableQuery;
/// let stream = tbl
/// .query()
/// .nearest_to(&[1.0, 2.0, 3.0])
/// .unwrap()
/// .refine_factor(5)
/// .nprobes(10)
/// .execute_stream()
/// .execute()
/// .await
/// .unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
///
/// ## Run a SQL-style filter
/// ## SQL-style filter
///
/// This query will return up to 1000 rows whose value in the `id` column
/// is greater than 5. LanceDb supports a broad set of filtering functions.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
/// # use futures::TryStreamExt;
@@ -571,18 +599,23 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// use crate::lancedb::query::{ExecutableQuery, QueryBase};
/// let stream = tbl
/// .query()
/// .filter("id > 5")
/// .only_if("id > 5")
/// .limit(1000)
/// .execute_stream()
/// .execute()
/// .await
/// .unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
///
/// ## Run a full scan query.
/// ## Full scan
///
/// This query will return everything in the table in no particular
/// order.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
/// # use futures::TryStreamExt;
@@ -590,7 +623,8 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// let stream = tbl.query().execute_stream().await.unwrap();
/// use crate::lancedb::query::ExecutableQuery;
/// let stream = tbl.query().execute().await.unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
@@ -598,6 +632,15 @@ impl Table {
Query::new(self.inner.clone())
}
/// Search the table with a given query vector.
///
/// This is a convenience method for preparing a vector query and
/// is the same thing as calling `nearest_to` on the builder returned
/// by `query`. See [`Query::nearest_to`] for more details.
pub fn vector_search(&self, query: impl ToQueryVector) -> Result<VectorQuery> {
self.query().nearest_to(query)
}
/// Optimize the on-disk data and indices for better performance.
///
/// <section class="warning">Experimental API</section>
@@ -1107,6 +1150,75 @@ impl NativeTable {
.await?;
Ok(())
}
async fn generic_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
{
return Err(Error::Schema {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query_vector.len(),
),
});
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type);
}
Ok(scanner.try_into_stream().await?)
}
}
#[async_trait::async_trait]
@@ -1232,63 +1344,21 @@ impl TableInternal for NativeTable {
Ok(())
}
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(&query.clone().into_vector(), options)
.await
}
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
{
return Err(Error::Schema {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query_vector.len(),
),
});
}
scanner.nearest(&column, query_vector, query.limit.unwrap_or(DEFAULT_TOP_K))?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
match &query.select {
Select::Simple(select) => {
scanner.project(select.as_slice())?;
}
Select::Projection(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(metric_type) = query.metric_type {
scanner.distance_metric(metric_type);
}
Ok(scanner.try_into_stream().await?)
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert(
@@ -1450,6 +1520,7 @@ mod tests {
use crate::connect;
use crate::connection::ConnectBuilder;
use crate::index::scalar::BTreeIndexBuilder;
use crate::query::{ExecutableQuery, QueryBase};
use super::*;
@@ -1689,8 +1760,8 @@ mod tests {
let mut batches = table
.query()
.select(&["id", "name"])
.execute_stream()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
@@ -1841,7 +1912,7 @@ mod tests {
let mut batches = table
.query()
.select(&[
.select(Select::columns(&[
"string",
"large_string",
"int32",
@@ -1855,8 +1926,8 @@ mod tests {
"timestamp_ms",
"vec_f32",
"vec_f64",
])
.execute_stream()
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()