Merge branch 'main' of https://github.com/lancedb/lancedb into docs_march

This commit is contained in:
ayush chaurasia
2024-03-23 20:35:03 +05:30
21 changed files with 534 additions and 138 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.4
current_version = 0.6.5
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.6.4"
version = "0.6.5"
dependencies = [
"deprecation",
"pylance==0.10.4",
"pylance==0.10.5",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
@@ -94,13 +94,11 @@ lancedb = "lancedb.cli.cli:cli"
requires = ["maturin>=1.4"]
build-backend = "maturin"
[tool.ruff.lint]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio",

View File

@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
from .util import (
fs_from_uri,
get_uri_location,
get_uri_scheme,
join_uri,
validate_table_name,
)
if TYPE_CHECKING:
from datetime import timedelta
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
"""
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
validate_table_name(name)
tbl = LanceTable.create(
self,

View File

@@ -26,6 +26,7 @@ from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
from .errors import LanceDBClientError
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4)
"""
validate_table_name(name)
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:

View File

@@ -25,6 +25,8 @@ import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs
from ._lancedb import validate_table_name as native_validate_table_name
def safe_import_adlfs():
try:
@@ -286,3 +288,8 @@ def deprecated(func):
return func(*args, **kwargs)
return new_func
def validate_table_name(name: str):
"""Verify the table name is valid."""
native_validate_table_name(name)

View File

@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
.to_arrow()
)
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)

View File

@@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}

View File

@@ -3,7 +3,7 @@ use std::sync::Mutex;
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
pyfunction, PyResult,
};
/// A wrapper around a rust builder
@@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
))),
}
}
#[pyfunction]
pub(crate) fn validate_table_name(table_name: &str) -> PyResult<()> {
lancedb::utils::validate_table_name(table_name)
.map_err(|e| PyValueError::new_err(e.to_string()))
}