mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
3 Commits
add-python
...
changhiskh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9eca8e7cd1 | ||
|
|
587fe6ffc1 | ||
|
|
89c8e5839b |
@@ -397,14 +397,6 @@ class LanceTable(Table):
|
||||
self.name = name
|
||||
self._version = version
|
||||
|
||||
def _reset_dataset(self, version=None):
|
||||
try:
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
self._version = version
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the schema of the table.
|
||||
@@ -413,16 +405,16 @@ class LanceTable(Table):
|
||||
-------
|
||||
pa.Schema
|
||||
A PyArrow schema object."""
|
||||
return self._dataset.schema
|
||||
return self.to_lance().schema
|
||||
|
||||
def list_versions(self):
|
||||
"""List all versions of the table"""
|
||||
return self._dataset.versions()
|
||||
return self.to_lance().versions()
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
return self._dataset.version
|
||||
return self.to_lance().version
|
||||
|
||||
def checkout(self, version: int):
|
||||
"""Checkout a version of the table. This is an in-place operation.
|
||||
@@ -455,14 +447,12 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = max([v["version"] for v in self.to_lance().versions()])
|
||||
if version < 1 or version > max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
self._reset_dataset(version=version)
|
||||
|
||||
try:
|
||||
# Accessing the property updates the cached value
|
||||
_ = self._dataset
|
||||
self.to_lance().checkout(version)
|
||||
except Exception as e:
|
||||
if "not found" in str(e):
|
||||
raise ValueError(
|
||||
@@ -505,7 +495,7 @@ class LanceTable(Table):
|
||||
>>> len(table.list_versions())
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = max([v["version"] for v in self.to_lance().versions()])
|
||||
if version is None:
|
||||
version = self.version
|
||||
elif version < 1 or version > max_ver:
|
||||
@@ -517,11 +507,10 @@ class LanceTable(Table):
|
||||
# no-op if restoring the latest version
|
||||
return
|
||||
|
||||
self._dataset.restore()
|
||||
self._reset_dataset()
|
||||
self.to_lance().restore()
|
||||
|
||||
def __len__(self):
|
||||
return self._dataset.count_rows()
|
||||
return self.to_lance().count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LanceTable({self.name})"
|
||||
@@ -531,7 +520,7 @@ class LanceTable(Table):
|
||||
|
||||
def head(self, n=5) -> pa.Table:
|
||||
"""Return the first n rows of the table."""
|
||||
return self._dataset.head(n)
|
||||
return self.to_lance().head(n)
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
@@ -548,7 +537,7 @@ class LanceTable(Table):
|
||||
Returns
|
||||
-------
|
||||
pa.Table"""
|
||||
return self._dataset.to_table()
|
||||
return self.to_lance().to_table()
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
@@ -575,7 +564,6 @@ class LanceTable(Table):
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("create_index")
|
||||
|
||||
def create_fts_index(
|
||||
@@ -607,7 +595,11 @@ class LanceTable(Table):
|
||||
raise ValueError(
|
||||
f"Index already exists. Use replace=True to overwrite."
|
||||
)
|
||||
fs.delete_dir(path)
|
||||
try:
|
||||
fs.delete_dir(path)
|
||||
except FileNotFoundError as e:
|
||||
if "Cannot get information for path" in str(e):
|
||||
pass
|
||||
|
||||
index = create_index(self._get_fts_index_path(), field_names)
|
||||
populate_index(index, self, field_names)
|
||||
@@ -662,8 +654,7 @@ class LanceTable(Table):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
self.to_lance().write(data, mode=mode)
|
||||
register_event("add")
|
||||
|
||||
def merge(
|
||||
@@ -724,10 +715,9 @@ class LanceTable(Table):
|
||||
other_table = other_table.to_lance()
|
||||
if isinstance(other_table, LanceDataset):
|
||||
other_table = other_table.to_table()
|
||||
self._dataset.merge(
|
||||
self.to_lance().merge(
|
||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("merge")
|
||||
|
||||
@cached_property
|
||||
@@ -930,7 +920,7 @@ class LanceTable(Table):
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
self.to_lance().delete(where)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -985,7 +975,6 @@ class LanceTable(Table):
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
register_event("update")
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
|
||||
@@ -95,12 +95,12 @@ def test_create_index_from_table(tmp_path, table):
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
table.create_fts_index("text")
|
||||
|
||||
table.create_fts_index("text", replace=True)
|
||||
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
table.create_fts_index("text")
|
||||
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
|
||||
@@ -226,39 +226,38 @@ def test_versioning(db):
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
|
||||
Reference in New Issue
Block a user