diff --git a/Cargo.toml b/Cargo.toml index bb982c30..d08e9ae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,3 +39,5 @@ pin-project = "1.0.7" snafu = "0.7.4" url = "2" num-traits = "0.2" +regex = "1.10" +lazy_static = "1" diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 486cd30b..0bd98016 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -31,7 +31,13 @@ from lancedb.utils.events import register_event from ._lancedb import connect as lancedb_connect from .pydantic import LanceModel from .table import AsyncTable, LanceTable, Table, _sanitize_data -from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri +from .util import ( + fs_from_uri, + get_uri_location, + get_uri_scheme, + join_uri, + validate_table_name, +) if TYPE_CHECKING: from datetime import timedelta @@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection): """ if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") + validate_table_name(name) tbl = LanceTable.create( self, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index f2ded712..9dff65c5 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -26,6 +26,7 @@ from ..db import DBConnection from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel from ..table import Table, _sanitize_data +from ..util import validate_table_name from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient from .errors import LanceDBClientError @@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection): LanceTable(table4) """ + validate_table_name(name) if data is None and schema is None: raise ValueError("Either data or schema must be provided.") if embedding_functions is not None: diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index f5987c06..4470754d 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -25,6 +25,8 @@ import numpy as np import pyarrow as pa import pyarrow.fs as pa_fs +from ._lancedb import validate_table_name as native_validate_table_name + def safe_import_adlfs(): try: @@ -286,3 +288,8 @@ def deprecated(func): return func(*args, **kwargs) return new_func + + +def validate_table_name(name: str): + """Verify the table name is valid.""" + native_validate_table_name(name) diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index c84c0800..fc4420ba 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path): .to_arrow() ) assert table.num_rows == 1 + + +def test_create_table_with_invalid_names(tmp_path): + db = lancedb.connect(uri=tmp_path) + data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)] + with pytest.raises(ValueError): + db.create_table("foo/bar", data) + with pytest.raises(ValueError): + db.create_table("foo bar", data) + with pytest.raises(ValueError): + db.create_table("foo$$bar", data) + db.create_table("foo.bar", data) diff --git a/python/src/lib.rs b/python/src/lib.rs index 558668cb..9d1f0a80 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; + m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) } diff --git a/python/src/util.rs b/python/src/util.rs index 893e8089..19662fac 100644 --- a/python/src/util.rs +++ b/python/src/util.rs @@ -3,7 +3,7 @@ use std::sync::Mutex; use lancedb::DistanceType; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, - PyResult, + pyfunction, PyResult, }; /// A wrapper around a rust builder @@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef) -> PyResult PyResult<()> { + lancedb::utils::validate_table_name(table_name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 93f38691..f7317a79 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true } object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } +lazy_static.workspace = true lance = { workspace = true } lance-index = { workspace = true } lance-linalg = { workspace = true } @@ -34,11 +35,10 @@ bytes = "1" futures.workspace = true num-traits.workspace = true url.workspace = true +regex.workspace = true serde = { version = "^1" } serde_json = { version = "1" } - # For remote feature - reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } [dev-dependencies] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 54ae8d27..06bf9c41 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -31,6 +31,7 @@ use crate::arrow::IntoArrow; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; use crate::table::{NativeTable, WriteOptions}; +use crate::utils::validate_table_name; use crate::Table; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -675,13 +676,18 @@ impl Database { /// Get the URI of a table in the database. fn table_uri(&self, name: &str) -> Result { + validate_table_name(name)?; + let path = Path::new(&self.uri); let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let mut uri = table_uri .as_path() .to_str() - .context(InvalidTableNameSnafu { name })? + .context(InvalidTableNameSnafu { + name, + reason: "Name is not valid URL", + })? .to_string(); // If there are query string set on the connection, propagate to lance diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 8baed35d..a528a177 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -20,8 +20,8 @@ use snafu::Snafu; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { - #[snafu(display("Invalid table name: {name}"))] - InvalidTableName { name: String }, + #[snafu(display("Invalid table name (\"{name}\"): {reason}"))] + InvalidTableName { name: String, reason: String }, #[snafu(display("Invalid input, {message}"))] InvalidInput { message: String }, #[snafu(display("Table '{name}' was not found"))] diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 2848cf41..e3e3ab2d 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -854,6 +854,7 @@ impl NativeTable { .to_str() .ok_or(Error::InvalidTableName { name: uri.to_string(), + reason: "Table name is not valid URL".to_string(), })?; Ok(name.to_string()) } @@ -1197,7 +1198,7 @@ impl NativeTable { if dim != query_vector.len() as i32 { return Err(Error::InvalidInput { message: format!( - "The dimension of the query vector does not match with the dimension of the vector column '{}': + "The dimension of the query vector does not match with the dimension of the vector column '{}': query dim={}, expected vector dim={}", column, query_vector.len(), diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 05499017..d6578f81 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -1,12 +1,30 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::sync::Arc; use arrow_schema::Schema; - use lance::dataset::{ReadParams, WriteParams}; use lance::io::{ObjectStoreParams, WrappingObjectStore}; +use lazy_static::lazy_static; use crate::error::{Error, Result}; +lazy_static! { + static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap(); +} + pub trait PatchStoreParam { fn patch_with_store_wrapper( self, @@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams { } } +/// Validate table name. +pub fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::InvalidTableName { + name: name.to_string(), + reason: "Table names cannot be empty strings".to_string(), + }); + } + if !TABLE_NAME_REGEX.is_match(name) { + return Err(Error::InvalidTableName { + name: name.to_string(), + reason: + "Table names can only contain alphanumeric characters, underscores, hyphens, and periods" + .to_string(), + }); + } + Ok(()) +} + /// Find one default column to create index. pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result { // Try to find one fixed size list array column. @@ -145,4 +182,20 @@ mod tests { .to_string() .contains("More than one")); } + + #[test] + fn test_validate_table_name() { + assert!(validate_table_name("my_table").is_ok()); + assert!(validate_table_name("my_table_1").is_ok()); + assert!(validate_table_name("123mytable").is_ok()); + assert!(validate_table_name("_12345table").is_ok()); + assert!(validate_table_name("table.12345").is_ok()); + assert!(validate_table_name("table.._dot_..12345").is_ok()); + + assert!(validate_table_name("").is_err()); + assert!(validate_table_name("my_table!").is_err()); + assert!(validate_table_name("my/table").is_err()); + assert!(validate_table_name("my@table").is_err()); + assert!(validate_table_name("name with space").is_err()); + } }