feat: change create table to accept Arrow table (#845)

This commit is contained in:
Lei Xu
2024-01-23 13:25:15 -08:00
committed by Weston Pace
parent 5ecbf971e2
commit 65c1d8bc4c
5 changed files with 586 additions and 160 deletions

View File

@@ -16,7 +16,8 @@ import { type Schema, Table as ArrowTable, tableFromIPC } from 'apache-arrow'
import {
createEmptyTable,
fromRecordsToBuffer,
fromTableToBuffer
fromTableToBuffer,
makeArrowTable
} from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
import { RemoteConnection } from './remote'
@@ -223,7 +224,7 @@ export interface Connection {
*/
createTable(
name: string,
data: Array<Record<string, unknown>>
data: Array<Record<string, unknown>> | ArrowTable
): Promise<Table>
/**
@@ -235,7 +236,7 @@ export interface Connection {
*/
createTable(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
options: WriteOptions
): Promise<Table>
@@ -248,7 +249,7 @@ export interface Connection {
*/
createTable<T>(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
embeddings: EmbeddingFunction<T>
): Promise<Table<T>>
/**
@@ -261,7 +262,7 @@ export interface Connection {
*/
createTable<T>(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
embeddings: EmbeddingFunction<T>,
options: WriteOptions
): Promise<Table<T>>
@@ -291,7 +292,7 @@ export interface Table<T = number[]> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
add: (data: Array<Record<string, unknown>>) => Promise<number>
add: (data: Array<Record<string, unknown>> | ArrowTable) => Promise<number>
/**
* Insert records into this Table, replacing its contents.
@@ -299,7 +300,9 @@ export interface Table<T = number[]> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
overwrite: (data: Array<Record<string, unknown>>) => Promise<number>
overwrite: (
data: Array<Record<string, unknown>> | ArrowTable
) => Promise<number>
/**
* Create an ANN index on this Table vector index.
@@ -544,7 +547,7 @@ export class LocalConnection implements Connection {
async createTable<T>(
name: string | CreateTableOptions<T>,
data?: Array<Record<string, unknown>>,
data?: Array<Record<string, unknown>> | ArrowTable,
optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>,
opt?: WriteOptions
): Promise<Table<T>> {
@@ -696,12 +699,20 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
async add (
data: Array<Record<string, unknown>> | ArrowTable
): Promise<number> {
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
return tableAdd
.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings, schema),
await fromTableToBuffer(tbl, this._embeddings, schema),
WriteMode.Append.toString(),
...getAwsArgs(this._options())
)
@@ -716,11 +727,19 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
async overwrite (
data: Array<Record<string, unknown>> | ArrowTable
): Promise<number> {
let buffer: Buffer
if (data instanceof ArrowTable) {
buffer = await fromTableToBuffer(data, this._embeddings)
} else {
buffer = await fromRecordsToBuffer(data, this._embeddings)
}
return tableAdd
.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings),
buffer,
WriteMode.Overwrite.toString(),
...getAwsArgs(this._options())
)

View File

@@ -129,7 +129,8 @@ export class Query<T = number[]> {
const newObject: Record<string, unknown> = {}
Object.keys(entry).forEach((key: string) => {
if (entry[key] instanceof Vector) {
newObject[key] = (entry[key] as Vector).toArray()
// toJSON() returns f16 array correctly
newObject[key] = (entry[key] as Vector).toJSON()
} else {
newObject[key] = entry[key]
}

View File

@@ -13,18 +13,29 @@
// limitations under the License.
import {
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
type EmbeddingFunction,
type Table,
type VectorIndexParams,
type Connection,
type ConnectionOptions,
type CreateTableOptions,
type VectorIndex,
type WriteOptions,
type IndexStats,
type UpdateArgs, type UpdateSqlArgs
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable
} from '../index'
import { Query } from '../query'
import { Vector, Table as ArrowTable } from 'apache-arrow'
import { HttpLancedbClient } from './client'
import { isEmbeddingFunction } from '../embedding/embedding_function'
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
import {
createEmptyTable,
fromRecordsToStreamBuffer,
fromTableToStreamBuffer
} from '../arrow'
import { toSQL } from '../util'
/**
@@ -54,7 +65,11 @@ export class RemoteConnection implements Connection {
} else {
server = opts.hostOverride
}
this._client = new HttpLancedbClient(server, opts.apiKey, opts.hostOverride === undefined ? undefined : this._dbName)
this._client = new HttpLancedbClient(
server,
opts.apiKey,
opts.hostOverride === undefined ? undefined : this._dbName
)
}
get uri (): string {
@@ -62,14 +77,26 @@ export class RemoteConnection implements Connection {
return 'db://' + this._client.uri
}
async tableNames (pageToken: string = '', limit: number = 10): Promise<string[]> {
const response = await this._client.get('/v1/table/', { limit, page_token: pageToken })
async tableNames (
pageToken: string = '',
limit: number = 10
): Promise<string[]> {
const response = await this._client.get('/v1/table/', {
limit,
page_token: pageToken
})
return response.data.tables
}
async openTable (name: string): Promise<Table>
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
async openTable<T>(
name: string,
embeddings: EmbeddingFunction<T>
): Promise<Table<T>>
async openTable<T>(
name: string,
embeddings?: EmbeddingFunction<T>
): Promise<Table<T>> {
if (embeddings !== undefined) {
return new RemoteTable(this._client, name, embeddings)
} else {
@@ -77,13 +104,21 @@ export class RemoteConnection implements Connection {
}
}
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
async createTable<T>(
nameOrOpts: string | CreateTableOptions<T>,
data?: Array<Record<string, unknown>> | ArrowTable,
optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>,
opt?: WriteOptions
): Promise<Table<T>> {
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
let schema
let embeddings: undefined | EmbeddingFunction<T>
let tableName: string
if (typeof nameOrOpts === 'string') {
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
if (
optsOrEmbedding !== undefined &&
isEmbeddingFunction(optsOrEmbedding)
) {
embeddings = optsOrEmbedding
}
tableName = nameOrOpts
@@ -95,14 +130,16 @@ export class RemoteConnection implements Connection {
let buffer: Buffer
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
function isEmpty (
data: Array<Record<string, unknown>> | ArrowTable<any>
): boolean {
if (data instanceof ArrowTable) {
return data.data.length === 0
return data.numRows === 0
}
return data.length === 0
}
if ((data === undefined) || isEmpty(data)) {
if (data === undefined || isEmpty(data)) {
if (schema === undefined) {
throw new Error('Either data or schema needs to defined')
}
@@ -121,9 +158,11 @@ export class RemoteConnection implements Connection {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
if (embeddings === undefined) {
@@ -139,8 +178,12 @@ export class RemoteConnection implements Connection {
}
export class RemoteQuery<T = number[]> extends Query<T> {
constructor (query: T, private readonly _client: HttpLancedbClient,
private readonly _name: string, embeddings?: EmbeddingFunction<T>) {
constructor (
query: T,
private readonly _client: HttpLancedbClient,
private readonly _name: string,
embeddings?: EmbeddingFunction<T>
) {
super(query, undefined, embeddings)
}
@@ -189,8 +232,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
private readonly _name: string
constructor (client: HttpLancedbClient, name: string)
constructor (client: HttpLancedbClient, name: string, embeddings: EmbeddingFunction<T>)
constructor (client: HttpLancedbClient, name: string, embeddings?: EmbeddingFunction<T>) {
constructor (
client: HttpLancedbClient,
name: string,
embeddings: EmbeddingFunction<T>
)
constructor (
client: HttpLancedbClient,
name: string,
embeddings?: EmbeddingFunction<T>
) {
this._client = client
this._name = name
this._embeddings = embeddings
@@ -201,22 +252,33 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
get schema (): Promise<any> {
return this._client.post(`/v1/table/${this._name}/describe/`).then(res => {
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
return res.data?.schema
})
return this._client
.post(`/v1/table/${this._name}/describe/`)
.then((res) => {
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return res.data?.schema
})
}
search (query: T): Query<T> {
return new RemoteQuery(query, this._client, this._name)//, this._embeddings_new)
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
}
async add (data: Array<Record<string, unknown>>): Promise<number> {
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
@@ -226,15 +288,23 @@ export class RemoteTable<T = number[]> implements Table<T> {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return data.length
return tbl.numRows
}
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
async overwrite (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data)
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
@@ -244,11 +314,13 @@ export class RemoteTable<T = number[]> implements Table<T> {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return data.length
return tbl.numRows
}
async createIndex (indexParams: VectorIndexParams): Promise<void> {
@@ -280,11 +352,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
metric_type: metricType,
index_cache_size: indexCacheSize
}
const res = await this._client.post(`/v1/table/${this._name}/create_index/`, data)
const res = await this._client.post(
`/v1/table/${this._name}/create_index/`,
data
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
@@ -298,7 +375,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async delete (filter: string): Promise<void> {
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
await this._client.post(`/v1/table/${this._name}/delete/`, {
predicate: filter
})
}
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
@@ -322,7 +401,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async listIndices (): Promise<VectorIndex[]> {
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
const results = await this._client.post(
`/v1/table/${this._name}/index/list/`
)
return results.data.indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
@@ -331,7 +412,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async indexStats (indexUuid: string): Promise<IndexStats> {
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
const results = await this._client.post(
`/v1/table/${this._name}/index/${indexUuid}/stats/`
)
return {
numIndexedRows: results.data.num_indexed_rows,
numUnindexedRows: results.data.num_unindexed_rows

View File

@@ -18,8 +18,28 @@ import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index'
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions, type LocalTable } from '../index'
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
import {
type AwsCredentials,
type EmbeddingFunction,
MetricType,
Query,
WriteMode,
DefaultWriteOptions,
isWriteOptions,
type LocalTable
} from '../index'
import {
FixedSizeList,
Field,
Int32,
makeVector,
Schema,
Utf8,
Table as ArrowTable,
vectorFromArray,
Float32,
Float16
} from 'apache-arrow'
const expect = chai.expect
const assert = chai.assert
@@ -45,7 +65,10 @@ describe('LanceDB client', function () {
accessKeyId: '',
secretKey: ''
}
const con = await lancedb.connect({ uri, awsCredentials })
const con = await lancedb.connect({
uri,
awsCredentials
})
assert.equal(con.uri, uri)
})
@@ -125,13 +148,29 @@ describe('LanceDB client', function () {
const uri = await createTestDB(16, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
// post filter should return less than the limit
let results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(false).execute()
let results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(false)
.execute()
assert.isTrue(results.length < 10)
// pre filter should return exactly the limit
results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute()
results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(true)
.execute()
assert.isTrue(results.length === 10)
})
@@ -142,7 +181,12 @@ describe('LanceDB client', function () {
await table.createScalarIndex('id', true)
// Prefiltering should still work the same
const results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute()
const results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(true)
.execute()
assert.isTrue(results.length === 10)
})
@@ -150,7 +194,10 @@ describe('LanceDB client', function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.1]).select(['is_active']).execute()
const results = await table
.search([0.1, 0.1])
.select(['is_active'])
.execute()
assert.equal(results.length, 2)
// vector and _distance are always returned
assert.isDefined(results[0].vector)
@@ -168,10 +215,14 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()), new Field('name', new Utf8())]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8())
])
const table = await con.createTable({
name: 'vectors',
schema
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -180,18 +231,33 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()),
new Field('name', new Utf8()),
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), false)
]
)
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8()),
new Field(
'vector',
new FixedSizeList(2, new Field('item', new Float32(), true)),
false
)
])
const data = [
{ vector: [0.5, 0.2], name: 'foo', id: 0 },
{ vector: [0.3, 0.1], name: 'bar', id: 1 }
{
vector: [0.5, 0.2],
name: 'foo',
id: 0
},
{
vector: [0.3, 0.1],
name: 'bar',
id: 1
}
]
// even thought the keys in data is out of order it should still work
const table = await con.createTable({ name: 'vectors', data, schema })
const table = await con.createTable({
name: 'vectors',
data,
schema
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -200,10 +266,15 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()), new Field('name', new Utf8())]
)
const table = await con.createTable({ name: 'vectors', schema, data: [] })
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8())
])
const table = await con.createTable({
name: 'vectors',
schema,
data: []
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -217,7 +288,10 @@ describe('LanceDB client', function () {
const data = new ArrowTable({ vector: i32 })
const table = await con.createTable({ name: 'vectors', data })
const table = await con.createTable({
name: 'vectors',
data
})
assert.equal(table.name, 'vectors')
assert.equal(await table.countRows(), 10)
assert.deepEqual(await con.tableNames(), ['vectors'])
@@ -229,7 +303,11 @@ describe('LanceDB client', function () {
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
{
id: 2,
vector: [1.1, 1.2],
price: 50
}
]
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
@@ -243,30 +321,92 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], list_of_str: ['a', 'b', 'c'], list_of_num: [1, 2, 3] },
{ id: 2, vector: [1.1, 1.2], list_of_str: ['x', 'y'], list_of_num: [4, 5, 6] }
{
id: 1,
vector: [0.1, 0.2],
list_of_str: ['a', 'b', 'c'],
list_of_num: [1, 2, 3]
},
{
id: 2,
vector: [1.1, 1.2],
list_of_str: ['x', 'y'],
list_of_num: [4, 5, 6]
}
]
const tableName = 'with_variable_sized_list'
const table = await con.createTable(tableName, data) as LocalTable
const table = (await con.createTable(tableName, data)) as LocalTable
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 2)
const rs = await table.filter('id>1').execute()
assert.equal(rs.length, 1)
assert.deepEqual(rs[0].list_of_str, ['x', 'y'])
assert.isTrue(rs[0].list_of_num instanceof Float64Array)
assert.isTrue(rs[0].list_of_num instanceof Array)
})
it('create table from arrow table', async () => {
const dim = 128
const total = 256
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema([
new Field('id', new Int32()),
new Field(
'vector',
new FixedSizeList(dim, new Field('item', new Float16(), true)),
false
)
])
const data = lancedb.makeArrowTable(
Array.from(Array(total), (_, i) => ({
id: i,
vector: Array.from(Array(dim), Math.random)
})),
{ schema }
)
const table = await con.createTable('f16', data)
assert.equal(table.name, 'f16')
assert.equal(await table.countRows(), total)
assert.deepEqual(await con.tableNames(), ['f16'])
assert.deepEqual(await table.schema, schema)
await table.createIndex({
num_sub_vectors: 2,
num_partitions: 2,
type: 'ivf_pq'
})
const q = Array.from(Array(dim), Math.random)
const r = await table.search(q).limit(5).execute()
assert.equal(r.length, 5)
r.forEach((v) => {
assert.equal(Object.prototype.hasOwnProperty.call(v, 'vector'), true)
assert.equal(
v.vector?.constructor.name,
'Array',
'vector column is list of floats'
)
})
}).timeout(120000)
it('fails to create a new table when the vector column is missing', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, price: 10 }
{
id: 1,
price: 10
}
]
const create = con.createTable('missing_vector', data)
await expect(create).to.be.rejectedWith(Error, 'column \'vector\' is missing')
await expect(create).to.be.rejectedWith(
Error,
"column 'vector' is missing"
)
})
it('use overwrite flag to overwrite existing table', async function () {
@@ -275,7 +415,11 @@ describe('LanceDB client', function () {
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
{
id: 2,
vector: [1.1, 1.2],
price: 50
}
]
const tableName = 'overwrite'
@@ -284,12 +428,21 @@ describe('LanceDB client', function () {
const newData = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 },
{ id: 3, vector: [1.1, 1.2], price: 50 }
{
id: 3,
vector: [1.1, 1.2],
price: 50
}
]
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists')
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(
Error,
'already exists'
)
const table = await con.createTable(tableName, newData, { writeMode: WriteMode.Overwrite })
const table = await con.createTable(tableName, newData, {
writeMode: WriteMode.Overwrite
})
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 3)
})
@@ -299,16 +452,36 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
{
id: 1,
vector: [0.1, 0.2],
price: 10,
name: 'a'
},
{
id: 2,
vector: [1.1, 1.2],
price: 50,
name: 'b'
}
]
const table = await con.createTable('vectors', data)
assert.equal(await table.countRows(), 2)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], price: 10, name: 'c' },
{ id: 4, vector: [3.1, 3.2], price: 50, name: 'd' }
{
id: 3,
vector: [2.1, 2.2],
price: 10,
name: 'c'
},
{
id: 4,
vector: [3.1, 3.2],
price: 50,
name: 'd'
}
]
await table.add(dataAdd)
assert.equal(await table.countRows(), 4)
@@ -319,15 +492,35 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
{
id: 1,
vector: [0.1, 0.2],
price: 10,
name: 'a'
},
{
id: 2,
vector: [1.1, 1.2],
price: 50,
name: 'b'
}
]
const table = await con.createTable('vectors', data)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], name: 'c', price: 10 },
{ id: 4, vector: [3.1, 3.2], name: 'd', price: 50 }
{
id: 3,
vector: [2.1, 2.2],
name: 'c',
price: 10
},
{
id: 4,
vector: [3.1, 3.2],
name: 'd',
price: 50
}
]
await table.add(dataAdd)
assert.equal(await table.countRows(), 4)
@@ -341,8 +534,16 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
const dataOver = [
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
{
vector: [2.1, 2.2],
price: 10,
name: 'foo'
},
{
vector: [3.1, 3.2],
price: 50,
name: 'bar'
}
]
await table.overwrite(dataOver)
assert.equal(await table.countRows(), 2)
@@ -355,7 +556,10 @@ describe('LanceDB client', function () {
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
await table.update({
where: 'price = 10',
valuesSql: { price: '100' }
})
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
@@ -368,7 +572,10 @@ describe('LanceDB client', function () {
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', values: { price: 100 } })
await table.update({
where: 'price = 10',
values: { price: 100 }
})
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
@@ -405,10 +612,16 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field(
'vector',
new FixedSizeList(128, new Field('float32', new Float32()))
)
])
const table = await con.createTable({
name: 'vectors',
schema
})
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
@@ -419,10 +632,16 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field(
'vector',
new FixedSizeList(128, new Field('float32', new Float32()))
)
])
const table = await con.createTable({
name: 'vectors',
schema
})
await table.add([{ vector: Array(128).fill(0.1) }])
// https://github.com/lancedb/lance/issues/1635
await table.delete('true')
@@ -436,7 +655,13 @@ describe('LanceDB client', function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
it('replace an existing index', async function () {
@@ -444,39 +669,79 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
// Replace should fail if the index already exists
await expect(table.createIndex({
type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2, replace: false
})
await expect(
table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2,
replace: false
})
).to.be.rejectedWith('LanceError(Index)')
// Default replace = true
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
}).timeout(50_000)
it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/)
const createIndex = table.createIndex({
type: 'ivf_pq',
column: 'name',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
await expect(createIndex).to.be.rejectedWith(
/VectorIndex requires the column data type to be fixed size list of float32s/
)
})
it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
const createIndex = table.createIndex({
type: 'ivf_pq',
column: 'name',
num_partitions: -1,
max_iters: 2,
num_sub_vectors: 2
})
await expect(createIndex).to.be.rejectedWith(
'num_partitions: must be > 0'
)
})
it('should be able to list index and stats', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
const indices = await table.listIndices()
expect(indices).to.have.lengthOf(1)
@@ -505,7 +770,9 @@ describe('LanceDB client', function () {
])
async embed (data: string[]): Promise<number[][]> {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
return data.map(
(datum) => this._embedding_map.get(datum) ?? [0.0, 0.0]
)
}
}
@@ -515,10 +782,18 @@ describe('LanceDB client', function () {
const embeddings = new TextEmbedding('name')
const data = [
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
{
price: 10,
name: 'foo'
},
{
price: 50,
name: 'bar'
}
]
const table = await con.createTable('vectors', data, embeddings, { writeMode: WriteMode.Create })
const table = await con.createTable('vectors', data, embeddings, {
writeMode: WriteMode.Create
})
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})
@@ -531,7 +806,11 @@ describe('LanceDB client', function () {
const names = vectorFromArray(['foo', 'bar'], new Utf8())
const data = new ArrowTable({ name: names })
const table = await con.createTable({ name: 'vectors', data, embeddingFunction })
const table = await con.createTable({
name: 'vectors',
data,
embeddingFunction
})
assert.equal(table.name, 'vectors')
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
@@ -543,13 +822,14 @@ describe('LanceDB client', function () {
const uri = await createTestDB()
const db = await lancedb.connect(uri)
// the fsl inner field must be named 'item' and be nullable
const expectedSchema = new Schema(
[
new Field('id', new Int32()),
new Field('vector', new FixedSizeList(128, new Field('item', new Float32(), true))),
new Field('s', new Utf8())
]
)
const expectedSchema = new Schema([
new Field('id', new Int32()),
new Field(
'vector',
new FixedSizeList(128, new Field('item', new Float32(), true))
),
new Field('s', new Utf8())
])
const table = await db.createTable({
name: 'some_table',
schema: expectedSchema
@@ -573,14 +853,23 @@ describe('Remote LanceDB client', function () {
try {
await con.tableNames()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
// POST
try {
await con.createTable({ name: 'vectors', schema: new Schema([]) })
await con.createTable({
name: 'vectors',
schema: new Schema([])
})
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
// Search
@@ -588,7 +877,10 @@ describe('Remote LanceDB client', function () {
try {
await table.search([0.1, 0.3]).execute()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
})
})
@@ -610,7 +902,10 @@ describe('Query object', function () {
})
})
async function createTestDB (numDimensions: number = 2, numRows: number = 2): Promise<string> {
async function createTestDB (
numDimensions: number = 2,
numRows: number = 2
): Promise<string> {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
@@ -618,9 +913,15 @@ async function createTestDB (numDimensions: number = 2, numRows: number = 2): Pr
for (let i = 0; i < numRows; i++) {
const vector = []
for (let j = 0; j < numDimensions; j++) {
vector.push(i + (j * 0.1))
vector.push(i + j * 0.1)
}
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
data.push({
id: i + 1,
name: `name_${i}`,
price: i + 10,
is_active: i % 2 === 0,
vector
})
}
await con.createTable('vectors', data)
@@ -633,8 +934,16 @@ describe('Drop table', function () {
const con = await lancedb.connect(dir)
const data = [
{ price: 10, name: 'foo', vector: [1, 2, 3] },
{ price: 50, name: 'bar', vector: [4, 5, 6] }
{
price: 10,
name: 'foo',
vector: [1, 2, 3]
},
{
price: 50,
name: 'bar',
vector: [4, 5, 6]
}
]
await con.createTable('t1', data)
await con.createTable('t2', data)
@@ -669,13 +978,25 @@ describe('Compact and cleanup', function () {
const con = await lancedb.connect(dir)
const data = [
{ price: 10, name: 'foo', vector: [1, 2, 3] },
{ price: 50, name: 'bar', vector: [4, 5, 6] }
{
price: 10,
name: 'foo',
vector: [1, 2, 3]
},
{
price: 50,
name: 'bar',
vector: [4, 5, 6]
}
]
const table = await con.createTable('t1', data) as LocalTable
const table = (await con.createTable('t1', data)) as LocalTable
const newData = [
{ price: 30, name: 'baz', vector: [7, 8, 9] }
{
price: 30,
name: 'baz',
vector: [7, 8, 9]
}
]
await table.add(newData)