From 801a9e5f6ffd004d5ac31160714ceba3a4fc5db4 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 6 Feb 2025 16:37:30 -0800 Subject: [PATCH] feat(python): streaming larger-than-memory writes (#2094) Makes our preprocessing pipeline do transforms in streaming fashion, so users can do larger-then-memory writes. Closes #2082 --- python/python/lancedb/arrow.py | 16 +- python/python/lancedb/embeddings/base.py | 2 +- python/python/lancedb/table.py | 218 +++++++++++++---------- python/python/tests/test_embeddings.py | 2 +- python/python/tests/test_util.py | 73 +++++--- 5 files changed, 192 insertions(+), 119 deletions(-) diff --git a/python/python/lancedb/arrow.py b/python/python/lancedb/arrow.py index 7cf876d9..ccb62ec5 100644 --- a/python/python/lancedb/arrow.py +++ b/python/python/lancedb/arrow.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import pyarrow as pa @@ -66,3 +66,17 @@ class AsyncRecordBatchReader: batches = table.to_batches(max_chunksize=max_batch_length) for batch in batches: yield batch + + +def peek_reader( + reader: pa.RecordBatchReader, +) -> Tuple[pa.RecordBatch, pa.RecordBatchReader]: + if not isinstance(reader, pa.RecordBatchReader): + raise TypeError("reader must be a RecordBatchReader") + batch = reader.read_next_batch() + + def all_batches(): + yield batch + yield from reader + + return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches()) diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index d7ab93ee..cd68aa70 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -116,7 +116,7 @@ class EmbeddingFunction(BaseModel, ABC): ) @abstractmethod - def ndims(self): + def ndims(self) -> int: """ Return the dimensions of the vector column """ diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 86a1d500..48513e7f 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -24,6 +24,7 @@ from typing import ( from urllib.parse import urlparse import lance +from lancedb.arrow import peek_reader from lancedb.background_loop import LOOP from .dependencies import _check_for_pandas import pyarrow as pa @@ -74,17 +75,19 @@ pl = safe_import_polars() QueryType = Literal["vector", "fts", "hybrid", "auto"] -def _into_pyarrow_table(data) -> pa.Table: +def _into_pyarrow_reader(data) -> pa.RecordBatchReader: if _check_for_hugging_face(data): # Huggingface datasets from lance.dependencies import datasets if isinstance(data, datasets.Dataset): schema = data.features.arrow_schema - return pa.Table.from_batches(data.data.to_batches(), schema=schema) + return pa.RecordBatchReader.from_batches(schema, data.data.to_batches()) elif isinstance(data, datasets.dataset_dict.DatasetDict): schema = _schema_from_hf(data, schema) - return pa.Table.from_batches(_to_batches_with_split(data), schema=schema) + return pa.RecordBatchReader.from_batches( + schema, _to_batches_with_split(data) + ) if isinstance(data, LanceModel): raise ValueError("Cannot add a single LanceModel to a table. Use a list.") @@ -96,41 +99,41 @@ def _into_pyarrow_table(data) -> pa.Table: if isinstance(data[0], LanceModel): schema = data[0].__class__.to_arrow_schema() data = [model_to_dict(d) for d in data] - return pa.Table.from_pylist(data, schema=schema) + return pa.Table.from_pylist(data, schema=schema).to_reader() elif isinstance(data[0], pa.RecordBatch): - return pa.Table.from_batches(data) + return pa.Table.from_batches(data).to_reader() else: - return pa.Table.from_pylist(data) + return pa.Table.from_pylist(data).to_reader() elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): table = pa.Table.from_pandas(data, preserve_index=False) # Do not serialize Pandas metadata meta = table.schema.metadata if table.schema.metadata is not None else {} meta = {k: v for k, v in meta.items() if k != b"pandas"} - return table.replace_schema_metadata(meta) + return table.replace_schema_metadata(meta).to_reader() elif isinstance(data, pa.Table): - return data + return data.to_reader() elif isinstance(data, pa.RecordBatch): - return pa.Table.from_batches([data]) + return pa.RecordBatchReader.from_batches(data.schema, [data]) elif isinstance(data, LanceDataset): - return data.scanner().to_table() + return data.scanner().to_reader() elif isinstance(data, pa.dataset.Dataset): - return data.to_table() + return data.scanner().to_reader() elif isinstance(data, pa.dataset.Scanner): - return data.to_table() + return data.to_reader() elif isinstance(data, pa.RecordBatchReader): - return data.read_all() + return data elif ( type(data).__module__.startswith("polars") and data.__class__.__name__ == "DataFrame" ): - return data.to_arrow() + return data.to_arrow().to_reader() elif ( type(data).__module__.startswith("polars") and data.__class__.__name__ == "LazyFrame" ): - return data.collect().to_arrow() + return data.collect().to_arrow().to_reader() elif isinstance(data, Iterable): - return _iterator_to_table(data) + return _iterator_to_reader(data) else: raise TypeError( f"Unknown data type {type(data)}. " @@ -140,30 +143,28 @@ def _into_pyarrow_table(data) -> pa.Table: ) -def _iterator_to_table(data: Iterable) -> pa.Table: - batches = [] - schema = None # Will get schema from first batch - for batch in data: - batch_table = _into_pyarrow_table(batch) - if schema is not None: - if batch_table.schema != schema: +def _iterator_to_reader(data: Iterable) -> pa.RecordBatchReader: + # Each batch is treated as it's own reader, mainly so we can + # re-use the _into_pyarrow_reader logic. + first = _into_pyarrow_reader(next(data)) + schema = first.schema + + def gen(): + yield from first + for batch in data: + table: pa.Table = _into_pyarrow_reader(batch).read_all() + if table.schema != schema: try: - batch_table = batch_table.cast(schema) + table = table.cast(schema) except pa.lib.ArrowInvalid: raise ValueError( f"Input iterator yielded a batch with schema that " f"does not match the schema of other batches.\n" - f"Expected:\n{schema}\nGot:\n{batch_table.schema}" + f"Expected:\n{schema}\nGot:\n{batch.schema}" ) - else: - # Use the first schema for the remainder of the batches - schema = batch_table.schema - batches.append(batch_table) + yield from table.to_batches() - if batches: - return pa.concat_tables(batches) - else: - raise ValueError("Input iterable is empty") + return pa.RecordBatchReader.from_batches(schema, gen()) def _sanitize_data( @@ -174,7 +175,7 @@ def _sanitize_data( fill_value: float = 0.0, *, allow_subschema: bool = False, -) -> pa.Table: +) -> pa.RecordBatchReader: """ Handle input data, applying all standard transformations. @@ -207,20 +208,20 @@ def _sanitize_data( # 1. There might be embedding columns missing that will be added # in the add_embeddings step. # 2. If `allow_subschemas` is True, there might be columns missing. - table = _into_pyarrow_table(data) + reader = _into_pyarrow_reader(data) - table = _append_vector_columns(table, target_schema, metadata=metadata) + reader = _append_vector_columns(reader, target_schema, metadata=metadata) # This happens before the cast so we can fix vector columns with # incorrect lengths before they are cast to FSL. - table = _handle_bad_vectors( - table, + reader = _handle_bad_vectors( + reader, on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) if target_schema is None: - target_schema = _infer_target_schema(table) + target_schema, reader = _infer_target_schema(reader) if metadata: new_metadata = target_schema.metadata or {} @@ -229,25 +230,25 @@ def _sanitize_data( _validate_schema(target_schema) - table = _cast_to_target_schema(table, target_schema, allow_subschema) + reader = _cast_to_target_schema(reader, target_schema, allow_subschema) - return table + return reader def _cast_to_target_schema( - table: pa.Table, + reader: pa.RecordBatchReader, target_schema: pa.Schema, allow_subschema: bool = False, -) -> pa.Table: +) -> pa.RecordBatchReader: # pa.Table.cast expects field order not to be changed. # Lance doesn't care about field order, so we don't need to rearrange fields # to match the target schema. We just need to correctly cast the fields. - if table.schema == target_schema: + if reader.schema == target_schema: # Fast path when the schemas are already the same - return table + return reader fields = [] - for field in table.schema: + for field in reader.schema: target_field = target_schema.field(field.name) if target_field is None: raise ValueError(f"Field {field.name} not found in target schema") @@ -260,12 +261,16 @@ def _cast_to_target_schema( if allow_subschema and len(reordered_schema) != len(target_schema): fields = _infer_subschema( - list(iter(table.schema)), list(iter(reordered_schema)) + list(iter(reader.schema)), list(iter(reordered_schema)) ) - subschema = pa.schema(fields, metadata=target_schema.metadata) - return table.cast(subschema) - else: - return table.cast(reordered_schema) + reordered_schema = pa.schema(fields, metadata=target_schema.metadata) + + def gen(): + for batch in reader: + # Table but not RecordBatch has cast. + yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0] + + return pa.RecordBatchReader.from_batches(reordered_schema, gen()) def _infer_subschema( @@ -344,7 +349,10 @@ def sanitize_create_table( if metadata: schema = schema.with_metadata(metadata) # Need to apply metadata to the data as well - data = data.replace_schema_metadata(metadata) + if isinstance(data, pa.Table): + data = data.replace_schema_metadata(metadata) + elif isinstance(data, pa.RecordBatchReader): + data = pa.RecordBatchReader.from_batches(schema, data) return data, schema @@ -381,11 +389,11 @@ def _to_batches_with_split(data): def _append_vector_columns( - data: pa.Table, + reader: pa.RecordBatchReader, schema: Optional[pa.Schema] = None, *, metadata: Optional[dict] = None, -) -> pa.Table: +) -> pa.RecordBatchReader: """ Use the embedding function to automatically embed the source columns and add the vector columns to the table. @@ -396,28 +404,43 @@ def _append_vector_columns( metadata = schema.metadata or metadata or {} functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) + if not functions: + return reader + + fields = list(reader.schema) for vector_column, conf in functions.items(): - func = conf.function - no_vector_column = vector_column not in data.column_names - if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py(): - col_data = func.compute_source_embeddings_with_retry( - data[conf.source_column] - ) + if vector_column not in reader.schema.names: if schema is not None: - dtype = schema.field(vector_column).type + field = schema.field(vector_column) else: - dtype = pa.list_(pa.float32(), len(col_data[0])) - if no_vector_column: - data = data.append_column( - pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) - ) - else: - data = data.set_column( - data.column_names.index(vector_column), - pa.field(vector_column, type=dtype), - pa.array(col_data, type=dtype), - ) - return data + dtype = pa.list_(pa.float32(), conf.function.ndims()) + field = pa.field(vector_column, type=dtype, nullable=True) + fields.append(field) + schema = pa.schema(fields, metadata=reader.schema.metadata) + + def gen(): + for batch in reader: + for vector_column, conf in functions.items(): + func = conf.function + no_vector_column = vector_column not in batch.column_names + if no_vector_column or pc.all(pc.is_null(batch[vector_column])).as_py(): + col_data = func.compute_source_embeddings_with_retry( + batch[conf.source_column] + ) + if no_vector_column: + batch = batch.append_column( + schema.field(vector_column), + pa.array(col_data, type=schema.field(vector_column).type), + ) + else: + batch = batch.set_column( + batch.column_names.index(vector_column), + schema.field(vector_column), + pa.array(col_data, type=schema.field(vector_column).type), + ) + yield batch + + return pa.RecordBatchReader.from_batches(schema, gen()) def _table_path(base: str, table_name: str) -> str: @@ -2358,11 +2381,13 @@ class LanceTable(Table): def _handle_bad_vectors( - table: pa.Table, + reader: pa.RecordBatchReader, on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", fill_value: float = 0.0, -) -> pa.Table: - for field in table.schema: +) -> pa.RecordBatchReader: + vector_columns = [] + + for field in reader.schema: # They can provide a 'vector' column that isn't yet a FSL named_vector_col = ( ( @@ -2382,22 +2407,28 @@ def _handle_bad_vectors( ) if named_vector_col or likely_vector_col: - table = _handle_bad_vector_column( - table, - vector_column_name=field.name, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) + vector_columns.append(field.name) - return table + def gen(): + for batch in reader: + for name in vector_columns: + batch = _handle_bad_vector_column( + batch, + vector_column_name=name, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + yield batch + + return pa.RecordBatchReader.from_batches(reader.schema, gen()) def _handle_bad_vector_column( - data: pa.Table, + data: pa.RecordBatch, vector_column_name: str, on_bad_vectors: str = "error", fill_value: float = 0.0, -) -> pa.Table: +) -> pa.RecordBatch: """ Ensure that the vector column exists and has type fixed_size_list(float) @@ -2485,8 +2516,11 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray return pc.is_in(indices, has_nan_indices) -def _infer_target_schema(table: pa.Table) -> pa.Schema: - schema = table.schema +def _infer_target_schema( + reader: pa.RecordBatchReader, +) -> Tuple[pa.Schema, pa.RecordBatchReader]: + schema = reader.schema + peeked = None for i, field in enumerate(schema): if ( @@ -2494,8 +2528,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema: and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) and pa.types.is_floating(field.type.value_type) ): + if peeked is None: + peeked, reader = peek_reader(reader) # Use the most common length of the list as the dimensions - dim = _modal_list_size(table.column(i)) + dim = _modal_list_size(peeked.column(i)) new_field = pa.field( VECTOR_COLUMN_NAME, @@ -2509,8 +2545,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema: and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) and pa.types.is_integer(field.type.value_type) ): + if peeked is None: + peeked, reader = peek_reader(reader) # Use the most common length of the list as the dimensions - dim = _modal_list_size(table.column(i)) + dim = _modal_list_size(peeked.column(i)) new_field = pa.field( VECTOR_COLUMN_NAME, pa.list_(pa.uint8(), dim), @@ -2519,7 +2557,7 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema: schema = schema.set(i, new_field) - return schema + return schema, reader def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int: diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 8778d069..4fbc263b 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -107,7 +107,7 @@ def test_embedding_with_bad_results(tmp_path): vector: Vector(model.ndims()) = model.VectorField() table = db.create_table("test", schema=Schema, mode="overwrite") - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): # Default on_bad_vectors is "error" table.add([{"text": "hello world"}]) diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 3b31bcea..84c2f560 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -14,7 +14,7 @@ from lancedb.table import ( _append_vector_columns, _cast_to_target_schema, _handle_bad_vectors, - _into_pyarrow_table, + _into_pyarrow_reader, _sanitize_data, _infer_target_schema, ) @@ -145,19 +145,19 @@ def test_append_vector_columns(): schema=schema, ) output = _append_vector_columns( - data, + data.to_reader(), schema, # metadata passed separate from schema metadata=metadata, - ) + ).read_all() assert output.schema == schema assert output["vector"].null_count == 0 # Adds if missing data = pa.table({"text": ["hello"]}) output = _append_vector_columns( - data, + data.to_reader(), schema.with_metadata(metadata), - ) + ).read_all() assert output.schema == schema assert output["vector"].null_count == 0 @@ -170,9 +170,9 @@ def test_append_vector_columns(): schema=schema, ) output = _append_vector_columns( - data, + data.to_reader(), schema.with_metadata(metadata), - ) + ).read_all() assert output == data # No change # No provided schema @@ -182,9 +182,9 @@ def test_append_vector_columns(): } ) output = _append_vector_columns( - data, + data.to_reader(), metadata=metadata, - ) + ).read_all() expected_schema = pa.schema( { "text": pa.string(), @@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors): if on_bad_vectors == "error": with pytest.raises(ValueError) as e: output = _handle_bad_vectors( - data, + data.to_reader(), on_bad_vectors=on_bad_vectors, - ) + ).read_all() output = exception_output(e) assert output == ( "ValueError: Vector column 'vector' has variable length vectors. Set " @@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors): return else: output = _handle_bad_vectors( - data, + data.to_reader(), on_bad_vectors=on_bad_vectors, fill_value=42.0, - ) + ).read_all() if on_bad_vectors == "drop": expected = pa.array([[1.0, 2.0], [4.0, 5.0]]) @@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors): if on_bad_vectors == "error": with pytest.raises(ValueError) as e: output = _handle_bad_vectors( - data, + data.to_reader(), on_bad_vectors=on_bad_vectors, - ) + ).read_all() output = exception_output(e) assert output == ( "ValueError: Vector column 'vector' has NaNs. Set " @@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors): return else: output = _handle_bad_vectors( - data, + data.to_reader(), on_bad_vectors=on_bad_vectors, fill_value=42.0, - ) + ).read_all() if on_bad_vectors == "drop": expected = pa.array([[3.0, 4.0]]) @@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop(): [[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2) ) data = pa.table({"vector": vector}) - output = _handle_bad_vectors(data) + output = _handle_bad_vectors(data.to_reader()).read_all() assert output["vector"] == vector @@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel): ) def test_into_pyarrow_table(data): expected = pa.table({"a": [1], "b": [2]}) - output = _into_pyarrow_table(data()) + output = _into_pyarrow_reader(data()).read_all() assert output == expected @@ -349,7 +349,7 @@ def test_infer_target_schema(): "vector": pa.list_(pa.float32(), 2), } ) - output = _infer_target_schema(data) + output, _ = _infer_target_schema(data.to_reader()) assert output == expected # Handle large list and use modal size @@ -370,7 +370,7 @@ def test_infer_target_schema(): "vector": pa.list_(pa.float32(), 2), } ) - output = _infer_target_schema(data) + output, _ = _infer_target_schema(data.to_reader()) assert output == expected # ignore if not list @@ -386,7 +386,7 @@ def test_infer_target_schema(): schema=example, ) expected = example - output = _infer_target_schema(data) + output, _ = _infer_target_schema(data.to_reader()) assert output == expected @@ -476,7 +476,7 @@ def test_sanitize_data( target_schema=schema, metadata=metadata, allow_subschema=True, - ) + ).read_all() assert output_data == expected @@ -519,7 +519,7 @@ def test_cast_to_target_schema(): "vec2": pa.list_(pa.float32(), 2), } ) - output = _cast_to_target_schema(data, target) + output = _cast_to_target_schema(data.to_reader(), target) expected = pa.table( { "id": [1], @@ -550,8 +550,10 @@ def test_cast_to_target_schema(): } ) with pytest.raises(Exception): - _cast_to_target_schema(data, target) - output = _cast_to_target_schema(data, target, allow_subschema=True) + _cast_to_target_schema(data.to_reader(), target) + output = _cast_to_target_schema( + data.to_reader(), target, allow_subschema=True + ).read_all() expected_schema = pa.schema( { "id": pa.int64(), @@ -576,3 +578,22 @@ def test_cast_to_target_schema(): schema=expected_schema, ) assert output == expected + + +def test_sanitize_data_stream(): + # Make sure we don't collect the whole stream when running sanitize_data + schema = pa.schema({"a": pa.int32()}) + + def stream(): + yield pa.record_batch([pa.array([1, 2, 3])], schema=schema) + raise ValueError("error") + + reader = pa.RecordBatchReader.from_batches(schema, stream()) + + output = _sanitize_data(reader) + + first = next(output) + assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema) + + with pytest.raises(ValueError): + next(output)