feat: change create table to accept Arrow table (#845)

This commit is contained in:
Lei Xu
2024-01-23 13:25:15 -08:00
committed by Weston Pace
parent 5ecbf971e2
commit 65c1d8bc4c
5 changed files with 586 additions and 160 deletions

View File

@@ -18,8 +18,28 @@ import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index'
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions, type LocalTable } from '../index'
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
import {
type AwsCredentials,
type EmbeddingFunction,
MetricType,
Query,
WriteMode,
DefaultWriteOptions,
isWriteOptions,
type LocalTable
} from '../index'
import {
FixedSizeList,
Field,
Int32,
makeVector,
Schema,
Utf8,
Table as ArrowTable,
vectorFromArray,
Float32,
Float16
} from 'apache-arrow'
const expect = chai.expect
const assert = chai.assert
@@ -45,7 +65,10 @@ describe('LanceDB client', function () {
accessKeyId: '',
secretKey: ''
}
const con = await lancedb.connect({ uri, awsCredentials })
const con = await lancedb.connect({
uri,
awsCredentials
})
assert.equal(con.uri, uri)
})
@@ -125,13 +148,29 @@ describe('LanceDB client', function () {
const uri = await createTestDB(16, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
// post filter should return less than the limit
let results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(false).execute()
let results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(false)
.execute()
assert.isTrue(results.length < 10)
// pre filter should return exactly the limit
results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute()
results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(true)
.execute()
assert.isTrue(results.length === 10)
})
@@ -142,7 +181,12 @@ describe('LanceDB client', function () {
await table.createScalarIndex('id', true)
// Prefiltering should still work the same
const results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute()
const results = await table
.search(new Array(16).fill(0.1))
.limit(10)
.filter('id >= 10')
.prefilter(true)
.execute()
assert.isTrue(results.length === 10)
})
@@ -150,7 +194,10 @@ describe('LanceDB client', function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.1]).select(['is_active']).execute()
const results = await table
.search([0.1, 0.1])
.select(['is_active'])
.execute()
assert.equal(results.length, 2)
// vector and _distance are always returned
assert.isDefined(results[0].vector)
@@ -168,10 +215,14 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()), new Field('name', new Utf8())]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8())
])
const table = await con.createTable({
name: 'vectors',
schema
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -180,18 +231,33 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()),
new Field('name', new Utf8()),
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), false)
]
)
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8()),
new Field(
'vector',
new FixedSizeList(2, new Field('item', new Float32(), true)),
false
)
])
const data = [
{ vector: [0.5, 0.2], name: 'foo', id: 0 },
{ vector: [0.3, 0.1], name: 'bar', id: 1 }
{
vector: [0.5, 0.2],
name: 'foo',
id: 0
},
{
vector: [0.3, 0.1],
name: 'bar',
id: 1
}
]
// even thought the keys in data is out of order it should still work
const table = await con.createTable({ name: 'vectors', data, schema })
const table = await con.createTable({
name: 'vectors',
data,
schema
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -200,10 +266,15 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()), new Field('name', new Utf8())]
)
const table = await con.createTable({ name: 'vectors', schema, data: [] })
const schema = new Schema([
new Field('id', new Int32()),
new Field('name', new Utf8())
])
const table = await con.createTable({
name: 'vectors',
schema,
data: []
})
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -217,7 +288,10 @@ describe('LanceDB client', function () {
const data = new ArrowTable({ vector: i32 })
const table = await con.createTable({ name: 'vectors', data })
const table = await con.createTable({
name: 'vectors',
data
})
assert.equal(table.name, 'vectors')
assert.equal(await table.countRows(), 10)
assert.deepEqual(await con.tableNames(), ['vectors'])
@@ -229,7 +303,11 @@ describe('LanceDB client', function () {
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
{
id: 2,
vector: [1.1, 1.2],
price: 50
}
]
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
@@ -243,30 +321,92 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], list_of_str: ['a', 'b', 'c'], list_of_num: [1, 2, 3] },
{ id: 2, vector: [1.1, 1.2], list_of_str: ['x', 'y'], list_of_num: [4, 5, 6] }
{
id: 1,
vector: [0.1, 0.2],
list_of_str: ['a', 'b', 'c'],
list_of_num: [1, 2, 3]
},
{
id: 2,
vector: [1.1, 1.2],
list_of_str: ['x', 'y'],
list_of_num: [4, 5, 6]
}
]
const tableName = 'with_variable_sized_list'
const table = await con.createTable(tableName, data) as LocalTable
const table = (await con.createTable(tableName, data)) as LocalTable
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 2)
const rs = await table.filter('id>1').execute()
assert.equal(rs.length, 1)
assert.deepEqual(rs[0].list_of_str, ['x', 'y'])
assert.isTrue(rs[0].list_of_num instanceof Float64Array)
assert.isTrue(rs[0].list_of_num instanceof Array)
})
it('create table from arrow table', async () => {
const dim = 128
const total = 256
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema([
new Field('id', new Int32()),
new Field(
'vector',
new FixedSizeList(dim, new Field('item', new Float16(), true)),
false
)
])
const data = lancedb.makeArrowTable(
Array.from(Array(total), (_, i) => ({
id: i,
vector: Array.from(Array(dim), Math.random)
})),
{ schema }
)
const table = await con.createTable('f16', data)
assert.equal(table.name, 'f16')
assert.equal(await table.countRows(), total)
assert.deepEqual(await con.tableNames(), ['f16'])
assert.deepEqual(await table.schema, schema)
await table.createIndex({
num_sub_vectors: 2,
num_partitions: 2,
type: 'ivf_pq'
})
const q = Array.from(Array(dim), Math.random)
const r = await table.search(q).limit(5).execute()
assert.equal(r.length, 5)
r.forEach((v) => {
assert.equal(Object.prototype.hasOwnProperty.call(v, 'vector'), true)
assert.equal(
v.vector?.constructor.name,
'Array',
'vector column is list of floats'
)
})
}).timeout(120000)
it('fails to create a new table when the vector column is missing', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, price: 10 }
{
id: 1,
price: 10
}
]
const create = con.createTable('missing_vector', data)
await expect(create).to.be.rejectedWith(Error, 'column \'vector\' is missing')
await expect(create).to.be.rejectedWith(
Error,
"column 'vector' is missing"
)
})
it('use overwrite flag to overwrite existing table', async function () {
@@ -275,7 +415,11 @@ describe('LanceDB client', function () {
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
{
id: 2,
vector: [1.1, 1.2],
price: 50
}
]
const tableName = 'overwrite'
@@ -284,12 +428,21 @@ describe('LanceDB client', function () {
const newData = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 },
{ id: 3, vector: [1.1, 1.2], price: 50 }
{
id: 3,
vector: [1.1, 1.2],
price: 50
}
]
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists')
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(
Error,
'already exists'
)
const table = await con.createTable(tableName, newData, { writeMode: WriteMode.Overwrite })
const table = await con.createTable(tableName, newData, {
writeMode: WriteMode.Overwrite
})
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 3)
})
@@ -299,16 +452,36 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
{
id: 1,
vector: [0.1, 0.2],
price: 10,
name: 'a'
},
{
id: 2,
vector: [1.1, 1.2],
price: 50,
name: 'b'
}
]
const table = await con.createTable('vectors', data)
assert.equal(await table.countRows(), 2)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], price: 10, name: 'c' },
{ id: 4, vector: [3.1, 3.2], price: 50, name: 'd' }
{
id: 3,
vector: [2.1, 2.2],
price: 10,
name: 'c'
},
{
id: 4,
vector: [3.1, 3.2],
price: 50,
name: 'd'
}
]
await table.add(dataAdd)
assert.equal(await table.countRows(), 4)
@@ -319,15 +492,35 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
{
id: 1,
vector: [0.1, 0.2],
price: 10,
name: 'a'
},
{
id: 2,
vector: [1.1, 1.2],
price: 50,
name: 'b'
}
]
const table = await con.createTable('vectors', data)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], name: 'c', price: 10 },
{ id: 4, vector: [3.1, 3.2], name: 'd', price: 50 }
{
id: 3,
vector: [2.1, 2.2],
name: 'c',
price: 10
},
{
id: 4,
vector: [3.1, 3.2],
name: 'd',
price: 50
}
]
await table.add(dataAdd)
assert.equal(await table.countRows(), 4)
@@ -341,8 +534,16 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
const dataOver = [
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
{
vector: [2.1, 2.2],
price: 10,
name: 'foo'
},
{
vector: [3.1, 3.2],
price: 50,
name: 'bar'
}
]
await table.overwrite(dataOver)
assert.equal(await table.countRows(), 2)
@@ -355,7 +556,10 @@ describe('LanceDB client', function () {
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
await table.update({
where: 'price = 10',
valuesSql: { price: '100' }
})
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
@@ -368,7 +572,10 @@ describe('LanceDB client', function () {
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', values: { price: 100 } })
await table.update({
where: 'price = 10',
values: { price: 100 }
})
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
@@ -405,10 +612,16 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field(
'vector',
new FixedSizeList(128, new Field('float32', new Float32()))
)
])
const table = await con.createTable({
name: 'vectors',
schema
})
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
@@ -419,10 +632,16 @@ describe('LanceDB client', function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const schema = new Schema([
new Field(
'vector',
new FixedSizeList(128, new Field('float32', new Float32()))
)
])
const table = await con.createTable({
name: 'vectors',
schema
})
await table.add([{ vector: Array(128).fill(0.1) }])
// https://github.com/lancedb/lance/issues/1635
await table.delete('true')
@@ -436,7 +655,13 @@ describe('LanceDB client', function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
it('replace an existing index', async function () {
@@ -444,39 +669,79 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
// Replace should fail if the index already exists
await expect(table.createIndex({
type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2, replace: false
})
await expect(
table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2,
replace: false
})
).to.be.rejectedWith('LanceError(Index)')
// Default replace = true
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
}).timeout(50_000)
it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/)
const createIndex = table.createIndex({
type: 'ivf_pq',
column: 'name',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
await expect(createIndex).to.be.rejectedWith(
/VectorIndex requires the column data type to be fixed size list of float32s/
)
})
it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
const createIndex = table.createIndex({
type: 'ivf_pq',
column: 'name',
num_partitions: -1,
max_iters: 2,
num_sub_vectors: 2
})
await expect(createIndex).to.be.rejectedWith(
'num_partitions: must be > 0'
)
})
it('should be able to list index and stats', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await table.createIndex({
type: 'ivf_pq',
column: 'vector',
num_partitions: 2,
max_iters: 2,
num_sub_vectors: 2
})
const indices = await table.listIndices()
expect(indices).to.have.lengthOf(1)
@@ -505,7 +770,9 @@ describe('LanceDB client', function () {
])
async embed (data: string[]): Promise<number[][]> {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
return data.map(
(datum) => this._embedding_map.get(datum) ?? [0.0, 0.0]
)
}
}
@@ -515,10 +782,18 @@ describe('LanceDB client', function () {
const embeddings = new TextEmbedding('name')
const data = [
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
{
price: 10,
name: 'foo'
},
{
price: 50,
name: 'bar'
}
]
const table = await con.createTable('vectors', data, embeddings, { writeMode: WriteMode.Create })
const table = await con.createTable('vectors', data, embeddings, {
writeMode: WriteMode.Create
})
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})
@@ -531,7 +806,11 @@ describe('LanceDB client', function () {
const names = vectorFromArray(['foo', 'bar'], new Utf8())
const data = new ArrowTable({ name: names })
const table = await con.createTable({ name: 'vectors', data, embeddingFunction })
const table = await con.createTable({
name: 'vectors',
data,
embeddingFunction
})
assert.equal(table.name, 'vectors')
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
@@ -543,13 +822,14 @@ describe('LanceDB client', function () {
const uri = await createTestDB()
const db = await lancedb.connect(uri)
// the fsl inner field must be named 'item' and be nullable
const expectedSchema = new Schema(
[
new Field('id', new Int32()),
new Field('vector', new FixedSizeList(128, new Field('item', new Float32(), true))),
new Field('s', new Utf8())
]
)
const expectedSchema = new Schema([
new Field('id', new Int32()),
new Field(
'vector',
new FixedSizeList(128, new Field('item', new Float32(), true))
),
new Field('s', new Utf8())
])
const table = await db.createTable({
name: 'some_table',
schema: expectedSchema
@@ -573,14 +853,23 @@ describe('Remote LanceDB client', function () {
try {
await con.tableNames()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
// POST
try {
await con.createTable({ name: 'vectors', schema: new Schema([]) })
await con.createTable({
name: 'vectors',
schema: new Schema([])
})
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
// Search
@@ -588,7 +877,10 @@ describe('Remote LanceDB client', function () {
try {
await table.search([0.1, 0.3]).execute()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
expect(err).to.have.property(
'message',
'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com'
)
}
})
})
@@ -610,7 +902,10 @@ describe('Query object', function () {
})
})
async function createTestDB (numDimensions: number = 2, numRows: number = 2): Promise<string> {
async function createTestDB (
numDimensions: number = 2,
numRows: number = 2
): Promise<string> {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
@@ -618,9 +913,15 @@ async function createTestDB (numDimensions: number = 2, numRows: number = 2): Pr
for (let i = 0; i < numRows; i++) {
const vector = []
for (let j = 0; j < numDimensions; j++) {
vector.push(i + (j * 0.1))
vector.push(i + j * 0.1)
}
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
data.push({
id: i + 1,
name: `name_${i}`,
price: i + 10,
is_active: i % 2 === 0,
vector
})
}
await con.createTable('vectors', data)
@@ -633,8 +934,16 @@ describe('Drop table', function () {
const con = await lancedb.connect(dir)
const data = [
{ price: 10, name: 'foo', vector: [1, 2, 3] },
{ price: 50, name: 'bar', vector: [4, 5, 6] }
{
price: 10,
name: 'foo',
vector: [1, 2, 3]
},
{
price: 50,
name: 'bar',
vector: [4, 5, 6]
}
]
await con.createTable('t1', data)
await con.createTable('t2', data)
@@ -669,13 +978,25 @@ describe('Compact and cleanup', function () {
const con = await lancedb.connect(dir)
const data = [
{ price: 10, name: 'foo', vector: [1, 2, 3] },
{ price: 50, name: 'bar', vector: [4, 5, 6] }
{
price: 10,
name: 'foo',
vector: [1, 2, 3]
},
{
price: 50,
name: 'bar',
vector: [4, 5, 6]
}
]
const table = await con.createTable('t1', data) as LocalTable
const table = (await con.createTable('t1', data)) as LocalTable
const newData = [
{ price: 30, name: 'baz', vector: [7, 8, 9] }
{
price: 30,
name: 'baz',
vector: [7, 8, 9]
}
]
await table.add(newData)