chore: validate table name (#1146)

Closes #1129
This commit is contained in:
Lei Xu
2024-03-21 14:46:13 -07:00
committed by GitHub
parent c0dd98c798
commit 25988d23cd
12 changed files with 106 additions and 9 deletions

View File

@@ -39,3 +39,5 @@ pin-project = "1.0.7"
snafu = "0.7.4"
url = "2"
num-traits = "0.2"
regex = "1.10"
lazy_static = "1"

View File

@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
from .util import (
fs_from_uri,
get_uri_location,
get_uri_scheme,
join_uri,
validate_table_name,
)
if TYPE_CHECKING:
from datetime import timedelta
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
"""
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
validate_table_name(name)
tbl = LanceTable.create(
self,

View File

@@ -26,6 +26,7 @@ from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
from .errors import LanceDBClientError
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4)
"""
validate_table_name(name)
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:

View File

@@ -25,6 +25,8 @@ import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs
from ._lancedb import validate_table_name as native_validate_table_name
def safe_import_adlfs():
try:
@@ -286,3 +288,8 @@ def deprecated(func):
return func(*args, **kwargs)
return new_func
def validate_table_name(name: str):
"""Verify the table name is valid."""
native_validate_table_name(name)

View File

@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
.to_arrow()
)
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)

View File

@@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}

View File

@@ -3,7 +3,7 @@ use std::sync::Mutex;
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
pyfunction, PyResult,
};
/// A wrapper around a rust builder
@@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
))),
}
}
#[pyfunction]
pub(crate) fn validate_table_name(table_name: &str) -> PyResult<()> {
lancedb::utils::validate_table_name(table_name)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

View File

@@ -22,6 +22,7 @@ chrono = { workspace = true }
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-index = { workspace = true }
lance-linalg = { workspace = true }
@@ -34,11 +35,10 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
[dev-dependencies]

View File

@@ -31,6 +31,7 @@ use crate::arrow::IntoArrow;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, WriteOptions};
use crate::utils::validate_table_name;
use crate::Table;
pub const LANCE_FILE_EXTENSION: &str = "lance";
@@ -675,13 +676,18 @@ impl Database {
/// Get the URI of a table in the database.
fn table_uri(&self, name: &str) -> Result<String> {
validate_table_name(name)?;
let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let mut uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.context(InvalidTableNameSnafu {
name,
reason: "Name is not valid URL",
})?
.to_string();
// If there are query string set on the connection, propagate to lance

View File

@@ -20,8 +20,8 @@ use snafu::Snafu;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum Error {
#[snafu(display("Invalid table name: {name}"))]
InvalidTableName { name: String },
#[snafu(display("Invalid table name (\"{name}\"): {reason}"))]
InvalidTableName { name: String, reason: String },
#[snafu(display("Invalid input, {message}"))]
InvalidInput { message: String },
#[snafu(display("Table '{name}' was not found"))]

View File

@@ -854,6 +854,7 @@ impl NativeTable {
.to_str()
.ok_or(Error::InvalidTableName {
name: uri.to_string(),
reason: "Table name is not valid URL".to_string(),
})?;
Ok(name.to_string())
}
@@ -1197,7 +1198,7 @@ impl NativeTable {
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}':
"The dimension of the query vector does not match with the dimension of the vector column '{}':
query dim={}, expected vector dim={}",
column,
query_vector.len(),

View File

@@ -1,12 +1,30 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_schema::Schema;
use lance::dataset::{ReadParams, WriteParams};
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
use crate::error::{Error, Result};
lazy_static! {
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
}
pub trait PatchStoreParam {
fn patch_with_store_wrapper(
self,
@@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams {
}
}
/// Validate table name.
pub fn validate_table_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason: "Table names cannot be empty strings".to_string(),
});
}
if !TABLE_NAME_REGEX.is_match(name) {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason:
"Table names can only contain alphanumeric characters, underscores, hyphens, and periods"
.to_string(),
});
}
Ok(())
}
/// Find one default column to create index.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find one fixed size list array column.
@@ -145,4 +182,20 @@ mod tests {
.to_string()
.contains("More than one"));
}
#[test]
fn test_validate_table_name() {
assert!(validate_table_name("my_table").is_ok());
assert!(validate_table_name("my_table_1").is_ok());
assert!(validate_table_name("123mytable").is_ok());
assert!(validate_table_name("_12345table").is_ok());
assert!(validate_table_name("table.12345").is_ok());
assert!(validate_table_name("table.._dot_..12345").is_ok());
assert!(validate_table_name("").is_err());
assert!(validate_table_name("my_table!").is_err());
assert!(validate_table_name("my/table").is_err());
assert!(validate_table_name("my@table").is_err());
assert!(validate_table_name("name with space").is_err());
}
}