From 2a02d1394b77c452607a09614ea62a4311e68d10 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 29 Feb 2024 13:29:29 -0800 Subject: [PATCH] feat: port create_table to the async python API and the remote rust API (#1031) I've also started `ASYNC_MIGRATION.MD` to keep track of the breaking changes from sync to async python. --- nodejs/__test__/connection.test.ts | 34 ++ nodejs/__test__/table.test.ts | 8 +- nodejs/__test__/tsconfig.json | 15 + nodejs/lancedb/connection.ts | 30 +- nodejs/lancedb/native.d.ts | 4 +- nodejs/lancedb/table.ts | 2 +- nodejs/package.json | 3 +- nodejs/src/connection.rs | 22 +- nodejs/src/table.rs | 18 +- python/ASYNC_MIGRATION.md | 24 + python/Cargo.toml | 6 +- python/python/lancedb/_lancedb.pyi | 12 + python/python/lancedb/common.py | 98 +++- python/python/lancedb/db.py | 72 ++- python/python/lancedb/table.py | 714 ++++++++++++++++++++++++++++ python/python/tests/test_db.py | 72 +++ python/src/connection.rs | 67 ++- python/src/error.rs | 1 + python/src/lib.rs | 3 +- python/src/table.rs | 34 ++ rust/lancedb/Cargo.toml | 1 + rust/lancedb/src/connection.rs | 10 +- rust/lancedb/src/error.rs | 30 +- rust/lancedb/src/remote.rs | 2 + rust/lancedb/src/remote/client.rs | 7 +- rust/lancedb/src/remote/db.rs | 31 +- rust/lancedb/src/remote/table.rs | 89 ++++ rust/lancedb/src/remote/util.rs | 21 + rust/lancedb/tests/lancedb_cloud.rs | 29 +- 29 files changed, 1406 insertions(+), 53 deletions(-) create mode 100644 nodejs/__test__/connection.test.ts create mode 100644 nodejs/__test__/tsconfig.json create mode 100644 python/ASYNC_MIGRATION.md create mode 100644 python/src/table.rs create mode 100644 rust/lancedb/src/remote/table.rs create mode 100644 rust/lancedb/src/remote/util.rs diff --git a/nodejs/__test__/connection.test.ts b/nodejs/__test__/connection.test.ts new file mode 100644 index 00000000..4ffcb906 --- /dev/null +++ b/nodejs/__test__/connection.test.ts @@ -0,0 +1,34 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import * as os from "os"; +import * as path from "path"; +import * as fs from "fs"; + +import { connect } from "../dist/index.js"; + +describe("when working with a connection", () => { + + const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-connection")); + + it("should fail if creating table twice, unless overwrite is true", async() => { + const db = await connect(tmpDir); + let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]); + await expect(tbl.countRows()).resolves.toBe(2); + await expect(db.createTable("test", [{ id: 1 }, { id: 2 }])).rejects.toThrow(); + tbl = await db.createTable("test", [{ id: 3 }], { mode: "overwrite" }); + await expect(tbl.countRows()).resolves.toBe(1); + }) + +}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 50a15247..6de039ad 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -201,17 +201,17 @@ describe("Read consistency interval", () => { await table.add([{ id: 2 }]); if (interval === undefined) { - expect(await table2.countRows()).toEqual(1n); + expect(await table2.countRows()).toEqual(1); // TODO: once we implement time travel we can uncomment this part of the test. // await table2.checkout_latest(); // expect(await table2.countRows()).toEqual(2); } else if (interval === 0) { - expect(await table2.countRows()).toEqual(2n); + expect(await table2.countRows()).toEqual(2); } else { // interval == 0.1 - expect(await table2.countRows()).toEqual(1n); + expect(await table2.countRows()).toEqual(1); await new Promise(r => setTimeout(r, 100)); - expect(await table2.countRows()).toEqual(2n); + expect(await table2.countRows()).toEqual(2); } }); }); diff --git a/nodejs/__test__/tsconfig.json b/nodejs/__test__/tsconfig.json new file mode 100644 index 00000000..f127268c --- /dev/null +++ b/nodejs/__test__/tsconfig.json @@ -0,0 +1,15 @@ +{ + "extends": "../tsconfig.json", + "compilerOptions": { + "outDir": "./dist/spec", + "module": "commonjs", + "target": "es2022", + "types": [ + "jest", + "node" + ] + }, + "include": [ + "**/*", + ] +} diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index 723d8c85..46b1109f 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -17,6 +17,24 @@ import { Connection as _NativeConnection } from "./native"; import { Table } from "./table"; import { Table as ArrowTable } from "apache-arrow"; +export interface CreateTableOptions { + /** + * The mode to use when creating the table. + * + * If this is set to "create" and the table already exists then either + * an error will be thrown or, if existOk is true, then nothing will + * happen. Any provided data will be ignored. + * + * If this is set to "overwrite" then any existing table will be replaced. + */ + mode: "create" | "overwrite"; + /** + * If this is true and the table already exists and the mode is "create" + * then no error will be raised. + */ + existOk: boolean; +} + /** * A LanceDB Connection that allows you to open tables and create new ones. * @@ -53,10 +71,18 @@ export class Connection { */ async createTable( name: string, - data: Record[] | ArrowTable + data: Record[] | ArrowTable, + options?: Partial ): Promise { + let mode: string = options?.mode ?? "create"; + const existOk = options?.existOk ?? false; + + if (mode === "create" && existOk) { + mode = "exist_ok"; + } + const buf = toBuffer(data); - const innerTable = await this.inner.createTable(name, buf); + const innerTable = await this.inner.createTable(name, buf, mode); return new Table(innerTable); } diff --git a/nodejs/lancedb/native.d.ts b/nodejs/lancedb/native.d.ts index 573b70c7..e72b54cb 100644 --- a/nodejs/lancedb/native.d.ts +++ b/nodejs/lancedb/native.d.ts @@ -85,7 +85,7 @@ export class Connection { * - buf: The buffer containing the IPC file. * */ - createTable(name: string, buf: Buffer): Promise
+ createTable(name: string, buf: Buffer, mode: string): Promise
openTable(name: string): Promise
/** Drop table with the name. Or raise an error if the table does not exist. */ dropTable(name: string): Promise @@ -117,7 +117,7 @@ export class Table { /** Return Schema as empty Arrow IPC file. */ schema(): Promise add(buf: Buffer): Promise - countRows(filter?: string | undefined | null): Promise + countRows(filter?: string | undefined | null): Promise delete(predicate: string): Promise createIndex(): IndexBuilder query(): Query diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index d3b9ea37..e2ef723a 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -50,7 +50,7 @@ export class Table { } /** Count the total number of rows in the dataset. */ - async countRows(filter?: string): Promise { + async countRows(filter?: string): Promise { return await this.inner.countRows(filter); } diff --git a/nodejs/package.json b/nodejs/package.json index ba91b0b8..39473320 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -51,8 +51,7 @@ "docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts", "lint": "eslint lancedb --ext .js,.ts", "prepublishOnly": "napi prepublish -t npm", - "//": "maxWorkers=1 is workaround for bigint issue in jest: https://github.com/jestjs/jest/issues/11617#issuecomment-1068732414", - "test": "npm run build && jest --maxWorkers=1", + "test": "npm run build && jest --verbose", "universal": "napi universal", "version": "napi version" }, diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 1942d701..9bef5eec 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -17,7 +17,7 @@ use napi_derive::*; use crate::table::Table; use crate::ConnectionOptions; -use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection}; +use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection, CreateTableMode}; use lancedb::ipc::ipc_file_to_batches; #[napi] @@ -25,6 +25,17 @@ pub struct Connection { conn: LanceDBConnection, } +impl Connection { + fn parse_create_mode_str(mode: &str) -> napi::Result { + match mode { + "create" => Ok(CreateTableMode::Create), + "overwrite" => Ok(CreateTableMode::Overwrite), + "exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)), + _ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))), + } + } +} + #[napi] impl Connection { /// Create a new Connection instance from the given URI. @@ -65,12 +76,19 @@ impl Connection { /// - buf: The buffer containing the IPC file. /// #[napi] - pub async fn create_table(&self, name: String, buf: Buffer) -> napi::Result
{ + pub async fn create_table( + &self, + name: String, + buf: Buffer, + mode: String, + ) -> napi::Result
{ let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; + let mode = Self::parse_create_mode_str(&mode)?; let tbl = self .conn .create_table(&name, Box::new(batches)) + .mode(mode) .execute() .await .map_err(|e| napi::Error::from_reason(format!("{}", e)))?; diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 29e03f3a..6d46e466 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -68,13 +68,17 @@ impl Table { } #[napi] - pub async fn count_rows(&self, filter: Option) -> napi::Result { - self.table.count_rows(filter).await.map_err(|e| { - napi::Error::from_reason(format!( - "Failed to count rows in table {}: {}", - self.table, e - )) - }) + pub async fn count_rows(&self, filter: Option) -> napi::Result { + self.table + .count_rows(filter) + .await + .map(|val| val as i64) + .map_err(|e| { + napi::Error::from_reason(format!( + "Failed to count rows in table {}: {}", + self.table, e + )) + }) } #[napi] diff --git a/python/ASYNC_MIGRATION.md b/python/ASYNC_MIGRATION.md new file mode 100644 index 00000000..6a9231c4 --- /dev/null +++ b/python/ASYNC_MIGRATION.md @@ -0,0 +1,24 @@ +# Migration from Sync to Async API + +A new asynchronous API has been added to LanceDb. This API is built +on top of the rust lancedb crate (instead of being built on top of +pylance). This will help keep the various language bindings in sync. +There are some slight changes between the synchronous and the asynchronous +APIs. This document will help you migrate. These changes relate mostly +to the Connection and Table classes. + +## Almost all functions are async + +The most important change is that almost all functions are now async. +This means the functions now return `asyncio` coroutines. You will +need to use `await` to call these functions. + +## Connection + +No changes yet. + +## Table + +* Previously `Table.schema` was a property. Now it is an async method. +* The method `Table.__len__` was removed and `len(table)` will no longer + work. Use `Table.count_rows` instead. diff --git a/python/Cargo.toml b/python/Cargo.toml index e6b83206..9d1101ea 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -14,6 +14,7 @@ name = "_lancedb" crate-type = ["cdylib"] [dependencies] +arrow = { version = "50.0.0", features = ["pyarrow"] } lancedb = { path = "../rust/lancedb" } env_logger = "0.10" pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] } @@ -23,4 +24,7 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } lzma-sys = { version = "*", features = ["static"] } [build-dependencies] -pyo3-build-config = { version = "0.20.3", features = ["extension-module", "abi3-py38"] } +pyo3-build-config = { version = "0.20.3", features = [ + "extension-module", + "abi3-py38", +] } diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 308220d5..d1351084 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -1,7 +1,19 @@ from typing import Optional +import pyarrow as pa + class Connection(object): async def table_names(self) -> list[str]: ... + async def create_table( + self, name: str, mode: str, data: pa.RecordBatchReader + ) -> Table: ... + async def create_empty_table( + self, name: str, mode: str, schema: pa.Schema + ) -> Table: ... + +class Table(object): + def name(self) -> str: ... + async def schema(self) -> pa.Schema: ... async def connect( uri: str, diff --git a/python/python/lancedb/common.py b/python/python/lancedb/common.py index ff3b3636..cc894a72 100644 --- a/python/python/lancedb/common.py +++ b/python/python/lancedb/common.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path -from typing import Iterable, List, Union +from typing import Iterable, List, Optional, Union import numpy as np import pyarrow as pa @@ -38,3 +38,99 @@ class Credential(str): def sanitize_uri(uri: URI) -> str: return str(uri) + + +def _casting_recordbatch_iter( + input_iter: Iterable[pa.RecordBatch], schema: pa.Schema +) -> Iterable[pa.RecordBatch]: + """ + Wrapper around an iterator of record batches. If the batches don't match the + schema, try to cast them to the schema. If that fails, raise an error. + + This is helpful for users who might have written the iterator with default + data types in PyArrow, but specified more specific types in the schema. For + example, PyArrow defaults to float64 for floating point types, but Lance + uses float32 for vectors. + """ + for batch in input_iter: + if not isinstance(batch, pa.RecordBatch): + raise TypeError(f"Expected RecordBatch, got {type(batch)}") + if batch.schema != schema: + try: + # RecordBatch doesn't have a cast method, but table does. + batch = pa.Table.from_batches([batch]).cast(schema).to_batches()[0] + except pa.lib.ArrowInvalid: + raise ValueError( + f"Input RecordBatch iterator yielded a batch with schema that " + f"does not match the expected schema.\nExpected:\n{schema}\n" + f"Got:\n{batch.schema}" + ) + yield batch + + +def data_to_reader( + data: DATA, schema: Optional[pa.Schema] = None +) -> pa.RecordBatchReader: + """Convert various types of input into a RecordBatchReader""" + if pd is not None and isinstance(data, pd.DataFrame): + return pa.Table.from_pandas(data, schema=schema).to_reader() + elif isinstance(data, pa.Table): + return data.to_reader() + elif isinstance(data, pa.RecordBatch): + return pa.Table.from_batches([data]).to_reader() + # elif isinstance(data, LanceDataset): + # return data_obj.scanner().to_reader() + elif isinstance(data, pa.dataset.Dataset): + return pa.dataset.Scanner.from_dataset(data).to_reader() + elif isinstance(data, pa.dataset.Scanner): + return data.to_reader() + elif isinstance(data, pa.RecordBatchReader): + return data + elif ( + type(data).__module__.startswith("polars") + and data.__class__.__name__ == "DataFrame" + ): + return data.to_arrow().to_reader() + # for other iterables, assume they are of type Iterable[RecordBatch] + elif isinstance(data, Iterable): + if schema is not None: + data = _casting_recordbatch_iter(data, schema) + return pa.RecordBatchReader.from_batches(schema, data) + else: + raise ValueError( + "Must provide schema to write dataset from RecordBatch iterable" + ) + else: + raise TypeError( + f"Unknown data type {type(data)}. " + "Please check " + "https://lancedb.github.io/lance/read_and_write.html " + "to see supported types." + ) + + +def validate_schema(schema: pa.Schema): + """ + Make sure the metadata is valid utf8 + """ + if schema.metadata is not None: + _validate_metadata(schema.metadata) + + +def _validate_metadata(metadata: dict): + """ + Make sure the metadata values are valid utf8 (can be nested) + + Raises ValueError if not valid utf8 + """ + for k, v in metadata.items(): + if isinstance(v, bytes): + try: + v.decode("utf8") + except UnicodeDecodeError: + raise ValueError( + f"Metadata key {k} is not valid utf8. " + "Consider base64 encode for generic binary metadata." + ) + elif isinstance(v, dict): + _validate_metadata(v) diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index a7dd4f4d..d87983da 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -13,6 +13,7 @@ from __future__ import annotations +import inspect import os from abc import abstractmethod from pathlib import Path @@ -22,7 +23,12 @@ import pyarrow as pa from overrides import EnforceOverrides, override from pyarrow import fs -from .table import LanceTable, Table +from lancedb.common import data_to_reader, validate_schema +from lancedb.embeddings.registry import EmbeddingFunctionRegistry +from lancedb.utils.events import register_event + +from .pydantic import LanceModel +from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri if TYPE_CHECKING: @@ -31,7 +37,6 @@ if TYPE_CHECKING: from ._lancedb import Connection as LanceDbConnection from .common import DATA, URI from .embeddings import EmbeddingFunctionConfig - from .pydantic import LanceModel class DBConnection(EnforceOverrides): @@ -644,6 +649,7 @@ class AsyncLanceDBConnection(AsyncConnection): page_token=None, limit=None, ) -> Iterable[str]: + # TODO: hook in page_token and limit return await self._inner.table_names() @override @@ -657,8 +663,66 @@ class AsyncLanceDBConnection(AsyncConnection): on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, - ) -> LanceTable: - raise NotImplementedError + ) -> Table: + if mode.lower() not in ["create", "overwrite"]: + raise ValueError("mode must be either 'create' or 'overwrite'") + + if inspect.isclass(schema) and issubclass(schema, LanceModel): + # convert LanceModel to pyarrow schema + # note that it's possible this contains + # embedding function metadata already + schema = schema.to_arrow_schema() + + metadata = None + if embedding_functions is not None: + # If we passed in embedding functions explicitly + # then we'll override any schema metadata that + # may was implicitly specified by the LanceModel schema + registry = EmbeddingFunctionRegistry.get_instance() + metadata = registry.get_table_metadata(embedding_functions) + + if data is not None: + data = _sanitize_data( + data, + schema, + metadata=metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + + if schema is None: + if data is None: + raise ValueError("Either data or schema must be provided") + elif hasattr(data, "schema"): + schema = data.schema + elif isinstance(data, Iterable): + if metadata: + raise TypeError( + ( + "Persistent embedding functions not yet " + "supported for generator data input" + ) + ) + + if metadata: + schema = schema.with_metadata(metadata) + validate_schema(schema) + + if mode == "create" and exist_ok: + mode = "exist_ok" + + if data is None: + new_table = await self._inner.create_empty_table(name, mode, schema) + else: + data = data_to_reader(data, schema) + new_table = await self._inner.create_table( + name, + mode, + data, + ) + + register_event("create_table") + return AsyncLanceTable(new_table) @override async def open_table(self, name: str) -> LanceTable: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index fe624793..7243f940 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -28,6 +28,7 @@ import pyarrow.compute as pc import pyarrow.fs as pa_fs from lance import LanceDataset from lance.vector import vec_to_table +from overrides import override from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry @@ -48,6 +49,7 @@ if TYPE_CHECKING: import PIL from lance.dataset import CleanupStats, ReaderLike + from ._lancedb import Table as LanceDBTable from .db import LanceDBConnection @@ -1780,3 +1782,715 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1) data = data.filter(is_full) return data + + +class AsyncTable(ABC): + """ + A Table is a collection of Records in a LanceDB Database. + + Examples + -------- + + Create using [DBConnection.create_table][lancedb.DBConnection.create_table] + (more examples in that method's documentation). + + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}]) + >>> table.head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + b: int64 + ---- + vector: [[[1.1,1.2]]] + b: [[2]] + + Can append new data with [Table.add()][lancedb.table.Table.add]. + + >>> table.add([{"vector": [0.5, 1.3], "b": 4}]) + + Can query the table with [Table.search][lancedb.table.Table.search]. + + >>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas() + b vector _distance + 0 4 [0.5, 1.3] 0.82 + 1 2 [1.1, 1.2] 1.13 + + Search queries are much faster when an index is created. See + [Table.create_index][lancedb.table.Table.create_index]. + """ + + @property + @abstractmethod + def name(self) -> str: + """The name of the table.""" + raise NotImplementedError + + @abstractmethod + async def schema(self) -> pa.Schema: + """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) + of this Table + + """ + raise NotImplementedError + + @abstractmethod + async def count_rows(self, filter: Optional[str] = None) -> int: + """ + Count the number of rows in the table. + + Parameters + ---------- + filter: str, optional + A SQL where clause to filter the rows to count. + """ + raise NotImplementedError + + async def to_pandas(self) -> "pd.DataFrame": + """Return the table as a pandas DataFrame. + + Returns + ------- + pd.DataFrame + """ + return self.to_arrow().to_pandas() + + @abstractmethod + async def to_arrow(self) -> pa.Table: + """Return the table as a pyarrow Table. + + Returns + ------- + pa.Table + """ + raise NotImplementedError + + async def create_index( + self, + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name: str = VECTOR_COLUMN_NAME, + replace: bool = True, + accelerator: Optional[str] = None, + index_cache_size: Optional[int] = None, + ): + """Create an index on the table. + + Parameters + ---------- + metric: str, default "L2" + The distance metric to use when creating the index. + Valid values are "L2", "cosine", or "dot". + L2 is euclidean distance. + num_partitions: int, default 256 + The number of IVF partitions to use when creating the index. + Default is 256. + num_sub_vectors: int, default 96 + The number of PQ sub-vectors to use when creating the index. + Default is 96. + vector_column_name: str, default "vector" + The vector column name to create the index. + replace: bool, default True + - If True, replace the existing index if it exists. + + - If False, raise an error if duplicate index exists. + accelerator: str, default None + If set, use the given accelerator to create the index. + Only support "cuda" for now. + index_cache_size : int, optional + The size of the index cache in number of entries. Default value is 256. + """ + raise NotImplementedError + + @abstractmethod + async def create_scalar_index( + self, + column: str, + *, + replace: bool = True, + ): + """Create a scalar index on a column. + + Scalar indices, like vector indices, can be used to speed up scans. A scalar + index can speed up scans that contain filter expressions on the indexed column. + For example, the following scan will be faster if the column ``my_col`` has + a scalar index: + + .. code-block:: python + + import lancedb + + db = lancedb.connect("/data/lance") + img_table = db.open_table("images") + my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas() + + Scalar indices can also speed up scans containing a vector search and a + prefilter: + + .. code-block::python + + import lancedb + + db = lancedb.connect("/data/lance") + img_table = db.open_table("images") + img_table.search([1, 2, 3, 4], vector_column_name="vector") + .where("my_col != 7", prefilter=True) + .to_pandas() + + Scalar indices can only speed up scans for basic filters using + equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set + membership (e.g. `my_col IN (0, 1, 2)`) + + Scalar indices can be used if the filter contains multiple indexed columns and + the filter criteria are AND'd or OR'd together + (e.g. ``my_col < 0 AND other_col> 100``) + + Scalar indices may be used if the filter contains non-indexed columns but, + depending on the structure of the filter, they may not be usable. For example, + if the column ``not_indexed`` does not have a scalar index then the filter + ``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on + ``my_col``. + + **Experimental API** + + Parameters + ---------- + column : str + The column to be indexed. Must be a boolean, integer, float, + or string column. + replace : bool, default True + Replace the existing index if it exists. + + Examples + -------- + + .. code-block:: python + + import lance + + dataset = lance.dataset("./images.lance") + dataset.create_scalar_index("category") + """ + raise NotImplementedError + + @abstractmethod + async def add( + self, + data: DATA, + mode: str = "append", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ): + """Add more data to the [Table](Table). + + Parameters + ---------- + data: DATA + The data to insert into the table. Acceptable types are: + + - dict or list-of-dict + + - pandas.DataFrame + + - pyarrow.Table or pyarrow.RecordBatch + mode: str + The mode to use when writing the data. Valid values are + "append" and "overwrite". + on_bad_vectors: str, default "error" + What to do if any of the vectors are not the same size or contains NaNs. + One of "error", "drop", "fill". + fill_value: float, default 0. + The value to use when filling vectors. Only used if on_bad_vectors="fill". + + """ + raise NotImplementedError + + def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: + """ + Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] + that can be used to create a "merge insert" operation + + This operation can add rows, update rows, and remove rows all in a single + transaction. It is a very generic tool that can be used to create + behaviors like "insert if not exists", "update or insert (i.e. upsert)", + or even replace a portion of existing data with new data (e.g. replace + all data where month="january") + + The merge insert operation works by combining new data from a + **source table** with existing data in a **target table** by using a + join. There are three categories of records. + + "Matched" records are records that exist in both the source table and + the target table. "Not matched" records exist only in the source table + (e.g. these are new data) "Not matched by source" records exist only + in the target table (this is old data) + + The builder returned by this method can be used to customize what + should happen for each category of data. + + Please note that the data may appear to be reordered as part of this + operation. This is because updated rows will be deleted from the + dataset and then reinserted at the end with the new values. + + Parameters + ---------- + + on: Union[str, Iterable[str]] + A column (or columns) to join on. This is how records from the + source table and target table are matched. Typically this is some + kind of key or id column. + + Examples + -------- + >>> import lancedb + >>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + >>> # Perform a "upsert" operation + >>> table.merge_insert("a") \\ + ... .when_matched_update_all() \\ + ... .when_not_matched_insert_all() \\ + ... .execute(new_data) + >>> # The order of new rows is non-deterministic since we use + >>> # a hash-join as part of this operation and so we sort here + >>> table.to_arrow().sort_by("a").to_pandas() + a b + 0 1 b + 1 2 x + 2 3 y + 3 4 z + """ + on = [on] if isinstance(on, str) else list(on.iter()) + + return LanceMergeInsertBuilder(self, on) + + @abstractmethod + async def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: str = "auto", + ) -> LanceQueryBuilder: + """Create a search query to find the nearest neighbors + of the given query vector. We currently support [vector search][search] + and [full-text search][experimental-full-text-search]. + + All query options are defined in [Query][lancedb.query.Query]. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> data = [ + ... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]}, + ... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]}, + ... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]} + ... ] + >>> table = db.create_table("my_table", data) + >>> query = [0.4, 1.4, 2.4] + >>> (table.search(query) + ... .where("original_width > 1000", prefilter=True) + ... .select(["caption", "original_width", "vector"]) + ... .limit(2) + ... .to_pandas()) + caption original_width vector _distance + 0 foo 2000 [0.5, 3.4, 1.3] 5.220000 + 1 test 3000 [0.3, 6.2, 2.6] 23.089996 + + Parameters + ---------- + query: list/np.ndarray/str/PIL.Image.Image, default None + The targetted vector to search for. + + - *default None*. + Acceptable types are: list, np.ndarray, PIL.Image.Image + + - If None then the select/where/limit clauses are applied to filter + the table + vector_column_name: str, optional + The name of the vector column to search. + + The vector column needs to be a pyarrow fixed size list type + + - If not specified then the vector column is inferred from + the table schema + + - If the table has multiple vector columns then the *vector_column_name* + needs to be specified. Otherwise, an error is raised. + query_type: str + *default "auto"*. + Acceptable types are: "vector", "fts", "hybrid", or "auto" + + - If "auto" then the query type is inferred from the query; + + - If `query` is a list/np.ndarray then the query type is + "vector"; + + - If `query` is a PIL.Image.Image then either do vector search, + or raise an error if no corresponding embedding function is found. + + - If `query` is a string, then the query type is "vector" if the + table has embedding functions else the query type is "fts" + + Returns + ------- + LanceQueryBuilder + A query builder object representing the query. + Once executed, the query returns + + - selected columns + + - the vector + + - and also the "_distance" column which is the distance between the query + vector and the returned vector. + """ + raise NotImplementedError + + @abstractmethod + async def _execute_query(self, query: Query) -> pa.Table: + pass + + @abstractmethod + async def _do_merge( + self, + merge: LanceMergeInsertBuilder, + new_data: DATA, + on_bad_vectors: str, + fill_value: float, + ): + pass + + @abstractmethod + async def delete(self, where: str): + """Delete rows from the table. + + This can be used to delete a single row, many rows, all rows, or + sometimes no rows (if your predicate matches nothing). + + Parameters + ---------- + where: str + The SQL where clause to use when deleting rows. + + - For example, 'x = 2' or 'x IN (1, 2, 3)'. + + The filter must not be empty, or it will error. + + Examples + -------- + >>> import lancedb + >>> data = [ + ... {"x": 1, "vector": [1, 2]}, + ... {"x": 2, "vector": [3, 4]}, + ... {"x": 3, "vector": [5, 6]} + ... ] + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 2 [3.0, 4.0] + 2 3 [5.0, 6.0] + >>> table.delete("x = 2") + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 3 [5.0, 6.0] + + If you have a list of values to delete, you can combine them into a + stringified list and use the `IN` operator: + + >>> to_remove = [1, 5] + >>> to_remove = ", ".join([str(v) for v in to_remove]) + >>> to_remove + '1, 5' + >>> table.delete(f"x IN ({to_remove})") + >>> table.to_pandas() + x vector + 0 3 [5.0, 6.0] + """ + raise NotImplementedError + + @abstractmethod + async def update( + self, + where: Optional[str] = None, + values: Optional[dict] = None, + *, + values_sql: Optional[Dict[str, str]] = None, + ): + """ + This can be used to update zero to all rows depending on how many + rows match the where clause. If no where clause is provided, then + all rows will be updated. + + Either `values` or `values_sql` must be provided. You cannot provide + both. + + Parameters + ---------- + where: str, optional + The SQL where clause to use when updating rows. For example, 'x = 2' + or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. + values: dict, optional + The values to update. The keys are the column names and the values + are the values to set. + values_sql: dict, optional + The values to update, expressed as SQL expression strings. These can + reference existing columns. For example, {"x": "x + 1"} will increment + the x column by 1. + + Examples + -------- + >>> import lancedb + >>> import pandas as pd + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 2 [3.0, 4.0] + 2 3 [5.0, 6.0] + >>> table.update(where="x = 2", values={"vector": [10, 10]}) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 3 [5.0, 6.0] + 2 2 [10.0, 10.0] + >>> table.update(values_sql={"x": "x + 1"}) + >>> table.to_pandas() + x vector + 0 2 [1.0, 2.0] + 1 4 [5.0, 6.0] + 2 3 [10.0, 10.0] + """ + raise NotImplementedError + + @abstractmethod + async def cleanup_old_versions( + self, + older_than: Optional[timedelta] = None, + *, + delete_unverified: bool = False, + ) -> CleanupStats: + """ + Clean up old versions of the table, freeing disk space. + + Note: This function is not available in LanceDb Cloud (since LanceDb + Cloud manages cleanup for you automatically) + + Parameters + ---------- + older_than: timedelta, default None + The minimum age of the version to delete. If None, then this defaults + to two weeks. + delete_unverified: bool, default False + Because they may be part of an in-progress transaction, files newer + than 7 days old are not deleted by default. If you are sure that + there are no in-progress transactions, then you can set this to True + to delete all files older than `older_than`. + + Returns + ------- + CleanupStats + The stats of the cleanup operation, including how many bytes were + freed. + """ + raise NotImplementedError + + @abstractmethod + async def compact_files(self, *args, **kwargs): + """ + Run the compaction process on the table. + + Note: This function is not available in LanceDb Cloud (since LanceDb + Cloud manages compaction for you automatically) + + This can be run after making several small appends to optimize the table + for faster reads. + + Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`. + For most cases, the default should be fine. + """ + raise NotImplementedError + + @abstractmethod + async def add_columns(self, transforms: Dict[str, str]): + """ + Add new columns with defined values. + + This is not yet available in LanceDB Cloud. + + Parameters + ---------- + transforms: Dict[str, str] + A map of column name to a SQL expression to use to calculate the + value of the new column. These expressions will be evaluated for + each row in the table, and can reference existing columns. + """ + raise NotImplementedError + + @abstractmethod + async def alter_columns(self, alterations: Iterable[Dict[str, str]]): + """ + Alter column names and nullability. + + This is not yet available in LanceDB Cloud. + + alterations : Iterable[Dict[str, Any]] + A sequence of dictionaries, each with the following keys: + - "path": str + The column path to alter. For a top-level column, this is the name. + For a nested column, this is the dot-separated path, e.g. "a.b.c". + - "name": str, optional + The new name of the column. If not specified, the column name is + not changed. + - "nullable": bool, optional + Whether the column should be nullable. If not specified, the column + nullability is not changed. Only non-nullable columns can be changed + to nullable. Currently, you cannot change a nullable column to + non-nullable. + """ + raise NotImplementedError + + @abstractmethod + async def drop_columns(self, columns: Iterable[str]): + """ + Drop columns from the table. + + This is not yet available in LanceDB Cloud. + + Parameters + ---------- + columns : Iterable[str] + The names of the columns to drop. + """ + raise NotImplementedError + + +class AsyncLanceTable(AsyncTable): + def __init__(self, table: LanceDBTable): + self._inner = table + + @property + @override + def name(self) -> str: + return self._inner.name() + + @override + async def schema(self) -> pa.Schema: + return await self._inner.schema() + + @override + async def count_rows(self, filter: Optional[str] = None) -> int: + raise NotImplementedError + + async def to_pandas(self) -> "pd.DataFrame": + return self.to_arrow().to_pandas() + + @override + async def to_arrow(self) -> pa.Table: + raise NotImplementedError + + async def create_index( + self, + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name: str = VECTOR_COLUMN_NAME, + replace: bool = True, + accelerator: Optional[str] = None, + index_cache_size: Optional[int] = None, + ): + raise NotImplementedError + + @override + async def create_scalar_index( + self, + column: str, + *, + replace: bool = True, + ): + raise NotImplementedError + + @override + async def add( + self, + data: DATA, + mode: str = "append", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ): + raise NotImplementedError + + def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: + on = [on] if isinstance(on, str) else list(on.iter()) + + return LanceMergeInsertBuilder(self, on) + + @override + async def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: str = "auto", + ) -> LanceQueryBuilder: + raise NotImplementedError + + @override + async def _execute_query(self, query: Query) -> pa.Table: + pass + + @override + async def _do_merge( + self, + merge: LanceMergeInsertBuilder, + new_data: DATA, + on_bad_vectors: str, + fill_value: float, + ): + pass + + @override + async def delete(self, where: str): + raise NotImplementedError + + @override + async def update( + self, + where: Optional[str] = None, + values: Optional[dict] = None, + *, + values_sql: Optional[Dict[str, str]] = None, + ): + raise NotImplementedError + + @override + async def cleanup_old_versions( + self, + older_than: Optional[timedelta] = None, + *, + delete_unverified: bool = False, + ) -> CleanupStats: + raise NotImplementedError + + @override + async def compact_files(self, *args, **kwargs): + raise NotImplementedError + + @override + async def add_columns(self, transforms: Dict[str, str]): + raise NotImplementedError + + @override + async def alter_columns(self, alterations: Iterable[Dict[str, str]]): + raise NotImplementedError + + @override + async def drop_columns(self, columns: Iterable[str]): + raise NotImplementedError diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 6f2fd1b4..c66131cf 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -250,6 +250,78 @@ def test_create_exist_ok(tmp_path): db.create_table("test", schema=bad_schema, exist_ok=True) +@pytest.mark.asyncio +async def test_create_mode_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + await db.create_table("test", data=data) + + with pytest.raises(RuntimeError): + await db.create_table("test", data=data) + + new_data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["fizz", "buzz"], + "price": [10.0, 20.0], + } + ) + _tbl = await db.create_table("test", data=new_data, mode="overwrite") + + # MIGRATION: to_pandas() is not available in async + # assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"] + + +@pytest.mark.asyncio +async def test_create_exist_ok_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + tbl = await db.create_table("test", data=data) + + with pytest.raises(RuntimeError): + await db.create_table("test", data=data) + + # open the table but don't add more rows + tbl2 = await db.create_table("test", data=data, exist_ok=True) + assert tbl.name == tbl2.name + assert await tbl.schema() == await tbl2.schema() + + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), list_size=2)), + pa.field("item", pa.utf8()), + pa.field("price", pa.float64()), + ] + ) + tbl3 = await db.create_table("test", schema=schema, exist_ok=True) + assert await tbl3.schema() == schema + + # Migration: When creating a table, but the table already exists, but + # the schema is different, it should raise an error. + # bad_schema = pa.schema( + # [ + # pa.field("vector", pa.list_(pa.float32(), list_size=2)), + # pa.field("item", pa.utf8()), + # pa.field("price", pa.float64()), + # pa.field("extra", pa.float32()), + # ] + # ) + # with pytest.raises(ValueError): + # await db.create_table("test", schema=bad_schema, exist_ok=True) + + def test_delete_table(tmp_path): db = lancedb.connect(tmp_path) data = pd.DataFrame( diff --git a/python/src/connection.rs b/python/src/connection.rs index 96d3c7cb..1f0fa759 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -12,19 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::time::Duration; +use std::{sync::Arc, time::Duration}; -use lancedb::connection::Connection as LanceConnection; -use pyo3::{pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python}; +use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; +use lancedb::connection::{Connection as LanceConnection, CreateTableMode}; +use pyo3::{ + exceptions::PyValueError, pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python, +}; use pyo3_asyncio::tokio::future_into_py; -use crate::error::PythonErrorExt; +use crate::{error::PythonErrorExt, table::Table}; #[pyclass] pub struct Connection { inner: LanceConnection, } +impl Connection { + fn parse_create_mode_str(mode: &str) -> PyResult { + match mode { + "create" => Ok(CreateTableMode::Create), + "overwrite" => Ok(CreateTableMode::Overwrite), + "exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)), + _ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))), + } + } +} + #[pymethods] impl Connection { pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { @@ -33,6 +47,51 @@ impl Connection { inner.table_names().await.infer_error() }) } + + pub fn create_table<'a>( + self_: PyRef<'a, Self>, + name: String, + mode: &str, + data: &PyAny, + ) -> PyResult<&'a PyAny> { + let inner = self_.inner.clone(); + + let mode = Self::parse_create_mode_str(mode)?; + + let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?); + future_into_py(self_.py(), async move { + let table = inner + .create_table(name, batches) + .mode(mode) + .execute() + .await + .infer_error()?; + Ok(Table::new(table)) + }) + } + + pub fn create_empty_table<'a>( + self_: PyRef<'a, Self>, + name: String, + mode: &str, + schema: &PyAny, + ) -> PyResult<&'a PyAny> { + let inner = self_.inner.clone(); + + let mode = Self::parse_create_mode_str(mode)?; + + let schema = Schema::from_pyarrow(schema)?; + + future_into_py(self_.py(), async move { + let table = inner + .create_empty_table(name, Arc::new(schema)) + .mode(mode) + .execute() + .await + .infer_error()?; + Ok(Table::new(table)) + }) + } } #[pyfunction] diff --git a/python/src/error.rs b/python/src/error.rs index 20ae7c2a..67add21b 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -45,6 +45,7 @@ impl PythonErrorExt for std::result::Result { LanceError::Lance { .. } => self.runtime_error(), LanceError::Runtime { .. } => self.runtime_error(), LanceError::Http { .. } => self.runtime_error(), + LanceError::Arrow { .. } => self.runtime_error(), }, } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 2a66810d..fa2f5fc4 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -17,7 +17,8 @@ use env_logger::Env; use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python}; pub mod connection; -pub(crate) mod error; +pub mod error; +pub mod table; #[pymodule] pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> { diff --git a/python/src/table.rs b/python/src/table.rs new file mode 100644 index 00000000..23bda9a3 --- /dev/null +++ b/python/src/table.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use arrow::pyarrow::ToPyArrow; +use lancedb::table::Table as LanceTable; +use pyo3::{pyclass, pymethods, PyAny, PyRef, PyResult, Python}; +use pyo3_asyncio::tokio::future_into_py; + +use crate::error::PythonErrorExt; + +#[pyclass] +pub struct Table { + inner: Arc, +} + +impl Table { + pub(crate) fn new(inner: Arc) -> Self { + Self { inner } + } +} + +#[pymethods] +impl Table { + pub fn name(&self) -> String { + self.inner.name().to_string() + } + + pub fn schema(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let schema = inner.schema().await.infer_error()?; + Python::with_gil(|py| schema.to_pyarrow(py)) + }) + } +} diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index e0febf1f..3ca835ab 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -42,6 +42,7 @@ reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } [dev-dependencies] tempfile = "3.5.0" rand = { version = "0.8.3", features = ["small_rng"] } +uuid = { version = "1.7.0", features = ["v4"] } walkdir = "2" [features] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 11d72cde..37b663f4 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -80,11 +80,11 @@ enum BadVectorHandling { /// A builder for configuring a [`Connection::create_table`] operation pub struct CreateTableBuilder { parent: Arc, - name: String, - data: Option>, - schema: Option, - mode: CreateTableMode, - write_options: WriteOptions, + pub(crate) name: String, + pub(crate) data: Option>, + pub(crate) schema: Option, + pub(crate) mode: CreateTableMode, + pub(crate) write_options: WriteOptions, } // Builder methods that only apply when we have initial data diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index f8961b85..611f171b 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -20,36 +20,40 @@ use snafu::Snafu; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { - #[snafu(display("LanceDBError: Invalid table name: {name}"))] + #[snafu(display("Invalid table name: {name}"))] InvalidTableName { name: String }, - #[snafu(display("LanceDBError: Invalid input, {message}"))] + #[snafu(display("Invalid input, {message}"))] InvalidInput { message: String }, - #[snafu(display("LanceDBError: Table '{name}' was not found"))] + #[snafu(display("Table '{name}' was not found"))] TableNotFound { name: String }, - #[snafu(display("LanceDBError: Table '{name}' already exists"))] + #[snafu(display("Table '{name}' already exists"))] TableAlreadyExists { name: String }, - #[snafu(display("LanceDBError: Unable to created lance dataset at {path}: {source}"))] + #[snafu(display("Unable to created lance dataset at {path}: {source}"))] CreateDir { path: String, source: std::io::Error, }, - #[snafu(display("LanceDBError: Http error: {message}"))] - Http { message: String }, - #[snafu(display("LanceDBError: {message}"))] - Store { message: String }, - #[snafu(display("LanceDBError: {message}"))] - Lance { message: String }, - #[snafu(display("LanceDB Schema Error: {message}"))] + #[snafu(display("Schema Error: {message}"))] Schema { message: String }, #[snafu(display("Runtime error: {message}"))] Runtime { message: String }, + + // 3rd party / external errors + #[snafu(display("object_store error: {message}"))] + Store { message: String }, + #[snafu(display("lance error: {message}"))] + Lance { message: String }, + #[snafu(display("Http error: {message}"))] + Http { message: String }, + #[snafu(display("Arrow error: {message}"))] + Arrow { message: String }, } pub type Result = std::result::Result; impl From for Error { fn from(e: ArrowError) -> Self { - Self::Lance { + Self::Arrow { message: e.to_string(), } } diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index 57a86a92..dfdf6224 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -19,3 +19,5 @@ pub mod client; pub mod db; +pub mod table; +pub mod util; diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 8b516d20..6ff9811b 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -21,7 +21,7 @@ use reqwest::{ use crate::error::{Error, Result}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct RestfulLanceDbClient { client: reqwest::Client, host: String, @@ -97,6 +97,11 @@ impl RestfulLanceDbClient { self.client.get(full_uri) } + pub fn post(&self, uri: &str) -> RequestBuilder { + let full_uri = format!("{}{}", self.host, uri); + self.client.post(full_uri) + } + async fn rsp_to_str(response: Response) -> String { let status = response.status(); response.text().await.unwrap_or_else(|_| status.to_string()) diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index db7ee00d..948db4fd 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -12,14 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use async_trait::async_trait; +use reqwest::header::CONTENT_TYPE; use serde::Deserialize; +use tokio::task::spawn_blocking; use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder}; use crate::error::Result; use crate::TableRef; use super::client::RestfulLanceDbClient; +use super::table::RemoteTable; +use super::util::batches_to_ipc_bytes; + +const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; #[derive(Deserialize)] struct ListTablesResponse { @@ -57,8 +65,27 @@ impl ConnectionInternal for RemoteDatabase { Ok(rsp.json::().await?.tables) } - async fn do_create_table(&self, _options: CreateTableBuilder) -> Result { - todo!() + async fn do_create_table(&self, options: CreateTableBuilder) -> Result { + let data = options.data.unwrap(); + // TODO: https://github.com/lancedb/lancedb/issues/1026 + // We should accept data from an async source. In the meantime, spawn this as blocking + // to make sure we don't block the tokio runtime if the source is slow. + let data_buffer = spawn_blocking(move || batches_to_ipc_bytes(data)) + .await + .unwrap()?; + + self.client + .post(&format!("/v1/table/{}/create", options.name)) + .body(data_buffer) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .header("x-request-id", "na") + .send() + .await?; + + Ok(Arc::new(RemoteTable::new( + self.client.clone(), + options.name, + ))) } async fn do_open_table(&self, _options: OpenTableBuilder) -> Result { diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs new file mode 100644 index 00000000..dfbf337f --- /dev/null +++ b/rust/lancedb/src/remote/table.rs @@ -0,0 +1,89 @@ +use arrow_array::RecordBatchReader; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use lance::dataset::{ColumnAlteration, NewColumnTransform}; + +use crate::{ + error::Result, + index::IndexBuilder, + query::Query, + table::{ + merge::MergeInsertBuilder, AddDataOptions, NativeTable, OptimizeAction, OptimizeStats, + }, + Table, +}; + +use super::client::RestfulLanceDbClient; + +#[derive(Debug)] +pub struct RemoteTable { + #[allow(dead_code)] + client: RestfulLanceDbClient, + name: String, +} + +impl RemoteTable { + pub fn new(client: RestfulLanceDbClient, name: String) -> Self { + Self { client, name } + } +} + +impl std::fmt::Display for RemoteTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RemoteTable({})", self.name) + } +} + +#[async_trait] +impl Table for RemoteTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn as_native(&self) -> Option<&NativeTable> { + None + } + fn name(&self) -> &str { + &self.name + } + async fn schema(&self) -> Result { + todo!() + } + async fn count_rows(&self, _filter: Option) -> Result { + todo!() + } + async fn add( + &self, + _batches: Box, + _options: AddDataOptions, + ) -> Result<()> { + todo!() + } + async fn delete(&self, _predicate: &str) -> Result<()> { + todo!() + } + fn create_index(&self, _column: &[&str]) -> IndexBuilder { + todo!() + } + fn merge_insert(&self, _on: &[&str]) -> MergeInsertBuilder { + todo!() + } + fn query(&self) -> Query { + todo!() + } + async fn optimize(&self, _action: OptimizeAction) -> Result { + todo!() + } + async fn add_columns( + &self, + _transforms: NewColumnTransform, + _read_columns: Option>, + ) -> Result<()> { + todo!() + } + async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> { + todo!() + } + async fn drop_columns(&self, _columns: &[&str]) -> Result<()> { + todo!() + } +} diff --git a/rust/lancedb/src/remote/util.rs b/rust/lancedb/src/remote/util.rs new file mode 100644 index 00000000..b594ed6e --- /dev/null +++ b/rust/lancedb/src/remote/util.rs @@ -0,0 +1,21 @@ +use std::io::Cursor; + +use arrow_array::RecordBatchReader; + +use crate::Result; + +pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> { + const WRITE_BUF_SIZE: usize = 4096; + let buf = Vec::with_capacity(WRITE_BUF_SIZE); + let mut buf = Cursor::new(buf); + { + let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batches.schema())?; + + for batch in batches { + let batch = batch?; + writer.write(&batch)?; + } + writer.finish()?; + } + Ok(buf.into_inner()) +} diff --git a/rust/lancedb/tests/lancedb_cloud.rs b/rust/lancedb/tests/lancedb_cloud.rs index 88fae82d..9bf75e91 100644 --- a/rust/lancedb/tests/lancedb_cloud.rs +++ b/rust/lancedb/tests/lancedb_cloud.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use arrow_array::RecordBatchIterator; + #[tokio::test] #[ignore] async fn cloud_integration_test() { @@ -36,5 +40,28 @@ async fn cloud_integration_test() { } let db = builder.execute().await.unwrap(); - db.table_names().await.unwrap(); + let schema = Arc::new(arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("id", arrow_schema::DataType::Int64, false), + arrow_schema::Field::new("name", arrow_schema::DataType::Utf8, false), + ])); + let initial_data = arrow::record_batch::RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow_array::Int64Array::from(vec![1, 2, 3])), + Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])), + ], + ); + let rbr = RecordBatchIterator::new(vec![initial_data], schema); + + let name = uuid::Uuid::new_v4().to_string(); + let tbl = db + .create_table(name.clone(), Box::new(rbr)) + .execute() + .await + .unwrap(); + + assert_eq!(tbl.name(), name); + + let table_names = db.table_names().await.unwrap(); + assert!(table_names.contains(&name)); }