mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat(python): streaming larger-than-memory writes (#2094)
Makes our preprocessing pipeline do transforms in streaming fashion, so users can do larger-then-memory writes. Closes #2082
This commit is contained in:
@@ -14,7 +14,7 @@ from lancedb.table import (
|
||||
_append_vector_columns,
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_table,
|
||||
_into_pyarrow_reader,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
)
|
||||
@@ -145,19 +145,19 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema, # metadata passed separate from schema
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# Adds if missing
|
||||
data = pa.table({"text": ["hello"]})
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
@@ -170,9 +170,9 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output == data # No change
|
||||
|
||||
# No provided schema
|
||||
@@ -182,9 +182,9 @@ def test_append_vector_columns():
|
||||
}
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||
)
|
||||
data = pa.table({"vector": vector})
|
||||
output = _handle_bad_vectors(data)
|
||||
output = _handle_bad_vectors(data.to_reader()).read_all()
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
|
||||
)
|
||||
def test_into_pyarrow_table(data):
|
||||
expected = pa.table({"a": [1], "b": [2]})
|
||||
output = _into_pyarrow_table(data())
|
||||
output = _into_pyarrow_reader(data()).read_all()
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# Handle large list and use modal size
|
||||
@@ -370,7 +370,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# ignore if not list
|
||||
@@ -386,7 +386,7 @@ def test_infer_target_schema():
|
||||
schema=example,
|
||||
)
|
||||
expected = example
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -476,7 +476,7 @@ def test_sanitize_data(
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
allow_subschema=True,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
assert output_data == expected
|
||||
|
||||
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data.to_reader(), target)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
|
||||
}
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
_cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||
_cast_to_target_schema(data.to_reader(), target)
|
||||
output = _cast_to_target_schema(
|
||||
data.to_reader(), target, allow_subschema=True
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_sanitize_data_stream():
|
||||
# Make sure we don't collect the whole stream when running sanitize_data
|
||||
schema = pa.schema({"a": pa.int32()})
|
||||
|
||||
def stream():
|
||||
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
raise ValueError("error")
|
||||
|
||||
reader = pa.RecordBatchReader.from_batches(schema, stream())
|
||||
|
||||
output = _sanitize_data(reader)
|
||||
|
||||
first = next(output)
|
||||
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
next(output)
|
||||
|
||||
Reference in New Issue
Block a user