feat: add a filterable count_rows to all the lancedb APIs (#913)

A `count_rows` method that takes a filter was recently added to
`LanceTable`. This PR adds it everywhere else except `RemoteTable` (that
will come soon).
This commit is contained in:
Weston Pace
2024-02-08 09:40:29 -08:00
parent 2c3f982f4f
commit 138fc3f66b
11 changed files with 86 additions and 30 deletions

View File

@@ -37,6 +37,9 @@ class RemoteTable(Table):
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
@@ -409,6 +412,13 @@ class RemoteTable(Table):
"compact_files() is not supported on the LanceDB cloud"
)
def count_rows(self, filter: Optional[str] = None) -> int:
# payload = {"filter": filter}
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
return NotImplementedError(
"count_rows() is not yet supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(

View File

@@ -176,6 +176,18 @@ class Table(ABC):
"""
raise NotImplementedError
@abstractmethod
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -925,14 +937,6 @@ class LanceTable(Table):
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
return self._dataset.count_rows(filter)
def __len__(self):