mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
fix: support pyarrow input types (#1628)
fixes #1625 Support PyArrow.RecordBatch, pa.dataset.Dataset, pa.dataset.Scanner, paRecordBatchReader
This commit is contained in:
@@ -64,6 +64,55 @@ def test_basic(db):
|
||||
assert table.to_lance().to_table() == ds.to_table()
|
||||
|
||||
|
||||
def test_input_data_type(db, tmp_path):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("name", pa.string()),
|
||||
pa.field("age", pa.int32()),
|
||||
]
|
||||
)
|
||||
|
||||
data = {
|
||||
"id": [1, 2, 3, 4, 5],
|
||||
"name": ["Alice", "Bob", "Charlie", "David", "Eve"],
|
||||
"age": [25, 30, 35, 40, 45],
|
||||
}
|
||||
record_batch = pa.RecordBatch.from_pydict(data, schema=schema)
|
||||
pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch])
|
||||
pa_table = pa.Table.from_batches([record_batch])
|
||||
|
||||
def create_dataset(tmp_path):
|
||||
path = os.path.join(tmp_path, "test_source_dataset")
|
||||
pa.dataset.write_dataset(pa_table, path, format="parquet")
|
||||
return pa.dataset.dataset(path, format="parquet")
|
||||
|
||||
pa_dataset = create_dataset(tmp_path)
|
||||
pa_scanner = pa_dataset.scanner()
|
||||
|
||||
input_types = [
|
||||
("RecordBatchReader", pa_reader),
|
||||
("RecordBatch", record_batch),
|
||||
("Table", pa_table),
|
||||
("Dataset", pa_dataset),
|
||||
("Scanner", pa_scanner),
|
||||
]
|
||||
for input_type, input_data in input_types:
|
||||
table_name = f"test_{input_type.lower()}"
|
||||
ds = LanceTable.create(db, table_name, data=input_data).to_lance()
|
||||
assert ds.schema == schema
|
||||
assert ds.count_rows() == 5
|
||||
|
||||
assert ds.schema.field("id").type == pa.int64()
|
||||
assert ds.schema.field("name").type == pa.string()
|
||||
assert ds.schema.field("age").type == pa.int32()
|
||||
|
||||
result_table = ds.to_table()
|
||||
assert result_table.column("id").to_pylist() == data["id"]
|
||||
assert result_table.column("name").to_pylist() == data["name"]
|
||||
assert result_table.column("age").to_pylist() == data["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(db_async: AsyncConnection):
|
||||
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||
@@ -274,7 +323,6 @@ def test_polars(db):
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
|
||||
Reference in New Issue
Block a user