mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-29 18:00:40 +00:00
nodejs create_table (#75)
This commit is contained in:
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -1645,9 +1645,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "0.4.6"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0be48532dda07b0cd5a5b087710bbbba6972a882f7c771a316940a4868d66a6"
|
||||
checksum = "fc96cf89139af6f439a0e28ccd04ddf81be795b79fda3105b7a8952fadeb778e"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"arrow",
|
||||
@@ -3371,7 +3371,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"futures",
|
||||
"lance",
|
||||
"neon",
|
||||
"once_cell",
|
||||
"tokio",
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# LanceDB
|
||||
|
||||
A JavaScript / Node.js library for [LanceDB](https://github.com/lancedb/lancedb).
|
||||
|
||||
## Quick Start
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm i vectordb
|
||||
npm install vectordb
|
||||
```
|
||||
|
||||
See the examples folder for usage.
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```javascript
|
||||
const lancedb = require('vectordb');
|
||||
const db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>');
|
||||
const table = await db.openTable('my_table');
|
||||
const query = await table.search([0.1, 0.3]).setLimit(20).execute();
|
||||
console.log(results);
|
||||
```
|
||||
|
||||
The [examples](./examples) folder contains complete examples.
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
@@ -16,12 +16,17 @@
|
||||
|
||||
async function example() {
|
||||
const lancedb = require('vectordb');
|
||||
const db = lancedb.connect('../../sample-lancedb');
|
||||
const db = await lancedb.connect('data/sample-lancedb');
|
||||
|
||||
console.log(db.tableNames());
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10 },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50 }
|
||||
]
|
||||
|
||||
const tbl = await db.openTable('my_table');
|
||||
const query = tbl.search([0.1, 0.3]);
|
||||
const table = await db.createTable('vectors', data)
|
||||
console.log(await db.tableNames());
|
||||
|
||||
const query = table.search([0.1, 0.3]);
|
||||
query.limit = 20;
|
||||
const results = await query.execute();
|
||||
console.log(results);
|
||||
|
||||
@@ -9,6 +9,6 @@
|
||||
"author": "",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"vectordb": "^0.0.6"
|
||||
"vectordb": "^0.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
91
node/package-lock.json
generated
91
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.0.6",
|
||||
"version": "0.1.0",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.0.6",
|
||||
"version": "0.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@apache-arrow/ts": "^12.0.0",
|
||||
@@ -16,6 +16,7 @@
|
||||
"@types/chai": "^4.3.4",
|
||||
"@types/mocha": "^10.0.1",
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -25,6 +26,7 @@
|
||||
"eslint-plugin-n": "^15.7.0",
|
||||
"eslint-plugin-promise": "^6.1.1",
|
||||
"mocha": "^10.2.0",
|
||||
"temp": "^0.9.4",
|
||||
"ts-node": "^10.9.1",
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typescript": "*"
|
||||
@@ -317,6 +319,15 @@
|
||||
"integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/temp": {
|
||||
"version": "0.9.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/temp/-/temp-0.9.1.tgz",
|
||||
"integrity": "sha512-yDQ8Y+oQi9V7VkexwE6NBSVyNuyNFeGI275yWXASc2DjmxNicMi9O50KxDpNlST1kBbV9jKYBHGXhgNYFMPqtA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -3569,6 +3580,43 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/temp": {
|
||||
"version": "0.9.4",
|
||||
"resolved": "https://registry.npmjs.org/temp/-/temp-0.9.4.tgz",
|
||||
"integrity": "sha512-yYrrsWnrXMcdsnu/7YMYAofM1ktpL5By7vZhf15CrXijWWrEYZks5AXBudalfSWJLlnen/QUJUB5aoB0kqZUGA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"mkdirp": "^0.5.1",
|
||||
"rimraf": "~2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/temp/node_modules/mkdirp": {
|
||||
"version": "0.5.6",
|
||||
"resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.6.tgz",
|
||||
"integrity": "sha512-FP+p8RB8OWpF3YZBCrP5gtADmtXApB5AMLn+vdyA+PyxCjrCs00mjyUozssO33cwDeT3wNGdLxJ5M//YqtHAJw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"minimist": "^1.2.6"
|
||||
},
|
||||
"bin": {
|
||||
"mkdirp": "bin/cmd.js"
|
||||
}
|
||||
},
|
||||
"node_modules/temp/node_modules/rimraf": {
|
||||
"version": "2.6.3",
|
||||
"resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.6.3.tgz",
|
||||
"integrity": "sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"glob": "^7.1.3"
|
||||
},
|
||||
"bin": {
|
||||
"rimraf": "bin.js"
|
||||
}
|
||||
},
|
||||
"node_modules/text-table": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz",
|
||||
@@ -4256,6 +4304,15 @@
|
||||
"integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==",
|
||||
"dev": true
|
||||
},
|
||||
"@types/temp": {
|
||||
"version": "0.9.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/temp/-/temp-0.9.1.tgz",
|
||||
"integrity": "sha512-yDQ8Y+oQi9V7VkexwE6NBSVyNuyNFeGI275yWXASc2DjmxNicMi9O50KxDpNlST1kBbV9jKYBHGXhgNYFMPqtA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -6552,6 +6609,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"temp": {
|
||||
"version": "0.9.4",
|
||||
"resolved": "https://registry.npmjs.org/temp/-/temp-0.9.4.tgz",
|
||||
"integrity": "sha512-yYrrsWnrXMcdsnu/7YMYAofM1ktpL5By7vZhf15CrXijWWrEYZks5AXBudalfSWJLlnen/QUJUB5aoB0kqZUGA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"mkdirp": "^0.5.1",
|
||||
"rimraf": "~2.6.2"
|
||||
},
|
||||
"dependencies": {
|
||||
"mkdirp": {
|
||||
"version": "0.5.6",
|
||||
"resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.6.tgz",
|
||||
"integrity": "sha512-FP+p8RB8OWpF3YZBCrP5gtADmtXApB5AMLn+vdyA+PyxCjrCs00mjyUozssO33cwDeT3wNGdLxJ5M//YqtHAJw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"minimist": "^1.2.6"
|
||||
}
|
||||
},
|
||||
"rimraf": {
|
||||
"version": "2.6.3",
|
||||
"resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.6.3.tgz",
|
||||
"integrity": "sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"glob": "^7.1.3"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"text-table": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.0.6",
|
||||
"version": "0.1.0",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -27,6 +27,7 @@
|
||||
"@types/chai": "^4.3.4",
|
||||
"@types/mocha": "^10.0.1",
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -36,6 +37,7 @@
|
||||
"eslint-plugin-n": "^15.7.0",
|
||||
"eslint-plugin-promise": "^6.1.1",
|
||||
"mocha": "^10.2.0",
|
||||
"temp": "^0.9.4",
|
||||
"ts-node": "^10.9.1",
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typescript": "*"
|
||||
|
||||
@@ -12,16 +12,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { tableFromIPC, Vector } from 'apache-arrow'
|
||||
import {
|
||||
Field,
|
||||
Float32,
|
||||
List,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Table as ArrowTable,
|
||||
tableFromIPC,
|
||||
Vector,
|
||||
vectorFromArray
|
||||
} from 'apache-arrow'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, tableSearch } = require('../index.node')
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch } = require('../index.node')
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI
|
||||
* @param uri The uri of the database.
|
||||
*/
|
||||
export function connect (uri: string): Connection {
|
||||
export async function connect (uri: string): Promise<Connection> {
|
||||
return new Connection(uri)
|
||||
}
|
||||
|
||||
@@ -44,7 +54,7 @@ export class Connection {
|
||||
/**
|
||||
* Get the names of all tables in the database.
|
||||
*/
|
||||
tableNames (): string[] {
|
||||
async tableNames (): Promise<string[]> {
|
||||
return databaseTableNames.call(this._db)
|
||||
}
|
||||
|
||||
@@ -56,6 +66,50 @@ export class Connection {
|
||||
const tbl = await databaseOpenTable.call(this._db, name)
|
||||
return new Table(tbl, name)
|
||||
}
|
||||
|
||||
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
|
||||
const columns = Object.keys(data[0])
|
||||
const records: Record<string, Vector> = {}
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
const listBuilder = makeBuilder({
|
||||
type: list
|
||||
})
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
throw new Error(`Invalid vector size, expected ${vectorSize}`)
|
||||
}
|
||||
|
||||
listBuilder.append(datum[columnsKey])
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
const values = []
|
||||
for (const datum of data) {
|
||||
values.push(datum[columnsKey])
|
||||
}
|
||||
records[columnsKey] = vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
|
||||
const table = new ArrowTable(records)
|
||||
await this.createTableArrow(name, table)
|
||||
return await this.openTable(name)
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
await tableCreate.call(this._db, name, Buffer.from(await writer.toUint8Array()))
|
||||
return await this.openTable(name)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -93,7 +147,7 @@ export class Query {
|
||||
private readonly _refine_factor?: number
|
||||
private readonly _nprobes: number
|
||||
private readonly _columns?: string[]
|
||||
private readonly _where?: string
|
||||
private _filter?: string
|
||||
private readonly _metric = 'L2'
|
||||
|
||||
constructor (tbl: any, queryVector: number[]) {
|
||||
@@ -103,22 +157,29 @@ export class Query {
|
||||
this._nprobes = 20
|
||||
this._refine_factor = undefined
|
||||
this._columns = undefined
|
||||
this._where = undefined
|
||||
this._filter = undefined
|
||||
}
|
||||
|
||||
set limit (value: number) {
|
||||
limit (value: number): Query {
|
||||
this._limit = value
|
||||
return this
|
||||
}
|
||||
|
||||
get limit (): number {
|
||||
return this._limit
|
||||
filter (value: string): Query {
|
||||
this._filter = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
async execute (): Promise<unknown[]> {
|
||||
const buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit)
|
||||
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
||||
let buffer;
|
||||
if (this._filter != null) {
|
||||
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter)
|
||||
} else {
|
||||
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit)
|
||||
}
|
||||
const data = tableFromIPC(buffer)
|
||||
return data.toArray().map((entry: Record<string, unknown>) => {
|
||||
const newObject: Record<string, unknown> = {}
|
||||
@@ -129,14 +190,7 @@ export class Query {
|
||||
newObject[key] = entry[key]
|
||||
}
|
||||
})
|
||||
return newObject
|
||||
return newObject as unknown as T
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of the generic type provided
|
||||
*/
|
||||
async execute_cast<T>(): Promise<T[]> {
|
||||
return await this.execute() as T[]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,67 +14,94 @@
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
import { track } from 'temp'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
|
||||
describe('LanceDB client', function () {
|
||||
describe('open a connection to lancedb', function () {
|
||||
const con = lancedb.connect('.../../sample-lancedb')
|
||||
|
||||
it('should have a valid url', function () {
|
||||
assert.equal(con.uri, '.../../sample-lancedb')
|
||||
describe('when creating a connection to lancedb', function () {
|
||||
it('should have a valid url', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should return the existing table names', function () {
|
||||
assert.deepEqual(con.tableNames(), ['my_table'])
|
||||
it('should return the existing table names', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('when querying an existing dataset', function () {
|
||||
it('should open a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(table.name, 'vectors')
|
||||
})
|
||||
|
||||
describe('open a table from a connection', function () {
|
||||
const tablePromise = con.openTable('my_table')
|
||||
it('execute a query', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).execute()
|
||||
|
||||
it('should have a valid name', async function () {
|
||||
const table = await tablePromise
|
||||
assert.equal(table.name, 'my_table')
|
||||
})
|
||||
assert.equal(results.length, 2)
|
||||
assert.equal(results[0].price, 10)
|
||||
const vector = results[0].vector as Float32Array
|
||||
assert.approximately(vector[0], 0.0, 0.2)
|
||||
assert.approximately(vector[0], 0.1, 0.3)
|
||||
})
|
||||
|
||||
class MyResult {
|
||||
vector: Float32Array = new Float32Array(0)
|
||||
price: number = 0
|
||||
item: string = ''
|
||||
}
|
||||
it('limits # of results', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 1)
|
||||
})
|
||||
|
||||
it('execute a query', async function () {
|
||||
const table = await tablePromise
|
||||
const builder = table.search([0.1, 0.3])
|
||||
const results = await builder.execute() as MyResult[]
|
||||
it('uses a filter', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).filter('id == 2').execute()
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 2)
|
||||
})
|
||||
})
|
||||
|
||||
assert.equal(results.length, 2)
|
||||
assert.equal(results[0].item, 'foo')
|
||||
assert.equal(results[0].price, 10)
|
||||
assert.approximately(results[0].vector[0], 3.1, 0.1)
|
||||
assert.approximately(results[0].vector[1], 4.1, 0.1)
|
||||
})
|
||||
describe('when creating a new dataset', function () {
|
||||
it('creates a new table from javascript objects', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
it('execute a query and type cast the result', async function () {
|
||||
const table = await tablePromise
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10 },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50 }
|
||||
]
|
||||
|
||||
const builder = table.search([0.1, 0.3])
|
||||
const results = await builder.execute_cast<MyResult>()
|
||||
assert.equal(results.length, 2)
|
||||
assert.equal(results[0].item, 'foo')
|
||||
assert.equal(results[0].price, 10)
|
||||
assert.approximately(results[0].vector[0], 3.1, 0.1)
|
||||
assert.approximately(results[0].vector[1], 4.1, 0.1)
|
||||
})
|
||||
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
|
||||
const table = await con.createTable(tableName, data)
|
||||
assert.equal(table.name, tableName)
|
||||
|
||||
it('limits # of results', async function () {
|
||||
const table = await tablePromise
|
||||
const builder = table.search([0.1, 0.3])
|
||||
builder.limit = 1
|
||||
const results = await builder.execute() as MyResult[]
|
||||
|
||||
assert.equal(results.length, 1)
|
||||
})
|
||||
const results = await table.search([0.1, 0.3]).execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
async function createTestDB (): Promise<string> {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], name: 'foo', price: 10, is_active: true },
|
||||
{ id: 2, vector: [1.1, 1.2], name: 'bar', price: 50, is_active: false }
|
||||
]
|
||||
|
||||
await con.createTable('vectors', data)
|
||||
return dir
|
||||
}
|
||||
|
||||
@@ -12,8 +12,10 @@ crate-type = ["cdylib"]
|
||||
[dependencies]
|
||||
arrow-array = "37.0"
|
||||
arrow-ipc = "37.0"
|
||||
arrow-schema = "37.0"
|
||||
once_cell = "1"
|
||||
futures = "0.3"
|
||||
lance = "0.4.3"
|
||||
vectordb = { path = "../../vectordb" }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||
|
||||
47
rust/ffi/node/src/arrow.rs
Normal file
47
rust/ffi/node/src/arrow.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::cast::as_list_array;
|
||||
use arrow_array::{Array, FixedSizeListArray, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
|
||||
|
||||
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
|
||||
let column = record_batch
|
||||
.column_by_name("vector")
|
||||
.expect("vector column is missing");
|
||||
let arr = as_list_array(column.deref());
|
||||
let list_size = arr.values().len() / record_batch.num_rows();
|
||||
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
list_size as i32,
|
||||
),
|
||||
true,
|
||||
)]));
|
||||
|
||||
let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)]).unwrap();
|
||||
|
||||
if record_batch.num_columns() > 1 {
|
||||
let rb = record_batch.drop_column("vector").unwrap();
|
||||
new_batch = new_batch.merge(&rb).unwrap();
|
||||
}
|
||||
new_batch
|
||||
}
|
||||
@@ -12,21 +12,29 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod convert;
|
||||
|
||||
use std::io::Cursor;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use futures::TryStreamExt;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::arrow::RecordBatchBuffer;
|
||||
use neon::prelude::*;
|
||||
use neon::types::buffer::TypedArray;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::database::Database;
|
||||
use vectordb::error::Error;
|
||||
use vectordb::table::Table;
|
||||
|
||||
use crate::arrow::convert_record_batch;
|
||||
|
||||
mod arrow;
|
||||
mod convert;
|
||||
|
||||
struct JsDatabase {
|
||||
database: Arc<Database>,
|
||||
}
|
||||
@@ -90,6 +98,7 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let query_vector = cx.argument::<JsArray>(0)?; //. .as_value(&mut cx);
|
||||
let limit = cx.argument::<JsNumber>(1)?.value(&mut cx);
|
||||
let filter = cx.argument_opt(2).map(|f| f.downcast_or_throw::<JsString, _>(&mut cx).unwrap().value(&mut cx));
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
@@ -101,12 +110,11 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
rt.spawn(async move {
|
||||
let builder = table
|
||||
.search(Float32Array::from(query))
|
||||
.limit(limit as usize);
|
||||
let results = builder
|
||||
.execute()
|
||||
.await
|
||||
.unwrap() // FIXME unwrap
|
||||
.try_collect::<Vec<_>>()
|
||||
.limit(limit as usize)
|
||||
.filter(filter);
|
||||
let record_batch_stream = builder.execute();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
@@ -135,11 +143,46 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let db = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let buffer = cx.argument::<JsBuffer>(1)?;
|
||||
let slice = buffer.as_slice(&mut cx);
|
||||
|
||||
let mut batches: Vec<RecordBatch> = Vec::new();
|
||||
let fr = FileReader::try_new(Cursor::new(slice), None);
|
||||
let file_reader = fr.unwrap();
|
||||
for b in file_reader {
|
||||
let record_batch = convert_record_batch(b.unwrap());
|
||||
batches.push(record_batch);
|
||||
}
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let database = db.database.clone();
|
||||
|
||||
rt.block_on(async move {
|
||||
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
|
||||
let table_rst = database.create_table(table_name, batch_reader).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?);
|
||||
Ok(cx.boxed(JsTable { table }))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
#[neon::main]
|
||||
fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("databaseNew", database_new)?;
|
||||
cx.export_function("databaseTableNames", database_table_names)?;
|
||||
cx.export_function("databaseOpenTable", database_open_table)?;
|
||||
cx.export_function("tableSearch", table_search)?;
|
||||
cx.export_function("tableCreate", table_create)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
@@ -73,6 +74,14 @@ impl Database {
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
pub async fn create_table(
|
||||
&self,
|
||||
name: String,
|
||||
batches: Box<dyn RecordBatchReader>,
|
||||
) -> Result<Table> {
|
||||
Table::create(self.path.clone(), name, batches).await
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -26,6 +26,7 @@ pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Float32Array,
|
||||
pub limit: usize,
|
||||
pub filter: Option<String>,
|
||||
pub nprobes: usize,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: MetricType,
|
||||
@@ -52,6 +53,7 @@ impl Query {
|
||||
refine_factor: None,
|
||||
metric_type: MetricType::L2,
|
||||
use_index: false,
|
||||
filter: None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +73,7 @@ impl Query {
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.distance_metric(self.metric_type);
|
||||
scanner.use_index(self.use_index);
|
||||
self.filter.as_ref().map(|f| scanner.filter(f));
|
||||
self.refine_factor.map(|rf| scanner.refine(rf));
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
@@ -134,6 +137,11 @@ impl Query {
|
||||
self.use_index = use_index;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn filter(mut self, filter: Option<String>) -> Query {
|
||||
self.filter = filter;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::Dataset;
|
||||
use arrow_array::{Float32Array, RecordBatchReader};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::query::Query;
|
||||
@@ -55,6 +55,21 @@ impl Table {
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
base_path: Arc<PathBuf>,
|
||||
name: String,
|
||||
mut batches: Box<dyn RecordBatchReader>,
|
||||
) -> Result<Self> {
|
||||
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
let ds_uri = ds_path
|
||||
.to_str()
|
||||
.ok_or(Error::IO(format!("Unable to find table {}", name)))?;
|
||||
|
||||
let dataset =
|
||||
Arc::new(Dataset::write(&mut batches, ds_uri, Some(WriteParams::default())).await?);
|
||||
Ok(Table { name, dataset })
|
||||
}
|
||||
|
||||
/// Creates a new Query object that can be executed.
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
Reference in New Issue
Block a user