feat: make it possible to opt in to using the v2 format (#1352)

This also exposed the max_batch_length configuration option in
python/node (it was needed to verify if we are actually in v2 mode or
not)
This commit is contained in:
Weston Pace
2024-06-04 21:52:14 -07:00
committed by GitHub
parent d39e7d23f4
commit d5586c9c32
17 changed files with 310 additions and 33 deletions

View File

@@ -24,6 +24,7 @@ class Connection(object):
mode: str,
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
async def create_empty_table(
self,
@@ -31,6 +32,7 @@ class Connection(object):
mode: str,
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ...
class Table:
@@ -72,7 +74,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...

View File

@@ -558,6 +558,8 @@ class AsyncConnection(object):
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
use_legacy_format: Optional[bool] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -600,6 +602,9 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
use_legacy_format: bool, optional, default True
If True, use the legacy format for the table. If False, use the new format.
The default is True while the new format is in beta.
Returns
@@ -761,7 +766,11 @@ class AsyncConnection(object):
if data is None:
new_table = await self._inner.create_empty_table(
name, mode, schema, storage_options=storage_options
name,
mode,
schema,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
)
else:
data = data_to_reader(data, schema)
@@ -770,6 +779,7 @@ class AsyncConnection(object):
mode,
data,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
)
return AsyncTable(new_table)

View File

@@ -1113,11 +1113,22 @@ class AsyncQueryBase(object):
self._inner.limit(limit)
return self
async def to_batches(self) -> AsyncRecordBatchReader:
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
Parameters
----------
max_batch_length: Optional[int]
The maximum number of selected records in a single RecordBatch object.
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
"""
return AsyncRecordBatchReader(await self._inner.execute())
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
async def to_arrow(self) -> pa.Table:
"""

View File

@@ -507,6 +507,52 @@ def test_empty_or_nonexistent_table(tmp_path):
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
def make_table():
return pa.table([pa.array([x for x in range(10 * 1024)])], names=["x"])
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table("test", data=make_data(), schema=schema)
async def is_in_v2_mode(tbl):
batches = await tbl.query().to_batches(max_batch_length=1024 * 10)
num_batches = 0
async for batch in batches:
num_batches += 1
return num_batches < 10
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
assert await is_in_v2_mode(tbl)
# Add data (should remain in v2 mode)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(