diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 230daa89..96c6a13a 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3208,7 +3208,27 @@ def _infer_target_schema( if pa.types.is_floating(field.type.value_type): target_type = pa.list_(pa.float32(), dim) elif pa.types.is_integer(field.type.value_type): - target_type = pa.list_(pa.uint8(), dim) + values = peeked.column(i) + + if isinstance(values, pa.ChunkedArray): + values = values.combine_chunks() + + flattened = values.flatten() + valid_count = pc.count(flattened, mode="only_valid").as_py() + + if valid_count == 0: + target_type = pa.list_(pa.uint8(), dim) + else: + min_max = pc.min_max(flattened) + min_value = min_max["min"].as_py() + max_value = min_max["max"].as_py() + + if (min_value is not None and min_value < 0) or ( + max_value is not None and max_value > 255 + ): + target_type = pa.list_(pa.float32(), dim) + else: + target_type = pa.list_(pa.uint8(), dim) else: continue # Skip non-numeric types diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 9d40c2fa..f145d3c4 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -46,6 +46,39 @@ def test_basic(mem_db: DBConnection): assert table.to_arrow() == expected_data +def test_create_table_infers_large_int_vectors(mem_db: DBConnection): + data = [{"vector": [0, 300]}] + + table = mem_db.create_table( + "int_vector_overflow", data=data, mode="overwrite", exist_ok=True + ) + + vector_field = table.schema.field("vector") + assert vector_field.type == pa.list_(pa.float32(), 2) + + vector_column = table.to_arrow().column("vector") + assert vector_column.type == pa.list_(pa.float32(), 2) + assert vector_column.to_pylist() == [[0.0, 300.0]] + + +@pytest.mark.asyncio +async def test_create_table_async_infers_large_int_vectors( + mem_db_async: AsyncConnection, +): + data = [{"vector": [256, 257]}] + + table = await mem_db_async.create_table( + "int_vector_overflow_async", data=data, mode="overwrite", exist_ok=True + ) + + schema = await table.schema() + assert schema.field("vector").type == pa.list_(pa.float32(), 2) + + vector_column = (await table.to_arrow()).column("vector") + assert vector_column.type == pa.list_(pa.float32(), 2) + assert vector_column.to_pylist() == [[256.0, 257.0]] + + def test_input_data_type(mem_db: DBConnection, tmp_path): schema = pa.schema( {