mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 11:22:58 +00:00
feat: infer vector type to float32 if integers are out of uint8 range (#2856)
## Summary - infer integer vector columns as float32 when any value exceeds uint8 range or is negative - keep uint8 for integer vectors within range and nulls only - add sync/async tests covering large integer vector inference ## Testing - ./.venv/bin/pytest python/python/tests/test_table.py -k "large_int_vectors"
This commit is contained in:
@@ -3208,7 +3208,27 @@ def _infer_target_schema(
|
||||
if pa.types.is_floating(field.type.value_type):
|
||||
target_type = pa.list_(pa.float32(), dim)
|
||||
elif pa.types.is_integer(field.type.value_type):
|
||||
target_type = pa.list_(pa.uint8(), dim)
|
||||
values = peeked.column(i)
|
||||
|
||||
if isinstance(values, pa.ChunkedArray):
|
||||
values = values.combine_chunks()
|
||||
|
||||
flattened = values.flatten()
|
||||
valid_count = pc.count(flattened, mode="only_valid").as_py()
|
||||
|
||||
if valid_count == 0:
|
||||
target_type = pa.list_(pa.uint8(), dim)
|
||||
else:
|
||||
min_max = pc.min_max(flattened)
|
||||
min_value = min_max["min"].as_py()
|
||||
max_value = min_max["max"].as_py()
|
||||
|
||||
if (min_value is not None and min_value < 0) or (
|
||||
max_value is not None and max_value > 255
|
||||
):
|
||||
target_type = pa.list_(pa.float32(), dim)
|
||||
else:
|
||||
target_type = pa.list_(pa.uint8(), dim)
|
||||
else:
|
||||
continue # Skip non-numeric types
|
||||
|
||||
|
||||
@@ -46,6 +46,39 @@ def test_basic(mem_db: DBConnection):
|
||||
assert table.to_arrow() == expected_data
|
||||
|
||||
|
||||
def test_create_table_infers_large_int_vectors(mem_db: DBConnection):
|
||||
data = [{"vector": [0, 300]}]
|
||||
|
||||
table = mem_db.create_table(
|
||||
"int_vector_overflow", data=data, mode="overwrite", exist_ok=True
|
||||
)
|
||||
|
||||
vector_field = table.schema.field("vector")
|
||||
assert vector_field.type == pa.list_(pa.float32(), 2)
|
||||
|
||||
vector_column = table.to_arrow().column("vector")
|
||||
assert vector_column.type == pa.list_(pa.float32(), 2)
|
||||
assert vector_column.to_pylist() == [[0.0, 300.0]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table_async_infers_large_int_vectors(
|
||||
mem_db_async: AsyncConnection,
|
||||
):
|
||||
data = [{"vector": [256, 257]}]
|
||||
|
||||
table = await mem_db_async.create_table(
|
||||
"int_vector_overflow_async", data=data, mode="overwrite", exist_ok=True
|
||||
)
|
||||
|
||||
schema = await table.schema()
|
||||
assert schema.field("vector").type == pa.list_(pa.float32(), 2)
|
||||
|
||||
vector_column = (await table.to_arrow()).column("vector")
|
||||
assert vector_column.type == pa.list_(pa.float32(), 2)
|
||||
assert vector_column.to_pylist() == [[256.0, 257.0]]
|
||||
|
||||
|
||||
def test_input_data_type(mem_db: DBConnection, tmp_path):
|
||||
schema = pa.schema(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user