feat(python): streaming larger-than-memory writes (#2094)

Makes our preprocessing pipeline do transforms in streaming fashion, so
users can do larger-then-memory writes.

Closes #2082
This commit is contained in:
Will Jones
2025-02-06 16:37:30 -08:00
committed by GitHub
parent 4e5fbe6c99
commit 801a9e5f6f
5 changed files with 192 additions and 119 deletions

View File

@@ -14,7 +14,7 @@ from lancedb.table import (
_append_vector_columns,
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_table,
_into_pyarrow_reader,
_sanitize_data,
_infer_target_schema,
)
@@ -145,19 +145,19 @@ def test_append_vector_columns():
schema=schema,
)
output = _append_vector_columns(
data,
data.to_reader(),
schema, # metadata passed separate from schema
metadata=metadata,
)
).read_all()
assert output.schema == schema
assert output["vector"].null_count == 0
# Adds if missing
data = pa.table({"text": ["hello"]})
output = _append_vector_columns(
data,
data.to_reader(),
schema.with_metadata(metadata),
)
).read_all()
assert output.schema == schema
assert output["vector"].null_count == 0
@@ -170,9 +170,9 @@ def test_append_vector_columns():
schema=schema,
)
output = _append_vector_columns(
data,
data.to_reader(),
schema.with_metadata(metadata),
)
).read_all()
assert output == data # No change
# No provided schema
@@ -182,9 +182,9 @@ def test_append_vector_columns():
}
)
output = _append_vector_columns(
data,
data.to_reader(),
metadata=metadata,
)
).read_all()
expected_schema = pa.schema(
{
"text": pa.string(),
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
)
).read_all()
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has variable length vectors. Set "
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
return
else:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
).read_all()
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
)
).read_all()
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has NaNs. Set "
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
return
else:
output = _handle_bad_vectors(
data,
data.to_reader(),
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
).read_all()
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
)
data = pa.table({"vector": vector})
output = _handle_bad_vectors(data)
output = _handle_bad_vectors(data.to_reader()).read_all()
assert output["vector"] == vector
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
)
def test_into_pyarrow_table(data):
expected = pa.table({"a": [1], "b": [2]})
output = _into_pyarrow_table(data())
output = _into_pyarrow_reader(data()).read_all()
assert output == expected
@@ -349,7 +349,7 @@ def test_infer_target_schema():
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
# Handle large list and use modal size
@@ -370,7 +370,7 @@ def test_infer_target_schema():
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
# ignore if not list
@@ -386,7 +386,7 @@ def test_infer_target_schema():
schema=example,
)
expected = example
output = _infer_target_schema(data)
output, _ = _infer_target_schema(data.to_reader())
assert output == expected
@@ -476,7 +476,7 @@ def test_sanitize_data(
target_schema=schema,
metadata=metadata,
allow_subschema=True,
)
).read_all()
assert output_data == expected
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
"vec2": pa.list_(pa.float32(), 2),
}
)
output = _cast_to_target_schema(data, target)
output = _cast_to_target_schema(data.to_reader(), target)
expected = pa.table(
{
"id": [1],
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
}
)
with pytest.raises(Exception):
_cast_to_target_schema(data, target)
output = _cast_to_target_schema(data, target, allow_subschema=True)
_cast_to_target_schema(data.to_reader(), target)
output = _cast_to_target_schema(
data.to_reader(), target, allow_subschema=True
).read_all()
expected_schema = pa.schema(
{
"id": pa.int64(),
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
schema=expected_schema,
)
assert output == expected
def test_sanitize_data_stream():
# Make sure we don't collect the whole stream when running sanitize_data
schema = pa.schema({"a": pa.int32()})
def stream():
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
raise ValueError("error")
reader = pa.RecordBatchReader.from_batches(schema, stream())
output = _sanitize_data(reader)
first = next(output)
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
with pytest.raises(ValueError):
next(output)