diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 1c77b299..d2345e4a 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -14,7 +14,6 @@ from __future__ import annotations import asyncio -import inspect import os from abc import abstractmethod from pathlib import Path @@ -27,8 +26,13 @@ from pyarrow import fs from lancedb.common import data_to_reader, validate_schema from ._lancedb import connect as lancedb_connect -from .pydantic import LanceModel -from .table import AsyncTable, LanceTable, Table, _sanitize_data, _table_path +from .table import ( + AsyncTable, + LanceTable, + Table, + _table_path, + sanitize_create_table, +) from .util import ( fs_from_uri, get_uri_location, @@ -37,6 +41,7 @@ from .util import ( ) if TYPE_CHECKING: + from .pydantic import LanceModel from datetime import timedelta from ._lancedb import Connection as LanceDbConnection @@ -722,12 +727,6 @@ class AsyncConnection(object): ... await db.create_table("table4", make_batches(), schema=schema) >>> asyncio.run(iterable_example()) """ - if inspect.isclass(schema) and issubclass(schema, LanceModel): - # convert LanceModel to pyarrow schema - # note that it's possible this contains - # embedding function metadata already - schema = schema.to_arrow_schema() - metadata = None # Defining defaults here and not in function prototype. In the future @@ -738,31 +737,9 @@ class AsyncConnection(object): if fill_value is None: fill_value = 0.0 - if data is not None: - data, schema = _sanitize_data( - data, - schema, - metadata=metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - - if schema is None: - if data is None: - raise ValueError("Either data or schema must be provided") - elif hasattr(data, "schema"): - schema = data.schema - elif isinstance(data, Iterable): - if metadata: - raise TypeError( - ( - "Persistent embedding functions not yet " - "supported for generator data input" - ) - ) - - if metadata: - schema = schema.with_metadata(metadata) + data, schema = sanitize_create_table( + data, schema, metadata, on_bad_vectors, fill_value + ) validate_schema(schema) if exist_ok is None: diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 9c9c69ae..c6b14c0f 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -852,7 +852,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): ) if len(row_ids) == 0: empty_schema = pa.schema([pa.field("_score", pa.float32())]) - return pa.Table.from_pylist([], schema=empty_schema) + return pa.Table.from_batches([], schema=empty_schema) scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = output_tbl.append_column("_score", scores) diff --git a/python/python/lancedb/remote/arrow.py b/python/python/lancedb/remote/arrow.py index 753087cf..ac39e247 100644 --- a/python/python/lancedb/remote/arrow.py +++ b/python/python/lancedb/remote/arrow.py @@ -11,12 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable, Union import pyarrow as pa -def to_ipc_binary(table: pa.Table) -> bytes: +def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes: """Serialize a PyArrow Table to IPC binary.""" sink = pa.BufferOutputStream() + if isinstance(table, Iterable): + table = pa.Table.from_batches(table) with pa.ipc.new_stream(sink, table.schema) as writer: writer.write_table(table) return sink.getvalue().to_pybytes() diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 0dd6bb6d..bb7554a4 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import uuid from concurrent.futures import ThreadPoolExecutor @@ -26,7 +25,7 @@ from ..common import DATA from ..db import DBConnection from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel -from ..table import Table, _sanitize_data +from ..table import Table, sanitize_create_table from ..util import validate_table_name from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient @@ -228,8 +227,6 @@ class RemoteDBConnection(DBConnection): """ validate_table_name(name) - if data is None and schema is None: - raise ValueError("Either data or schema must be provided.") if embedding_functions is not None: logging.warning( "embedding_functions is not yet supported on LanceDB Cloud." @@ -239,24 +236,9 @@ class RemoteDBConnection(DBConnection): if mode is not None: logging.warning("mode is not yet supported on LanceDB Cloud.") - if inspect.isclass(schema) and issubclass(schema, LanceModel): - # convert LanceModel to pyarrow schema - # note that it's possible this contains - # embedding function metadata already - schema = schema.to_arrow_schema() - - if data is not None: - data, schema = _sanitize_data( - data, - schema, - metadata=None, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - else: - if schema is None: - raise ValueError("Either data or schema must be provided") - data = pa.Table.from_pylist([], schema=schema) + data, schema = sanitize_create_table( + data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) from .table import RemoteTable diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 7d3ebaa0..53e624a0 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -117,15 +117,50 @@ def _sanitize_data( data = _sanitize_schema( data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) + if schema is None: + schema = data.schema elif isinstance(data, Iterable): data = _to_record_batch_generator( data, schema, metadata, on_bad_vectors, fill_value ) + if schema is None: + data, schema = _generator_to_data_and_schema(data) + if schema is None: + raise ValueError("Cannot infer schema from generator data") else: raise TypeError(f"Unsupported data type: {type(data)}") return data, schema +def sanitize_create_table( + data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0 +): + if inspect.isclass(schema) and issubclass(schema, LanceModel): + # convert LanceModel to pyarrow schema + # note that it's possible this contains + # embedding function metadata already + schema = schema.to_arrow_schema() + + if data is not None: + data, schema = _sanitize_data( + data, + schema, + metadata=metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + if schema is None: + if data is None: + raise ValueError("Either data or schema must be provided") + elif hasattr(data, "schema"): + schema = data.schema + + if metadata: + schema = schema.with_metadata(metadata) + + return data, schema + + def _schema_from_hf(data, schema): """ Extract pyarrow schema from HuggingFace DatasetDict @@ -187,8 +222,30 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem return data +def _generator_to_data_and_schema( + data: Iterable, +) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]: + def _with_first_generator(first, data): + yield first + yield from data + + first = next(data, None) + schema = None + if isinstance(first, pa.RecordBatch): + schema = first.schema + data = _with_first_generator(first, data) + elif isinstance(first, pa.Table): + schema = first.schema + data = _with_first_generator(first.to_batches(), data) + return data, schema + + def _to_record_batch_generator( - data: Iterable, schema, metadata, on_bad_vectors, fill_value + data: Iterable, + schema, + metadata, + on_bad_vectors, + fill_value, ): for batch in data: # always convert to table because we need to sanitize the data @@ -1569,12 +1626,6 @@ class LanceTable(Table): The embedding functions to use when creating the table. """ tbl = LanceTable(db, name) - if inspect.isclass(schema) and issubclass(schema, LanceModel): - # convert LanceModel to pyarrow schema - # note that it's possible this contains - # embedding function metadata already - schema = schema.to_arrow_schema() - metadata = None if embedding_functions is not None: # If we passed in embedding functions explicitly @@ -1583,33 +1634,11 @@ class LanceTable(Table): registry = EmbeddingFunctionRegistry.get_instance() metadata = registry.get_table_metadata(embedding_functions) - if data is not None: - data, schema = _sanitize_data( - data, - schema, - metadata=metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) + data, schema = sanitize_create_table( + data, schema, metadata, on_bad_vectors, fill_value + ) - if schema is None: - if data is None: - raise ValueError("Either data or schema must be provided") - elif hasattr(data, "schema"): - schema = data.schema - elif isinstance(data, Iterable): - if metadata: - raise TypeError( - ( - "Persistent embedding functions not yet " - "supported for generator data input" - ) - ) - - if metadata: - schema = schema.with_metadata(metadata) - - empty = pa.Table.from_pylist([], schema=schema) + empty = pa.Table.from_batches([], schema=schema) try: lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode) except OSError as err: diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 373ae2b6..5b7f3c42 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -233,6 +233,43 @@ def test_create_mode(tmp_path): assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"] +def test_create_table_from_iterator(tmp_path): + db = lancedb.connect(tmp_path) + + def gen_data(): + for _ in range(10): + yield pa.RecordBatch.from_arrays( + [ + pa.array([[3.1, 4.1]], pa.list_(pa.float32(), 2)), + pa.array(["foo"]), + pa.array([10.0]), + ], + ["vector", "item", "price"], + ) + + table = db.create_table("test", data=gen_data()) + assert table.count_rows() == 10 + + +@pytest.mark.asyncio +async def test_create_table_from_iterator_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + + def gen_data(): + for _ in range(10): + yield pa.RecordBatch.from_arrays( + [ + pa.array([[3.1, 4.1]], pa.list_(pa.float32(), 2)), + pa.array(["foo"]), + pa.array([10.0]), + ], + ["vector", "item", "price"], + ) + + table = await db.create_table("test", data=gen_data()) + assert await table.count_rows() == 10 + + def test_create_exist_ok(tmp_path): db = lancedb.connect(tmp_path) data = pd.DataFrame(