mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-13 18:10:41 +00:00
fix(python): use namespace-backed rust connection for namespace tables (#3286)
So far, I have been using a hacky approach that creates and opens namespace-backed table, by getting its location and use a temporary lancedb connection to create or open it. This was working for features like credentials vending but is no longer fully working for the managed versioning feature, recently geneva tests have been failing here and there and various patches are not addressing the root cause. This PR fully fixes this and implements proper rust binding for it. Specifically: - build a real Rust namespace-backed connection from the Python namespace client - route namespace table create/open through that connection instead of resolved-location temp connections - keep namespace client naming consistent in the Rust bridge and preserve federated namespace + DuckDB behavior
This commit is contained in:
@@ -10,7 +10,6 @@ through a namespace abstraction.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
@@ -25,7 +24,24 @@ if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.db import DBConnection, LanceDBConnection
|
||||
from lance_namespace_urllib3_client.models.json_arrow_data_type import JsonArrowDataType
|
||||
from lance_namespace_urllib3_client.models.json_arrow_field import JsonArrowField
|
||||
from lance_namespace_urllib3_client.models.json_arrow_schema import JsonArrowSchema
|
||||
from lance_namespace_urllib3_client.models.query_table_request import QueryTableRequest
|
||||
from lance_namespace_urllib3_client.models.query_table_request_columns import (
|
||||
QueryTableRequestColumns,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.query_table_request_full_text_query import (
|
||||
QueryTableRequestFullTextQuery,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.query_table_request_vector import (
|
||||
QueryTableRequestVector,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.string_fts_query import StringFtsQuery
|
||||
from lance_namespace.errors import TableNotFoundError
|
||||
from lancedb._lancedb import connect_namespace_client as _connect_namespace_client
|
||||
from lancedb.background_loop import LOOP
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
from lancedb.namespace_utils import (
|
||||
_normalize_create_namespace_mode,
|
||||
_normalize_drop_namespace_mode,
|
||||
@@ -40,14 +56,11 @@ from lance_namespace import (
|
||||
ListNamespacesResponse,
|
||||
ListTablesResponse,
|
||||
ListTablesRequest,
|
||||
DescribeTableRequest,
|
||||
DescribeNamespaceRequest,
|
||||
DropTableRequest,
|
||||
ListNamespacesRequest,
|
||||
CreateNamespaceRequest,
|
||||
DropNamespaceRequest,
|
||||
DeclareTableRequest,
|
||||
CreateTableRequest,
|
||||
)
|
||||
from lancedb.table import AsyncTable, LanceTable, Table
|
||||
from lancedb.util import validate_table_name
|
||||
@@ -56,21 +69,6 @@ from lancedb.pydantic import LanceModel
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig
|
||||
from ._lancedb import Session
|
||||
|
||||
from lance_namespace_urllib3_client.models.json_arrow_schema import JsonArrowSchema
|
||||
from lance_namespace_urllib3_client.models.json_arrow_field import JsonArrowField
|
||||
from lance_namespace_urllib3_client.models.json_arrow_data_type import JsonArrowDataType
|
||||
from lance_namespace_urllib3_client.models.query_table_request import QueryTableRequest
|
||||
from lance_namespace_urllib3_client.models.query_table_request_vector import (
|
||||
QueryTableRequestVector,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.query_table_request_columns import (
|
||||
QueryTableRequestColumns,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.query_table_request_full_text_query import (
|
||||
QueryTableRequestFullTextQuery,
|
||||
)
|
||||
from lance_namespace_urllib3_client.models.string_fts_query import StringFtsQuery
|
||||
|
||||
|
||||
def _query_to_namespace_request(
|
||||
table_id: List[str],
|
||||
@@ -424,6 +422,23 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
)
|
||||
self._namespace_client_impl = namespace_client_impl
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
read_consistency_interval=(
|
||||
read_consistency_interval.total_seconds()
|
||||
if read_consistency_interval is not None
|
||||
else None
|
||||
),
|
||||
storage_options=self.storage_options or None,
|
||||
session=session,
|
||||
namespace_client_pushdown_operations=(
|
||||
list(self._namespace_client_pushdown_operations)
|
||||
),
|
||||
namespace_client_impl=namespace_client_impl,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def serialize(self) -> str:
|
||||
@@ -497,13 +512,10 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
table_id = namespace_path + [name]
|
||||
|
||||
if "CreateTable" in self._namespace_client_pushdown_operations:
|
||||
return self._create_table_server_side(
|
||||
name=name,
|
||||
data=data,
|
||||
async_table = LOOP.run(
|
||||
self._inner.create_table(
|
||||
name,
|
||||
data,
|
||||
schema=schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
@@ -513,130 +525,15 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
# Local create path: declare_table + local write
|
||||
# Step 1: Get the table location and storage options from namespace
|
||||
# In overwrite mode, if table exists, use describe_table to get
|
||||
# existing location. Otherwise, call create_empty_table to reserve
|
||||
# a new location
|
||||
location = None
|
||||
namespace_storage_options = None
|
||||
if mode.lower() == "overwrite":
|
||||
# Try to describe the table first to see if it exists
|
||||
try:
|
||||
describe_request = DescribeTableRequest(id=table_id)
|
||||
describe_response = self._namespace_client.describe_table(
|
||||
describe_request
|
||||
)
|
||||
location = describe_response.location
|
||||
namespace_storage_options = describe_response.storage_options
|
||||
except Exception:
|
||||
# Table doesn't exist, will create a new one below
|
||||
pass
|
||||
|
||||
if location is None:
|
||||
# Table doesn't exist or mode is "create", reserve a new location
|
||||
declare_request = DeclareTableRequest(
|
||||
id=table_id,
|
||||
location=None,
|
||||
properties=self.storage_options if self.storage_options else None,
|
||||
)
|
||||
declare_response = self._namespace_client.declare_table(declare_request)
|
||||
|
||||
if not declare_response.location:
|
||||
raise ValueError(
|
||||
"Table location is missing from declare_table response"
|
||||
)
|
||||
|
||||
location = declare_response.location
|
||||
namespace_storage_options = declare_response.storage_options
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
if storage_options:
|
||||
merged_storage_options.update(storage_options)
|
||||
if namespace_storage_options:
|
||||
merged_storage_options.update(namespace_storage_options)
|
||||
|
||||
# Step 2: Create table using LanceTable.create with the location
|
||||
# We need a temporary connection for the LanceTable.create method
|
||||
temp_conn = LanceDBConnection(
|
||||
location, # Use the actual location as the connection URI
|
||||
read_consistency_interval=self.read_consistency_interval,
|
||||
storage_options=merged_storage_options,
|
||||
session=self.session,
|
||||
)
|
||||
|
||||
# Note: storage_options_provider is auto-created in Rust from namespace_client
|
||||
tbl = LanceTable.create(
|
||||
temp_conn,
|
||||
return LanceTable(
|
||||
self,
|
||||
name,
|
||||
data,
|
||||
schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=merged_storage_options,
|
||||
location=location,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
)
|
||||
|
||||
return tbl
|
||||
|
||||
def _create_table_server_side(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA],
|
||||
schema: Optional[Union[pa.Schema, LanceModel]],
|
||||
mode: str,
|
||||
exist_ok: bool,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]],
|
||||
namespace_path: Optional[List[str]],
|
||||
storage_options: Optional[Dict[str, str]],
|
||||
) -> Table:
|
||||
"""Create a table using server-side namespace.create_table()."""
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
table_id = namespace_path + [name]
|
||||
|
||||
arrow_ipc_bytes = _data_to_arrow_ipc(
|
||||
data=data,
|
||||
schema=schema,
|
||||
embedding_functions=embedding_functions,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
merged = dict(self.storage_options or {})
|
||||
if storage_options:
|
||||
merged.update(storage_options)
|
||||
request = CreateTableRequest(
|
||||
id=table_id,
|
||||
mode=_normalize_create_table_mode(mode),
|
||||
properties=merged or None,
|
||||
)
|
||||
|
||||
try:
|
||||
self._namespace_client.create_table(request, arrow_ipc_bytes)
|
||||
except Exception as e:
|
||||
if exist_ok and "already exists" in str(e).lower():
|
||||
return self.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
raise
|
||||
|
||||
return self.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -650,30 +547,28 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
) -> Table:
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
table_id = namespace_path + [name]
|
||||
request = DescribeTableRequest(id=table_id)
|
||||
response = self._namespace_client.describe_table(request)
|
||||
try:
|
||||
async_table = LOOP.run(
|
||||
self._inner.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "Table not found" in str(e):
|
||||
table_id = namespace_path + [name]
|
||||
raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}")
|
||||
raise
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
if storage_options:
|
||||
merged_storage_options.update(storage_options)
|
||||
if response.storage_options:
|
||||
merged_storage_options.update(response.storage_options)
|
||||
|
||||
# Pass managed_versioning to avoid redundant describe_table call in Rust.
|
||||
# Convert None to False since we already have the answer from describe_table.
|
||||
managed_versioning = response.managed_versioning is True
|
||||
|
||||
# Note: storage_options_provider is auto-created in Rust from namespace_client
|
||||
return self._lance_table_from_uri(
|
||||
return LanceTable(
|
||||
self,
|
||||
name,
|
||||
response.location,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=merged_storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
namespace_client=self._namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -897,33 +792,34 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
) -> LanceTable:
|
||||
# Open a table directly from a URI using the location parameter
|
||||
# Note: storage_options should already be merged by the caller
|
||||
# Note: storage_options_provider is auto-created in Rust from namespace_client
|
||||
# Open a table directly from the namespace-resolved physical location.
|
||||
#
|
||||
# Open the table through the Rust namespace-backed connection. The Rust
|
||||
# layer keeps the logical namespace path and namespace client intact.
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
temp_conn = LanceDBConnection(
|
||||
table_uri, # Use the table location as the connection URI
|
||||
read_consistency_interval=self.read_consistency_interval,
|
||||
storage_options=storage_options if storage_options is not None else {},
|
||||
session=self.session,
|
||||
|
||||
async_table = LOOP.run(
|
||||
self._inner.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
location=None,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
)
|
||||
|
||||
# Open the table using the temporary connection with the location parameter
|
||||
# Pass namespace_client to enable managed versioning support and auto-create
|
||||
# storage options provider
|
||||
# Pass managed_versioning to avoid redundant describe_table call
|
||||
# Pass pushdown_operations if configured on this connection
|
||||
return LanceTable.open(
|
||||
temp_conn,
|
||||
return LanceTable(
|
||||
self,
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
location=table_uri,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -990,6 +886,23 @@ class AsyncLanceNamespaceDBConnection:
|
||||
self._namespace_client_pushdown_operations = set(
|
||||
namespace_client_pushdown_operations or []
|
||||
)
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
read_consistency_interval=(
|
||||
read_consistency_interval.total_seconds()
|
||||
if read_consistency_interval is not None
|
||||
else None
|
||||
),
|
||||
storage_options=self.storage_options or None,
|
||||
session=session,
|
||||
namespace_client_pushdown_operations=(
|
||||
list(self._namespace_client_pushdown_operations)
|
||||
),
|
||||
namespace_client_impl=None,
|
||||
namespace_client_properties=None,
|
||||
)
|
||||
)
|
||||
|
||||
async def table_names(
|
||||
self,
|
||||
@@ -1041,148 +954,16 @@ class AsyncLanceNamespaceDBConnection:
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
table_id = namespace_path + [name]
|
||||
|
||||
if "CreateTable" in self._namespace_client_pushdown_operations:
|
||||
return await self._create_table_server_side(
|
||||
name=name,
|
||||
data=data,
|
||||
schema=schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
# Local create path: declare_table + local write
|
||||
# Step 1: Get the table location and storage options from namespace
|
||||
location = None
|
||||
namespace_storage_options = None
|
||||
if mode.lower() == "overwrite":
|
||||
# Try to describe the table first to see if it exists
|
||||
try:
|
||||
describe_request = DescribeTableRequest(id=table_id)
|
||||
describe_response = self._namespace_client.describe_table(
|
||||
describe_request
|
||||
)
|
||||
location = describe_response.location
|
||||
namespace_storage_options = describe_response.storage_options
|
||||
except Exception:
|
||||
# Table doesn't exist, will create a new one below
|
||||
pass
|
||||
|
||||
if location is None:
|
||||
# Table doesn't exist or mode is "create", reserve a new location
|
||||
declare_request = DeclareTableRequest(
|
||||
id=table_id,
|
||||
location=None,
|
||||
properties=self.storage_options if self.storage_options else None,
|
||||
)
|
||||
declare_response = self._namespace_client.declare_table(declare_request)
|
||||
|
||||
if not declare_response.location:
|
||||
raise ValueError(
|
||||
"Table location is missing from declare_table response"
|
||||
)
|
||||
|
||||
location = declare_response.location
|
||||
namespace_storage_options = declare_response.storage_options
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
if storage_options:
|
||||
merged_storage_options.update(storage_options)
|
||||
if namespace_storage_options:
|
||||
merged_storage_options.update(namespace_storage_options)
|
||||
|
||||
# Step 2: Create table using LanceTable.create with the location
|
||||
# Run the sync operation in a thread
|
||||
def _create_table():
|
||||
temp_conn = LanceDBConnection(
|
||||
location,
|
||||
read_consistency_interval=self.read_consistency_interval,
|
||||
storage_options=merged_storage_options,
|
||||
session=self.session,
|
||||
)
|
||||
|
||||
# storage_options_provider is auto-created in Rust from namespace_client
|
||||
return LanceTable.create(
|
||||
temp_conn,
|
||||
name,
|
||||
data,
|
||||
schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=merged_storage_options,
|
||||
location=location,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
)
|
||||
|
||||
lance_table = await asyncio.to_thread(_create_table)
|
||||
# Get the underlying async table from LanceTable
|
||||
return lance_table._table
|
||||
|
||||
async def _create_table_server_side(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA],
|
||||
schema: Optional[Union[pa.Schema, LanceModel]],
|
||||
mode: str,
|
||||
exist_ok: bool,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]],
|
||||
namespace_path: Optional[List[str]],
|
||||
storage_options: Optional[Dict[str, str]],
|
||||
) -> AsyncTable:
|
||||
"""Create a table using server-side namespace.create_table()."""
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
table_id = namespace_path + [name]
|
||||
|
||||
def _prepare_and_create():
|
||||
arrow_ipc_bytes = _data_to_arrow_ipc(
|
||||
data=data,
|
||||
schema=schema,
|
||||
embedding_functions=embedding_functions,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
merged = dict(self.storage_options or {})
|
||||
if storage_options:
|
||||
merged.update(storage_options)
|
||||
request = CreateTableRequest(
|
||||
id=table_id,
|
||||
mode=_normalize_create_table_mode(mode),
|
||||
properties=merged or None,
|
||||
)
|
||||
|
||||
self._namespace_client.create_table(request, arrow_ipc_bytes)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_prepare_and_create)
|
||||
except Exception as e:
|
||||
if exist_ok and "already exists" in str(e).lower():
|
||||
return await self.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
raise
|
||||
|
||||
return await self.open_table(
|
||||
return await self._inner.create_table(
|
||||
name,
|
||||
data,
|
||||
schema=schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
namespace_path=namespace_path,
|
||||
embedding_functions=embedding_functions,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
@@ -1197,45 +978,18 @@ class AsyncLanceNamespaceDBConnection:
|
||||
"""Open an existing table from the namespace."""
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
table_id = namespace_path + [name]
|
||||
request = DescribeTableRequest(id=table_id)
|
||||
response = self._namespace_client.describe_table(request)
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
if storage_options:
|
||||
merged_storage_options.update(storage_options)
|
||||
if response.storage_options:
|
||||
merged_storage_options.update(response.storage_options)
|
||||
|
||||
# Capture managed_versioning from describe response.
|
||||
# Convert None to False since we already have the answer from describe_table.
|
||||
managed_versioning = response.managed_versioning is True
|
||||
|
||||
# Open table in a thread
|
||||
# Note: storage_options_provider is auto-created in Rust from namespace_client
|
||||
def _open_table():
|
||||
temp_conn = LanceDBConnection(
|
||||
response.location,
|
||||
read_consistency_interval=self.read_consistency_interval,
|
||||
storage_options=merged_storage_options,
|
||||
session=self.session,
|
||||
)
|
||||
|
||||
return LanceTable.open(
|
||||
temp_conn,
|
||||
try:
|
||||
return await self._inner.open_table(
|
||||
name,
|
||||
namespace_path=namespace_path,
|
||||
storage_options=merged_storage_options,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
location=response.location,
|
||||
namespace_client=self._namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
)
|
||||
|
||||
lance_table = await asyncio.to_thread(_open_table)
|
||||
return lance_table._table
|
||||
except RuntimeError as e:
|
||||
if "Table not found" in str(e):
|
||||
table_id = namespace_path + [name]
|
||||
raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}")
|
||||
raise
|
||||
|
||||
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||
"""Drop a table from the namespace."""
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
connection::NamespaceClientPushdownOperation,
|
||||
database::namespace::LanceNamespaceDatabase,
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
@@ -39,6 +45,29 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_namespace_client_pushdown_operations(
|
||||
operations: Option<Vec<String>>,
|
||||
) -> PyResult<HashSet<NamespaceClientPushdownOperation>> {
|
||||
let mut parsed = HashSet::new();
|
||||
for operation in operations.unwrap_or_default() {
|
||||
match operation.as_str() {
|
||||
"QueryTable" => {
|
||||
parsed.insert(NamespaceClientPushdownOperation::QueryTable);
|
||||
}
|
||||
"CreateTable" => {
|
||||
parsed.insert(NamespaceClientPushdownOperation::CreateTable);
|
||||
}
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Invalid pushdown operation: {}",
|
||||
operation
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(parsed)
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
fn parse_create_mode_str(mode: &str) -> PyResult<CreateTableMode> {
|
||||
match mode {
|
||||
@@ -538,6 +567,52 @@ pub fn connect(
|
||||
})
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (
|
||||
namespace_client,
|
||||
read_consistency_interval=None,
|
||||
storage_options=None,
|
||||
session=None,
|
||||
namespace_client_pushdown_operations=None,
|
||||
namespace_client_impl=None,
|
||||
namespace_client_properties=None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect_namespace_client(
|
||||
py: Python<'_>,
|
||||
namespace_client: Py<PyAny>,
|
||||
read_consistency_interval: Option<f64>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
session: Option<crate::session::Session>,
|
||||
namespace_client_pushdown_operations: Option<Vec<String>>,
|
||||
namespace_client_impl: Option<String>,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Connection> {
|
||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
|
||||
let namespace_client_pushdown_operations =
|
||||
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
|
||||
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
|
||||
let ns_properties = namespace_client_properties.unwrap_or_default();
|
||||
let storage_options = storage_options.unwrap_or_default();
|
||||
let session = session.map(|s| s.inner.clone());
|
||||
|
||||
let database = LanceNamespaceDatabase::from_namespace_client(
|
||||
namespace_client,
|
||||
ns_impl,
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
);
|
||||
|
||||
Ok(Connection::new(LanceConnection::new(
|
||||
Arc::new(database),
|
||||
Arc::new(lancedb::embeddings::MemoryRegistry::new()),
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{Connection, connect};
|
||||
use connection::{Connection, connect, connect_namespace_client};
|
||||
use env_logger::Env;
|
||||
use expr::{PyExpr, expr_col, expr_func, expr_lit};
|
||||
use index::IndexConfig;
|
||||
@@ -58,6 +58,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyPermutationReader>()?;
|
||||
m.add_class::<PyExpr>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(connect_namespace_client, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||
|
||||
Reference in New Issue
Block a user