mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_location,
|
||||
get_uri_scheme,
|
||||
join_uri,
|
||||
validate_table_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
tbl = LanceTable.create(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from ..util import validate_table_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
from .errors import LanceDBClientError
|
||||
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
validate_table_name(name)
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
|
||||
@@ -25,6 +25,8 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
from ._lancedb import validate_table_name as native_validate_table_name
|
||||
|
||||
|
||||
def safe_import_adlfs():
|
||||
try:
|
||||
@@ -286,3 +288,8 @@ def deprecated(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
|
||||
.to_arrow()
|
||||
)
|
||||
assert table.num_rows == 1
|
||||
|
||||
|
||||
def test_create_table_with_invalid_names(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo/bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo$$bar", data)
|
||||
db.create_table("foo.bar", data)
|
||||
|
||||
Reference in New Issue
Block a user