diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 748a74a8..bd7f2039 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1113,7 +1113,9 @@ class Table(ABC): raise NotImplementedError @abstractmethod - def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder: + def take_offsets( + self, offsets: list[int], *, with_row_id: bool = False + ) -> LanceTakeQueryBuilder: """ Take a list of offsets from the table. @@ -1139,8 +1141,60 @@ class Table(ABC): A record batch containing the rows at the given offsets. """ + def __getitems__(self, offsets: list[int]) -> pa.RecordBatch: + """ + Take a list of offsets from the table and return as a record batch. + + This method uses the `take_offsets` method to take the rows. However, it + aligns the offsets to the passed in offsets. This means the return type + is a record batch (and so users should take care not to pass in too many + offsets) + + Note: this method is primarily intended to fulfill the Dataset contract + for pytorch. + + Parameters + ---------- + offsets: list[int] + The offsets to take. + + Returns + ------- + pa.RecordBatch + A record batch containing the rows at the given offsets. + """ + # We don't know the order of the results at all. So we calculate a permutation + # for ordering the given offsets. Then we load the data with the _rowoffset + # column. Then we sort by _rowoffset and apply the inverse of the permutation + # that we calculated. + # + # Note: this is potentially a lot of memory copy if we're operating on large + # batches :( + num_offsets = len(offsets) + indices = list(range(num_offsets)) + permutation = sorted(indices, key=lambda idx: offsets[idx]) + permutation_inv = [0] * num_offsets + for i in range(num_offsets): + permutation_inv[permutation[i]] = i + + columns = self.schema.names + columns.append("_rowoffset") + tbl = ( + self.take_offsets(offsets) + .select(columns) + .to_arrow() + .sort_by("_rowoffset") + .take(permutation_inv) + .combine_chunks() + .drop_columns(["_rowoffset"]) + ) + + return tbl + @abstractmethod - def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder: + def take_row_ids( + self, row_ids: list[int], *, with_row_id: bool = False + ) -> LanceTakeQueryBuilder: """ Take a list of row ids from the table. diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 756aa009..b80984da 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -5,6 +5,7 @@ from typing import List, Union import unittest.mock as mock from datetime import timedelta from pathlib import Path +import random import lancedb from lancedb.db import AsyncConnection @@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path): ] +def test_getitems(tmp_path): + db = lancedb.connect(tmp_path) + data = pa.table( + { + "idx": range(100), + } + ) + # Make two fragments + table = db.create_table("test", data) + table.add(pa.table({"idx": range(100, 200)})) + + assert table.__getitems__([5, 2, 117]) == pa.table( + { + "idx": [5, 2, 117], + } + ) + + offsets = random.sample(range(200), 10) + assert table.__getitems__(offsets) == pa.table({"idx": offsets}) + + @pytest.mark.asyncio async def test_query_timeout_async(tmp_path): db = await lancedb.connect_async(tmp_path) diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py new file mode 100644 index 00000000..26c3ef5f --- /dev/null +++ b/python/python/tests/test_torch.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import pyarrow as pa +import pytest + +torch = pytest.importorskip("torch") + + +def tbl_to_tensor(tbl): + def to_tensor(col: pa.ChunkedArray): + if col.num_chunks > 1: + raise Exception("Single batch was too large to fit into a one-chunk table") + return torch.from_dlpack(col.chunk(0)) + + return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)]) + + +def test_table_dataloader(mem_db): + table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) + dataloader = torch.utils.data.DataLoader( + table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True + ) + for batch in dataloader: + assert batch.size(0) == 1 + assert batch.size(1) == 10