mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
@@ -39,3 +39,5 @@ pin-project = "1.0.7"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
|
||||
@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_location,
|
||||
get_uri_scheme,
|
||||
join_uri,
|
||||
validate_table_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
tbl = LanceTable.create(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from ..util import validate_table_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
from .errors import LanceDBClientError
|
||||
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
validate_table_name(name)
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
|
||||
@@ -25,6 +25,8 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
from ._lancedb import validate_table_name as native_validate_table_name
|
||||
|
||||
|
||||
def safe_import_adlfs():
|
||||
try:
|
||||
@@ -286,3 +288,8 @@ def deprecated(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
|
||||
.to_arrow()
|
||||
)
|
||||
assert table.num_rows == 1
|
||||
|
||||
|
||||
def test_create_table_with_invalid_names(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo/bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo$$bar", data)
|
||||
db.create_table("foo.bar", data)
|
||||
|
||||
@@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<VectorQuery>()?;
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Mutex;
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
pyfunction, PyResult,
|
||||
};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
@@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub(crate) fn validate_table_name(table_name: &str) -> PyResult<()> {
|
||||
lancedb::utils::validate_table_name(table_name)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ chrono = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lazy_static.workspace = true
|
||||
lance = { workspace = true }
|
||||
lance-index = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
@@ -34,11 +35,10 @@ bytes = "1"
|
||||
futures.workspace = true
|
||||
num-traits.workspace = true
|
||||
url.workspace = true
|
||||
regex.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
|
||||
# For remote feature
|
||||
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::arrow::IntoArrow;
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
@@ -675,13 +676,18 @@ impl Database {
|
||||
|
||||
/// Get the URI of a table in the database.
|
||||
fn table_uri(&self, name: &str) -> Result<String> {
|
||||
validate_table_name(name)?;
|
||||
|
||||
let path = Path::new(&self.uri);
|
||||
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
|
||||
let mut uri = table_uri
|
||||
.as_path()
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu { name })?
|
||||
.context(InvalidTableNameSnafu {
|
||||
name,
|
||||
reason: "Name is not valid URL",
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
// If there are query string set on the connection, propagate to lance
|
||||
|
||||
@@ -20,8 +20,8 @@ use snafu::Snafu;
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
#[snafu(display("Invalid table name: {name}"))]
|
||||
InvalidTableName { name: String },
|
||||
#[snafu(display("Invalid table name (\"{name}\"): {reason}"))]
|
||||
InvalidTableName { name: String, reason: String },
|
||||
#[snafu(display("Invalid input, {message}"))]
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
|
||||
@@ -854,6 +854,7 @@ impl NativeTable {
|
||||
.to_str()
|
||||
.ok_or(Error::InvalidTableName {
|
||||
name: uri.to_string(),
|
||||
reason: "Table name is not valid URL".to_string(),
|
||||
})?;
|
||||
Ok(name.to_string())
|
||||
}
|
||||
@@ -1197,7 +1198,7 @@ impl NativeTable {
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}':
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}':
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Schema;
|
||||
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
lazy_static! {
|
||||
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
|
||||
}
|
||||
|
||||
pub trait PatchStoreParam {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
@@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams {
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate table name.
|
||||
pub fn validate_table_name(name: &str) -> Result<()> {
|
||||
if name.is_empty() {
|
||||
return Err(Error::InvalidTableName {
|
||||
name: name.to_string(),
|
||||
reason: "Table names cannot be empty strings".to_string(),
|
||||
});
|
||||
}
|
||||
if !TABLE_NAME_REGEX.is_match(name) {
|
||||
return Err(Error::InvalidTableName {
|
||||
name: name.to_string(),
|
||||
reason:
|
||||
"Table names can only contain alphanumeric characters, underscores, hyphens, and periods"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find one default column to create index.
|
||||
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
|
||||
// Try to find one fixed size list array column.
|
||||
@@ -145,4 +182,20 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("More than one"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_table_name() {
|
||||
assert!(validate_table_name("my_table").is_ok());
|
||||
assert!(validate_table_name("my_table_1").is_ok());
|
||||
assert!(validate_table_name("123mytable").is_ok());
|
||||
assert!(validate_table_name("_12345table").is_ok());
|
||||
assert!(validate_table_name("table.12345").is_ok());
|
||||
assert!(validate_table_name("table.._dot_..12345").is_ok());
|
||||
|
||||
assert!(validate_table_name("").is_err());
|
||||
assert!(validate_table_name("my_table!").is_err());
|
||||
assert!(validate_table_name("my/table").is_err());
|
||||
assert!(validate_table_name("my@table").is_err());
|
||||
assert!(validate_table_name("name with space").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user