From 73c69a6b9a986c1e0f072c5128d23a75678498e3 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 5 Mar 2024 08:38:18 -0800 Subject: [PATCH] feat: page_token / limit to native table_names function. Use async table_names function from sync table_names function (#1059) The synchronous table_names function in python lancedb relies on arrow's filesystem which behaves slightly differently than object_store. As a result, the function would not work properly in GCS. However, the async table_names function uses object_store directly and thus is accurate. In most cases we can fallback to using the async table_names function and so this PR does so. The one case we cannot is if the user is already in an async context (we can't start a new async event loop). Soon, we can just redirect those users to use the async API instead of the sync API and so that case will eventually go away. For now, we fallback to the old behavior. --- .pre-commit-config.yaml | 1 + nodejs/__test__/arrow.test.ts | 4 +- nodejs/__test__/connection.test.ts | 22 +++++- nodejs/__test__/table.test.ts | 15 ++-- nodejs/eslint.config.js | 6 ++ nodejs/lancedb/connection.ts | 24 +++++- nodejs/lancedb/embedding/openai.ts | 1 + nodejs/lancedb/indexer.ts | 3 + nodejs/lancedb/native.d.ts | 2 +- nodejs/lancedb/query.ts | 10 +-- nodejs/lancedb/sanitize.ts | 2 + nodejs/src/connection.rs | 16 +++- python/python/lancedb/_lancedb.pyi | 4 +- python/python/lancedb/db.py | 64 +++++++++------- python/python/tests/test_db.py | 4 + python/src/connection.rs | 17 ++++- rust/ffi/node/src/lib.rs | 2 +- rust/lancedb/examples/simple.rs | 2 +- rust/lancedb/src/connection.rs | 111 ++++++++++++++++++++++++---- rust/lancedb/src/remote/db.rs | 21 +++--- rust/lancedb/tests/lancedb_cloud.rs | 2 +- 21 files changed, 250 insertions(+), 83 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccece35f..46f78607 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,3 +15,4 @@ repos: hooks: - id: prettier files: "nodejs/.*" + exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.* diff --git a/nodejs/__test__/arrow.test.ts b/nodejs/__test__/arrow.test.ts index 4a267c57..11317b37 100644 --- a/nodejs/__test__/arrow.test.ts +++ b/nodejs/__test__/arrow.test.ts @@ -457,8 +457,8 @@ describe("when using two versions of arrow", function () { expect(lhs.nullable).toEqual(rhs.nullable); expect(lhs.typeId).toEqual(rhs.typeId); if ("children" in lhs.type && lhs.type.children !== null) { - const lhs_children = lhs.type.children as Field[]; - lhs_children.forEach((child: Field, idx) => { + const lhsChildren = lhs.type.children as Field[]; + lhsChildren.forEach((child: Field, idx) => { compareFields(child, rhs.type.children[idx]); }); } diff --git a/nodejs/__test__/connection.test.ts b/nodejs/__test__/connection.test.ts index 8791705f..ebd7b757 100644 --- a/nodejs/__test__/connection.test.ts +++ b/nodejs/__test__/connection.test.ts @@ -66,9 +66,23 @@ describe("given a connection", () => { await expect(tbl.countRows()).resolves.toBe(1); }); - it("should list tables", async () => { - await db.createTable("test2", [{ id: 1 }, { id: 2 }]); - await db.createTable("test1", [{ id: 1 }, { id: 2 }]); - expect(await db.tableNames()).toEqual(["test1", "test2"]); + it("should respect limit and page token when listing tables", async () => { + const db = await connect(tmpDir.name); + + await db.createTable("b", [{ id: 1 }]); + await db.createTable("a", [{ id: 1 }]); + await db.createTable("c", [{ id: 1 }]); + + let tables = await db.tableNames(); + expect(tables).toEqual(["a", "b", "c"]); + + tables = await db.tableNames({ limit: 1 }); + expect(tables).toEqual(["a"]); + + tables = await db.tableNames({ limit: 1, startAfter: "a" }); + expect(tables).toEqual(["b"]); + + tables = await db.tableNames({ startAfter: "a" }); + expect(tables).toEqual(["b", "c"]); }); }); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index c06fccae..1c4c8879 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -103,12 +103,12 @@ describe("Test creating index", () => { // TODO: check index type. // Search without specifying the column - const query_vector = data.toArray()[5].vec.toJSON(); - const rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow(); + const queryVector = data.toArray()[5].vec.toJSON(); + const rst = await tbl.query().nearestTo(queryVector).limit(2).toArrow(); expect(rst.numRows).toBe(2); // Search with specifying the column - const rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow(); + const rst2 = await tbl.search(queryVector, "vec").limit(2).toArrow(); expect(rst2.numRows).toBe(2); expect(rst.toString()).toEqual(rst2.toString()); }); @@ -169,6 +169,7 @@ describe("Test creating index", () => { ); tbl .createIndex("vec") + // eslint-disable-next-line @typescript-eslint/naming-convention .ivf_pq({ num_partitions: 2, num_sub_vectors: 2 }) .build(); @@ -199,10 +200,10 @@ describe("Test creating index", () => { const query64 = Array(64) .fill(1) .map(() => Math.random()); - const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow(); - const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow(); - expect(rst64_1.toString()).toEqual(rst64_2.toString()); - expect(rst64_1.numRows).toBe(2); + const rst64Query = await tbl.query().nearestTo(query64).limit(2).toArrow(); + const rst64Search = await tbl.search(query64, "vec2").limit(2).toArrow(); + expect(rst64Query.toString()).toEqual(rst64Search.toString()); + expect(rst64Query.numRows).toBe(2); }); test("create scalar index", async () => { diff --git a/nodejs/eslint.config.js b/nodejs/eslint.config.js index 5485ff71..73afdbdc 100644 --- a/nodejs/eslint.config.js +++ b/nodejs/eslint.config.js @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/naming-convention */ // @ts-check const eslint = require("@eslint/js"); @@ -8,4 +9,9 @@ module.exports = tseslint.config( eslint.configs.recommended, eslintConfigPrettier, ...tseslint.configs.recommended, + { + rules: { + "@typescript-eslint/naming-convention": "error", + }, + }, ); diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index 157cb08c..b42bc3ba 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -35,6 +35,19 @@ export interface CreateTableOptions { existOk: boolean; } +export interface TableNamesOptions { + /** + * If present, only return names that come lexicographically after the + * supplied value. + * + * This can be combined with limit to implement pagination by setting this to + * the last table name from the previous page. + */ + startAfter?: string; + /** An optional limit to the number of results to return. */ + limit?: number; +} + /** * A LanceDB Connection that allows you to open tables and create new ones. * @@ -80,9 +93,14 @@ export class Connection { return this.inner.display(); } - /** List all the table names in this database. */ - async tableNames(): Promise { - return this.inner.tableNames(); + /** List all the table names in this database. + * + * Tables will be returned in lexicographical order. + * + * @param options Optional parameters to control the listing. + */ + async tableNames(options?: Partial): Promise { + return this.inner.tableNames(options?.startAfter, options?.limit); } /** diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index deade19e..a61079f3 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -27,6 +27,7 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { /** * @type {import("openai").default} */ + // eslint-disable-next-line @typescript-eslint/naming-convention let Openai; try { // eslint-disable-next-line @typescript-eslint/no-var-requires diff --git a/nodejs/lancedb/indexer.ts b/nodejs/lancedb/indexer.ts index 57ced46d..4c193940 100644 --- a/nodejs/lancedb/indexer.ts +++ b/nodejs/lancedb/indexer.ts @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO: Re-enable this as part of https://github.com/lancedb/lancedb/pull/1052 +/* eslint-disable @typescript-eslint/naming-convention */ + import { MetricType, IndexBuilder as NativeBuilder, diff --git a/nodejs/lancedb/native.d.ts b/nodejs/lancedb/native.d.ts index baa2199e..ebcc7329 100644 --- a/nodejs/lancedb/native.d.ts +++ b/nodejs/lancedb/native.d.ts @@ -78,7 +78,7 @@ export class Connection { isOpen(): boolean close(): void /** List all tables in the dataset. */ - tableNames(): Promise> + tableNames(startAfter?: string | undefined | null, limit?: number | undefined | null): Promise> /** * Create table from a Apache Arrow IPC (file) buffer. * diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index e57e61e8..cd86a310 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -20,7 +20,7 @@ import { } from "./native"; class RecordBatchIterator implements AsyncIterator { - private promised_inner?: Promise; + private promisedInner?: Promise; private inner?: NativeBatchIterator; constructor( @@ -29,13 +29,13 @@ class RecordBatchIterator implements AsyncIterator { ) { // TODO: check promise reliably so we dont need to pass two arguments. this.inner = inner; - this.promised_inner = promise; + this.promisedInner = promise; } // eslint-disable-next-line @typescript-eslint/no-explicit-any async next(): Promise>> { if (this.inner === undefined) { - this.inner = await this.promised_inner; + this.inner = await this.promisedInner; } if (this.inner === undefined) { throw new Error("Invalid iterator state state"); @@ -115,8 +115,8 @@ export class Query implements AsyncIterable { /** * Set the refine factor for the query. */ - refineFactor(refine_factor: number): Query { - this.inner.refineFactor(refine_factor); + refineFactor(refineFactor: number): Query { + this.inner.refineFactor(refineFactor); return this; } diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index 9a5face3..92ec01e6 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -168,6 +168,7 @@ function sanitizeTimestamp(typeLike: object) { function sanitizeTypedTimestamp( typeLike: object, + // eslint-disable-next-line @typescript-eslint/naming-convention Datatype: | typeof TimestampNanosecond | typeof TimestampMicrosecond @@ -235,6 +236,7 @@ function sanitizeUnion(typeLike: object) { function sanitizeTypedUnion( typeLike: object, + // eslint-disable-next-line @typescript-eslint/naming-convention UnionType: typeof DenseUnion | typeof SparseUnion, ) { if (!("typeIds" in typeLike)) { diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index afb7787e..295fccf0 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -89,9 +89,19 @@ impl Connection { /// List all tables in the dataset. #[napi] - pub async fn table_names(&self) -> napi::Result> { - self.get_inner()? - .table_names() + pub async fn table_names( + &self, + start_after: Option, + limit: Option, + ) -> napi::Result> { + let mut op = self.get_inner()?.table_names(); + if let Some(start_after) = start_after { + op = op.start_after(start_after); + } + if let Some(limit) = limit { + op = op.limit(limit); + } + op.execute() .await .map_err(|e| napi::Error::from_reason(format!("{}", e))) } diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8ac62fd1..2c9733b8 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -3,7 +3,9 @@ from typing import Optional import pyarrow as pa class Connection(object): - async def table_names(self) -> list[str]: ... + async def table_names( + self, start_after: Optional[str], limit: Optional[int] + ) -> list[str]: ... async def create_table( self, name: str, mode: str, data: pa.RecordBatchReader ) -> Table: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index c18656c3..f4f2a429 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -13,6 +13,7 @@ from __future__ import annotations +import asyncio import inspect import os from abc import abstractmethod @@ -27,6 +28,7 @@ from lancedb.common import data_to_reader, validate_schema from lancedb.embeddings.registry import EmbeddingFunctionRegistry from lancedb.utils.events import register_event +from ._lancedb import connect as lancedb_connect from .pydantic import LanceModel from .table import AsyncTable, LanceTable, Table, _sanitize_data from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri @@ -317,6 +319,10 @@ class LanceDBConnection(DBConnection): def uri(self) -> str: return self._uri + async def _async_get_table_names(self, start_after: Optional[str], limit: int): + conn = AsyncConnection(await lancedb_connect(self.uri)) + return await conn.table_names(start_after=start_after, limit=limit) + @override def table_names( self, page_token: Optional[str] = None, limit: int = 10 @@ -329,23 +335,31 @@ class LanceDBConnection(DBConnection): A list of table names. """ try: - filesystem = fs_from_uri(self.uri)[0] - except pa.ArrowInvalid: - raise NotImplementedError("Unsupported scheme: " + self.uri) + asyncio.get_running_loop() + # User application is async. Soon we will just tell them to use the + # async version. Until then fallback to the old sync implementation. + try: + filesystem = fs_from_uri(self.uri)[0] + except pa.ArrowInvalid: + raise NotImplementedError("Unsupported scheme: " + self.uri) - try: - loc = get_uri_location(self.uri) - paths = filesystem.get_file_info(fs.FileSelector(loc)) - except FileNotFoundError: - # It is ok if the file does not exist since it will be created - paths = [] - tables = [ - os.path.splitext(file_info.base_name)[0] - for file_info in paths - if file_info.extension == "lance" - ] - tables.sort() - return tables + try: + loc = get_uri_location(self.uri) + paths = filesystem.get_file_info(fs.FileSelector(loc)) + except FileNotFoundError: + # It is ok if the file does not exist since it will be created + paths = [] + tables = [ + os.path.splitext(file_info.base_name)[0] + for file_info in paths + if file_info.extension == "lance" + ] + tables.sort() + return tables + except RuntimeError: + # User application is sync. It is safe to use the async implementation + # under the hood. + return asyncio.run(self._async_get_table_names(page_token, limit)) def __len__(self) -> int: return len(self.table_names()) @@ -484,26 +498,26 @@ class AsyncConnection(object): self._inner.close() async def table_names( - self, *, page_token: Optional[str] = None, limit: Optional[int] = None + self, *, start_after: Optional[str] = None, limit: Optional[int] = None ) -> Iterable[str]: """List all tables in this database, in sorted order Parameters ---------- - page_token: str, optional - The token to use for pagination. If not present, start from the beginning. - Typically, this token is last table name from the previous page. - Only supported by LanceDb Cloud. + start_after: str, optional + If present, only return names that come lexicographically after the supplied + value. + + This can be combined with limit to implement pagination by setting this to + the last table name from the previous page. limit: int, default 10 - The size of the page to return. - Only supported by LanceDb Cloud. + The number of results to return. Returns ------- Iterable of str """ - # TODO: hook in page_token and limit - return await self._inner.table_names() + return await self._inner.table_names(start_after=start_after, limit=limit) async def create_table( self, diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 06b1c326..c84c0800 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -185,6 +185,10 @@ async def test_table_names_async(tmp_path): db = await lancedb.connect_async(tmp_path) assert await db.table_names() == ["test1", "test2", "test3"] + assert await db.table_names(limit=1) == ["test1"] + assert await db.table_names(start_after="test1", limit=1) == ["test2"] + assert await db.table_names(start_after="test1") == ["test2", "test3"] + def test_create_mode(tmp_path): db = lancedb.connect(tmp_path) diff --git a/python/src/connection.rs b/python/src/connection.rs index 7bfa2ae5..93f5332a 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -69,11 +69,20 @@ impl Connection { self.inner.take(); } - pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { + pub fn table_names( + self_: PyRef<'_, Self>, + start_after: Option, + limit: Option, + ) -> PyResult<&PyAny> { let inner = self_.get_inner()?.clone(); - future_into_py(self_.py(), async move { - inner.table_names().await.infer_error() - }) + let mut op = inner.table_names(); + if let Some(start_after) = start_after { + op = op.start_after(start_after); + } + if let Some(limit) = limit { + op = op.limit(limit); + } + future_into_py(self_.py(), async move { op.execute().await.infer_error() }) } pub fn create_table<'a>( diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 070e2afc..eee75144 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -132,7 +132,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult { let database = db.database.clone(); rt.spawn(async move { - let tables_rst = database.table_names().await; + let tables_rst = database.table_names().execute().await; deferred.settle_with(&channel, move |mut cx| { let tables = tables_rst.or_throw(&mut cx)?; diff --git a/rust/lancedb/examples/simple.rs b/rust/lancedb/examples/simple.rs index 51e8e44b..f5b4af8c 100644 --- a/rust/lancedb/examples/simple.rs +++ b/rust/lancedb/examples/simple.rs @@ -33,7 +33,7 @@ async fn main() -> Result<()> { // --8<-- [end:connect] // --8<-- [start:list_names] - println!("{:?}", db.table_names().await?); + println!("{:?}", db.table_names().execute().await?); // --8<-- [end:list_names] let tbl = create_table(&db).await?; create_index(&tbl).await?; diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 80957cd3..842539cd 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -78,6 +78,44 @@ enum BadVectorHandling { Fill(f32), } +/// A builder for configuring a [`Connection::table_names`] operation +pub struct TableNamesBuilder { + parent: Arc, + pub(crate) start_after: Option, + pub(crate) limit: Option, +} + +impl TableNamesBuilder { + fn new(parent: Arc) -> Self { + Self { + parent, + start_after: None, + limit: None, + } + } + + /// If present, only return names that come lexicographically after the supplied + /// value. + /// + /// This can be combined with limit to implement pagination by setting this to + /// the last table name from the previous page. + pub fn start_after(mut self, start_after: String) -> Self { + self.start_after = Some(start_after); + self + } + + /// The maximum number of table names to return + pub fn limit(mut self, limit: u32) -> Self { + self.limit = Some(limit); + self + } + + /// Execute the table names operation + pub async fn execute(self) -> Result> { + self.parent.clone().table_names(self).await + } +} + /// A builder for configuring a [`Connection::create_table`] operation pub struct CreateTableBuilder { parent: Arc, @@ -198,7 +236,7 @@ impl OpenTableBuilder { pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + std::fmt::Display + 'static { - async fn table_names(&self) -> Result>; + async fn table_names(&self, options: TableNamesBuilder) -> Result>; async fn do_create_table(&self, options: CreateTableBuilder) -> Result; async fn do_open_table(&self, options: OpenTableBuilder) -> Result
; async fn drop_table(&self, name: &str) -> Result<()>; @@ -232,9 +270,13 @@ impl Connection { self.uri.as_str() } - /// Get the names of all tables in the database. - pub async fn table_names(&self) -> Result> { - self.internal.table_names().await + /// Get the names of all tables in the database + /// + /// The names will be returned in lexicographical order (ascending) + /// + /// The parameters `page_token` and `limit` can be used to paginate the results + pub fn table_names(&self) -> TableNamesBuilder { + TableNamesBuilder::new(self.internal.clone()) } /// Create a new table from data @@ -613,7 +655,7 @@ impl Database { #[async_trait::async_trait] impl ConnectionInternal for Database { - async fn table_names(&self) -> Result> { + async fn table_names(&self, options: TableNamesBuilder) -> Result> { let mut f = self .object_store .read_dir(self.base_path.clone()) @@ -630,6 +672,16 @@ impl ConnectionInternal for Database { .filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from))) .collect::>(); f.sort(); + if let Some(start_after) = options.start_after { + let index = f + .iter() + .position(|name| name.as_str() > start_after.as_str()) + .unwrap_or(f.len()); + f.drain(0..index); + } + if let Some(limit) = options.limit { + f.truncate(limit as usize); + } Ok(f) } @@ -742,16 +794,43 @@ mod tests { #[tokio::test] async fn test_table_names() { let tmp_dir = tempdir().unwrap(); - create_dir_all(tmp_dir.path().join("table1.lance")).unwrap(); - create_dir_all(tmp_dir.path().join("table2.lance")).unwrap(); - create_dir_all(tmp_dir.path().join("invalidlance")).unwrap(); + let mut names = Vec::with_capacity(100); + for _ in 0..100 { + let name = uuid::Uuid::new_v4().to_string(); + names.push(name.clone()); + let table_name = name + ".lance"; + create_dir_all(tmp_dir.path().join(&table_name)).unwrap(); + } + names.sort(); let uri = tmp_dir.path().to_str().unwrap(); let db = connect(uri).execute().await.unwrap(); - let tables = db.table_names().await.unwrap(); - assert_eq!(tables.len(), 2); - assert!(tables[0].eq(&String::from("table1"))); - assert!(tables[1].eq(&String::from("table2"))); + let tables = db.table_names().execute().await.unwrap(); + + assert_eq!(tables, names); + + let tables = db + .table_names() + .start_after(names[30].clone()) + .execute() + .await + .unwrap(); + + assert_eq!(tables, names[31..]); + + let tables = db + .table_names() + .start_after(names[30].clone()) + .limit(7) + .execute() + .await + .unwrap(); + + assert_eq!(tables, names[31..38]); + + let tables = db.table_names().limit(7).execute().await.unwrap(); + + assert_eq!(tables, names[..7]); } #[tokio::test] @@ -766,14 +845,14 @@ mod tests { let uri = tmp_dir.path().to_str().unwrap(); let db = connect(uri).execute().await.unwrap(); - assert_eq!(db.table_names().await.unwrap().len(), 0); + assert_eq!(db.table_names().execute().await.unwrap().len(), 0); // open non-exist table assert!(matches!( db.open_table("invalid_table").execute().await, Err(crate::Error::TableNotFound { .. }) )); - assert_eq!(db.table_names().await.unwrap().len(), 0); + assert_eq!(db.table_names().execute().await.unwrap().len(), 0); let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); db.create_empty_table("table1", schema) @@ -781,7 +860,7 @@ mod tests { .await .unwrap(); db.open_table("table1").execute().await.unwrap(); - let tables = db.table_names().await.unwrap(); + let tables = db.table_names().execute().await.unwrap(); assert_eq!(tables, vec!["table1".to_owned()]); } @@ -801,7 +880,7 @@ mod tests { create_dir_all(tmp_dir.path().join("table1.lance")).unwrap(); db.drop_table("table1").await.unwrap(); - let tables = db.table_names().await.unwrap(); + let tables = db.table_names().execute().await.unwrap(); assert_eq!(tables.len(), 0); } diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 3dd17ec9..fe430f3d 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -19,7 +19,9 @@ use reqwest::header::CONTENT_TYPE; use serde::Deserialize; use tokio::task::spawn_blocking; -use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder}; +use crate::connection::{ + ConnectionInternal, CreateTableBuilder, OpenTableBuilder, TableNamesBuilder, +}; use crate::error::Result; use crate::Table; @@ -59,14 +61,15 @@ impl std::fmt::Display for RemoteDatabase { #[async_trait] impl ConnectionInternal for RemoteDatabase { - async fn table_names(&self) -> Result> { - let rsp = self - .client - .get("/v1/table/") - .query(&[("limit", 10)]) - .query(&[("page_token", "")]) - .send() - .await?; + async fn table_names(&self, options: TableNamesBuilder) -> Result> { + let mut req = self.client.get("/v1/table/"); + if let Some(limit) = options.limit { + req = req.query(&[("limit", limit)]); + } + if let Some(start_after) = options.start_after { + req = req.query(&[("page_token", start_after)]); + } + let rsp = req.send().await?; let rsp = self.client.check_response(rsp).await?; Ok(rsp.json::().await?.tables) } diff --git a/rust/lancedb/tests/lancedb_cloud.rs b/rust/lancedb/tests/lancedb_cloud.rs index 9bf75e91..84bd7158 100644 --- a/rust/lancedb/tests/lancedb_cloud.rs +++ b/rust/lancedb/tests/lancedb_cloud.rs @@ -62,6 +62,6 @@ async fn cloud_integration_test() { assert_eq!(tbl.name(), name); - let table_names = db.table_names().await.unwrap(); + let table_names = db.table_names().execute().await.unwrap(); assert!(table_names.contains(&name)); }