feat: add __getitems__ method impl for torch integration (#2596)

This allows a lancedb Table to act as a torch dataset.
This commit is contained in:
Weston Pace
2025-08-25 13:23:22 -07:00
committed by GitHub
parent 6839ac3509
commit fabe37274f
3 changed files with 104 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ from typing import List, Union
import unittest.mock as mock
from datetime import timedelta
from pathlib import Path
import random
import lancedb
from lancedb.db import AsyncConnection
@@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path):
]
def test_getitems(tmp_path):
db = lancedb.connect(tmp_path)
data = pa.table(
{
"idx": range(100),
}
)
# Make two fragments
table = db.create_table("test", data)
table.add(pa.table({"idx": range(100, 200)}))
assert table.__getitems__([5, 2, 117]) == pa.table(
{
"idx": [5, 2, 117],
}
)
offsets = random.sample(range(200), 10)
assert table.__getitems__(offsets) == pa.table({"idx": offsets})
@pytest.mark.asyncio
async def test_query_timeout_async(tmp_path):
db = await lancedb.connect_async(tmp_path)

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
torch = pytest.importorskip("torch")
def tbl_to_tensor(tbl):
def to_tensor(col: pa.ChunkedArray):
if col.num_chunks > 1:
raise Exception("Single batch was too large to fit into a one-chunk table")
return torch.from_dlpack(col.chunk(0))
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
def test_table_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
dataloader = torch.utils.data.DataLoader(
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10