diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index bb6c65d56..e69161586 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -215,6 +215,85 @@ def connect( ) +WORKER_PROPERTY_PREFIX = "_lancedb_worker_" + + +def _apply_worker_overrides(props: dict[str, str]) -> dict[str, str]: + """Apply worker property overrides. + + Any key starting with ``_lancedb_worker_`` is extracted, the prefix + is stripped, and the resulting key-value pair is put back into the + map (overriding the existing value if present). The original + prefixed key is removed. + """ + worker_keys = [k for k in props if k.startswith(WORKER_PROPERTY_PREFIX)] + if not worker_keys: + return props + result = dict(props) + for key in worker_keys: + value = result.pop(key) + real_key = key[len(WORKER_PROPERTY_PREFIX) :] + result[real_key] = value + return result + + +def deserialize_conn( + data: str, + *, + for_worker: bool = False, +) -> DBConnection: + """Reconstruct a DBConnection from a serialized string. + + The string must have been produced by + :meth:`DBConnection.serialize`. + + Parameters + ---------- + data : str + String produced by ``serialize()``. + for_worker : bool, default False + When ``True``, any namespace client property whose key starts + with ``_lancedb_worker_`` has that prefix stripped and the + value overrides the corresponding property. For example, + ``_lancedb_worker_uri`` replaces ``uri``. + + Returns + ------- + DBConnection + A new connection matching the serialized state. + """ + import json + + parsed = json.loads(data) + connection_type = parsed.get("connection_type") + + rci_secs = parsed.get("read_consistency_interval_seconds") + rci = timedelta(seconds=rci_secs) if rci_secs is not None else None + storage_options = parsed.get("storage_options") + + if connection_type == "namespace": + props = dict(parsed.get("namespace_client_properties") or {}) + if for_worker: + props = _apply_worker_overrides(props) + return connect_namespace( + namespace_client_impl=parsed["namespace_client_impl"], + namespace_client_properties=props, + read_consistency_interval=rci, + storage_options=storage_options, + namespace_client_pushdown_operations=parsed.get( + "namespace_client_pushdown_operations" + ), + ) + elif connection_type == "local": + return LanceDBConnection( + parsed["uri"], + read_consistency_interval=rci, + storage_options=storage_options, + ) + else: + raise ValueError(f"Unknown connection_type: {connection_type}") + + async def connect_async( uri: URI, *, diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 869f1481f..81b177e61 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -529,6 +529,19 @@ class DBConnection(EnforceOverrides): "namespace_client is not supported for this connection type" ) + def serialize(self) -> str: + """Serialize this connection for reconstruction. + + The returned string can be passed to :func:`lancedb.deserialize_conn` + to recreate an equivalent connection, e.g. in a remote worker. + + Returns + ------- + str + Serialized representation of this connection. + """ + raise NotImplementedError("serialize is not supported for this connection type") + class LanceDBConnection(DBConnection): """ @@ -581,6 +594,7 @@ class LanceDBConnection(DBConnection): ): if _inner is not None: self._conn = _inner + self._cached_namespace_client = None return if not isinstance(uri, Path): @@ -628,6 +642,7 @@ class LanceDBConnection(DBConnection): # beyond _conn. self.storage_options = storage_options self._conn = AsyncConnection(LOOP.run(do_connect())) + self._cached_namespace_client: Optional[LanceNamespace] = None @property def read_consistency_interval(self) -> Optional[timedelta]: @@ -652,6 +667,22 @@ class LanceDBConnection(DBConnection): val += ")" return val + @override + def serialize(self) -> str: + import json + + rci = self.read_consistency_interval + return json.dumps( + { + "connection_type": "local", + "uri": self.uri, + "storage_options": self.storage_options, + "read_consistency_interval_seconds": ( + rci.total_seconds() if rci else None + ), + } + ) + async def _async_get_table_names(self, start_after: Optional[str], limit: int): conn = AsyncConnection(await lancedb_connect(self.uri)) return await conn.table_names(start_after=start_after, limit=limit) @@ -687,10 +718,10 @@ class LanceDBConnection(DBConnection): """ if namespace_path is None: namespace_path = [] - return LOOP.run( - self._conn.list_namespaces( - namespace_path=namespace_path, page_token=page_token, limit=limit - ) + return self._namespace_conn().list_namespaces( + namespace_path=namespace_path, + page_token=page_token, + limit=limit, ) @override @@ -700,27 +731,10 @@ class LanceDBConnection(DBConnection): mode: Optional[str] = None, properties: Optional[Dict[str, str]] = None, ) -> CreateNamespaceResponse: - """Create a new namespace. - - Parameters - ---------- - namespace_path: List[str] - The namespace identifier to create. - mode: str, optional - Creation mode - "create" (fail if exists), "exist_ok" (skip if exists), - or "overwrite" (replace if exists). Case insensitive. - properties: Dict[str, str], optional - Properties to set on the namespace. - - Returns - ------- - CreateNamespaceResponse - Response containing the properties of the created namespace. - """ - return LOOP.run( - self._conn.create_namespace( - namespace_path=namespace_path, mode=mode, properties=properties - ) + return self._namespace_conn().create_namespace( + namespace_path=namespace_path, + mode=mode, + properties=properties, ) @override @@ -730,46 +744,19 @@ class LanceDBConnection(DBConnection): mode: Optional[str] = None, behavior: Optional[str] = None, ) -> DropNamespaceResponse: - """Drop a namespace. - - Parameters - ---------- - namespace_path: List[str] - The namespace identifier to drop. - mode: str, optional - Whether to skip if not exists ("SKIP") or fail ("FAIL"). Case insensitive. - behavior: str, optional - Whether to restrict drop if not empty ("RESTRICT") or cascade ("CASCADE"). - Case insensitive. - - Returns - ------- - DropNamespaceResponse - Response containing properties and transaction_id if applicable. - """ - return LOOP.run( - self._conn.drop_namespace( - namespace_path=namespace_path, mode=mode, behavior=behavior - ) + return self._namespace_conn().drop_namespace( + namespace_path=namespace_path, + mode=mode, + behavior=behavior, ) @override def describe_namespace( self, namespace_path: List[str] ) -> DescribeNamespaceResponse: - """Describe a namespace. - - Parameters - ---------- - namespace_path: List[str] - The namespace identifier to describe. - - Returns - ------- - DescribeNamespaceResponse - Response containing the namespace properties. - """ - return LOOP.run(self._conn.describe_namespace(namespace_path=namespace_path)) + return self._namespace_conn().describe_namespace( + namespace_path=namespace_path, + ) @override def list_tables( @@ -798,6 +785,12 @@ class LanceDBConnection(DBConnection): """ if namespace_path is None: namespace_path = [] + if namespace_path: + return self._namespace_conn().list_tables( + namespace_path=namespace_path, + page_token=page_token, + limit=limit, + ) return LOOP.run( self._conn.list_tables( namespace_path=namespace_path, page_token=page_token, limit=limit @@ -886,6 +879,22 @@ class LanceDBConnection(DBConnection): raise ValueError("mode must be either 'create' or 'overwrite'") validate_table_name(name) + if namespace_path: + return self._namespace_conn().create_table( + name, + data=data, + schema=schema, + mode=mode, + exist_ok=exist_ok, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + embedding_functions=embedding_functions, + namespace_path=namespace_path, + storage_options=storage_options, + data_storage_version=data_storage_version, + enable_v2_manifest_paths=enable_v2_manifest_paths, + ) + tbl = LanceTable.create( self, name, @@ -901,6 +910,19 @@ class LanceDBConnection(DBConnection): ) return tbl + def _namespace_conn(self) -> DBConnection: + """Return a LanceNamespaceDBConnection backed by this connection's + directory namespace. Used to delegate child-namespace operations.""" + from lancedb.namespace import LanceNamespaceDBConnection + + return LanceNamespaceDBConnection( + self.namespace_client(), + read_consistency_interval=self.read_consistency_interval, + storage_options=self.storage_options, + namespace_client_impl=None, + namespace_client_properties=None, + ) + @override def open_table( self, @@ -917,7 +939,8 @@ class LanceDBConnection(DBConnection): name: str The name of the table. namespace_path: List[str], optional - The namespace to open the table from. + The namespace to open the table from. When non-empty, the + table is resolved through the directory namespace client. Returns ------- @@ -936,6 +959,14 @@ class LanceDBConnection(DBConnection): stacklevel=2, ) + if namespace_path: + return self._namespace_conn().open_table( + name, + namespace_path=namespace_path, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) + return LanceTable.open( self, name, @@ -1020,6 +1051,9 @@ class LanceDBConnection(DBConnection): """ if namespace_path is None: namespace_path = [] + if namespace_path: + self._namespace_conn().drop_table(name, namespace_path=namespace_path) + return LOOP.run( self._conn.drop_table( name, namespace_path=namespace_path, ignore_missing=ignore_missing @@ -1071,14 +1105,17 @@ class LanceDBConnection(DBConnection): """Get the equivalent namespace client for this connection. Returns a DirectoryNamespace pointing to the same root with the - same storage options. + same storage options. The result is cached for the lifetime of + this connection. Returns ------- LanceNamespace The namespace client for this connection. """ - return LOOP.run(self._conn.namespace_client()) + if self._cached_namespace_client is None: + self._cached_namespace_client = LOOP.run(self._conn.namespace_client()) + return self._cached_namespace_client @deprecation.deprecated( deprecated_in="0.15.1", diff --git a/python/python/lancedb/namespace.py b/python/python/lancedb/namespace.py index 55df0c82b..0996c63fd 100644 --- a/python/python/lancedb/namespace.py +++ b/python/python/lancedb/namespace.py @@ -381,6 +381,8 @@ class LanceNamespaceDBConnection(DBConnection): storage_options: Optional[Dict[str, str]] = None, session: Optional[Session] = None, namespace_client_pushdown_operations: Optional[List[str]] = None, + namespace_client_impl: Optional[str] = None, + namespace_client_properties: Optional[Dict[str, str]] = None, ): """ Initialize a namespace-based LanceDB connection. @@ -406,12 +408,43 @@ class LanceNamespaceDBConnection(DBConnection): namespace.create_table() instead of using declare_table + local write. Default is None (no pushdown, all operations run locally). + namespace_client_impl : Optional[str] + The namespace implementation name used to create this connection. + Stored for serialization purposes. + namespace_client_properties : Optional[Dict[str, str]] + The namespace properties used to create this connection. + Stored for serialization purposes. """ self._namespace_client = namespace_client self.read_consistency_interval = read_consistency_interval self.storage_options = storage_options or {} self.session = session - self._pushdown_operations = set(namespace_client_pushdown_operations or []) + self._namespace_client_pushdown_operations = set( + namespace_client_pushdown_operations or [] + ) + self._namespace_client_impl = namespace_client_impl + self._namespace_client_properties = namespace_client_properties + + @override + def serialize(self) -> str: + import json + + return json.dumps( + { + "connection_type": "namespace", + "namespace_client_impl": self._namespace_client_impl, + "namespace_client_properties": self._namespace_client_properties, + "namespace_client_pushdown_operations": sorted( + self._namespace_client_pushdown_operations + ), + "storage_options": self.storage_options or None, + "read_consistency_interval_seconds": ( + self.read_consistency_interval.total_seconds() + if self.read_consistency_interval + else None + ), + } + ) @override def table_names( @@ -467,7 +500,7 @@ class LanceNamespaceDBConnection(DBConnection): table_id = namespace_path + [name] - if "CreateTable" in self._pushdown_operations: + if "CreateTable" in self._namespace_client_pushdown_operations: return self._create_table_server_side( name=name, data=data, @@ -549,7 +582,7 @@ class LanceNamespaceDBConnection(DBConnection): storage_options=merged_storage_options, location=location, namespace_client=self._namespace_client, - pushdown_operations=self._pushdown_operations, + pushdown_operations=self._namespace_client_pushdown_operations, ) return tbl @@ -580,10 +613,13 @@ class LanceNamespaceDBConnection(DBConnection): fill_value=fill_value, ) + merged = dict(self.storage_options or {}) + if storage_options: + merged.update(storage_options) request = CreateTableRequest( id=table_id, mode=_normalize_create_table_mode(mode), - properties=self.storage_options if self.storage_options else None, + properties=merged or None, ) try: @@ -887,7 +923,7 @@ class LanceNamespaceDBConnection(DBConnection): location=table_uri, namespace_client=namespace_client, managed_versioning=managed_versioning, - pushdown_operations=self._pushdown_operations, + pushdown_operations=self._namespace_client_pushdown_operations, ) @override @@ -951,7 +987,9 @@ class AsyncLanceNamespaceDBConnection: self.read_consistency_interval = read_consistency_interval self.storage_options = storage_options or {} self.session = session - self._pushdown_operations = set(namespace_client_pushdown_operations or []) + self._namespace_client_pushdown_operations = set( + namespace_client_pushdown_operations or [] + ) async def table_names( self, @@ -1006,7 +1044,7 @@ class AsyncLanceNamespaceDBConnection: table_id = namespace_path + [name] - if "CreateTable" in self._pushdown_operations: + if "CreateTable" in self._namespace_client_pushdown_operations: return await self._create_table_server_side( name=name, data=data, @@ -1086,7 +1124,7 @@ class AsyncLanceNamespaceDBConnection: storage_options=merged_storage_options, location=location, namespace_client=self._namespace_client, - pushdown_operations=self._pushdown_operations, + pushdown_operations=self._namespace_client_pushdown_operations, ) lance_table = await asyncio.to_thread(_create_table) @@ -1120,10 +1158,13 @@ class AsyncLanceNamespaceDBConnection: fill_value=fill_value, ) + merged = dict(self.storage_options or {}) + if storage_options: + merged.update(storage_options) request = CreateTableRequest( id=table_id, mode=_normalize_create_table_mode(mode), - properties=self.storage_options if self.storage_options else None, + properties=merged or None, ) self._namespace_client.create_table(request, arrow_ipc_bytes) @@ -1190,7 +1231,7 @@ class AsyncLanceNamespaceDBConnection: location=response.location, namespace_client=self._namespace_client, managed_versioning=managed_versioning, - pushdown_operations=self._pushdown_operations, + pushdown_operations=self._namespace_client_pushdown_operations, ) lance_table = await asyncio.to_thread(_open_table) @@ -1472,6 +1513,8 @@ def connect_namespace( storage_options=storage_options, session=session, namespace_client_pushdown_operations=namespace_client_pushdown_operations, + namespace_client_impl=namespace_client_impl, + namespace_client_properties=namespace_client_properties, ) diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index ebb42b61e..5f3fc60ec 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -897,42 +897,22 @@ def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection): def test_local_namespace_operations(tmp_path): - """Test that local mode namespace operations behave as expected.""" - # Create a local database connection + """Test that local mode namespace operations work via directory namespace.""" db = lancedb.connect(tmp_path) - # Test list_namespaces returns empty list for root namespace - namespaces = db.list_namespaces().namespaces - assert namespaces == [] + # Root namespace starts empty + assert db.list_namespaces().namespaces == [] - # Test list_namespaces with non-empty namespace raises NotImplementedError - with pytest.raises( - NotImplementedError, - match="Namespace operations are not supported for listing database", - ): - db.list_namespaces(namespace_path=["test"]) + # Create and list child namespace + db.create_namespace(["child"]) + assert "child" in db.list_namespaces().namespaces + # List namespaces under child + assert db.list_namespaces(namespace_path=["child"]).namespaces == [] -def test_local_create_namespace_not_supported(tmp_path): - """Test that create_namespace is not supported in local mode.""" - db = lancedb.connect(tmp_path) - - with pytest.raises( - NotImplementedError, - match="Namespace operations are not supported for listing database", - ): - db.create_namespace(["test_namespace"]) - - -def test_local_drop_namespace_not_supported(tmp_path): - """Test that drop_namespace is not supported in local mode.""" - db = lancedb.connect(tmp_path) - - with pytest.raises( - NotImplementedError, - match="Namespace operations are not supported for listing database", - ): - db.drop_namespace(["test_namespace"]) + # Drop namespace + db.drop_namespace(["child"]) + assert db.list_namespaces().namespaces == [] def test_clone_table_latest_version(tmp_path): diff --git a/python/python/tests/test_namespace.py b/python/python/tests/test_namespace.py index 1f326e29a..bbf7e4c6f 100644 --- a/python/python/tests/test_namespace.py +++ b/python/python/tests/test_namespace.py @@ -681,7 +681,7 @@ class TestPushdownOperations: {"root": self.temp_dir}, namespace_client_pushdown_operations=["QueryTable"], ) - assert "QueryTable" in db._pushdown_operations + assert "QueryTable" in db._namespace_client_pushdown_operations def test_create_table_pushdown_stored(self): """Test that CreateTable pushdown is stored on sync connection.""" @@ -690,7 +690,7 @@ class TestPushdownOperations: {"root": self.temp_dir}, namespace_client_pushdown_operations=["CreateTable"], ) - assert "CreateTable" in db._pushdown_operations + assert "CreateTable" in db._namespace_client_pushdown_operations def test_both_pushdowns_stored(self): """Test that both pushdown operations can be set together.""" @@ -699,13 +699,13 @@ class TestPushdownOperations: {"root": self.temp_dir}, namespace_client_pushdown_operations=["QueryTable", "CreateTable"], ) - assert "QueryTable" in db._pushdown_operations - assert "CreateTable" in db._pushdown_operations + assert "QueryTable" in db._namespace_client_pushdown_operations + assert "CreateTable" in db._namespace_client_pushdown_operations def test_pushdown_defaults_to_empty(self): """Test that pushdown operations default to empty.""" db = lancedb.connect_namespace("dir", {"root": self.temp_dir}) - assert len(db._pushdown_operations) == 0 + assert len(db._namespace_client_pushdown_operations) == 0 @pytest.mark.asyncio @@ -727,7 +727,7 @@ class TestAsyncPushdownOperations: {"root": self.temp_dir}, namespace_client_pushdown_operations=["QueryTable"], ) - assert "QueryTable" in db._pushdown_operations + assert "QueryTable" in db._namespace_client_pushdown_operations async def test_async_create_table_pushdown_stored(self): """Test that CreateTable pushdown is stored on async connection.""" @@ -736,9 +736,9 @@ class TestAsyncPushdownOperations: {"root": self.temp_dir}, namespace_client_pushdown_operations=["CreateTable"], ) - assert "CreateTable" in db._pushdown_operations + assert "CreateTable" in db._namespace_client_pushdown_operations async def test_async_pushdown_defaults_to_empty(self): """Test that pushdown operations default to empty on async connection.""" db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir}) - assert len(db._pushdown_operations) == 0 + assert len(db._namespace_client_pushdown_operations) == 0