feat: support to create table from record batch iterator (#1593)

This commit is contained in:
BubbleCal
2024-09-06 10:41:38 +08:00
committed by GitHub
parent 1d61717d0e
commit 8dcd328dce
6 changed files with 119 additions and 91 deletions

View File

@@ -14,7 +14,6 @@
from __future__ import annotations
import asyncio
import inspect
import os
from abc import abstractmethod
from pathlib import Path
@@ -27,8 +26,13 @@ from pyarrow import fs
from lancedb.common import data_to_reader, validate_schema
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data, _table_path
from .table import (
AsyncTable,
LanceTable,
Table,
_table_path,
sanitize_create_table,
)
from .util import (
fs_from_uri,
get_uri_location,
@@ -37,6 +41,7 @@ from .util import (
)
if TYPE_CHECKING:
from .pydantic import LanceModel
from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
@@ -722,12 +727,6 @@ class AsyncConnection(object):
... await db.create_table("table4", make_batches(), schema=schema)
>>> asyncio.run(iterable_example())
"""
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
metadata = None
# Defining defaults here and not in function prototype. In the future
@@ -738,31 +737,9 @@ class AsyncConnection(object):
if fill_value is None:
fill_value = 0.0
if data is not None:
data, schema = _sanitize_data(
data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided")
elif hasattr(data, "schema"):
schema = data.schema
elif isinstance(data, Iterable):
if metadata:
raise TypeError(
(
"Persistent embedding functions not yet "
"supported for generator data input"
)
)
if metadata:
schema = schema.with_metadata(metadata)
data, schema = sanitize_create_table(
data, schema, metadata, on_bad_vectors, fill_value
)
validate_schema(schema)
if exist_ok is None:

View File

@@ -852,7 +852,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
)
if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("_score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)
return pa.Table.from_batches([], schema=empty_schema)
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("_score", scores)

View File

@@ -11,12 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Union
import pyarrow as pa
def to_ipc_binary(table: pa.Table) -> bytes:
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
"""Serialize a PyArrow Table to IPC binary."""
sink = pa.BufferOutputStream()
if isinstance(table, Iterable):
table = pa.Table.from_batches(table)
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
@@ -26,7 +25,7 @@ from ..common import DATA
from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from ..table import Table, sanitize_create_table
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
@@ -228,8 +227,6 @@ class RemoteDBConnection(DBConnection):
"""
validate_table_name(name)
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:
logging.warning(
"embedding_functions is not yet supported on LanceDB Cloud."
@@ -239,24 +236,9 @@ class RemoteDBConnection(DBConnection):
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
if data is not None:
data, schema = _sanitize_data(
data,
schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
else:
if schema is None:
raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema)
data, schema = sanitize_create_table(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
from .table import RemoteTable

View File

@@ -117,15 +117,50 @@ def _sanitize_data(
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if schema is None:
schema = data.schema
elif isinstance(data, Iterable):
data = _to_record_batch_generator(
data, schema, metadata, on_bad_vectors, fill_value
)
if schema is None:
data, schema = _generator_to_data_and_schema(data)
if schema is None:
raise ValueError("Cannot infer schema from generator data")
else:
raise TypeError(f"Unsupported data type: {type(data)}")
return data, schema
def sanitize_create_table(
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
):
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
if data is not None:
data, schema = _sanitize_data(
data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided")
elif hasattr(data, "schema"):
schema = data.schema
if metadata:
schema = schema.with_metadata(metadata)
return data, schema
def _schema_from_hf(data, schema):
"""
Extract pyarrow schema from HuggingFace DatasetDict
@@ -187,8 +222,30 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
return data
def _generator_to_data_and_schema(
data: Iterable,
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
def _with_first_generator(first, data):
yield first
yield from data
first = next(data, None)
schema = None
if isinstance(first, pa.RecordBatch):
schema = first.schema
data = _with_first_generator(first, data)
elif isinstance(first, pa.Table):
schema = first.schema
data = _with_first_generator(first.to_batches(), data)
return data, schema
def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
data: Iterable,
schema,
metadata,
on_bad_vectors,
fill_value,
):
for batch in data:
# always convert to table because we need to sanitize the data
@@ -1569,12 +1626,6 @@ class LanceTable(Table):
The embedding functions to use when creating the table.
"""
tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
# If we passed in embedding functions explicitly
@@ -1583,33 +1634,11 @@ class LanceTable(Table):
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
if data is not None:
data, schema = _sanitize_data(
data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
data, schema = sanitize_create_table(
data, schema, metadata, on_bad_vectors, fill_value
)
if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided")
elif hasattr(data, "schema"):
schema = data.schema
elif isinstance(data, Iterable):
if metadata:
raise TypeError(
(
"Persistent embedding functions not yet "
"supported for generator data input"
)
)
if metadata:
schema = schema.with_metadata(metadata)
empty = pa.Table.from_pylist([], schema=schema)
empty = pa.Table.from_batches([], schema=schema)
try:
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
except OSError as err:

View File

@@ -233,6 +233,43 @@ def test_create_mode(tmp_path):
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_table_from_iterator(tmp_path):
db = lancedb.connect(tmp_path)
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
[
pa.array([[3.1, 4.1]], pa.list_(pa.float32(), 2)),
pa.array(["foo"]),
pa.array([10.0]),
],
["vector", "item", "price"],
)
table = db.create_table("test", data=gen_data())
assert table.count_rows() == 10
@pytest.mark.asyncio
async def test_create_table_from_iterator_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
[
pa.array([[3.1, 4.1]], pa.list_(pa.float32(), 2)),
pa.array(["foo"]),
pa.array([10.0]),
],
["vector", "item", "price"],
)
table = await db.create_table("test", data=gen_data())
assert await table.count_rows() == 10
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(