diff --git a/.github/workflows/nodejs.yml b/.github/workflows/nodejs.yml index 53b5e001..5dfb0c76 100644 --- a/.github/workflows/nodejs.yml +++ b/.github/workflows/nodejs.yml @@ -22,6 +22,7 @@ env: jobs: lint: + if: false name: Lint runs-on: ubuntu-22.04 defaults: @@ -111,4 +112,3 @@ jobs: - name: Test run: | npm run test - diff --git a/nodejs/__test__/arrow.test.ts b/nodejs/__test__/arrow.test.ts index cb4a300f..c4a02850 100644 --- a/nodejs/__test__/arrow.test.ts +++ b/nodejs/__test__/arrow.test.ts @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { makeArrowTable, toBuffer } from "../lancedb/arrow"; +import { convertToTable, fromTableToBuffer, makeArrowTable, makeEmptyTable } from '../dist/arrow' import { - Int64, Field, FixedSizeList, Float16, @@ -23,98 +22,301 @@ import { tableFromIPC, Schema, Float64, -} from "apache-arrow"; + type Table, + Binary, + Bool, + Utf8, + Struct, + List, + DataType, + Dictionary, + Int64, + Float, + Precision +} from 'apache-arrow' +import { type EmbeddingFunction } from '../dist/embedding/embedding_function' -test("customized schema", function () { - const schema = new Schema([ - new Field("a", new Int32(), true), - new Field("b", new Float32(), true), - new Field( - "c", - new FixedSizeList(3, new Field("item", new Float16())), - true - ), - ]); - const table = makeArrowTable( - [ - { a: 1, b: 2, c: [1, 2, 3] }, - { a: 4, b: 5, c: [4, 5, 6] }, - { a: 7, b: 8, c: [7, 8, 9] }, - ], - { schema } - ); - - expect(table.schema.toString()).toEqual(schema.toString()); - - const buf = toBuffer(table); - expect(buf.byteLength).toBeGreaterThan(0); - - const actual = tableFromIPC(buf); - expect(actual.numRows).toBe(3); - const actualSchema = actual.schema; - expect(actualSchema.toString()).toStrictEqual(schema.toString()); -}); - -test("default vector column", function () { - const schema = new Schema([ - new Field("a", new Float64(), true), - new Field("b", new Float64(), true), - new Field("vector", new FixedSizeList(3, new Field("item", new Float32()))), - ]); - const table = makeArrowTable([ - { a: 1, b: 2, vector: [1, 2, 3] }, - { a: 4, b: 5, vector: [4, 5, 6] }, - { a: 7, b: 8, vector: [7, 8, 9] }, - ]); - - const buf = toBuffer(table); - expect(buf.byteLength).toBeGreaterThan(0); - - const actual = tableFromIPC(buf); - expect(actual.numRows).toBe(3); - const actualSchema = actual.schema; - expect(actualSchema.toString()).toEqual(actualSchema.toString()); -}); - -test("2 vector columns", function () { - const schema = new Schema([ - new Field("a", new Float64()), - new Field("b", new Float64()), - new Field("vec1", new FixedSizeList(3, new Field("item", new Float16()))), - new Field("vec2", new FixedSizeList(3, new Field("item", new Float16()))), - ]); - const table = makeArrowTable( - [ - { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, - { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, - { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }, - ], +function sampleRecords (): Array> { + return [ { - vectorColumns: { - vec1: { type: new Float16() }, - vec2: { type: new Float16() }, - }, + binary: Buffer.alloc(5), + boolean: false, + number: 7, + string: 'hello', + struct: { x: 0, y: 0 }, + list: ['anime', 'action', 'comedy'] } - ); + ] +} - const buf = toBuffer(table); - expect(buf.byteLength).toBeGreaterThan(0); - - const actual = tableFromIPC(buf); - expect(actual.numRows).toBe(3); - const actualSchema = actual.schema; - expect(actualSchema.toString()).toEqual(schema.toString()); -}); - -test("handles int64", function() { - // https://github.com/lancedb/lancedb/issues/960 +// Helper method to verify various ways to create a table +async function checkTableCreation (tableCreationMethod: (records: any, recordsReversed: any, schema: Schema) => Promise, infersTypes: boolean): Promise { + const records = sampleRecords() + const recordsReversed = [{ + list: ['anime', 'action', 'comedy'], + struct: { x: 0, y: 0 }, + string: 'hello', + number: 7, + boolean: false, + binary: Buffer.alloc(5) + }] const schema = new Schema([ - new Field("x", new Int64(), true) - ]); - const table = makeArrowTable([ - { x: 1 }, - { x: 2 }, - { x: 3 } - ], { schema }); - expect(table.schema).toEqual(schema); -}) \ No newline at end of file + new Field('binary', new Binary(), false), + new Field('boolean', new Bool(), false), + new Field('number', new Float64(), false), + new Field('string', new Utf8(), false), + new Field('struct', new Struct([ + new Field('x', new Float64(), false), + new Field('y', new Float64(), false) + ])), + new Field('list', new List(new Field('item', new Utf8(), false)), false) + ]) + + const table = await tableCreationMethod(records, recordsReversed, schema) + schema.fields.forEach((field, idx) => { + const actualField = table.schema.fields[idx] + // Type inference always assumes nullable=true + if (infersTypes) { + expect(actualField.nullable).toBe(true) + } else { + expect(actualField.nullable).toBe(false) + } + expect(table.getChild(field.name)?.type.toString()).toEqual(field.type.toString()) + expect(table.getChildAt(idx)?.type.toString()).toEqual(field.type.toString()) + }) +} + +describe('The function makeArrowTable', function () { + it('will use data types from a provided schema instead of inference', async function () { + const schema = new Schema([ + new Field('a', new Int32()), + new Field('b', new Float32()), + new Field('c', new FixedSizeList(3, new Field('item', new Float16()))), + new Field('d', new Int64()) + ]) + const table = makeArrowTable( + [ + { a: 1, b: 2, c: [1, 2, 3], d: 9 }, + { a: 4, b: 5, c: [4, 5, 6], d: 10 }, + { a: 7, b: 8, c: [7, 8, 9], d: null } + ], + { schema } + ) + + const buf = await fromTableToBuffer(table) + expect(buf.byteLength).toBeGreaterThan(0) + + const actual = tableFromIPC(buf) + expect(actual.numRows).toBe(3) + const actualSchema = actual.schema + expect(actualSchema).toEqual(schema) + }) + + it('will assume the column `vector` is FixedSizeList by default', async function () { + const schema = new Schema([ + new Field('a', new Float(Precision.DOUBLE), true), + new Field('b', new Float(Precision.DOUBLE), true), + new Field( + 'vector', + new FixedSizeList(3, new Field('item', new Float(Precision.SINGLE), true)), + true + ) + ]) + const table = makeArrowTable([ + { a: 1, b: 2, vector: [1, 2, 3] }, + { a: 4, b: 5, vector: [4, 5, 6] }, + { a: 7, b: 8, vector: [7, 8, 9] } + ]) + + const buf = await fromTableToBuffer(table) + expect(buf.byteLength).toBeGreaterThan(0) + + const actual = tableFromIPC(buf) + expect(actual.numRows).toBe(3) + const actualSchema = actual.schema + expect(actualSchema).toEqual(schema) + }) + + it('can support multiple vector columns', async function () { + const schema = new Schema([ + new Field('a', new Float(Precision.DOUBLE), true), + new Field('b', new Float(Precision.DOUBLE), true), + new Field('vec1', new FixedSizeList(3, new Field('item', new Float16(), true)), true), + new Field('vec2', new FixedSizeList(3, new Field('item', new Float16(), true)), true) + ]) + const table = makeArrowTable( + [ + { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, + { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, + { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] } + ], + { + vectorColumns: { + vec1: { type: new Float16() }, + vec2: { type: new Float16() } + } + } + ) + + const buf = await fromTableToBuffer(table) + expect(buf.byteLength).toBeGreaterThan(0) + + const actual = tableFromIPC(buf) + expect(actual.numRows).toBe(3) + const actualSchema = actual.schema + expect(actualSchema).toEqual(schema) + }) + + it('will allow different vector column types', async function () { + const table = makeArrowTable( + [ + { fp16: [1], fp32: [1], fp64: [1] } + ], + { + vectorColumns: { + fp16: { type: new Float16() }, + fp32: { type: new Float32() }, + fp64: { type: new Float64() } + } + } + ) + + expect(table.getChild('fp16')?.type.children[0].type.toString()).toEqual(new Float16().toString()) + expect(table.getChild('fp32')?.type.children[0].type.toString()).toEqual(new Float32().toString()) + expect(table.getChild('fp64')?.type.children[0].type.toString()).toEqual(new Float64().toString()) + }) + + it('will use dictionary encoded strings if asked', async function () { + const table = makeArrowTable([{ str: 'hello' }]) + expect(DataType.isUtf8(table.getChild('str')?.type)).toBe(true) + + const tableWithDict = makeArrowTable([{ str: 'hello' }], { dictionaryEncodeStrings: true }) + expect(DataType.isDictionary(tableWithDict.getChild('str')?.type)).toBe(true) + + const schema = new Schema([ + new Field('str', new Dictionary(new Utf8(), new Int32())) + ]) + + const tableWithDict2 = makeArrowTable([{ str: 'hello' }], { schema }) + expect(DataType.isDictionary(tableWithDict2.getChild('str')?.type)).toBe(true) + }) + + it('will infer data types correctly', async function () { + await checkTableCreation(async (records) => makeArrowTable(records), true) + }) + + it('will allow a schema to be provided', async function () { + await checkTableCreation(async (records, _, schema) => makeArrowTable(records, { schema }), false) + }) + + it('will use the field order of any provided schema', async function () { + await checkTableCreation(async (_, recordsReversed, schema) => makeArrowTable(recordsReversed, { schema }), false) + }) + + it('will make an empty table', async function () { + await checkTableCreation(async (_, __, schema) => makeArrowTable([], { schema }), false) + }) +}) + +class DummyEmbedding implements EmbeddingFunction { + public readonly sourceColumn = 'string' + public readonly embeddingDimension = 2 + public readonly embeddingDataType = new Float16() + + async embed (data: string[]): Promise { + return data.map( + () => [0.0, 0.0] + ) + } +} + +class DummyEmbeddingWithNoDimension implements EmbeddingFunction { + public readonly sourceColumn = 'string' + + async embed (data: string[]): Promise { + return data.map( + () => [0.0, 0.0] + ) + } +} + +describe('convertToTable', function () { + it('will infer data types correctly', async function () { + await checkTableCreation(async (records) => await convertToTable(records), true) + }) + + it('will allow a schema to be provided', async function () { + await checkTableCreation(async (records, _, schema) => await convertToTable(records, undefined, { schema }), false) + }) + + it('will use the field order of any provided schema', async function () { + await checkTableCreation(async (_, recordsReversed, schema) => await convertToTable(recordsReversed, undefined, { schema }), false) + }) + + it('will make an empty table', async function () { + await checkTableCreation(async (_, __, schema) => await convertToTable([], undefined, { schema }), false) + }) + + it('will apply embeddings', async function () { + const records = sampleRecords() + const table = await convertToTable(records, new DummyEmbedding()) + expect(DataType.isFixedSizeList(table.getChild('vector')?.type)).toBe(true) + expect(table.getChild('vector')?.type.children[0].type.toString()).toEqual(new Float16().toString()) + }) + + it('will fail if missing the embedding source column', async function () { + await expect(convertToTable([{ id: 1 }], new DummyEmbedding())).rejects.toThrow("'string' was not present") + }) + + it('use embeddingDimension if embedding missing from table', async function () { + const schema = new Schema([ + new Field('string', new Utf8(), false) + ]) + // Simulate getting an empty Arrow table (minus embedding) from some other source + // In other words, we aren't starting with records + const table = makeEmptyTable(schema) + + // If the embedding specifies the dimension we are fine + await fromTableToBuffer(table, new DummyEmbedding()) + + // We can also supply a schema and should be ok + const schemaWithEmbedding = new Schema([ + new Field('string', new Utf8(), false), + new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false) + ]) + await fromTableToBuffer(table, new DummyEmbeddingWithNoDimension(), schemaWithEmbedding) + + // Otherwise we will get an error + await expect(fromTableToBuffer(table, new DummyEmbeddingWithNoDimension())).rejects.toThrow('does not specify `embeddingDimension`') + }) + + it('will apply embeddings to an empty table', async function () { + const schema = new Schema([ + new Field('string', new Utf8(), false), + new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false) + ]) + const table = await convertToTable([], new DummyEmbedding(), { schema }) + expect(DataType.isFixedSizeList(table.getChild('vector')?.type)).toBe(true) + expect(table.getChild('vector')?.type.children[0].type.toString()).toEqual(new Float16().toString()) + }) + + it('will complain if embeddings present but schema missing embedding column', async function () { + const schema = new Schema([ + new Field('string', new Utf8(), false) + ]) + await expect(convertToTable([], new DummyEmbedding(), { schema })).rejects.toThrow('column vector was missing') + }) + + it('will provide a nice error if run twice', async function () { + const records = sampleRecords() + const table = await convertToTable(records, new DummyEmbedding()) + // fromTableToBuffer will try and apply the embeddings again + await expect(fromTableToBuffer(table, new DummyEmbedding())).rejects.toThrow('already existed') + }) +}) + +describe('makeEmptyTable', function () { + it('will make an empty table', async function () { + await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema), false) + }) +}) diff --git a/nodejs/__test__/connection.test.ts b/nodejs/__test__/connection.test.ts index 4ffcb906..202f9435 100644 --- a/nodejs/__test__/connection.test.ts +++ b/nodejs/__test__/connection.test.ts @@ -12,18 +12,49 @@ // See the License for the specific language governing permissions and // limitations under the License. -import * as os from "os"; -import * as path from "path"; -import * as fs from "fs"; +import * as tmp from "tmp"; -import { connect } from "../dist/index.js"; +import { Connection, connect } from "../dist/index.js"; -describe("when working with a connection", () => { +describe("when connecting", () => { - const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-connection")); + let tmpDir: tmp.DirResult; + beforeEach(() => tmpDir = tmp.dirSync({ unsafeCleanup: true })); + afterEach(() => tmpDir.removeCallback()); + + it("should connect", async() => { + const db = await connect(tmpDir.name); + expect(db.display()).toBe(`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`); + }) + + it("should allow read consistency interval to be specified", async() => { + const db = await connect(tmpDir.name, { readConsistencyInterval: 5}); + expect(db.display()).toBe(`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`); + }) +}); + +describe("given a connection", () => { + + let tmpDir: tmp.DirResult + let db: Connection + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + db = await connect(tmpDir.name) + }); + afterEach(() => tmpDir.removeCallback()); + + it("should raise an error if opening a non-existent table", async() => { + await expect(db.openTable("non-existent")).rejects.toThrow("was not found"); + }) + + it("should raise an error if any operation is tried after it is closed", async() => { + expect(db.isOpen()).toBe(true); + await db.close(); + expect(db.isOpen()).toBe(false); + await expect(db.tableNames()).rejects.toThrow("Connection is closed"); + }) it("should fail if creating table twice, unless overwrite is true", async() => { - const db = await connect(tmpDir); let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]); await expect(tbl.countRows()).resolves.toBe(2); await expect(db.createTable("test", [{ id: 1 }, { id: 2 }])).rejects.toThrow(); @@ -31,4 +62,10 @@ describe("when working with a connection", () => { await expect(tbl.countRows()).resolves.toBe(1); }) + it("should list tables", async() => { + await db.createTable("test2", [{ id: 1 }, { id: 2 }]); + await db.createTable("test1", [{ id: 1 }, { id: 2 }]); + expect(await db.tableNames()).toEqual(["test1", "test2"]); + }) + }); diff --git a/nodejs/__test__/index.test.ts b/nodejs/__test__/index.test.ts deleted file mode 100644 index dd7266ec..00000000 --- a/nodejs/__test__/index.test.ts +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Lance Developers. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import * as os from "os"; -import * as path from "path"; -import * as fs from "fs"; - -import { Schema, Field, Float64 } from "apache-arrow"; -import { connect } from "../dist/index.js"; - -test("open database", async () => { - const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-open")); - - const db = await connect(tmpDir); - let tableNames = await db.tableNames(); - expect(tableNames).toStrictEqual([]); - - const tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]); - expect(await db.tableNames()).toStrictEqual(["test"]); - - const schema = await tbl.schema(); - expect(schema).toEqual(new Schema([new Field("id", new Float64(), true)])); -}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 6de039ad..29840fa7 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -12,27 +12,68 @@ // See the License for the specific language governing permissions and // limitations under the License. -import * as os from "os"; -import * as path from "path"; import * as fs from "fs"; +import * as path from "path"; +import * as tmp from "tmp"; -import { connect } from "../dist"; +import { Table, connect } from "../dist"; import { Schema, Field, Float32, Int32, FixedSizeList, Int64, Float64 } from "apache-arrow"; import { makeArrowTable } from "../dist/arrow"; +describe("Given a table", () => { + let tmpDir: tmp.DirResult; + let table: Table; + const schema = new Schema([ + new Field("id", new Float64(), true), + ]); + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const conn = await connect(tmpDir.name); + table = await conn.createEmptyTable("some_table", schema); + }); + afterEach(() => tmpDir.removeCallback()); + + it("be displayable", async () => { + expect(table.display()).toMatch(/NativeTable\(some_table, uri=.*, read_consistency_interval=None\)/); + table.close() + expect(table.display()).toBe("ClosedTable(some_table)") + }) + + it("should let me add data", async () => { + await table.add([{ id: 1 }, { id: 2 }]); + await table.add([{ id: 1 }]); + await expect(table.countRows()).resolves.toBe(3); + }) + + it("should overwrite data if asked", async () => { + await table.add([{ id: 1 }, { id: 2 }]); + await table.add([{ id: 1 }], { mode: "overwrite" }); + await expect(table.countRows()).resolves.toBe(1); + }) + + it("should let me close the table", async () => { + expect(table.isOpen()).toBe(true); + table.close(); + expect(table.isOpen()).toBe(false); + expect(table.countRows()).rejects.toThrow("Table some_table is closed"); + }) + +}) + describe("Test creating index", () => { - let tmpDir: string; + let tmpDir: tmp.DirResult; const schema = new Schema([ new Field("id", new Int32(), true), new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))), ]); beforeEach(() => { - tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "index-")); + tmpDir = tmp.dirSync({ unsafeCleanup: true }); }); + afterEach(() => tmpDir.removeCallback()); test("create vector index with no column", async () => { - const db = await connect(tmpDir); + const db = await connect(tmpDir.name); const data = makeArrowTable( Array(300) .fill(1) @@ -50,7 +91,7 @@ describe("Test creating index", () => { await tbl.createIndex().build(); // check index directory - const indexDir = path.join(tmpDir, "test.lance", "_indices"); + const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); // TODO: check index type. @@ -66,7 +107,7 @@ describe("Test creating index", () => { }); test("no vector column available", async () => { - const db = await connect(tmpDir); + const db = await connect(tmpDir.name); const tbl = await db.createTable( "no_vec", makeArrowTable([ @@ -79,7 +120,7 @@ describe("Test creating index", () => { ); await tbl.createIndex("val").build(); - const indexDir = path.join(tmpDir, "no_vec.lance", "_indices"); + const indexDir = path.join(tmpDir.name, "no_vec.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); for await (const r of tbl.query().filter("id > 1").select(["id"])) { @@ -88,7 +129,7 @@ describe("Test creating index", () => { }); test("two columns with different dimensions", async () => { - const db = await connect(tmpDir); + const db = await connect(tmpDir.name); const schema = new Schema([ new Field("id", new Int32(), true), new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))), @@ -158,7 +199,7 @@ describe("Test creating index", () => { }); test("create scalar index", async () => { - const db = await connect(tmpDir); + const db = await connect(tmpDir.name); const data = makeArrowTable( Array(300) .fill(1) @@ -176,25 +217,27 @@ describe("Test creating index", () => { await tbl.createIndex("id").build(); // check index directory - const indexDir = path.join(tmpDir, "test.lance", "_indices"); + const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); // TODO: check index type. }); }); describe("Read consistency interval", () => { - let tmpDir: string; + + let tmpDir: tmp.DirResult; beforeEach(() => { - tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "read-consistency-")); + tmpDir = tmp.dirSync({ unsafeCleanup: true }); }); + afterEach(() => tmpDir.removeCallback()); // const intervals = [undefined, 0, 0.1]; const intervals = [0]; test.each(intervals)("read consistency interval %p", async (interval) => { - const db = await connect({ uri: tmpDir }); + const db = await connect(tmpDir.name); const table = await db.createTable("my_table", [{ id: 1 }]); - const db2 = await connect({ uri: tmpDir, readConsistencyInterval: interval }); + const db2 = await connect(tmpDir.name, { readConsistencyInterval: interval }); const table2 = await db2.openTable("my_table"); expect(await table2.countRows()).toEqual(await table.countRows()); @@ -218,14 +261,18 @@ describe("Read consistency interval", () => { describe('schema evolution', function () { - let tmpDir: string; + + let tmpDir: tmp.DirResult; beforeEach(() => { - tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "schema-evolution-")); + tmpDir = tmp.dirSync({ unsafeCleanup: true }); }); + afterEach(() => { + tmpDir.removeCallback(); + }) // Create a new sample table it('can add a new column to the schema', async function () { - const con = await connect(tmpDir) + const con = await connect(tmpDir.name) const table = await con.createTable('vectors', [ { id: 1n, vector: [0.1, 0.2] } ]) @@ -241,7 +288,7 @@ describe('schema evolution', function () { }); it('can alter the columns in the schema', async function () { - const con = await connect(tmpDir) + const con = await connect(tmpDir.name) const schema = new Schema([ new Field('id', new Int64(), true), new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true), @@ -268,7 +315,7 @@ describe('schema evolution', function () { }); it('can drop a column from the schema', async function () { - const con = await connect(tmpDir) + const con = await connect(tmpDir.name) const table = await con.createTable('vectors', [ { id: 1n, vector: [0.1, 0.2] } ]) @@ -279,4 +326,4 @@ describe('schema evolution', function () { ]) expect(await table.schema()).toEqual(expectedSchema) }); -}); \ No newline at end of file +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 1923eaf0..6b3c58bb 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -1,4 +1,4 @@ -// Copyright 2024 Lance Developers. +// Copyright 2023 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,44 +13,91 @@ // limitations under the License. import { - Int64, Field, + makeBuilder, + RecordBatchFileWriter, + Utf8, + type Vector, FixedSizeList, - Float, - Float32, - Schema, - Table as ArrowTable, - Table, - Vector, vectorFromArray, - tableToIPC, + type Schema, + Table as ArrowTable, + RecordBatchStreamWriter, + List, + RecordBatch, + makeData, + Struct, + type Float, DataType, -} from "apache-arrow"; + Binary, + Float32 +} from 'apache-arrow' +import { type EmbeddingFunction } from './embedding/embedding_function' +import { Table } from './native'; /** Data type accepted by NodeJS SDK */ export type Data = Record[] | ArrowTable; +/* + * Options to control how a column should be converted to a vector array + */ export class VectorColumnOptions { /** Vector column type. */ - type: Float = new Float32(); + type: Float = new Float32() - constructor(values?: Partial) { - Object.assign(this, values); + constructor (values?: Partial) { + Object.assign(this, values) } } /** Options to control the makeArrowTable call. */ export class MakeArrowTableOptions { - /** Provided schema. */ - schema?: Schema; + /* + * Schema of the data. + * + * If this is not provided then the data type will be inferred from the + * JS type. Integer numbers will become int64, floating point numbers + * will become float64 and arrays will become variable sized lists with + * the data type inferred from the first element in the array. + * + * The schema must be specified if there are no records (e.g. to make + * an empty table) + */ + schema?: Schema - /** Vector columns */ + /* + * Mapping from vector column name to expected type + * + * Lance expects vector columns to be fixed size list arrays (i.e. tensors) + * However, `makeArrowTable` will not infer this by default (it creates + * variable size list arrays). This field can be used to indicate that a column + * should be treated as a vector column and converted to a fixed size list. + * + * The keys should be the names of the vector columns. The value specifies the + * expected data type of the vector columns. + * + * If `schema` is provided then this field is ignored. + * + * By default, the column named "vector" will be assumed to be a float32 + * vector column. + */ vectorColumns: Record = { - vector: new VectorColumnOptions(), - }; + vector: new VectorColumnOptions() + } - constructor(values?: Partial) { - Object.assign(this, values); + /** + * If true then string columns will be encoded with dictionary encoding + * + * Set this to true if your string columns tend to repeat the same values + * often. For more precise control use the `schema` property to specify the + * data type for individual columns. + * + * If `schema` is provided then this property is ignored. + */ + dictionaryEncodeStrings: boolean = false + + constructor (values?: Partial) { + Object.assign(this, values) } } @@ -58,8 +105,30 @@ export class MakeArrowTableOptions { * An enhanced version of the {@link makeTable} function from Apache Arrow * that supports nested fields and embeddings columns. * + * This function converts an array of Record (row-major JS objects) + * to an Arrow Table (a columnar structure) + * * Note that it currently does not support nulls. * + * If a schema is provided then it will be used to determine the resulting array + * types. Fields will also be reordered to fit the order defined by the schema. + * + * If a schema is not provided then the types will be inferred and the field order + * will be controlled by the order of properties in the first record. If a type + * is inferred it will always be nullable. + * + * If the input is empty then a schema must be provided to create an empty table. + * + * When a schema is not specified then data types will be inferred. The inference + * rules are as follows: + * + * - boolean => Bool + * - number => Float64 + * - String => Utf8 + * - Buffer => Binary + * - Record => Struct + * - Array => List + * * @param data input data * @param options options to control the makeArrowTable call. * @@ -82,25 +151,27 @@ export class MakeArrowTableOptions { * ], { schema }); * ``` * - * It guesses the vector columns if the schema is not provided. For example, - * by default it assumes that the column named `vector` is a vector column. + * By default it assumes that the column named `vector` is a vector column + * and it will be converted into a fixed size list array of type float32. + * The `vectorColumns` option can be used to support other vector column + * names and data types. * * ```ts * * const schema = new Schema([ - new Field("a", new Float64()), - new Field("b", new Float64()), - new Field( - "vector", - new FixedSizeList(3, new Field("item", new Float32())) - ), - ]); - const table = makeArrowTable([ - { a: 1, b: 2, vector: [1, 2, 3] }, - { a: 4, b: 5, vector: [4, 5, 6] }, - { a: 7, b: 8, vector: [7, 8, 9] }, - ]); - assert.deepEqual(table.schema, schema); + new Field("a", new Float64()), + new Field("b", new Float64()), + new Field( + "vector", + new FixedSizeList(3, new Field("item", new Float32())) + ), + ]); + const table = makeArrowTable([ + { a: 1, b: 2, vector: [1, 2, 3] }, + { a: 4, b: 5, vector: [4, 5, 6] }, + { a: 7, b: 8, vector: [7, 8, 9] }, + ]); + assert.deepEqual(table.schema, schema); * ``` * * You can specify the vector column types and names using the options as well @@ -108,81 +179,372 @@ export class MakeArrowTableOptions { * ```typescript * * const schema = new Schema([ - new Field('a', new Float64()), - new Field('b', new Float64()), - new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))), - new Field('vec2', new FixedSizeList(3, new Field('item', new Float16()))) - ]); + new Field('a', new Float64()), + new Field('b', new Float64()), + new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))), + new Field('vec2', new FixedSizeList(3, new Field('item', new Float16()))) + ]); * const table = makeArrowTable([ - { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, - { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, - { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] } - ], { - vectorColumns: { - vec1: { type: new Float16() }, - vec2: { type: new Float16() } - } - } + { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, + { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, + { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] } + ], { + vectorColumns: { + vec1: { type: new Float16() }, + vec2: { type: new Float16() } + } + } * assert.deepEqual(table.schema, schema) * ``` */ -export function makeArrowTable( - data: Record[], +export function makeArrowTable ( + data: Array>, options?: Partial -): Table { - if (data.length === 0) { - throw new Error("At least one record needs to be provided"); +): ArrowTable { + if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) { + throw new Error('At least one record or a schema needs to be provided') } - const opt = new MakeArrowTableOptions(options ?? {}); - const columns: Record = {}; - // TODO: sample dataset to find missing columns - const columnNames = Object.keys(data[0]); - for (const colName of columnNames) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-return - let values = data.map((datum) => datum[colName]); - let vector: Vector; + const opt = new MakeArrowTableOptions(options !== undefined ? options : {}) + const columns: Record = {} + // TODO: sample dataset to find missing columns + // Prefer the field ordering of the schema, if present + const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0]) + for (const colName of columnNames) { + if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) { + // The field is present in the schema, but not in the data, skip it + continue + } + // Extract a single column from the records (transpose from row-major to col-major) + let values = data.map((datum) => datum[colName]) + + // By default (type === undefined) arrow will infer the type from the JS type + let type if (opt.schema !== undefined) { - // Explicit schema is provided, highest priority - const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType; - if (fieldType instanceof Int64) { + // If there is a schema provided, then use that for the type instead + type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type + if (DataType.isInt(type) && type.bitWidth === 64) { // wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051 - // eslint-disable-next-line @typescript-eslint/no-unsafe-argument - values = values.map((v) => BigInt(v)); + values = values.map((v) => { + if (v === null) { + return v + } + return BigInt(v) + }) } - vector = vectorFromArray(values, fieldType); } else { - const vectorColumnOptions = opt.vectorColumns[colName]; + // Otherwise, check to see if this column is one of the vector columns + // defined by opt.vectorColumns and, if so, use the fixed size list type + const vectorColumnOptions = opt.vectorColumns[colName] if (vectorColumnOptions !== undefined) { - const fslType = new FixedSizeList( - (values[0] as any[]).length, - new Field("item", vectorColumnOptions.type, false) - ); - vector = vectorFromArray(values, fslType); - } else { - // Normal case - vector = vectorFromArray(values); + type = newVectorType(values[0].length, vectorColumnOptions.type) } } - columns[colName] = vector; + + try { + // Convert an Array of JS values to an arrow vector + columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings) + } catch (error: unknown) { + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + throw Error(`Could not convert column "${colName}" to Arrow: ${error}`) + } } - return new Table(columns); + if (opt.schema != null) { + // `new ArrowTable(columns)` infers a schema which may sometimes have + // incorrect nullability (it assumes nullable=true always) + // + // `new ArrowTable(schema, columns)` will also fail because it will create a + // batch with an inferred schema and then complain that the batch schema + // does not match the provided schema. + // + // To work around this we first create a table with the wrong schema and + // then patch the schema of the batches so we can use + // `new ArrowTable(schema, batches)` which does not do any schema inference + const firstTable = new ArrowTable(columns) + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data)) + return new ArrowTable(opt.schema, batchesFixed) + } else { + return new ArrowTable(columns) + } } /** - * Convert an Arrow Table to a Buffer. - * - * @param data Arrow Table - * @param schema Arrow Schema, optional - * @returns Buffer node + * Create an empty Arrow table with the provided schema */ -export function toBuffer(data: Data, schema?: Schema): Buffer { - let tbl: Table; - if (data instanceof Table) { - tbl = data; - } else { - tbl = makeArrowTable(data, { schema }); - } - return Buffer.from(tableToIPC(tbl)); +export function makeEmptyTable (schema: Schema): ArrowTable { + return makeArrowTable([], { schema }) +} + +// Helper function to convert Array> to a variable sized list array +function makeListVector (lists: any[][]): Vector { + if (lists.length === 0 || lists[0].length === 0) { + throw Error('Cannot infer list vector from empty array or empty list') + } + const sampleList = lists[0] + let inferredType + try { + const sampleVector = makeVector(sampleList) + inferredType = sampleVector.type + } catch (error: unknown) { + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`) + } + + const listBuilder = makeBuilder({ + type: new List(new Field('item', inferredType, true)) + }) + for (const list of lists) { + listBuilder.append(list) + } + return listBuilder.finish().toVector() +} + +// Helper function to convert an Array of JS values to an Arrow Vector +function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector { + if (type !== undefined) { + // No need for inference, let Arrow create it + return vectorFromArray(values, type) + } + if (values.length === 0) { + throw Error('makeVector requires at least one value or the type must be specfied') + } + const sampleValue = values.find(val => val !== null && val !== undefined) + if (sampleValue === undefined) { + throw Error('makeVector cannot infer the type if all values are null or undefined') + } + if (Array.isArray(sampleValue)) { + // Default Arrow inference doesn't handle list types + return makeListVector(values) + } else if (Buffer.isBuffer(sampleValue)) { + // Default Arrow inference doesn't handle Buffer + return vectorFromArray(values, new Binary()) + } else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) { + // If the type is string then don't use Arrow's default inference unless dictionaries are requested + // because it will always use dictionary encoding for strings + return vectorFromArray(values, new Utf8()) + } else { + // Convert a JS array of values to an arrow vector + return vectorFromArray(values) + } +} + +async function applyEmbeddings (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { + if (embeddings == null) { + return table + } + + // Convert from ArrowTable to Record + const colEntries = [...Array(table.numCols).keys()].map((_, idx) => { + const name = table.schema.fields[idx].name + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const vec = table.getChildAt(idx)! + return [name, vec] + }) + const newColumns = Object.fromEntries(colEntries) + + const sourceColumn = newColumns[embeddings.sourceColumn] + const destColumn = embeddings.destColumn ?? 'vector' + const innerDestType = embeddings.embeddingDataType ?? new Float32() + if (sourceColumn === undefined) { + throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`) + } + + if (table.numRows === 0) { + if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) { + // We have an empty table and it already has the embedding column so no work needs to be done + // Note: we don't return an error like we did below because this is a common occurrence. For example, + // if we call convertToTable with 0 records and a schema that includes the embedding + return table + } + if (embeddings.embeddingDimension !== undefined) { + const destType = newVectorType(embeddings.embeddingDimension, innerDestType) + newColumns[destColumn] = makeVector([], destType) + } else if (schema != null) { + const destField = schema.fields.find(f => f.name === destColumn) + if (destField != null) { + newColumns[destColumn] = makeVector([], destField.type) + } else { + throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`) + } + } else { + throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`') + } + } else { + if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) { + throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`) + } + if (table.batches.length > 1) { + throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch') + } + const values = sourceColumn.toArray() + const vectors = await embeddings.embed(values as T[]) + if (vectors.length !== values.length) { + throw new Error('Embedding function did not return an embedding for each input element') + } + const destType = newVectorType(vectors[0].length, innerDestType) + newColumns[destColumn] = makeVector(vectors, destType) + } + + const newTable = new ArrowTable(newColumns) + if (schema != null) { + if (schema.fields.find(f => f.name === destColumn) === undefined) { + throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`) + } + return alignTable(newTable, schema) + } + return newTable +} + +/* + * Convert an Array of records into an Arrow Table, optionally applying an + * embeddings function to it. + * + * This function calls `makeArrowTable` first to create the Arrow Table. + * Any provided `makeTableOptions` (e.g. a schema) will be passed on to + * that call. + * + * The embedding function will be passed a column of values (based on the + * `sourceColumn` of the embedding function) and expects to receive back + * number[][] which will be converted into a fixed size list column. By + * default this will be a fixed size list of Float32 but that can be + * customized by the `embeddingDataType` property of the embedding function. + * + * If a schema is provided in `makeTableOptions` then it should include the + * embedding columns. If no schema is provded then embedding columns will + * be placed at the end of the table, after all of the input columns. + */ +export async function convertToTable ( + data: Array>, + embeddings?: EmbeddingFunction, + makeTableOptions?: Partial +): Promise { + const table = makeArrowTable(data, makeTableOptions) + return await applyEmbeddings(table, embeddings, makeTableOptions?.schema) +} + +// Creates the Arrow Type for a Vector column with dimension `dim` +function newVectorType (dim: number, innerType: T): FixedSizeList { + // in Lance we always default to have the elements nullable, so we need to set it to true + // otherwise we often get schema mismatches because the stored data always has schema with nullable elements + const children = new Field('item', innerType, true) + return new FixedSizeList(dim, children) +} + +/** + * Serialize an Array of records into a buffer using the Arrow IPC File serialization + * + * This function will call `convertToTable` and pass on `embeddings` and `schema` + * + * `schema` is required if data is empty + */ +export async function fromRecordsToBuffer ( + data: Array>, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { + const table = await convertToTable(data, embeddings, { schema }) + const writer = RecordBatchFileWriter.writeAll(table) + return Buffer.from(await writer.toUint8Array()) +} + +/** + * Serialize an Array of records into a buffer using the Arrow IPC Stream serialization + * + * This function will call `convertToTable` and pass on `embeddings` and `schema` + * + * `schema` is required if data is empty + */ +export async function fromRecordsToStreamBuffer ( + data: Array>, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { + const table = await convertToTable(data, embeddings, { schema }) + const writer = RecordBatchStreamWriter.writeAll(table) + return Buffer.from(await writer.toUint8Array()) +} + +/** + * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization + * + * This function will apply `embeddings` to the table in a manner similar to + * `convertToTable`. + * + * `schema` is required if the table is empty + */ +export async function fromTableToBuffer ( + table: ArrowTable, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { + const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) + const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings) + return Buffer.from(await writer.toUint8Array()) +} + +export async function fromDataToBuffer ( + data: Data, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { + if (data instanceof ArrowTable) { + return fromTableToBuffer(data, embeddings, schema) + } else { + const table = await convertToTable(data); + return fromTableToBuffer(table, embeddings, schema); + } +} + +/** + * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization + * + * This function will apply `embeddings` to the table in a manner similar to + * `convertToTable`. + * + * `schema` is required if the table is empty + */ +export async function fromTableToStreamBuffer ( + table: ArrowTable, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { + const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) + const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings) + return Buffer.from(await writer.toUint8Array()) +} + +function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch { + const alignedChildren = [] + for (const field of schema.fields) { + const indexInBatch = batch.schema.fields?.findIndex( + (f) => f.name === field.name + ) + if (indexInBatch < 0) { + throw new Error( + `The column ${field.name} was not found in the Arrow Table` + ) + } + alignedChildren.push(batch.data.children[indexInBatch]) + } + const newData = makeData({ + type: new Struct(schema.fields), + length: batch.numRows, + nullCount: batch.nullCount, + children: alignedChildren + }) + return new RecordBatch(schema, newData) +} + +function alignTable (table: ArrowTable, schema: Schema): ArrowTable { + const alignedBatches = table.batches.map((batch) => + alignBatch(batch, schema) + ) + return new ArrowTable(schema, alignedBatches) +} + +// Creates an empty Arrow Table +export function createEmptyTable (schema: Schema): ArrowTable { + return new ArrowTable(schema) } diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index 46b1109f..3f24d42d 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { toBuffer } from "./arrow"; -import { Connection as _NativeConnection } from "./native"; +import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow"; +import { Connection as LanceDbConnection } from "./native"; import { Table } from "./table"; -import { Table as ArrowTable } from "apache-arrow"; +import { Table as ArrowTable, Schema } from "apache-arrow"; export interface CreateTableOptions { /** @@ -39,14 +39,47 @@ export interface CreateTableOptions { * A LanceDB Connection that allows you to open tables and create new ones. * * Connection could be local against filesystem or remote against a server. + * + * A Connection is intended to be a long lived object and may hold open + * resources such as HTTP connection pools. This is generally fine and + * a single connection should be shared if it is going to be used many + * times. However, if you are finished with a connection, you may call + * close to eagerly free these resources. Any call to a Connection + * method after it has been closed will result in an error. + * + * Closing a connection is optional. Connections will automatically + * be closed when they are garbage collected. + * + * Any created tables are independent and will continue to work even if + * the underlying connection has been closed. */ export class Connection { - readonly inner: _NativeConnection; + readonly inner: LanceDbConnection; - constructor(inner: _NativeConnection) { + constructor(inner: LanceDbConnection) { this.inner = inner; } + /** Return true if the connection has not been closed */ + isOpen(): boolean { + return this.inner.isOpen(); + } + + /** Close the connection, releasing any underlying resources. + * + * It is safe to call this method multiple times. + * + * Any attempt to use the connection after it is closed will result in an error. + */ + close(): void { + this.inner.close(); + } + + /** Return a brief description of the connection */ + display(): string { + return this.inner.display(); + } + /** List all the table names in this database. */ async tableNames(): Promise { return this.inner.tableNames(); @@ -81,11 +114,41 @@ export class Connection { mode = "exist_ok"; } - const buf = toBuffer(data); + let table: ArrowTable; + if (data instanceof ArrowTable) { + table = data; + } else { + table = makeArrowTable(data); + } + const buf = await fromTableToBuffer(table); const innerTable = await this.inner.createTable(name, buf, mode); return new Table(innerTable); } + /** + * Creates a new empty Table + * + * @param {string} name - The name of the table. + * @param schema - The schema of the table + */ + async createEmptyTable( + name: string, + schema: Schema, + options?: Partial + ): Promise
{ + let mode: string = options?.mode ?? "create"; + const existOk = options?.existOk ?? false; + + if (mode === "create" && existOk) { + mode = "exist_ok"; + } + + const table = makeEmptyTable(schema); + const buf = await fromTableToBuffer(table); + const innerTable = await this.inner.createEmptyTable(name, buf, mode); + return new Table(innerTable); + } + /** * Drop an existing table. * @param name The name of the table to drop. diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts new file mode 100644 index 00000000..4900a976 --- /dev/null +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -0,0 +1,68 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { type Float } from 'apache-arrow' + +/** + * An embedding function that automatically creates vector representation for a given column. + */ +export interface EmbeddingFunction { + /** + * The name of the column that will be used as input for the Embedding Function. + */ + sourceColumn: string + + /** + * The data type of the embedding + * + * The embedding function should return `number`. This will be converted into + * an Arrow float array. By default this will be Float32 but this property can + * be used to control the conversion. + */ + embeddingDataType?: Float + + /** + * The dimension of the embedding + * + * This is optional, normally this can be determined by looking at the results of + * `embed`. If this is not specified, and there is an attempt to apply the embedding + * to an empty table, then that process will fail. + */ + embeddingDimension?: number + + /** + * The name of the column that will contain the embedding + * + * By default this is "vector" + */ + destColumn?: string + + /** + * Should the source column be excluded from the resulting table + * + * By default the source column is included. Set this to true and + * only the embedding will be stored. + */ + excludeSource?: boolean + + /** + * Creates a vector representation for the given values. + */ + embed: (data: T[]) => Promise +} + +export function isEmbeddingFunction (value: any): value is EmbeddingFunction { + return typeof value.sourceColumn === 'string' && + typeof value.embed === 'function' +} diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts new file mode 100644 index 00000000..354e470f --- /dev/null +++ b/nodejs/lancedb/embedding/openai.ts @@ -0,0 +1,57 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { type EmbeddingFunction } from './embedding_function' +import type OpenAI from 'openai' + +export class OpenAIEmbeddingFunction implements EmbeddingFunction { + private readonly _openai: OpenAI + private readonly _modelName: string + + constructor (sourceColumn: string, openAIKey: string, modelName: string = 'text-embedding-ada-002') { + /** + * @type {import("openai").default} + */ + let Openai + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + Openai = require('openai') + } catch { + throw new Error('please install openai@^4.24.1 using npm install openai') + } + + this.sourceColumn = sourceColumn + const configuration = { + apiKey: openAIKey + } + + this._openai = new Openai(configuration) + this._modelName = modelName + } + + async embed (data: string[]): Promise { + const response = await this._openai.embeddings.create({ + model: this._modelName, + input: data + }) + + const embeddings: number[][] = [] + for (let i = 0; i < response.data.length; i++) { + embeddings.push(response.data[i].embedding) + } + return embeddings + } + + sourceColumn: string +} diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index fab396c4..4503bdc9 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -13,7 +13,7 @@ // limitations under the License. import { Connection } from "./connection"; -import { Connection as NativeConnection, ConnectionOptions } from "./native.js"; +import { Connection as LanceDbConnection, ConnectionOptions } from "./native.js"; export { ConnectionOptions, @@ -23,7 +23,6 @@ export { } from "./native.js"; export { Connection } from "./connection"; export { Table } from "./table"; -export { Data } from "./arrow"; export { IvfPQOptions, IndexBuilder } from "./indexer"; /** @@ -39,26 +38,9 @@ export { IvfPQOptions, IndexBuilder } from "./indexer"; * * @see {@link ConnectionOptions} for more details on the URI format. */ -export async function connect(uri: string): Promise; -export async function connect( - opts: Partial -): Promise; -export async function connect( - args: string | Partial -): Promise { - let opts: ConnectionOptions; - if (typeof args === "string") { - opts = { uri: args }; - } else { - opts = Object.assign( - { - uri: "", - apiKey: undefined, - hostOverride: undefined, - }, - args - ); - } - const nativeConn = await NativeConnection.new(opts); +export async function connect(uri: string, opts?: Partial): Promise +{ + opts = opts ?? {}; + const nativeConn = await LanceDbConnection.new(uri, opts); return new Connection(nativeConn); } diff --git a/nodejs/lancedb/native.d.ts b/nodejs/lancedb/native.d.ts index e72b54cb..baa2199e 100644 --- a/nodejs/lancedb/native.d.ts +++ b/nodejs/lancedb/native.d.ts @@ -45,7 +45,6 @@ export interface AddColumnsSql { valueSql: string } export interface ConnectionOptions { - uri: string apiKey?: string hostOverride?: string /** @@ -71,10 +70,13 @@ export const enum WriteMode { export interface WriteOptions { mode?: WriteMode } -export function connect(options: ConnectionOptions): Promise +export function connect(uri: string, options: ConnectionOptions): Promise export class Connection { /** Create a new Connection instance from the given URI. */ - static new(options: ConnectionOptions): Promise + static new(uri: string, options: ConnectionOptions): Promise + display(): string + isOpen(): boolean + close(): void /** List all tables in the dataset. */ tableNames(): Promise> /** @@ -86,6 +88,7 @@ export class Connection { * */ createTable(name: string, buf: Buffer, mode: string): Promise
+ createEmptyTable(name: string, schemaBuf: Buffer, mode: string): Promise
openTable(name: string): Promise
/** Drop table with the name. Or raise an error if the table does not exist. */ dropTable(name: string): Promise @@ -114,9 +117,12 @@ export class Query { executeStream(): Promise } export class Table { + display(): string + isOpen(): boolean + close(): void /** Return Schema as empty Arrow IPC file. */ schema(): Promise - add(buf: Buffer): Promise + add(buf: Buffer, mode: string): Promise countRows(filter?: string | undefined | null): Promise delete(predicate: string): Promise createIndex(): IndexBuilder diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index e2ef723a..342ca631 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -14,14 +14,32 @@ import { Schema, tableFromIPC } from "apache-arrow"; import { AddColumnsSql, ColumnAlteration, Table as _NativeTable } from "./native"; -import { toBuffer, Data } from "./arrow"; import { Query } from "./query"; import { IndexBuilder } from "./indexer"; +import { Data, fromDataToBuffer } from "./arrow"; /** - * A LanceDB Table is the collection of Records. + * Options for adding data to a table. + */ +export interface AddDataOptions { + /** If "append" (the default) then the new data will be added to the table + * + * If "overwrite" then the new data will replace the existing data in the table. + */ + mode: "append" | "overwrite"; +} + +/** + * A Table is a collection of Records in a LanceDB Database. * - * Each Record has one or more vector fields. + * A Table object is expected to be long lived and reused for multiple operations. + * Table objects will cache a certain amount of index data in memory. This cache + * will be freed when the Table is garbage collected. To eagerly free the cache you + * can call the `close` method. Once the Table is closed, it cannot be used for any + * further operations. + * + * Closing a table is optional. It not closed, it will be closed when it is garbage + * collected. */ export class Table { private readonly inner: _NativeTable; @@ -31,6 +49,27 @@ export class Table { this.inner = inner; } + + /** Return true if the table has not been closed */ + isOpen(): boolean { + return this.inner.isOpen(); + } + + /** Close the table, releasing any underlying resources. + * + * It is safe to call this method multiple times. + * + * Any attempt to use the table after it is closed will result in an error. + */ + close(): void { + this.inner.close(); + } + + /** Return a brief description of the table */ + display(): string { + return this.inner.display(); + } + /** Get the schema of the table. */ async schema(): Promise { const schemaBuf = await this.inner.schema(); @@ -44,9 +83,11 @@ export class Table { * @param {Data} data Records to be inserted into the Table * @return The number of rows added to the table */ - async add(data: Data): Promise { - const buffer = toBuffer(data); - await this.inner.add(buffer); + async add(data: Data, options?: Partial): Promise { + let mode = options?.mode ?? "append"; + + const buffer = await fromDataToBuffer(data); + await this.inner.add(buffer, mode); } /** Count the total number of rows in the dataset. */ diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index b0f33580..e9fa2c8e 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,11 +1,11 @@ { - "name": "vectordb", + "name": "lancedb", "version": "0.4.3", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "vectordb", + "name": "lancedb", "version": "0.4.3", "cpu": [ "x64", @@ -17,16 +17,15 @@ "linux", "windows" ], - "dependencies": { - "apache-arrow": "^15.0.0" - }, "devDependencies": { "@napi-rs/cli": "^2.18.0", "@types/jest": "^29.1.2", + "@types/tmp": "^0.2.6", "@typescript-eslint/eslint-plugin": "^6.19.0", "@typescript-eslint/parser": "^6.19.0", "eslint": "^8.56.0", "jest": "^29.7.0", + "tmp": "^0.2.3", "ts-jest": "^29.1.2", "typedoc": "^0.25.7", "typedoc-plugin-markdown": "^3.17.1", @@ -36,16 +35,21 @@ "node": ">= 18" }, "optionalDependencies": { - "vectordb-darwin-arm64": "0.4.3", - "vectordb-darwin-x64": "0.4.3", - "vectordb-linux-arm64-gnu": "0.4.3", - "vectordb-linux-x64-gnu": "0.4.3" + "lancedb-darwin-arm64": "0.4.3", + "lancedb-darwin-x64": "0.4.3", + "lancedb-linux-arm64-gnu": "0.4.3", + "lancedb-linux-x64-gnu": "0.4.3", + "openai": "^4.28.4" + }, + "peerDependencies": { + "apache-arrow": "^15.0.0" } }, "node_modules/@75lb/deep-merge": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@75lb/deep-merge/-/deep-merge-1.1.1.tgz", "integrity": "sha512-xvgv6pkMGBA6GwdyJbNAnDmfAIR/DfWhrj9jgWh3TY7gRm3KO46x/GPjRg6wJ0nOepwqrNxFfojebh0Df4h4Tw==", + "peer": true, "dependencies": { "lodash.assignwith": "^4.2.0", "typical": "^7.1.1" @@ -58,6 +62,7 @@ "version": "7.1.1", "resolved": "https://registry.npmjs.org/typical/-/typical-7.1.1.tgz", "integrity": "sha512-T+tKVNs6Wu7IWiAce5BgMd7OZfNYUndHwc5MknN+UHOudi7sGZzuHdCadllRuqJ3fPtgFtIH9+lt9qRv6lmpfA==", + "peer": true, "engines": { "node": ">=12.17" } @@ -1416,9 +1421,10 @@ } }, "node_modules/@swc/helpers": { - "version": "0.5.3", - "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.3.tgz", - "integrity": "sha512-FaruWX6KdudYloq1AHD/4nU+UsMTdNE8CKyrseXWEcgjDAbvkwJg2QGPAnfIJLIWsjZOSPLOAykK6fuYp4vp4A==", + "version": "0.5.6", + "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.6.tgz", + "integrity": "sha512-aYX01Ke9hunpoCexYAgQucEpARGQ5w/cqHFrIR+e9gdKb1QWTsVJuTJ2ozQzIAxLyRQe/m+2RqzkyOOGiMKRQA==", + "peer": true, "dependencies": { "tslib": "^2.4.0" } @@ -1467,12 +1473,14 @@ "node_modules/@types/command-line-args": { "version": "5.2.3", "resolved": "https://registry.npmjs.org/@types/command-line-args/-/command-line-args-5.2.3.tgz", - "integrity": "sha512-uv0aG6R0Y8WHZLTamZwtfsDLVRnOa+n+n5rEvFWL5Na5gZ8V2Teab/duDPFzIIIhs9qizDpcavCusCLJZu62Kw==" + "integrity": "sha512-uv0aG6R0Y8WHZLTamZwtfsDLVRnOa+n+n5rEvFWL5Na5gZ8V2Teab/duDPFzIIIhs9qizDpcavCusCLJZu62Kw==", + "peer": true }, "node_modules/@types/command-line-usage": { - "version": "5.0.4", - "resolved": "https://registry.npmjs.org/@types/command-line-usage/-/command-line-usage-5.0.4.tgz", - "integrity": "sha512-BwR5KP3Es/CSht0xqBcUXS3qCAUVXwpRKsV2+arxeb65atasuXG9LykC9Ab10Cw3s2raH92ZqOeILaQbsB2ACg==" + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/@types/command-line-usage/-/command-line-usage-5.0.2.tgz", + "integrity": "sha512-n7RlEEJ+4x4TS7ZQddTmNSxP+zziEG0TNsMfiRIxcIVXt71ENJ9ojeXmGO3wPoTdn7pJcU2xc3CJYMktNT6DPg==", + "peer": true }, "node_modules/@types/graceful-fs": { "version": "4.1.9", @@ -1531,6 +1539,16 @@ "undici-types": "~5.26.4" } }, + "node_modules/@types/node-fetch": { + "version": "2.6.11", + "resolved": "https://registry.npmjs.org/@types/node-fetch/-/node-fetch-2.6.11.tgz", + "integrity": "sha512-24xFj9R5+rfQJLRyM56qh+wnVSYhyXC2tkoBndtY0U+vubqNsYXGjufB2nn8Q6gt0LrARwL6UBtMCSVCwl4B1g==", + "optional": true, + "dependencies": { + "@types/node": "*", + "form-data": "^4.0.0" + } + }, "node_modules/@types/semver": { "version": "7.5.6", "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.5.6.tgz", @@ -1543,6 +1561,12 @@ "integrity": "sha512-9aEbYZ3TbYMznPdcdr3SmIrLXwC/AKZXQeCf9Pgao5CKb8CyHuEX5jzWPTkvregvhRJHcpRO6BFoGW9ycaOkYw==", "dev": true }, + "node_modules/@types/tmp": { + "version": "0.2.6", + "resolved": "https://registry.npmjs.org/@types/tmp/-/tmp-0.2.6.tgz", + "integrity": "sha512-chhaNf2oKHlRkDGt+tiKE2Z5aJ6qalm7Z9rlLdBwmOiAAf09YQvvoLXjWK4HWPF1xU/fqvMgfNfpVoBscA/tKA==", + "dev": true + }, "node_modules/@types/yargs": { "version": "17.0.32", "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-17.0.32.tgz", @@ -1807,6 +1831,18 @@ "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==", "dev": true }, + "node_modules/abort-controller": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", + "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", + "optional": true, + "dependencies": { + "event-target-shim": "^5.0.0" + }, + "engines": { + "node": ">=6.5" + } + }, "node_modules/acorn": { "version": "8.11.3", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.3.tgz", @@ -1828,6 +1864,18 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/agentkeepalive": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/agentkeepalive/-/agentkeepalive-4.5.0.tgz", + "integrity": "sha512-5GG/5IbQQpC9FpkRGsSvZI5QYeSCzlJHdpBQntCsuTOxhKD8lqKhrleg2Yi7yvMIf82Ycmmqln9U8V9qwEiJew==", + "optional": true, + "dependencies": { + "humanize-ms": "^1.2.1" + }, + "engines": { + "node": ">= 8.0.0" + } + }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -1917,6 +1965,7 @@ "version": "15.0.0", "resolved": "https://registry.npmjs.org/apache-arrow/-/apache-arrow-15.0.0.tgz", "integrity": "sha512-e6aunxNKM+woQf137ny3tp/xbLjFJS2oGQxQhYGqW6dGeIwNV1jOeEAeR6sS2jwAI2qLO83gYIP2MBz02Gw5Xw==", + "peer": true, "dependencies": { "@swc/helpers": "^0.5.2", "@types/command-line-args": "^5.2.1", @@ -1945,6 +1994,7 @@ "version": "3.1.0", "resolved": "https://registry.npmjs.org/array-back/-/array-back-3.1.0.tgz", "integrity": "sha512-TkuxA4UCOvxuDK6NZYXCalszEzj+TLszyASooky+i742l9TqsOdYCMJJupxRic61hwquNtppB3hgcuq9SVSH1Q==", + "peer": true, "engines": { "node": ">=6" } @@ -1958,6 +2008,12 @@ "node": ">=8" } }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "optional": true + }, "node_modules/babel-jest": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.7.0.tgz", @@ -2089,6 +2145,12 @@ "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", "dev": true }, + "node_modules/base-64": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/base-64/-/base-64-0.1.0.tgz", + "integrity": "sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA==", + "optional": true + }, "node_modules/brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -2218,6 +2280,7 @@ "version": "0.4.0", "resolved": "https://registry.npmjs.org/chalk-template/-/chalk-template-0.4.0.tgz", "integrity": "sha512-/ghrgmhfY8RaSdeo43hNXxpoHAtxdbskUHjPpfqUWGttFgycUhYPGx3YZBCnUCvOa7Doivn1IZec3DEGFoMgLg==", + "peer": true, "dependencies": { "chalk": "^4.1.2" }, @@ -2237,6 +2300,15 @@ "node": ">=10" } }, + "node_modules/charenc": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/charenc/-/charenc-0.0.2.tgz", + "integrity": "sha512-yrLQ/yVUFXkzg7EDQsPieE/53+0RlaWTs+wBrvW36cyilJ2SaDWfl4Yj7MtLTXleV9uEKefbAGUPv2/iWSooRA==", + "optional": true, + "engines": { + "node": "*" + } + }, "node_modules/cjs-module-lexer": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/cjs-module-lexer/-/cjs-module-lexer-1.2.3.tgz", @@ -2289,10 +2361,23 @@ "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "optional": true, + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/command-line-args": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/command-line-args/-/command-line-args-5.2.1.tgz", "integrity": "sha512-H4UfQhZyakIjC74I9d34fGYDwk3XpSr17QhEd0Q3I9Xq1CETHo4Hcuo87WyWHpAF1aSLjLRf5lD9ZGX2qStUvg==", + "peer": true, "dependencies": { "array-back": "^3.1.0", "find-replace": "^3.0.0", @@ -2307,6 +2392,7 @@ "version": "7.0.1", "resolved": "https://registry.npmjs.org/command-line-usage/-/command-line-usage-7.0.1.tgz", "integrity": "sha512-NCyznE//MuTjwi3y84QVUGEOT+P5oto1e1Pk/jFPVdPPfsG03qpTIl3yw6etR+v73d0lXsoojRpvbru2sqePxQ==", + "peer": true, "dependencies": { "array-back": "^6.2.2", "chalk-template": "^0.4.0", @@ -2321,6 +2407,7 @@ "version": "6.2.2", "resolved": "https://registry.npmjs.org/array-back/-/array-back-6.2.2.tgz", "integrity": "sha512-gUAZ7HPyb4SJczXAMUXMGAvI976JoK3qEx9v1FTmeYuJj0IBiaKttG1ydtGKdkfqWkIkouke7nG8ufGy77+Cvw==", + "peer": true, "engines": { "node": ">=12.17" } @@ -2329,6 +2416,7 @@ "version": "7.1.1", "resolved": "https://registry.npmjs.org/typical/-/typical-7.1.1.tgz", "integrity": "sha512-T+tKVNs6Wu7IWiAce5BgMd7OZfNYUndHwc5MknN+UHOudi7sGZzuHdCadllRuqJ3fPtgFtIH9+lt9qRv6lmpfA==", + "peer": true, "engines": { "node": ">=12.17" } @@ -2380,6 +2468,15 @@ "node": ">= 8" } }, + "node_modules/crypt": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/crypt/-/crypt-0.0.2.tgz", + "integrity": "sha512-mCxBlsHFYh9C+HVpiEacem8FEBnMXgU9gy4zmNC+SXAZNB/1idgp/aulFJ4FgCi7GPEVbfyng092GqL2k2rmow==", + "optional": true, + "engines": { + "node": "*" + } + }, "node_modules/debug": { "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", @@ -2432,6 +2529,15 @@ "node": ">=0.10.0" } }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "optional": true, + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/detect-newline": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/detect-newline/-/detect-newline-3.1.0.tgz", @@ -2450,6 +2556,16 @@ "node": "^14.15.0 || ^16.10.0 || >=18.0.0" } }, + "node_modules/digest-fetch": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/digest-fetch/-/digest-fetch-1.3.0.tgz", + "integrity": "sha512-CGJuv6iKNM7QyZlM2T3sPAdZWd/p9zQiRNS9G+9COUCwzWFTs0Xp8NF5iePx7wtvhDykReiRRrSeNb4oMmB8lA==", + "optional": true, + "dependencies": { + "base-64": "^0.1.0", + "md5": "^2.3.0" + } + }, "node_modules/dir-glob": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", @@ -2710,6 +2826,15 @@ "node": ">=0.10.0" } }, + "node_modules/event-target-shim": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", + "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", + "optional": true, + "engines": { + "node": ">=6" + } + }, "node_modules/exit": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/exit/-/exit-0.1.2.tgz", @@ -2815,6 +2940,7 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/find-replace/-/find-replace-3.0.0.tgz", "integrity": "sha512-6Tb2myMioCAgv5kfvP5/PkZZ/ntTpVK39fHY7WkWBgvbeE+VHd/tZuZ4mrC+bxh4cfOZeYKVPaJIZtZXV7GNCQ==", + "peer": true, "dependencies": { "array-back": "^3.0.1" }, @@ -2855,7 +2981,8 @@ "node_modules/flatbuffers": { "version": "23.5.26", "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-23.5.26.tgz", - "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==" + "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==", + "peer": true }, "node_modules/flatted": { "version": "3.2.9", @@ -2863,6 +2990,48 @@ "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", "dev": true }, + "node_modules/form-data": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", + "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", + "optional": true, + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/form-data-encoder": { + "version": "1.7.2", + "resolved": "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz", + "integrity": "sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==", + "optional": true + }, + "node_modules/formdata-node": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/formdata-node/-/formdata-node-4.4.1.tgz", + "integrity": "sha512-0iirZp3uVDjVGt9p49aTaqjk84TrglENEDuqfdlZQ1roC9CWlPk6Avf8EEnZNcAqPonwkG35x4n3ww/1THYAeQ==", + "optional": true, + "dependencies": { + "node-domexception": "1.0.0", + "web-streams-polyfill": "4.0.0-beta.3" + }, + "engines": { + "node": ">= 12.20" + } + }, + "node_modules/formdata-node/node_modules/web-streams-polyfill": { + "version": "4.0.0-beta.3", + "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-4.0.0-beta.3.tgz", + "integrity": "sha512-QW95TCTaHmsYfHDybGMwO5IJIM93I/6vTRk+daHTWFPhwh+C8Cg7j7XyKrwrj8Ib6vYXe0ocYNrmzY4xAAN6ug==", + "optional": true, + "engines": { + "node": ">= 14" + } + }, "node_modules/fs.realpath": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", @@ -3049,6 +3218,15 @@ "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", "dev": true }, + "node_modules/humanize-ms": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/humanize-ms/-/humanize-ms-1.2.1.tgz", + "integrity": "sha512-Fl70vYtsAFb/C06PTS9dZBo7ihau+Tu/DNCk/OyHhea07S+aeMWpFFkUaXRa8fI+ScZbEI8dfSxwY7gxZ9SAVQ==", + "optional": true, + "dependencies": { + "ms": "^2.0.0" + } + }, "node_modules/ignore": { "version": "5.3.0", "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.0.tgz", @@ -3133,6 +3311,12 @@ "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==", "dev": true }, + "node_modules/is-buffer": { + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.6.tgz", + "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==", + "optional": true + }, "node_modules/is-core-module": { "version": "2.13.1", "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz", @@ -4067,6 +4251,7 @@ "version": "0.0.3", "resolved": "https://registry.npmjs.org/json-bignum/-/json-bignum-0.0.3.tgz", "integrity": "sha512-2WHyXj3OfHSgNyuzDbSxI1w2jgw5gkWSWhS7Qg4bWXx1nLk3jnbwfUeS0PSba3IzpTUWdHxBieELUzXRjQB2zg==", + "peer": true, "engines": { "node": ">=0.8" } @@ -4177,12 +4362,14 @@ "node_modules/lodash.assignwith": { "version": "4.2.0", "resolved": "https://registry.npmjs.org/lodash.assignwith/-/lodash.assignwith-4.2.0.tgz", - "integrity": "sha512-ZznplvbvtjK2gMvnQ1BR/zqPFZmS6jbK4p+6Up4xcRYA7yMIwxHCfbTcrYxXKzzqLsQ05eJPVznEW3tuwV7k1g==" + "integrity": "sha512-ZznplvbvtjK2gMvnQ1BR/zqPFZmS6jbK4p+6Up4xcRYA7yMIwxHCfbTcrYxXKzzqLsQ05eJPVznEW3tuwV7k1g==", + "peer": true }, "node_modules/lodash.camelcase": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/lodash.camelcase/-/lodash.camelcase-4.3.0.tgz", - "integrity": "sha512-TwuEnCnxbc3rAvhf/LbG7tJUDzhqXyFnv3dtzLOPgCG/hODL7WFnsbwktkD7yUV0RrreP/l1PALq/YSg6VvjlA==" + "integrity": "sha512-TwuEnCnxbc3rAvhf/LbG7tJUDzhqXyFnv3dtzLOPgCG/hODL7WFnsbwktkD7yUV0RrreP/l1PALq/YSg6VvjlA==", + "peer": true }, "node_modules/lodash.memoize": { "version": "4.1.2", @@ -4241,6 +4428,17 @@ "node": ">= 12" } }, + "node_modules/md5": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/md5/-/md5-2.3.0.tgz", + "integrity": "sha512-T1GITYmFaKuO91vxyoQMFETst+O71VUPEU3ze5GNzDm0OWdP8v1ziTaAEPUr/3kLsY3Sftgz242A1SetQiDL7g==", + "optional": true, + "dependencies": { + "charenc": "0.0.2", + "crypt": "0.0.2", + "is-buffer": "~1.1.6" + } + }, "node_modules/merge-stream": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", @@ -4269,6 +4467,27 @@ "node": ">=8.6" } }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "optional": true, + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -4290,6 +4509,12 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "optional": true + }, "node_modules/natural-compare": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", @@ -4302,6 +4527,45 @@ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", "dev": true }, + "node_modules/node-domexception": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/node-domexception/-/node-domexception-1.0.0.tgz", + "integrity": "sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "github", + "url": "https://paypal.me/jimmywarting" + } + ], + "optional": true, + "engines": { + "node": ">=10.5.0" + } + }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "optional": true, + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, "node_modules/node-int64": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/node-int64/-/node-int64-0.4.0.tgz", @@ -4332,6 +4596,35 @@ "wrappy": "1" } }, + "node_modules/openai": { + "version": "4.28.4", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.28.4.tgz", + "integrity": "sha512-RNIwx4MT/F0zyizGcwS+bXKLzJ8QE9IOyigDG/ttnwB220d58bYjYFp0qjvGwEFBO6+pvFVIDABZPGDl46RFsg==", + "optional": true, + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "digest-fetch": "^1.3.0", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7", + "web-streams-polyfill": "^3.2.1" + }, + "bin": { + "openai": "bin/cli" + } + }, + "node_modules/openai/node_modules/@types/node": { + "version": "18.19.20", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.20.tgz", + "integrity": "sha512-SKXZvI375jkpvAj8o+5U2518XQv76mAsixqfXiVyWyXZbVWQK25RurFovYpVIxVzul0rZoH58V/3SkEnm7s3qA==", + "optional": true, + "dependencies": { + "undici-types": "~5.26.4" + } + }, "node_modules/optionator": { "version": "0.9.3", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", @@ -4864,6 +5157,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/stream-read-all/-/stream-read-all-3.0.1.tgz", "integrity": "sha512-EWZT9XOceBPlVJRrYcykW8jyRSZYbkb/0ZK36uLEmoWVO5gxBOnntNTseNzfREsqxqdfEGQrD8SXQ3QWbBmq8A==", + "peer": true, "engines": { "node": ">=10" } @@ -4955,6 +5249,7 @@ "version": "3.0.2", "resolved": "https://registry.npmjs.org/table-layout/-/table-layout-3.0.2.tgz", "integrity": "sha512-rpyNZYRw+/C+dYkcQ3Pr+rLxW4CfHpXjPDnG7lYhdRoUcZTUt+KEsX+94RGp/aVp/MQU35JCITv2T/beY4m+hw==", + "peer": true, "dependencies": { "@75lb/deep-merge": "^1.1.1", "array-back": "^6.2.2", @@ -4975,6 +5270,7 @@ "version": "6.2.2", "resolved": "https://registry.npmjs.org/array-back/-/array-back-6.2.2.tgz", "integrity": "sha512-gUAZ7HPyb4SJczXAMUXMGAvI976JoK3qEx9v1FTmeYuJj0IBiaKttG1ydtGKdkfqWkIkouke7nG8ufGy77+Cvw==", + "peer": true, "engines": { "node": ">=12.17" } @@ -4983,6 +5279,7 @@ "version": "7.1.1", "resolved": "https://registry.npmjs.org/typical/-/typical-7.1.1.tgz", "integrity": "sha512-T+tKVNs6Wu7IWiAce5BgMd7OZfNYUndHwc5MknN+UHOudi7sGZzuHdCadllRuqJ3fPtgFtIH9+lt9qRv6lmpfA==", + "peer": true, "engines": { "node": ">=12.17" } @@ -5007,6 +5304,15 @@ "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", "dev": true }, + "node_modules/tmp": { + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.3.tgz", + "integrity": "sha512-nZD7m9iCPC5g0pYmcaxogYKggSfLsdxl8of3Q/oIbqCqLLIO9IAF0GWjX1z9NZRHPiXv8Wex4yDCaZsgEw0Y8w==", + "dev": true, + "engines": { + "node": ">=14.14" + } + }, "node_modules/tmpl": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/tmpl/-/tmpl-1.0.5.tgz", @@ -5034,6 +5340,12 @@ "node": ">=8.0" } }, + "node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "optional": true + }, "node_modules/ts-api-utils": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.0.3.tgz", @@ -5092,7 +5404,8 @@ "node_modules/tslib": { "version": "2.6.2", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", - "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "peer": true }, "node_modules/type-check": { "version": "0.4.0", @@ -5189,6 +5502,7 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/typical/-/typical-4.0.0.tgz", "integrity": "sha512-VAH4IvQ7BDFYglMd7BPRDfLgxZZX4O4TFcRDA6EN5X7erNJJq+McIEp8np9aVtxrCJ6qx4GTYVfOWNjcqwZgRw==", + "peer": true, "engines": { "node": ">=8" } @@ -5285,6 +5599,31 @@ "makeerror": "1.0.12" } }, + "node_modules/web-streams-polyfill": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-3.3.3.tgz", + "integrity": "sha512-d2JWLCivmZYTSIoge9MsgFCZrt571BikcWGYkjC1khllbTeDlGqZ2D8vD8E/lJa8WGWbb7Plm8/XJYV7IJHZZw==", + "optional": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", + "optional": true + }, + "node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "optional": true, + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -5310,6 +5649,7 @@ "version": "5.1.0", "resolved": "https://registry.npmjs.org/wordwrapjs/-/wordwrapjs-5.1.0.tgz", "integrity": "sha512-JNjcULU2e4KJwUNv6CHgI46UvDGitb6dGryHajXTDiLgg1/RiGoPSDw4kZfYnwGtEXf2ZMeIewDQgFGzkCB2Sg==", + "peer": true, "engines": { "node": ">=12.17" } diff --git a/nodejs/package.json b/nodejs/package.json index 39473320..d0dcdb8c 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -19,10 +19,12 @@ "devDependencies": { "@napi-rs/cli": "^2.18.0", "@types/jest": "^29.1.2", + "@types/tmp": "^0.2.6", "@typescript-eslint/eslint-plugin": "^6.19.0", "@typescript-eslint/parser": "^6.19.0", "eslint": "^8.56.0", "jest": "^29.7.0", + "tmp": "^0.2.3", "ts-jest": "^29.1.2", "typedoc": "^0.25.7", "typedoc-plugin-markdown": "^3.17.1", @@ -59,7 +61,8 @@ "lancedb-darwin-arm64": "0.4.3", "lancedb-darwin-x64": "0.4.3", "lancedb-linux-arm64-gnu": "0.4.3", - "lancedb-linux-x64-gnu": "0.4.3" + "lancedb-linux-x64-gnu": "0.4.3", + "openai": "^4.28.4" }, "peerDependencies": { "apache-arrow": "^15.0.0" diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 9bef5eec..afb7787e 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -18,11 +18,23 @@ use napi_derive::*; use crate::table::Table; use crate::ConnectionOptions; use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection, CreateTableMode}; -use lancedb::ipc::ipc_file_to_batches; +use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema}; #[napi] pub struct Connection { - conn: LanceDBConnection, + inner: Option, +} + +impl Connection { + pub(crate) fn inner_new(inner: LanceDBConnection) -> Self { + Self { inner: Some(inner) } + } + + fn get_inner(&self) -> napi::Result<&LanceDBConnection> { + self.inner + .as_ref() + .ok_or_else(|| napi::Error::from_reason("Connection is closed")) + } } impl Connection { @@ -40,8 +52,8 @@ impl Connection { impl Connection { /// Create a new Connection instance from the given URI. #[napi(factory)] - pub async fn new(options: ConnectionOptions) -> napi::Result { - let mut builder = ConnectBuilder::new(&options.uri); + pub async fn new(uri: String, options: ConnectionOptions) -> napi::Result { + let mut builder = ConnectBuilder::new(&uri); if let Some(api_key) = options.api_key { builder = builder.api_key(&api_key); } @@ -52,18 +64,33 @@ impl Connection { builder = builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval)); } - Ok(Self { - conn: builder + Ok(Self::inner_new( + builder .execute() .await .map_err(|e| napi::Error::from_reason(format!("{}", e)))?, - }) + )) + } + + #[napi] + pub fn display(&self) -> napi::Result { + Ok(self.get_inner()?.to_string()) + } + + #[napi] + pub fn is_open(&self) -> bool { + self.inner.is_some() + } + + #[napi] + pub fn close(&mut self) { + self.inner.take(); } /// List all tables in the dataset. #[napi] pub async fn table_names(&self) -> napi::Result> { - self.conn + self.get_inner()? .table_names() .await .map_err(|e| napi::Error::from_reason(format!("{}", e))) @@ -86,7 +113,7 @@ impl Connection { .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; let mode = Self::parse_create_mode_str(&mode)?; let tbl = self - .conn + .get_inner()? .create_table(&name, Box::new(batches)) .mode(mode) .execute() @@ -95,10 +122,31 @@ impl Connection { Ok(Table::new(tbl)) } + #[napi] + pub async fn create_empty_table( + &self, + name: String, + schema_buf: Buffer, + mode: String, + ) -> napi::Result
{ + let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| { + napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e)) + })?; + let mode = Self::parse_create_mode_str(&mode)?; + let tbl = self + .get_inner()? + .create_empty_table(&name, schema) + .mode(mode) + .execute() + .await + .map_err(|e| napi::Error::from_reason(format!("{}", e)))?; + Ok(Table::new(tbl)) + } + #[napi] pub async fn open_table(&self, name: String) -> napi::Result
{ let tbl = self - .conn + .get_inner()? .open_table(&name) .execute() .await @@ -109,7 +157,7 @@ impl Connection { /// Drop table with the name. Or raise an error if the table does not exist. #[napi] pub async fn drop_table(&self, name: String) -> napi::Result<()> { - self.conn + self.get_inner()? .drop_table(&name) .await .map_err(|e| napi::Error::from_reason(format!("{}", e))) diff --git a/nodejs/src/index.rs b/nodejs/src/index.rs index 91d3a7d6..7c9864f0 100644 --- a/nodejs/src/index.rs +++ b/nodejs/src/index.rs @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Mutex; + use lance_linalg::distance::MetricType as LanceMetricType; +use lancedb::index::IndexBuilder as LanceDbIndexBuilder; +use lancedb::Table as LanceDbTable; use napi_derive::napi; #[napi] @@ -40,58 +44,93 @@ impl From for LanceMetricType { #[napi] pub struct IndexBuilder { - inner: lancedb::index::IndexBuilder, + inner: Mutex>, +} + +impl IndexBuilder { + fn modify( + &self, + mod_fn: impl Fn(LanceDbIndexBuilder) -> LanceDbIndexBuilder, + ) -> napi::Result<()> { + let mut inner = self.inner.lock().unwrap(); + let inner_builder = inner.take().ok_or_else(|| { + napi::Error::from_reason("IndexBuilder has already been consumed".to_string()) + })?; + let inner_builder = mod_fn(inner_builder); + inner.replace(inner_builder); + Ok(()) + } } #[napi] impl IndexBuilder { - pub fn new(tbl: &dyn lancedb::Table) -> Self { + pub fn new(tbl: &LanceDbTable) -> Self { let inner = tbl.create_index(&[]); - Self { inner } + Self { + inner: Mutex::new(Some(inner)), + } } #[napi] - pub unsafe fn replace(&mut self, v: bool) { - self.inner.replace(v); + pub fn replace(&self, v: bool) -> napi::Result<()> { + self.modify(|b| b.replace(v)) } #[napi] - pub unsafe fn column(&mut self, c: String) { - self.inner.columns(&[c.as_str()]); + pub fn column(&self, c: String) -> napi::Result<()> { + self.modify(|b| b.columns(&[c.as_str()])) } #[napi] - pub unsafe fn name(&mut self, name: String) { - self.inner.name(name.as_str()); + pub fn name(&self, name: String) -> napi::Result<()> { + self.modify(|b| b.name(name.as_str())) } #[napi] - pub unsafe fn ivf_pq( - &mut self, + pub fn ivf_pq( + &self, metric_type: Option, num_partitions: Option, num_sub_vectors: Option, num_bits: Option, max_iterations: Option, sample_rate: Option, - ) { - self.inner.ivf_pq(); - metric_type.map(|m| self.inner.metric_type(m.into())); - num_partitions.map(|p| self.inner.num_partitions(p)); - num_sub_vectors.map(|s| self.inner.num_sub_vectors(s)); - num_bits.map(|b| self.inner.num_bits(b)); - max_iterations.map(|i| self.inner.max_iterations(i)); - sample_rate.map(|s| self.inner.sample_rate(s)); + ) -> napi::Result<()> { + self.modify(|b| { + let mut b = b.ivf_pq(); + if let Some(metric_type) = metric_type { + b = b.metric_type(metric_type.into()); + } + if let Some(num_partitions) = num_partitions { + b = b.num_partitions(num_partitions); + } + if let Some(num_sub_vectors) = num_sub_vectors { + b = b.num_sub_vectors(num_sub_vectors); + } + if let Some(num_bits) = num_bits { + b = b.num_bits(num_bits); + } + if let Some(max_iterations) = max_iterations { + b = b.max_iterations(max_iterations); + } + if let Some(sample_rate) = sample_rate { + b = b.sample_rate(sample_rate); + } + b + }) } #[napi] - pub unsafe fn scalar(&mut self) { - self.inner.scalar(); + pub fn scalar(&self) -> napi::Result<()> { + self.modify(|b| b.scalar()) } #[napi] pub async fn build(&self) -> napi::Result<()> { - self.inner + let inner = self.inner.lock().unwrap().take().ok_or_else(|| { + napi::Error::from_reason("IndexBuilder has already been consumed".to_string()) + })?; + inner .build() .await .map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?; diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index 8913e1d5..13e7453f 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -24,7 +24,6 @@ mod table; #[napi(object)] #[derive(Debug)] pub struct ConnectionOptions { - pub uri: String, pub api_key: Option, pub host_override: Option, /// (For LanceDB OSS only): The interval, in seconds, at which to check for @@ -54,6 +53,6 @@ pub struct WriteOptions { } #[napi] -pub async fn connect(options: ConnectionOptions) -> napi::Result { - Connection::new(options).await +pub async fn connect(uri: String, options: ConnectionOptions) -> napi::Result { + Connection::new(uri, options).await } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 891a6454..6710fd03 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -16,7 +16,7 @@ use lancedb::query::Query as LanceDBQuery; use napi::bindgen_prelude::*; use napi_derive::napi; -use crate::{iterator::RecordBatchIterator, table::Table}; +use crate::iterator::RecordBatchIterator; #[napi] pub struct Query { @@ -25,10 +25,8 @@ pub struct Query { #[napi] impl Query { - pub fn new(table: &Table) -> Self { - Self { - inner: table.table.query(), - } + pub fn new(query: LanceDBQuery) -> Self { + Self { inner: query } } #[napi] diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 6d46e466..66489b28 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -14,10 +14,8 @@ use arrow_ipc::writer::FileWriter; use lance::dataset::ColumnAlteration as LanceColumnAlteration; -use lancedb::{ - ipc::ipc_file_to_batches, - table::{AddDataOptions, TableRef}, -}; +use lancedb::ipc::ipc_file_to_batches; +use lancedb::table::{AddDataMode, Table as LanceDbTable}; use napi::bindgen_prelude::*; use napi_derive::napi; @@ -26,20 +24,52 @@ use crate::query::Query; #[napi] pub struct Table { - pub(crate) table: TableRef, + // We keep a duplicate of the table name so we can use it for error + // messages even if the table has been closed + name: String, + pub(crate) inner: Option, +} + +impl Table { + fn inner_ref(&self) -> napi::Result<&LanceDbTable> { + self.inner + .as_ref() + .ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name))) + } } #[napi] impl Table { - pub(crate) fn new(table: TableRef) -> Self { - Self { table } + pub(crate) fn new(table: LanceDbTable) -> Self { + Self { + name: table.name().to_string(), + inner: Some(table), + } + } + + #[napi] + pub fn display(&self) -> String { + match &self.inner { + None => format!("ClosedTable({})", self.name), + Some(inner) => inner.to_string(), + } + } + + #[napi] + pub fn is_open(&self) -> bool { + self.inner.is_some() + } + + #[napi] + pub fn close(&mut self) { + self.inner.take(); } /// Return Schema as empty Arrow IPC file. #[napi] pub async fn schema(&self) -> napi::Result { let schema = - self.table.schema().await.map_err(|e| { + self.inner_ref()?.schema().await.map_err(|e| { napi::Error::from_reason(format!("Failed to create IPC file: {}", e)) })?; let mut writer = FileWriter::try_new(vec![], &schema) @@ -53,52 +83,59 @@ impl Table { } #[napi] - pub async fn add(&self, buf: Buffer) -> napi::Result<()> { + pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; - self.table - .add(Box::new(batches), AddDataOptions::default()) - .await - .map_err(|e| { - napi::Error::from_reason(format!( - "Failed to add batches to table {}: {}", - self.table, e - )) - }) + let mut op = self.inner_ref()?.add(Box::new(batches)); + + op = if mode == "append" { + op.mode(AddDataMode::Append) + } else if mode == "overwrite" { + op.mode(AddDataMode::Overwrite) + } else { + return Err(napi::Error::from_reason(format!("Invalid mode: {}", mode))); + }; + + op.execute().await.map_err(|e| { + napi::Error::from_reason(format!( + "Failed to add batches to table {}: {}", + self.name, e + )) + }) } #[napi] pub async fn count_rows(&self, filter: Option) -> napi::Result { - self.table + self.inner_ref()? .count_rows(filter) .await .map(|val| val as i64) .map_err(|e| { napi::Error::from_reason(format!( "Failed to count rows in table {}: {}", - self.table, e + self.name, e )) }) } #[napi] pub async fn delete(&self, predicate: String) -> napi::Result<()> { - self.table.delete(&predicate).await.map_err(|e| { + self.inner_ref()?.delete(&predicate).await.map_err(|e| { napi::Error::from_reason(format!( "Failed to delete rows in table {}: predicate={}", - self.table, e + self.name, e )) }) } #[napi] - pub fn create_index(&self) -> IndexBuilder { - IndexBuilder::new(self.table.as_ref()) + pub fn create_index(&self) -> napi::Result { + Ok(IndexBuilder::new(self.inner_ref()?)) } #[napi] - pub fn query(&self) -> Query { - Query::new(self) + pub fn query(&self) -> napi::Result { + Ok(Query::new(self.inner_ref()?.query())) } #[napi] @@ -108,13 +145,13 @@ impl Table { .map(|sql| (sql.name, sql.value_sql)) .collect::>(); let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms); - self.table + self.inner_ref()? .add_columns(transforms, None) .await .map_err(|err| { napi::Error::from_reason(format!( "Failed to add columns to table {}: {}", - self.table, err + self.name, err )) })?; Ok(()) @@ -134,13 +171,13 @@ impl Table { .map(LanceColumnAlteration::from) .collect::>(); - self.table + self.inner_ref()? .alter_columns(&alterations) .await .map_err(|err| { napi::Error::from_reason(format!( "Failed to alter columns in table {}: {}", - self.table, err + self.name, err )) })?; Ok(()) @@ -149,12 +186,15 @@ impl Table { #[napi] pub async fn drop_columns(&self, columns: Vec) -> napi::Result<()> { let col_refs = columns.iter().map(String::as_str).collect::>(); - self.table.drop_columns(&col_refs).await.map_err(|err| { - napi::Error::from_reason(format!( - "Failed to drop columns from table {}: {}", - self.table, err - )) - })?; + self.inner_ref()? + .drop_columns(&col_refs) + .await + .map_err(|err| { + napi::Error::from_reason(format!( + "Failed to drop columns from table {}: {}", + self.name, err + )) + })?; Ok(()) } } diff --git a/python/ASYNC_MIGRATION.md b/python/ASYNC_MIGRATION.md index 6a9231c4..7acf75a6 100644 --- a/python/ASYNC_MIGRATION.md +++ b/python/ASYNC_MIGRATION.md @@ -15,10 +15,24 @@ need to use `await` to call these functions. ## Connection -No changes yet. +* The connection now has a `close` method. You can call this when + you are done with the connection to eagerly free resources. Currently + this is limited to freeing/closing the HTTP connection for remote + connections. In the future we may add caching or other resources to + native connections so this is probably a good practice even if you aren't using remote connections. + + In addition, the connection can be used as a context manager which may + be a more convenient way to ensure the connection is closed. + + It is not mandatory to call the `close` method. If you don't call it + the connection will be closed when the object is garbage collected. ## Table +* The table now has a `close` method, similar to the connection. This + can be used to eagerly free the cache used by a Table object. Similar + to the connection, it can be used as a context manager and it is not + mandatory to call the `close` method. * Previously `Table.schema` was a property. Now it is an async method. * The method `Table.__len__` was removed and `len(table)` will no longer work. Use `Table.count_rows` instead. diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 5464f654..103720b3 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -21,7 +21,7 @@ __version__ = importlib.metadata.version("lancedb") from ._lancedb import connect as lancedb_connect from .common import URI, sanitize_uri -from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection +from .db import AsyncConnection, DBConnection, LanceDBConnection from .remote.db import RemoteDBConnection from .schema import vector # noqa: F401 @@ -167,8 +167,17 @@ async def connect_async( conn : DBConnection A connection to a LanceDB database. """ - return AsyncLanceDBConnection( + if read_consistency_interval is not None: + read_consistency_interval_secs = read_consistency_interval.total_seconds() + else: + read_consistency_interval_secs = None + + return AsyncConnection( await lancedb_connect( - sanitize_uri(uri), api_key, region, host_override, read_consistency_interval + sanitize_uri(uri), + api_key, + region, + host_override, + read_consistency_interval_secs, ) ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index d1351084..8ac62fd1 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -13,6 +13,7 @@ class Connection(object): class Table(object): def name(self) -> str: ... + def __repr__(self) -> str: ... async def schema(self) -> pa.Schema: ... async def connect( diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index d87983da..c18656c3 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -17,7 +17,7 @@ import inspect import os from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union import pyarrow as pa from overrides import EnforceOverrides, override @@ -28,7 +28,7 @@ from lancedb.embeddings.registry import EmbeddingFunctionRegistry from lancedb.utils.events import register_event from .pydantic import LanceModel -from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data +from .table import AsyncTable, LanceTable, Table, _sanitize_data from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri if TYPE_CHECKING: @@ -427,12 +427,64 @@ class LanceDBConnection(DBConnection): filesystem.delete_dir(path) -class AsyncConnection(EnforceOverrides): - """An active LanceDB connection interface.""" +class AsyncConnection(object): + """An active LanceDB connection + + To obtain a connection you can use the [connect] function. + + This could be a native connection (using lance) or a remote connection (e.g. for + connecting to LanceDb Cloud) + + Local connections do not currently hold any open resources but they may do so in the + future (for example, for shared cache or connections to catalog services) Remote + connections represent an open connection to the remote server. The [close] method + can be used to release any underlying resources eagerly. The connection can also + be used as a context manager: + + Connections can be shared on multiple threads and are expected to be long lived. + Connections can also be used as a context manager, however, in many cases a single + connection can be used for the lifetime of the application and so this is often + not needed. Closing a connection is optional. If it is not closed then it will + be automatically closed when the connection object is deleted. + + Examples + -------- + + >>> import asyncio + >>> import lancedb + >>> async def my_connect(): + ... with await lancedb.connect("/tmp/my_dataset") as conn: + ... # do something with the connection + ... pass + ... # conn is closed here + """ + + def __init__(self, connection: LanceDbConnection): + self._inner = connection + + def __repr__(self): + return self._inner.__repr__() + + def __enter__(self): + self + + def __exit__(self, *_): + self.close() + + def is_open(self): + """Return True if the connection is open.""" + return self._inner.is_open() + + def close(self): + """Close the connection, releasing any underlying resources. + + It is safe to call this method multiple times. + + Any attempt to use the connection after it is closed will result in an error.""" + self._inner.close() - @abstractmethod async def table_names( - self, *, page_token: Optional[str] = None, limit: int = 10 + self, *, page_token: Optional[str] = None, limit: Optional[int] = None ) -> Iterable[str]: """List all tables in this database, in sorted order @@ -450,18 +502,18 @@ class AsyncConnection(EnforceOverrides): ------- Iterable of str """ - pass + # TODO: hook in page_token and limit + return await self._inner.table_names() - @abstractmethod async def create_table( self, name: str, data: Optional[DATA] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None, - mode: str = "create", - exist_ok: bool = False, - on_bad_vectors: str = "error", - fill_value: float = 0.0, + mode: Optional[Literal["create", "overwrite"]] = None, + exist_ok: Optional[bool] = None, + on_bad_vectors: Optional[str] = None, + fill_value: Optional[float] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, ) -> Table: """Create a [Table][lancedb.table.Table] in the database. @@ -485,7 +537,7 @@ class AsyncConnection(EnforceOverrides): - pyarrow.Schema - [LanceModel][lancedb.pydantic.LanceModel] - mode: str; default "create" + mode: Literal["create", "overwrite"]; default "create" The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. @@ -601,72 +653,6 @@ class AsyncConnection(EnforceOverrides): LanceTable(connection=..., name="table4") """ - raise NotImplementedError - - async def open_table(self, name: str) -> Table: - """Open a Lance Table in the database. - - Parameters - ---------- - name: str - The name of the table. - - Returns - ------- - A LanceTable object representing the table. - """ - raise NotImplementedError - - async def drop_table(self, name: str): - """Drop a table from the database. - - Parameters - ---------- - name: str - The name of the table. - """ - raise NotImplementedError - - async def drop_database(self): - """ - Drop database - This is the same thing as dropping all the tables - """ - raise NotImplementedError - - -class AsyncLanceDBConnection(AsyncConnection): - def __init__(self, connection: LanceDbConnection): - self._inner = connection - - async def __repr__(self) -> str: - pass - - @override - async def table_names( - self, - *, - page_token=None, - limit=None, - ) -> Iterable[str]: - # TODO: hook in page_token and limit - return await self._inner.table_names() - - @override - async def create_table( - self, - name: str, - data: Optional[DATA] = None, - schema: Optional[Union[pa.Schema, LanceModel]] = None, - mode: str = "create", - exist_ok: bool = False, - on_bad_vectors: str = "error", - fill_value: float = 0.0, - embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, - ) -> Table: - if mode.lower() not in ["create", "overwrite"]: - raise ValueError("mode must be either 'create' or 'overwrite'") - if inspect.isclass(schema) and issubclass(schema, LanceModel): # convert LanceModel to pyarrow schema # note that it's possible this contains @@ -681,6 +667,14 @@ class AsyncLanceDBConnection(AsyncConnection): registry = EmbeddingFunctionRegistry.get_instance() metadata = registry.get_table_metadata(embedding_functions) + # Defining defaults here and not in function prototype. In the future + # these defaults will move into rust so better to keep them as None. + if on_bad_vectors is None: + on_bad_vectors = "error" + + if fill_value is None: + fill_value = 0.0 + if data is not None: data = _sanitize_data( data, @@ -708,6 +702,10 @@ class AsyncLanceDBConnection(AsyncConnection): schema = schema.with_metadata(metadata) validate_schema(schema) + if exist_ok is None: + exist_ok = False + if mode is None: + mode = "create" if mode == "create" and exist_ok: mode = "exist_ok" @@ -722,16 +720,37 @@ class AsyncLanceDBConnection(AsyncConnection): ) register_event("create_table") - return AsyncLanceTable(new_table) + return AsyncTable(new_table) - @override - async def open_table(self, name: str) -> LanceTable: + async def open_table(self, name: str) -> Table: + """Open a Lance Table in the database. + + Parameters + ---------- + name: str + The name of the table. + + Returns + ------- + A LanceTable object representing the table. + """ + table = await self._inner.open_table(name) + register_event("open_table") + return AsyncTable(table) + + async def drop_table(self, name: str): + """Drop a table from the database. + + Parameters + ---------- + name: str + The name of the table. + """ raise NotImplementedError - @override - async def drop_table(self, name: str, ignore_missing: bool = False): - raise NotImplementedError - - @override async def drop_database(self): + """ + Drop database + This is the same thing as dropping all the tables + """ raise NotImplementedError diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 9c2ee743..d70eda9a 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -19,7 +19,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + Union, +) import lance import numpy as np @@ -28,7 +38,6 @@ import pyarrow.compute as pc import pyarrow.fs as pa_fs from lance import LanceDataset from lance.vector import vec_to_table -from overrides import override from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry @@ -1776,9 +1785,23 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name return data -class AsyncTable(ABC): +class AsyncTable: """ - A Table is a collection of Records in a LanceDB Database. + An AsyncTable is a collection of Records in a LanceDB Database. + + An AsyncTable can be obtained from the + [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and + [AsyncConnection.open_table][lancedb.AsyncConnection.open_table] methods. + + An AsyncTable object is expected to be long lived and reused for multiple + operations. AsyncTable objects will cache a certain amount of index data in memory. + This cache will be freed when the Table is garbage collected. To eagerly free the + cache you can call the [close][AsyncTable.close] method. Once the AsyncTable is + closed, it cannot be used for any further operations. + + An AsyncTable can also be used as a context manager, and will automatically close + when the context is exited. Closing a table is optional. If you do not close the + table, it will be closed when the AsyncTable object is garbage collected. Examples -------- @@ -1813,21 +1836,49 @@ class AsyncTable(ABC): [Table.create_index][lancedb.table.Table.create_index]. """ + def __init__(self, table: LanceDBTable): + """Create a new Table object. + + You should not create Table objects directly. + + Use [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and + [AsyncConnection.open_table][lancedb.AsyncConnection.open_table] to obtain + Table objects.""" + self._inner = table + + def __repr__(self): + return self._inner.__repr__() + + def __enter__(self): + return self + + def __exit__(self, *_): + self.close() + + def is_open(self) -> bool: + """Return True if the table is closed.""" + return self._inner.is_open() + + def close(self): + """Close the table and free any resources associated with it. + + It is safe to call this method multiple times. + + Any attempt to use the table after it has been closed will raise an error.""" + return self._inner.close() + @property - @abstractmethod def name(self) -> str: """The name of the table.""" - raise NotImplementedError + return self._inner.name() - @abstractmethod async def schema(self) -> pa.Schema: """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of this Table """ - raise NotImplementedError + return await self._inner.schema() - @abstractmethod async def count_rows(self, filter: Optional[str] = None) -> int: """ Count the number of rows in the table. @@ -1837,7 +1888,7 @@ class AsyncTable(ABC): filter: str, optional A SQL where clause to filter the rows to count. """ - raise NotImplementedError + return await self._inner.count_rows(filter) async def to_pandas(self) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -1848,7 +1899,6 @@ class AsyncTable(ABC): """ return self.to_arrow().to_pandas() - @abstractmethod async def to_arrow(self) -> pa.Table: """Return the table as a pyarrow Table. @@ -1896,7 +1946,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def create_scalar_index( self, column: str, @@ -1967,13 +2016,13 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def add( self, data: DATA, - mode: str = "append", - on_bad_vectors: str = "error", - fill_value: float = 0.0, + *, + mode: Optional[Literal["append", "overwrite"]] = "append", + on_bad_vectors: Optional[str] = None, + fill_value: Optional[float] = None, ): """Add more data to the [Table](Table). @@ -1997,7 +2046,20 @@ class AsyncTable(ABC): The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - raise NotImplementedError + schema = await self.schema() + if on_bad_vectors is None: + on_bad_vectors = "error" + if fill_value is None: + fill_value = 0.0 + data = _sanitize_data( + data, + schema, + metadata=schema.metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + await self._inner.add(data, mode) + register_event("add") def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """ @@ -2059,7 +2121,6 @@ class AsyncTable(ABC): return LanceMergeInsertBuilder(self, on) - @abstractmethod async def search( self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, @@ -2142,11 +2203,9 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def _execute_query(self, query: Query) -> pa.Table: pass - @abstractmethod async def _do_merge( self, merge: LanceMergeInsertBuilder, @@ -2156,7 +2215,6 @@ class AsyncTable(ABC): ): pass - @abstractmethod async def delete(self, where: str): """Delete rows from the table. @@ -2207,7 +2265,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def update( self, where: Optional[str] = None, @@ -2263,7 +2320,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def cleanup_old_versions( self, older_than: Optional[timedelta] = None, @@ -2295,7 +2351,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def compact_files(self, *args, **kwargs): """ Run the compaction process on the table. @@ -2311,7 +2366,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def add_columns(self, transforms: Dict[str, str]): """ Add new columns with defined values. @@ -2327,7 +2381,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def alter_columns(self, alterations: Iterable[Dict[str, str]]): """ Alter column names and nullability. @@ -2350,7 +2403,6 @@ class AsyncTable(ABC): """ raise NotImplementedError - @abstractmethod async def drop_columns(self, columns: Iterable[str]): """ Drop columns from the table. @@ -2363,126 +2415,3 @@ class AsyncTable(ABC): The names of the columns to drop. """ raise NotImplementedError - - -class AsyncLanceTable(AsyncTable): - def __init__(self, table: LanceDBTable): - self._inner = table - - @property - @override - def name(self) -> str: - return self._inner.name() - - @override - async def schema(self) -> pa.Schema: - return await self._inner.schema() - - @override - async def count_rows(self, filter: Optional[str] = None) -> int: - raise NotImplementedError - - async def to_pandas(self) -> "pd.DataFrame": - return self.to_arrow().to_pandas() - - @override - async def to_arrow(self) -> pa.Table: - raise NotImplementedError - - async def create_index( - self, - metric="L2", - num_partitions=256, - num_sub_vectors=96, - vector_column_name: str = VECTOR_COLUMN_NAME, - replace: bool = True, - accelerator: Optional[str] = None, - index_cache_size: Optional[int] = None, - ): - raise NotImplementedError - - @override - async def create_scalar_index( - self, - column: str, - *, - replace: bool = True, - ): - raise NotImplementedError - - @override - async def add( - self, - data: DATA, - mode: str = "append", - on_bad_vectors: str = "error", - fill_value: float = 0.0, - ): - raise NotImplementedError - - def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: - on = [on] if isinstance(on, str) else list(on.iter()) - - return LanceMergeInsertBuilder(self, on) - - @override - async def search( - self, - query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, - vector_column_name: Optional[str] = None, - query_type: str = "auto", - ) -> LanceQueryBuilder: - raise NotImplementedError - - @override - async def _execute_query(self, query: Query) -> pa.Table: - pass - - @override - async def _do_merge( - self, - merge: LanceMergeInsertBuilder, - new_data: DATA, - on_bad_vectors: str, - fill_value: float, - ): - pass - - @override - async def delete(self, where: str): - raise NotImplementedError - - @override - async def update( - self, - where: Optional[str] = None, - values: Optional[dict] = None, - *, - values_sql: Optional[Dict[str, str]] = None, - ): - raise NotImplementedError - - @override - async def cleanup_old_versions( - self, - older_than: Optional[timedelta] = None, - *, - delete_unverified: bool = False, - ) -> CleanupStats: - raise NotImplementedError - - @override - async def compact_files(self, *args, **kwargs): - raise NotImplementedError - - @override - async def add_columns(self, transforms: Dict[str, str]): - raise NotImplementedError - - @override - async def alter_columns(self, alterations: Iterable[Dict[str, str]]): - raise NotImplementedError - - @override - async def drop_columns(self, columns: Iterable[str]): - raise NotImplementedError diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index c66131cf..06b1c326 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -11,6 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from datetime import timedelta + import lancedb import numpy as np import pandas as pd @@ -250,6 +253,28 @@ def test_create_exist_ok(tmp_path): db.create_table("test", schema=bad_schema, exist_ok=True) +@pytest.mark.asyncio +async def test_connect(tmp_path): + db = await lancedb.connect_async(tmp_path) + assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)" + + db = await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=5) + ) + assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)" + + +@pytest.mark.asyncio +async def test_close(tmp_path): + db = await lancedb.connect_async(tmp_path) + assert db.is_open() + db.close() + assert not db.is_open() + + with pytest.raises(RuntimeError, match="is closed"): + await db.table_names() + + @pytest.mark.asyncio async def test_create_mode_async(tmp_path): db = await lancedb.connect_async(tmp_path) @@ -322,6 +347,39 @@ async def test_create_exist_ok_async(tmp_path): # await db.create_table("test", schema=bad_schema, exist_ok=True) +@pytest.mark.asyncio +async def test_open_table(tmp_path): + db = await lancedb.connect_async(tmp_path) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + await db.create_table("test", data=data) + + tbl = await db.open_table("test") + assert tbl.name == "test" + assert ( + re.search( + r"NativeTable\(test, uri=.*test\.lance, read_consistency_interval=None\)", + str(tbl), + ) + is not None + ) + assert await tbl.schema() == pa.schema( + { + "vector": pa.list_(pa.float32(), list_size=2), + "item": pa.utf8(), + "price": pa.float64(), + } + ) + + with pytest.raises(ValueError, match="was not found"): + await db.open_table("does_not_exist") + + def test_delete_table(tmp_path): db = lancedb.connect(tmp_path) data = pd.DataFrame( diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 8b3029e0..564c3829 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -26,8 +26,9 @@ import pandas as pd import polars as pl import pyarrow as pa import pytest +import pytest_asyncio from lancedb.conftest import MockTextEmbeddingFunction -from lancedb.db import LanceDBConnection +from lancedb.db import AsyncConnection, LanceDBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable @@ -49,6 +50,13 @@ def db(tmp_path) -> MockDB: return MockDB(tmp_path) +@pytest_asyncio.fixture +async def db_async(tmp_path) -> AsyncConnection: + return await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) + + def test_basic(db): ds = LanceTable.create( db, @@ -65,6 +73,18 @@ def test_basic(db): assert table.to_lance().to_table() == ds.to_table() +@pytest.mark.asyncio +async def test_close(db_async: AsyncConnection): + table = await db_async.create_table("some_table", data=[{"id": 0}]) + assert table.is_open() + table.close() + assert not table.is_open() + + with pytest.raises(Exception, match="Table some_table is closed"): + await table.count_rows() + assert str(table) == "ClosedTable(some_table)" + + def test_create_table(db): schema = pa.schema( [ @@ -186,6 +206,25 @@ def test_add_pydantic_model(db): assert len(really_flattened.columns) == 7 +@pytest.mark.asyncio +async def test_add_async(db_async: AsyncConnection): + table = await db_async.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + assert await table.count_rows() == 2 + await table.add( + data=[ + {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, + ], + ) + table = await db_async.open_table("test") + assert await table.count_rows() == 3 + + def test_polars(db): data = { "vector": [[3.1, 4.1], [5.9, 26.5]], diff --git a/python/src/connection.rs b/python/src/connection.rs index 1f0fa759..7bfa2ae5 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -17,7 +17,8 @@ use std::{sync::Arc, time::Duration}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; use lancedb::connection::{Connection as LanceConnection, CreateTableMode}; use pyo3::{ - exceptions::PyValueError, pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python, + exceptions::{PyRuntimeError, PyValueError}, + pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python, }; use pyo3_asyncio::tokio::future_into_py; @@ -25,7 +26,19 @@ use crate::{error::PythonErrorExt, table::Table}; #[pyclass] pub struct Connection { - inner: LanceConnection, + inner: Option, +} + +impl Connection { + pub(crate) fn new(inner: LanceConnection) -> Self { + Self { inner: Some(inner) } + } + + fn get_inner(&self) -> PyResult<&LanceConnection> { + self.inner + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("Connection is closed")) + } } impl Connection { @@ -41,8 +54,23 @@ impl Connection { #[pymethods] impl Connection { + fn __repr__(&self) -> String { + match &self.inner { + Some(inner) => inner.to_string(), + None => "ClosedConnection".to_string(), + } + } + + fn is_open(&self) -> bool { + self.inner.is_some() + } + + fn close(&mut self) { + self.inner.take(); + } + pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { - let inner = self_.inner.clone(); + let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { inner.table_names().await.infer_error() }) @@ -54,7 +82,7 @@ impl Connection { mode: &str, data: &PyAny, ) -> PyResult<&'a PyAny> { - let inner = self_.inner.clone(); + let inner = self_.get_inner()?.clone(); let mode = Self::parse_create_mode_str(mode)?; @@ -76,7 +104,7 @@ impl Connection { mode: &str, schema: &PyAny, ) -> PyResult<&'a PyAny> { - let inner = self_.inner.clone(); + let inner = self_.get_inner()?.clone(); let mode = Self::parse_create_mode_str(mode)?; @@ -92,6 +120,14 @@ impl Connection { Ok(Table::new(table)) }) } + + pub fn open_table(self_: PyRef<'_, Self>, name: String) -> PyResult<&PyAny> { + let inner = self_.get_inner()?.clone(); + future_into_py(self_.py(), async move { + let table = inner.open_table(&name).execute().await.infer_error()?; + Ok(Table::new(table)) + }) + } } #[pyfunction] @@ -118,8 +154,6 @@ pub fn connect( let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval); builder = builder.read_consistency_interval(read_consistency_interval); } - Ok(Connection { - inner: builder.execute().await.infer_error()?, - }) + Ok(Connection::new(builder.execute().await.infer_error()?)) }) } diff --git a/python/src/table.rs b/python/src/table.rs index 23bda9a3..7fcf1a5f 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -1,34 +1,90 @@ -use std::sync::Arc; - -use arrow::pyarrow::ToPyArrow; -use lancedb::table::Table as LanceTable; -use pyo3::{pyclass, pymethods, PyAny, PyRef, PyResult, Python}; +use arrow::{ + ffi_stream::ArrowArrayStreamReader, + pyarrow::{FromPyArrow, ToPyArrow}, +}; +use lancedb::table::{AddDataMode, Table as LanceDbTable}; +use pyo3::{ + exceptions::{PyRuntimeError, PyValueError}, + pyclass, pymethods, PyAny, PyRef, PyResult, Python, +}; use pyo3_asyncio::tokio::future_into_py; use crate::error::PythonErrorExt; #[pyclass] pub struct Table { - inner: Arc, + // We keep a copy of the name to use if the inner table is dropped + name: String, + inner: Option, } impl Table { - pub(crate) fn new(inner: Arc) -> Self { - Self { inner } + pub(crate) fn new(inner: LanceDbTable) -> Self { + Self { + name: inner.name().to_string(), + inner: Some(inner), + } + } +} + +impl Table { + fn inner_ref(&self) -> PyResult<&LanceDbTable> { + self.inner + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name))) } } #[pymethods] impl Table { pub fn name(&self) -> String { - self.inner.name().to_string() + self.name.clone() + } + + pub fn is_open(&self) -> bool { + self.inner.is_some() + } + + pub fn close(&mut self) { + self.inner.take(); } pub fn schema(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { - let inner = self_.inner.clone(); + let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { let schema = inner.schema().await.infer_error()?; Python::with_gil(|py| schema.to_pyarrow(py)) }) } + + pub fn add<'a>(self_: PyRef<'a, Self>, data: &PyAny, mode: String) -> PyResult<&'a PyAny> { + let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?); + let mut op = self_.inner_ref()?.add(batches); + if mode == "append" { + op = op.mode(AddDataMode::Append); + } else if mode == "overwrite" { + op = op.mode(AddDataMode::Overwrite); + } else { + return Err(PyValueError::new_err(format!("Invalid mode: {}", mode))); + } + + future_into_py(self_.py(), async move { + op.execute().await.infer_error()?; + Ok(()) + }) + } + + pub fn count_rows(self_: PyRef<'_, Self>, filter: Option) -> PyResult<&PyAny> { + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.count_rows(filter).await.infer_error() + }) + } + + pub fn __repr__(&self) -> String { + match &self.inner { + None => format!("ClosedTable({})", self.name), + Some(inner) => inner.to_string(), + } + } } diff --git a/rust/ffi/node/src/index/scalar.rs b/rust/ffi/node/src/index/scalar.rs index 6605364c..3babdda7 100644 --- a/rust/ffi/node/src/index/scalar.rs +++ b/rust/ffi/node/src/index/scalar.rs @@ -19,7 +19,6 @@ use neon::{ }; use crate::{error::ResultExt, runtime, table::JsTable}; -use lancedb::Table; pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; @@ -34,8 +33,6 @@ pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult rt.spawn(async move { let idx_result = table - .as_native() - .unwrap() .create_index(&[&column]) .replace(replace) .build() diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 4fb559dd..8c6698bf 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -40,8 +40,9 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult .unwrap_or("vector".to_string()); // Backward compatibility let tbl = table.clone(); - let mut index_builder = tbl.create_index(&[&column_name]); - get_index_params_builder(&mut cx, index_params, &mut index_builder).or_throw(&mut cx)?; + let index_builder = tbl.create_index(&[&column_name]); + let index_builder = + get_index_params_builder(&mut cx, index_params, index_builder).or_throw(&mut cx)?; rt.spawn(async move { let idx_result = index_builder.build().await; @@ -56,9 +57,9 @@ pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult fn get_index_params_builder( cx: &mut FunctionContext, obj: Handle, - builder: &mut IndexBuilder, -) -> crate::error::Result<()> { - match obj.get::(cx, "type")?.value(cx).as_str() { + builder: IndexBuilder, +) -> crate::error::Result { + let mut builder = match obj.get::(cx, "type")?.value(cx).as_str() { "ivf_pq" => builder.ivf_pq(), _ => { return Err(InvalidIndexType { @@ -67,28 +68,29 @@ fn get_index_params_builder( } }; - obj.get_opt::(cx, "index_name")? - .map(|s| builder.name(s.value(cx).as_str())); + if let Some(index_name) = obj.get_opt::(cx, "index_name")? { + builder = builder.name(index_name.value(cx).as_str()); + } if let Some(metric_type) = obj.get_opt::(cx, "metric_type")? { let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?; - builder.metric_type(metric_type); + builder = builder.metric_type(metric_type); } if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? { - builder.num_partitions(np); + builder = builder.num_partitions(np); } if let Some(ns) = obj.get_opt_u32(cx, "num_sub_vectors")? { - builder.num_sub_vectors(ns); + builder = builder.num_sub_vectors(ns); } if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? { - builder.max_iterations(max_iters); + builder = builder.max_iterations(max_iters); } if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? { - builder.num_bits(num_bits); + builder = builder.num_bits(num_bits); } if let Some(replace) = obj.get_opt::(cx, "replace")? { - builder.replace(replace.value(cx)); + builder = builder.replace(replace.value(cx)); } - Ok(()) + Ok(builder) } diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index c687f849..1e0e71a3 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -18,10 +18,10 @@ use arrow_array::{RecordBatch, RecordBatchIterator}; use lance::dataset::optimize::CompactionOptions; use lance::dataset::{ColumnAlteration, NewColumnTransform, WriteMode, WriteParams}; use lance::io::ObjectStoreParams; -use lancedb::table::{AddDataOptions, OptimizeAction, WriteOptions}; +use lancedb::table::{OptimizeAction, WriteOptions}; use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; -use lancedb::TableRef; +use lancedb::table::Table as LanceDbTable; use neon::prelude::*; use neon::types::buffer::TypedArray; @@ -29,13 +29,13 @@ use crate::error::ResultExt; use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase}; pub struct JsTable { - pub table: TableRef, + pub table: LanceDbTable, } impl Finalize for JsTable {} -impl From for JsTable { - fn from(table: TableRef) -> Self { +impl From for JsTable { + fn from(table: LanceDbTable) -> Self { Self { table } } } @@ -125,13 +125,13 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let opts = AddDataOptions { - write_options: WriteOptions { + let add_result = table + .add(Box::new(batch_reader)) + .write_options(WriteOptions { lance_write_params: Some(params), - }, - ..Default::default() - }; - let add_result = table.add(Box::new(batch_reader), opts).await; + }) + .execute() + .await; deferred.settle_with(&channel, move |mut cx| { add_result.or_throw(&mut cx)?; diff --git a/rust/lancedb/examples/simple.rs b/rust/lancedb/examples/simple.rs index a09eca97..51e8e44b 100644 --- a/rust/lancedb/examples/simple.rs +++ b/rust/lancedb/examples/simple.rs @@ -20,8 +20,7 @@ use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lancedb::connection::Connection; -use lancedb::table::AddDataOptions; -use lancedb::{connect, Result, Table, TableRef}; +use lancedb::{connect, Result, Table as LanceDbTable}; #[tokio::main] async fn main() -> Result<()> { @@ -37,8 +36,8 @@ async fn main() -> Result<()> { println!("{:?}", db.table_names().await?); // --8<-- [end:list_names] let tbl = create_table(&db).await?; - create_index(tbl.as_ref()).await?; - let batches = search(tbl.as_ref()).await?; + create_index(&tbl).await?; + let batches = search(&tbl).await?; println!("{:?}", batches); create_empty_table(&db).await.unwrap(); @@ -63,7 +62,7 @@ async fn open_with_existing_tbl() -> Result<()> { Ok(()) } -async fn create_table(db: &Connection) -> Result { +async fn create_table(db: &Connection) -> Result { // --8<-- [start:create_table] const TOTAL: usize = 1000; const DIM: usize = 128; @@ -125,15 +124,13 @@ async fn create_table(db: &Connection) -> Result { schema.clone(), ); // --8<-- [start:add] - tbl.add(Box::new(new_batches), AddDataOptions::default()) - .await - .unwrap(); + tbl.add(Box::new(new_batches)).execute().await.unwrap(); // --8<-- [end:add] Ok(tbl) } -async fn create_empty_table(db: &Connection) -> Result { +async fn create_empty_table(db: &Connection) -> Result { // --8<-- [start:create_empty_table] let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -143,7 +140,7 @@ async fn create_empty_table(db: &Connection) -> Result { // --8<-- [end:create_empty_table] } -async fn create_index(table: &dyn Table) -> Result<()> { +async fn create_index(table: &LanceDbTable) -> Result<()> { // --8<-- [start:create_index] table .create_index(&["vector"]) @@ -154,7 +151,7 @@ async fn create_index(table: &dyn Table) -> Result<()> { // --8<-- [end:create_index] } -async fn search(table: &dyn Table) -> Result> { +async fn search(table: &LanceDbTable) -> Result> { // --8<-- [start:search] Ok(table .search(&[1.0; 128]) diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 37b663f4..cc87350d 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -29,7 +29,8 @@ use snafu::prelude::*; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; -use crate::table::{NativeTable, TableRef, WriteOptions}; +use crate::table::{NativeTable, WriteOptions}; +use crate::Table; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -111,7 +112,7 @@ impl CreateTableBuilder { } /// Execute the create table operation - pub async fn execute(self) -> Result { + pub async fn execute(self) -> Result
{ self.parent.clone().do_create_table(self).await } } @@ -130,7 +131,7 @@ impl CreateTableBuilder { } /// Execute the create table operation - pub async fn execute(self) -> Result { + pub async fn execute(self) -> Result
{ self.parent.clone().do_create_empty_table(self).await } } @@ -188,20 +189,22 @@ impl OpenTableBuilder { } /// Open the table - pub async fn execute(self) -> Result { + pub async fn execute(self) -> Result
{ self.parent.clone().do_open_table(self).await } } #[async_trait::async_trait] -pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static { +pub(crate) trait ConnectionInternal: + Send + Sync + std::fmt::Debug + std::fmt::Display + 'static +{ async fn table_names(&self) -> Result>; - async fn do_create_table(&self, options: CreateTableBuilder) -> Result; - async fn do_open_table(&self, options: OpenTableBuilder) -> Result; + async fn do_create_table(&self, options: CreateTableBuilder) -> Result
; + async fn do_open_table(&self, options: OpenTableBuilder) -> Result
; async fn drop_table(&self, name: &str) -> Result<()>; async fn drop_db(&self) -> Result<()>; - async fn do_create_empty_table(&self, options: CreateTableBuilder) -> Result { + async fn do_create_empty_table(&self, options: CreateTableBuilder) -> Result
{ let batches = RecordBatchIterator::new(vec![], options.schema.unwrap()); let opts = CreateTableBuilder::::new(options.parent, options.name, Box::new(batches)) .mode(options.mode) @@ -217,6 +220,12 @@ pub struct Connection { internal: Arc, } +impl std::fmt::Display for Connection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.internal) + } +} + impl Connection { /// Get the URI of the connection pub fn uri(&self) -> &str { @@ -431,6 +440,24 @@ struct Database { read_consistency_interval: Option, } +impl std::fmt::Display for Database { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "NativeDatabase(uri={}, read_consistency_interval={})", + self.uri, + match self.read_consistency_interval { + None => { + "None".to_string() + } + Some(duration) => { + format!("{}s", duration.as_secs_f64()) + } + } + ) + } +} + const LANCE_EXTENSION: &str = "lance"; const ENGINE: &str = "engine"; const MIRRORED_STORE: &str = "mirroredStore"; @@ -606,7 +633,7 @@ impl ConnectionInternal for Database { Ok(f) } - async fn do_create_table(&self, options: CreateTableBuilder) -> Result { + async fn do_create_table(&self, options: CreateTableBuilder) -> Result
{ let table_uri = self.table_uri(&options.name)?; let mut write_params = options.write_options.lance_write_params.unwrap_or_default(); @@ -624,7 +651,7 @@ impl ConnectionInternal for Database { ) .await { - Ok(table) => Ok(Arc::new(table)), + Ok(table) => Ok(Table::new(Arc::new(table))), Err(Error::TableAlreadyExists { name }) => match options.mode { CreateTableMode::Create => Err(Error::TableAlreadyExists { name }), CreateTableMode::ExistOk(callback) => { @@ -638,9 +665,9 @@ impl ConnectionInternal for Database { } } - async fn do_open_table(&self, options: OpenTableBuilder) -> Result { + async fn do_open_table(&self, options: OpenTableBuilder) -> Result
{ let table_uri = self.table_uri(&options.name)?; - Ok(Arc::new( + let native_table = Arc::new( NativeTable::open_with_params( &table_uri, &options.name, @@ -649,7 +676,8 @@ impl ConnectionInternal for Database { self.read_consistency_interval, ) .await?, - )) + ); + Ok(Table::new(native_table)) } async fn drop_table(&self, name: &str) -> Result<()> { diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index 6d2cbbb2..b55ff661 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -14,13 +14,12 @@ use std::{cmp::max, sync::Arc}; -use lance::index::scalar::ScalarIndexParams; -use lance_index::{DatasetIndexExt, IndexType}; +use lance_index::IndexType; pub use lance_linalg::distance::MetricType; pub mod vector; -use crate::{utils::default_vector_column, Error, Result, Table}; +use crate::{table::TableInternal, Result}; /// Index Parameters. pub enum IndexParams { @@ -41,36 +40,36 @@ pub enum IndexParams { /// Builder for Index Parameters. pub struct IndexBuilder { - table: Arc, - columns: Vec, + parent: Arc, + pub(crate) columns: Vec, // General parameters /// Index name. - name: Option, + pub(crate) name: Option, /// Replace the existing index. - replace: bool, + pub(crate) replace: bool, - index_type: IndexType, + pub(crate) index_type: IndexType, // Scalar index parameters // Nothing to set here. // IVF_PQ parameters - metric_type: MetricType, - num_partitions: Option, + pub(crate) metric_type: MetricType, + pub(crate) num_partitions: Option, // PQ related - num_sub_vectors: Option, - num_bits: u32, + pub(crate) num_sub_vectors: Option, + pub(crate) num_bits: u32, /// The rate to find samples to train kmeans. - sample_rate: u32, + pub(crate) sample_rate: u32, /// Max iteration to train kmeans. - max_iterations: u32, + pub(crate) max_iterations: u32, } impl IndexBuilder { - pub(crate) fn new(table: Arc, columns: &[&str]) -> Self { + pub(crate) fn new(parent: Arc, columns: &[&str]) -> Self { Self { - table, + parent, columns: columns.iter().map(|c| c.to_string()).collect(), name: None, replace: true, @@ -89,7 +88,7 @@ impl IndexBuilder { /// Accepted parameters: /// - `replace`: Replace the existing index. /// - `name`: Index name. Default: `None` - pub fn scalar(&mut self) -> &mut Self { + pub fn scalar(mut self) -> Self { self.index_type = IndexType::Scalar; self } @@ -105,25 +104,25 @@ impl IndexBuilder { /// - `num_bits`: Number of bits used for PQ centroids. /// - `sample_rate`: The rate to find samples to train kmeans. /// - `max_iterations`: Max iteration to train kmeans. - pub fn ivf_pq(&mut self) -> &mut Self { + pub fn ivf_pq(mut self) -> Self { self.index_type = IndexType::Vector; self } /// The columns to build index on. - pub fn columns(&mut self, cols: &[&str]) -> &mut Self { + pub fn columns(mut self, cols: &[&str]) -> Self { self.columns = cols.iter().map(|s| s.to_string()).collect(); self } /// Whether to replace the existing index, default is `true`. - pub fn replace(&mut self, v: bool) -> &mut Self { + pub fn replace(mut self, v: bool) -> Self { self.replace = v; self } /// Set the index name. - pub fn name(&mut self, name: &str) -> &mut Self { + pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_string()); self } @@ -131,156 +130,53 @@ impl IndexBuilder { /// [MetricType] to use to build Vector Index. /// /// Default value is [MetricType::L2]. - pub fn metric_type(&mut self, metric_type: MetricType) -> &mut Self { + pub fn metric_type(mut self, metric_type: MetricType) -> Self { self.metric_type = metric_type; self } /// Number of IVF partitions. - pub fn num_partitions(&mut self, num_partitions: u32) -> &mut Self { + pub fn num_partitions(mut self, num_partitions: u32) -> Self { self.num_partitions = Some(num_partitions); self } /// Number of sub-vectors of PQ. - pub fn num_sub_vectors(&mut self, num_sub_vectors: u32) -> &mut Self { + pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self { self.num_sub_vectors = Some(num_sub_vectors); self } /// Number of bits used for PQ centroids. - pub fn num_bits(&mut self, num_bits: u32) -> &mut Self { + pub fn num_bits(mut self, num_bits: u32) -> Self { self.num_bits = num_bits; self } /// The rate to find samples to train kmeans. - pub fn sample_rate(&mut self, sample_rate: u32) -> &mut Self { + pub fn sample_rate(mut self, sample_rate: u32) -> Self { self.sample_rate = sample_rate; self } /// Max iteration to train kmeans. - pub fn max_iterations(&mut self, max_iterations: u32) -> &mut Self { + pub fn max_iterations(mut self, max_iterations: u32) -> Self { self.max_iterations = max_iterations; self } /// Build the parameters. - pub async fn build(&self) -> Result<()> { - let schema = self.table.schema().await?; - - // TODO: simplify this after GH lance#1864. - let mut index_type = &self.index_type; - let columns = if self.columns.is_empty() { - // By default we create vector index. - index_type = &IndexType::Vector; - vec![default_vector_column(&schema, None)?] - } else { - self.columns.clone() - }; - - if columns.len() != 1 { - return Err(Error::Schema { - message: "Only one column is supported for index".to_string(), - }); - } - let column = &columns[0]; - - let field = schema.field_with_name(column)?; - - let params = match index_type { - IndexType::Scalar => IndexParams::Scalar { - replace: self.replace, - }, - IndexType::Vector => { - let num_partitions = if let Some(n) = self.num_partitions { - n - } else { - suggested_num_partitions(self.table.count_rows(None).await?) - }; - let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors { - n - } else { - match field.data_type() { - arrow_schema::DataType::FixedSizeList(_, n) => { - Ok::(suggested_num_sub_vectors(*n as u32)) - } - _ => Err(Error::Schema { - message: format!( - "Column '{}' is not a FixedSizeList", - &self.columns[0] - ), - }), - }? - }; - IndexParams::IvfPq { - replace: self.replace, - metric_type: self.metric_type, - num_partitions: num_partitions as u64, - num_sub_vectors, - num_bits: self.num_bits, - sample_rate: self.sample_rate, - max_iterations: self.max_iterations, - } - } - }; - - let tbl = self - .table - .as_native() - .expect("Only native table is supported here"); - let mut dataset = tbl.dataset.get_mut().await?; - match params { - IndexParams::Scalar { replace } => { - dataset - .create_index( - &[&column], - IndexType::Scalar, - None, - &ScalarIndexParams::default(), - replace, - ) - .await? - } - IndexParams::IvfPq { - replace, - metric_type, - num_partitions, - num_sub_vectors, - num_bits, - max_iterations, - .. - } => { - let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq( - num_partitions as usize, - num_bits as u8, - num_sub_vectors as usize, - false, - metric_type, - max_iterations as usize, - ); - dataset - .create_index( - &[column], - IndexType::Vector, - None, - &lance_idx_params, - replace, - ) - .await?; - } - } - Ok(()) + pub async fn build(self) -> Result<()> { + self.parent.clone().do_create_index(self).await } } -fn suggested_num_partitions(rows: usize) -> u32 { +pub(crate) fn suggested_num_partitions(rows: usize) -> u32 { let num_partitions = (rows as f64).sqrt() as u32; max(1, num_partitions) } -fn suggested_num_sub_vectors(dim: u32) -> u32 { +pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 { if dim % 16 == 0 { // Should be more aggressive than this default. dim / 16 diff --git a/rust/lancedb/src/ipc.rs b/rust/lancedb/src/ipc.rs index 54a17a8a..1446832c 100644 --- a/rust/lancedb/src/ipc.rs +++ b/rust/lancedb/src/ipc.rs @@ -14,17 +14,18 @@ //! IPC support -use std::io::Cursor; +use std::{io::Cursor, sync::Arc}; use arrow_array::{RecordBatch, RecordBatchReader}; -use arrow_ipc::{reader::StreamReader, writer::FileWriter}; +use arrow_ipc::{reader::FileReader, writer::FileWriter}; +use arrow_schema::Schema; use crate::{Error, Result}; /// Convert a Arrow IPC file to a batch reader pub fn ipc_file_to_batches(buf: Vec) -> Result { let buf_reader = Cursor::new(buf); - let reader = StreamReader::try_new(buf_reader, None)?; + let reader = FileReader::try_new(buf_reader, None)?; Ok(reader) } @@ -44,6 +45,20 @@ pub fn batches_to_ipc_file(batches: &[RecordBatch]) -> Result> { Ok(writer.into_inner()?) } +/// Convert a schema to an Arrow IPC file with 0 batches +pub fn schema_to_ipc_file(schema: &Schema) -> Result> { + let mut writer = FileWriter::try_new(vec![], schema)?; + writer.finish()?; + Ok(writer.into_inner()?) +} + +/// Retrieve the schema from an Arrow IPC file +pub fn ipc_file_to_schema(buf: Vec) -> Result> { + let buf_reader = Cursor::new(buf); + let reader = FileReader::try_new(buf_reader, None)?; + Ok(reader.schema()) +} + #[cfg(test)] mod tests { @@ -71,7 +86,7 @@ mod tests { fn test_ipc_file_to_batches() -> Result<()> { let batch = create_record_batch()?; - let mut writer = StreamWriter::try_new(vec![], &batch.schema())?; + let mut writer = FileWriter::try_new(vec![], &batch.schema())?; writer.write(&batch)?; writer.finish()?; diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index a04826aa..817f0329 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -194,7 +194,7 @@ pub mod table; pub mod utils; pub use error::{Error, Result}; -pub use table::{Table, TableRef}; +pub use table::Table; /// Connect to a database pub use connection::connect; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 75e5499b..49676a99 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use arrow_array::Float32Array; -use arrow_schema::Schema; -use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::DatasetRecordBatchStream; use lance_linalg::distance::MetricType; use crate::error::Result; -use crate::table::dataset::DatasetConsistencyWrapper; -use crate::utils::default_vector_column; -use crate::Error; +use crate::table::TableInternal; -const DEFAULT_TOP_K: usize = 10; +pub(crate) const DEFAULT_TOP_K: usize = 10; #[derive(Debug, Clone)] pub enum Select { @@ -34,29 +33,29 @@ pub enum Select { /// A builder for nearest neighbor queries for LanceDB. #[derive(Clone)] pub struct Query { - dataset: DatasetConsistencyWrapper, + parent: Arc, // The column to run the query on. If not specified, we will attempt to guess // the column based on the dataset's schema. - column: Option, + pub(crate) column: Option, // IVF PQ - ANN search. - query_vector: Option, - nprobes: usize, - refine_factor: Option, - metric_type: Option, + pub(crate) query_vector: Option, + pub(crate) nprobes: usize, + pub(crate) refine_factor: Option, + pub(crate) metric_type: Option, /// limit the number of rows to return. - limit: Option, + pub(crate) limit: Option, /// Apply filter to the returned rows. - filter: Option, + pub(crate) filter: Option, /// Select column projection. - select: Select, + pub(crate) select: Select, /// Default is true. Set to false to enforce a brute force search. - use_index: bool, + pub(crate) use_index: bool, /// Apply filter before ANN search/ - prefilter: bool, + pub(crate) prefilter: bool, } impl Query { @@ -64,11 +63,11 @@ impl Query { /// /// # Arguments /// - /// * `dataset` - Lance dataset. + /// * `parent` - the table to run the query on. /// - pub(crate) fn new(dataset: DatasetConsistencyWrapper) -> Self { + pub(crate) fn new(parent: Arc) -> Self { Self { - dataset, + parent, query_vector: None, column: None, limit: None, @@ -88,54 +87,7 @@ impl Query { /// /// * A [DatasetRecordBatchStream] with the query's results. pub async fn execute_stream(&self) -> Result { - let ds_ref = self.dataset.get().await?; - let mut scanner: Scanner = ds_ref.scan(); - - if let Some(query) = self.query_vector.as_ref() { - // If there is a vector query, default to limit=10 if unspecified - let column = if let Some(col) = self.column.as_ref() { - col.clone() - } else { - // Infer a vector column with the same dimension of the query vector. - let arrow_schema = Schema::from(ds_ref.schema()); - default_vector_column(&arrow_schema, Some(query.len() as i32))? - }; - let field = ds_ref.schema().field(&column).ok_or(Error::Store { - message: format!("Column {} not found in dataset schema", column), - })?; - if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query.len() as i32) - { - return Err(Error::Store { - message: format!( - "Vector column '{}' does not match the dimension of the query vector: dim={}", - column, - query.len(), - ), - }); - } - scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?; - } else { - // If there is no vector query, it's ok to not have a limit - scanner.limit(self.limit.map(|limit| limit as i64), None)?; - } - scanner.nprobs(self.nprobes); - scanner.use_index(self.use_index); - scanner.prefilter(self.prefilter); - - match &self.select { - Select::Simple(select) => { - scanner.project(select.as_slice())?; - } - Select::Projection(select_with_transform) => { - scanner.project_with_transform(select_with_transform.as_slice())?; - } - Select::All => { /* Do nothing */ } - } - - self.filter.as_ref().map(|f| scanner.filter(f)); - self.refine_factor.map(|rf| scanner.refine(rf)); - self.metric_type.map(|mt| scanner.distance_metric(mt)); - Ok(scanner.try_into_stream().await?) + self.parent.clone().do_query(self).await } /// Set the column to query @@ -259,22 +211,29 @@ mod tests { }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; - use lance::dataset::Dataset; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use tempfile::tempdir; - use crate::query::Query; - use crate::table::{NativeTable, Table}; + use crate::connect; #[tokio::test] async fn test_setters_getters() { - let batches = make_test_batches(); - let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 + // is fixed + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); - let ds = DatasetConsistencyWrapper::new_latest(ds, None); + let batches = make_test_batches(); + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() + .await + .unwrap(); let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); - let query = Query::new(ds).nearest_to(&[0.1, 0.2]); + let query = table.query().nearest_to(&[0.1, 0.2]); assert_eq!(query.query_vector, vector); let new_vector = Float32Array::from_iter_values([9.8, 8.7]); @@ -297,12 +256,21 @@ mod tests { #[tokio::test] async fn test_execute() { + // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 + // is fixed + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + let batches = make_non_empty_batches(); - let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() + .await + .unwrap(); - let ds = DatasetConsistencyWrapper::new_latest(ds, None); - - let query = Query::new(ds.clone()).nearest_to(&[0.1; 4]); + let query = table.query().nearest_to(&[0.1; 4]); let result = query.limit(10).filter("id % 2 == 0").execute_stream().await; let mut stream = result.expect("should have result"); // should only have one batch @@ -311,7 +279,7 @@ mod tests { assert!(batch.expect("should be Ok").num_rows() < 10); } - let query = Query::new(ds).nearest_to(&[0.1; 4]); + let query = table.query().nearest_to(&[0.1; 4]); let result = query .limit(10) .filter(String::from("id % 2 == 0")) // Work with String too @@ -328,12 +296,22 @@ mod tests { #[tokio::test] async fn test_select_with_transform() { + // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 + // is fixed + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + let batches = make_non_empty_batches(); - let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() + .await + .unwrap(); - let ds = DatasetConsistencyWrapper::new_latest(ds, None); - - let query = Query::new(ds) + let query = table + .query() .limit(10) .select_with_projection(&[("id2", "id * 2"), ("id", "id")]); let result = query.execute_stream().await; @@ -360,13 +338,22 @@ mod tests { #[tokio::test] async fn test_execute_no_vector() { + // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 + // is fixed + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + // test that it's ok to not specify a query vector (just filter / limit) let batches = make_non_empty_batches(); - let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() + .await + .unwrap(); - let ds = DatasetConsistencyWrapper::new_latest(ds, None); - - let query = Query::new(ds); + let query = table.query(); let result = query.filter("id % 2 == 0").execute_stream().await; let mut stream = result.expect("should have result"); // should only have one batch @@ -413,12 +400,13 @@ mod tests { let uri = dataset_path.to_str().unwrap(); let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() .await .unwrap(); - let table = NativeTable::open(uri).await.unwrap(); - let query = table.search(&[0.1, 0.2]); assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values()); } diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 6ff9811b..2b2c1fd7 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -28,6 +28,10 @@ pub struct RestfulLanceDbClient { } impl RestfulLanceDbClient { + pub fn host(&self) -> &str { + &self.host + } + fn default_headers( api_key: &str, region: &str, diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 948db4fd..3dd17ec9 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -21,7 +21,7 @@ use tokio::task::spawn_blocking; use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder}; use crate::error::Result; -use crate::TableRef; +use crate::Table; use super::client::RestfulLanceDbClient; use super::table::RemoteTable; @@ -51,6 +51,12 @@ impl RemoteDatabase { } } +impl std::fmt::Display for RemoteDatabase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RemoteDatabase(host={})", self.client.host()) + } +} + #[async_trait] impl ConnectionInternal for RemoteDatabase { async fn table_names(&self) -> Result> { @@ -65,7 +71,7 @@ impl ConnectionInternal for RemoteDatabase { Ok(rsp.json::().await?.tables) } - async fn do_create_table(&self, options: CreateTableBuilder) -> Result { + async fn do_create_table(&self, options: CreateTableBuilder) -> Result
{ let data = options.data.unwrap(); // TODO: https://github.com/lancedb/lancedb/issues/1026 // We should accept data from an async source. In the meantime, spawn this as blocking @@ -78,17 +84,18 @@ impl ConnectionInternal for RemoteDatabase { .post(&format!("/v1/table/{}/create", options.name)) .body(data_buffer) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + // This is currently expected by LanceDb cloud but will be removed soon. .header("x-request-id", "na") .send() .await?; - Ok(Arc::new(RemoteTable::new( + Ok(Table::new(Arc::new(RemoteTable::new( self.client.clone(), options.name, - ))) + )))) } - async fn do_open_table(&self, _options: OpenTableBuilder) -> Result { + async fn do_open_table(&self, _options: OpenTableBuilder) -> Result
{ todo!() } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index dfbf337f..f258d9a8 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,16 +1,16 @@ use arrow_array::RecordBatchReader; use arrow_schema::SchemaRef; use async_trait::async_trait; -use lance::dataset::{ColumnAlteration, NewColumnTransform}; +use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform}; use crate::{ error::Result, index::IndexBuilder, query::Query, table::{ - merge::MergeInsertBuilder, AddDataOptions, NativeTable, OptimizeAction, OptimizeStats, + merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats, + TableInternal, }, - Table, }; use super::client::RestfulLanceDbClient; @@ -35,7 +35,7 @@ impl std::fmt::Display for RemoteTable { } #[async_trait] -impl Table for RemoteTable { +impl TableInternal for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { self } @@ -51,23 +51,23 @@ impl Table for RemoteTable { async fn count_rows(&self, _filter: Option) -> Result { todo!() } - async fn add( - &self, - _batches: Box, - _options: AddDataOptions, - ) -> Result<()> { + async fn do_add(&self, _add: AddDataBuilder) -> Result<()> { + todo!() + } + async fn do_query(&self, _query: &Query) -> Result { todo!() } async fn delete(&self, _predicate: &str) -> Result<()> { todo!() } - fn create_index(&self, _column: &[&str]) -> IndexBuilder { + async fn do_create_index(&self, _index: IndexBuilder) -> Result<()> { todo!() } - fn merge_insert(&self, _on: &[&str]) -> MergeInsertBuilder { - todo!() - } - fn query(&self) -> Query { + async fn do_merge_insert( + &self, + _params: MergeInsertBuilder, + _new_data: Box, + ) -> Result<()> { todo!() } async fn optimize(&self, _action: OptimizeAction) -> Result { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index c9638f08..7a90950d 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -26,24 +26,29 @@ use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{ compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, }; +use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; pub use lance::dataset::ReadParams; use lance::dataset::{ ColumnAlteration, Dataset, NewColumnTransform, UpdateBuilder, WhenMatched, WriteMode, WriteParams, }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; +use lance::index::scalar::ScalarIndexParams; use lance::io::WrappingObjectStore; +use lance_index::IndexType; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use log::info; use crate::error::{Error, Result}; use crate::index::vector::{VectorIndex, VectorIndexStatistics}; -use crate::index::IndexBuilder; -use crate::query::Query; -use crate::utils::{PatchReadParam, PatchWriteParam}; +use crate::index::{ + suggested_num_partitions, suggested_num_sub_vectors, IndexBuilder, IndexParams, +}; +use crate::query::{Query, Select, DEFAULT_TOP_K}; +use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam}; use self::dataset::DatasetConsistencyWrapper; -use self::merge::{MergeInsert, MergeInsertBuilder}; +use self::merge::MergeInsertBuilder; pub(crate) mod dataset; pub mod merge; @@ -97,7 +102,7 @@ pub struct WriteOptions { // pub on_bad_vectors: BadVectorHandling, /// Advanced parameters that can be used to customize table creation /// - /// If set, these will take precedence over any overlapping `OpenTableOptions` options + /// If set, these will take precedence over any overlapping `OpenTableBuilder` options pub lance_write_params: Option, } @@ -110,36 +115,115 @@ pub enum AddDataMode { Overwrite, } -#[derive(Debug, Default, Clone)] -pub struct AddDataOptions { - /// Whether to add new rows (the default) or replace the existing data - pub mode: AddDataMode, - /// Options to use when writing the data - pub write_options: WriteOptions, +/// A builder for configuring a [`Connection::create_table`] operation +pub struct AddDataBuilder { + parent: Arc, + pub(crate) data: Box, + pub(crate) mode: AddDataMode, + pub(crate) write_options: WriteOptions, +} + +impl std::fmt::Debug for AddDataBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AddDataBuilder") + .field("parent", &self.parent) + .field("mode", &self.mode) + .field("write_options", &self.write_options) + .finish() + } +} + +impl AddDataBuilder { + pub fn mode(mut self, mode: AddDataMode) -> Self { + self.mode = mode; + self + } + + pub fn write_options(mut self, options: WriteOptions) -> Self { + self.write_options = options; + self + } + + pub async fn execute(self) -> Result<()> { + self.parent.clone().do_add(self).await + } +} + +#[async_trait] +pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync { + fn as_any(&self) -> &dyn std::any::Any; + /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. + fn as_native(&self) -> Option<&NativeTable>; + /// Get the name of the table. + fn name(&self) -> &str; + /// Get the arrow [Schema] of the table. + async fn schema(&self) -> Result; + /// Count the number of rows in this table. + async fn count_rows(&self, filter: Option) -> Result; + async fn do_add(&self, add: AddDataBuilder) -> Result<()>; + async fn do_query(&self, query: &Query) -> Result; + async fn delete(&self, predicate: &str) -> Result<()>; + async fn do_create_index(&self, index: IndexBuilder) -> Result<()>; + async fn do_merge_insert( + &self, + params: MergeInsertBuilder, + new_data: Box, + ) -> Result<()>; + async fn optimize(&self, action: OptimizeAction) -> Result; + async fn add_columns( + &self, + transforms: NewColumnTransform, + read_columns: Option>, + ) -> Result<()>; + async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>; + async fn drop_columns(&self, columns: &[&str]) -> Result<()>; } /// A Table is a collection of strong typed Rows. /// /// The type of the each row is defined in Apache Arrow [Schema]. -#[async_trait::async_trait] -pub trait Table: std::fmt::Display + Send + Sync { - fn as_any(&self) -> &dyn std::any::Any; +#[derive(Clone)] +pub struct Table { + inner: Arc, +} + +impl std::fmt::Display for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl Table { + pub(crate) fn new(inner: Arc) -> Self { + Self { inner } + } /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. - fn as_native(&self) -> Option<&NativeTable>; + /// + /// Warning: This function will be removed soon (features exclusive to NativeTable + /// will be added to Table) + pub fn as_native(&self) -> Option<&NativeTable> { + self.inner.as_native() + } /// Get the name of the table. - fn name(&self) -> &str; + pub fn name(&self) -> &str { + self.inner.name() + } /// Get the arrow [Schema] of the table. - async fn schema(&self) -> Result; + pub async fn schema(&self) -> Result { + self.inner.schema().await + } /// Count the number of rows in this dataset. /// /// # Arguments /// /// * `filter` if present, only count rows matching the filter - async fn count_rows(&self, filter: Option) -> Result; + pub async fn count_rows(&self, filter: Option) -> Result { + self.inner.count_rows(filter).await + } /// Insert new records into this Table /// @@ -147,11 +231,14 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// * `batches` data to be added to the Table /// * `options` options to control how data is added - async fn add( - &self, - batches: Box, - options: AddDataOptions, - ) -> Result<()>; + pub fn add(&self, batches: Box) -> AddDataBuilder { + AddDataBuilder { + parent: self.inner.clone(), + data: batches, + mode: AddDataMode::Append, + write_options: WriteOptions::default(), + } + } /// Delete the rows from table that match the predicate. /// @@ -202,7 +289,9 @@ pub trait Table: std::fmt::Display + Send + Sync { /// tbl.delete("id > 5").await.unwrap(); /// # }); /// ``` - async fn delete(&self, predicate: &str) -> Result<()>; + pub async fn delete(&self, predicate: &str) -> Result<()> { + self.inner.delete(predicate).await + } /// Create an index on the column name. /// @@ -228,7 +317,9 @@ pub trait Table: std::fmt::Display + Send + Sync { /// .unwrap(); /// # }); /// ``` - fn create_index(&self, column: &[&str]) -> IndexBuilder; + pub fn create_index(&self, column: &[&str]) -> IndexBuilder { + IndexBuilder::new(self.inner.clone(), column) + } /// Create a builder for a merge insert operation /// @@ -305,12 +396,17 @@ pub trait Table: std::fmt::Display + Send + Sync { /// merge_insert.execute(Box::new(new_data)).await.unwrap(); /// # }); /// ``` - fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder; + pub fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder { + MergeInsertBuilder::new( + self.inner.clone(), + on.iter().map(|s| s.to_string()).collect(), + ) + } /// Search the table with a given query vector. /// /// This is a convenience method for preparing an ANN query. - fn search(&self, query: &[f32]) -> Query { + pub fn search(&self, query: &[f32]) -> Query { self.query().nearest_to(query) } @@ -327,7 +423,8 @@ pub trait Table: std::fmt::Display + Send + Sync { /// # use arrow_array::RecordBatch; /// # use futures::TryStreamExt; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// # let tbl = lancedb::table::NativeTable::open("/tmp/tbl").await.unwrap(); + /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); + /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; /// let stream = tbl /// .query() @@ -346,7 +443,8 @@ pub trait Table: std::fmt::Display + Send + Sync { /// # use arrow_array::RecordBatch; /// # use futures::TryStreamExt; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// # let tbl = lancedb::table::NativeTable::open("/tmp/tbl").await.unwrap(); + /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); + /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; /// let stream = tbl /// .query() @@ -364,13 +462,16 @@ pub trait Table: std::fmt::Display + Send + Sync { /// # use arrow_array::RecordBatch; /// # use futures::TryStreamExt; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { - /// # let tbl = lancedb::table::NativeTable::open("/tmp/tbl").await.unwrap(); + /// # let conn = lancedb::connect("/tmp").execute().await.unwrap(); + /// # let tbl = conn.open_table("tbl").execute().await.unwrap(); /// use crate::lancedb::Table; /// let stream = tbl.query().execute_stream().await.unwrap(); /// let batches: Vec = stream.try_collect().await.unwrap(); /// # }); /// ``` - fn query(&self) -> Query; + pub fn query(&self) -> Query { + Query::new(self.inner.clone()) + } /// Optimize the on-disk data and indices for better performance. /// @@ -378,25 +479,30 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// Modeled after ``VACUUM`` in PostgreSQL. /// Not all implementations support explicit optimization. - async fn optimize(&self, action: OptimizeAction) -> Result; + pub async fn optimize(&self, action: OptimizeAction) -> Result { + self.inner.optimize(action).await + } /// Add new columns to the table, providing values to fill in. - async fn add_columns( + pub async fn add_columns( &self, transforms: NewColumnTransform, read_columns: Option>, - ) -> Result<()>; + ) -> Result<()> { + self.inner.add_columns(transforms, read_columns).await + } /// Change a column's name or nullability. - async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>; + pub async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()> { + self.inner.alter_columns(alterations).await + } /// Remove columns from the table. - async fn drop_columns(&self, columns: &[&str]) -> Result<()>; + pub async fn drop_columns(&self, columns: &[&str]) -> Result<()> { + self.inner.drop_columns(columns).await + } } -/// Reference to a Table pointer. -pub type TableRef = Arc; - /// A table in a LanceDB database. #[derive(Debug, Clone)] pub struct NativeTable { @@ -414,7 +520,20 @@ pub struct NativeTable { impl std::fmt::Display for NativeTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Table({})", self.name) + write!( + f, + "NativeTable({}, uri={}, read_consistency_interval={})", + self.name, + self.uri, + match self.read_consistency_interval { + None => { + "None".to_string() + } + Some(duration) => { + format!("{}s", duration.as_secs_f64()) + } + } + ) } } @@ -758,8 +877,107 @@ impl NativeTable { } } -#[async_trait] -impl MergeInsert for NativeTable { +#[async_trait::async_trait] +impl TableInternal for NativeTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_native(&self) -> Option<&NativeTable> { + Some(self) + } + + fn name(&self) -> &str { + self.name.as_str() + } + + async fn schema(&self) -> Result { + let lance_schema = self.dataset.get().await?.schema().clone(); + Ok(Arc::new(Schema::from(&lance_schema))) + } + + async fn count_rows(&self, filter: Option) -> Result { + let dataset = self.dataset.get().await?; + if let Some(filter) = filter { + let mut scanner = dataset.scan(); + scanner.filter(&filter)?; + Ok(scanner.count_rows().await? as usize) + } else { + Ok(dataset.count_rows().await?) + } + } + + async fn do_add(&self, add: AddDataBuilder) -> Result<()> { + let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { + mode: match add.mode { + AddDataMode::Append => WriteMode::Append, + AddDataMode::Overwrite => WriteMode::Overwrite, + }, + ..Default::default() + }); + + // patch the params if we have a write store wrapper + let lance_params = match self.store_wrapper.clone() { + Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?, + None => lance_params, + }; + + let dataset = Dataset::write(add.data, &self.uri, Some(lance_params)).await?; + self.dataset.set_latest(dataset).await; + Ok(()) + } + + async fn do_query(&self, query: &Query) -> Result { + let ds_ref = self.dataset.get().await?; + let mut scanner: Scanner = ds_ref.scan(); + + if let Some(query_vector) = query.query_vector.as_ref() { + // If there is a vector query, default to limit=10 if unspecified + let column = if let Some(col) = query.column.as_ref() { + col.clone() + } else { + // Infer a vector column with the same dimension of the query vector. + let arrow_schema = Schema::from(ds_ref.schema()); + default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? + }; + let field = ds_ref.schema().field(&column).ok_or(Error::Store { + message: format!("Column {} not found in dataset schema", column), + })?; + if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32) + { + return Err(Error::Store { + message: format!( + "Vector column '{}' does not match the dimension of the query vector: dim={}", + column, + query_vector.len(), + ), + }); + } + scanner.nearest(&column, query_vector, query.limit.unwrap_or(DEFAULT_TOP_K))?; + } else { + // If there is no vector query, it's ok to not have a limit + scanner.limit(query.limit.map(|limit| limit as i64), None)?; + } + scanner.nprobs(query.nprobes); + scanner.use_index(query.use_index); + scanner.prefilter(query.prefilter); + + match &query.select { + Select::Simple(select) => { + scanner.project(select.as_slice())?; + } + Select::Projection(select_with_transform) => { + scanner.project_with_transform(select_with_transform.as_slice())?; + } + Select::All => { /* Do nothing */ } + } + + query.filter.as_ref().map(|f| scanner.filter(f)); + query.refine_factor.map(|rf| scanner.refine(rf)); + query.metric_type.map(|mt| scanner.distance_metric(mt)); + Ok(scanner.try_into_stream().await?) + } + async fn do_merge_insert( &self, params: MergeInsertBuilder, @@ -795,78 +1013,113 @@ impl MergeInsert for NativeTable { self.dataset.set_latest(new_dataset.as_ref().clone()).await; Ok(()) } -} -#[async_trait::async_trait] -impl Table for NativeTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } + async fn do_create_index(&self, index: IndexBuilder) -> Result<()> { + let schema = self.schema().await?; - fn as_native(&self) -> Option<&NativeTable> { - Some(self) - } - - fn name(&self) -> &str { - self.name.as_str() - } - - async fn schema(&self) -> Result { - let lance_schema = self.dataset.get().await?.schema().clone(); - Ok(Arc::new(Schema::from(&lance_schema))) - } - - async fn count_rows(&self, filter: Option) -> Result { - let dataset = self.dataset.get().await?; - if let Some(filter) = filter { - let mut scanner = dataset.scan(); - scanner.filter(&filter)?; - Ok(scanner.count_rows().await? as usize) + // TODO: simplify this after GH lance#1864. + let mut index_type = &index.index_type; + let columns = if index.columns.is_empty() { + // By default we create vector index. + index_type = &IndexType::Vector; + vec![default_vector_column(&schema, None)?] } else { - Ok(dataset.count_rows().await?) - } - } - - async fn add( - &self, - batches: Box, - params: AddDataOptions, - ) -> Result<()> { - let lance_params = params - .write_options - .lance_write_params - .unwrap_or(WriteParams { - mode: match params.mode { - AddDataMode::Append => WriteMode::Append, - AddDataMode::Overwrite => WriteMode::Overwrite, - }, - ..Default::default() - }); - - // patch the params if we have a write store wrapper - let lance_params = match self.store_wrapper.clone() { - Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?, - None => lance_params, + index.columns.clone() }; - let dataset = Dataset::write(batches, &self.uri, Some(lance_params)).await?; - self.dataset.set_latest(dataset).await; + if columns.len() != 1 { + return Err(Error::Schema { + message: "Only one column is supported for index".to_string(), + }); + } + let column = &columns[0]; + + let field = schema.field_with_name(column)?; + + let params = match index_type { + IndexType::Scalar => IndexParams::Scalar { + replace: index.replace, + }, + IndexType::Vector => { + let num_partitions = if let Some(n) = index.num_partitions { + n + } else { + suggested_num_partitions(self.count_rows(None).await?) + }; + let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors { + n + } else { + match field.data_type() { + arrow_schema::DataType::FixedSizeList(_, n) => { + Ok::(suggested_num_sub_vectors(*n as u32)) + } + _ => Err(Error::Schema { + message: format!( + "Column '{}' is not a FixedSizeList", + &index.columns[0] + ), + }), + }? + }; + IndexParams::IvfPq { + replace: index.replace, + metric_type: index.metric_type, + num_partitions: num_partitions as u64, + num_sub_vectors, + num_bits: index.num_bits, + sample_rate: index.sample_rate, + max_iterations: index.max_iterations, + } + } + }; + + let tbl = self + .as_native() + .expect("Only native table is supported here"); + let mut dataset = tbl.dataset.get_mut().await?; + match params { + IndexParams::Scalar { replace } => { + dataset + .create_index( + &[&column], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + replace, + ) + .await? + } + IndexParams::IvfPq { + replace, + metric_type, + num_partitions, + num_sub_vectors, + num_bits, + max_iterations, + .. + } => { + let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq( + num_partitions as usize, + num_bits as u8, + num_sub_vectors as usize, + false, + metric_type, + max_iterations as usize, + ); + dataset + .create_index( + &[column], + IndexType::Vector, + None, + &lance_idx_params, + replace, + ) + .await?; + } + } Ok(()) } - fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder { - let on = Vec::from_iter(on.iter().map(|key| key.to_string())); - MergeInsertBuilder::new(Arc::new(self.clone()), on) - } - - fn create_index(&self, columns: &[&str]) -> IndexBuilder { - IndexBuilder::new(Arc::new(self.clone()), columns) - } - - fn query(&self) -> Query { - Query::new(self.dataset.clone()) - } - /// Delete rows from the table async fn delete(&self, predicate: &str) -> Result<()> { self.dataset.get_mut().await?.delete(predicate).await?; @@ -968,6 +1221,7 @@ mod tests { use rand::Rng; use tempfile::tempdir; + use crate::connect; use crate::connection::ConnectBuilder; use super::*; @@ -1027,10 +1281,13 @@ mod tests { async fn test_add() { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); let batches = make_test_batches(); let schema = batches.schema().clone(); - let table = NativeTable::create(uri, "test", batches, None, None, None) + let table = conn + .create_table("test", Box::new(batches)) + .execute() .await .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); @@ -1046,22 +1303,22 @@ mod tests { schema.clone(), ); - table - .add(Box::new(new_batches), AddDataOptions::default()) - .await - .unwrap(); + table.add(Box::new(new_batches)).execute().await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 20); - assert_eq!(table.name, "test"); + assert_eq!(table.name(), "test"); } #[tokio::test] async fn test_merge_insert() { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); // Create a dataset with i=0..10 let batches = merge_insert_test_batches(0, 0); - let table = NativeTable::create(uri, "test", batches, None, None, None) + let table = conn + .create_table("my_table", Box::new(batches)) + .execute() .await .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); @@ -1104,10 +1361,13 @@ mod tests { async fn test_add_overwrite() { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); let batches = make_test_batches(); let schema = batches.schema().clone(); - let table = NativeTable::create(uri, "test", batches, None, None, None) + let table = conn + .create_table("test", Box::new(batches)) + .execute() .await .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); @@ -1124,17 +1384,13 @@ mod tests { // Can overwrite using AddDataOptions::mode table - .add( - Box::new(new_batches), - AddDataOptions { - mode: AddDataMode::Overwrite, - ..Default::default() - }, - ) + .add(Box::new(new_batches)) + .mode(AddDataMode::Overwrite) + .execute() .await .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name, "test"); + assert_eq!(table.name(), "test"); // Can overwrite using underlying WriteParams (which // take precedence over AddDataOptions::mode) @@ -1144,17 +1400,18 @@ mod tests { ..Default::default() }; - let opts = AddDataOptions { - write_options: WriteOptions { - lance_write_params: Some(param), - }, - mode: AddDataMode::Append, - }; - let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); - table.add(Box::new(new_batches), opts).await.unwrap(); + table + .add(Box::new(new_batches)) + .write_options(WriteOptions { + lance_write_params: Some(param), + }) + .mode(AddDataMode::Append) + .execute() + .await + .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name, "test"); + assert_eq!(table.name(), "test"); } #[tokio::test] @@ -1162,6 +1419,11 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test.lance"); let uri = dataset_path.to_str().unwrap(); + let conn = connect(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -1184,20 +1446,23 @@ mod tests { schema.clone(), ); - Dataset::write(record_batch_iter, uri, None).await.unwrap(); - let table = NativeTable::open(uri).await.unwrap(); + let table = conn + .create_table("my_table", Box::new(record_batch_iter)) + .execute() + .await + .unwrap(); table + .as_native() + .unwrap() .update(Some("id > 5"), vec![("name", "'foo'")]) .await .unwrap(); - let ds_after = Dataset::open(uri).await.unwrap(); - let mut batches = ds_after - .scan() - .project(&["id", "name"]) - .unwrap() - .try_into_stream() + let mut batches = table + .query() + .select(&["id", "name"]) + .execute_stream() .await .unwrap() .try_collect::>() @@ -1236,6 +1501,11 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test.lance"); let uri = dataset_path.to_str().unwrap(); + let conn = connect(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("int32", DataType::Int32, false), @@ -1312,8 +1582,11 @@ mod tests { schema.clone(), ); - Dataset::write(record_batch_iter, uri, None).await.unwrap(); - let table = NativeTable::open(uri).await.unwrap(); + let table = conn + .create_table("my_table", Box::new(record_batch_iter)) + .execute() + .await + .unwrap(); // check it can do update for each type let updates: Vec<(&str, &str)> = vec![ @@ -1333,12 +1606,16 @@ mod tests { ]; // for (column, value) in test_cases { - table.update(None, updates).await.unwrap(); + table + .as_native() + .unwrap() + .update(None, updates) + .await + .unwrap(); - let ds_after = Dataset::open(uri).await.unwrap(); - let mut batches = ds_after - .scan() - .project(&[ + let mut batches = table + .query() + .select(&[ "string", "large_string", "int32", @@ -1353,8 +1630,7 @@ mod tests { "vec_f32", "vec_f64", ]) - .unwrap() - .try_into_stream() + .execute_stream() .await .unwrap() .try_collect::>() @@ -1445,9 +1721,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test.lance"); let uri = dataset_path.to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + + conn.create_table("my_table", Box::new(batches)) + .execute() .await .unwrap(); @@ -1462,7 +1741,9 @@ mod tests { ..Default::default() }; assert!(!wrapper.called()); - let _ = NativeTable::open_with_params(uri, "test", None, Some(param), None) + conn.open_table("my_table") + .lance_read_params(param) + .execute() .await .unwrap(); assert!(wrapper.called()); @@ -1510,6 +1791,7 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); let dimension = 16; let schema = Arc::new(ArrowSchema::new(vec![Field::new( @@ -1536,12 +1818,30 @@ mod tests { schema, ); - let table = NativeTable::create(uri, "test", batches, None, None, None) + let table = conn + .create_table("test", Box::new(batches)) + .execute() .await .unwrap(); - assert_eq!(table.count_indexed_rows("my_index").await.unwrap(), None); - assert_eq!(table.count_unindexed_rows("my_index").await.unwrap(), None); + assert_eq!( + table + .as_native() + .unwrap() + .count_indexed_rows("my_index") + .await + .unwrap(), + None + ); + assert_eq!( + table + .as_native() + .unwrap() + .count_unindexed_rows("my_index") + .await + .unwrap(), + None + ); table .create_index(&["embeddings"]) @@ -1552,18 +1852,37 @@ mod tests { .await .unwrap(); - assert_eq!(table.load_indices().await.unwrap().len(), 1); + assert_eq!( + table + .as_native() + .unwrap() + .load_indices() + .await + .unwrap() + .len(), + 1 + ); assert_eq!(table.count_rows(None).await.unwrap(), 512); - assert_eq!(table.name, "test"); + assert_eq!(table.name(), "test"); - let indices = table.load_indices().await.unwrap(); + let indices = table.as_native().unwrap().load_indices().await.unwrap(); let index_uuid = &indices[0].index_uuid; assert_eq!( - table.count_indexed_rows(index_uuid).await.unwrap(), + table + .as_native() + .unwrap() + .count_indexed_rows(index_uuid) + .await + .unwrap(), Some(512) ); assert_eq!( - table.count_unindexed_rows(index_uuid).await.unwrap(), + table + .as_native() + .unwrap() + .count_unindexed_rows(index_uuid) + .await + .unwrap(), Some(0) ); } @@ -1618,13 +1937,11 @@ mod tests { assert_eq!(table2.count_rows(None).await.unwrap(), 0); table1 - .add( - Box::new(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )), - AddDataOptions::default(), - ) + .add(Box::new(RecordBatchIterator::new( + vec![Ok(batch.clone())], + batch.schema(), + ))) + .execute() .await .unwrap(); assert_eq!(table1.count_rows(None).await.unwrap(), 1); diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index 38a8fa13..e7d141e1 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -15,24 +15,16 @@ use std::sync::Arc; use arrow_array::RecordBatchReader; -use async_trait::async_trait; use crate::Result; -#[async_trait] -pub(super) trait MergeInsert: Send + Sync { - async fn do_merge_insert( - &self, - params: MergeInsertBuilder, - new_data: Box, - ) -> Result<()>; -} +use super::TableInternal; /// A builder used to create and run a merge insert operation /// /// See [`super::Table::merge_insert`] for more context pub struct MergeInsertBuilder { - table: Arc, + table: Arc, pub(super) on: Vec, pub(super) when_matched_update_all: bool, pub(super) when_matched_update_all_filt: Option, @@ -42,7 +34,7 @@ pub struct MergeInsertBuilder { } impl MergeInsertBuilder { - pub(super) fn new(table: Arc, on: Vec) -> Self { + pub(super) fn new(table: Arc, on: Vec) -> Self { Self { table, on,