mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
bugfix for LanceTable.add to convert python lists into arrow fixed size lists
- Fixed `add` unit test to create the correct expected result - Added a unit test for LanceTable.add - Need to discuss if len(LanceTable) is handled correctly
This commit is contained in:
@@ -171,6 +171,7 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
|
|||||||
return data
|
return data
|
||||||
# cast the columns to the expected types
|
# cast the columns to the expected types
|
||||||
data = data.combine_chunks()
|
data = data.combine_chunks()
|
||||||
|
data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
|
||||||
return pa.Table.from_arrays(
|
return pa.Table.from_arrays(
|
||||||
[data[name] for name in schema.names], schema=schema
|
[data[name] for name in schema.names], schema=schema
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,14 +49,14 @@ def test_basic(db):
|
|||||||
def test_add(db):
|
def test_add(db):
|
||||||
schema = pa.schema(
|
schema = pa.schema(
|
||||||
[
|
[
|
||||||
pa.field("vector", pa.list_(pa.float32())),
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
pa.field("item", pa.string()),
|
pa.field("item", pa.string()),
|
||||||
pa.field("price", pa.float32()),
|
pa.field("price", pa.float32()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected = pa.Table.from_arrays(
|
expected = pa.Table.from_arrays(
|
||||||
[
|
[
|
||||||
pa.array([[3.1, 4.1], [5.9, 26.5]]),
|
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||||
pa.array(["foo", "bar"]),
|
pa.array(["foo", "bar"]),
|
||||||
pa.array([10.0, 20.0]),
|
pa.array([10.0, 20.0]),
|
||||||
],
|
],
|
||||||
@@ -79,3 +79,35 @@ def test_add(db):
|
|||||||
.to_table()
|
.to_table()
|
||||||
)
|
)
|
||||||
assert expected == tbl
|
assert expected == tbl
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_items(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"test",
|
||||||
|
data=[
|
||||||
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# table = LanceTable(db, "test")
|
||||||
|
assert len(table) == 2
|
||||||
|
|
||||||
|
count = table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||||
|
assert count == 3
|
||||||
|
#assert len(table) == 3 #FAILS! len(table) == 2, since add creates a new ds
|
||||||
|
|
||||||
|
expected = pa.Table.from_arrays(
|
||||||
|
[
|
||||||
|
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||||
|
pa.array(["foo", "bar"]),
|
||||||
|
pa.array([10.0, 20.0]),
|
||||||
|
],
|
||||||
|
schema=pa.schema([
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
pa.field("item", pa.string()),
|
||||||
|
pa.field("price", pa.float64()),
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
assert expected == table.to_arrow()
|
||||||
|
|||||||
Reference in New Issue
Block a user