mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
feat: add __getitems__ method impl for torch integration (#2596)
This allows a lancedb Table to act as a torch dataset.
This commit is contained in:
@@ -1113,7 +1113,9 @@ class Table(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
def take_offsets(
|
||||
self, offsets: list[int], *, with_row_id: bool = False
|
||||
) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of offsets from the table.
|
||||
|
||||
@@ -1139,8 +1141,60 @@ class Table(ABC):
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
|
||||
def __getitems__(self, offsets: list[int]) -> pa.RecordBatch:
|
||||
"""
|
||||
Take a list of offsets from the table and return as a record batch.
|
||||
|
||||
This method uses the `take_offsets` method to take the rows. However, it
|
||||
aligns the offsets to the passed in offsets. This means the return type
|
||||
is a record batch (and so users should take care not to pass in too many
|
||||
offsets)
|
||||
|
||||
Note: this method is primarily intended to fulfill the Dataset contract
|
||||
for pytorch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offsets: list[int]
|
||||
The offsets to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
# We don't know the order of the results at all. So we calculate a permutation
|
||||
# for ordering the given offsets. Then we load the data with the _rowoffset
|
||||
# column. Then we sort by _rowoffset and apply the inverse of the permutation
|
||||
# that we calculated.
|
||||
#
|
||||
# Note: this is potentially a lot of memory copy if we're operating on large
|
||||
# batches :(
|
||||
num_offsets = len(offsets)
|
||||
indices = list(range(num_offsets))
|
||||
permutation = sorted(indices, key=lambda idx: offsets[idx])
|
||||
permutation_inv = [0] * num_offsets
|
||||
for i in range(num_offsets):
|
||||
permutation_inv[permutation[i]] = i
|
||||
|
||||
columns = self.schema.names
|
||||
columns.append("_rowoffset")
|
||||
tbl = (
|
||||
self.take_offsets(offsets)
|
||||
.select(columns)
|
||||
.to_arrow()
|
||||
.sort_by("_rowoffset")
|
||||
.take(permutation_inv)
|
||||
.combine_chunks()
|
||||
.drop_columns(["_rowoffset"])
|
||||
)
|
||||
|
||||
return tbl
|
||||
|
||||
@abstractmethod
|
||||
def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder:
|
||||
def take_row_ids(
|
||||
self, row_ids: list[int], *, with_row_id: bool = False
|
||||
) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of row ids from the table.
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
from lancedb.db import AsyncConnection
|
||||
@@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path):
|
||||
]
|
||||
|
||||
|
||||
def test_getitems(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"idx": range(100),
|
||||
}
|
||||
)
|
||||
# Make two fragments
|
||||
table = db.create_table("test", data)
|
||||
table.add(pa.table({"idx": range(100, 200)}))
|
||||
|
||||
assert table.__getitems__([5, 2, 117]) == pa.table(
|
||||
{
|
||||
"idx": [5, 2, 117],
|
||||
}
|
||||
)
|
||||
|
||||
offsets = random.sample(range(200), 10)
|
||||
assert table.__getitems__(offsets) == pa.table({"idx": offsets})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
26
python/python/tests/test_torch.py
Normal file
26
python/python/tests/test_torch.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl):
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
|
||||
)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
Reference in New Issue
Block a user