mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-22 06:20:39 +00:00
fix(python): check all table pages for db membership (#3395)
## Summary - Fix `name in db` and `len(db)` for local Python connections with more than one page of tables. - Use `list_tables()` pagination instead of deprecated `table_names()` with its default 10-item page. - Add regression coverage with 20 tables so later pages are included. Fixes #2727. ## Validation - `python3 -m py_compile python/python/lancedb/db.py python/python/tests/test_db.py` - No-build Python harness that extracts and executes the edited `LanceDBConnection` pagination methods: passed - `uvx ruff check python/python/lancedb/db.py python/python/tests/test_db.py` - `uvx ruff format --check python/python/lancedb/db.py python/python/tests/test_db.py` Note: `uv run pytest python/tests/test_db.py::test_db_contains_and_len_include_all_table_name_pages -q` was attempted first, but it stayed in the broad Rust/PyO3 native extension build and was stopped before pytest started.
This commit is contained in:
@@ -8,7 +8,17 @@ from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -847,11 +857,20 @@ class LanceDBConnection(DBConnection):
|
||||
)
|
||||
)
|
||||
|
||||
def _all_table_names(self) -> Generator[str, None, None]:
|
||||
page_token = None
|
||||
while True:
|
||||
response = self.list_tables(page_token=page_token)
|
||||
yield from response.tables
|
||||
page_token = response.page_token
|
||||
if not page_token:
|
||||
return
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.table_names())
|
||||
return sum(1 for _ in self._all_table_names())
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.table_names()
|
||||
return name in self._all_table_names()
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
@@ -188,6 +189,43 @@ def test_table_names(tmp_db: lancedb.DBConnection):
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_db_contains_and_len_include_all_table_name_pages(tmp_db: lancedb.DBConnection):
|
||||
for idx in range(20):
|
||||
tmp_db.create_table(f"table_{idx}", data=[{"id": idx}])
|
||||
|
||||
assert len(tmp_db) == 20
|
||||
for idx in range(20):
|
||||
assert f"table_{idx}" in tmp_db
|
||||
assert "does_not_exist" not in tmp_db
|
||||
|
||||
|
||||
def test_db_contains_stops_after_matching_table_page(
|
||||
tmp_db: lancedb.DBConnection, monkeypatch
|
||||
):
|
||||
calls = []
|
||||
pages = {
|
||||
None: SimpleNamespace(tables=["table_0", "table_1"], page_token="next"),
|
||||
"next": SimpleNamespace(tables=["table_2"], page_token=None),
|
||||
}
|
||||
|
||||
def list_tables(*, page_token=None, **_kwargs):
|
||||
calls.append(page_token)
|
||||
return pages[page_token]
|
||||
|
||||
monkeypatch.setattr(tmp_db, "list_tables", list_tables)
|
||||
|
||||
assert "table_1" in tmp_db
|
||||
assert calls == [None]
|
||||
|
||||
calls.clear()
|
||||
assert "table_2" in tmp_db
|
||||
assert calls == [None, "next"]
|
||||
|
||||
calls.clear()
|
||||
assert len(tmp_db) == 3
|
||||
assert calls == [None, "next"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user