mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-03 20:30:42 +00:00
feat: add __getitems__ method impl for torch integration (#2596)
This allows a lancedb Table to act as a torch dataset.
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
from lancedb.db import AsyncConnection
|
||||
@@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path):
|
||||
]
|
||||
|
||||
|
||||
def test_getitems(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"idx": range(100),
|
||||
}
|
||||
)
|
||||
# Make two fragments
|
||||
table = db.create_table("test", data)
|
||||
table.add(pa.table({"idx": range(100, 200)}))
|
||||
|
||||
assert table.__getitems__([5, 2, 117]) == pa.table(
|
||||
{
|
||||
"idx": [5, 2, 117],
|
||||
}
|
||||
)
|
||||
|
||||
offsets = random.sample(range(200), 10)
|
||||
assert table.__getitems__(offsets) == pa.table({"idx": offsets})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
26
python/python/tests/test_torch.py
Normal file
26
python/python/tests/test_torch.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl):
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
|
||||
)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
Reference in New Issue
Block a user