fix: support pyarrow input types (#1628)

fixes #1625 
Support PyArrow.RecordBatch, pa.dataset.Dataset, pa.dataset.Scanner,
paRecordBatchReader
This commit is contained in:
LuQQiu
2024-09-12 10:59:18 -07:00
committed by GitHub
parent b3bf6386c3
commit c7732585bf
3 changed files with 419 additions and 51 deletions

View File

@@ -64,6 +64,55 @@ def test_basic(db):
assert table.to_lance().to_table() == ds.to_table()
def test_input_data_type(db, tmp_path):
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
pa.field("age", pa.int32()),
]
)
data = {
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Charlie", "David", "Eve"],
"age": [25, 30, 35, 40, 45],
}
record_batch = pa.RecordBatch.from_pydict(data, schema=schema)
pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch])
pa_table = pa.Table.from_batches([record_batch])
def create_dataset(tmp_path):
path = os.path.join(tmp_path, "test_source_dataset")
pa.dataset.write_dataset(pa_table, path, format="parquet")
return pa.dataset.dataset(path, format="parquet")
pa_dataset = create_dataset(tmp_path)
pa_scanner = pa_dataset.scanner()
input_types = [
("RecordBatchReader", pa_reader),
("RecordBatch", record_batch),
("Table", pa_table),
("Dataset", pa_dataset),
("Scanner", pa_scanner),
]
for input_type, input_data in input_types:
table_name = f"test_{input_type.lower()}"
ds = LanceTable.create(db, table_name, data=input_data).to_lance()
assert ds.schema == schema
assert ds.count_rows() == 5
assert ds.schema.field("id").type == pa.int64()
assert ds.schema.field("name").type == pa.string()
assert ds.schema.field("age").type == pa.int32()
result_table = ds.to_table()
assert result_table.column("id").to_pylist() == data["id"]
assert result_table.column("name").to_pylist() == data["name"]
assert result_table.column("age").to_pylist() == data["age"]
@pytest.mark.asyncio
async def test_close(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
@@ -274,7 +323,6 @@ def test_polars(db):
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])