From bc582bb7024819d35c1fda2b33b1b04547245e14 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 14 May 2024 08:43:39 -0500 Subject: [PATCH] fix(nodejs): add better error handling when missing embedding functions (#1290) note: running the default lint command `npm run lint -- --fix` seems to have made a lot of unrelated changes. --- node/src/arrow.ts | 342 ++++++---- node/src/index.ts | 433 ++++++------ node/src/sanitize.ts | 54 +- node/src/test/test.ts | 1386 +++++++++++++++++++++------------------ nodejs/lancedb/arrow.ts | 44 +- 5 files changed, 1242 insertions(+), 1017 deletions(-) diff --git a/node/src/arrow.ts b/node/src/arrow.ts index 792c68f2..6aac34b9 100644 --- a/node/src/arrow.ts +++ b/node/src/arrow.ts @@ -27,23 +27,23 @@ import { RecordBatch, makeData, Struct, - Float, + type Float, DataType, Binary, Float32 -} from 'apache-arrow' -import { type EmbeddingFunction } from './index' -import { sanitizeSchema } from './sanitize' +} from "apache-arrow"; +import { type EmbeddingFunction } from "./index"; +import { sanitizeSchema } from "./sanitize"; /* * Options to control how a column should be converted to a vector array */ export class VectorColumnOptions { /** Vector column type. */ - type: Float = new Float32() + type: Float = new Float32(); - constructor (values?: Partial) { - Object.assign(this, values) + constructor(values?: Partial) { + Object.assign(this, values); } } @@ -60,7 +60,7 @@ export class MakeArrowTableOptions { * The schema must be specified if there are no records (e.g. to make * an empty table) */ - schema?: Schema + schema?: Schema; /* * Mapping from vector column name to expected type @@ -80,7 +80,9 @@ export class MakeArrowTableOptions { */ vectorColumns: Record = { vector: new VectorColumnOptions() - } + }; + + embeddings?: EmbeddingFunction; /** * If true then string columns will be encoded with dictionary encoding @@ -91,10 +93,10 @@ export class MakeArrowTableOptions { * * If `schema` is provided then this property is ignored. */ - dictionaryEncodeStrings: boolean = false + dictionaryEncodeStrings: boolean = false; - constructor (values?: Partial) { - Object.assign(this, values) + constructor(values?: Partial) { + Object.assign(this, values); } } @@ -193,59 +195,68 @@ export class MakeArrowTableOptions { * assert.deepEqual(table.schema, schema) * ``` */ -export function makeArrowTable ( +export function makeArrowTable( data: Array>, options?: Partial ): ArrowTable { - if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) { - throw new Error('At least one record or a schema needs to be provided') + if ( + data.length === 0 && + (options?.schema === undefined || options?.schema === null) + ) { + throw new Error("At least one record or a schema needs to be provided"); } - const opt = new MakeArrowTableOptions(options !== undefined ? options : {}) + const opt = new MakeArrowTableOptions(options !== undefined ? options : {}); if (opt.schema !== undefined && opt.schema !== null) { - opt.schema = sanitizeSchema(opt.schema) + opt.schema = sanitizeSchema(opt.schema); + opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings); } - const columns: Record = {} + + const columns: Record = {}; // TODO: sample dataset to find missing columns // Prefer the field ordering of the schema, if present - const columnNames = ((opt.schema) != null) ? (opt.schema.names as string[]) : Object.keys(data[0]) + const columnNames = + opt.schema != null ? (opt.schema.names as string[]) : Object.keys(data[0]); for (const colName of columnNames) { - if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) { + if ( + data.length !== 0 && + !Object.prototype.hasOwnProperty.call(data[0], colName) + ) { // The field is present in the schema, but not in the data, skip it - continue + continue; } // Extract a single column from the records (transpose from row-major to col-major) - let values = data.map((datum) => datum[colName]) + let values = data.map((datum) => datum[colName]); // By default (type === undefined) arrow will infer the type from the JS type - let type + let type; if (opt.schema !== undefined) { // If there is a schema provided, then use that for the type instead - type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type + type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type; if (DataType.isInt(type) && type.bitWidth === 64) { // wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051 values = values.map((v) => { if (v === null) { - return v + return v; } - return BigInt(v) - }) + return BigInt(v); + }); } } else { // Otherwise, check to see if this column is one of the vector columns // defined by opt.vectorColumns and, if so, use the fixed size list type - const vectorColumnOptions = opt.vectorColumns[colName] + const vectorColumnOptions = opt.vectorColumns[colName]; if (vectorColumnOptions !== undefined) { - type = newVectorType(values[0].length, vectorColumnOptions.type) + type = newVectorType(values[0].length, vectorColumnOptions.type); } } try { // Convert an Array of JS values to an arrow vector - columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings) + columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings); } catch (error: unknown) { // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - throw Error(`Could not convert column "${colName}" to Arrow: ${error}`) + throw Error(`Could not convert column "${colName}" to Arrow: ${error}`); } } @@ -260,97 +271,116 @@ export function makeArrowTable ( // To work around this we first create a table with the wrong schema and // then patch the schema of the batches so we can use // `new ArrowTable(schema, batches)` which does not do any schema inference - const firstTable = new ArrowTable(columns) - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data)) - return new ArrowTable(opt.schema, batchesFixed) + const firstTable = new ArrowTable(columns); + const batchesFixed = firstTable.batches.map( + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + (batch) => new RecordBatch(opt.schema!, batch.data) + ); + return new ArrowTable(opt.schema, batchesFixed); } else { - return new ArrowTable(columns) + return new ArrowTable(columns); } } /** * Create an empty Arrow table with the provided schema */ -export function makeEmptyTable (schema: Schema): ArrowTable { - return makeArrowTable([], { schema }) +export function makeEmptyTable(schema: Schema): ArrowTable { + return makeArrowTable([], { schema }); } // Helper function to convert Array> to a variable sized list array -function makeListVector (lists: any[][]): Vector { +function makeListVector(lists: any[][]): Vector { if (lists.length === 0 || lists[0].length === 0) { - throw Error('Cannot infer list vector from empty array or empty list') + throw Error("Cannot infer list vector from empty array or empty list"); } - const sampleList = lists[0] - let inferredType + const sampleList = lists[0]; + let inferredType; try { - const sampleVector = makeVector(sampleList) - inferredType = sampleVector.type + const sampleVector = makeVector(sampleList); + inferredType = sampleVector.type; } catch (error: unknown) { // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`) + throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`); } const listBuilder = makeBuilder({ - type: new List(new Field('item', inferredType, true)) - }) + type: new List(new Field("item", inferredType, true)) + }); for (const list of lists) { - listBuilder.append(list) + listBuilder.append(list); } - return listBuilder.finish().toVector() + return listBuilder.finish().toVector(); } // Helper function to convert an Array of JS values to an Arrow Vector -function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector { +function makeVector( + values: any[], + type?: DataType, + stringAsDictionary?: boolean +): Vector { if (type !== undefined) { // No need for inference, let Arrow create it - return vectorFromArray(values, type) + return vectorFromArray(values, type); } if (values.length === 0) { - throw Error('makeVector requires at least one value or the type must be specfied') + throw Error( + "makeVector requires at least one value or the type must be specfied" + ); } - const sampleValue = values.find(val => val !== null && val !== undefined) + const sampleValue = values.find((val) => val !== null && val !== undefined); if (sampleValue === undefined) { - throw Error('makeVector cannot infer the type if all values are null or undefined') + throw Error( + "makeVector cannot infer the type if all values are null or undefined" + ); } if (Array.isArray(sampleValue)) { // Default Arrow inference doesn't handle list types - return makeListVector(values) + return makeListVector(values); } else if (Buffer.isBuffer(sampleValue)) { // Default Arrow inference doesn't handle Buffer - return vectorFromArray(values, new Binary()) - } else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) { + return vectorFromArray(values, new Binary()); + } else if ( + !(stringAsDictionary ?? false) && + (typeof sampleValue === "string" || sampleValue instanceof String) + ) { // If the type is string then don't use Arrow's default inference unless dictionaries are requested // because it will always use dictionary encoding for strings - return vectorFromArray(values, new Utf8()) + return vectorFromArray(values, new Utf8()); } else { // Convert a JS array of values to an arrow vector - return vectorFromArray(values) + return vectorFromArray(values); } } -async function applyEmbeddings (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { +async function applyEmbeddings( + table: ArrowTable, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { if (embeddings == null) { - return table + return table; } if (schema !== undefined && schema !== null) { - schema = sanitizeSchema(schema) + schema = sanitizeSchema(schema); } // Convert from ArrowTable to Record const colEntries = [...Array(table.numCols).keys()].map((_, idx) => { - const name = table.schema.fields[idx].name + const name = table.schema.fields[idx].name; // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const vec = table.getChildAt(idx)! - return [name, vec] - }) - const newColumns = Object.fromEntries(colEntries) + const vec = table.getChildAt(idx)!; + return [name, vec]; + }); + const newColumns = Object.fromEntries(colEntries); - const sourceColumn = newColumns[embeddings.sourceColumn] - const destColumn = embeddings.destColumn ?? 'vector' - const innerDestType = embeddings.embeddingDataType ?? new Float32() + const sourceColumn = newColumns[embeddings.sourceColumn]; + const destColumn = embeddings.destColumn ?? "vector"; + const innerDestType = embeddings.embeddingDataType ?? new Float32(); if (sourceColumn === undefined) { - throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`) + throw new Error( + `Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data` + ); } if (table.numRows === 0) { @@ -358,45 +388,60 @@ async function applyEmbeddings (table: ArrowTable, embeddings?: EmbeddingFunc // We have an empty table and it already has the embedding column so no work needs to be done // Note: we don't return an error like we did below because this is a common occurrence. For example, // if we call convertToTable with 0 records and a schema that includes the embedding - return table + return table; } if (embeddings.embeddingDimension !== undefined) { - const destType = newVectorType(embeddings.embeddingDimension, innerDestType) - newColumns[destColumn] = makeVector([], destType) + const destType = newVectorType( + embeddings.embeddingDimension, + innerDestType + ); + newColumns[destColumn] = makeVector([], destType); } else if (schema != null) { - const destField = schema.fields.find(f => f.name === destColumn) + const destField = schema.fields.find((f) => f.name === destColumn); if (destField != null) { - newColumns[destColumn] = makeVector([], destField.type) + newColumns[destColumn] = makeVector([], destField.type); } else { - throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`) + throw new Error( + `Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'` + ); } } else { - throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`') + throw new Error( + "Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`" + ); } } else { if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) { - throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`) + throw new Error( + `Attempt to apply embeddings to table failed because column ${destColumn} already existed` + ); } if (table.batches.length > 1) { - throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch') + throw new Error( + "Internal error: `makeArrowTable` unexpectedly created a table with more than one batch" + ); } - const values = sourceColumn.toArray() - const vectors = await embeddings.embed(values as T[]) + const values = sourceColumn.toArray(); + const vectors = await embeddings.embed(values as T[]); if (vectors.length !== values.length) { - throw new Error('Embedding function did not return an embedding for each input element') + throw new Error( + "Embedding function did not return an embedding for each input element" + ); } - const destType = newVectorType(vectors[0].length, innerDestType) - newColumns[destColumn] = makeVector(vectors, destType) + const destType = newVectorType(vectors[0].length, innerDestType); + newColumns[destColumn] = makeVector(vectors, destType); } - const newTable = new ArrowTable(newColumns) + const newTable = new ArrowTable(newColumns); if (schema != null) { - if (schema.fields.find(f => f.name === destColumn) === undefined) { - throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`) + if (schema.fields.find((f) => f.name === destColumn) === undefined) { + throw new Error( + `When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing` + ); } - return alignTable(newTable, schema) + return alignTable(newTable, schema); } - return newTable + return newTable; } /* @@ -417,21 +462,24 @@ async function applyEmbeddings (table: ArrowTable, embeddings?: EmbeddingFunc * embedding columns. If no schema is provded then embedding columns will * be placed at the end of the table, after all of the input columns. */ -export async function convertToTable ( +export async function convertToTable( data: Array>, embeddings?: EmbeddingFunction, makeTableOptions?: Partial ): Promise { - const table = makeArrowTable(data, makeTableOptions) - return await applyEmbeddings(table, embeddings, makeTableOptions?.schema) + const table = makeArrowTable(data, makeTableOptions); + return await applyEmbeddings(table, embeddings, makeTableOptions?.schema); } // Creates the Arrow Type for a Vector column with dimension `dim` -function newVectorType (dim: number, innerType: T): FixedSizeList { +function newVectorType( + dim: number, + innerType: T +): FixedSizeList { // Somewhere we always default to have the elements nullable, so we need to set it to true // otherwise we often get schema mismatches because the stored data always has schema with nullable elements - const children = new Field('item', innerType, true) - return new FixedSizeList(dim, children) + const children = new Field("item", innerType, true); + return new FixedSizeList(dim, children); } /** @@ -441,17 +489,17 @@ function newVectorType (dim: number, innerType: T): FixedSizeL * * `schema` is required if data is empty */ -export async function fromRecordsToBuffer ( +export async function fromRecordsToBuffer( data: Array>, embeddings?: EmbeddingFunction, schema?: Schema ): Promise { if (schema !== undefined && schema !== null) { - schema = sanitizeSchema(schema) + schema = sanitizeSchema(schema); } - const table = await convertToTable(data, embeddings, { schema }) - const writer = RecordBatchFileWriter.writeAll(table) - return Buffer.from(await writer.toUint8Array()) + const table = await convertToTable(data, embeddings, { schema, embeddings }); + const writer = RecordBatchFileWriter.writeAll(table); + return Buffer.from(await writer.toUint8Array()); } /** @@ -461,17 +509,17 @@ export async function fromRecordsToBuffer ( * * `schema` is required if data is empty */ -export async function fromRecordsToStreamBuffer ( +export async function fromRecordsToStreamBuffer( data: Array>, embeddings?: EmbeddingFunction, schema?: Schema ): Promise { if (schema !== null && schema !== undefined) { - schema = sanitizeSchema(schema) + schema = sanitizeSchema(schema); } - const table = await convertToTable(data, embeddings, { schema }) - const writer = RecordBatchStreamWriter.writeAll(table) - return Buffer.from(await writer.toUint8Array()) + const table = await convertToTable(data, embeddings, { schema }); + const writer = RecordBatchStreamWriter.writeAll(table); + return Buffer.from(await writer.toUint8Array()); } /** @@ -482,17 +530,17 @@ export async function fromRecordsToStreamBuffer ( * * `schema` is required if the table is empty */ -export async function fromTableToBuffer ( +export async function fromTableToBuffer( table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema ): Promise { if (schema !== null && schema !== undefined) { - schema = sanitizeSchema(schema) + schema = sanitizeSchema(schema); } - const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) - const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings) - return Buffer.from(await writer.toUint8Array()) + const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema); + const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings); + return Buffer.from(await writer.toUint8Array()); } /** @@ -503,49 +551,87 @@ export async function fromTableToBuffer ( * * `schema` is required if the table is empty */ -export async function fromTableToStreamBuffer ( +export async function fromTableToStreamBuffer( table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema ): Promise { if (schema !== null && schema !== undefined) { - schema = sanitizeSchema(schema) + schema = sanitizeSchema(schema); } - const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) - const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings) - return Buffer.from(await writer.toUint8Array()) + const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema); + const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings); + return Buffer.from(await writer.toUint8Array()); } -function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch { - const alignedChildren = [] +function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch { + const alignedChildren = []; for (const field of schema.fields) { const indexInBatch = batch.schema.fields?.findIndex( (f) => f.name === field.name - ) + ); if (indexInBatch < 0) { throw new Error( `The column ${field.name} was not found in the Arrow Table` - ) + ); } - alignedChildren.push(batch.data.children[indexInBatch]) + alignedChildren.push(batch.data.children[indexInBatch]); } const newData = makeData({ type: new Struct(schema.fields), length: batch.numRows, nullCount: batch.nullCount, children: alignedChildren - }) - return new RecordBatch(schema, newData) + }); + return new RecordBatch(schema, newData); } -function alignTable (table: ArrowTable, schema: Schema): ArrowTable { +function alignTable(table: ArrowTable, schema: Schema): ArrowTable { const alignedBatches = table.batches.map((batch) => alignBatch(batch, schema) - ) - return new ArrowTable(schema, alignedBatches) + ); + return new ArrowTable(schema, alignedBatches); } // Creates an empty Arrow Table -export function createEmptyTable (schema: Schema): ArrowTable { - return new ArrowTable(sanitizeSchema(schema)) +export function createEmptyTable(schema: Schema): ArrowTable { + return new ArrowTable(sanitizeSchema(schema)); +} + +function validateSchemaEmbeddings( + schema: Schema, + data: Array>, + embeddings: EmbeddingFunction | undefined +) { + const fields = []; + const missingEmbeddingFields = []; + + // First we check if the field is a `FixedSizeList` + // Then we check if the data contains the field + // if it does not, we add it to the list of missing embedding fields + // Finally, we check if those missing embedding fields are `this._embeddings` + // if they are not, we throw an error + for (const field of schema.fields) { + if (field.type instanceof FixedSizeList) { + if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { + missingEmbeddingFields.push(field); + } else { + fields.push(field); + } + } else { + fields.push(field); + } + } + + if (missingEmbeddingFields.length > 0 && embeddings === undefined) { + console.log({ missingEmbeddingFields, embeddings }); + + throw new Error( + `Table has embeddings: "${missingEmbeddingFields + .map((f) => f.name) + .join(",")}", but no embedding function was provided` + ); + } + + return new Schema(fields); } diff --git a/node/src/index.ts b/node/src/index.ts index f018e3fb..7a487da2 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { type Schema, Table as ArrowTable, tableFromIPC } from 'apache-arrow' +import { type Schema, Table as ArrowTable, tableFromIPC } from "apache-arrow"; import { createEmptyTable, fromRecordsToBuffer, fromTableToBuffer, makeArrowTable -} from './arrow' -import type { EmbeddingFunction } from './embedding/embedding_function' -import { RemoteConnection } from './remote' -import { Query } from './query' -import { isEmbeddingFunction } from './embedding/embedding_function' -import { type Literal, toSQL } from './util' -import { type HttpMiddleware } from './middleware' +} from "./arrow"; +import type { EmbeddingFunction } from "./embedding/embedding_function"; +import { RemoteConnection } from "./remote"; +import { Query } from "./query"; +import { isEmbeddingFunction } from "./embedding/embedding_function"; +import { type Literal, toSQL } from "./util"; + +import { type HttpMiddleware } from "./middleware"; const { databaseNew, @@ -48,14 +49,18 @@ const { tableAlterColumns, tableDropColumns // eslint-disable-next-line @typescript-eslint/no-var-requires -} = require('../native.js') +} = require("../native.js"); -export { Query } -export type { EmbeddingFunction } -export { OpenAIEmbeddingFunction } from './embedding/openai' -export { convertToTable, makeArrowTable, type MakeArrowTableOptions } from './arrow' +export { Query }; +export type { EmbeddingFunction }; +export { OpenAIEmbeddingFunction } from "./embedding/openai"; +export { + convertToTable, + makeArrowTable, + type MakeArrowTableOptions +} from "./arrow"; -const defaultAwsRegion = 'us-west-2' +const defaultAwsRegion = "us-west-2"; export interface AwsCredentials { accessKeyId: string @@ -128,19 +133,19 @@ export interface ConnectionOptions { readConsistencyInterval?: number } -function getAwsArgs (opts: ConnectionOptions): any[] { - const callArgs: any[] = [] - const awsCredentials = opts.awsCredentials +function getAwsArgs(opts: ConnectionOptions): any[] { + const callArgs: any[] = []; + const awsCredentials = opts.awsCredentials; if (awsCredentials !== undefined) { - callArgs.push(awsCredentials.accessKeyId) - callArgs.push(awsCredentials.secretKey) - callArgs.push(awsCredentials.sessionToken) + callArgs.push(awsCredentials.accessKeyId); + callArgs.push(awsCredentials.secretKey); + callArgs.push(awsCredentials.sessionToken); } else { - callArgs.fill(undefined, 0, 3) + callArgs.fill(undefined, 0, 3); } - callArgs.push(opts.awsRegion) - return callArgs + callArgs.push(opts.awsRegion); + return callArgs; } export interface CreateTableOptions { @@ -173,56 +178,56 @@ export interface CreateTableOptions { * * @see {@link ConnectionOptions} for more details on the URI format. */ -export async function connect (uri: string): Promise +export async function connect(uri: string): Promise; /** * Connect to a LanceDB instance with connection options. * * @param opts The {@link ConnectionOptions} to use when connecting to the database. */ -export async function connect ( +export async function connect( opts: Partial -): Promise -export async function connect ( +): Promise; +export async function connect( arg: string | Partial ): Promise { - let opts: ConnectionOptions - if (typeof arg === 'string') { - opts = { uri: arg } + let opts: ConnectionOptions; + if (typeof arg === "string") { + opts = { uri: arg }; } else { - const keys = Object.keys(arg) - if (keys.length === 1 && keys[0] === 'uri' && typeof arg.uri === 'string') { - opts = { uri: arg.uri } + const keys = Object.keys(arg); + if (keys.length === 1 && keys[0] === "uri" && typeof arg.uri === "string") { + opts = { uri: arg.uri }; } else { opts = Object.assign( { - uri: '', + uri: "", awsCredentials: undefined, awsRegion: defaultAwsRegion, apiKey: undefined, region: defaultAwsRegion }, arg - ) + ); } } - if (opts.uri.startsWith('db://')) { + if (opts.uri.startsWith("db://")) { // Remote connection - return new RemoteConnection(opts) + return new RemoteConnection(opts); } const storageOptions = opts.storageOptions ?? {}; if (opts.awsCredentials?.accessKeyId !== undefined) { - storageOptions.aws_access_key_id = opts.awsCredentials.accessKeyId + storageOptions.aws_access_key_id = opts.awsCredentials.accessKeyId; } if (opts.awsCredentials?.secretKey !== undefined) { - storageOptions.aws_secret_access_key = opts.awsCredentials.secretKey + storageOptions.aws_secret_access_key = opts.awsCredentials.secretKey; } if (opts.awsCredentials?.sessionToken !== undefined) { - storageOptions.aws_session_token = opts.awsCredentials.sessionToken + storageOptions.aws_session_token = opts.awsCredentials.sessionToken; } if (opts.awsRegion !== undefined) { - storageOptions.region = opts.awsRegion + storageOptions.region = opts.awsRegion; } // It's a pain to pass a record to Rust, so we convert it to an array of key-value pairs const storageOptionsArr = Object.entries(storageOptions); @@ -231,8 +236,8 @@ export async function connect ( opts.uri, storageOptionsArr, opts.readConsistencyInterval - ) - return new LocalConnection(db, opts) + ); + return new LocalConnection(db, opts); } /** @@ -533,7 +538,11 @@ export interface Table { * @param data the new data to insert * @param args parameters controlling how the operation should behave */ - mergeInsert: (on: string, data: Array> | ArrowTable, args: MergeInsertArgs) => Promise + mergeInsert: ( + on: string, + data: Array> | ArrowTable, + args: MergeInsertArgs + ) => Promise /** * List the indicies on this table. @@ -558,7 +567,9 @@ export interface Table { * expressions will be evaluated for each row in the * table, and can reference existing columns in the table. */ - addColumns(newColumnTransforms: Array<{ name: string, valueSql: string }>): Promise + addColumns( + newColumnTransforms: Array<{ name: string, valueSql: string }> + ): Promise /** * Alter the name or nullability of columns. @@ -699,23 +710,23 @@ export interface IndexStats { * A connection to a LanceDB database. */ export class LocalConnection implements Connection { - private readonly _options: () => ConnectionOptions - private readonly _db: any + private readonly _options: () => ConnectionOptions; + private readonly _db: any; - constructor (db: any, options: ConnectionOptions) { - this._options = () => options - this._db = db + constructor(db: any, options: ConnectionOptions) { + this._options = () => options; + this._db = db; } - get uri (): string { - return this._options().uri + get uri(): string { + return this._options().uri; } /** * Get the names of all tables in the database. */ - async tableNames (): Promise { - return databaseTableNames.call(this._db) + async tableNames(): Promise { + return databaseTableNames.call(this._db); } /** @@ -723,7 +734,7 @@ export class LocalConnection implements Connection { * * @param name The name of the table. */ - async openTable (name: string): Promise + async openTable(name: string): Promise
; /** * Open a table in the database. @@ -734,23 +745,20 @@ export class LocalConnection implements Connection { async openTable( name: string, embeddings: EmbeddingFunction - ): Promise> + ): Promise>; async openTable( name: string, embeddings?: EmbeddingFunction - ): Promise> + ): Promise>; async openTable( name: string, embeddings?: EmbeddingFunction ): Promise> { - const tbl = await databaseOpenTable.call( - this._db, - name, - ) + const tbl = await databaseOpenTable.call(this._db, name); if (embeddings !== undefined) { - return new LocalTable(tbl, name, this._options(), embeddings) + return new LocalTable(tbl, name, this._options(), embeddings); } else { - return new LocalTable(tbl, name, this._options()) + return new LocalTable(tbl, name, this._options()); } } @@ -760,32 +768,32 @@ export class LocalConnection implements Connection { optsOrEmbedding?: WriteOptions | EmbeddingFunction, opt?: WriteOptions ): Promise> { - if (typeof name === 'string') { - let writeOptions: WriteOptions = new DefaultWriteOptions() + if (typeof name === "string") { + let writeOptions: WriteOptions = new DefaultWriteOptions(); if (opt !== undefined && isWriteOptions(opt)) { - writeOptions = opt + writeOptions = opt; } else if ( optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding) ) { - writeOptions = optsOrEmbedding + writeOptions = optsOrEmbedding; } - let embeddings: undefined | EmbeddingFunction + let embeddings: undefined | EmbeddingFunction; if ( optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding) ) { - embeddings = optsOrEmbedding + embeddings = optsOrEmbedding; } return await this.createTableImpl({ name, data, embeddingFunction: embeddings, writeOptions - }) + }); } - return await this.createTableImpl(name) + return await this.createTableImpl(name); } private async createTableImpl({ @@ -801,27 +809,27 @@ export class LocalConnection implements Connection { embeddingFunction?: EmbeddingFunction | undefined writeOptions?: WriteOptions | undefined }): Promise> { - let buffer: Buffer + let buffer: Buffer; - function isEmpty ( + function isEmpty( data: Array> | ArrowTable ): boolean { if (data instanceof ArrowTable) { - return data.data.length === 0 + return data.data.length === 0; } - return data.length === 0 + return data.length === 0; } if (data === undefined || isEmpty(data)) { if (schema === undefined) { - throw new Error('Either data or schema needs to defined') + throw new Error("Either data or schema needs to defined"); } - buffer = await fromTableToBuffer(createEmptyTable(schema)) + buffer = await fromTableToBuffer(createEmptyTable(schema)); } else if (data instanceof ArrowTable) { - buffer = await fromTableToBuffer(data, embeddingFunction, schema) + buffer = await fromTableToBuffer(data, embeddingFunction, schema); } else { // data is Array> - buffer = await fromRecordsToBuffer(data, embeddingFunction, schema) + buffer = await fromRecordsToBuffer(data, embeddingFunction, schema); } const tbl = await tableCreate.call( @@ -830,11 +838,11 @@ export class LocalConnection implements Connection { buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options()) - ) + ); if (embeddingFunction !== undefined) { - return new LocalTable(tbl, name, this._options(), embeddingFunction) + return new LocalTable(tbl, name, this._options(), embeddingFunction); } else { - return new LocalTable(tbl, name, this._options()) + return new LocalTable(tbl, name, this._options()); } } @@ -842,69 +850,69 @@ export class LocalConnection implements Connection { * Drop an existing table. * @param name The name of the table to drop. */ - async dropTable (name: string): Promise { - await databaseDropTable.call(this._db, name) + async dropTable(name: string): Promise { + await databaseDropTable.call(this._db, name); } - withMiddleware (middleware: HttpMiddleware): Connection { - return this + withMiddleware(middleware: HttpMiddleware): Connection { + return this; } } export class LocalTable implements Table { - private _tbl: any - private readonly _name: string - private readonly _isElectron: boolean - private readonly _embeddings?: EmbeddingFunction - private readonly _options: () => ConnectionOptions + private _tbl: any; + private readonly _name: string; + private readonly _isElectron: boolean; + private readonly _embeddings?: EmbeddingFunction; + private readonly _options: () => ConnectionOptions; - constructor (tbl: any, name: string, options: ConnectionOptions) + constructor(tbl: any, name: string, options: ConnectionOptions); /** * @param tbl * @param name * @param options * @param embeddings An embedding function to use when interacting with this table */ - constructor ( + constructor( tbl: any, name: string, options: ConnectionOptions, embeddings: EmbeddingFunction - ) - constructor ( + ); + constructor( tbl: any, name: string, options: ConnectionOptions, embeddings?: EmbeddingFunction ) { - this._tbl = tbl - this._name = name - this._embeddings = embeddings - this._options = () => options - this._isElectron = this.checkElectron() + this._tbl = tbl; + this._name = name; + this._embeddings = embeddings; + this._options = () => options; + this._isElectron = this.checkElectron(); } - get name (): string { - return this._name + get name(): string { + return this._name; } /** * Creates a search query to find the nearest neighbors of the given search term * @param query The query search term */ - search (query: T): Query { - return new Query(query, this._tbl, this._embeddings) + search(query: T): Query { + return new Query(query, this._tbl, this._embeddings); } /** * Creates a filter query to find all rows matching the specified criteria * @param value The filter criteria (like SQL where clause syntax) */ - filter (value: string): Query { - return new Query(undefined, this._tbl, this._embeddings).filter(value) + filter(value: string): Query { + return new Query(undefined, this._tbl, this._embeddings).filter(value); } - where = this.filter + where = this.filter; /** * Insert records into this Table. @@ -912,16 +920,19 @@ export class LocalTable implements Table { * @param data Records to be inserted into the Table * @return The number of rows added to the table */ - async add ( + async add( data: Array> | ArrowTable ): Promise { - const schema = await this.schema - let tbl: ArrowTable + const schema = await this.schema; + + let tbl: ArrowTable; + if (data instanceof ArrowTable) { - tbl = data + tbl = data; } else { - tbl = makeArrowTable(data, { schema }) + tbl = makeArrowTable(data, { schema, embeddings: this._embeddings }); } + return tableAdd .call( this._tbl, @@ -930,8 +941,8 @@ export class LocalTable implements Table { ...getAwsArgs(this._options()) ) .then((newTable: any) => { - this._tbl = newTable - }) + this._tbl = newTable; + }); } /** @@ -940,14 +951,14 @@ export class LocalTable implements Table { * @param data Records to be inserted into the Table * @return The number of rows added to the table */ - async overwrite ( + async overwrite( data: Array> | ArrowTable ): Promise { - let buffer: Buffer + let buffer: Buffer; if (data instanceof ArrowTable) { - buffer = await fromTableToBuffer(data, this._embeddings) + buffer = await fromTableToBuffer(data, this._embeddings); } else { - buffer = await fromRecordsToBuffer(data, this._embeddings) + buffer = await fromRecordsToBuffer(data, this._embeddings); } return tableAdd .call( @@ -957,8 +968,8 @@ export class LocalTable implements Table { ...getAwsArgs(this._options()) ) .then((newTable: any) => { - this._tbl = newTable - }) + this._tbl = newTable; + }); } /** @@ -966,26 +977,26 @@ export class LocalTable implements Table { * * @param indexParams The parameters of this Index, @see VectorIndexParams. */ - async createIndex (indexParams: VectorIndexParams): Promise { + async createIndex(indexParams: VectorIndexParams): Promise { return tableCreateVectorIndex .call(this._tbl, indexParams) .then((newTable: any) => { - this._tbl = newTable - }) + this._tbl = newTable; + }); } - async createScalarIndex (column: string, replace?: boolean): Promise { + async createScalarIndex(column: string, replace?: boolean): Promise { if (replace === undefined) { - replace = true + replace = true; } - return tableCreateScalarIndex.call(this._tbl, column, replace) + return tableCreateScalarIndex.call(this._tbl, column, replace); } /** * Returns the number of rows in this table. */ - async countRows (filter?: string): Promise { - return tableCountRows.call(this._tbl, filter) + async countRows(filter?: string): Promise { + return tableCountRows.call(this._tbl, filter); } /** @@ -993,10 +1004,10 @@ export class LocalTable implements Table { * * @param filter A filter in the same format used by a sql WHERE clause. */ - async delete (filter: string): Promise { + async delete(filter: string): Promise { return tableDelete.call(this._tbl, filter).then((newTable: any) => { - this._tbl = newTable - }) + this._tbl = newTable; + }); } /** @@ -1006,55 +1017,65 @@ export class LocalTable implements Table { * * @returns */ - async update (args: UpdateArgs | UpdateSqlArgs): Promise { - let filter: string | null - let updates: Record + async update(args: UpdateArgs | UpdateSqlArgs): Promise { + let filter: string | null; + let updates: Record; - if ('valuesSql' in args) { - filter = args.where ?? null - updates = args.valuesSql + if ("valuesSql" in args) { + filter = args.where ?? null; + updates = args.valuesSql; } else { - filter = args.where ?? null - updates = {} + filter = args.where ?? null; + updates = {}; for (const [key, value] of Object.entries(args.values)) { - updates[key] = toSQL(value) + updates[key] = toSQL(value); } } return tableUpdate .call(this._tbl, filter, updates) .then((newTable: any) => { - this._tbl = newTable - }) + this._tbl = newTable; + }); } - async mergeInsert (on: string, data: Array> | ArrowTable, args: MergeInsertArgs): Promise { - let whenMatchedUpdateAll = false - let whenMatchedUpdateAllFilt = null - if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) { - whenMatchedUpdateAll = true + async mergeInsert( + on: string, + data: Array> | ArrowTable, + args: MergeInsertArgs + ): Promise { + let whenMatchedUpdateAll = false; + let whenMatchedUpdateAllFilt = null; + if ( + args.whenMatchedUpdateAll !== undefined && + args.whenMatchedUpdateAll !== null + ) { + whenMatchedUpdateAll = true; if (args.whenMatchedUpdateAll !== true) { - whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll + whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll; } } - const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false - let whenNotMatchedBySourceDelete = false - let whenNotMatchedBySourceDeleteFilt = null - if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) { - whenNotMatchedBySourceDelete = true + const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false; + let whenNotMatchedBySourceDelete = false; + let whenNotMatchedBySourceDeleteFilt = null; + if ( + args.whenNotMatchedBySourceDelete !== undefined && + args.whenNotMatchedBySourceDelete !== null + ) { + whenNotMatchedBySourceDelete = true; if (args.whenNotMatchedBySourceDelete !== true) { - whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete + whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete; } } - const schema = await this.schema - let tbl: ArrowTable + const schema = await this.schema; + let tbl: ArrowTable; if (data instanceof ArrowTable) { - tbl = data + tbl = data; } else { - tbl = makeArrowTable(data, { schema }) + tbl = makeArrowTable(data, { schema }); } - const buffer = await fromTableToBuffer(tbl, this._embeddings, schema) + const buffer = await fromTableToBuffer(tbl, this._embeddings, schema); this._tbl = await tableMergeInsert.call( this._tbl, @@ -1065,7 +1086,7 @@ export class LocalTable implements Table { whenNotMatchedBySourceDelete, whenNotMatchedBySourceDeleteFilt, buffer - ) + ); } /** @@ -1083,16 +1104,16 @@ export class LocalTable implements Table { * uphold this promise can lead to corrupted tables. * @returns */ - async cleanupOldVersions ( + async cleanupOldVersions( olderThan?: number, deleteUnverified?: boolean ): Promise { return tableCleanupOldVersions .call(this._tbl, olderThan, deleteUnverified) .then((res: { newTable: any, metrics: CleanupStats }) => { - this._tbl = res.newTable - return res.metrics - }) + this._tbl = res.newTable; + return res.metrics; + }); } /** @@ -1106,62 +1127,64 @@ export class LocalTable implements Table { * for most tables. * @returns Metrics about the compaction operation. */ - async compactFiles (options?: CompactionOptions): Promise { - const optionsArg = options ?? {} + async compactFiles(options?: CompactionOptions): Promise { + const optionsArg = options ?? {}; return tableCompactFiles .call(this._tbl, optionsArg) .then((res: { newTable: any, metrics: CompactionMetrics }) => { - this._tbl = res.newTable - return res.metrics - }) + this._tbl = res.newTable; + return res.metrics; + }); } - async listIndices (): Promise { - return tableListIndices.call(this._tbl) + async listIndices(): Promise { + return tableListIndices.call(this._tbl); } - async indexStats (indexUuid: string): Promise { - return tableIndexStats.call(this._tbl, indexUuid) + async indexStats(indexUuid: string): Promise { + return tableIndexStats.call(this._tbl, indexUuid); } - get schema (): Promise { + get schema(): Promise { // empty table - return this.getSchema() + return this.getSchema(); } - private async getSchema (): Promise { - const buffer = await tableSchema.call(this._tbl, this._isElectron) - const table = tableFromIPC(buffer) - return table.schema + private async getSchema(): Promise { + const buffer = await tableSchema.call(this._tbl, this._isElectron); + const table = tableFromIPC(buffer); + return table.schema; } // See https://github.com/electron/electron/issues/2288 - private checkElectron (): boolean { + private checkElectron(): boolean { try { // eslint-disable-next-line no-prototype-builtins return ( - Object.prototype.hasOwnProperty.call(process?.versions, 'electron') || - navigator?.userAgent?.toLowerCase()?.includes(' electron') - ) + Object.prototype.hasOwnProperty.call(process?.versions, "electron") || + navigator?.userAgent?.toLowerCase()?.includes(" electron") + ); } catch (e) { - return false + return false; } } - async addColumns (newColumnTransforms: Array<{ name: string, valueSql: string }>): Promise { - return tableAddColumns.call(this._tbl, newColumnTransforms) + async addColumns( + newColumnTransforms: Array<{ name: string, valueSql: string }> + ): Promise { + return tableAddColumns.call(this._tbl, newColumnTransforms); } - async alterColumns (columnAlterations: ColumnAlteration[]): Promise { - return tableAlterColumns.call(this._tbl, columnAlterations) + async alterColumns(columnAlterations: ColumnAlteration[]): Promise { + return tableAlterColumns.call(this._tbl, columnAlterations); } - async dropColumns (columnNames: string[]): Promise { - return tableDropColumns.call(this._tbl, columnNames) + async dropColumns(columnNames: string[]): Promise { + return tableDropColumns.call(this._tbl, columnNames); } - withMiddleware (middleware: HttpMiddleware): Table { - return this + withMiddleware(middleware: HttpMiddleware): Table { + return this; } } @@ -1184,7 +1207,7 @@ export interface CompactionOptions { */ targetRowsPerFragment?: number /** - * The maximum number of rows per group. Defaults to 1024. + * The maximum number of T per group. Defaults to 1024. */ maxRowsPerGroup?: number /** @@ -1284,21 +1307,21 @@ export interface IvfPQIndexConfig { */ index_cache_size?: number - type: 'ivf_pq' + type: "ivf_pq" } -export type VectorIndexParams = IvfPQIndexConfig +export type VectorIndexParams = IvfPQIndexConfig; /** * Write mode for writing a table. */ export enum WriteMode { /** Create a new {@link Table}. */ - Create = 'create', + Create = "create", /** Overwrite the existing {@link Table} if presented. */ - Overwrite = 'overwrite', + Overwrite = "overwrite", /** Append new data to the table. */ - Append = 'append', + Append = "append", } /** @@ -1310,14 +1333,14 @@ export interface WriteOptions { } export class DefaultWriteOptions implements WriteOptions { - writeMode = WriteMode.Create + writeMode = WriteMode.Create; } -export function isWriteOptions (value: any): value is WriteOptions { +export function isWriteOptions(value: any): value is WriteOptions { return ( Object.keys(value).length === 1 && - (value.writeMode === undefined || typeof value.writeMode === 'string') - ) + (value.writeMode === undefined || typeof value.writeMode === "string") + ); } /** @@ -1327,15 +1350,15 @@ export enum MetricType { /** * Euclidean distance */ - L2 = 'l2', + L2 = "l2", /** * Cosine distance */ - Cosine = 'cosine', + Cosine = "cosine", /** * Dot product */ - Dot = 'dot', + Dot = "dot", } diff --git a/node/src/sanitize.ts b/node/src/sanitize.ts index 5a10b5d7..cc5d958d 100644 --- a/node/src/sanitize.ts +++ b/node/src/sanitize.ts @@ -32,7 +32,7 @@ import { Bool, Date_, Decimal, - DataType, + type DataType, Dictionary, Binary, Float32, @@ -74,12 +74,12 @@ import { DurationNanosecond, DurationMicrosecond, DurationMillisecond, - DurationSecond, + DurationSecond } from "apache-arrow"; import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type"; function sanitizeMetadata( - metadataLike?: unknown, + metadataLike?: unknown ): Map | undefined { if (metadataLike === undefined || metadataLike === null) { return undefined; @@ -90,7 +90,7 @@ function sanitizeMetadata( for (const item of metadataLike) { if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) { throw Error( - "Expected metadata, if present, to be a Map but it had non-string keys or values", + "Expected metadata, if present, to be a Map but it had non-string keys or values" ); } } @@ -105,7 +105,7 @@ function sanitizeInt(typeLike: object) { typeof typeLike.isSigned !== "boolean" ) { throw Error( - "Expected an Int Type to have a `bitWidth` and `isSigned` property", + "Expected an Int Type to have a `bitWidth` and `isSigned` property" ); } return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth); @@ -128,7 +128,7 @@ function sanitizeDecimal(typeLike: object) { typeof typeLike.bitWidth !== "number" ) { throw Error( - "Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties", + "Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties" ); } return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); @@ -149,7 +149,7 @@ function sanitizeTime(typeLike: object) { typeof typeLike.bitWidth !== "number" ) { throw Error( - "Expected a Time type to have `unit` and `bitWidth` properties", + "Expected a Time type to have `unit` and `bitWidth` properties" ); } return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth); @@ -172,7 +172,7 @@ function sanitizeTypedTimestamp( | typeof TimestampNanosecond | typeof TimestampMicrosecond | typeof TimestampMillisecond - | typeof TimestampSecond, + | typeof TimestampSecond ) { let timezone = null; if ("timezone" in typeLike && typeof typeLike.timezone === "string") { @@ -191,7 +191,7 @@ function sanitizeInterval(typeLike: object) { function sanitizeList(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a List type to have an array-like `children` property", + "Expected a List type to have an array-like `children` property" ); } if (typeLike.children.length !== 1) { @@ -203,7 +203,7 @@ function sanitizeList(typeLike: object) { function sanitizeStruct(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Struct type to have an array-like `children` property", + "Expected a Struct type to have an array-like `children` property" ); } return new Struct(typeLike.children.map((child) => sanitizeField(child))); @@ -216,47 +216,47 @@ function sanitizeUnion(typeLike: object) { typeof typeLike.mode !== "number" ) { throw Error( - "Expected a Union type to have `typeIds` and `mode` properties", + "Expected a Union type to have `typeIds` and `mode` properties" ); } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Union type to have an array-like `children` property", + "Expected a Union type to have an array-like `children` property" ); } return new Union( typeLike.mode, typeLike.typeIds as any, - typeLike.children.map((child) => sanitizeField(child)), + typeLike.children.map((child) => sanitizeField(child)) ); } function sanitizeTypedUnion( typeLike: object, - UnionType: typeof DenseUnion | typeof SparseUnion, + UnionType: typeof DenseUnion | typeof SparseUnion ) { if (!("typeIds" in typeLike)) { throw Error( - "Expected a DenseUnion/SparseUnion type to have a `typeIds` property", + "Expected a DenseUnion/SparseUnion type to have a `typeIds` property" ); } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a DenseUnion/SparseUnion type to have an array-like `children` property", + "Expected a DenseUnion/SparseUnion type to have an array-like `children` property" ); } return new UnionType( typeLike.typeIds as any, - typeLike.children.map((child) => sanitizeField(child)), + typeLike.children.map((child) => sanitizeField(child)) ); } function sanitizeFixedSizeBinary(typeLike: object) { if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { throw Error( - "Expected a FixedSizeBinary type to have a `byteWidth` property", + "Expected a FixedSizeBinary type to have a `byteWidth` property" ); } return new FixedSizeBinary(typeLike.byteWidth); @@ -268,7 +268,7 @@ function sanitizeFixedSizeList(typeLike: object) { } if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a FixedSizeList type to have an array-like `children` property", + "Expected a FixedSizeList type to have an array-like `children` property" ); } if (typeLike.children.length !== 1) { @@ -276,14 +276,14 @@ function sanitizeFixedSizeList(typeLike: object) { } return new FixedSizeList( typeLike.listSize, - sanitizeField(typeLike.children[0]), + sanitizeField(typeLike.children[0]) ); } function sanitizeMap(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( - "Expected a Map type to have an array-like `children` property", + "Expected a Map type to have an array-like `children` property" ); } if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") { @@ -291,7 +291,7 @@ function sanitizeMap(typeLike: object) { } return new Map_( typeLike.children.map((field) => sanitizeField(field)) as any, - typeLike.keysSorted, + typeLike.keysSorted ); } @@ -319,7 +319,7 @@ function sanitizeDictionary(typeLike: object) { sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices) as any, typeLike.id, - typeLike.isOrdered, + typeLike.isOrdered ); } @@ -454,7 +454,7 @@ function sanitizeField(fieldLike: unknown): Field { !("nullable" in fieldLike) ) { throw Error( - "The field passed in is missing a `type`/`name`/`nullable` property", + "The field passed in is missing a `type`/`name`/`nullable` property" ); } const type = sanitizeType(fieldLike.type); @@ -489,7 +489,7 @@ export function sanitizeSchema(schemaLike: unknown): Schema { } if (!("fields" in schemaLike)) { throw Error( - "The schema passed in does not appear to be a schema (no 'fields' property)", + "The schema passed in does not appear to be a schema (no 'fields' property)" ); } let metadata; @@ -498,11 +498,11 @@ export function sanitizeSchema(schemaLike: unknown): Schema { } if (!Array.isArray(schemaLike.fields)) { throw Error( - "The schema passed in had a 'fields' property but it was not an array", + "The schema passed in had a 'fields' property but it was not an array" ); } const sanitizedFields = schemaLike.fields.map((field) => - sanitizeField(field), + sanitizeField(field) ); return new Schema(sanitizedFields, metadata); } diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 75d6351b..32377a39 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { describe } from 'mocha' -import { track } from 'temp' -import * as chai from 'chai' -import * as chaiAsPromised from 'chai-as-promised' +import { describe } from "mocha"; +import { track } from "temp"; +import * as chai from "chai"; +import * as chaiAsPromised from "chai-as-promised"; -import * as lancedb from '../index' +import * as lancedb from "../index"; import { type AwsCredentials, type EmbeddingFunction, @@ -27,7 +27,7 @@ import { DefaultWriteOptions, isWriteOptions, type LocalTable -} from '../index' +} from "../index"; import { FixedSizeList, Field, @@ -41,288 +41,288 @@ import { Float32, Float16, Int64 -} from 'apache-arrow' -import type { RemoteRequest, RemoteResponse } from '../middleware' +} from "apache-arrow"; +import type { RemoteRequest, RemoteResponse } from "../middleware"; -const expect = chai.expect -const assert = chai.assert -chai.use(chaiAsPromised) +const expect = chai.expect; +const assert = chai.assert; +chai.use(chaiAsPromised); -describe('LanceDB client', function () { - describe('when creating a connection to lancedb', function () { - it('should have a valid url', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) - assert.equal(con.uri, uri) - }) +describe("LanceDB client", function () { + describe("when creating a connection to lancedb", function () { + it("should have a valid url", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + assert.equal(con.uri, uri); + }); - it('should accept an options object', async function () { - const uri = await createTestDB() - const con = await lancedb.connect({ uri }) - assert.equal(con.uri, uri) - }) + it("should accept an options object", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect({ uri }); + assert.equal(con.uri, uri); + }); - it('should accept custom aws credentials', async function () { - const uri = await createTestDB() + it("should accept custom aws credentials", async function () { + const uri = await createTestDB(); const awsCredentials: AwsCredentials = { - accessKeyId: '', - secretKey: '' - } + accessKeyId: "", + secretKey: "" + }; const con = await lancedb.connect({ uri, awsCredentials - }) - assert.equal(con.uri, uri) - }) + }); + assert.equal(con.uri, uri); + }); - it('should accept custom storage options', async function () { - const uri = await createTestDB() + it("should accept custom storage options", async function () { + const uri = await createTestDB(); const storageOptions = { - region: 'us-west-2', - timeout: '30s' + region: "us-west-2", + timeout: "30s" }; const con = await lancedb.connect({ uri, storageOptions - }) - assert.equal(con.uri, uri) - }) + }); + assert.equal(con.uri, uri); + }); - it('should return the existing table names', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) - assert.deepEqual(await con.tableNames(), ['vectors']) - }) - }) + it("should return the existing table names", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + assert.deepEqual(await con.tableNames(), ["vectors"]); + }); + }); - describe('when querying an existing dataset', function () { - it('should open a table', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') - assert.equal(table.name, 'vectors') - }) + describe("when querying an existing dataset", function () { + it("should open a table", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); + assert.equal(table.name, "vectors"); + }); - it('execute a query', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') - const results = await table.search([0.1, 0.3]).execute() + it("execute a query", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); + const results = await table.search([0.1, 0.3]).execute(); - assert.equal(results.length, 2) - assert.equal(results[0].price, 10) - const vector = results[0].vector as Float32Array - assert.approximately(vector[0], 0.0, 0.2) - assert.approximately(vector[0], 0.1, 0.3) - }) + assert.equal(results.length, 2); + assert.equal(results[0].price, 10); + const vector = results[0].vector as Float32Array; + assert.approximately(vector[0], 0.0, 0.2); + assert.approximately(vector[0], 0.1, 0.3); + }); - it('limits # of results', async function () { - const uri = await createTestDB(2, 100) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') - let results = await table.search([0.1, 0.3]).limit(1).execute() - assert.equal(results.length, 1) - assert.equal(results[0].id, 1) + it("limits # of results", async function () { + const uri = await createTestDB(2, 100); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); + let results = await table.search([0.1, 0.3]).limit(1).execute(); + assert.equal(results.length, 1); + assert.equal(results[0].id, 1); // there is a default limit if unspecified - results = await table.search([0.1, 0.3]).execute() - assert.equal(results.length, 10) - }) + results = await table.search([0.1, 0.3]).execute(); + assert.equal(results.length, 10); + }); - it('uses a filter / where clause without vector search', async function () { + it("uses a filter / where clause without vector search", async function () { // eslint-disable-next-line @typescript-eslint/explicit-function-return-type const assertResults = (results: Array>) => { - assert.equal(results.length, 50) - } + assert.equal(results.length, 50); + }; - const uri = await createTestDB(2, 100) - const con = await lancedb.connect(uri) - const table = (await con.openTable('vectors')) as LocalTable - let results = await table.filter('id % 2 = 0').limit(100).execute() - assertResults(results) - results = await table.where('id % 2 = 0').limit(100).execute() - assertResults(results) + const uri = await createTestDB(2, 100); + const con = await lancedb.connect(uri); + const table = (await con.openTable("vectors")) as LocalTable; + let results = await table.filter("id % 2 = 0").limit(100).execute(); + assertResults(results); + results = await table.where("id % 2 = 0").limit(100).execute(); + assertResults(results); // Should reject a bad filter - await expect(table.filter('id % 2 = 0 AND').execute()).to.be.rejectedWith( + await expect(table.filter("id % 2 = 0 AND").execute()).to.be.rejectedWith( /.*sql parser error: Expected an expression:, found: EOF.*/ - ) - }) + ); + }); - it('uses a filter / where clause', async function () { + it("uses a filter / where clause", async function () { // eslint-disable-next-line @typescript-eslint/explicit-function-return-type const assertResults = (results: Array>) => { - assert.equal(results.length, 1) - assert.equal(results[0].id, 2) - } + assert.equal(results.length, 1); + assert.equal(results[0].id, 2); + }; - const uri = await createTestDB() - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') - let results = await table.search([0.1, 0.1]).filter('id == 2').execute() - assertResults(results) - results = await table.search([0.1, 0.1]).where('id == 2').execute() - assertResults(results) - }) + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); + let results = await table.search([0.1, 0.1]).filter("id == 2").execute(); + assertResults(results); + results = await table.search([0.1, 0.1]).where("id == 2").execute(); + assertResults(results); + }); - it('should correctly process prefilter/postfilter', async function () { - const uri = await createTestDB(16, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("should correctly process prefilter/postfilter", async function () { + const uri = await createTestDB(16, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); await table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) + }); // post filter should return less than the limit let results = await table .search(new Array(16).fill(0.1)) .limit(10) - .filter('id >= 10') + .filter("id >= 10") .prefilter(false) - .execute() - assert.isTrue(results.length < 10) + .execute(); + assert.isTrue(results.length < 10); // pre filter should return exactly the limit results = await table .search(new Array(16).fill(0.1)) .limit(10) - .filter('id >= 10') + .filter("id >= 10") .prefilter(true) - .execute() - assert.isTrue(results.length === 10) - }) + .execute(); + assert.isTrue(results.length === 10); + }); - it('should allow creation and use of scalar indices', async function () { - const uri = await createTestDB(16, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') - await table.createScalarIndex('id', true) + it("should allow creation and use of scalar indices", async function () { + const uri = await createTestDB(16, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); + await table.createScalarIndex("id", true); // Prefiltering should still work the same const results = await table .search(new Array(16).fill(0.1)) .limit(10) - .filter('id >= 10') + .filter("id >= 10") .prefilter(true) - .execute() - assert.isTrue(results.length === 10) - }) + .execute(); + assert.isTrue(results.length === 10); + }); - it('select only a subset of columns', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("select only a subset of columns", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); const results = await table .search([0.1, 0.1]) - .select(['is_active', 'vector']) - .execute() - assert.equal(results.length, 2) + .select(["is_active", "vector"]) + .execute(); + assert.equal(results.length, 2); // vector and _distance are always returned - assert.isDefined(results[0].vector) - assert.isDefined(results[0]._distance) - assert.isDefined(results[0].is_active) + assert.isDefined(results[0].vector); + assert.isDefined(results[0]._distance); + assert.isDefined(results[0].is_active); - assert.isUndefined(results[0].id) - assert.isUndefined(results[0].name) - assert.isUndefined(results[0].price) - }) - }) + assert.isUndefined(results[0].id); + assert.isUndefined(results[0].name); + assert.isUndefined(results[0].price); + }); + }); - describe('when creating a new dataset', function () { - it('create an empty table', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + describe("when creating a new dataset", function () { + it("create an empty table", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ - new Field('id', new Int32()), - new Field('name', new Utf8()) - ]) + new Field("id", new Int32()), + new Field("name", new Utf8()) + ]); const table = await con.createTable({ - name: 'vectors', + name: "vectors", schema - }) - assert.equal(table.name, 'vectors') - assert.deepEqual(await con.tableNames(), ['vectors']) - }) + }); + assert.equal(table.name, "vectors"); + assert.deepEqual(await con.tableNames(), ["vectors"]); + }); - it('create a table with a schema and records', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("create a table with a schema and records", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ - new Field('id', new Int32()), - new Field('name', new Utf8()), + new Field("id", new Int32()), + new Field("name", new Utf8()), new Field( - 'vector', - new FixedSizeList(2, new Field('item', new Float32(), true)), + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)), false ) - ]) + ]); const data = [ { vector: [0.5, 0.2], - name: 'foo', + name: "foo", id: 0 }, { vector: [0.3, 0.1], - name: 'bar', + name: "bar", id: 1 } - ] + ]; // even thought the keys in data is out of order it should still work const table = await con.createTable({ - name: 'vectors', + name: "vectors", data, schema - }) - assert.equal(table.name, 'vectors') - assert.deepEqual(await con.tableNames(), ['vectors']) - }) + }); + assert.equal(table.name, "vectors"); + assert.deepEqual(await con.tableNames(), ["vectors"]); + }); - it('create a table with a empty data array', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("create a table with a empty data array", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ - new Field('id', new Int32()), - new Field('name', new Utf8()) - ]) + new Field("id", new Int32()), + new Field("name", new Utf8()) + ]); const table = await con.createTable({ - name: 'vectors', + name: "vectors", schema, data: [] - }) - assert.equal(table.name, 'vectors') - assert.deepEqual(await con.tableNames(), ['vectors']) - }) + }); + assert.equal(table.name, "vectors"); + assert.deepEqual(await con.tableNames(), ["vectors"]); + }); - it('create a table from an Arrow Table', async function () { - const dir = await track().mkdir('lancejs') + it("create a table from an Arrow Table", async function () { + const dir = await track().mkdir("lancejs"); // Also test the connect function with an object - const con = await lancedb.connect({ uri: dir }) + const con = await lancedb.connect({ uri: dir }); - const i32s = new Int32Array(new Array(10)) - const i32 = makeVector(i32s) + const i32s = new Int32Array(new Array(10)); + const i32 = makeVector(i32s); - const data = new ArrowTable({ vector: i32 }) + const data = new ArrowTable({ vector: i32 }); const table = await con.createTable({ - name: 'vectors', + name: "vectors", data - }) - assert.equal(table.name, 'vectors') - assert.equal(await table.countRows(), 10) - assert.equal(await table.countRows('vector IS NULL'), 0) - assert.deepEqual(await con.tableNames(), ['vectors']) - }) + }); + assert.equal(table.name, "vectors"); + assert.equal(await table.countRows(), 10); + assert.equal(await table.countRows("vector IS NULL"), 0); + assert.deepEqual(await con.tableNames(), ["vectors"]); + }); - it('creates a new table from javascript objects', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("creates a new table from javascript objects", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { id: 1, vector: [0.1, 0.2], price: 10 }, @@ -331,93 +331,93 @@ describe('LanceDB client', function () { vector: [1.1, 1.2], price: 50 } - ] + ]; - const tableName = `vectors_${Math.floor(Math.random() * 100)}` - const table = await con.createTable(tableName, data) - assert.equal(table.name, tableName) - assert.equal(await table.countRows(), 2) - }) + const tableName = `vectors_${Math.floor(Math.random() * 100)}`; + const table = await con.createTable(tableName, data); + assert.equal(table.name, tableName); + assert.equal(await table.countRows(), 2); + }); - it('creates a new table from javascript objects with variable sized list', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("creates a new table from javascript objects with variable sized list", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { id: 1, vector: [0.1, 0.2], - list_of_str: ['a', 'b', 'c'], + list_of_str: ["a", "b", "c"], list_of_num: [1, 2, 3] }, { id: 2, vector: [1.1, 1.2], - list_of_str: ['x', 'y'], + list_of_str: ["x", "y"], list_of_num: [4, 5, 6] } - ] + ]; - const tableName = 'with_variable_sized_list' - const table = (await con.createTable(tableName, data)) as LocalTable - assert.equal(table.name, tableName) - assert.equal(await table.countRows(), 2) - const rs = await table.filter('id>1').execute() - assert.equal(rs.length, 1) - assert.deepEqual(rs[0].list_of_str, ['x', 'y']) - assert.isTrue(rs[0].list_of_num instanceof Array) - }) + const tableName = "with_variable_sized_list"; + const table = (await con.createTable(tableName, data)) as LocalTable; + assert.equal(table.name, tableName); + assert.equal(await table.countRows(), 2); + const rs = await table.filter("id>1").execute(); + assert.equal(rs.length, 1); + assert.deepEqual(rs[0].list_of_str, ["x", "y"]); + assert.isTrue(rs[0].list_of_num instanceof Array); + }); - it('create table from arrow table', async () => { - const dim = 128 - const total = 256 - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("create table from arrow table", async () => { + const dim = 128; + const total = 256; + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ - new Field('id', new Int32()), + new Field("id", new Int32()), new Field( - 'vector', - new FixedSizeList(dim, new Field('item', new Float16(), true)), + "vector", + new FixedSizeList(dim, new Field("item", new Float16(), true)), false ) - ]) + ]); const data = lancedb.makeArrowTable( Array.from(Array(total), (_, i) => ({ id: i, vector: Array.from(Array(dim), Math.random) })), { schema } - ) - const table = await con.createTable('f16', data) - assert.equal(table.name, 'f16') - assert.equal(await table.countRows(), total) - assert.equal(await table.countRows('id < 5'), 5) - assert.deepEqual(await con.tableNames(), ['f16']) - assert.deepEqual(await table.schema, schema) + ); + const table = await con.createTable("f16", data); + assert.equal(table.name, "f16"); + assert.equal(await table.countRows(), total); + assert.equal(await table.countRows("id < 5"), 5); + assert.deepEqual(await con.tableNames(), ["f16"]); + assert.deepEqual(await table.schema, schema); await table.createIndex({ num_sub_vectors: 2, num_partitions: 2, - type: 'ivf_pq' - }) + type: "ivf_pq" + }); - const q = Array.from(Array(dim), Math.random) - const r = await table.search(q).limit(5).execute() - assert.equal(r.length, 5) + const q = Array.from(Array(dim), Math.random); + const r = await table.search(q).limit(5).execute(); + assert.equal(r.length, 5); r.forEach((v) => { - assert.equal(Object.prototype.hasOwnProperty.call(v, 'vector'), true) + assert.equal(Object.prototype.hasOwnProperty.call(v, "vector"), true); assert.equal( v.vector?.constructor.name, - 'Array', - 'vector column is list of floats' - ) - }) - }).timeout(120000) + "Array", + "vector column is list of floats" + ); + }); + }).timeout(120000); - it('use overwrite flag to overwrite existing table', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("use overwrite flag to overwrite existing table", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { id: 1, vector: [0.1, 0.2], price: 10 }, @@ -426,10 +426,10 @@ describe('LanceDB client', function () { vector: [1.1, 1.2], price: 50 } - ] + ]; - const tableName = 'overwrite' - await con.createTable(tableName, data, { writeMode: WriteMode.Create }) + const tableName = "overwrite"; + await con.createTable(tableName, data, { writeMode: WriteMode.Create }); const newData = [ { id: 1, vector: [0.1, 0.2], price: 10 }, @@ -439,550 +439,617 @@ describe('LanceDB client', function () { vector: [1.1, 1.2], price: 50 } - ] + ]; await expect(con.createTable(tableName, newData)).to.be.rejectedWith( Error, - 'already exists' - ) + "already exists" + ); const table = await con.createTable(tableName, newData, { writeMode: WriteMode.Overwrite - }) - assert.equal(table.name, tableName) - assert.equal(await table.countRows(), 3) - }) + }); + assert.equal(table.name, tableName); + assert.equal(await table.countRows(), 3); + }); - it('appends records to an existing table ', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("appends records to an existing table ", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { id: 1, vector: [0.1, 0.2], price: 10, - name: 'a' + name: "a" }, { id: 2, vector: [1.1, 1.2], price: 50, - name: 'b' + name: "b" } - ] + ]; - const table = await con.createTable('vectors', data) - assert.equal(await table.countRows(), 2) + const table = await con.createTable("vectors", data); + assert.equal(await table.countRows(), 2); const dataAdd = [ { id: 3, vector: [2.1, 2.2], price: 10, - name: 'c' + name: "c" }, { id: 4, vector: [3.1, 3.2], price: 50, - name: 'd' + name: "d" } - ] - await table.add(dataAdd) - assert.equal(await table.countRows(), 4) - }) + ]; + await table.add(dataAdd); + assert.equal(await table.countRows(), 4); + }); - it('appends records with fields in a different order', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("appends records with fields in a different order", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { id: 1, vector: [0.1, 0.2], price: 10, - name: 'a' + name: "a" }, { id: 2, vector: [1.1, 1.2], price: 50, - name: 'b' + name: "b" } - ] + ]; - const table = await con.createTable('vectors', data) + const table = await con.createTable("vectors", data); const dataAdd = [ { id: 3, vector: [2.1, 2.2], - name: 'c', + name: "c", price: 10 }, { id: 4, vector: [3.1, 3.2], - name: 'd', + name: "d", price: 50 } - ] - await table.add(dataAdd) - assert.equal(await table.countRows(), 4) - }) + ]; + await table.add(dataAdd); + assert.equal(await table.countRows(), 4); + }); - it('overwrite all records in a table', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) + it("overwrite all records in a table", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); - const table = await con.openTable('vectors') - assert.equal(await table.countRows(), 2) + const table = await con.openTable("vectors"); + assert.equal(await table.countRows(), 2); const dataOver = [ { vector: [2.1, 2.2], price: 10, - name: 'foo' + name: "foo" }, { vector: [3.1, 3.2], price: 50, - name: 'bar' + name: "bar" } - ] - await table.overwrite(dataOver) - assert.equal(await table.countRows(), 2) - }) + ]; + await table.overwrite(dataOver); + assert.equal(await table.countRows(), 2); + }); - it('can merge insert records into the table', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("can merge insert records into the table", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); - const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }] - const table = await con.createTable('my_table', data) + const data = [ + { id: 1, age: 1 }, + { id: 2, age: 1 } + ]; + const table = await con.createTable("my_table", data); // insert if not exists - let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }] - await table.mergeInsert('id', newData, { + let newData = [ + { id: 2, age: 2 }, + { id: 3, age: 2 } + ]; + await table.mergeInsert("id", newData, { whenNotMatchedInsertAll: true - }) - assert.equal(await table.countRows(), 3) - assert.equal(await table.countRows('age = 2'), 1) + }); + assert.equal(await table.countRows(), 3); + assert.equal(await table.countRows("age = 2"), 1); // conditional update - newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }] - await table.mergeInsert('id', newData, { - whenMatchedUpdateAll: 'target.age = 1' - }) - assert.equal(await table.countRows(), 3) - assert.equal(await table.countRows('age = 1'), 1) - assert.equal(await table.countRows('age = 3'), 1) + newData = [ + { id: 2, age: 3 }, + { id: 3, age: 3 } + ]; + await table.mergeInsert("id", newData, { + whenMatchedUpdateAll: "target.age = 1" + }); + assert.equal(await table.countRows(), 3); + assert.equal(await table.countRows("age = 1"), 1); + assert.equal(await table.countRows("age = 3"), 1); - newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }] - await table.mergeInsert('id', newData, { + newData = [ + { id: 3, age: 4 }, + { id: 4, age: 4 } + ]; + await table.mergeInsert("id", newData, { whenNotMatchedInsertAll: true, whenMatchedUpdateAll: true - }) - assert.equal(await table.countRows(), 4) - assert.equal((await table.filter('age = 4').execute()).length, 2) + }); + assert.equal(await table.countRows(), 4); + assert.equal((await table.filter("age = 4").execute()).length, 2); - newData = [{ id: 5, age: 5 }] - await table.mergeInsert('id', newData, { + newData = [{ id: 5, age: 5 }]; + await table.mergeInsert("id", newData, { whenNotMatchedInsertAll: true, whenMatchedUpdateAll: true, - whenNotMatchedBySourceDelete: 'age < 4' - }) - assert.equal(await table.countRows(), 3) + whenNotMatchedBySourceDelete: "age < 4" + }); + assert.equal(await table.countRows(), 3); - await table.mergeInsert('id', newData, { + await table.mergeInsert("id", newData, { whenNotMatchedInsertAll: true, whenMatchedUpdateAll: true, whenNotMatchedBySourceDelete: true - }) - assert.equal(await table.countRows(), 1) - }) + }); + assert.equal(await table.countRows(), 1); + }); - it('can update records in the table', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) + it("can update records in the table", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); - const table = await con.openTable('vectors') - assert.equal(await table.countRows(), 2) + const table = await con.openTable("vectors"); + assert.equal(await table.countRows(), 2); await table.update({ - where: 'price = 10', - valuesSql: { price: '100' } - }) - const results = await table.search([0.1, 0.2]).execute() - assert.equal(results[0].price, 100) - assert.equal(results[1].price, 11) - }) + where: "price = 10", + valuesSql: { price: "100" } + }); + const results = await table.search([0.1, 0.2]).execute(); + assert.equal(results[0].price, 100); + assert.equal(results[1].price, 11); + }); - it('can update the records using a literal value', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) + it("can update the records using a literal value", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); - const table = await con.openTable('vectors') - assert.equal(await table.countRows(), 2) + const table = await con.openTable("vectors"); + assert.equal(await table.countRows(), 2); await table.update({ - where: 'price = 10', + where: "price = 10", values: { price: 100 } - }) - const results = await table.search([0.1, 0.2]).execute() - assert.equal(results[0].price, 100) - assert.equal(results[1].price, 11) - }) + }); + const results = await table.search([0.1, 0.2]).execute(); + assert.equal(results[0].price, 100); + assert.equal(results[1].price, 11); + }); - it('can update every record in the table', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) + it("can update every record in the table", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); - const table = await con.openTable('vectors') - assert.equal(await table.countRows(), 2) + const table = await con.openTable("vectors"); + assert.equal(await table.countRows(), 2); - await table.update({ valuesSql: { price: '100' } }) - const results = await table.search([0.1, 0.2]).execute() + await table.update({ valuesSql: { price: "100" } }); + const results = await table.search([0.1, 0.2]).execute(); - assert.equal(results[0].price, 100) - assert.equal(results[1].price, 100) - }) + assert.equal(results[0].price, 100); + assert.equal(results[1].price, 100); + }); - it('can delete records from a table', async function () { - const uri = await createTestDB() - const con = await lancedb.connect(uri) + it("can delete records from a table", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); - const table = await con.openTable('vectors') - assert.equal(await table.countRows(), 2) + const table = await con.openTable("vectors"); + assert.equal(await table.countRows(), 2); - await table.delete('price = 10') - assert.equal(await table.countRows(), 1) - }) - }) + await table.delete("price = 10"); + assert.equal(await table.countRows(), 1); + }); + it("can manually provide embedding columns", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const schema = new Schema([ + new Field("id", new Int32()), + new Field("text", new Utf8()), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)) + ) + ]); + const data = [ + { id: 1, text: "foo", vector: [0.1, 0.2] }, + { id: 2, text: "bar", vector: [0.3, 0.4] } + ]; + let table = await con.createTable({ + name: "embed_vectors", + data, + schema + }); + assert.equal(table.name, "embed_vectors"); + table = await con.openTable("embed_vectors"); + assert.equal(await table.countRows(), 2); + }); - describe('when searching an empty dataset', function () { - it('should not fail', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("will error if no implementation for embedding column found", async function () { + const uri = await createTestDB(); + const con = await lancedb.connect(uri); + const schema = new Schema([ + new Field("id", new Int32()), + new Field("text", new Utf8()), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)) + ) + ]); + const data = [ + { id: 1, text: "foo" }, + { id: 2, text: "bar" } + ]; + + const table = con.createTable({ + name: "embed_vectors", + data, + schema + }); + await assert.isRejected(table); + }); + }); + + describe("when searching an empty dataset", function () { + it("should not fail", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ new Field( - 'vector', - new FixedSizeList(128, new Field('float32', new Float32())) + "vector", + new FixedSizeList(128, new Field("float32", new Float32())) ) - ]) + ]); const table = await con.createTable({ - name: 'vectors', + name: "vectors", schema - }) - const result = await table.search(Array(128).fill(0.1)).execute() - assert.isEmpty(result) - }) - }) + }); + const result = await table.search(Array(128).fill(0.1)).execute(); + assert.isEmpty(result); + }); + }); - describe('when searching an empty-after-delete dataset', function () { - it('should not fail', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + describe("when searching an empty-after-delete dataset", function () { + it("should not fail", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ new Field( - 'vector', - new FixedSizeList(128, new Field('float32', new Float32())) + "vector", + new FixedSizeList(128, new Field("float32", new Float32())) ) - ]) + ]); const table = await con.createTable({ - name: 'vectors', + name: "vectors", schema - }) - await table.add([{ vector: Array(128).fill(0.1) }]) + }); + await table.add([{ vector: Array(128).fill(0.1) }]); // https://github.com/lancedb/lance/issues/1635 - await table.delete('true') - const result = await table.search(Array(128).fill(0.1)).execute() - assert.isEmpty(result) - }) - }) + await table.delete("true"); + const result = await table.search(Array(128).fill(0.1)).execute(); + assert.isEmpty(result); + }); + }); - describe('when creating a vector index', function () { - it('overwrite all records in a table', async function () { - const uri = await createTestDB(32, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + describe("when creating a vector index", function () { + it("overwrite all records in a table", async function () { + const uri = await createTestDB(32, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); await table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) - }).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow + }); + }).timeout(10_000); // Timeout is high partially because GH macos runner is pretty slow - it('replace an existing index', async function () { - const uri = await createTestDB(16, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("replace an existing index", async function () { + const uri = await createTestDB(16, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); await table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) + }); // Replace should fail if the index already exists await expect( table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2, replace: false }) - ).to.be.rejectedWith('LanceError(Index)') + ).to.be.rejectedWith("LanceError(Index)"); // Default replace = true await table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) - }).timeout(50_000) + }); + }).timeout(50_000); - it('it should fail when the column is not a vector', async function () { - const uri = await createTestDB(32, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("it should fail when the column is not a vector", async function () { + const uri = await createTestDB(32, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); const createIndex = table.createIndex({ - type: 'ivf_pq', - column: 'name', + type: "ivf_pq", + column: "name", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) + }); await expect(createIndex).to.be.rejectedWith( "index cannot be created on the column `name` which has data type Utf8" - ) - }) + ); + }); - it('it should fail when num_partitions is invalid', async function () { - const uri = await createTestDB(32, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("it should fail when num_partitions is invalid", async function () { + const uri = await createTestDB(32, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); const createIndex = table.createIndex({ - type: 'ivf_pq', - column: 'name', + type: "ivf_pq", + column: "name", num_partitions: -1, max_iters: 2, num_sub_vectors: 2 - }) + }); await expect(createIndex).to.be.rejectedWith( - 'num_partitions: must be > 0' - ) - }) + "num_partitions: must be > 0" + ); + }); - it('should be able to list index and stats', async function () { - const uri = await createTestDB(32, 300) - const con = await lancedb.connect(uri) - const table = await con.openTable('vectors') + it("should be able to list index and stats", async function () { + const uri = await createTestDB(32, 300); + const con = await lancedb.connect(uri); + const table = await con.openTable("vectors"); await table.createIndex({ - type: 'ivf_pq', - column: 'vector', + type: "ivf_pq", + column: "vector", num_partitions: 2, max_iters: 2, num_sub_vectors: 2 - }) + }); - const indices = await table.listIndices() - expect(indices).to.have.lengthOf(1) - expect(indices[0].name).to.equal('vector_idx') - expect(indices[0].uuid).to.not.be.equal(undefined) - expect(indices[0].columns).to.have.lengthOf(1) - expect(indices[0].columns[0]).to.equal('vector') + const indices = await table.listIndices(); + expect(indices).to.have.lengthOf(1); + expect(indices[0].name).to.equal("vector_idx"); + expect(indices[0].uuid).to.not.be.equal(undefined); + expect(indices[0].columns).to.have.lengthOf(1); + expect(indices[0].columns[0]).to.equal("vector"); - const stats = await table.indexStats(indices[0].uuid) - expect(stats.numIndexedRows).to.equal(300) - expect(stats.numUnindexedRows).to.equal(0) - }).timeout(50_000) - }) + const stats = await table.indexStats(indices[0].uuid); + expect(stats.numIndexedRows).to.equal(300); + expect(stats.numUnindexedRows).to.equal(0); + }).timeout(50_000); + }); - describe('when using a custom embedding function', function () { + describe("when using a custom embedding function", function () { class TextEmbedding implements EmbeddingFunction { - sourceColumn: string + sourceColumn: string; - constructor (targetColumn: string) { - this.sourceColumn = targetColumn + constructor(targetColumn: string) { + this.sourceColumn = targetColumn; } _embedding_map = new Map([ - ['foo', [2.1, 2.2]], - ['bar', [3.1, 3.2]] - ]) + ["foo", [2.1, 2.2]], + ["bar", [3.1, 3.2]] + ]); - async embed (data: string[]): Promise { + async embed(data: string[]): Promise { return data.map( (datum) => this._embedding_map.get(datum) ?? [0.0, 0.0] - ) + ); } } - it('should encode the original data into embeddings', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) - const embeddings = new TextEmbedding('name') + it("should encode the original data into embeddings", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); + const embeddings = new TextEmbedding("name"); const data = [ { price: 10, - name: 'foo' + name: "foo" }, { price: 50, - name: 'bar' + name: "bar" } - ] - const table = await con.createTable('vectors', data, embeddings, { + ]; + const table = await con.createTable("vectors", data, embeddings, { writeMode: WriteMode.Create - }) - const results = await table.search('foo').execute() - assert.equal(results.length, 2) - }) + }); + const results = await table.search("foo").execute(); + assert.equal(results.length, 2); + }); - it('should create embeddings for Arrow Table', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) - const embeddingFunction = new TextEmbedding('name') + it("should create embeddings for Arrow Table", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); + const embeddingFunction = new TextEmbedding("name"); - const names = vectorFromArray(['foo', 'bar'], new Utf8()) - const data = new ArrowTable({ name: names }) + const names = vectorFromArray(["foo", "bar"], new Utf8()); + const data = new ArrowTable({ name: names }); const table = await con.createTable({ - name: 'vectors', + name: "vectors", data, embeddingFunction - }) - assert.equal(table.name, 'vectors') - const results = await table.search('foo').execute() - assert.equal(results.length, 2) - }) - }) + }); + assert.equal(table.name, "vectors"); + const results = await table.search("foo").execute(); + assert.equal(results.length, 2); + }); + }); - describe('when inspecting the schema', function () { - it('should return the schema', async function () { - const uri = await createTestDB() - const db = await lancedb.connect(uri) + describe("when inspecting the schema", function () { + it("should return the schema", async function () { + const uri = await createTestDB(); + const db = await lancedb.connect(uri); // the fsl inner field must be named 'item' and be nullable const expectedSchema = new Schema([ - new Field('id', new Int32()), + new Field("id", new Int32()), new Field( - 'vector', - new FixedSizeList(128, new Field('item', new Float32(), true)) + "vector", + new FixedSizeList(128, new Field("item", new Float32(), true)) ), - new Field('s', new Utf8()) - ]) + new Field("s", new Utf8()) + ]); const table = await db.createTable({ - name: 'some_table', + name: "some_table", schema: expectedSchema - }) - const schema = await table.schema - assert.deepEqual(expectedSchema, schema) - }) - }) -}) + }); + const schema = await table.schema; + assert.deepEqual(expectedSchema, schema); + }); + }); +}); -describe('Remote LanceDB client', function () { - describe('when the server is not reachable', function () { - it('produces a network error', async function () { +describe("Remote LanceDB client", function () { + describe("when the server is not reachable", function () { + it("produces a network error", async function () { const con = await lancedb.connect({ - uri: 'db://test-1234', - region: 'asdfasfasfdf', - apiKey: 'some-api-key' - }) + uri: "db://test-1234", + region: "asdfasfasfdf", + apiKey: "some-api-key" + }); // GET try { - await con.tableNames() + await con.tableNames(); } catch (err) { expect(err).to.have.property( - 'message', - 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com' - ) + "message", + "Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com" + ); } // POST try { await con.createTable({ - name: 'vectors', + name: "vectors", schema: new Schema([]) - }) + }); } catch (err) { expect(err).to.have.property( - 'message', - 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com' - ) + "message", + "Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com" + ); } // Search - const table = await con.withMiddleware(new (class { - async onRemoteRequest(req: RemoteRequest, next: (req: RemoteRequest) => Promise) { - // intercept call to check if the table exists and make the call succeed - if (req.uri.endsWith('/describe/')) { - return { - status: 200, - statusText: 'OK', - headers: new Map(), - body: async () => ({}) - } - } + const table = await con + .withMiddleware( + new (class { + async onRemoteRequest( + req: RemoteRequest, + next: (req: RemoteRequest) => Promise + ) { + // intercept call to check if the table exists and make the call succeed + if (req.uri.endsWith("/describe/")) { + return { + status: 200, + statusText: "OK", + headers: new Map(), + body: async () => ({}) + }; + } - return await next(req) - } - })()).openTable('vectors') + return await next(req); + } + })() + ) + .openTable("vectors"); try { - await table.search([0.1, 0.3]).execute() + await table.search([0.1, 0.3]).execute(); } catch (err) { expect(err).to.have.property( - 'message', - 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com' - ) + "message", + "Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com" + ); } - }) - }) -}) + }); + }); +}); -describe('Query object', function () { - it('sets custom parameters', async function () { +describe("Query object", function () { + it("sets custom parameters", async function () { const query = new Query([0.1, 0.3]) .limit(1) .metricType(MetricType.Cosine) .refineFactor(100) - .select(['a', 'b']) - .nprobes(20) as Record - assert.equal(query._limit, 1) - assert.equal(query._metricType, MetricType.Cosine) - assert.equal(query._refineFactor, 100) - assert.equal(query._nprobes, 20) - assert.deepEqual(query._select, ['a', 'b']) - }) -}) + .select(["a", "b"]) + .nprobes(20) as Record; + assert.equal(query._limit, 1); + assert.equal(query._metricType, MetricType.Cosine); + assert.equal(query._refineFactor, 100); + assert.equal(query._nprobes, 20); + assert.deepEqual(query._select, ["a", "b"]); + }); +}); -async function createTestDB ( +async function createTestDB( numDimensions: number = 2, numRows: number = 2 ): Promise { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); - const data = [] + const data = []; for (let i = 0; i < numRows; i++) { - const vector = [] + const vector = []; for (let j = 0; j < numDimensions; j++) { - vector.push(i + j * 0.1) + vector.push(i + j * 0.1); } data.push({ id: i + 1, @@ -990,94 +1057,94 @@ async function createTestDB ( price: i + 10, is_active: i % 2 === 0, vector - }) + }); } - await con.createTable('vectors', data) - return dir + await con.createTable("vectors", data); + return dir; } -describe('Drop table', function () { - it('drop a table', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) +describe("Drop table", function () { + it("drop a table", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { price: 10, - name: 'foo', + name: "foo", vector: [1, 2, 3] }, { price: 50, - name: 'bar', + name: "bar", vector: [4, 5, 6] } - ] - await con.createTable('t1', data) - await con.createTable('t2', data) + ]; + await con.createTable("t1", data); + await con.createTable("t2", data); - assert.deepEqual(await con.tableNames(), ['t1', 't2']) + assert.deepEqual(await con.tableNames(), ["t1", "t2"]); - await con.dropTable('t1') - assert.deepEqual(await con.tableNames(), ['t2']) - }) -}) + await con.dropTable("t1"); + assert.deepEqual(await con.tableNames(), ["t2"]); + }); +}); -describe('WriteOptions', function () { - context('#isWriteOptions', function () { - it('should not match empty object', function () { - assert.equal(isWriteOptions({}), false) - }) - it('should match write options', function () { - assert.equal(isWriteOptions({ writeMode: WriteMode.Create }), true) - }) - it('should match undefined write mode', function () { - assert.equal(isWriteOptions({ writeMode: undefined }), true) - }) - it('should match default write options', function () { - assert.equal(isWriteOptions(new DefaultWriteOptions()), true) - }) - }) -}) +describe("WriteOptions", function () { + context("#isWriteOptions", function () { + it("should not match empty object", function () { + assert.equal(isWriteOptions({}), false); + }); + it("should match write options", function () { + assert.equal(isWriteOptions({ writeMode: WriteMode.Create }), true); + }); + it("should match undefined write mode", function () { + assert.equal(isWriteOptions({ writeMode: undefined }), true); + }); + it("should match default write options", function () { + assert.equal(isWriteOptions(new DefaultWriteOptions()), true); + }); + }); +}); -describe('Compact and cleanup', function () { - it('can cleanup after compaction', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) +describe("Compact and cleanup", function () { + it("can cleanup after compaction", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const data = [ { price: 10, - name: 'foo', + name: "foo", vector: [1, 2, 3] }, { price: 50, - name: 'bar', + name: "bar", vector: [4, 5, 6] } - ] - const table = (await con.createTable('t1', data)) as LocalTable + ]; + const table = (await con.createTable("t1", data)) as LocalTable; const newData = [ { price: 30, - name: 'baz', + name: "baz", vector: [7, 8, 9] } - ] - await table.add(newData) + ]; + await table.add(newData); const compactionMetrics = await table.compactFiles({ numThreads: 2 - }) - assert.equal(compactionMetrics.fragmentsRemoved, 2) - assert.equal(compactionMetrics.fragmentsAdded, 1) - assert.equal(await table.countRows(), 3) + }); + assert.equal(compactionMetrics.fragmentsRemoved, 2); + assert.equal(compactionMetrics.fragmentsAdded, 1); + assert.equal(await table.countRows(), 3); - await table.cleanupOldVersions() - assert.equal(await table.countRows(), 3) + await table.cleanupOldVersions(); + assert.equal(await table.countRows(), 3); // should have no effect, but this validates the arguments are parsed. await table.compactFiles({ @@ -1086,71 +1153,80 @@ describe('Compact and cleanup', function () { materializeDeletions: true, materializeDeletionsThreshold: 0.5, numThreads: 2 - }) + }); - const cleanupMetrics = await table.cleanupOldVersions(0, true) - assert.isAtLeast(cleanupMetrics.bytesRemoved, 1) - assert.isAtLeast(cleanupMetrics.oldVersions, 1) - assert.equal(await table.countRows(), 3) - }) -}) + const cleanupMetrics = await table.cleanupOldVersions(0, true); + assert.isAtLeast(cleanupMetrics.bytesRemoved, 1); + assert.isAtLeast(cleanupMetrics.oldVersions, 1); + assert.equal(await table.countRows(), 3); + }); +}); -describe('schema evolution', function () { +describe("schema evolution", function () { // Create a new sample table - it('can add a new column to the schema', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) - const table = await con.createTable('vectors', [ + it("can add a new column to the schema", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); + const table = await con.createTable("vectors", [ { id: 1n, vector: [0.1, 0.2] } - ]) + ]); - await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as float)' }]) + await table.addColumns([ + { name: "price", valueSql: "cast(10.0 as float)" } + ]); const expectedSchema = new Schema([ - new Field('id', new Int64()), - new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))), - new Field('price', new Float32()) - ]) - expect(await table.schema).to.deep.equal(expectedSchema) - }) + new Field("id", new Int64()), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)) + ), + new Field("price", new Float32()) + ]); + expect(await table.schema).to.deep.equal(expectedSchema); + }); - it('can alter the columns in the schema', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) + it("can alter the columns in the schema", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); const schema = new Schema([ - new Field('id', new Int64(), false), - new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))), - new Field('price', new Float64(), false) - ]) - const table = await con.createTable('vectors', [ + new Field("id", new Int64(), false), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)) + ), + new Field("price", new Float64(), false) + ]); + const table = await con.createTable("vectors", [ { id: 1n, vector: [0.1, 0.2], price: 10.0 } - ]) - expect(await table.schema).to.deep.equal(schema) + ]); + expect(await table.schema).to.deep.equal(schema); await table.alterColumns([ - { path: 'id', rename: 'new_id' }, - { path: 'price', nullable: true } - ]) + { path: "id", rename: "new_id" }, + { path: "price", nullable: true } + ]); const expectedSchema = new Schema([ - new Field('new_id', new Int64(), false), - new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))), - new Field('price', new Float64(), true) - ]) - expect(await table.schema).to.deep.equal(expectedSchema) - }) + new Field("new_id", new Int64(), false), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)) + ), + new Field("price", new Float64(), true) + ]); + expect(await table.schema).to.deep.equal(expectedSchema); + }); - it('can drop a column from the schema', async function () { - const dir = await track().mkdir('lancejs') - const con = await lancedb.connect(dir) - const table = await con.createTable('vectors', [ + it("can drop a column from the schema", async function () { + const dir = await track().mkdir("lancejs"); + const con = await lancedb.connect(dir); + const table = await con.createTable("vectors", [ { id: 1n, vector: [0.1, 0.2] } - ]) - await table.dropColumns(['vector']) + ]); + await table.dropColumns(["vector"]); - const expectedSchema = new Schema([ - new Field('id', new Int64(), false) - ]) - expect(await table.schema).to.deep.equal(expectedSchema) - }) -}) + const expectedSchema = new Schema([new Field("id", new Int64(), false)]); + expect(await table.schema).to.deep.equal(expectedSchema); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 24902974..690f18f0 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -20,7 +20,7 @@ import { type Vector, FixedSizeList, vectorFromArray, - type Schema, + Schema, Table as ArrowTable, RecordBatchStreamWriter, List, @@ -85,6 +85,7 @@ export class MakeArrowTableOptions { vectorColumns: Record = { vector: new VectorColumnOptions(), }; + embeddings?: EmbeddingFunction; /** * If true then string columns will be encoded with dictionary encoding @@ -208,6 +209,7 @@ export function makeArrowTable( const opt = new MakeArrowTableOptions(options !== undefined ? options : {}); if (opt.schema !== undefined && opt.schema !== null) { opt.schema = sanitizeSchema(opt.schema); + opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings); } const columns: Record = {}; // TODO: sample dataset to find missing columns @@ -287,8 +289,8 @@ export function makeArrowTable( // then patch the schema of the batches so we can use // `new ArrowTable(schema, batches)` which does not do any schema inference const firstTable = new ArrowTable(columns); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const batchesFixed = firstTable.batches.map( + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion (batch) => new RecordBatch(opt.schema!, batch.data), ); return new ArrowTable(opt.schema, batchesFixed); @@ -648,3 +650,41 @@ function alignTable(table: ArrowTable, schema: Schema): ArrowTable { export function createEmptyTable(schema: Schema): ArrowTable { return new ArrowTable(sanitizeSchema(schema)); } + +function validateSchemaEmbeddings( + schema: Schema, + data: Array>, + embeddings: EmbeddingFunction | undefined, +) { + const fields = []; + const missingEmbeddingFields = []; + + // First we check if the field is a `FixedSizeList` + // Then we check if the data contains the field + // if it does not, we add it to the list of missing embedding fields + // Finally, we check if those missing embedding fields are `this._embeddings` + // if they are not, we throw an error + for (const field of schema.fields) { + if (field.type instanceof FixedSizeList) { + if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { + missingEmbeddingFields.push(field); + } else { + fields.push(field); + } + } else { + fields.push(field); + } + } + + if (missingEmbeddingFields.length > 0 && embeddings === undefined) { + console.log({ missingEmbeddingFields, embeddings }); + + throw new Error( + `Table has embeddings: "${missingEmbeddingFields + .map((f) => f.name) + .join(",")}", but no embedding function was provided`, + ); + } + + return new Schema(fields); +}