From 9ee152eb42e55ddfef8418ebce0556cf7509a0d1 Mon Sep 17 00:00:00 2001 From: Wyatt Alt Date: Wed, 7 May 2025 17:23:39 -0700 Subject: [PATCH] fix: support __len__ on remote table (#2379) This moves the __len__ method from LanceTable and RemoteTable to Table so that child classes don't need to implement their own. In the process, it fixes the implementation of RemoteTable's length method, which was previously missing a return statement. ## Summary by CodeRabbit - **Refactor** - Centralized the table length functionality in the base table class, simplifying subclass behavior. - Removed redundant or non-functional length methods from specific table classes. - **Tests** - Added a new test to verify correct table length reporting for remote tables. --- python/python/lancedb/remote/table.py | 3 --- python/python/lancedb/table.py | 7 ++++--- python/python/tests/test_remote_db.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index d8aae374..2de418da 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -47,9 +47,6 @@ class RemoteTable(Table): def __repr__(self) -> str: return f"RemoteTable({self.db_name}.{self.name})" - def __len__(self) -> int: - self.count_rows(None) - @property def schema(self) -> pa.Schema: """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 7bf43c9d..8fb84ff3 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -620,6 +620,10 @@ class Table(ABC): """ raise NotImplementedError + def __len__(self) -> int: + """The number of rows in this Table""" + return self.count_rows(None) + @property @abstractmethod def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: @@ -1762,9 +1766,6 @@ class LanceTable(Table): def count_rows(self, filter: Optional[str] = None) -> int: return LOOP.run(self._table.count_rows(filter)) - def __len__(self) -> int: - return self.count_rows() - def __repr__(self) -> str: val = f"{self.__class__.__name__}(name={self.name!r}, version={self.version}" if self._conn.read_consistency_interval is not None: diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 5a2584ee..a5f3feda 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -149,6 +149,24 @@ async def test_async_checkout(): assert await table.count_rows() == 300 +def test_table_len_sync(): + def handler(request): + if request.path == "/v1/table/test/create/?mode=create": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(json.dumps(1).encode()) + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}]) + assert len(table) == 1 + + @pytest.mark.asyncio async def test_http_error(): request_id_holder = {"request_id": None}