mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-01 21:40:41 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -3,19 +3,11 @@
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl):
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user