# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors import os import sys import threading import warnings from datetime import date, datetime, timedelta from time import sleep from typing import List from unittest.mock import patch import lancedb from lancedb.dependencies import _PANDAS_AVAILABLE from lancedb.index import BTree, FTS, HnswFlat, HnswPq, HnswSq, IvfPq import numpy as np import polars as pl import pyarrow as pa import pyarrow.dataset import pytest from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import AsyncConnection, DBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.expr import col, lit from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable from pydantic import BaseModel def _blob_test_data(): return pa.table( { "id": pa.array([1, 2], pa.int64()), "blob": pa.array([b"hello", b"world"], pa.large_binary()), }, schema=pa.schema( [ pa.field("id", pa.int64()), pa.field( "blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"} ), ] ), ) def _assert_lazy_blob(value, expected: bytes): assert hasattr(value, "readall") assert value.readall() == expected def test_basic(mem_db: DBConnection): data = [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ] table = mem_db.create_table("test", data=data) assert table.name == "test" assert "LanceTable(name='test', _conn=LanceDBConnection(" in repr(table) expected_schema = pa.schema( { "vector": pa.list_(pa.float32(), 2), "item": pa.string(), "price": pa.float64(), } ) assert table.schema == expected_schema expected_data = pa.Table.from_pylist(data, schema=expected_schema) assert table.to_arrow() == expected_data def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection): pd = pytest.importorskip("pandas") data = pa.table({"id": [1, 2], "text": ["one", "two"]}) table = tmp_db.create_table("test_to_pandas_old_call", data=data) expected = data.to_pandas() pd.testing.assert_frame_equal(table.to_pandas(), expected) def test_table_to_pandas_invalid_blob_mode_non_blob_table(tmp_db: DBConnection): data = pa.table({"id": [1, 2], "text": ["one", "two"]}) table = tmp_db.create_table("test_to_pandas_invalid_blob_mode", data=data) with pytest.raises(ValueError, match="blob_mode must be one of"): table.to_pandas(blob_mode="invalid") @pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"]) def test_table_to_pandas_blob_modes(tmp_db: DBConnection, blob_mode): pytest.importorskip("lance") table = tmp_db.create_table(f"test_to_pandas_blob_{blob_mode}", _blob_test_data()) df = table.to_pandas(blob_mode=blob_mode) if blob_mode == "lazy": _assert_lazy_blob(df["blob"].iloc[0], b"hello") _assert_lazy_blob(df["blob"].iloc[1], b"world") elif blob_mode == "bytes": assert df["blob"].tolist() == [b"hello", b"world"] else: first = df["blob"].iloc[0] assert first != b"hello" assert not hasattr(first, "readall") def test_table_to_pandas_kwargs(tmp_db: DBConnection): pd = pytest.importorskip("pandas") data = pa.table({"id": pa.array([1, 2], pa.int64())}) table = tmp_db.create_table("test_to_pandas_kwargs", data=data) df = table.to_pandas(types_mapper=pd.ArrowDtype) assert str(df["id"].dtype) == "int64[pyarrow]" @pytest.mark.asyncio async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection): pytest.importorskip("lance") table = await tmp_db_async.create_table( "test_async_to_pandas_blob_bytes", data=_blob_test_data() ) df = await table.to_pandas(blob_mode="bytes") assert df["blob"].tolist() == [b"hello", b"world"] @pytest.mark.asyncio async def test_async_table_to_pandas_invalid_blob_mode_non_blob_table( tmp_db_async: AsyncConnection, ): table = await tmp_db_async.create_table( "test_async_to_pandas_invalid_blob_mode", data=pa.table({"id": [1, 2], "text": ["one", "two"]}), ) with pytest.raises(ValueError, match="blob_mode must be one of"): await table.to_pandas(blob_mode="invalid") @pytest.mark.asyncio async def test_async_table_to_pandas_kwargs(tmp_db_async: AsyncConnection): pd = pytest.importorskip("pandas") data = pa.table({"id": pa.array([1, 2], pa.int64())}) table = await tmp_db_async.create_table("test_async_to_pandas_kwargs", data=data) df = await table.to_pandas(types_mapper=pd.ArrowDtype) assert str(df["id"].dtype) == "int64[pyarrow]" def test_create_table_infers_large_int_vectors(mem_db: DBConnection): data = [{"vector": [0, 300]}] table = mem_db.create_table( "int_vector_overflow", data=data, mode="overwrite", exist_ok=True ) vector_field = table.schema.field("vector") assert vector_field.type == pa.list_(pa.float32(), 2) vector_column = table.to_arrow().column("vector") assert vector_column.type == pa.list_(pa.float32(), 2) assert vector_column.to_pylist() == [[0.0, 300.0]] @pytest.mark.asyncio async def test_create_table_async_infers_large_int_vectors( mem_db_async: AsyncConnection, ): data = [{"vector": [256, 257]}] table = await mem_db_async.create_table( "int_vector_overflow_async", data=data, mode="overwrite", exist_ok=True ) schema = await table.schema() assert schema.field("vector").type == pa.list_(pa.float32(), 2) vector_column = (await table.to_arrow()).column("vector") assert vector_column.type == pa.list_(pa.float32(), 2) assert vector_column.to_pylist() == [[256.0, 257.0]] def test_input_data_type(mem_db: DBConnection, tmp_path): schema = pa.schema( { "id": pa.int64(), "name": pa.string(), "age": pa.int32(), } ) data = { "id": [1, 2, 3, 4, 5], "name": ["Alice", "Bob", "Charlie", "David", "Eve"], "age": [25, 30, 35, 40, 45], } record_batch = pa.RecordBatch.from_pydict(data, schema=schema) pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch]) pa_table = pa.Table.from_batches([record_batch]) def create_dataset(tmp_path): path = os.path.join(tmp_path, "test_source_dataset") pa.dataset.write_dataset(pa_table, path, format="parquet") return pa.dataset.dataset(path, format="parquet") pa_dataset = create_dataset(tmp_path) pa_scanner = pa_dataset.scanner() input_types = [ ("RecordBatchReader", pa_reader), ("RecordBatch", record_batch), ("Table", pa_table), ("Dataset", pa_dataset), ("Scanner", pa_scanner), ] for input_type, input_data in input_types: table_name = f"test_{input_type.lower()}" table = mem_db.create_table(table_name, data=input_data) assert table.schema == schema assert table.count_rows() == 5 assert table.schema == schema assert table.to_arrow() == pa_table @pytest.mark.asyncio async def test_close(mem_db_async: AsyncConnection): table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) assert table.is_open() table.close() assert not table.is_open() with pytest.raises(Exception, match="Table some_table is closed"): await table.count_rows() assert str(table) == "ClosedTable(some_table)" @pytest.mark.asyncio async def test_update_async(mem_db_async: AsyncConnection): table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) assert await table.count_rows("id == 0") == 1 assert await table.count_rows("id == 7") == 0 update_res = await table.update({"id": 7}) assert update_res.rows_updated == 1 assert update_res.version == 2 assert await table.count_rows("id == 7") == 1 assert await table.count_rows("id == 0") == 0 add_res = await table.add([{"id": 2}]) assert add_res.version == 3 update_res = await table.update(where="id % 2 == 0", updates_sql={"id": "5"}) assert update_res.rows_updated == 1 assert update_res.version == 4 assert await table.count_rows("id == 7") == 1 assert await table.count_rows("id == 2") == 0 assert await table.count_rows("id == 5") == 1 update_res = await table.update({"id": 10}, where="id == 5") assert update_res.rows_updated == 1 assert update_res.version == 5 assert await table.count_rows("id == 10") == 1 def test_create_table(mem_db: DBConnection): schema = pa.schema( { "vector": pa.list_(pa.float32(), 2), "item": pa.string(), "price": pa.float64(), } ) expected = pa.table( { "vector": [[3.1, 4.1], [5.9, 26.5]], "item": ["foo", "bar"], "price": [10.0, 20.0], }, schema=schema, ) rows = [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ] pa_table = pa.Table.from_pylist(rows, schema=schema) data = [ ("Rows", rows), ("pa_Table", pa_table), ] if _PANDAS_AVAILABLE: import pandas as pd df = pd.DataFrame(rows) data.append(("pd_DataFrame", df)) for name, d in data: tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow() assert expected == tbl def test_create_table_rejects_single_dictionary(mem_db: DBConnection): data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0} with pytest.raises(ValueError) as excep_info: mem_db.create_table("test", data=data) assert ( str(excep_info.value) == "Cannot create or add rows from a single dictionary. " "Use a list of dictionaries instead." ) def test_empty_table(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.string()), pa.field("price", pa.float32()), ] ) tbl = mem_db.create_table("test", schema=schema) data = [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ] tbl.add(data=data) def test_add_dictionary(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.string()), pa.field("price", pa.float32()), ] ) tbl = mem_db.create_table("test", schema=schema) data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0} with pytest.raises(ValueError) as excep_info: tbl.add(data=data) assert ( str(excep_info.value) == "Cannot create or add rows from a single dictionary. " "Use a list of dictionaries instead." ) def test_add(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.string()), pa.field("price", pa.float64()), ] ) def _add(table, schema): assert len(table) == 2 table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) assert len(table) == 3 expected = pa.table( { "vector": [[3.1, 4.1], [5.9, 26.5], [6.3, 100.5]], "item": ["foo", "bar", "new"], "price": [10.0, 20.0, 30.0], }, schema=schema, ) assert expected == table.to_arrow() # Append to table created with data table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) _add(table, schema) # Append to table created empty with schema table = mem_db.create_table("test2", schema=schema) table.add( data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) _add(table, schema) def test_add_struct(mem_db: DBConnection): # https://github.com/lancedb/lancedb/issues/2114 schema = pa.schema( [ ( "stuff", pa.struct( [ ("b", pa.int64()), ("a", pa.int64()), # TODO: also test subset of nested. ] ), ) ] ) # Create test data with fields in same order data = [{"stuff": {"b": 1, "a": 2}}] # pa.Table.from_pylist() will reorder the fields. We need to make sure # we fix the field order later, before casting. table = mem_db.create_table("test", schema=schema) table.add(data) data = [{"stuff": {"b": 4}}] table.add(data) expected = pa.table( { "stuff": [{"b": 1, "a": 2}, {"b": 4, "a": None}], }, schema=schema, ) assert table.to_arrow() == expected # Also check struct in list schema = pa.schema( { "s_list": pa.list_( pa.struct( [ ("b", pa.int64()), ("a", pa.int64()), ] ) ) } ) data = [{"s_list": [{"b": 1, "a": 2}, {"b": 4}]}] table = mem_db.create_table("test2", schema=schema) table.add(data) struct_type = pa.struct( [ ("b", pa.int64()), ("a", pa.int64()), ] ) expected = pa.table( { "s_list": [ [ pa.scalar({"b": 1, "a": 2}, type=struct_type), pa.scalar({"b": 4, "a": None}, type=struct_type), ] ], } ) assert table.to_arrow() == expected def test_add_subschema(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), pa.field("item", pa.string(), nullable=True), pa.field("price", pa.float64(), nullable=False), ] ) table = mem_db.create_table("test", schema=schema) data = {"price": 10.0, "item": "foo"} table.add([data]) data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]}) table.add(data) data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"} table.add([data]) expected = pa.table( { "vector": [None, [3.1, 4.1], [5.9, 26.5]], "item": ["foo", None, "bar"], "price": [10.0, 2.0, 3.0], }, schema=schema, ) assert table.to_arrow() == expected data = {"item": "foo"} # We can't omit a column if it's not nullable with pytest.raises(RuntimeError, match="Append with different schema"): table.add([data]) # We can add it if we make the column nullable table.alter_columns(dict(path="price", nullable=True)) table.add([data]) expected_schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), pa.field("item", pa.string(), nullable=True), pa.field("price", pa.float64(), nullable=True), ] ) expected = pa.table( { "vector": [None, [3.1, 4.1], [5.9, 26.5], None], "item": ["foo", None, "bar", "foo"], "price": [10.0, 2.0, 3.0, None], }, schema=expected_schema, ) assert table.to_arrow() == expected def test_add_nullability(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=False), pa.field("id", pa.string(), nullable=False), ] ) table = mem_db.create_table("test", schema=schema) assert table.schema.field("vector").nullable is False nullable_schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), pa.field("id", pa.string(), nullable=True), ] ) data = pa.table( { "vector": [[3.1, 4.1], [5.9, 26.5]], "id": ["foo", "bar"], }, schema=nullable_schema, ) # We can add nullable schema if it doesn't actually contain nulls table.add(data) expected = data.cast(schema) assert table.to_arrow() == expected data = pa.table( { "vector": [None], "id": ["baz"], }, schema=nullable_schema, ) # We can't add nullable schema if it contains nulls with pytest.raises( Exception, match=( "The field `vector` contained null values even though " "the field is marked non-null in the schema" ), ): table.add(data) # But we can make it nullable table.alter_columns(dict(path="vector", nullable=True)) table.add(data) expected_schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), pa.field("id", pa.string(), nullable=False), ] ) expected = pa.table( { "vector": [[3.1, 4.1], [5.9, 26.5], None], "id": ["foo", "bar", "baz"], }, schema=expected_schema, ) assert table.to_arrow() == expected def test_add_pydantic_model(mem_db: DBConnection): pytest.importorskip("pandas") # https://github.com/lancedb/lancedb/issues/562 class Metadata(BaseModel): source: str timestamp: datetime class Document(BaseModel): content: str meta: Metadata class LanceSchema(LanceModel): id: str vector: Vector(2) li: List[int] payload: Document tbl = mem_db.create_table("mytable", schema=LanceSchema, mode="overwrite") assert tbl.schema == LanceSchema.to_arrow_schema() # add works expected = LanceSchema( id="id", vector=[0.0, 0.0], li=[1, 2, 3], payload=Document( content="foo", meta=Metadata(source="bar", timestamp=datetime.now()) ), ) add_res = tbl.add([expected]) assert add_res.version == 2 result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0] assert result == expected flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=1) assert len(flattened.columns) == 6 # _distance is automatically added really_flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=True) assert len(really_flattened.columns) == 7 @pytest.mark.asyncio async def test_add_async(mem_db_async: AsyncConnection): table = await mem_db_async.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) assert await table.count_rows() == 2 add_res = await table.add( data=[ {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, ], ) assert add_res.version == 2 assert await table.count_rows() == 3 def test_add_overwrite_infers_vector_schema(mem_db: DBConnection): """Overwrite should infer vector columns the same way create_table does. Regression test for https://github.com/lancedb/lancedb/issues/3183 """ table = mem_db.create_table( "test_overwrite_vec", data=[ {"vector": [1.0, 2.0, 3.0, 4.0], "item": "foo"}, {"vector": [5.0, 6.0, 7.0, 8.0], "item": "bar"}, ], ) # create_table infers vector as fixed_size_list original_type = table.schema.field("vector").type assert pa.types.is_fixed_size_list(original_type) # overwrite with plain Python lists (PyArrow infers list) table.add( [ {"vector": [10.0, 20.0, 30.0, 40.0], "item": "baz"}, ], mode="overwrite", ) # overwrite should infer vector column the same way as create_table new_type = table.schema.field("vector").type assert pa.types.is_fixed_size_list(new_type), ( f"Expected fixed_size_list after overwrite, got {new_type}" ) def test_add_progress_callback(mem_db: DBConnection): table = mem_db.create_table( "test", data=[{"id": 1}, {"id": 2}], ) updates = [] table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p))) assert len(table) == 4 # The done callback always fires, so we should always get at least one. assert len(updates) >= 1, "expected at least one progress callback" for p in updates: assert "output_rows" in p assert "output_bytes" in p assert "total_rows" in p assert "elapsed_seconds" in p assert "active_tasks" in p assert "total_tasks" in p assert "done" in p # The last callback should have done=True. assert updates[-1]["done"] is True def test_add_progress_tqdm_like(mem_db: DBConnection): """Test that a tqdm-like object gets total set and update() called.""" class FakeBar: def __init__(self): self.total = None self.n = 0 self.postfix = None def update(self, n): self.n += n def set_postfix_str(self, s): self.postfix = s def refresh(self): pass table = mem_db.create_table( "test", data=[{"id": 1}, {"id": 2}], ) bar = FakeBar() table.add([{"id": 3}, {"id": 4}], progress=bar) assert len(table) == 4 # Postfix should contain throughput and worker count if bar.postfix is not None: assert "MB/s" in bar.postfix assert "workers" in bar.postfix def test_add_progress_bool(mem_db: DBConnection): """Test that progress=True creates and closes a tqdm bar automatically.""" table = mem_db.create_table( "test", data=[{"id": 1}, {"id": 2}], ) table.add([{"id": 3}, {"id": 4}], progress=True) assert len(table) == 4 # progress=False should be the same as None table.add([{"id": 5}], progress=False) assert len(table) == 5 @pytest.mark.asyncio async def test_add_progress_callback_async(mem_db_async: AsyncConnection): """Progress callbacks work through the async path too.""" table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}]) updates = [] await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p))) assert await table.count_rows() == 4 assert len(updates) >= 1 assert updates[-1]["done"] is True def test_add_progress_callback_error(mem_db: DBConnection): """A failing callback must not prevent the write from succeeding.""" table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}]) def bad_callback(p): raise RuntimeError("boom") table.add([{"id": 3}, {"id": 4}], progress=bad_callback) assert len(table) == 4 def test_polars(mem_db: DBConnection): data = { "vector": [[3.1, 4.1], [5.9, 26.5]], "item": ["foo", "bar"], "price": [10.0, 20.0], } # Ingest polars dataframe table = mem_db.create_table("test", data=pl.DataFrame(data)) assert len(table) == 2 result = table.to_arrow() assert np.allclose(result["vector"].to_pylist(), data["vector"]) assert result["item"].to_pylist() == data["item"] assert np.allclose(result["price"].to_pylist(), data["price"]) schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.large_string()), pa.field("price", pa.float64()), ] ) assert table.schema == schema # search results to polars dataframe q = [3.1, 4.1] result = table.search(q).limit(1).to_polars() assert np.allclose(result["vector"][0], q) assert result["item"][0] == "foo" assert np.allclose(result["price"][0], 10.0) # enter table to polars dataframe result = table.to_polars() assert np.allclose(result.collect()["vector"].to_list(), data["vector"]) # make sure filtering isn't broken filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect() assert len(filtered_result) == 2 def test_versioning(mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) assert len(table.list_versions()) == 1 assert table.version == 1 table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) assert len(table.list_versions()) == 2 assert table.version == 2 assert len(table) == 3 table.checkout(1) assert table.version == 1 assert len(table) == 2 def test_tags(mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) table.tags.create("tag1", 1) tags = table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 1 table.add( data=[ {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, ], ) table.tags.create("tag2", 2) tags = table.tags.list() assert "tag1" in tags assert "tag2" in tags assert tags["tag1"]["version"] == 1 assert tags["tag2"]["version"] == 2 table.tags.delete("tag2") table.tags.update("tag1", 2) tags = table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 2 table.tags.update("tag1", 1) tags = table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 1 table.checkout("tag1") assert table.version == 1 assert table.count_rows() == 2 table.tags.create("tag2", 2) table.checkout("tag2") assert table.version == 2 assert table.count_rows() == 3 table.checkout_latest() table.add( data=[ {"vector": [12.0, 13.0], "item": "baz", "price": 40.0}, ], ) @pytest.mark.asyncio async def test_async_tags(mem_db_async: AsyncConnection): table = await mem_db_async.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) await table.tags.create("tag1", 1) tags = await table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 1 await table.add( data=[ {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, ], ) await table.tags.create("tag2", 2) tags = await table.tags.list() assert "tag1" in tags assert "tag2" in tags assert tags["tag1"]["version"] == 1 assert tags["tag2"]["version"] == 2 await table.tags.delete("tag2") await table.tags.update("tag1", 2) tags = await table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 2 await table.tags.update("tag1", 1) tags = await table.tags.list() assert "tag1" in tags assert tags["tag1"]["version"] == 1 await table.checkout("tag1") assert await table.version() == 1 assert await table.count_rows() == 2 await table.tags.create("tag2", 2) await table.checkout("tag2") assert await table.version() == 2 assert await table.count_rows() == 3 await table.checkout_latest() await table.add( data=[ {"vector": [12.0, 13.0], "item": "baz", "price": 40.0}, ], ) def test_branches(tmp_path): db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0)) table = db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) assert table.count_rows() == 2 # fork an isolated, writable branch from main branch = table.branches.create("exp") assert branch.count_rows() == 2 branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) # writes on the branch do not touch main assert branch.count_rows() == 3 assert table.count_rows() == 2 # the branch is listed, with main (None) as its parent branches = table.branches.list() assert "exp" in branches assert branches["exp"]["parent_branch"] is None # from_ref="main" is equivalent to the default table.branches.create("exp2", from_ref="main") assert table.branches.list()["exp2"]["parent_branch"] is None # checkout returns a handle scoped to the branch's latest checked_out = table.branches.checkout("exp") assert checked_out.count_rows() == 3 # delete removes it table.branches.delete("exp") table.branches.delete("exp2") assert "exp" not in table.branches.list() def test_branch_handle_tracks_concurrent_writes(tmp_path): db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0)) table = db.create_table("t", [{"id": 1}]) # two independent handles on the same branch writer = table.branches.create("exp") reader = db.open_table("t", branch="exp") assert reader.count_rows() == 1 # a concurrent write on the branch is visible to the other handle writer.add([{"id": 2}]) assert reader.count_rows() == 2 # main is unaffected assert table.count_rows() == 1 def test_branch_name_validation(tmp_path): db = lancedb.connect(tmp_path) table = db.create_table("t", [{"id": 1}]) with pytest.raises(ValueError, match="non-empty"): table.branches.create("") with pytest.raises(ValueError, match="non-empty"): table.branches.checkout("") with pytest.raises(ValueError, match="non-empty"): table.branches.delete("") def test_branches_preserve_namespace(tmp_path): pytest.importorskip( "lance" ) # namespace_path routes through lance's DirectoryNamespace db = lancedb.connect(tmp_path) table = db.create_table("t", [{"id": 1}], namespace_path=["ns1"]) assert table.namespace == ["ns1"] branch = table.branches.create("exp") assert branch.namespace == ["ns1"] assert branch.id == table.id # opening the branch directly also preserves namespace identity opened = db.open_table("t", namespace_path=["ns1"], branch="exp") assert opened.namespace == ["ns1"] def test_open_table_with_branch(tmp_path): db = lancedb.connect(tmp_path) table = db.create_table("t", [{"i": 1}]) table.branches.create("exp").add([{"i": 2}]) # open_table(branch=...) returns a handle scoped to the branch assert db.open_table("t", branch="exp").count_rows() == 2 # opening without branch still tracks main assert db.open_table("t").count_rows() == 1 def test_open_table_with_branch_version(tmp_path): db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0)) # main: a single fork-point row t = db.create_table("t", [{"i": 0}]) main_v1 = t.version # fork "exp", then advance exp AND main independently past the fork so they # diverge while sharing version numbers exp = t.branches.create("exp") exp.add([{"i": 1}]) # exp: {0, 1} exp_v2 = exp.version exp.add([{"i": 2}]) # exp HEAD: {0, 1, 2} t.add([{"i": 100}, {"i": 101}, {"i": 102}]) # main HEAD: {0, 100, 101, 102} assert exp_v2 == t.version, "branch and main must share the version number" # open exp at the shared version: the data must be exp's, not main's. count # alone cannot prove this (main@v2 also exists), so assert provenance by # content. pinned = db.open_table("t", branch="exp", version=exp_v2) assert pinned.current_branch() == "exp" assert pinned.count_rows() == 2 # not exp HEAD (3), not main@v2 (4) assert pinned.count_rows("i = 1") == 1 # exp's post-fork row is visible assert pinned.count_rows("i = 100") == 0 # main's divergent rows are invisible # the same coordinate is reachable directly via branches.checkout(name, version) pinned_direct = t.branches.checkout("exp", exp_v2) assert pinned_direct.current_branch() == "exp" assert pinned_direct.count_rows() == 2 # the HEADs are unaffected assert db.open_table("t", branch="exp").count_rows() == 3 assert db.open_table("t").count_rows() == 4 # version-only (no branch) time-travels main itself: its fork-point version # holds only main's first row, and the shared version number resolves to # main's data, not the branch's ("opens main at the version") old_main = db.open_table("t", version=main_v1) assert old_main.current_branch() is None assert old_main.count_rows() == 1 shared_on_main = db.open_table("t", version=exp_v2) assert shared_on_main.current_branch() is None assert shared_on_main.count_rows() == 4 # detached head: writing to a pinned version is rejected with pytest.raises((ValueError, RuntimeError), match="cannot be modified"): pinned.add([{"i": 9}]) # a nonexistent version is rejected -- on main, and on a branch (a distinct # resolution path, on the branch's manifests) with pytest.raises((ValueError, RuntimeError)): db.open_table("t", version=9999) with pytest.raises((ValueError, RuntimeError)): db.open_table("t", branch="exp", version=9999) # checkout_latest re-attaches the pinned handle to the BRANCH's HEAD # (writable again), not main's HEAD, and not staying pinned pinned.checkout_latest() assert pinned.current_branch() == "exp" assert pinned.count_rows() == 3 # exp HEAD, not main's 4 pinned.add([{"i": 3}]) assert pinned.count_rows() == 4 # writable again @pytest.mark.asyncio async def test_async_namespace_open_table_with_branch(tmp_path): pytest.importorskip("lance") # "dir" impl is lance.namespace.DirectoryNamespace db = lancedb.connect_namespace_async("dir", {"root": str(tmp_path)}) await db.create_namespace(["ns1"]) table = await db.create_table("t", [{"id": 1}], namespace_path=["ns1"]) branch = await table.branches.create("exp") await branch.add([{"id": 2}]) # open_table(branch=...) on the async namespace connection must work opened = await db.open_table("t", namespace_path=["ns1"], branch="exp") assert await opened.count_rows() == 2 def test_namespace_open_table_with_branch_version(tmp_path): pytest.importorskip("lance") # "dir" impl is lance.namespace.DirectoryNamespace db = lancedb.connect_namespace("dir", {"root": str(tmp_path)}) db.create_namespace(["ns1"]) t = db.create_table("t", [{"i": 0}], namespace_path=["ns1"]) # fork "exp", then advance exp AND main past the fork so they diverge while # sharing version numbers exp = t.branches.create("exp") exp.add([{"i": 1}]) exp_v2 = exp.version exp.add([{"i": 2}]) t.add([{"i": 100}, {"i": 101}, {"i": 102}]) assert exp_v2 == t.version, "branch and main must share the version number" # open_table(branch=, version=) on the namespace connection reads the # branch's data at that version, not main's pinned = db.open_table("t", namespace_path=["ns1"], branch="exp", version=exp_v2) assert pinned.current_branch() == "exp" assert pinned.count_rows() == 2 # not exp HEAD (3), not main@v2 (4) assert pinned.count_rows("i = 1") == 1 # exp's post-fork row is visible assert pinned.count_rows("i = 100") == 0 # main's divergent rows are invisible assert db.open_table("t", namespace_path=["ns1"], branch="exp").count_rows() == 3 @pytest.mark.asyncio async def test_async_namespace_open_table_with_branch_version(tmp_path): pytest.importorskip("lance") # "dir" impl is lance.namespace.DirectoryNamespace db = lancedb.connect_namespace_async("dir", {"root": str(tmp_path)}) await db.create_namespace(["ns1"]) t = await db.create_table("t", [{"i": 0}], namespace_path=["ns1"]) # fork "exp", then advance exp AND main past the fork so they diverge while # sharing version numbers exp = await t.branches.create("exp") await exp.add([{"i": 1}]) exp_v2 = await exp.version() await exp.add([{"i": 2}]) await t.add([{"i": 100}, {"i": 101}, {"i": 102}]) assert exp_v2 == await t.version(), "branch and main must share the version number" # open_table(branch=, version=) on the async namespace connection reads the # branch's data at that version, not main's pinned = await db.open_table( "t", namespace_path=["ns1"], branch="exp", version=exp_v2 ) assert pinned.current_branch() == "exp" assert await pinned.count_rows() == 2 # not exp HEAD (3), not main@v2 (4) assert await pinned.count_rows("i = 1") == 1 # exp's post-fork row is visible assert await pinned.count_rows("i = 100") == 0 # main's rows are invisible assert ( await ( await db.open_table("t", namespace_path=["ns1"], branch="exp") ).count_rows() == 3 ) def test_branch_to_lance_targets_branch(tmp_path): pytest.importorskip("lance") db = lancedb.connect(tmp_path) table = db.create_table("t", [{"i": 1}]) branch = table.branches.create("exp") branch.add([{"i": 2}]) # branch: 2 rows, main: 1 row assert branch.to_lance().count_rows() == 2 assert table.to_lance().count_rows() == 1 @pytest.mark.asyncio async def test_async_branches(tmp_path): db = await lancedb.connect_async(tmp_path) table = await db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) assert await table.count_rows() == 2 branch = await table.branches.create("exp") assert await branch.count_rows() == 2 await branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) assert await branch.count_rows() == 3 assert await table.count_rows() == 2 branches = await table.branches.list() assert "exp" in branches assert branches["exp"]["parent_branch"] is None await table.branches.create("exp2", from_ref="main") assert (await table.branches.list())["exp2"]["parent_branch"] is None checked_out = await table.branches.checkout("exp") assert await checked_out.count_rows() == 3 await table.branches.delete("exp") await table.branches.delete("exp2") assert "exp" not in await table.branches.list() @pytest.mark.asyncio async def test_async_open_table_with_branch_version(tmp_path): db = await lancedb.connect_async(tmp_path, read_consistency_interval=timedelta(0)) # main: a single fork-point row t = await db.create_table("t", [{"i": 0}]) main_v1 = await t.version() # fork "exp", then advance exp AND main independently past the fork so they # diverge while sharing version numbers exp = await t.branches.create("exp") await exp.add([{"i": 1}]) # exp: {0, 1} exp_v2 = await exp.version() await exp.add([{"i": 2}]) # exp HEAD: {0, 1, 2} await t.add([{"i": 100}, {"i": 101}, {"i": 102}]) # main HEAD: {0, 100, 101, 102} assert exp_v2 == await t.version(), "branch and main must share the version number" # open exp at the shared version: the data must be exp's, not main's. count # alone cannot prove this (main@v2 also exists), so assert provenance by # content. pinned = await db.open_table("t", branch="exp", version=exp_v2) assert pinned.current_branch() == "exp" assert await pinned.count_rows() == 2 # not exp HEAD (3), not main@v2 (4) assert await pinned.count_rows("i = 1") == 1 # exp's post-fork row is visible assert await pinned.count_rows("i = 100") == 0 # main's rows are invisible # the same coordinate is reachable directly via branches.checkout(name, version) pinned_direct = await t.branches.checkout("exp", exp_v2) assert pinned_direct.current_branch() == "exp" assert await pinned_direct.count_rows() == 2 # the HEADs are unaffected assert await (await db.open_table("t", branch="exp")).count_rows() == 3 assert await (await db.open_table("t")).count_rows() == 4 # version-only (no branch) time-travels main itself: its fork-point version # holds only main's first row, and the shared version number resolves to # main's data, not the branch's ("opens main at the version") old_main = await db.open_table("t", version=main_v1) assert old_main.current_branch() is None assert await old_main.count_rows() == 1 shared_on_main = await db.open_table("t", version=exp_v2) assert shared_on_main.current_branch() is None assert await shared_on_main.count_rows() == 4 # detached head: writing to a pinned version is rejected with pytest.raises((ValueError, RuntimeError), match="cannot be modified"): await pinned.add([{"i": 9}]) # a nonexistent version is rejected -- on main, and on a branch with pytest.raises((ValueError, RuntimeError)): await db.open_table("t", version=9999) with pytest.raises((ValueError, RuntimeError)): await db.open_table("t", branch="exp", version=9999) # checkout_latest re-attaches the pinned handle to the BRANCH's HEAD # (writable again), not main's HEAD, and not staying pinned await pinned.checkout_latest() assert pinned.current_branch() == "exp" assert await pinned.count_rows() == 3 # exp HEAD, not main's 4 await pinned.add([{"i": 3}]) assert await pinned.count_rows() == 4 # writable again @patch("lancedb.table.AsyncTable.create_index") def test_create_index_method(mock_create_index, mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}, ], ) table.create_index( metric="l2", num_partitions=256, num_sub_vectors=96, vector_column_name="vector", replace=True, index_cache_size=256, num_bits=4, ) expected_config = IvfPq( distance_type="l2", num_partitions=256, num_sub_vectors=96, num_bits=4, ) mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=True, ) # Test with target_partition_size table.create_index( metric="l2", num_sub_vectors=96, vector_column_name="vector", replace=True, index_cache_size=256, num_bits=4, target_partition_size=8192, ) expected_config = IvfPq( distance_type="l2", num_sub_vectors=96, num_bits=4, target_partition_size=8192, ) mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=True, ) # target_partition_size has a default value, # so `num_partitions` and `target_partition_size` are not required table.create_index( metric="l2", num_sub_vectors=96, vector_column_name="vector", replace=True, index_cache_size=256, num_bits=4, ) expected_config = IvfPq( distance_type="l2", num_sub_vectors=96, num_bits=4, ) mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=True, ) table.create_index( vector_column_name="my_vector", metric="dot", index_type="IVF_HNSW_PQ", replace=False, ) expected_config = HnswPq(distance_type="dot") mock_create_index.assert_called_with( "my_vector", replace=False, config=expected_config, wait_timeout=None, name=None, train=True, ) table.create_index( vector_column_name="my_vector", metric="cosine", index_type="IVF_HNSW_SQ", sample_rate=0.1, m=29, ef_construction=10, ) expected_config = HnswSq( distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 ) mock_create_index.assert_called_with( "my_vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=True, ) table.create_index( vector_column_name="my_vector", metric="cosine", index_type="IVF_HNSW_FLAT", sample_rate=0.1, m=29, ef_construction=10, ) expected_config = HnswFlat( distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 ) mock_create_index.assert_called_with( "my_vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=True, ) @patch("lancedb.table.AsyncTable.create_index") def test_create_index_name_and_train_parameters( mock_create_index, mem_db: DBConnection ): """Test that name and train parameters are passed correctly to AsyncTable""" table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "id": 1}, {"vector": [5.9, 26.5], "id": 2}, ], ) # Test with custom name table.create_index(vector_column_name="vector", name="my_custom_index") expected_config = IvfPq() # Default config mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name="my_custom_index", train=True, ) # Test with train=False table.create_index(vector_column_name="vector", train=False) mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name=None, train=False, ) # Test with both name and train table.create_index(vector_column_name="vector", name="my_index_name", train=True) mock_create_index.assert_called_with( "vector", replace=True, config=expected_config, wait_timeout=None, name="my_index_name", train=True, ) @patch("lancedb.table.AsyncTable.create_index") def test_create_index_legacy_emits_deprecation_warning( mock_create_index, mem_db: DBConnection ): table = mem_db.create_table( "test", data=[{"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}], ) with pytest.warns(DeprecationWarning, match="create_index"): table.create_index(metric="l2", num_partitions=8, vector_column_name="vector") @patch("lancedb.table.AsyncTable.create_index") def test_create_index_new_api(mock_create_index, mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "category": "a", "text": "hello world"}, {"vector": [5.9, 26.5], "category": "b", "text": "goodbye"}, ], ) # Vector index via new API should not warn with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) table.create_index("vector", config=IvfPq(distance_type="l2")) mock_create_index.assert_called_with( "vector", replace=True, config=IvfPq(distance_type="l2"), wait_timeout=None, name=None, train=True, ) # Scalar index via new API table.create_index("category", config=BTree()) mock_create_index.assert_called_with( "category", replace=True, config=BTree(), wait_timeout=None, name=None, train=True, ) # FTS index via new API table.create_index("text", config=FTS(with_position=True)) mock_create_index.assert_called_with( "text", replace=True, config=FTS(with_position=True), wait_timeout=None, name=None, train=True, ) def test_create_with_nans(mem_db: DBConnection): # by default we raise an error on bad input vectors bad_data = [ {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [5], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, ] for row in bad_data: with pytest.raises(ValueError): mem_db.create_table( "error_test", data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row], ) table = mem_db.create_table( "drop_test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [5], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, ], on_bad_vectors="drop", ) assert len(table) == 1 # We can fill bad input with some value table = mem_db.create_table( "fill_test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, ], on_bad_vectors="fill", fill_value=0.0, ) assert len(table) == 3 arrow_tbl = table.search().where("item == 'bar'").to_arrow() v = arrow_tbl["vector"].to_pylist()[0] assert np.allclose(v, np.array([0.0, 0.0])) def test_add_with_nans(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), pa.field("item", pa.string(), nullable=True), pa.field("price", pa.float64(), nullable=False), ], ) table = mem_db.create_table("test", schema=schema) # by default we raise an error on bad input vectors bad_data = [ {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [5], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, ] for row in bad_data: with pytest.raises(ValueError): table.add( data=[row], ) table.add( [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [2.1, 4.1], "item": "foo", "price": 9.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [5], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, ], on_bad_vectors="drop", ) assert len(table) == 2 table.delete("true") # We can fill bad input with some value table.add( data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, ], on_bad_vectors="fill", fill_value=0.0, ) assert len(table) == 3 arrow_tbl = table.search().where("item == 'bar'").to_arrow() v = arrow_tbl["vector"].to_pylist()[0] assert np.allclose(v, np.array([0.0, 0.0])) def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection): class Schema(LanceModel): text: str embedding: Vector(16) table = mem_db.create_table("test_empty_embeddings", schema=Schema) table.add( [ {"text": "hello", "embedding": []}, {"text": "bar", "embedding": [0.1] * 16}, ], on_bad_vectors="drop", ) data = table.to_arrow() assert data["text"].to_pylist() == ["bar"] assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16)) def test_add_nullable_struct_with_none(mem_db: DBConnection): """Regression test for issue #2654: a nullable struct column whose first batch contains only None values must not crash in _align_field_types with AttributeError: 'pyarrow.lib.DataType' object has no attribute 'fields'. PyArrow infers an all-None struct column as `null` (not `struct`), so the type-alignment path needs to handle the case where the source field type is null and use the target type directly. """ # Use the v2.1 file format so that nullable structs are supported. table = mem_db.create_table( "test_nullable_struct", schema=pa.schema( [ pa.field("id", pa.string()), pa.field( "data", pa.struct([pa.field("x", pa.float32())]), nullable=True, ), ] ), storage_options=dict(new_table_data_storage_version="2.1"), ) # Adding a row with a non-null struct should work. table.add([{"id": "1", "data": {"x": 1.0}}]) # Adding a row with None for the nullable struct field should also # work — this is what used to crash. table.add([{"id": "2", "data": None}]) result = table.to_arrow() assert result.num_rows == 2 assert result.column("id").to_pylist() == ["1", "2"] assert result.column("data").to_pylist() == [{"x": 1.0}, None] def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection): class Schema(LanceModel): text: str embedding: Vector(4) table = mem_db.create_table("test_integer_embeddings", schema=Schema) table.add( [{"text": "foo", "embedding": [1, 2, 3, 4]}], on_bad_vectors="drop", ) assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists( mem_db: DBConnection, ): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 4)), pa.field("bbox", pa.list_(pa.float32(), 4)), ] ) table = mem_db.create_table("test_bbox_schema", schema=schema) with pytest.raises(RuntimeError, match="FixedSizeListType"): table.add( [{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}], on_bad_vectors="drop", ) def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists( mem_db: DBConnection, ): schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))]) table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema) with pytest.raises(RuntimeError, match="FixedSizeListType"): table.add( [ {"features": []}, {"features": [0.1] * 16}, ], on_bad_vectors="drop", ) def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection): schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) table = mem_db.create_table("test_schema_list_vector", schema=schema) table.add( [ {"vector": [1.0, 2.0]}, {"vector": [3.0]}, {"vector": [4.0, 5.0]}, ], on_bad_vectors="drop", ) assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]] def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema( mem_db: DBConnection, ): schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))]) table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema) data = pa.table( { "vec": pa.array( [[float("nan")] * 16, [1.0] * 16], type=pa.list_(pa.float32(), 16), ) } ) table.add(data, on_bad_vectors="drop") assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16] def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection): schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema) data = pa.table( { "vector": pa.array( [[1.0, 2.0], [float("nan"), 3.0]], type=pa.list_(pa.float32(), 2), ) } ) table.add( data, on_bad_vectors="fill", fill_value=0.0, ) assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]] @pytest.mark.parametrize( ("table_name", "batch1", "expected"), [ ( "test_schema_list_vector_empty_prefix", pa.record_batch({"vector": [[], []]}), [[], [], [1.0, 2.0], [3.0, 4.0]], ), ( "test_schema_list_vector_all_bad_prefix", pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}), [[1.0, 2.0], [3.0, 4.0]], ), ], ) def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches( mem_db: DBConnection, table_name: str, batch1: pa.RecordBatch, expected: list, ): schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) table = mem_db.create_table(table_name, schema=schema) batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]}) reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) table.add(reader, on_bad_vectors="drop") assert table.to_arrow()["vector"].to_pylist() == expected def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop( mem_db: DBConnection, ): registry = EmbeddingFunctionRegistry.get_instance() func = MockTextEmbeddingFunction.create() metadata = registry.get_table_metadata( [ EmbeddingFunctionConfig( source_column="text1", vector_column="vec1", function=func ), EmbeddingFunctionConfig( source_column="text2", vector_column="vec2", function=func ), ] ) schema = pa.schema( [ pa.field("vec1", pa.list_(pa.float32())), pa.field("vec2", pa.list_(pa.float32())), ], metadata=metadata, ) table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema) batch1 = pa.record_batch( { "vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]], "vec2": [[float("nan"), 0.0], [5.0, 6.0]], } ) batch2 = pa.record_batch( { "vec1": [[20.0, 21.0], [30.0, 31.0]], "vec2": [[7.0, 8.0], [9.0, 10.0]], } ) reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) table.add(reader, on_bad_vectors="drop") data = table.to_arrow() assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]] assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]] def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection): schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))]) table = mem_db.create_table("test_non_vector_list_schema", schema=schema) table.add( [ {"embedding_history": [1.0, 2.0]}, {"embedding_history": [3.0]}, ], on_bad_vectors="drop", ) assert table.to_arrow()["embedding_history"].to_pylist() == [ [1.0, 2.0], [3.0], ] def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash( mem_db: DBConnection, ): schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)]) table = mem_db.create_table("test_all_null_vector_batch", schema=schema) table.add([{"vector": None}], on_bad_vectors="drop") assert table.to_arrow()["vector"].to_pylist() == [None] def test_restore(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "type": "vector"}], ) table.add([{"vector": [0.5, 0.2], "type": "vector"}]) table.restore(1) assert len(table.list_versions()) == 3 assert len(table) == 1 expected = table.to_arrow() table.checkout(1) table.restore() assert len(table.list_versions()) == 4 assert table.to_arrow() == expected table.restore(4) # latest version should be no-op assert len(table.list_versions()) == 5 with pytest.raises(ValueError): table.restore(6) with pytest.raises(ValueError): table.restore(0) def test_restore_with_tags(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "type": "vector"}], ) tag = "tag1" table.tags.create(tag, 1) table.add([{"vector": [0.5, 0.2], "type": "vector"}]) table.restore(tag) assert len(table.list_versions()) == 3 assert len(table) == 1 expected = table.to_arrow() table.add([{"vector": [0.3, 0.3], "type": "vector"}]) table.checkout("tag1") table.restore() assert len(table.list_versions()) == 5 assert table.to_arrow() == expected with pytest.raises(ValueError): table.restore("tag_unknown") def test_merge(tmp_db: DBConnection, tmp_path): pytest.importorskip("lance") import lance table = tmp_db.create_table( "my_table", schema=pa.schema( { "vector": pa.list_(pa.float32(), 2), "id": pa.int64(), } ), ) table.add([{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}]) other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]}) table.merge(other_table, left_on="id") assert len(table.list_versions()) == 3 expected = pa.table( {"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]}, schema=table.schema, ) assert table.to_arrow() == expected other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance") table.restore(1) table.merge(other_dataset, left_on="id") def test_delete(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], ) assert len(table) == 2 assert len(table.list_versions()) == 1 delete_res = table.delete("id=0") assert delete_res.version == 2 assert len(table.list_versions()) == 2 assert table.version == 2 assert len(table) == 1 assert table.to_arrow()["id"].to_pylist() == [1] def test_delete_expr(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[ {"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}, {"vector": [1.3, 2.9], "id": 2}, ], ) assert len(table) == 3 delete_res = table.delete(col("id") == lit(0)) assert delete_res.version == 2 assert len(table) == 2 assert sorted(table.to_arrow()["id"].to_pylist()) == [1, 2] @pytest.mark.asyncio async def test_delete_expr_async(mem_db_async: AsyncConnection): table = await mem_db_async.create_table( "my_table", data=[ {"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}, {"vector": [1.3, 2.9], "id": 2}, ], ) assert await table.count_rows() == 3 await table.delete(col("id") == lit(0)) assert await table.count_rows() == 2 assert sorted((await table.to_arrow())["id"].to_pylist()) == [1, 2] def test_update(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], ) assert len(table) == 2 assert len(table.list_versions()) == 1 update_res = table.update(where="id=0", values={"vector": [1.1, 1.1]}) assert update_res.version == 2 assert update_res.rows_updated == 1 assert len(table.list_versions()) == 2 assert table.version == 2 assert len(table) == 2 v = table.to_arrow()["vector"].combine_chunks() v = v.values.to_numpy().reshape(2, 2) assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]])) def test_update_types(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[ { "id": 0, "str": "foo", "float": 1.1, "timestamp": datetime(2021, 1, 1), "date": date(2021, 1, 1), "vector1": [1.0, 0.0], "vector2": [1.0, 1.0], "binary": b"abc", } ], ) # Update with SQL table.update( values_sql=dict( id="1", str="'bar'", float="2.2", timestamp="TIMESTAMP '2021-01-02 00:00:00'", date="DATE '2021-01-02'", vector1="[2.0, 2.0]", vector2="[3.0, 3.0]", binary="X'646566'", ) ) actual = table.to_arrow().to_pylist()[0] expected = dict( id=1, str="bar", float=2.2, timestamp=datetime(2021, 1, 2), date=date(2021, 1, 2), vector1=[2.0, 2.0], vector2=[3.0, 3.0], binary=b"def", ) assert actual == expected # Update with values table.update( values=dict( id=2, str="baz", float=3.3, timestamp=datetime(2021, 1, 3), date=date(2021, 1, 3), vector1=[3.0, 3.0], vector2=np.array([4.0, 4.0]), binary=b"def", ) ) actual = table.to_arrow().to_pylist()[0] expected = dict( id=2, str="baz", float=3.3, timestamp=datetime(2021, 1, 3), date=date(2021, 1, 3), vector1=[3.0, 3.0], vector2=[4.0, 4.0], binary=b"def", ) assert actual == expected def test_merge_insert(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}), ) assert len(table) == 3 version = table.version new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) # upsert merge_insert_res = ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_data, timeout=timedelta(seconds=10)) ) assert merge_insert_res.version == 2 assert merge_insert_res.num_inserted_rows == 1 assert merge_insert_res.num_updated_rows == 2 assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) # conditional update merge_insert_res = ( table.merge_insert("a") .when_matched_update_all(where="target.b = 'b'") .execute(new_data) ) assert merge_insert_res.version == 4 assert merge_insert_res.num_inserted_rows == 0 assert merge_insert_res.num_updated_rows == 1 assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) # insert-if-not-exists merge_insert_res = ( table.merge_insert("a").when_not_matched_insert_all().execute(new_data) ) assert merge_insert_res.version == 6 assert merge_insert_res.num_inserted_rows == 1 assert merge_insert_res.num_updated_rows == 0 assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) # replace-range merge_insert_res = ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete("a > 2") .execute(new_data) ) assert merge_insert_res.version == 8 assert merge_insert_res.num_inserted_rows == 1 assert merge_insert_res.num_updated_rows == 1 assert merge_insert_res.num_deleted_rows == 1 expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) # replace-range no condition merge_insert_res = ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete() .execute(new_data) ) assert merge_insert_res.version == 10 assert merge_insert_res.num_inserted_rows == 1 assert merge_insert_res.num_updated_rows == 1 assert merge_insert_res.num_deleted_rows == 2 expected = pa.table({"a": [2, 4], "b": ["x", "z"]}) assert table.to_arrow().sort_by("a") == expected # timeout with pytest.raises(Exception, match="merge insert timed out"): table.merge_insert("a").when_matched_update_all().execute( new_data, timeout=timedelta(0) ) def test_merge_insert_by_source_delete_expr(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}), ) new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) # replace-range, limiting the source-absent delete with an Expr condition merge_insert_res = ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete(col("a") > lit(2)) .execute(new_data) ) assert merge_insert_res.num_inserted_rows == 1 assert merge_insert_res.num_updated_rows == 1 assert merge_insert_res.num_deleted_rows == 1 expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert table.to_arrow().sort_by("a") == expected @pytest.mark.asyncio async def test_merge_insert_by_source_delete_expr_async( mem_db_async: AsyncConnection, ): data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) table = await mem_db_async.create_table("some_table", data=data) new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) # replace-range, limiting the source-absent delete with an Expr condition await ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete(col("a") > lit(2)) .execute(new_data) ) expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert (await table.to_arrow()).sort_by("a") == expected # We vary the data format because there are slight differences in how # subschemas are handled in different formats @pytest.mark.parametrize( "data_format", [ lambda table: table, lambda table: table.to_pandas(), lambda table: table.to_pylist(), ], ids=["pa.Table", "pd.DataFrame", "rows"], ) def test_merge_insert_subschema(mem_db: DBConnection, data_format): pytest.importorskip("pandas") initial_data = pa.table( {"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]} ) table = mem_db.create_table("my_table", data=initial_data) new_data = pa.table({"id": [2, 3], "c": ["y", "y"]}) new_data = data_format(new_data) ( table.merge_insert(on="id") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_data) ) expected = pa.table( {"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]} ) assert table.to_arrow().sort_by("id") == expected @pytest.mark.asyncio async def test_merge_insert_async(mem_db_async: AsyncConnection): data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) table = await mem_db_async.create_table("some_table", data=data) assert await table.count_rows() == 3 version = await table.version() new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) # upsert await ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_data) ) expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) assert (await table.to_arrow()).sort_by("a") == expected await table.checkout(version) await table.restore() # conditional update await ( table.merge_insert("a") .when_matched_update_all(where="target.b = 'b'") .execute(new_data) ) expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]}) assert (await table.to_arrow()).sort_by("a") == expected await table.checkout(version) await table.restore() # insert-if-not-exists await table.merge_insert("a").when_not_matched_insert_all().execute(new_data) expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]}) assert (await table.to_arrow()).sort_by("a") == expected await table.checkout(version) await table.restore() # replace-range new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) await ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete("a > 2") .execute(new_data) ) expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert (await table.to_arrow()).sort_by("a") == expected await table.checkout(version) await table.restore() # replace-range no condition await ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete() .execute(new_data) ) expected = pa.table({"a": [2, 4], "b": ["x", "z"]}) assert (await table.to_arrow()).sort_by("a") == expected def test_create_with_embedding_function(mem_db: DBConnection): class MyTable(LanceModel): text: str vector: Vector(10) func = MockTextEmbeddingFunction.create() texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)}) conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", function=func ) table = mem_db.create_table( "my_table", schema=MyTable, embedding_functions=[conf], ) table.add(df) query_str = "hi how are you?" query_vector = func.compute_query_embeddings(query_str)[0] expected = table.search(query_vector).limit(2).to_arrow() actual = table.search(query_str).limit(2).to_arrow() assert actual == expected def test_create_f16_table(mem_db: DBConnection): class MyTable(LanceModel): text: str vector: Vector(32, value_type=pa.float16()) df = pa.table( { "text": [f"s-{i}" for i in range(512)], "vector": [np.random.randn(32).astype(np.float16) for _ in range(512)], } ) table = mem_db.create_table( "f16_tbl", schema=MyTable, ) table.add(df) table.create_index(num_partitions=2, num_sub_vectors=2) query = df["vector"][2].as_py() expected = table.search(query).limit(2).to_arrow() assert "s-2" in expected["text"].to_pylist() def test_add_with_embedding_function(mem_db: DBConnection): emb = EmbeddingFunctionRegistry.get_instance().get("test").create() class MyTable(LanceModel): text: str = emb.SourceField() vector: Vector(emb.ndims()) = emb.VectorField() table = mem_db.create_table("my_table", schema=MyTable) texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pa.table({"text": texts}) table.add(df) texts = ["the quick brown fox", "jumped over the lazy dog"] table.add([{"text": t} for t in texts]) query_str = "hi how are you?" query_vector = emb.compute_query_embeddings(query_str)[0] expected = table.search(query_vector).limit(2).to_arrow() actual = table.search(query_str).limit(2).to_arrow() assert actual == expected def test_multiple_vector_columns(mem_db: DBConnection): class MyTable(LanceModel): text: str vector1: Vector(10) vector2: Vector(10) table = mem_db.create_table( "my_table", schema=MyTable, ) v1 = np.random.randn(10) v2 = np.random.randn(10) data = [ {"vector1": v1, "vector2": v2, "text": "foo"}, {"vector1": v2, "vector2": v1, "text": "bar"}, ] df = pa.Table.from_pylist(data) table.add(df) q = np.random.randn(10) result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow() result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow() assert result1["text"][0] != result2["text"][0] def test_create_scalar_index(mem_db: DBConnection): vec_array = pa.array( [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], pa.list_(pa.float32(), 2) ) test_data = pa.Table.from_pydict( {"x": ["c", "b", "a", "e", "b"], "y": [1, 2, 3, 4, 5], "vector": vec_array} ) table = mem_db.create_table( "my_table", data=test_data, ) # Test with default name; confirm DeprecationWarning fires with pytest.warns(DeprecationWarning, match="create_scalar_index"): table.create_scalar_index("x") indices = table.list_indices() assert len(indices) == 1 scalar_index = indices[0] assert scalar_index.index_type == "BTree" assert scalar_index.name == "x_idx" # Default name # Confirm that prefiltering still works with the scalar index column results = table.search().where("x = 'c'").to_arrow() assert results == test_data.slice(0, 1) results = table.search([5, 5]).to_arrow() assert results["_distance"][0].as_py() == 0 results = table.search([5, 5]).where("x != 'b'").to_arrow() assert results["_distance"][0].as_py() > 0 table.drop_index(scalar_index.name) indices = table.list_indices() assert len(indices) == 0 # Test with custom name table.create_scalar_index("y", name="custom_y_index") indices = table.list_indices() assert len(indices) == 1 scalar_index = indices[0] assert scalar_index.index_type == "BTree" assert scalar_index.name == "custom_y_index" def test_create_index_nested_field_paths(mem_db: DBConnection): schema = pa.schema( [ pa.field("rowId", pa.int32()), pa.field("row-id", pa.int32()), pa.field("userId", pa.int32()), pa.field("metadata", pa.struct([pa.field("user_id", pa.int32())])), pa.field("MetaData", pa.struct([pa.field("userId", pa.int32())])), pa.field( "image", pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), ), pa.field("payload", pa.struct([pa.field("text", pa.string())])), pa.field("meta-data", pa.struct([pa.field("user-id", pa.int32())])), pa.field("literal", pa.struct([pa.field("a.b", pa.int32())])), ] ) data = pa.Table.from_pylist( [ { "rowId": i, "row-id": i, "userId": i, "metadata": {"user_id": i}, "MetaData": {"userId": i}, "image": {"embedding": [float(i), float(i + 1)]}, "payload": {"text": f"document {i}"}, "meta-data": {"user-id": i}, "literal": {"a.b": i}, } for i in range(256) ], schema=schema, ) table = mem_db.create_table("nested_index_paths", data=data) table.create_scalar_index("rowId", name="row_id_idx") table.create_scalar_index("`row-id`", name="row_dash_id_idx") table.create_scalar_index("userId", name="top_user_id_idx") table.create_scalar_index("metadata.user_id", name="metadata_user_id_idx") table.create_scalar_index("MetaData.userId", name="mixed_case_metadata_user_id_idx") table.create_scalar_index("`meta-data`.`user-id`", name="escaped_names_idx") table.create_scalar_index("literal.`a.b`", name="literal_dot_idx") table.create_index( vector_column_name="image.embedding", num_partitions=1, num_sub_vectors=1, name="image_embedding_idx", ) table.create_fts_index("payload.text", with_position=False, name="payload_text_idx") indices = sorted(table.list_indices(), key=lambda idx: idx.name) assert [(idx.name, idx.index_type, idx.columns) for idx in indices] == [ ("escaped_names_idx", "BTree", ["`meta-data`.`user-id`"]), ("image_embedding_idx", "IvfPq", ["image.embedding"]), ("literal_dot_idx", "BTree", ["literal.`a.b`"]), ("metadata_user_id_idx", "BTree", ["metadata.user_id"]), ("mixed_case_metadata_user_id_idx", "BTree", ["MetaData.userId"]), ("payload_text_idx", "FTS", ["payload.text"]), ("row_dash_id_idx", "BTree", ["`row-id`"]), ("row_id_idx", "BTree", ["rowId"]), ("top_user_id_idx", "BTree", ["userId"]), ] for index in indices: stats = table.index_stats(index.name) assert stats is not None assert stats.num_indexed_rows == 256 vector_results = ( table.search([0.0, 1.0], vector_column_name="image.embedding") .limit(1) .to_list() ) assert len(vector_results) == 1 assert vector_results[0]["metadata"]["user_id"] == 0 default_vector_results = table.search([0.0, 1.0]).limit(1).to_list() assert len(default_vector_results) == 1 assert default_vector_results[0]["metadata"]["user_id"] == 0 filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list() assert len(filtered_results) == 1 assert filtered_results[0]["metadata"]["user_id"] == 42 escaped_results = table.search().where("`row-id` = 43").limit(1).to_list() assert len(escaped_results) == 1 assert escaped_results[0]["row-id"] == 43 fts_results = table.search("document 44", query_type="fts").limit(1).to_list() assert len(fts_results) == 1 assert fts_results[0]["payload"]["text"] == "document 44" def test_index_config_fields(mem_db: DBConnection): """Test that IndexConfig exposes the new rich metadata fields.""" vec_array = pa.array( [[float(i), float(i + 1)] for i in range(300)], pa.list_(pa.float32(), 2) ) data = pa.Table.from_pydict({"x": list(range(300)), "vector": vec_array}) table = mem_db.create_table("index_config_fields", data=data) table.create_scalar_index("x", index_type="BTREE") table.create_index( vector_column_name="vector", num_partitions=1, num_sub_vectors=1, ) indices = {idx.name: idx for idx in table.list_indices()} scalar_idx = indices["x_idx"] assert scalar_idx.index_uuid is not None assert isinstance(scalar_idx.index_uuid, str) assert scalar_idx.num_indexed_rows is not None assert scalar_idx.num_indexed_rows == 300 assert scalar_idx.num_unindexed_rows is not None assert scalar_idx.num_unindexed_rows == 0 assert scalar_idx.num_segments is not None assert scalar_idx.num_segments >= 1 assert scalar_idx.size_bytes is not None assert scalar_idx.size_bytes > 0 assert scalar_idx.created_at is not None from datetime import datetime, timezone assert isinstance(scalar_idx.created_at, datetime) assert scalar_idx.created_at.tzinfo == timezone.utc # __getitem__ compatibility assert scalar_idx["index_uuid"] == scalar_idx.index_uuid assert scalar_idx["num_indexed_rows"] == scalar_idx.num_indexed_rows assert scalar_idx["created_at"] == scalar_idx.created_at # index_details is parsed from JSON into a Python object assert scalar_idx.index_details is not None assert isinstance(scalar_idx.index_details, dict) assert scalar_idx["index_details"] == scalar_idx.index_details vector_idx = indices["vector_idx"] assert vector_idx.index_uuid is not None assert vector_idx.num_indexed_rows == 300 assert isinstance(vector_idx.index_details, dict) def test_empty_query(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow() val = df["id"][0].as_py() assert val == 1 table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)]) df = table.search().select(["id"]).to_arrow() assert df.num_rows == 100 # None is the same as default df = table.search().select(["id"]).limit(None).to_arrow() assert df.num_rows == 100 # invalid limist is the same as None, wihch is the same as default df = table.search().select(["id"]).limit(-1).to_arrow() assert df.num_rows == 100 # valid limit should work df = table.search().select(["id"]).limit(42).to_arrow() assert df.num_rows == 42 def test_search_with_schema_inf_single_vector(mem_db: DBConnection): class MyTable(LanceModel): text: str vector_col: Vector(10) table = mem_db.create_table( "my_table", schema=MyTable, ) v1 = np.random.randn(10) v2 = np.random.randn(10) data = [ {"vector_col": v1, "text": "foo"}, {"vector_col": v2, "text": "bar"}, ] df = pa.Table.from_pylist(data) table.add(df) q = np.random.randn(10) result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow() result2 = table.search(q).limit(1).to_arrow() assert result1["text"][0].as_py() == result2["text"][0].as_py() def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection): class MyTable(LanceModel): text: str vector1: Vector(10) vector2: Vector(10) table = mem_db.create_table( "my_table", schema=MyTable, ) v1 = np.random.randn(10) v2 = np.random.randn(10) data = [ {"vector1": v1, "vector2": v2, "text": "foo"}, {"vector1": v2, "vector2": v1, "text": "bar"}, ] df = pa.Table.from_pylist(data) table.add(df) q = np.random.randn(10) with pytest.raises(ValueError): table.search(q).limit(1).to_arrow() def test_search_infers_single_nested_vector(mem_db: DBConnection): schema = pa.schema( [ pa.field("id", pa.int32()), pa.field( "image", pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), ), ] ) data = pa.Table.from_pylist( [ {"id": 0, "image": {"embedding": [0.0, 1.0]}}, {"id": 1, "image": {"embedding": [10.0, 11.0]}}, ], schema=schema, ) table = mem_db.create_table("nested_vector_default_search", data=data) result = table.search([0.0, 1.0]).limit(1).to_list() assert result[0]["id"] == 0 def test_search_nested_vector_multiple_candidates(mem_db: DBConnection): schema = pa.schema( [ pa.field( "image", pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), ), pa.field( "text", pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), ), ] ) data = pa.Table.from_pylist( [ { "image": {"embedding": [0.0, 1.0]}, "text": {"embedding": [2.0, 3.0]}, } ], schema=schema, ) table = mem_db.create_table("nested_vector_multiple_candidates", data=data) with pytest.raises(ValueError, match="image.embedding.*text.embedding"): table.search([0.0, 1.0]).limit(1).to_arrow() def test_search_nested_vector_no_candidates(mem_db: DBConnection): schema = pa.schema( [ pa.field("id", pa.int32()), pa.field("metadata", pa.struct([pa.field("label", pa.string())])), ] ) data = pa.Table.from_pylist( [{"id": 0, "metadata": {"label": "cat"}}], schema=schema, ) table = mem_db.create_table("nested_vector_no_candidates", data=data) with pytest.raises(ValueError, match="no vector column"): table.search([0.0, 1.0]).limit(1).to_arrow() def test_compact_cleanup(tmp_db: DBConnection): pytest.importorskip("lance") table = tmp_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) table.add([{"text": "baz", "id": 2}]) assert len(table) == 3 assert table.version == 2 stats = table.compact_files() assert len(table) == 3 # Compact_files bump 2 versions. assert table.version == 4 assert stats.fragments_removed > 0 assert stats.fragments_added == 1 stats = table.cleanup_old_versions() assert stats.bytes_removed == 0 stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True) assert stats.bytes_removed > 0 assert table.version == 4 with pytest.raises(Exception, match="Version 3 no longer exists"): table.checkout(3) def test_count_rows(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) assert len(table) == 2 assert table.count_rows() == 2 assert table.count_rows(filter="text='bar'") == 1 def setup_hybrid_search_table(db: DBConnection, embedding_func): # Create a LanceDB table schema with a vector and a text column emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create() class MyTable(LanceModel): text: str = emb.SourceField() vector: Vector(emb.ndims()) = emb.VectorField() # Initialize the table using the schema table = db.create_table( "my_table", schema=MyTable, ) # Create a list of 10 unique english phrases phrases = [ "great kid don't get cocky", "now that's a name I haven't heard in a long time", "if you strike me down I shall become more powerful than you imagine", "I find your lack of faith disturbing", "I've got a bad feeling about this", "never tell me the odds", "I am your father", "somebody has to save our skins", "New strategy R2 let the wookiee win", "Arrrrggghhhhhhh", ] # Add the phrases and vectors to the table table.add([{"text": p} for p in phrases]) # Create a fts index table.create_fts_index("text", with_position=True) return table, MyTable, emb def test_hybrid_search(tmp_db: DBConnection): # This test uses an FTS index pytest.importorskip("lance") table, MyTable, emb = setup_hybrid_search_table(tmp_db, "test") result1 = ( table.search("Our father who art in heaven", query_type="hybrid") .rerank(normalize="score") .to_pydantic(MyTable) ) result2 = ( # noqa table.search("Our father who art in heaven", query_type="hybrid") .rerank(normalize="rank") .to_pydantic(MyTable) ) result3 = table.search( "Our father who art in heaven", query_type="hybrid" ).to_pydantic(MyTable) # Test that double and single quote characters are handled with phrase_query() ( table.search( '"Aren\'t you a little short for a stormtrooper?" -- Leia', query_type="hybrid", ) .phrase_query(True) .to_pydantic(MyTable) ) assert result1 == result3 # with post filters result = ( table.search("Arrrrggghhhhhhh", query_type="hybrid") .where("text='Arrrrggghhhhhhh'") .to_list() ) assert len(result) == 1 # with explicit query type vector_query = list(range(emb.ndims())) result = ( table.search(query_type="hybrid") .vector(vector_query) .text("Arrrrggghhhhhhh") .to_arrow() ) assert len(result) > 0 assert "_relevance_score" in result.column_names # with vector_column_name result = ( table.search(query_type="hybrid", vector_column_name="vector") .vector(vector_query) .text("Arrrrggghhhhhhh") .to_arrow() ) assert len(result) > 0 assert "_relevance_score" in result.column_names # fail if only text or vector is provided with pytest.raises(ValueError): table.search(query_type="hybrid").to_list() with pytest.raises(ValueError): table.search(query_type="hybrid").vector(vector_query).to_list() with pytest.raises(ValueError): table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list() def test_hybrid_search_metric_type(tmp_db: DBConnection): # This test uses an FTS index pytest.importorskip("lance") # Need to use nonnorm as the embedding function so l2 and dot results # are different table, _, _ = setup_hybrid_search_table(tmp_db, "nonnorm") # with custom metric result_dot = ( table.search("feeling lucky", query_type="hybrid") .distance_type("dot") .to_arrow() ) result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow() assert len(result_dot) > 0 assert len(result_l2) > 0 assert result_dot["_relevance_score"] != result_l2["_relevance_score"] @pytest.mark.parametrize( "consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)] ) @pytest.mark.skipif( sys.platform == "win32", reason=( "TODO: directory namespace is not supported on Windows yet; " "re-enable after that is fixed." ), ) def test_consistency(tmp_path, consistency_interval): db = lancedb.connect(tmp_path) table = db.create_table("my_table", data=[{"id": 0}]) db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval) table2 = db2.open_table("my_table") if consistency_interval is not None: assert "read_consistency_interval=datetime.timedelta(" in repr(db2) assert "read_consistency_interval=datetime.timedelta(" in repr(table2) assert table2.version == table.version table.add([{"id": 1}]) if consistency_interval is None: assert table2.version == table.version - 1 table2.checkout_latest() assert table2.version == table.version elif consistency_interval == timedelta(seconds=0): assert table2.version == table.version else: assert table2.version == table.version - 1 sleep(0.1) assert table2.version == table.version def test_restore_consistency(tmp_path): db = lancedb.connect(tmp_path) table = db.create_table("my_table", data=[{"id": 0}]) assert table.version == 1 db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) table2 = db2.open_table("my_table") assert table2.version == table.version # If we call checkout, it should lose consistency table2.checkout(table.version) table.add([{"id": 2}]) assert table2.version == 1 # But if we call checkout_latest, it should be consistent again table2.checkout_latest() assert table2.version == table.version # Schema evolution def test_add_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) table = LanceTable.create(mem_db, "my_table", data=data) add_columns_res = table.add_columns({"new_col": "id + 2"}) assert add_columns_res.version == 2 assert table.to_arrow().column_names == ["id", "new_col"] assert table.to_arrow()["new_col"].to_pylist() == [2, 3] add_columns_res = table.add_columns({"null_int": "cast(null as bigint)"}) assert add_columns_res.version == 3 assert table.schema.field("null_int").type == pa.int64() @pytest.mark.asyncio async def test_add_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) add_columns_res = await table.add_columns({"new_col": "id + 2"}) assert add_columns_res.version == 2 data = await table.to_arrow() assert data.column_names == ["id", "new_col"] assert data["new_col"].to_pylist() == [2, 3] @pytest.mark.asyncio async def test_add_columns_with_schema(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) add_columns_res = await table.add_columns( [pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))] ) assert add_columns_res.version == 2 assert await table.schema() == pa.schema( [ pa.field("id", pa.int64()), pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8)), ] ) table = await mem_db_async.create_table("table2", data=data) add_columns_res = await table.add_columns( pa.schema( [pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))] ) ) assert add_columns_res.version == 2 assert await table.schema() == pa.schema( [ pa.field("id", pa.int64()), pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8)), ] ) def test_alter_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) table = mem_db.create_table("my_table", data=data) alter_columns_res = table.alter_columns({"path": "id", "rename": "new_id"}) assert alter_columns_res.version == 2 assert table.to_arrow().column_names == ["new_id"] def test_update_field_metadata(mem_db: DBConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) table = mem_db.create_table("my_table", data=data) res = table.update_field_metadata( {"path": "category", "metadata": {"unit": "label", "pii": "false"}} ) assert res.version == 2 # Arrow field metadata is bytes-keyed assert table.schema.field("category").metadata == { b"unit": b"label", b"pii": b"false", } # merge: add a key, delete one via None, keep the rest table.update_field_metadata( {"path": "category", "metadata": {"source": "import", "pii": None}} ) assert table.schema.field("category").metadata == { b"unit": b"label", b"source": b"import", } @pytest.mark.asyncio async def test_alter_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) alter_columns_res = await table.alter_columns({"path": "id", "rename": "new_id"}) assert alter_columns_res.version == 2 assert (await table.to_arrow()).column_names == ["new_id"] alter_columns_res = await table.alter_columns( dict(path="new_id", data_type=pa.int16(), nullable=True) ) assert alter_columns_res.version == 3 data = await table.to_arrow() assert data.column(0).type == pa.int16() assert data.schema.field(0).nullable def test_drop_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) table = mem_db.create_table("my_table", data=data) drop_columns_res = table.drop_columns(["category"]) assert drop_columns_res.version == 2 assert table.to_arrow().column_names == ["id"] @pytest.mark.asyncio async def test_drop_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) table = await mem_db_async.create_table("my_table", data=data) drop_columns_res = await table.drop_columns(["category"]) assert drop_columns_res.version == 2 assert (await table.to_arrow()).column_names == ["id"] @pytest.mark.asyncio async def test_time_travel(mem_db_async: AsyncConnection): # Setup table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) version = await table.version() await table.add([{"id": 1}]) assert await table.count_rows() == 2 # Make sure we can rewind await table.checkout(version) assert await table.count_rows() == 1 # Can't add data in time travel mode with pytest.raises( ValueError, match="table cannot be modified when a specific version is checked out", ): await table.add([{"id": 2}]) # Can go back to normal mode await table.checkout_latest() assert await table.count_rows() == 2 # Should be able to add data again await table.add([{"id": 3}]) assert await table.count_rows() == 3 # Now checkout and restore await table.checkout(version) await table.restore() assert await table.count_rows() == 1 # Should be able to add data await table.add([{"id": 4}]) assert await table.count_rows() == 2 # Can't use restore if not checked out with pytest.raises(ValueError, match="checkout before running restore"): await table.restore() def test_sync_optimize(mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) table.create_scalar_index("price", index_type="BTREE") stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 2 table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}]) assert table.count_rows() == 3 table.optimize() stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 3 @pytest.mark.asyncio async def test_sync_optimize_in_async(mem_db: DBConnection): table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) table.create_scalar_index("price", index_type="BTREE") stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 2 table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}]) assert table.count_rows() == 3 table.optimize() @pytest.mark.asyncio async def test_optimize(mem_db_async: AsyncConnection): table = await mem_db_async.create_table( "test", data=[{"x": [1]}], ) await table.add( data=[ {"x": [2]}, ], ) stats = await table.optimize() expected = ( "OptimizeStats(compaction=CompactionStats { fragments_removed: 2, " "fragments_added: 1, files_removed: 2, files_added: 1 }, " "prune=RemovalStats { bytes_removed: 0, old_versions_removed: 0 })" ) assert str(stats) == expected assert stats.compaction.files_removed == 2 assert stats.compaction.files_added == 1 assert stats.compaction.fragments_added == 1 assert stats.compaction.fragments_removed == 2 assert stats.prune.bytes_removed == 0 assert stats.prune.old_versions_removed == 0 stats = await table.optimize(cleanup_older_than=timedelta(seconds=0)) assert stats.prune.bytes_removed > 0 assert stats.prune.old_versions_removed == 3 assert await table.query().to_arrow() == pa.table({"x": [[1], [2]]}) @pytest.mark.asyncio async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_path): table = await tmp_db_async.create_table( "test", data=[{"x": [1]}], ) await table.add( data=[ {"x": [2]}, ], ) version = await table.version() assert version == 2 # By removing a manifest file, we make the data files we just inserted unverified version_name = 18446744073709551615 - (version - 1) path = tmp_path / "test.lance" / "_versions" / f"{version_name:020}.manifest" os.remove(path) stats = await table.optimize(delete_unverified=False) assert stats.prune.old_versions_removed == 0 stats = await table.optimize( cleanup_older_than=timedelta(seconds=0), delete_unverified=True ) assert stats.prune.old_versions_removed == 2 def test_replace_field_metadata(tmp_path): db = lancedb.connect(tmp_path) table = db.create_table("my_table", data=[{"x": 0}]) table.replace_field_metadata("x", {"foo": "bar"}) schema = table.schema field = schema[0].metadata assert field == {b"foo": b"bar"} def test_stats(mem_db: DBConnection): table = mem_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) assert len(table) == 2 stats = table.stats() print(f"{stats=}") assert stats == { "total_bytes": 60, "num_rows": 2, "num_indices": 0, "fragment_stats": { "num_fragments": 1, "num_small_fragments": 1, "lengths": { "min": 2, "max": 2, "mean": 2, "p25": 2, "p50": 2, "p75": 2, "p99": 2, }, }, } def test_create_table_empty_list_with_schema(mem_db: DBConnection): """Test creating table with empty list data and schema Regression test for IndexError: list index out of range when calling create_table(name, data=[], schema=schema) """ schema = pa.schema( [pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("id", pa.int64())] ) table = mem_db.create_table("test_empty_list", data=[], schema=schema) assert table.count_rows() == 0 assert table.schema == schema def test_create_table_empty_list_no_schema_error(mem_db: DBConnection): """Test that creating table with empty list and no schema raises error""" with pytest.raises( ValueError, match="Cannot create table from empty list without a schema" ): mem_db.create_table("test_empty_no_schema", data=[]) def test_add_table_with_empty_embeddings(tmp_path): """Test exact scenario from issue #1968 Regression test for issue #1968: https://github.com/lancedb/lancedb/issues/1968 """ db = lancedb.connect(tmp_path) class MySchema(LanceModel): text: str embedding: Vector(16) table = db.create_table("test", schema=MySchema) table.add( [{"text": "bar", "embedding": [0.1] * 16}], on_bad_vectors="drop", ) assert table.count_rows() == 1 def test_table_uri(tmp_path): db = lancedb.connect(tmp_path) table = db.create_table("my_table", data=[{"x": 0}]) assert table.uri == str(tmp_path / "my_table.lance") def test_sanitize_data_metadata_not_stripped(): """Regression test: dict.update() returns None, so assigning its result would silently replace metadata with None, causing with_metadata(None) to strip all schema metadata from the target schema.""" from lancedb.table import _sanitize_data schema = pa.schema( [pa.field("x", pa.int64())], metadata={b"existing_key": b"existing_value"}, ) batch = pa.record_batch([pa.array([1, 2, 3])], schema=schema) # Use a different field type so the reader and target schemas differ, # forcing _cast_to_target_schema to rebuild the schema with the # target's metadata (instead of taking the fast-path). target_schema = pa.schema( [pa.field("x", pa.int32())], metadata={b"existing_key": b"existing_value"}, ) reader = pa.RecordBatchReader.from_batches(schema, [batch]) metadata = {b"new_key": b"new_value"} result = _sanitize_data(reader, target_schema=target_schema, metadata=metadata) result_schema = result.schema assert result_schema.metadata is not None assert result_schema.metadata[b"existing_key"] == b"existing_value" assert result_schema.metadata[b"new_key"] == b"new_value" @pytest.mark.asyncio async def test_async_search_runs_embedding_on_dedicated_executor( mem_db_async: AsyncConnection, ): # Regression test for #3310: AsyncTable.search() must run the (potentially # blocking) query-embedding call on the dedicated embedding executor, not # asyncio's default executor -- which is shared with other blocking I/O and # can be starved by a slow embedding call under concurrent load. func = MockTextEmbeddingFunction.create() class Schema(LanceModel): text: str = func.SourceField() vector: Vector(func.ndims()) = func.VectorField() table = await mem_db_async.create_table("embed_executor", schema=Schema) await table.add([{"text": "hello world"}]) captured_threads: List[str] = [] original = MockTextEmbeddingFunction.generate_embeddings def record_thread(self, texts): captured_threads.append(threading.current_thread().name) return original(self, texts) # Patch only around the search so we capture the query-embedding call, not # the add-time source-embedding call. with patch.object(MockTextEmbeddingFunction, "generate_embeddings", record_thread): await (await table.search("a query string")).limit(1).to_list() assert captured_threads, "search did not invoke the embedding function" assert all(name.startswith("lancedb-embedding") for name in captured_threads), ( f"embedding ran off the dedicated executor: {captured_threads}" )