mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 22:09:58 +00:00
Compare commits
2 Commits
type-reorg
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
600bfd7237 | ||
|
|
d087e7891d |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -21,15 +21,7 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { RemoteConnection } from './remote'
|
||||
import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
import {
|
||||
type Connection, type CreateTableOptions, type Table,
|
||||
type VectorIndexParams, type UpdateArgs, type UpdateSqlArgs,
|
||||
type VectorIndex, type IndexStats,
|
||||
type ConnectionOptions, WriteMode, type WriteOptions
|
||||
} from './types'
|
||||
import { toSQL } from './util'
|
||||
|
||||
export { type WriteMode }
|
||||
import { type Literal, toSQL } from './util'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
@@ -38,6 +30,30 @@ export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||
|
||||
export interface AwsCredentials {
|
||||
accessKeyId: string
|
||||
|
||||
secretKey: string
|
||||
|
||||
sessionToken?: string
|
||||
}
|
||||
|
||||
export interface ConnectionOptions {
|
||||
uri: string
|
||||
|
||||
awsCredentials?: AwsCredentials
|
||||
|
||||
awsRegion?: string
|
||||
|
||||
// API key for the remote connections
|
||||
apiKey?: string
|
||||
// Region to connect
|
||||
region?: string
|
||||
|
||||
// override the host for the remote connections
|
||||
hostOverride?: string
|
||||
}
|
||||
|
||||
function getAwsArgs (opts: ConnectionOptions): any[] {
|
||||
const callArgs = []
|
||||
const awsCredentials = opts.awsCredentials
|
||||
@@ -55,6 +71,23 @@ function getAwsArgs (opts: ConnectionOptions): any[] {
|
||||
return callArgs
|
||||
}
|
||||
|
||||
export interface CreateTableOptions<T> {
|
||||
// Name of Table
|
||||
name: string
|
||||
|
||||
// Data to insert into the Table
|
||||
data?: Array<Record<string, unknown>> | ArrowTable | undefined
|
||||
|
||||
// Optional Arrow Schema for this table
|
||||
schema?: Schema | undefined
|
||||
|
||||
// Optional embedding function used to create embeddings
|
||||
embeddingFunction?: EmbeddingFunction<T> | undefined
|
||||
|
||||
// WriteOptions for this operation
|
||||
writeOptions?: WriteOptions | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI
|
||||
* @param uri The uri of the database.
|
||||
@@ -83,6 +116,235 @@ export async function connect (arg: string | Partial<ConnectionOptions>): Promis
|
||||
return new LocalConnection(db, opts)
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Connection that allows you to open tables and create new ones.
|
||||
*
|
||||
* Connection could be local against filesystem or remote against a server.
|
||||
*/
|
||||
export interface Connection {
|
||||
uri: string
|
||||
|
||||
tableNames(): Promise<string[]>
|
||||
|
||||
/**
|
||||
* Open a table in the database.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
* @param embeddings An embedding function to use on this table
|
||||
*/
|
||||
openTable<T>(name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table, optionally initializing it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Array of Records to be inserted into the table
|
||||
* @param schema - An Arrow Schema that describe this table columns
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
* @param {WriteOptions} writeOptions - The write options to use when creating the table.
|
||||
*/
|
||||
createTable<T> ({ name, data, schema, embeddingFunction, writeOptions }: CreateTableOptions<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
*/
|
||||
createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {WriteOptions} options - The write options to use when creating the table.
|
||||
*/
|
||||
createTable (name: string, data: Array<Record<string, unknown>>, options: WriteOptions): Promise<Table>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
*/
|
||||
createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
* @param {WriteOptions} options - The write options to use when creating the table.
|
||||
*/
|
||||
createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>, options: WriteOptions): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Drop an existing table.
|
||||
* @param name The name of the table to drop.
|
||||
*/
|
||||
dropTable(name: string): Promise<void>
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Table is the collection of Records. Each Record has one or more vector fields.
|
||||
*/
|
||||
export interface Table<T = number[]> {
|
||||
name: string
|
||||
|
||||
/**
|
||||
* Creates a search query to find the nearest neighbors of the given search term
|
||||
* @param query The query search term
|
||||
*/
|
||||
search: (query: T) => Query<T>
|
||||
|
||||
/**
|
||||
* Insert records into this Table.
|
||||
*
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
add: (data: Array<Record<string, unknown>>) => Promise<number>
|
||||
|
||||
/**
|
||||
* Insert records into this Table, replacing its contents.
|
||||
*
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
overwrite: (data: Array<Record<string, unknown>>) => Promise<number>
|
||||
|
||||
/**
|
||||
* Create an ANN index on this Table vector index.
|
||||
*
|
||||
* @param indexParams The parameters of this Index, @see VectorIndexParams.
|
||||
*/
|
||||
createIndex: (indexParams: VectorIndexParams) => Promise<any>
|
||||
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
countRows: () => Promise<number>
|
||||
|
||||
/**
|
||||
* Delete rows from this table.
|
||||
*
|
||||
* This can be used to delete a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param filter A filter in the same format used by a sql WHERE clause. The
|
||||
* filter must not be empty.
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [1, 2]},
|
||||
* {id: 2, vector: [3, 4]},
|
||||
* {id: 3, vector: [5, 6]},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
* await tbl.delete("id = 2")
|
||||
* await tbl.countRows() // Returns 2
|
||||
* ```
|
||||
*
|
||||
* If you have a list of values to delete, you can combine them into a
|
||||
* stringified list and use the `IN` operator:
|
||||
*
|
||||
* ```ts
|
||||
* const to_remove = [1, 5];
|
||||
* await tbl.delete(`id IN (${to_remove.join(",")})`)
|
||||
* await tbl.countRows() // Returns 1
|
||||
* ```
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* This can be used to update a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [3, 3], name: 'Ye'},
|
||||
* {id: 2, vector: [4, 4], name: 'Mike'},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
*
|
||||
* await tbl.update({
|
||||
* filter: "id = 2",
|
||||
* updates: { vector: [2, 2], name: "Michael" },
|
||||
* })
|
||||
*
|
||||
* let results = await tbl.search([1, 1]).execute();
|
||||
* // Returns [
|
||||
* // {id: 2, vector: [2, 2], name: 'Michael'}
|
||||
* // {id: 1, vector: [3, 3], name: 'Ye'}
|
||||
* // ]
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
listIndices: () => Promise<VectorIndex[]>
|
||||
|
||||
/**
|
||||
* Get statistics about an index.
|
||||
*/
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
|
||||
export interface UpdateArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set
|
||||
*/
|
||||
values: Record<string, Literal>
|
||||
}
|
||||
|
||||
export interface UpdateSqlArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set as SQL expressions.
|
||||
*/
|
||||
valuesSql: Record<string, string>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
uuid: string
|
||||
}
|
||||
|
||||
export interface IndexStats {
|
||||
numIndexedRows: number | null
|
||||
numUnindexedRows: number | null
|
||||
}
|
||||
|
||||
/**
|
||||
* A connection to a LanceDB database.
|
||||
*/
|
||||
@@ -430,6 +692,83 @@ export interface CompactionMetrics {
|
||||
filesAdded: number
|
||||
}
|
||||
|
||||
/// Config to build IVF_PQ index.
|
||||
///
|
||||
export interface IvfPQIndexConfig {
|
||||
/**
|
||||
* The column to be indexed
|
||||
*/
|
||||
column?: string
|
||||
|
||||
/**
|
||||
* A unique name for the index
|
||||
*/
|
||||
index_name?: string
|
||||
|
||||
/**
|
||||
* Metric type, L2 or Cosine
|
||||
*/
|
||||
metric_type?: MetricType
|
||||
|
||||
/**
|
||||
* The number of partitions this index
|
||||
*/
|
||||
num_partitions?: number
|
||||
|
||||
/**
|
||||
* The max number of iterations for kmeans training.
|
||||
*/
|
||||
max_iters?: number
|
||||
|
||||
/**
|
||||
* Train as optimized product quantization.
|
||||
*/
|
||||
use_opq?: boolean
|
||||
|
||||
/**
|
||||
* Number of subvectors to build PQ code
|
||||
*/
|
||||
num_sub_vectors?: number
|
||||
/**
|
||||
* The number of bits to present one PQ centroid.
|
||||
*/
|
||||
num_bits?: number
|
||||
|
||||
/**
|
||||
* Max number of iterations to train OPQ, if `use_opq` is true.
|
||||
*/
|
||||
max_opq_iters?: number
|
||||
|
||||
/**
|
||||
* Replace an existing index with the same name if it exists.
|
||||
*/
|
||||
replace?: boolean
|
||||
|
||||
type: 'ivf_pq'
|
||||
}
|
||||
|
||||
export type VectorIndexParams = IvfPQIndexConfig
|
||||
|
||||
/**
|
||||
* Write mode for writing a table.
|
||||
*/
|
||||
export enum WriteMode {
|
||||
/** Create a new {@link Table}. */
|
||||
Create = 'create',
|
||||
/** Overwrite the existing {@link Table} if presented. */
|
||||
Overwrite = 'overwrite',
|
||||
/** Append new data to the table. */
|
||||
Append = 'append'
|
||||
}
|
||||
|
||||
/**
|
||||
* Write options when creating a Table.
|
||||
*/
|
||||
export interface WriteOptions {
|
||||
/** A {@link WriteMode} to use on this operation */
|
||||
writeMode?: WriteMode
|
||||
}
|
||||
|
||||
export class DefaultWriteOptions implements WriteOptions {
|
||||
writeMode = WriteMode.Create
|
||||
}
|
||||
@@ -438,3 +777,23 @@ export function isWriteOptions (value: any): value is WriteOptions {
|
||||
return Object.keys(value).length === 1 &&
|
||||
(value.writeMode === undefined || typeof value.writeMode === 'string')
|
||||
}
|
||||
|
||||
/**
|
||||
* Distance metrics type.
|
||||
*/
|
||||
export enum MetricType {
|
||||
/**
|
||||
* Euclidean distance
|
||||
*/
|
||||
L2 = 'l2',
|
||||
|
||||
/**
|
||||
* Cosine distance
|
||||
*/
|
||||
Cosine = 'cosine',
|
||||
|
||||
/**
|
||||
* Dot product
|
||||
*/
|
||||
Dot = 'dot'
|
||||
}
|
||||
|
||||
180
node/src/integration_test/test.ts
Normal file
180
node/src/integration_test/test.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import * as chai from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { tmpdir } from 'os'
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
|
||||
const assert = chai.assert
|
||||
chai.use(chaiAsPromised)
|
||||
|
||||
describe('LanceDB AWS Integration test', function () {
|
||||
it('s3+ddb schema is processed correctly', async function () {
|
||||
this.timeout(15000)
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
|
||||
const data = [{ vector: Array(128).fill(1.0) }]
|
||||
|
||||
const tableName = uuidv4()
|
||||
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
|
||||
|
||||
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
|
||||
await Promise.allSettled(futs)
|
||||
|
||||
table = await conn.openTable(tableName)
|
||||
assert.equal(await table.countRows(), 6)
|
||||
})
|
||||
})
|
||||
|
||||
describe('LanceDB Mirrored Store Integration test', function () {
|
||||
it('s3://...?mirroredStore=... param is processed correctly', async function () {
|
||||
this.timeout(600000)
|
||||
|
||||
const dir = tmpdir()
|
||||
console.log(dir)
|
||||
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
|
||||
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 3 }))
|
||||
|
||||
const tableName = uuidv4()
|
||||
|
||||
// try create table and check if it's mirrored
|
||||
const t = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
|
||||
|
||||
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be three dirs
|
||||
assert.equal(files.length, 3)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.manifest'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||
})
|
||||
})
|
||||
|
||||
// try create index and check if it's mirrored
|
||||
await t.createIndex({ column: 'vector', type: 'ivf_pq' })
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be four dirs
|
||||
assert.equal(files.length, 4)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
|
||||
// Two TXs now
|
||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 2)
|
||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||
assert.isTrue(files[1].name.endsWith('.txn'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].isFile())
|
||||
assert.isTrue(files[0].name.endsWith('.idx'))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// try delete and check if it's mirrored
|
||||
await t.delete('id = 0')
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be five dirs
|
||||
assert.equal(files.length, 5)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
assert.isTrue(files[3].isDirectory())
|
||||
assert.isTrue(files[4].isDirectory())
|
||||
|
||||
// Three TXs now
|
||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 3)
|
||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||
assert.isTrue(files[1].name.endsWith('.txn'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].isFile())
|
||||
assert.isTrue(files[0].name.endsWith('.idx'))
|
||||
})
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_deletions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.arrow'))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -14,16 +14,10 @@
|
||||
|
||||
import { Vector, tableFromIPC } from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { type MetricType } from './types'
|
||||
import { type MetricType } from '.'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
// const { tableSearch } = require('../native.js')
|
||||
|
||||
const tableSearch = async function (args: any, arg2: any): Promise<any> {
|
||||
return await new Promise((resolve, reject) => {
|
||||
resolve('')
|
||||
})
|
||||
}
|
||||
const { tableSearch } = require('../native.js')
|
||||
|
||||
/**
|
||||
* A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
@@ -13,15 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
type Table, type VectorIndexParams,
|
||||
type VectorIndex,
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats,
|
||||
type UpdateArgs, type UpdateSqlArgs,
|
||||
type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions,
|
||||
type WriteOptions
|
||||
} from '../types'
|
||||
import { type EmbeddingFunction } from '../embedding/embedding_function'
|
||||
type UpdateArgs, type UpdateSqlArgs
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
import { Vector, Table as ArrowTable } from 'apache-arrow'
|
||||
|
||||
57
node/src/test/embedding/openai.ts
Normal file
57
node/src/test/embedding/openai.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
|
||||
import { OpenAIEmbeddingFunction } from '../../embedding/openai'
|
||||
import { isEmbeddingFunction } from '../../embedding/embedding_function'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { OpenAIApi } = require('openai')
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { stub } = require('sinon')
|
||||
|
||||
describe('OpenAPIEmbeddings', function () {
|
||||
const stubValue = {
|
||||
data: {
|
||||
data: [
|
||||
{
|
||||
embedding: Array(1536).fill(1.0)
|
||||
},
|
||||
{
|
||||
embedding: Array(1536).fill(2.0)
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
describe('#embed', function () {
|
||||
it('should create vector embeddings', async function () {
|
||||
const openAIStub = stub(OpenAIApi.prototype, 'createEmbedding').returns(stubValue)
|
||||
const f = new OpenAIEmbeddingFunction('text', 'sk-key')
|
||||
const vectors = await f.embed(['abc', 'def'])
|
||||
assert.isTrue(openAIStub.calledOnce)
|
||||
assert.equal(vectors.length, 2)
|
||||
assert.deepEqual(vectors[0], stubValue.data.data[0].embedding)
|
||||
assert.deepEqual(vectors[1], stubValue.data.data[1].embedding)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isEmbeddingFunction', function () {
|
||||
it('should match the isEmbeddingFunction guard', function () {
|
||||
assert.isTrue(isEmbeddingFunction(new OpenAIEmbeddingFunction('text', 'sk-key')))
|
||||
})
|
||||
})
|
||||
})
|
||||
76
node/src/test/io.ts
Normal file
76
node/src/test/io.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// IO tests
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type ConnectionOptions } from '../index'
|
||||
|
||||
describe('LanceDB S3 client', function () {
|
||||
if (process.env.TEST_S3_BASE_URL != null) {
|
||||
const baseUri = process.env.TEST_S3_BASE_URL
|
||||
it('should have a valid url', async function () {
|
||||
const opts = { uri: `${baseUri}/valid_url` }
|
||||
const table = await createTestDB(opts, 2, 20)
|
||||
const con = await lancedb.connect(opts)
|
||||
assert.equal(con.uri, opts.uri)
|
||||
|
||||
const results = await table.search([0.1, 0.3]).limit(5).execute()
|
||||
assert.equal(results.length, 5)
|
||||
}).timeout(10_000)
|
||||
} else {
|
||||
describe.skip('Skip S3 test', function () {})
|
||||
}
|
||||
|
||||
if (process.env.TEST_S3_BASE_URL != null && process.env.TEST_AWS_ACCESS_KEY_ID != null && process.env.TEST_AWS_SECRET_ACCESS_KEY != null) {
|
||||
const baseUri = process.env.TEST_S3_BASE_URL
|
||||
it('use custom credentials', async function () {
|
||||
const opts: ConnectionOptions = {
|
||||
uri: `${baseUri}/custom_credentials`,
|
||||
awsCredentials: {
|
||||
accessKeyId: process.env.TEST_AWS_ACCESS_KEY_ID as string,
|
||||
secretKey: process.env.TEST_AWS_SECRET_ACCESS_KEY as string
|
||||
}
|
||||
}
|
||||
const table = await createTestDB(opts, 2, 20)
|
||||
console.log(table)
|
||||
const con = await lancedb.connect(opts)
|
||||
console.log(con)
|
||||
assert.equal(con.uri, opts.uri)
|
||||
|
||||
const results = await table.search([0.1, 0.3]).limit(5).execute()
|
||||
assert.equal(results.length, 5)
|
||||
}).timeout(10_000)
|
||||
} else {
|
||||
describe.skip('Skip S3 test', function () {})
|
||||
}
|
||||
})
|
||||
|
||||
async function createTestDB (opts: ConnectionOptions, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
|
||||
const con = await lancedb.connect(opts)
|
||||
|
||||
const data = []
|
||||
for (let i = 0; i < numRows; i++) {
|
||||
const vector = []
|
||||
for (let j = 0; j < numDimensions; j++) {
|
||||
vector.push(i + (j * 0.1))
|
||||
}
|
||||
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
|
||||
}
|
||||
|
||||
return await con.createTable('vectors_2', data)
|
||||
}
|
||||
616
node/src/test/test.ts
Normal file
616
node/src/test/test.ts
Normal file
@@ -0,0 +1,616 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { track } from 'temp'
|
||||
import * as chai from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions, type LocalTable } from '../index'
|
||||
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
||||
|
||||
const expect = chai.expect
|
||||
const assert = chai.assert
|
||||
chai.use(chaiAsPromised)
|
||||
|
||||
describe('LanceDB client', function () {
|
||||
describe('when creating a connection to lancedb', function () {
|
||||
it('should have a valid url', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should accept an options object', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect({ uri })
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should accept custom aws credentials', async function () {
|
||||
const uri = await createTestDB()
|
||||
const awsCredentials: AwsCredentials = {
|
||||
accessKeyId: '',
|
||||
secretKey: ''
|
||||
}
|
||||
const con = await lancedb.connect({ uri, awsCredentials })
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should return the existing table names', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('when querying an existing dataset', function () {
|
||||
it('should open a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(table.name, 'vectors')
|
||||
})
|
||||
|
||||
it('execute a query', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).execute()
|
||||
|
||||
assert.equal(results.length, 2)
|
||||
assert.equal(results[0].price, 10)
|
||||
const vector = results[0].vector as Float32Array
|
||||
assert.approximately(vector[0], 0.0, 0.2)
|
||||
assert.approximately(vector[0], 0.1, 0.3)
|
||||
})
|
||||
|
||||
it('limits # of results', async function () {
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
let results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 1)
|
||||
|
||||
// there is a default limit if unspecified
|
||||
results = await table.search([0.1, 0.3]).execute()
|
||||
assert.equal(results.length, 10)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause without vector search', async function () {
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
|
||||
const assertResults = (results: Array<Record<string, unknown>>) => {
|
||||
assert.equal(results.length, 50)
|
||||
}
|
||||
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = (await con.openTable('vectors')) as LocalTable
|
||||
let results = await table.filter('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
results = await table.where('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause', async function () {
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
|
||||
const assertResults = (results: Array<Record<string, unknown>>) => {
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 2)
|
||||
}
|
||||
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
let results = await table.search([0.1, 0.1]).filter('id == 2').execute()
|
||||
assertResults(results)
|
||||
results = await table.search([0.1, 0.1]).where('id == 2').execute()
|
||||
assertResults(results)
|
||||
})
|
||||
|
||||
it('should correctly process prefilter/postfilter', async function () {
|
||||
const uri = await createTestDB(16, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
// post filter should return less than the limit
|
||||
let results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(false).execute()
|
||||
assert.isTrue(results.length < 10)
|
||||
|
||||
// pre filter should return exactly the limit
|
||||
results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute()
|
||||
assert.isTrue(results.length === 10)
|
||||
})
|
||||
|
||||
it('select only a subset of columns', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.1]).select(['is_active']).execute()
|
||||
assert.equal(results.length, 2)
|
||||
// vector and _distance are always returned
|
||||
assert.isDefined(results[0].vector)
|
||||
assert.isDefined(results[0]._distance)
|
||||
assert.isDefined(results[0].is_active)
|
||||
|
||||
assert.isUndefined(results[0].id)
|
||||
assert.isUndefined(results[0].name)
|
||||
assert.isUndefined(results[0].price)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when creating a new dataset', function () {
|
||||
it('create an empty table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('id', new Int32()), new Field('name', new Utf8())]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table with a empty data array', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('id', new Int32()), new Field('name', new Utf8())]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema, data: [] })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table from an Arrow Table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const i32s = new Int32Array(new Array<number>(10))
|
||||
const i32 = makeVector(i32s)
|
||||
|
||||
const data = new ArrowTable({ vector: i32 })
|
||||
|
||||
const table = await con.createTable({ name: 'vectors', data })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.equal(await table.countRows(), 10)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('creates a new table from javascript objects', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10 },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50 }
|
||||
]
|
||||
|
||||
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
|
||||
const table = await con.createTable(tableName, data)
|
||||
assert.equal(table.name, tableName)
|
||||
assert.equal(await table.countRows(), 2)
|
||||
})
|
||||
|
||||
it('fails to create a new table when the vector column is missing', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, price: 10 }
|
||||
]
|
||||
|
||||
const create = con.createTable('missing_vector', data)
|
||||
await expect(create).to.be.rejectedWith(Error, 'column \'vector\' is missing')
|
||||
})
|
||||
|
||||
it('use overwrite flag to overwrite existing table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10 },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50 }
|
||||
]
|
||||
|
||||
const tableName = 'overwrite'
|
||||
await con.createTable(tableName, data, { writeMode: WriteMode.Create })
|
||||
|
||||
const newData = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10 },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50 },
|
||||
{ id: 3, vector: [1.1, 1.2], price: 50 }
|
||||
]
|
||||
|
||||
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists')
|
||||
|
||||
const table = await con.createTable(tableName, newData, { writeMode: WriteMode.Overwrite })
|
||||
assert.equal(table.name, tableName)
|
||||
assert.equal(await table.countRows(), 3)
|
||||
})
|
||||
|
||||
it('appends records to an existing table ', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
|
||||
]
|
||||
|
||||
const table = await con.createTable('vectors', data)
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
const dataAdd = [
|
||||
{ id: 3, vector: [2.1, 2.2], price: 10, name: 'c' },
|
||||
{ id: 4, vector: [3.1, 3.2], price: 50, name: 'd' }
|
||||
]
|
||||
await table.add(dataAdd)
|
||||
assert.equal(await table.countRows(), 4)
|
||||
})
|
||||
|
||||
it('overwrite all records in a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
const dataOver = [
|
||||
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
|
||||
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
|
||||
]
|
||||
await table.overwrite(dataOver)
|
||||
assert.equal(await table.countRows(), 2)
|
||||
})
|
||||
|
||||
it('can update records in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update the records using a literal value', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', values: { price: 100 } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update every record in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 100)
|
||||
})
|
||||
|
||||
it('can delete records from a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.delete('price = 10')
|
||||
assert.equal(await table.countRows(), 1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty-after-delete dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
await table.add([{ vector: Array(128).fill(0.1) }])
|
||||
// https://github.com/lancedb/lance/issues/1635
|
||||
await table.delete('true')
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when creating a vector index', function () {
|
||||
it('overwrite all records in a table', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
|
||||
|
||||
it('replace an existing index', async function () {
|
||||
const uri = await createTestDB(16, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
|
||||
// Replace should fail if the index already exists
|
||||
await expect(table.createIndex({
|
||||
type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2, replace: false
|
||||
})
|
||||
).to.be.rejectedWith('LanceError(Index)')
|
||||
|
||||
// Default replace = true
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
}).timeout(50_000)
|
||||
|
||||
it('it should fail when the column is not a vector', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/)
|
||||
})
|
||||
|
||||
it('it should fail when the column is not a vector', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
|
||||
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
|
||||
})
|
||||
|
||||
it('should be able to list index and stats', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
|
||||
const indices = await table.listIndices()
|
||||
expect(indices).to.have.lengthOf(1)
|
||||
expect(indices[0].name).to.equal('vector_idx')
|
||||
expect(indices[0].uuid).to.not.be.equal(undefined)
|
||||
expect(indices[0].columns).to.have.lengthOf(1)
|
||||
expect(indices[0].columns[0]).to.equal('vector')
|
||||
|
||||
const stats = await table.indexStats(indices[0].uuid)
|
||||
expect(stats.numIndexedRows).to.equal(300)
|
||||
expect(stats.numUnindexedRows).to.equal(0)
|
||||
}).timeout(50_000)
|
||||
})
|
||||
|
||||
describe('when using a custom embedding function', function () {
|
||||
class TextEmbedding implements EmbeddingFunction<string> {
|
||||
sourceColumn: string
|
||||
|
||||
constructor (targetColumn: string) {
|
||||
this.sourceColumn = targetColumn
|
||||
}
|
||||
|
||||
_embedding_map = new Map<string, number[]>([
|
||||
['foo', [2.1, 2.2]],
|
||||
['bar', [3.1, 3.2]]
|
||||
])
|
||||
|
||||
async embed (data: string[]): Promise<number[][]> {
|
||||
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
|
||||
}
|
||||
}
|
||||
|
||||
it('should encode the original data into embeddings', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
const embeddings = new TextEmbedding('name')
|
||||
|
||||
const data = [
|
||||
{ price: 10, name: 'foo' },
|
||||
{ price: 50, name: 'bar' }
|
||||
]
|
||||
const table = await con.createTable('vectors', data, embeddings, { writeMode: WriteMode.Create })
|
||||
const results = await table.search('foo').execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
|
||||
it('should create embeddings for Arrow Table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
const embeddingFunction = new TextEmbedding('name')
|
||||
|
||||
const names = vectorFromArray(['foo', 'bar'], new Utf8())
|
||||
const data = new ArrowTable({ name: names })
|
||||
|
||||
const table = await con.createTable({ name: 'vectors', data, embeddingFunction })
|
||||
assert.equal(table.name, 'vectors')
|
||||
const results = await table.search('foo').execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Remote LanceDB client', function () {
|
||||
describe('when the server is not reachable', function () {
|
||||
it('produces a network error', async function () {
|
||||
const con = await lancedb.connect({
|
||||
uri: 'db://test-1234',
|
||||
region: 'asdfasfasfdf',
|
||||
apiKey: 'some-api-key'
|
||||
})
|
||||
|
||||
// GET
|
||||
try {
|
||||
await con.tableNames()
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
|
||||
// POST
|
||||
try {
|
||||
await con.createTable({ name: 'vectors', schema: new Schema([]) })
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
|
||||
// Search
|
||||
const table = await con.openTable('vectors')
|
||||
try {
|
||||
await table.search([0.1, 0.3]).execute()
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query object', function () {
|
||||
it('sets custom parameters', async function () {
|
||||
const query = new Query([0.1, 0.3])
|
||||
.limit(1)
|
||||
.metricType(MetricType.Cosine)
|
||||
.refineFactor(100)
|
||||
.select(['a', 'b'])
|
||||
.nprobes(20) as Record<string, any>
|
||||
assert.equal(query._limit, 1)
|
||||
assert.equal(query._metricType, MetricType.Cosine)
|
||||
assert.equal(query._refineFactor, 100)
|
||||
assert.equal(query._nprobes, 20)
|
||||
assert.deepEqual(query._select, ['a', 'b'])
|
||||
})
|
||||
})
|
||||
|
||||
async function createTestDB (numDimensions: number = 2, numRows: number = 2): Promise<string> {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = []
|
||||
for (let i = 0; i < numRows; i++) {
|
||||
const vector = []
|
||||
for (let j = 0; j < numDimensions; j++) {
|
||||
vector.push(i + (j * 0.1))
|
||||
}
|
||||
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
|
||||
}
|
||||
|
||||
await con.createTable('vectors', data)
|
||||
return dir
|
||||
}
|
||||
|
||||
describe('Drop table', function () {
|
||||
it('drop a table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ price: 10, name: 'foo', vector: [1, 2, 3] },
|
||||
{ price: 50, name: 'bar', vector: [4, 5, 6] }
|
||||
]
|
||||
await con.createTable('t1', data)
|
||||
await con.createTable('t2', data)
|
||||
|
||||
assert.deepEqual(await con.tableNames(), ['t1', 't2'])
|
||||
|
||||
await con.dropTable('t1')
|
||||
assert.deepEqual(await con.tableNames(), ['t2'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('WriteOptions', function () {
|
||||
context('#isWriteOptions', function () {
|
||||
it('should not match empty object', function () {
|
||||
assert.equal(isWriteOptions({}), false)
|
||||
})
|
||||
it('should match write options', function () {
|
||||
assert.equal(isWriteOptions({ writeMode: WriteMode.Create }), true)
|
||||
})
|
||||
it('should match undefined write mode', function () {
|
||||
assert.equal(isWriteOptions({ writeMode: undefined }), true)
|
||||
})
|
||||
it('should match default write options', function () {
|
||||
assert.equal(isWriteOptions(new DefaultWriteOptions()), true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Compact and cleanup', function () {
|
||||
it('can cleanup after compaction', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ price: 10, name: 'foo', vector: [1, 2, 3] },
|
||||
{ price: 50, name: 'bar', vector: [4, 5, 6] }
|
||||
]
|
||||
const table = await con.createTable('t1', data) as LocalTable
|
||||
|
||||
const newData = [
|
||||
{ price: 30, name: 'baz', vector: [7, 8, 9] }
|
||||
]
|
||||
await table.add(newData)
|
||||
|
||||
const compactionMetrics = await table.compactFiles({
|
||||
numThreads: 2
|
||||
})
|
||||
assert.equal(compactionMetrics.fragmentsRemoved, 2)
|
||||
assert.equal(compactionMetrics.fragmentsAdded, 1)
|
||||
assert.equal(await table.countRows(), 3)
|
||||
|
||||
await table.cleanupOldVersions()
|
||||
assert.equal(await table.countRows(), 3)
|
||||
|
||||
// should have no effect, but this validates the arguments are parsed.
|
||||
await table.compactFiles({
|
||||
targetRowsPerFragment: 102410,
|
||||
maxRowsPerGroup: 1024,
|
||||
materializeDeletions: true,
|
||||
materializeDeletionsThreshold: 0.5,
|
||||
numThreads: 2
|
||||
})
|
||||
|
||||
const cleanupMetrics = await table.cleanupOldVersions(0, true)
|
||||
assert.isAtLeast(cleanupMetrics.bytesRemoved, 1)
|
||||
assert.isAtLeast(cleanupMetrics.oldVersions, 1)
|
||||
assert.equal(await table.countRows(), 3)
|
||||
})
|
||||
})
|
||||
45
node/src/test/util.ts
Normal file
45
node/src/test/util.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { toSQL } from '../util'
|
||||
import * as chai from 'chai'
|
||||
|
||||
const expect = chai.expect
|
||||
|
||||
describe('toSQL', function () {
|
||||
it('should turn string to SQL expression', function () {
|
||||
expect(toSQL('foo')).to.equal("'foo'")
|
||||
})
|
||||
|
||||
it('should turn number to SQL expression', function () {
|
||||
expect(toSQL(123)).to.equal('123')
|
||||
})
|
||||
|
||||
it('should turn boolean to SQL expression', function () {
|
||||
expect(toSQL(true)).to.equal('TRUE')
|
||||
})
|
||||
|
||||
it('should turn null to SQL expression', function () {
|
||||
expect(toSQL(null)).to.equal('NULL')
|
||||
})
|
||||
|
||||
it('should turn Date to SQL expression', function () {
|
||||
const date = new Date('05 October 2011 14:48 UTC')
|
||||
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
|
||||
})
|
||||
|
||||
it('should turn array to SQL expression', function () {
|
||||
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
|
||||
})
|
||||
})
|
||||
@@ -1,375 +0,0 @@
|
||||
|
||||
import {
|
||||
type Schema,
|
||||
type Table as ArrowTable
|
||||
} from 'apache-arrow'
|
||||
|
||||
import { type Literal } from './util'
|
||||
import type { EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { type Query } from './query'
|
||||
|
||||
export interface AwsCredentials {
|
||||
accessKeyId: string
|
||||
|
||||
secretKey: string
|
||||
|
||||
sessionToken?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Write options when creating a Table.
|
||||
*/
|
||||
export interface WriteOptions {
|
||||
/** A {@link WriteMode} to use on this operation */
|
||||
writeMode?: WriteMode
|
||||
}
|
||||
|
||||
/**
|
||||
* Write mode for writing a table.
|
||||
*/
|
||||
export enum WriteMode {
|
||||
/** Create a new {@link Table}. */
|
||||
Create = 'create',
|
||||
/** Overwrite the existing {@link Table} if presented. */
|
||||
Overwrite = 'overwrite',
|
||||
/** Append new data to the table. */
|
||||
Append = 'append'
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Connection that allows you to open tables and create new ones.
|
||||
*
|
||||
* Connection could be local against filesystem or remote against a server.
|
||||
*/
|
||||
export interface Connection {
|
||||
uri: string
|
||||
|
||||
tableNames(): Promise<string[]>
|
||||
|
||||
/**
|
||||
* Open a table in the database.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
* @param embeddings An embedding function to use on this table
|
||||
*/
|
||||
openTable<T>(name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table, optionally initializing it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Array of Records to be inserted into the table
|
||||
* @param schema - An Arrow Schema that describe this table columns
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
* @param {WriteOptions} writeOptions - The write options to use when creating the table.
|
||||
*/
|
||||
createTable<T> ({ name, data, schema, embeddingFunction, writeOptions }: CreateTableOptions<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
*/
|
||||
createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {WriteOptions} options - The write options to use when creating the table.
|
||||
*/
|
||||
createTable (name: string, data: Array<Record<string, unknown>>, options: WriteOptions): Promise<Table>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
*/
|
||||
createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
* @param {WriteOptions} options - The write options to use when creating the table.
|
||||
*/
|
||||
createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>, options: WriteOptions): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Drop an existing table.
|
||||
* @param name The name of the table to drop.
|
||||
*/
|
||||
dropTable(name: string): Promise<void>
|
||||
|
||||
}
|
||||
|
||||
export interface CreateTableOptions<T> {
|
||||
// Name of Table
|
||||
name: string
|
||||
|
||||
// Data to insert into the Table
|
||||
data?: Array<Record<string, unknown>> | ArrowTable | undefined
|
||||
|
||||
// Optional Arrow Schema for this table
|
||||
schema?: Schema | undefined
|
||||
|
||||
// Optional embedding function used to create embeddings
|
||||
embeddingFunction?: EmbeddingFunction<T> | undefined
|
||||
|
||||
// WriteOptions for this operation
|
||||
writeOptions?: WriteOptions | undefined
|
||||
}
|
||||
|
||||
export interface ConnectionOptions {
|
||||
uri: string
|
||||
|
||||
awsCredentials?: AwsCredentials
|
||||
|
||||
awsRegion?: string
|
||||
|
||||
// API key for the remote connections
|
||||
apiKey?: string
|
||||
// Region to connect
|
||||
region?: string
|
||||
|
||||
// override the host for the remote connections
|
||||
hostOverride?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Distance metrics type.
|
||||
*/
|
||||
export enum MetricType {
|
||||
/**
|
||||
* Euclidean distance
|
||||
*/
|
||||
L2 = 'l2',
|
||||
|
||||
/**
|
||||
* Cosine distance
|
||||
*/
|
||||
Cosine = 'cosine',
|
||||
|
||||
/**
|
||||
* Dot product
|
||||
*/
|
||||
Dot = 'dot'
|
||||
}
|
||||
|
||||
/// Config to build IVF_PQ index.
|
||||
///
|
||||
export interface IvfPQIndexConfig {
|
||||
/**
|
||||
* The column to be indexed
|
||||
*/
|
||||
column?: string
|
||||
|
||||
/**
|
||||
* A unique name for the index
|
||||
*/
|
||||
index_name?: string
|
||||
|
||||
/**
|
||||
* Metric type, L2 or Cosine
|
||||
*/
|
||||
metric_type?: MetricType
|
||||
|
||||
/**
|
||||
* The number of partitions this index
|
||||
*/
|
||||
num_partitions?: number
|
||||
|
||||
/**
|
||||
* The max number of iterations for kmeans training.
|
||||
*/
|
||||
max_iters?: number
|
||||
|
||||
/**
|
||||
* Train as optimized product quantization.
|
||||
*/
|
||||
use_opq?: boolean
|
||||
|
||||
/**
|
||||
* Number of subvectors to build PQ code
|
||||
*/
|
||||
num_sub_vectors?: number
|
||||
/**
|
||||
* The number of bits to present one PQ centroid.
|
||||
*/
|
||||
num_bits?: number
|
||||
|
||||
/**
|
||||
* Max number of iterations to train OPQ, if `use_opq` is true.
|
||||
*/
|
||||
max_opq_iters?: number
|
||||
|
||||
/**
|
||||
* Replace an existing index with the same name if it exists.
|
||||
*/
|
||||
replace?: boolean
|
||||
|
||||
type: 'ivf_pq'
|
||||
}
|
||||
|
||||
export type VectorIndexParams = IvfPQIndexConfig
|
||||
|
||||
/**
|
||||
* A LanceDB Table is the collection of Records. Each Record has one or more vector fields.
|
||||
*/
|
||||
export interface Table<T = number[]> {
|
||||
name: string
|
||||
|
||||
/**
|
||||
* Creates a search query to find the nearest neighbors of the given search term
|
||||
* @param query The query search term
|
||||
*/
|
||||
search: (query: T) => Query<T>
|
||||
|
||||
/**
|
||||
* Insert records into this Table.
|
||||
*
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
add: (data: Array<Record<string, unknown>>) => Promise<number>
|
||||
|
||||
/**
|
||||
* Insert records into this Table, replacing its contents.
|
||||
*
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
overwrite: (data: Array<Record<string, unknown>>) => Promise<number>
|
||||
|
||||
/**
|
||||
* Create an ANN index on this Table vector index.
|
||||
*
|
||||
* @param indexParams The parameters of this Index, @see VectorIndexParams.
|
||||
*/
|
||||
createIndex: (indexParams: VectorIndexParams) => Promise<any>
|
||||
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
countRows: () => Promise<number>
|
||||
|
||||
/**
|
||||
* Delete rows from this table.
|
||||
*
|
||||
* This can be used to delete a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param filter A filter in the same format used by a sql WHERE clause. The
|
||||
* filter must not be empty.
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [1, 2]},
|
||||
* {id: 2, vector: [3, 4]},
|
||||
* {id: 3, vector: [5, 6]},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
* await tbl.delete("id = 2")
|
||||
* await tbl.countRows() // Returns 2
|
||||
* ```
|
||||
*
|
||||
* If you have a list of values to delete, you can combine them into a
|
||||
* stringified list and use the `IN` operator:
|
||||
*
|
||||
* ```ts
|
||||
* const to_remove = [1, 5];
|
||||
* await tbl.delete(`id IN (${to_remove.join(",")})`)
|
||||
* await tbl.countRows() // Returns 1
|
||||
* ```
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* This can be used to update a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [3, 3], name: 'Ye'},
|
||||
* {id: 2, vector: [4, 4], name: 'Mike'},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
*
|
||||
* await tbl.update({
|
||||
* filter: "id = 2",
|
||||
* updates: { vector: [2, 2], name: "Michael" },
|
||||
* })
|
||||
*
|
||||
* let results = await tbl.search([1, 1]).execute();
|
||||
* // Returns [
|
||||
* // {id: 2, vector: [2, 2], name: 'Michael'}
|
||||
* // {id: 1, vector: [3, 3], name: 'Ye'}
|
||||
* // ]
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
listIndices: () => Promise<VectorIndex[]>
|
||||
|
||||
/**
|
||||
* Get statistics about an index.
|
||||
*/
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
export interface UpdateArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set
|
||||
*/
|
||||
values: Record<string, Literal>
|
||||
}
|
||||
|
||||
export interface UpdateSqlArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set as SQL expressions.
|
||||
*/
|
||||
valuesSql: Record<string, string>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
uuid: string
|
||||
}
|
||||
|
||||
export interface IndexStats {
|
||||
numIndexedRows: number | null
|
||||
numUnindexedRows: number | null
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.4
|
||||
current_version = 0.3.5
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -30,7 +30,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .util import fs_from_uri, safe_import_pandas, value_to_sql
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -913,30 +913,35 @@ class LanceTable(Table):
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
|
||||
def update(self, where: str, values: dict):
|
||||
def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
||||
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -952,18 +957,15 @@ class LanceTable(Table):
|
||||
2 2 [10.0, 10.0]
|
||||
|
||||
"""
|
||||
orig_data = self._dataset.to_table(filter=where).combine_chunks()
|
||||
if len(orig_data) == 0:
|
||||
return
|
||||
for col, val in values.items():
|
||||
i = orig_data.column_names.index(col)
|
||||
if i < 0:
|
||||
raise ValueError(f"Column {col} does not exist")
|
||||
orig_data = orig_data.set_column(
|
||||
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
|
||||
)
|
||||
self.delete(where)
|
||||
self.add(orig_data, mode="append")
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
register_event("update")
|
||||
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
|
||||
@@ -88,3 +91,53 @@ def safe_import_pandas():
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
|
||||
@value_to_sql.register(str)
|
||||
def _(value: str):
|
||||
return f"'{value}'"
|
||||
|
||||
|
||||
@value_to_sql.register(int)
|
||||
def _(value: int):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(float)
|
||||
def _(value: float):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(bool)
|
||||
def _(value: bool):
|
||||
return str(value).upper()
|
||||
|
||||
|
||||
@value_to_sql.register(type(None))
|
||||
def _(value: type(None)):
|
||||
return "NULL"
|
||||
|
||||
|
||||
@value_to_sql.register(datetime)
|
||||
def _(value: datetime):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(date)
|
||||
def _(value: date):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(list)
|
||||
def _(value: list):
|
||||
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
||||
|
||||
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.8.17",
|
||||
"pylance==0.8.21",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
"tqdm>=4.27.0",
|
||||
"aiohttp",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from datetime import timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
@@ -348,14 +348,79 @@ def test_update(db):
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 4
|
||||
assert table.version == 4
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_update_types(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[
|
||||
{
|
||||
"id": 0,
|
||||
"str": "foo",
|
||||
"float": 1.1,
|
||||
"timestamp": datetime(2021, 1, 1),
|
||||
"date": date(2021, 1, 1),
|
||||
"vector1": [1.0, 0.0],
|
||||
"vector2": [1.0, 1.0],
|
||||
}
|
||||
],
|
||||
)
|
||||
# Update with SQL
|
||||
table.update(
|
||||
values_sql=dict(
|
||||
id="1",
|
||||
str="'bar'",
|
||||
float="2.2",
|
||||
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
|
||||
date="DATE '2021-01-02'",
|
||||
vector1="[2.0, 2.0]",
|
||||
vector2="[3.0, 3.0]",
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=1,
|
||||
str="bar",
|
||||
float=2.2,
|
||||
timestamp=datetime(2021, 1, 2),
|
||||
date=date(2021, 1, 2),
|
||||
vector1=[2.0, 2.0],
|
||||
vector2=[3.0, 3.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
# Update with values
|
||||
table.update(
|
||||
values=dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=np.array([4.0, 4.0]),
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=[4.0, 4.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
|
||||
Reference in New Issue
Block a user