feat(python): streaming larger-than-memory writes (#2094)

Makes our preprocessing pipeline do transforms in streaming fashion, so
users can do larger-then-memory writes.

Closes #2082
This commit is contained in:
Will Jones
2025-02-06 16:37:30 -08:00
committed by GitHub
parent 4e5fbe6c99
commit 801a9e5f6f
5 changed files with 192 additions and 119 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import pyarrow as pa
@@ -66,3 +66,17 @@ class AsyncRecordBatchReader:
batches = table.to_batches(max_chunksize=max_batch_length)
for batch in batches:
yield batch
def peek_reader(
reader: pa.RecordBatchReader,
) -> Tuple[pa.RecordBatch, pa.RecordBatchReader]:
if not isinstance(reader, pa.RecordBatchReader):
raise TypeError("reader must be a RecordBatchReader")
batch = reader.read_next_batch()
def all_batches():
yield batch
yield from reader
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())

View File

@@ -116,7 +116,7 @@ class EmbeddingFunction(BaseModel, ABC):
)
@abstractmethod
def ndims(self):
def ndims(self) -> int:
"""
Return the dimensions of the vector column
"""

View File

@@ -24,6 +24,7 @@ from typing import (
from urllib.parse import urlparse
import lance
from lancedb.arrow import peek_reader
from lancedb.background_loop import LOOP
from .dependencies import _check_for_pandas
import pyarrow as pa
@@ -74,17 +75,19 @@ pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _into_pyarrow_table(data) -> pa.Table:
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
if _check_for_hugging_face(data):
# Huggingface datasets
from lance.dependencies import datasets
if isinstance(data, datasets.Dataset):
schema = data.features.arrow_schema
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
elif isinstance(data, datasets.dataset_dict.DatasetDict):
schema = _schema_from_hf(data, schema)
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
return pa.RecordBatchReader.from_batches(
schema, _to_batches_with_split(data)
)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
@@ -96,41 +99,41 @@ def _into_pyarrow_table(data) -> pa.Table:
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data]
return pa.Table.from_pylist(data, schema=schema)
return pa.Table.from_pylist(data, schema=schema).to_reader()
elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data)
return pa.Table.from_batches(data).to_reader()
else:
return pa.Table.from_pylist(data)
return pa.Table.from_pylist(data).to_reader()
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
table = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
return table.replace_schema_metadata(meta)
return table.replace_schema_metadata(meta).to_reader()
elif isinstance(data, pa.Table):
return data
return data.to_reader()
elif isinstance(data, pa.RecordBatch):
return pa.Table.from_batches([data])
return pa.RecordBatchReader.from_batches(data.schema, [data])
elif isinstance(data, LanceDataset):
return data.scanner().to_table()
return data.scanner().to_reader()
elif isinstance(data, pa.dataset.Dataset):
return data.to_table()
return data.scanner().to_reader()
elif isinstance(data, pa.dataset.Scanner):
return data.to_table()
return data.to_reader()
elif isinstance(data, pa.RecordBatchReader):
return data.read_all()
return data
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "DataFrame"
):
return data.to_arrow()
return data.to_arrow().to_reader()
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "LazyFrame"
):
return data.collect().to_arrow()
return data.collect().to_arrow().to_reader()
elif isinstance(data, Iterable):
return _iterator_to_table(data)
return _iterator_to_reader(data)
else:
raise TypeError(
f"Unknown data type {type(data)}. "
@@ -140,30 +143,28 @@ def _into_pyarrow_table(data) -> pa.Table:
)
def _iterator_to_table(data: Iterable) -> pa.Table:
batches = []
schema = None # Will get schema from first batch
for batch in data:
batch_table = _into_pyarrow_table(batch)
if schema is not None:
if batch_table.schema != schema:
def _iterator_to_reader(data: Iterable) -> pa.RecordBatchReader:
# Each batch is treated as it's own reader, mainly so we can
# re-use the _into_pyarrow_reader logic.
first = _into_pyarrow_reader(next(data))
schema = first.schema
def gen():
yield from first
for batch in data:
table: pa.Table = _into_pyarrow_reader(batch).read_all()
if table.schema != schema:
try:
batch_table = batch_table.cast(schema)
table = table.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the schema of other batches.\n"
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
f"Expected:\n{schema}\nGot:\n{batch.schema}"
)
else:
# Use the first schema for the remainder of the batches
schema = batch_table.schema
batches.append(batch_table)
yield from table.to_batches()
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
return pa.RecordBatchReader.from_batches(schema, gen())
def _sanitize_data(
@@ -174,7 +175,7 @@ def _sanitize_data(
fill_value: float = 0.0,
*,
allow_subschema: bool = False,
) -> pa.Table:
) -> pa.RecordBatchReader:
"""
Handle input data, applying all standard transformations.
@@ -207,20 +208,20 @@ def _sanitize_data(
# 1. There might be embedding columns missing that will be added
# in the add_embeddings step.
# 2. If `allow_subschemas` is True, there might be columns missing.
table = _into_pyarrow_table(data)
reader = _into_pyarrow_reader(data)
table = _append_vector_columns(table, target_schema, metadata=metadata)
reader = _append_vector_columns(reader, target_schema, metadata=metadata)
# This happens before the cast so we can fix vector columns with
# incorrect lengths before they are cast to FSL.
table = _handle_bad_vectors(
table,
reader = _handle_bad_vectors(
reader,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if target_schema is None:
target_schema = _infer_target_schema(table)
target_schema, reader = _infer_target_schema(reader)
if metadata:
new_metadata = target_schema.metadata or {}
@@ -229,25 +230,25 @@ def _sanitize_data(
_validate_schema(target_schema)
table = _cast_to_target_schema(table, target_schema, allow_subschema)
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
return table
return reader
def _cast_to_target_schema(
table: pa.Table,
reader: pa.RecordBatchReader,
target_schema: pa.Schema,
allow_subschema: bool = False,
) -> pa.Table:
) -> pa.RecordBatchReader:
# pa.Table.cast expects field order not to be changed.
# Lance doesn't care about field order, so we don't need to rearrange fields
# to match the target schema. We just need to correctly cast the fields.
if table.schema == target_schema:
if reader.schema == target_schema:
# Fast path when the schemas are already the same
return table
return reader
fields = []
for field in table.schema:
for field in reader.schema:
target_field = target_schema.field(field.name)
if target_field is None:
raise ValueError(f"Field {field.name} not found in target schema")
@@ -260,12 +261,16 @@ def _cast_to_target_schema(
if allow_subschema and len(reordered_schema) != len(target_schema):
fields = _infer_subschema(
list(iter(table.schema)), list(iter(reordered_schema))
list(iter(reader.schema)), list(iter(reordered_schema))
)
subschema = pa.schema(fields, metadata=target_schema.metadata)
return table.cast(subschema)
else:
return table.cast(reordered_schema)
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
def gen():
for batch in reader:
# Table but not RecordBatch has cast.
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
def _infer_subschema(
@@ -344,7 +349,10 @@ def sanitize_create_table(
if metadata:
schema = schema.with_metadata(metadata)
# Need to apply metadata to the data as well
data = data.replace_schema_metadata(metadata)
if isinstance(data, pa.Table):
data = data.replace_schema_metadata(metadata)
elif isinstance(data, pa.RecordBatchReader):
data = pa.RecordBatchReader.from_batches(schema, data)
return data, schema
@@ -381,11 +389,11 @@ def _to_batches_with_split(data):
def _append_vector_columns(
data: pa.Table,
reader: pa.RecordBatchReader,
schema: Optional[pa.Schema] = None,
*,
metadata: Optional[dict] = None,
) -> pa.Table:
) -> pa.RecordBatchReader:
"""
Use the embedding function to automatically embed the source columns and add the
vector columns to the table.
@@ -396,28 +404,43 @@ def _append_vector_columns(
metadata = schema.metadata or metadata or {}
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
if not functions:
return reader
fields = list(reader.schema)
for vector_column, conf in functions.items():
func = conf.function
no_vector_column = vector_column not in data.column_names
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
col_data = func.compute_source_embeddings_with_retry(
data[conf.source_column]
)
if vector_column not in reader.schema.names:
if schema is not None:
dtype = schema.field(vector_column).type
field = schema.field(vector_column)
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
if no_vector_column:
data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
)
else:
data = data.set_column(
data.column_names.index(vector_column),
pa.field(vector_column, type=dtype),
pa.array(col_data, type=dtype),
)
return data
dtype = pa.list_(pa.float32(), conf.function.ndims())
field = pa.field(vector_column, type=dtype, nullable=True)
fields.append(field)
schema = pa.schema(fields, metadata=reader.schema.metadata)
def gen():
for batch in reader:
for vector_column, conf in functions.items():
func = conf.function
no_vector_column = vector_column not in batch.column_names
if no_vector_column or pc.all(pc.is_null(batch[vector_column])).as_py():
col_data = func.compute_source_embeddings_with_retry(
batch[conf.source_column]
)
if no_vector_column:
batch = batch.append_column(
schema.field(vector_column),
pa.array(col_data, type=schema.field(vector_column).type),
)
else:
batch = batch.set_column(
batch.column_names.index(vector_column),
schema.field(vector_column),
pa.array(col_data, type=schema.field(vector_column).type),
)
yield batch
return pa.RecordBatchReader.from_batches(schema, gen())
def _table_path(base: str, table_name: str) -> str:
@@ -2358,11 +2381,13 @@ class LanceTable(Table):
def _handle_bad_vectors(
table: pa.Table,
reader: pa.RecordBatchReader,
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
fill_value: float = 0.0,
) -> pa.Table:
for field in table.schema:
) -> pa.RecordBatchReader:
vector_columns = []
for field in reader.schema:
# They can provide a 'vector' column that isn't yet a FSL
named_vector_col = (
(
@@ -2382,22 +2407,28 @@ def _handle_bad_vectors(
)
if named_vector_col or likely_vector_col:
table = _handle_bad_vector_column(
table,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
vector_columns.append(field.name)
return table
def gen():
for batch in reader:
for name in vector_columns:
batch = _handle_bad_vector_column(
batch,
vector_column_name=name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
yield batch
return pa.RecordBatchReader.from_batches(reader.schema, gen())
def _handle_bad_vector_column(
data: pa.Table,
data: pa.RecordBatch,
vector_column_name: str,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
) -> pa.RecordBatch:
"""
Ensure that the vector column exists and has type fixed_size_list(float)
@@ -2485,8 +2516,11 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
return pc.is_in(indices, has_nan_indices)
def _infer_target_schema(table: pa.Table) -> pa.Schema:
schema = table.schema
def _infer_target_schema(
reader: pa.RecordBatchReader,
) -> Tuple[pa.Schema, pa.RecordBatchReader]:
schema = reader.schema
peeked = None
for i, field in enumerate(schema):
if (
@@ -2494,8 +2528,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_floating(field.type.value_type)
):
if peeked is None:
peeked, reader = peek_reader(reader)
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
dim = _modal_list_size(peeked.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
@@ -2509,8 +2545,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_integer(field.type.value_type)
):
if peeked is None:
peeked, reader = peek_reader(reader)
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
dim = _modal_list_size(peeked.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
pa.list_(pa.uint8(), dim),
@@ -2519,7 +2557,7 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
schema = schema.set(i, new_field)
return schema
return schema, reader
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:

View File

@@ -107,7 +107,7 @@ def test_embedding_with_bad_results(tmp_path):
vector: Vector(model.ndims()) = model.VectorField()
table = db.create_table("test", schema=Schema, mode="overwrite")
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
# Default on_bad_vectors is "error"
table.add([{"text": "hello world"}])

View File

@@ -14,7 +14,7 @@ from lancedb.table import (
_append_vector_columns,
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_table,
_into_pyarrow_reader,
_sanitize_data,
_infer_target_schema,
)
@@ -145,19 +145,19 @@ def test_append_vector_columns():
schema=schema,
)
output = _append_vector_columns(
data,
data.to_reader(),
schema, # metadata passed separate from schema
metadata=metadata,
)
).read_all()
assert output.schema == schema
assert output["vector"].null_count == 0
# Adds if missing
data = pa.table({"text": ["hello"]})
output = _append_vector_columns(
data,
data.to_reader(),
schema.with_metadata(metadata),
)
).read_all()
assert output.schema == schema
assert output["vector"].null_count == 0
@@ -170,9 +170,9 @@ def test_append_vector_columns():
schema=schema,
)
output = _append_vector_columns(
data,
data.to_reader(),
schema.with_metadata(metadata),
)
).read_all()
assert output == data # No change
# No provided schema
@@ -182,9 +182,9 @@ def test_append_vector_columns():
}
)
output = _append_vector_columns(
data,
data.to_reader(),
metadata=metadata,
)
).read_all()
expected_schema = pa.schema(
{
"text": pa.string(),
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
)
).read_all()
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has variable length vectors. Set "
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
return
else:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
).read_all()
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
)
).read_all()
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has NaNs. Set "
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
return
else:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
).read_all()
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
)
data = pa.table({"vector": vector})
output = _handle_bad_vectors(data)
output = _handle_bad_vectors(data.to_reader()).read_all()
assert output["vector"] == vector
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
)
def test_into_pyarrow_table(data):
expected = pa.table({"a": [1], "b": [2]})
output = _into_pyarrow_table(data())
output = _into_pyarrow_reader(data()).read_all()
assert output == expected
@@ -349,7 +349,7 @@ def test_infer_target_schema():
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
# Handle large list and use modal size
@@ -370,7 +370,7 @@ def test_infer_target_schema():
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
# ignore if not list
@@ -386,7 +386,7 @@ def test_infer_target_schema():
schema=example,
)
expected = example
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
@@ -476,7 +476,7 @@ def test_sanitize_data(
target_schema=schema,
metadata=metadata,
allow_subschema=True,
)
).read_all()
assert output_data == expected
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
"vec2": pa.list_(pa.float32(), 2),
}
)
output = _cast_to_target_schema(data, target)
output = _cast_to_target_schema(data.to_reader(), target)
expected = pa.table(
{
"id": [1],
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
}
)
with pytest.raises(Exception):
_cast_to_target_schema(data, target)
output = _cast_to_target_schema(data, target, allow_subschema=True)
_cast_to_target_schema(data.to_reader(), target)
output = _cast_to_target_schema(
data.to_reader(), target, allow_subschema=True
).read_all()
expected_schema = pa.schema(
{
"id": pa.int64(),
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
schema=expected_schema,
)
assert output == expected
def test_sanitize_data_stream():
# Make sure we don't collect the whole stream when running sanitize_data
schema = pa.schema({"a": pa.int32()})
def stream():
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
raise ValueError("error")
reader = pa.RecordBatchReader.from_batches(schema, stream())
output = _sanitize_data(reader)
first = next(output)
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
with pytest.raises(ValueError):
next(output)