diff --git a/python/python/lancedb/namespace.py b/python/python/lancedb/namespace.py index 0996c63fd..14f6dd25a 100644 --- a/python/python/lancedb/namespace.py +++ b/python/python/lancedb/namespace.py @@ -10,7 +10,6 @@ through a namespace abstraction. from __future__ import annotations -import asyncio import sys from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union @@ -25,7 +24,24 @@ if TYPE_CHECKING: from datetime import timedelta import pyarrow as pa -from lancedb.db import DBConnection, LanceDBConnection +from lance_namespace_urllib3_client.models.json_arrow_data_type import JsonArrowDataType +from lance_namespace_urllib3_client.models.json_arrow_field import JsonArrowField +from lance_namespace_urllib3_client.models.json_arrow_schema import JsonArrowSchema +from lance_namespace_urllib3_client.models.query_table_request import QueryTableRequest +from lance_namespace_urllib3_client.models.query_table_request_columns import ( + QueryTableRequestColumns, +) +from lance_namespace_urllib3_client.models.query_table_request_full_text_query import ( + QueryTableRequestFullTextQuery, +) +from lance_namespace_urllib3_client.models.query_table_request_vector import ( + QueryTableRequestVector, +) +from lance_namespace_urllib3_client.models.string_fts_query import StringFtsQuery +from lance_namespace.errors import TableNotFoundError +from lancedb._lancedb import connect_namespace_client as _connect_namespace_client +from lancedb.background_loop import LOOP +from lancedb.db import AsyncConnection, DBConnection from lancedb.namespace_utils import ( _normalize_create_namespace_mode, _normalize_drop_namespace_mode, @@ -40,14 +56,11 @@ from lance_namespace import ( ListNamespacesResponse, ListTablesResponse, ListTablesRequest, - DescribeTableRequest, DescribeNamespaceRequest, DropTableRequest, ListNamespacesRequest, CreateNamespaceRequest, DropNamespaceRequest, - DeclareTableRequest, - CreateTableRequest, ) from lancedb.table import AsyncTable, LanceTable, Table from lancedb.util import validate_table_name @@ -56,21 +69,6 @@ from lancedb.pydantic import LanceModel from lancedb.embeddings import EmbeddingFunctionConfig from ._lancedb import Session -from lance_namespace_urllib3_client.models.json_arrow_schema import JsonArrowSchema -from lance_namespace_urllib3_client.models.json_arrow_field import JsonArrowField -from lance_namespace_urllib3_client.models.json_arrow_data_type import JsonArrowDataType -from lance_namespace_urllib3_client.models.query_table_request import QueryTableRequest -from lance_namespace_urllib3_client.models.query_table_request_vector import ( - QueryTableRequestVector, -) -from lance_namespace_urllib3_client.models.query_table_request_columns import ( - QueryTableRequestColumns, -) -from lance_namespace_urllib3_client.models.query_table_request_full_text_query import ( - QueryTableRequestFullTextQuery, -) -from lance_namespace_urllib3_client.models.string_fts_query import StringFtsQuery - def _query_to_namespace_request( table_id: List[str], @@ -424,6 +422,23 @@ class LanceNamespaceDBConnection(DBConnection): ) self._namespace_client_impl = namespace_client_impl self._namespace_client_properties = namespace_client_properties + self._inner = AsyncConnection( + _connect_namespace_client( + namespace_client, + read_consistency_interval=( + read_consistency_interval.total_seconds() + if read_consistency_interval is not None + else None + ), + storage_options=self.storage_options or None, + session=session, + namespace_client_pushdown_operations=( + list(self._namespace_client_pushdown_operations) + ), + namespace_client_impl=namespace_client_impl, + namespace_client_properties=namespace_client_properties, + ) + ) @override def serialize(self) -> str: @@ -497,13 +512,10 @@ class LanceNamespaceDBConnection(DBConnection): if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") validate_table_name(name) - - table_id = namespace_path + [name] - - if "CreateTable" in self._namespace_client_pushdown_operations: - return self._create_table_server_side( - name=name, - data=data, + async_table = LOOP.run( + self._inner.create_table( + name, + data, schema=schema, mode=mode, exist_ok=exist_ok, @@ -513,130 +525,15 @@ class LanceNamespaceDBConnection(DBConnection): namespace_path=namespace_path, storage_options=storage_options, ) - - # Local create path: declare_table + local write - # Step 1: Get the table location and storage options from namespace - # In overwrite mode, if table exists, use describe_table to get - # existing location. Otherwise, call create_empty_table to reserve - # a new location - location = None - namespace_storage_options = None - if mode.lower() == "overwrite": - # Try to describe the table first to see if it exists - try: - describe_request = DescribeTableRequest(id=table_id) - describe_response = self._namespace_client.describe_table( - describe_request - ) - location = describe_response.location - namespace_storage_options = describe_response.storage_options - except Exception: - # Table doesn't exist, will create a new one below - pass - - if location is None: - # Table doesn't exist or mode is "create", reserve a new location - declare_request = DeclareTableRequest( - id=table_id, - location=None, - properties=self.storage_options if self.storage_options else None, - ) - declare_response = self._namespace_client.declare_table(declare_request) - - if not declare_response.location: - raise ValueError( - "Table location is missing from declare_table response" - ) - - location = declare_response.location - namespace_storage_options = declare_response.storage_options - - # Merge storage options: self.storage_options < user options < namespace options - merged_storage_options = dict(self.storage_options) - if storage_options: - merged_storage_options.update(storage_options) - if namespace_storage_options: - merged_storage_options.update(namespace_storage_options) - - # Step 2: Create table using LanceTable.create with the location - # We need a temporary connection for the LanceTable.create method - temp_conn = LanceDBConnection( - location, # Use the actual location as the connection URI - read_consistency_interval=self.read_consistency_interval, - storage_options=merged_storage_options, - session=self.session, ) - # Note: storage_options_provider is auto-created in Rust from namespace_client - tbl = LanceTable.create( - temp_conn, + return LanceTable( + self, name, - data, - schema, - mode=mode, - exist_ok=exist_ok, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - embedding_functions=embedding_functions, namespace_path=namespace_path, - storage_options=merged_storage_options, - location=location, namespace_client=self._namespace_client, pushdown_operations=self._namespace_client_pushdown_operations, - ) - - return tbl - - def _create_table_server_side( - self, - name: str, - data: Optional[DATA], - schema: Optional[Union[pa.Schema, LanceModel]], - mode: str, - exist_ok: bool, - on_bad_vectors: str, - fill_value: float, - embedding_functions: Optional[List[EmbeddingFunctionConfig]], - namespace_path: Optional[List[str]], - storage_options: Optional[Dict[str, str]], - ) -> Table: - """Create a table using server-side namespace.create_table().""" - if namespace_path is None: - namespace_path = [] - table_id = namespace_path + [name] - - arrow_ipc_bytes = _data_to_arrow_ipc( - data=data, - schema=schema, - embedding_functions=embedding_functions, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - - merged = dict(self.storage_options or {}) - if storage_options: - merged.update(storage_options) - request = CreateTableRequest( - id=table_id, - mode=_normalize_create_table_mode(mode), - properties=merged or None, - ) - - try: - self._namespace_client.create_table(request, arrow_ipc_bytes) - except Exception as e: - if exist_ok and "already exists" in str(e).lower(): - return self.open_table( - name, - namespace_path=namespace_path, - storage_options=storage_options, - ) - raise - - return self.open_table( - name, - namespace_path=namespace_path, - storage_options=storage_options, + _async=async_table, ) @override @@ -650,30 +547,28 @@ class LanceNamespaceDBConnection(DBConnection): ) -> Table: if namespace_path is None: namespace_path = [] - table_id = namespace_path + [name] - request = DescribeTableRequest(id=table_id) - response = self._namespace_client.describe_table(request) + try: + async_table = LOOP.run( + self._inner.open_table( + name, + namespace_path=namespace_path, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) + ) + except RuntimeError as e: + if "Table not found" in str(e): + table_id = namespace_path + [name] + raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}") + raise - # Merge storage options: self.storage_options < user options < namespace options - merged_storage_options = dict(self.storage_options) - if storage_options: - merged_storage_options.update(storage_options) - if response.storage_options: - merged_storage_options.update(response.storage_options) - - # Pass managed_versioning to avoid redundant describe_table call in Rust. - # Convert None to False since we already have the answer from describe_table. - managed_versioning = response.managed_versioning is True - - # Note: storage_options_provider is auto-created in Rust from namespace_client - return self._lance_table_from_uri( + return LanceTable( + self, name, - response.location, namespace_path=namespace_path, - storage_options=merged_storage_options, - index_cache_size=index_cache_size, namespace_client=self._namespace_client, - managed_versioning=managed_versioning, + pushdown_operations=self._namespace_client_pushdown_operations, + _async=async_table, ) @override @@ -897,33 +792,34 @@ class LanceNamespaceDBConnection(DBConnection): namespace_client: Optional[Any] = None, managed_versioning: Optional[bool] = None, ) -> LanceTable: - # Open a table directly from a URI using the location parameter - # Note: storage_options should already be merged by the caller - # Note: storage_options_provider is auto-created in Rust from namespace_client + # Open a table directly from the namespace-resolved physical location. + # + # Open the table through the Rust namespace-backed connection. The Rust + # layer keeps the logical namespace path and namespace client intact. if namespace_path is None: namespace_path = [] - temp_conn = LanceDBConnection( - table_uri, # Use the table location as the connection URI - read_consistency_interval=self.read_consistency_interval, - storage_options=storage_options if storage_options is not None else {}, - session=self.session, + + async_table = LOOP.run( + self._inner.open_table( + name, + namespace_path=namespace_path, + storage_options=storage_options, + index_cache_size=index_cache_size, + location=None, + namespace_client=namespace_client, + managed_versioning=managed_versioning, + ) ) - # Open the table using the temporary connection with the location parameter - # Pass namespace_client to enable managed versioning support and auto-create - # storage options provider - # Pass managed_versioning to avoid redundant describe_table call - # Pass pushdown_operations if configured on this connection - return LanceTable.open( - temp_conn, + return LanceTable( + self, name, namespace_path=namespace_path, - storage_options=storage_options, - index_cache_size=index_cache_size, location=table_uri, namespace_client=namespace_client, managed_versioning=managed_versioning, pushdown_operations=self._namespace_client_pushdown_operations, + _async=async_table, ) @override @@ -990,6 +886,23 @@ class AsyncLanceNamespaceDBConnection: self._namespace_client_pushdown_operations = set( namespace_client_pushdown_operations or [] ) + self._inner = AsyncConnection( + _connect_namespace_client( + namespace_client, + read_consistency_interval=( + read_consistency_interval.total_seconds() + if read_consistency_interval is not None + else None + ), + storage_options=self.storage_options or None, + session=session, + namespace_client_pushdown_operations=( + list(self._namespace_client_pushdown_operations) + ), + namespace_client_impl=None, + namespace_client_properties=None, + ) + ) async def table_names( self, @@ -1041,148 +954,16 @@ class AsyncLanceNamespaceDBConnection: if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") validate_table_name(name) - - table_id = namespace_path + [name] - - if "CreateTable" in self._namespace_client_pushdown_operations: - return await self._create_table_server_side( - name=name, - data=data, - schema=schema, - mode=mode, - exist_ok=exist_ok, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - embedding_functions=embedding_functions, - namespace_path=namespace_path, - storage_options=storage_options, - ) - - # Local create path: declare_table + local write - # Step 1: Get the table location and storage options from namespace - location = None - namespace_storage_options = None - if mode.lower() == "overwrite": - # Try to describe the table first to see if it exists - try: - describe_request = DescribeTableRequest(id=table_id) - describe_response = self._namespace_client.describe_table( - describe_request - ) - location = describe_response.location - namespace_storage_options = describe_response.storage_options - except Exception: - # Table doesn't exist, will create a new one below - pass - - if location is None: - # Table doesn't exist or mode is "create", reserve a new location - declare_request = DeclareTableRequest( - id=table_id, - location=None, - properties=self.storage_options if self.storage_options else None, - ) - declare_response = self._namespace_client.declare_table(declare_request) - - if not declare_response.location: - raise ValueError( - "Table location is missing from declare_table response" - ) - - location = declare_response.location - namespace_storage_options = declare_response.storage_options - - # Merge storage options: self.storage_options < user options < namespace options - merged_storage_options = dict(self.storage_options) - if storage_options: - merged_storage_options.update(storage_options) - if namespace_storage_options: - merged_storage_options.update(namespace_storage_options) - - # Step 2: Create table using LanceTable.create with the location - # Run the sync operation in a thread - def _create_table(): - temp_conn = LanceDBConnection( - location, - read_consistency_interval=self.read_consistency_interval, - storage_options=merged_storage_options, - session=self.session, - ) - - # storage_options_provider is auto-created in Rust from namespace_client - return LanceTable.create( - temp_conn, - name, - data, - schema, - mode=mode, - exist_ok=exist_ok, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - embedding_functions=embedding_functions, - namespace_path=namespace_path, - storage_options=merged_storage_options, - location=location, - namespace_client=self._namespace_client, - pushdown_operations=self._namespace_client_pushdown_operations, - ) - - lance_table = await asyncio.to_thread(_create_table) - # Get the underlying async table from LanceTable - return lance_table._table - - async def _create_table_server_side( - self, - name: str, - data: Optional[DATA], - schema: Optional[Union[pa.Schema, LanceModel]], - mode: str, - exist_ok: bool, - on_bad_vectors: str, - fill_value: float, - embedding_functions: Optional[List[EmbeddingFunctionConfig]], - namespace_path: Optional[List[str]], - storage_options: Optional[Dict[str, str]], - ) -> AsyncTable: - """Create a table using server-side namespace.create_table().""" - if namespace_path is None: - namespace_path = [] - table_id = namespace_path + [name] - - def _prepare_and_create(): - arrow_ipc_bytes = _data_to_arrow_ipc( - data=data, - schema=schema, - embedding_functions=embedding_functions, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - - merged = dict(self.storage_options or {}) - if storage_options: - merged.update(storage_options) - request = CreateTableRequest( - id=table_id, - mode=_normalize_create_table_mode(mode), - properties=merged or None, - ) - - self._namespace_client.create_table(request, arrow_ipc_bytes) - - try: - await asyncio.to_thread(_prepare_and_create) - except Exception as e: - if exist_ok and "already exists" in str(e).lower(): - return await self.open_table( - name, - namespace_path=namespace_path, - storage_options=storage_options, - ) - raise - - return await self.open_table( + return await self._inner.create_table( name, + data, + schema=schema, + mode=mode, + exist_ok=exist_ok, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, namespace_path=namespace_path, + embedding_functions=embedding_functions, storage_options=storage_options, ) @@ -1197,45 +978,18 @@ class AsyncLanceNamespaceDBConnection: """Open an existing table from the namespace.""" if namespace_path is None: namespace_path = [] - table_id = namespace_path + [name] - request = DescribeTableRequest(id=table_id) - response = self._namespace_client.describe_table(request) - - # Merge storage options: self.storage_options < user options < namespace options - merged_storage_options = dict(self.storage_options) - if storage_options: - merged_storage_options.update(storage_options) - if response.storage_options: - merged_storage_options.update(response.storage_options) - - # Capture managed_versioning from describe response. - # Convert None to False since we already have the answer from describe_table. - managed_versioning = response.managed_versioning is True - - # Open table in a thread - # Note: storage_options_provider is auto-created in Rust from namespace_client - def _open_table(): - temp_conn = LanceDBConnection( - response.location, - read_consistency_interval=self.read_consistency_interval, - storage_options=merged_storage_options, - session=self.session, - ) - - return LanceTable.open( - temp_conn, + try: + return await self._inner.open_table( name, namespace_path=namespace_path, - storage_options=merged_storage_options, + storage_options=storage_options, index_cache_size=index_cache_size, - location=response.location, - namespace_client=self._namespace_client, - managed_versioning=managed_versioning, - pushdown_operations=self._namespace_client_pushdown_operations, ) - - lance_table = await asyncio.to_thread(_open_table) - return lance_table._table + except RuntimeError as e: + if "Table not found" in str(e): + table_id = namespace_path + [name] + raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}") + raise async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None): """Drop a table from the namespace.""" diff --git a/python/src/connection.rs b/python/src/connection.rs index f19bfba97..9c67f38c7 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -1,11 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; use lancedb::{ connection::Connection as LanceConnection, + connection::NamespaceClientPushdownOperation, + database::namespace::LanceNamespaceDatabase, database::{CreateTableMode, Database, ReadConsistency}, }; use pyo3::{ @@ -39,6 +45,29 @@ impl Connection { } } +fn parse_namespace_client_pushdown_operations( + operations: Option>, +) -> PyResult> { + let mut parsed = HashSet::new(); + for operation in operations.unwrap_or_default() { + match operation.as_str() { + "QueryTable" => { + parsed.insert(NamespaceClientPushdownOperation::QueryTable); + } + "CreateTable" => { + parsed.insert(NamespaceClientPushdownOperation::CreateTable); + } + _ => { + return Err(PyValueError::new_err(format!( + "Invalid pushdown operation: {}", + operation + ))); + } + } + } + Ok(parsed) +} + impl Connection { fn parse_create_mode_str(mode: &str) -> PyResult { match mode { @@ -538,6 +567,52 @@ pub fn connect( }) } +#[pyfunction] +#[pyo3(signature = ( + namespace_client, + read_consistency_interval=None, + storage_options=None, + session=None, + namespace_client_pushdown_operations=None, + namespace_client_impl=None, + namespace_client_properties=None, +))] +#[allow(clippy::too_many_arguments)] +pub fn connect_namespace_client( + py: Python<'_>, + namespace_client: Py, + read_consistency_interval: Option, + storage_options: Option>, + session: Option, + namespace_client_pushdown_operations: Option>, + namespace_client_impl: Option, + namespace_client_properties: Option>, +) -> PyResult { + let namespace_client = extract_namespace_arc(py, namespace_client)?; + let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64); + let namespace_client_pushdown_operations = + parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?; + let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string()); + let ns_properties = namespace_client_properties.unwrap_or_default(); + let storage_options = storage_options.unwrap_or_default(); + let session = session.map(|s| s.inner.clone()); + + let database = LanceNamespaceDatabase::from_namespace_client( + namespace_client, + ns_impl, + ns_properties, + storage_options, + read_consistency_interval, + session, + namespace_client_pushdown_operations, + ); + + Ok(Connection::new(LanceConnection::new( + Arc::new(database), + Arc::new(lancedb::embeddings::MemoryRegistry::new()), + ))) +} + #[derive(FromPyObject)] pub struct PyClientConfig { user_agent: String, diff --git a/python/src/lib.rs b/python/src/lib.rs index e6294cd14..7dd52bdc2 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors use arrow::RecordBatchStream; -use connection::{Connection, connect}; +use connection::{Connection, connect, connect_namespace_client}; use env_logger::Env; use expr::{PyExpr, expr_col, expr_func, expr_lit}; use index::IndexConfig; @@ -58,6 +58,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; + m.add_function(wrap_pyfunction!(connect_namespace_client, m)?)?; m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?; m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?; diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 35863aa60..9e0d3ea3f 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -915,7 +915,7 @@ use std::collections::HashSet; /// These operations will be executed on the namespace server instead of locally /// when enabled via [`ConnectNamespaceBuilder::pushdown_operations`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum PushdownOperation { +pub enum NamespaceClientPushdownOperation { /// Execute queries on the namespace server via `query_table()` instead of locally. QueryTable, /// Execute table creation on the namespace server via `create_table()` @@ -931,7 +931,7 @@ pub struct ConnectNamespaceBuilder { read_consistency_interval: Option, embedding_registry: Option>, session: Option>, - pushdown_operations: HashSet, + pushdown_operations: HashSet, } impl ConnectNamespaceBuilder { @@ -1029,11 +1029,11 @@ impl ConnectNamespaceBuilder { /// and leveraging server-side compute resources. /// /// Available operations: - /// - [`PushdownOperation::QueryTable`]: Execute queries via `namespace.query_table()` - /// - [`PushdownOperation::CreateTable`]: Execute table creation via `namespace.create_table()` + /// - [`NamespaceClientPushdownOperation::QueryTable`]: Execute queries via `namespace.query_table()` + /// - [`NamespaceClientPushdownOperation::CreateTable`]: Execute table creation via `namespace.create_table()` /// /// By default, no operations are pushed down (all executed locally). - pub fn pushdown_operation(mut self, operation: PushdownOperation) -> Self { + pub fn pushdown_operation(mut self, operation: NamespaceClientPushdownOperation) -> Self { self.pushdown_operations.insert(operation); self } @@ -1043,7 +1043,7 @@ impl ConnectNamespaceBuilder { /// See [`Self::pushdown_operation`] for details. pub fn pushdown_operations( mut self, - operations: impl IntoIterator, + operations: impl IntoIterator, ) -> Self { self.pushdown_operations.extend(operations); self diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 6b0d19054..07f7e9ea4 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -22,10 +22,11 @@ use lance_namespace_impls::ConnectBuilder; use lance_table::io::commit::CommitHandler; use lance_table::io::commit::external_manifest::ExternalManifestCommitHandler; -use crate::connection::PushdownOperation; +use crate::connection::NamespaceClientPushdownOperation; use crate::database::ReadConsistency; use crate::error::{Error, Result}; use crate::table::NativeTable; +use lance::dataset::WriteMode; use super::{ BaseTable, CloneTableRequest, CreateTableMode, CreateTableRequest as DbCreateTableRequest, @@ -44,7 +45,7 @@ pub struct LanceNamespaceDatabase { // database URI uri: String, // Operations to push down to the namespace server - pushdown_operations: HashSet, + pushdown_operations: HashSet, // Namespace implementation type (e.g., "dir", "rest") ns_impl: String, // Namespace properties used to construct the namespace client @@ -52,13 +53,34 @@ pub struct LanceNamespaceDatabase { } impl LanceNamespaceDatabase { + pub fn from_namespace_client( + namespace_client: Arc, + namespace_client_impl: String, + namespace_client_properties: HashMap, + storage_options: HashMap, + read_consistency_interval: Option, + session: Option>, + namespace_client_pushdown_operations: HashSet, + ) -> Self { + Self { + namespace: namespace_client, + storage_options, + read_consistency_interval, + session, + uri: format!("namespace://{}", namespace_client_impl), + pushdown_operations: namespace_client_pushdown_operations, + ns_impl: namespace_client_impl, + ns_properties: namespace_client_properties, + } + } + pub async fn connect( ns_impl: &str, ns_properties: HashMap, storage_options: HashMap, read_consistency_interval: Option, session: Option>, - pushdown_operations: HashSet, + pushdown_operations: HashSet, ) -> Result { let mut builder = ConnectBuilder::new(ns_impl); for (key, value) in ns_properties.clone() { @@ -163,37 +185,23 @@ impl Database for LanceNamespaceDatabase { async fn create_table(&self, request: DbCreateTableRequest) -> Result> { let mut table_id = request.namespace_path.clone(); table_id.push(request.name.clone()); - let describe_request = DescribeTableRequest { - id: Some(table_id.clone()), - ..Default::default() - }; - - let describe_result = self.namespace.describe_table(describe_request).await; + let mut existing_table = None; match request.mode { - CreateTableMode::Create => { - if describe_result.is_ok() { - return Err(Error::TableAlreadyExists { - name: request.name.clone(), - }); - } - } + CreateTableMode::Create => {} CreateTableMode::Overwrite => { - if describe_result.is_ok() { - // Drop the existing table - must succeed - let drop_request = DropTableRequest { - id: Some(table_id.clone()), - ..Default::default() - }; - self.namespace - .drop_table(drop_request) - .await - .map_err(|e| Error::Runtime { - message: format!("Failed to drop existing table for overwrite: {}", e), - })?; - } + let describe_request = DescribeTableRequest { + id: Some(table_id.clone()), + ..Default::default() + }; + existing_table = self.namespace.describe_table(describe_request).await.ok(); } CreateTableMode::ExistOk(_) => { + let describe_request = DescribeTableRequest { + id: Some(table_id.clone()), + ..Default::default() + }; + let describe_result = self.namespace.describe_table(describe_request).await; if describe_result.is_ok() { let native_table = NativeTable::open_from_namespace( self.namespace.clone(), @@ -221,21 +229,55 @@ impl Database for LanceNamespaceDatabase { }; let (location, initial_storage_options, managed_versioning) = { - let response = self.namespace.declare_table(declare_request).await?; - let loc = response.location.ok_or_else(|| Error::Runtime { - message: "Table location is missing from declare_table response".to_string(), - })?; - // Use storage options from response, fall back to self.storage_options - let opts = response - .storage_options - .or_else(|| Some(self.storage_options.clone())) - .filter(|o| !o.is_empty()); - (loc, opts, response.managed_versioning) + if let Some(response) = existing_table { + let loc = response.location.ok_or_else(|| Error::Runtime { + message: "Table location is missing from describe_table response".to_string(), + })?; + let opts = response + .storage_options + .or_else(|| Some(self.storage_options.clone())) + .filter(|o| !o.is_empty()); + (loc, opts, response.managed_versioning) + } else { + let response = self + .namespace + .declare_table(declare_request) + .await + .map_err(|e| { + let err_str = e.to_string(); + if matches!(request.mode, CreateTableMode::Create) + && (err_str.contains("already exists") + || err_str.contains("TableAlreadyExists") + || err_str.contains("table already exists")) + { + Error::TableAlreadyExists { + name: request.name.clone(), + } + } else { + Error::Runtime { + message: format!("Failed to declare table: {}", e), + } + } + })?; + let loc = response.location.ok_or_else(|| Error::Runtime { + message: "Table location is missing from declare_table response".to_string(), + })?; + // Use storage options from response, fall back to self.storage_options + let opts = response + .storage_options + .or_else(|| Some(self.storage_options.clone())) + .filter(|o| !o.is_empty()); + (loc, opts, response.managed_versioning) + } }; // Build write params with storage options and commit handler let mut params = request.write_options.lance_write_params.unwrap_or_default(); + if matches!(request.mode, CreateTableMode::Overwrite) { + params.mode = WriteMode::Overwrite; + } + // Set up storage options if provided if let Some(storage_opts) = initial_storage_options { let store_params = params diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index b83a880be..73415e89b 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -47,7 +47,7 @@ use std::format; use std::path::Path; use std::sync::Arc; -use crate::connection::PushdownOperation; +use crate::connection::NamespaceClientPushdownOperation; use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions}; use crate::database::Database; @@ -1272,7 +1272,7 @@ pub struct NativeTable { pub(crate) namespace_client: Option>, // Operations to push down to the namespace server. // pub(crate) so query.rs can access the field for server-side query execution. - pub(crate) pushdown_operations: HashSet, + pub(crate) pushdown_operations: HashSet, } impl std::fmt::Debug for NativeTable { @@ -1359,7 +1359,7 @@ impl NativeTable { params: Option, read_consistency_interval: Option, namespace_client: Option>, - pushdown_operations: HashSet, + pushdown_operations: HashSet, managed_versioning: Option, ) -> Result { let params = params.unwrap_or_default(); @@ -1470,7 +1470,7 @@ impl NativeTable { write_store_wrapper: Option>, params: Option, read_consistency_interval: Option, - pushdown_operations: HashSet, + pushdown_operations: HashSet, session: Option>, ) -> Result { let mut params = params.unwrap_or_default(); @@ -1518,7 +1518,7 @@ impl NativeTable { let id = Self::build_id(&namespace, name); let stored_namespace_client = - if pushdown_operations.contains(&PushdownOperation::QueryTable) { + if pushdown_operations.contains(&NamespaceClientPushdownOperation::QueryTable) { Some(namespace_client) } else { None @@ -1588,7 +1588,7 @@ impl NativeTable { params: Option, read_consistency_interval: Option, namespace_client: Option>, - pushdown_operations: HashSet, + pushdown_operations: HashSet, ) -> Result { // Default params uses format v1. let params = params.unwrap_or(WriteParams { @@ -1635,7 +1635,7 @@ impl NativeTable { params: Option, read_consistency_interval: Option, namespace_client: Option>, - pushdown_operations: HashSet, + pushdown_operations: HashSet, ) -> Result { let data: Box = Box::new(RecordBatch::new_empty(schema)); Self::create( @@ -1685,7 +1685,7 @@ impl NativeTable { write_store_wrapper: Option>, params: Option, read_consistency_interval: Option, - pushdown_operations: HashSet, + pushdown_operations: HashSet, session: Option>, ) -> Result { // Build table_id from namespace + name for the storage options provider @@ -1738,7 +1738,7 @@ impl NativeTable { let id = Self::build_id(&namespace, name); let stored_namespace_client = - if pushdown_operations.contains(&PushdownOperation::QueryTable) { + if pushdown_operations.contains(&NamespaceClientPushdownOperation::QueryTable) { Some(namespace_client) } else { None diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index ae6a8cd08..e7e66a901 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use super::NativeTable; -use crate::connection::PushdownOperation; +use crate::connection::NamespaceClientPushdownOperation; use crate::error::{Error, Result}; use crate::expr::expr_to_sql_string; use crate::query::{ @@ -44,7 +44,7 @@ pub async fn execute_query( // If QueryTable pushdown is enabled and namespace client is configured, use server-side query execution if table .pushdown_operations - .contains(&PushdownOperation::QueryTable) + .contains(&NamespaceClientPushdownOperation::QueryTable) && let Some(ref namespace_client) = table.namespace_client { return execute_namespace_query(table, namespace_client.clone(), query, options).await;