Compare commits

...

3 Commits

Author SHA1 Message Date
Chang She
9eca8e7cd1 tests pass; still need catalog 2023-12-21 20:13:08 -08:00
Chang She
587fe6ffc1 almost 2023-12-21 19:45:10 -08:00
Chang She
89c8e5839b initial changes to enable an in-memory dataset 2023-12-21 08:52:11 -08:00
3 changed files with 50 additions and 62 deletions

View File

@@ -397,14 +397,6 @@ class LanceTable(Table):
self.name = name
self._version = version
def _reset_dataset(self, version=None):
try:
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
self._version = version
except AttributeError:
pass
@property
def schema(self) -> pa.Schema:
"""Return the schema of the table.
@@ -413,16 +405,16 @@ class LanceTable(Table):
-------
pa.Schema
A PyArrow schema object."""
return self._dataset.schema
return self.to_lance().schema
def list_versions(self):
"""List all versions of the table"""
return self._dataset.versions()
return self.to_lance().versions()
@property
def version(self) -> int:
"""Get the current version of the table"""
return self._dataset.version
return self.to_lance().version
def checkout(self, version: int):
"""Checkout a version of the table. This is an in-place operation.
@@ -455,14 +447,12 @@ class LanceTable(Table):
vector type
0 [1.1, 0.9] vector
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = max([v["version"] for v in self.to_lance().versions()])
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try:
# Accessing the property updates the cached value
_ = self._dataset
self.to_lance().checkout(version)
except Exception as e:
if "not found" in str(e):
raise ValueError(
@@ -505,7 +495,7 @@ class LanceTable(Table):
>>> len(table.list_versions())
4
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = max([v["version"] for v in self.to_lance().versions()])
if version is None:
version = self.version
elif version < 1 or version > max_ver:
@@ -517,11 +507,10 @@ class LanceTable(Table):
# no-op if restoring the latest version
return
self._dataset.restore()
self._reset_dataset()
self.to_lance().restore()
def __len__(self):
return self._dataset.count_rows()
return self.to_lance().count_rows()
def __repr__(self) -> str:
return f"LanceTable({self.name})"
@@ -531,7 +520,7 @@ class LanceTable(Table):
def head(self, n=5) -> pa.Table:
"""Return the first n rows of the table."""
return self._dataset.head(n)
return self.to_lance().head(n)
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -548,7 +537,7 @@ class LanceTable(Table):
Returns
-------
pa.Table"""
return self._dataset.to_table()
return self.to_lance().to_table()
@property
def _dataset_uri(self) -> str:
@@ -575,7 +564,6 @@ class LanceTable(Table):
accelerator=accelerator,
index_cache_size=index_cache_size,
)
self._reset_dataset()
register_event("create_index")
def create_fts_index(
@@ -607,7 +595,11 @@ class LanceTable(Table):
raise ValueError(
f"Index already exists. Use replace=True to overwrite."
)
fs.delete_dir(path)
try:
fs.delete_dir(path)
except FileNotFoundError as e:
if "Cannot get information for path" in str(e):
pass
index = create_index(self._get_fts_index_path(), field_names)
populate_index(index, self, field_names)
@@ -662,8 +654,7 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset()
self.to_lance().write(data, mode=mode)
register_event("add")
def merge(
@@ -724,10 +715,9 @@ class LanceTable(Table):
other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset):
other_table = other_table.to_table()
self._dataset.merge(
self.to_lance().merge(
other_table, left_on=left_on, right_on=right_on, schema=schema
)
self._reset_dataset()
register_event("merge")
@cached_property
@@ -930,7 +920,7 @@ class LanceTable(Table):
return tbl
def delete(self, where: str):
self._dataset.delete(where)
self.to_lance().delete(where)
def update(
self,
@@ -985,7 +975,6 @@ class LanceTable(Table):
values_sql = {k: value_to_sql(v) for k, v in values.items()}
self.to_lance().update(values_sql, where)
self._reset_dataset()
register_event("update")
def _execute_query(self, query: Query) -> pa.Table:

View File

@@ -95,12 +95,12 @@ def test_create_index_from_table(tmp_path, table):
]
)
with pytest.raises(ValueError, match="already exists"):
table.create_fts_index("text")
table.create_fts_index("text", replace=True)
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
with pytest.raises(ValueError, match="already exists"):
table.create_fts_index("text")
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])

View File

@@ -226,39 +226,38 @@ def test_versioning(db):
def test_create_index_method():
with patch.object(LanceTable, "_reset_dataset", return_value=None):
with patch.object(
LanceTable, "_dataset", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
with patch.object(
LanceTable, "_dataset", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
def test_add_with_nans(db):