From 39a18baf594439594918248300607cf0bdc702ae Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Dec 2025 17:10:25 +0800 Subject: [PATCH] feat: infer vector type to float32 if integers are out of uint8 range (#2856) ## Summary - infer integer vector columns as float32 when any value exceeds uint8 range or is negative - keep uint8 for integer vectors within range and nulls only - add sync/async tests covering large integer vector inference ## Testing - ./.venv/bin/pytest python/python/tests/test_table.py -k "large_int_vectors" --- python/python/lancedb/table.py | 22 ++++++++++++++++++++- python/python/tests/test_table.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 230daa89..96c6a13a 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3208,7 +3208,27 @@ def _infer_target_schema( if pa.types.is_floating(field.type.value_type): target_type = pa.list_(pa.float32(), dim) elif pa.types.is_integer(field.type.value_type): - target_type = pa.list_(pa.uint8(), dim) + values = peeked.column(i) + + if isinstance(values, pa.ChunkedArray): + values = values.combine_chunks() + + flattened = values.flatten() + valid_count = pc.count(flattened, mode="only_valid").as_py() + + if valid_count == 0: + target_type = pa.list_(pa.uint8(), dim) + else: + min_max = pc.min_max(flattened) + min_value = min_max["min"].as_py() + max_value = min_max["max"].as_py() + + if (min_value is not None and min_value < 0) or ( + max_value is not None and max_value > 255 + ): + target_type = pa.list_(pa.float32(), dim) + else: + target_type = pa.list_(pa.uint8(), dim) else: continue # Skip non-numeric types diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 9d40c2fa..f145d3c4 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -46,6 +46,39 @@ def test_basic(mem_db: DBConnection): assert table.to_arrow() == expected_data +def test_create_table_infers_large_int_vectors(mem_db: DBConnection): + data = [{"vector": [0, 300]}] + + table = mem_db.create_table( + "int_vector_overflow", data=data, mode="overwrite", exist_ok=True + ) + + vector_field = table.schema.field("vector") + assert vector_field.type == pa.list_(pa.float32(), 2) + + vector_column = table.to_arrow().column("vector") + assert vector_column.type == pa.list_(pa.float32(), 2) + assert vector_column.to_pylist() == [[0.0, 300.0]] + + +@pytest.mark.asyncio +async def test_create_table_async_infers_large_int_vectors( + mem_db_async: AsyncConnection, +): + data = [{"vector": [256, 257]}] + + table = await mem_db_async.create_table( + "int_vector_overflow_async", data=data, mode="overwrite", exist_ok=True + ) + + schema = await table.schema() + assert schema.field("vector").type == pa.list_(pa.float32(), 2) + + vector_column = (await table.to_arrow()).column("vector") + assert vector_column.type == pa.list_(pa.float32(), 2) + assert vector_column.to_pylist() == [[256.0, 257.0]] + + def test_input_data_type(mem_db: DBConnection, tmp_path): schema = pa.schema( {