Files
lancedb/python/python/lancedb/remote/db.py
Colin Patrick McCabe 2f6d525802 fix: support exist_ok in RemoteDBConnection.create_table (#2901)
RemoteDBConnection should support passing exist_ok to create_table, just
like LanceDBConnection (the non-remote form) does. It can support this
by passing 'exist_ok' as the mode parameter.
2026-01-07 12:29:45 -08:00

572 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
import logging
from concurrent.futures import ThreadPoolExecutor
import sys
from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
if sys.version_info >= (3, 12):
from typing import override
else:
from overrides import override
# Remove this import to fix circular dependency
# from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from ..common import DATA
from ..db import DBConnection, LOOP
from ..embeddings import EmbeddingFunctionConfig
from lance_namespace import (
CreateNamespaceResponse,
DescribeNamespaceResponse,
DropNamespaceResponse,
ListNamespacesResponse,
ListTablesResponse,
)
from ..pydantic import LanceModel
from ..table import Table
from ..util import validate_table_name
class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database."""
def __init__(
self,
db_url: str,
api_key: str,
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
connection_timeout: Optional[float] = None,
read_timeout: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
):
"""Connect to a remote LanceDB database."""
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
elif client_config is None:
client_config = ClientConfig()
# These are legacy options from the old Python-based client. We keep them
# here for backwards compatibility, but will remove them in a future release.
if request_thread_pool is not None:
warnings.warn(
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
)
if connection_timeout is not None:
warnings.warn(
"connection_timeout is deprecated and will be removed in a future "
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
)
if read_timeout is not None:
warnings.warn(
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
self.client_config = client_config
# Import connect_async here to avoid circular import
from lancedb import connect_async
self._conn = LOOP.run(
connect_async(
db_url,
api_key=api_key,
region=region,
host_override=host_override,
client_config=client_config,
storage_options=storage_options,
)
)
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"
@override
def list_namespaces(
self,
namespace: Optional[List[str]] = None,
page_token: Optional[str] = None,
limit: Optional[int] = None,
) -> ListNamespacesResponse:
"""List immediate child namespace names in the given namespace.
Parameters
----------
namespace: List[str], optional
The parent namespace to list namespaces in.
None or empty list represents root namespace.
page_token: str, optional
Token for pagination. Use the token from a previous response
to get the next page of results.
limit: int, optional
The maximum number of results to return.
Returns
-------
ListNamespacesResponse
Response containing namespace names and optional page_token for pagination.
"""
if namespace is None:
namespace = []
return LOOP.run(
self._conn.list_namespaces(
namespace=namespace, page_token=page_token, limit=limit
)
)
@override
def create_namespace(
self,
namespace: List[str],
mode: Optional[str] = None,
properties: Optional[Dict[str, str]] = None,
) -> CreateNamespaceResponse:
"""Create a new namespace.
Parameters
----------
namespace: List[str]
The namespace identifier to create.
mode: str, optional
Creation mode - "create" (fail if exists), "exist_ok" (skip if exists),
or "overwrite" (replace if exists). Case insensitive.
properties: Dict[str, str], optional
Properties to set on the namespace.
Returns
-------
CreateNamespaceResponse
Response containing the properties of the created namespace.
"""
return LOOP.run(
self._conn.create_namespace(
namespace=namespace, mode=mode, properties=properties
)
)
@override
def drop_namespace(
self,
namespace: List[str],
mode: Optional[str] = None,
behavior: Optional[str] = None,
) -> DropNamespaceResponse:
"""Drop a namespace.
Parameters
----------
namespace: List[str]
The namespace identifier to drop.
mode: str, optional
Whether to skip if not exists ("SKIP") or fail ("FAIL"). Case insensitive.
behavior: str, optional
Whether to restrict drop if not empty ("RESTRICT") or cascade ("CASCADE").
Case insensitive.
Returns
-------
DropNamespaceResponse
Response containing properties and transaction_id if applicable.
"""
return LOOP.run(
self._conn.drop_namespace(namespace=namespace, mode=mode, behavior=behavior)
)
@override
def describe_namespace(self, namespace: List[str]) -> DescribeNamespaceResponse:
"""Describe a namespace.
Parameters
----------
namespace: List[str]
The namespace identifier to describe.
Returns
-------
DescribeNamespaceResponse
Response containing the namespace properties.
"""
return LOOP.run(self._conn.describe_namespace(namespace=namespace))
@override
def list_tables(
self,
namespace: Optional[List[str]] = None,
page_token: Optional[str] = None,
limit: Optional[int] = None,
) -> ListTablesResponse:
"""List all tables in this database with pagination support.
Parameters
----------
namespace: List[str], optional
The namespace to list tables in.
None or empty list represents root namespace.
page_token: str, optional
Token for pagination. Use the token from a previous response
to get the next page of results.
limit: int, optional
The maximum number of results to return.
Returns
-------
ListTablesResponse
Response containing table names and optional page_token for pagination.
"""
if namespace is None:
namespace = []
return LOOP.run(
self._conn.list_tables(
namespace=namespace, page_token=page_token, limit=limit
)
)
@override
def table_names(
self,
page_token: Optional[str] = None,
limit: int = 10,
*,
namespace: Optional[List[str]] = None,
) -> Iterable[str]:
"""List the names of all tables in the database.
.. deprecated::
Use :meth:`list_tables` instead, which provides proper pagination support.
Parameters
----------
namespace: List[str], default []
The namespace to list tables in.
Empty list represents root namespace.
page_token: str
The last token to start the new page.
limit: int, default 10
The maximum number of tables to return for each page.
Returns
-------
An iterator of table names.
"""
import warnings
warnings.warn(
"table_names() is deprecated, use list_tables() instead",
DeprecationWarning,
stacklevel=2,
)
if namespace is None:
namespace = []
return LOOP.run(
self._conn.table_names(
namespace=namespace, start_after=page_token, limit=limit
)
)
@override
def open_table(
self,
name: str,
*,
namespace: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
namespace: List[str], optional
The namespace to open the table from.
None or empty list represents root namespace.
Returns
-------
A LanceTable object representing the table.
"""
from .table import RemoteTable
if namespace is None:
namespace = []
if index_cache_size is not None:
logging.info(
"index_cache_size is ignored in LanceDb Cloud"
" (there is no local cache to configure)"
)
table = LOOP.run(self._conn.open_table(name, namespace=namespace))
return RemoteTable(table, self.db_name)
def clone_table(
self,
target_table_name: str,
source_uri: str,
*,
target_namespace: Optional[List[str]] = None,
source_version: Optional[int] = None,
source_tag: Optional[str] = None,
is_shallow: bool = True,
) -> Table:
"""Clone a table from a source table.
Parameters
----------
target_table_name: str
The name of the target table to create.
source_uri: str
The URI of the source table to clone from.
target_namespace: List[str], optional
The namespace for the target table.
None or empty list represents root namespace.
source_version: int, optional
The version of the source table to clone.
source_tag: str, optional
The tag of the source table to clone.
is_shallow: bool, default True
Whether to perform a shallow clone (True) or deep clone (False).
Currently only shallow clone is supported.
Returns
-------
A RemoteTable object representing the cloned table.
"""
from .table import RemoteTable
if target_namespace is None:
target_namespace = []
table = LOOP.run(
self._conn.clone_table(
target_table_name,
source_uri,
target_namespace=target_namespace,
source_version=source_version,
source_tag=source_tag,
is_shallow=is_shallow,
)
)
return RemoteTable(table, self.db_name)
@override
def create_table(
self,
name: str,
data: DATA = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
mode: Optional[str] = None,
exist_ok: bool = False,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
namespace: Optional[List[str]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
Parameters
----------
name: str
The name of the table.
namespace: List[str], optional
The namespace to create the table in.
None or empty list represents root namespace.
data: The data to initialize the table, *optional*
User must provide at least one of `data` or `schema`.
Acceptable types are:
- dict or list-of-dict
- pandas.DataFrame
- pyarrow.Table or pyarrow.RecordBatch
schema: The schema of the table, *optional*
Acceptable types are:
- pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel]
mode: str, default "create"
The mode to use when creating the table.
Can be either "create", "overwrite", or "exist_ok".
exist_ok: bool, default False
If exist_ok is True, and mode is None or "create", mode will be changed
to "exist_ok".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Returns
-------
LanceTable
A reference to the newly created table.
!!! note
The vector index won't be created by default.
To create the index, call the `create_index` method on the table.
Examples
--------
Can create with list of tuples or dictionaries:
>>> import lancedb
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data) # doctest: +SKIP
LanceTable(my_table)
You can also pass a pandas DataFrame:
>>> import pandas as pd
>>> data = pd.DataFrame({
... "vector": [[1.1, 1.2], [0.2, 1.8]],
... "lat": [45.5, 40.1],
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data) # doctest: +SKIP
LanceTable(table2)
>>> custom_schema = pa.schema([
... pa.field("vector", pa.list_(pa.float32(), 2)),
... pa.field("lat", pa.float32()),
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema) # doctest: +SKIP
LanceTable(table3)
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
>>> import pyarrow as pa
>>> def make_batches():
... for i in range(5):
... yield pa.RecordBatch.from_arrays(
... [
... pa.array([[3.1, 4.1], [5.9, 26.5]],
... pa.list_(pa.float32(), 2)),
... pa.array(["foo", "bar"]),
... pa.array([10.0, 20.0]),
... ],
... ["vector", "item", "price"],
... )
>>> schema=pa.schema([
... pa.field("vector", pa.list_(pa.float32(), 2)),
... pa.field("item", pa.utf8()),
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema) # doctest: +SKIP
LanceTable(table4)
"""
if exist_ok:
if mode == "create":
mode = "exist_ok"
elif not mode:
mode = "exist_ok"
if namespace is None:
namespace = []
validate_table_name(name)
if embedding_functions is not None:
logging.warning(
"embedding_functions is not yet supported on LanceDB Cloud."
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
from .table import RemoteTable
table = LOOP.run(
self._conn.create_table(
name,
data,
namespace=namespace,
mode=mode,
schema=schema,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
)
return RemoteTable(table, self.db_name)
@override
def drop_table(self, name: str, namespace: Optional[List[str]] = None):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
namespace: List[str], optional
The namespace to drop the table from.
None or empty list represents root namespace.
"""
if namespace is None:
namespace = []
LOOP.run(self._conn.drop_table(name, namespace=namespace))
@override
def rename_table(
self,
cur_name: str,
new_name: str,
cur_namespace: Optional[List[str]] = None,
new_namespace: Optional[List[str]] = None,
):
"""Rename a table in the database.
Parameters
----------
cur_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
if cur_namespace is None:
cur_namespace = []
if new_namespace is None:
new_namespace = []
LOOP.run(
self._conn.rename_table(
cur_name,
new_name,
cur_namespace=cur_namespace,
new_namespace=new_namespace,
)
)
async def close(self):
"""Close the connection to the database."""
self._client.close()