mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat(python): streaming larger-than-memory writes (#2094)
Makes our preprocessing pipeline do transforms in streaming fashion, so users can do larger-then-memory writes. Closes #2082
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -66,3 +66,17 @@ class AsyncRecordBatchReader:
|
||||
batches = table.to_batches(max_chunksize=max_batch_length)
|
||||
for batch in batches:
|
||||
yield batch
|
||||
|
||||
|
||||
def peek_reader(
|
||||
reader: pa.RecordBatchReader,
|
||||
) -> Tuple[pa.RecordBatch, pa.RecordBatchReader]:
|
||||
if not isinstance(reader, pa.RecordBatchReader):
|
||||
raise TypeError("reader must be a RecordBatchReader")
|
||||
batch = reader.read_next_batch()
|
||||
|
||||
def all_batches():
|
||||
yield batch
|
||||
yield from reader
|
||||
|
||||
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())
|
||||
|
||||
@@ -116,7 +116,7 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
def ndims(self) -> int:
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import (
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import lance
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
from .dependencies import _check_for_pandas
|
||||
import pyarrow as pa
|
||||
@@ -74,17 +75,19 @@ pl = safe_import_polars()
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _into_pyarrow_table(data) -> pa.Table:
|
||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
|
||||
if isinstance(data, datasets.Dataset):
|
||||
schema = data.features.arrow_schema
|
||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
||||
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
|
||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
schema = _schema_from_hf(data, schema)
|
||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
||||
return pa.RecordBatchReader.from_batches(
|
||||
schema, _to_batches_with_split(data)
|
||||
)
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
@@ -96,41 +99,41 @@ def _into_pyarrow_table(data) -> pa.Table:
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [model_to_dict(d) for d in data]
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
return pa.Table.from_pylist(data, schema=schema).to_reader()
|
||||
elif isinstance(data[0], pa.RecordBatch):
|
||||
return pa.Table.from_batches(data)
|
||||
return pa.Table.from_batches(data).to_reader()
|
||||
else:
|
||||
return pa.Table.from_pylist(data)
|
||||
return pa.Table.from_pylist(data).to_reader()
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
return table.replace_schema_metadata(meta)
|
||||
return table.replace_schema_metadata(meta).to_reader()
|
||||
elif isinstance(data, pa.Table):
|
||||
return data
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatch):
|
||||
return pa.Table.from_batches([data])
|
||||
return pa.RecordBatchReader.from_batches(data.schema, [data])
|
||||
elif isinstance(data, LanceDataset):
|
||||
return data.scanner().to_table()
|
||||
return data.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Dataset):
|
||||
return data.to_table()
|
||||
return data.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Scanner):
|
||||
return data.to_table()
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
return data.read_all()
|
||||
return data
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "DataFrame"
|
||||
):
|
||||
return data.to_arrow()
|
||||
return data.to_arrow().to_reader()
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "LazyFrame"
|
||||
):
|
||||
return data.collect().to_arrow()
|
||||
return data.collect().to_arrow().to_reader()
|
||||
elif isinstance(data, Iterable):
|
||||
return _iterator_to_table(data)
|
||||
return _iterator_to_reader(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown data type {type(data)}. "
|
||||
@@ -140,30 +143,28 @@ def _into_pyarrow_table(data) -> pa.Table:
|
||||
)
|
||||
|
||||
|
||||
def _iterator_to_table(data: Iterable) -> pa.Table:
|
||||
batches = []
|
||||
schema = None # Will get schema from first batch
|
||||
for batch in data:
|
||||
batch_table = _into_pyarrow_table(batch)
|
||||
if schema is not None:
|
||||
if batch_table.schema != schema:
|
||||
def _iterator_to_reader(data: Iterable) -> pa.RecordBatchReader:
|
||||
# Each batch is treated as it's own reader, mainly so we can
|
||||
# re-use the _into_pyarrow_reader logic.
|
||||
first = _into_pyarrow_reader(next(data))
|
||||
schema = first.schema
|
||||
|
||||
def gen():
|
||||
yield from first
|
||||
for batch in data:
|
||||
table: pa.Table = _into_pyarrow_reader(batch).read_all()
|
||||
if table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
table = table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the schema of other batches.\n"
|
||||
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
|
||||
f"Expected:\n{schema}\nGot:\n{batch.schema}"
|
||||
)
|
||||
else:
|
||||
# Use the first schema for the remainder of the batches
|
||||
schema = batch_table.schema
|
||||
batches.append(batch_table)
|
||||
yield from table.to_batches()
|
||||
|
||||
if batches:
|
||||
return pa.concat_tables(batches)
|
||||
else:
|
||||
raise ValueError("Input iterable is empty")
|
||||
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -174,7 +175,7 @@ def _sanitize_data(
|
||||
fill_value: float = 0.0,
|
||||
*,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Handle input data, applying all standard transformations.
|
||||
|
||||
@@ -207,20 +208,20 @@ def _sanitize_data(
|
||||
# 1. There might be embedding columns missing that will be added
|
||||
# in the add_embeddings step.
|
||||
# 2. If `allow_subschemas` is True, there might be columns missing.
|
||||
table = _into_pyarrow_table(data)
|
||||
reader = _into_pyarrow_reader(data)
|
||||
|
||||
table = _append_vector_columns(table, target_schema, metadata=metadata)
|
||||
reader = _append_vector_columns(reader, target_schema, metadata=metadata)
|
||||
|
||||
# This happens before the cast so we can fix vector columns with
|
||||
# incorrect lengths before they are cast to FSL.
|
||||
table = _handle_bad_vectors(
|
||||
table,
|
||||
reader = _handle_bad_vectors(
|
||||
reader,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
if target_schema is None:
|
||||
target_schema = _infer_target_schema(table)
|
||||
target_schema, reader = _infer_target_schema(reader)
|
||||
|
||||
if metadata:
|
||||
new_metadata = target_schema.metadata or {}
|
||||
@@ -229,25 +230,25 @@ def _sanitize_data(
|
||||
|
||||
_validate_schema(target_schema)
|
||||
|
||||
table = _cast_to_target_schema(table, target_schema, allow_subschema)
|
||||
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||
|
||||
return table
|
||||
return reader
|
||||
|
||||
|
||||
def _cast_to_target_schema(
|
||||
table: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
target_schema: pa.Schema,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
# pa.Table.cast expects field order not to be changed.
|
||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||
# to match the target schema. We just need to correctly cast the fields.
|
||||
if table.schema == target_schema:
|
||||
if reader.schema == target_schema:
|
||||
# Fast path when the schemas are already the same
|
||||
return table
|
||||
return reader
|
||||
|
||||
fields = []
|
||||
for field in table.schema:
|
||||
for field in reader.schema:
|
||||
target_field = target_schema.field(field.name)
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field {field.name} not found in target schema")
|
||||
@@ -260,12 +261,16 @@ def _cast_to_target_schema(
|
||||
|
||||
if allow_subschema and len(reordered_schema) != len(target_schema):
|
||||
fields = _infer_subschema(
|
||||
list(iter(table.schema)), list(iter(reordered_schema))
|
||||
list(iter(reader.schema)), list(iter(reordered_schema))
|
||||
)
|
||||
subschema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
return table.cast(subschema)
|
||||
else:
|
||||
return table.cast(reordered_schema)
|
||||
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
|
||||
def gen():
|
||||
for batch in reader:
|
||||
# Table but not RecordBatch has cast.
|
||||
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
|
||||
|
||||
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||
|
||||
|
||||
def _infer_subschema(
|
||||
@@ -344,7 +349,10 @@ def sanitize_create_table(
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
# Need to apply metadata to the data as well
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
if isinstance(data, pa.Table):
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
data = pa.RecordBatchReader.from_batches(schema, data)
|
||||
|
||||
return data, schema
|
||||
|
||||
@@ -381,11 +389,11 @@ def _to_batches_with_split(data):
|
||||
|
||||
|
||||
def _append_vector_columns(
|
||||
data: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
*,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Use the embedding function to automatically embed the source columns and add the
|
||||
vector columns to the table.
|
||||
@@ -396,28 +404,43 @@ def _append_vector_columns(
|
||||
metadata = schema.metadata or metadata or {}
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
|
||||
if not functions:
|
||||
return reader
|
||||
|
||||
fields = list(reader.schema)
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
no_vector_column = vector_column not in data.column_names
|
||||
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
data[conf.source_column]
|
||||
)
|
||||
if vector_column not in reader.schema.names:
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_column).type
|
||||
field = schema.field(vector_column)
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
if no_vector_column:
|
||||
data = data.append_column(
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
else:
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column),
|
||||
pa.field(vector_column, type=dtype),
|
||||
pa.array(col_data, type=dtype),
|
||||
)
|
||||
return data
|
||||
dtype = pa.list_(pa.float32(), conf.function.ndims())
|
||||
field = pa.field(vector_column, type=dtype, nullable=True)
|
||||
fields.append(field)
|
||||
schema = pa.schema(fields, metadata=reader.schema.metadata)
|
||||
|
||||
def gen():
|
||||
for batch in reader:
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
no_vector_column = vector_column not in batch.column_names
|
||||
if no_vector_column or pc.all(pc.is_null(batch[vector_column])).as_py():
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
batch[conf.source_column]
|
||||
)
|
||||
if no_vector_column:
|
||||
batch = batch.append_column(
|
||||
schema.field(vector_column),
|
||||
pa.array(col_data, type=schema.field(vector_column).type),
|
||||
)
|
||||
else:
|
||||
batch = batch.set_column(
|
||||
batch.column_names.index(vector_column),
|
||||
schema.field(vector_column),
|
||||
pa.array(col_data, type=schema.field(vector_column).type),
|
||||
)
|
||||
yield batch
|
||||
|
||||
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||
|
||||
|
||||
def _table_path(base: str, table_name: str) -> str:
|
||||
@@ -2358,11 +2381,13 @@ class LanceTable(Table):
|
||||
|
||||
|
||||
def _handle_bad_vectors(
|
||||
table: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
for field in table.schema:
|
||||
) -> pa.RecordBatchReader:
|
||||
vector_columns = []
|
||||
|
||||
for field in reader.schema:
|
||||
# They can provide a 'vector' column that isn't yet a FSL
|
||||
named_vector_col = (
|
||||
(
|
||||
@@ -2382,22 +2407,28 @@ def _handle_bad_vectors(
|
||||
)
|
||||
|
||||
if named_vector_col or likely_vector_col:
|
||||
table = _handle_bad_vector_column(
|
||||
table,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
vector_columns.append(field.name)
|
||||
|
||||
return table
|
||||
def gen():
|
||||
for batch in reader:
|
||||
for name in vector_columns:
|
||||
batch = _handle_bad_vector_column(
|
||||
batch,
|
||||
vector_column_name=name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
yield batch
|
||||
|
||||
return pa.RecordBatchReader.from_batches(reader.schema, gen())
|
||||
|
||||
|
||||
def _handle_bad_vector_column(
|
||||
data: pa.Table,
|
||||
data: pa.RecordBatch,
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatch:
|
||||
"""
|
||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||
|
||||
@@ -2485,8 +2516,11 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
|
||||
return pc.is_in(indices, has_nan_indices)
|
||||
|
||||
|
||||
def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
schema = table.schema
|
||||
def _infer_target_schema(
|
||||
reader: pa.RecordBatchReader,
|
||||
) -> Tuple[pa.Schema, pa.RecordBatchReader]:
|
||||
schema = reader.schema
|
||||
peeked = None
|
||||
|
||||
for i, field in enumerate(schema):
|
||||
if (
|
||||
@@ -2494,8 +2528,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
):
|
||||
if peeked is None:
|
||||
peeked, reader = peek_reader(reader)
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
dim = _modal_list_size(peeked.column(i))
|
||||
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
@@ -2509,8 +2545,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_integer(field.type.value_type)
|
||||
):
|
||||
if peeked is None:
|
||||
peeked, reader = peek_reader(reader)
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
dim = _modal_list_size(peeked.column(i))
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
pa.list_(pa.uint8(), dim),
|
||||
@@ -2519,7 +2557,7 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
|
||||
schema = schema.set(i, new_field)
|
||||
|
||||
return schema
|
||||
return schema, reader
|
||||
|
||||
|
||||
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(RuntimeError):
|
||||
# Default on_bad_vectors is "error"
|
||||
table.add([{"text": "hello world"}])
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from lancedb.table import (
|
||||
_append_vector_columns,
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_table,
|
||||
_into_pyarrow_reader,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
)
|
||||
@@ -145,19 +145,19 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema, # metadata passed separate from schema
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# Adds if missing
|
||||
data = pa.table({"text": ["hello"]})
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
@@ -170,9 +170,9 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output == data # No change
|
||||
|
||||
# No provided schema
|
||||
@@ -182,9 +182,9 @@ def test_append_vector_columns():
|
||||
}
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||
)
|
||||
data = pa.table({"vector": vector})
|
||||
output = _handle_bad_vectors(data)
|
||||
output = _handle_bad_vectors(data.to_reader()).read_all()
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
|
||||
)
|
||||
def test_into_pyarrow_table(data):
|
||||
expected = pa.table({"a": [1], "b": [2]})
|
||||
output = _into_pyarrow_table(data())
|
||||
output = _into_pyarrow_reader(data()).read_all()
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# Handle large list and use modal size
|
||||
@@ -370,7 +370,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# ignore if not list
|
||||
@@ -386,7 +386,7 @@ def test_infer_target_schema():
|
||||
schema=example,
|
||||
)
|
||||
expected = example
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -476,7 +476,7 @@ def test_sanitize_data(
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
allow_subschema=True,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
assert output_data == expected
|
||||
|
||||
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data.to_reader(), target)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
|
||||
}
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
_cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||
_cast_to_target_schema(data.to_reader(), target)
|
||||
output = _cast_to_target_schema(
|
||||
data.to_reader(), target, allow_subschema=True
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_sanitize_data_stream():
|
||||
# Make sure we don't collect the whole stream when running sanitize_data
|
||||
schema = pa.schema({"a": pa.int32()})
|
||||
|
||||
def stream():
|
||||
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
raise ValueError("error")
|
||||
|
||||
reader = pa.RecordBatchReader.from_batches(schema, stream())
|
||||
|
||||
output = _sanitize_data(reader)
|
||||
|
||||
first = next(output)
|
||||
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
next(output)
|
||||
|
||||
Reference in New Issue
Block a user