diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 276116db7..41dbc0d83 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -8,7 +8,17 @@ from abc import abstractmethod from datetime import timedelta from pathlib import Path import sys -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Literal, + Optional, + Union, +) if sys.version_info >= (3, 12): from typing import override @@ -847,11 +857,20 @@ class LanceDBConnection(DBConnection): ) ) + def _all_table_names(self) -> Generator[str, None, None]: + page_token = None + while True: + response = self.list_tables(page_token=page_token) + yield from response.tables + page_token = response.page_token + if not page_token: + return + def __len__(self) -> int: - return len(self.table_names()) + return sum(1 for _ in self._all_table_names()) def __contains__(self, name: str) -> bool: - return name in self.table_names() + return name in self._all_table_names() @override def create_table( diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index f1173e2ff..d3db372de 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -6,6 +6,7 @@ import re import sys from datetime import timedelta import os +from types import SimpleNamespace import lancedb import numpy as np @@ -188,6 +189,43 @@ def test_table_names(tmp_db: lancedb.DBConnection): assert len(result) == 3 +def test_db_contains_and_len_include_all_table_name_pages(tmp_db: lancedb.DBConnection): + for idx in range(20): + tmp_db.create_table(f"table_{idx}", data=[{"id": idx}]) + + assert len(tmp_db) == 20 + for idx in range(20): + assert f"table_{idx}" in tmp_db + assert "does_not_exist" not in tmp_db + + +def test_db_contains_stops_after_matching_table_page( + tmp_db: lancedb.DBConnection, monkeypatch +): + calls = [] + pages = { + None: SimpleNamespace(tables=["table_0", "table_1"], page_token="next"), + "next": SimpleNamespace(tables=["table_2"], page_token=None), + } + + def list_tables(*, page_token=None, **_kwargs): + calls.append(page_token) + return pages[page_token] + + monkeypatch.setattr(tmp_db, "list_tables", list_tables) + + assert "table_1" in tmp_db + assert calls == [None] + + calls.clear() + assert "table_2" in tmp_db + assert calls == [None, "next"] + + calls.clear() + assert len(tmp_db) == 3 + assert calls == [None, "next"] + + @pytest.mark.asyncio async def test_table_names_async(tmp_path): db = lancedb.connect(tmp_path)