mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
feat: add __getitems__ method impl for torch integration (#2596)
This allows a lancedb Table to act as a torch dataset.
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
from lancedb.db import AsyncConnection
|
||||
@@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path):
|
||||
]
|
||||
|
||||
|
||||
def test_getitems(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"idx": range(100),
|
||||
}
|
||||
)
|
||||
# Make two fragments
|
||||
table = db.create_table("test", data)
|
||||
table.add(pa.table({"idx": range(100, 200)}))
|
||||
|
||||
assert table.__getitems__([5, 2, 117]) == pa.table(
|
||||
{
|
||||
"idx": [5, 2, 117],
|
||||
}
|
||||
)
|
||||
|
||||
offsets = random.sample(range(200), 10)
|
||||
assert table.__getitems__(offsets) == pa.table({"idx": offsets})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user