mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-30 12:30:39 +00:00
19 lines
539 B
Python
19 lines
539 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import pyarrow as pa
|
|
import pytest
|
|
from lancedb.util import tbl_to_tensor
|
|
|
|
torch = pytest.importorskip("torch")
|
|
|
|
|
|
def test_table_dataloader(mem_db):
|
|
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
|
dataloader = torch.utils.data.DataLoader(
|
|
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
|
|
)
|
|
for batch in dataloader:
|
|
assert batch.size(0) == 1
|
|
assert batch.size(1) == 10
|