diff --git a/docs/src/python/arrow.md b/docs/src/python/arrow.md index 91e7019d..10dd33ae 100644 --- a/docs/src/python/arrow.md +++ b/docs/src/python/arrow.md @@ -5,6 +5,8 @@ Built on top of [Apache Arrow](https://arrow.apache.org/), `LanceDB` is easy to integrate with the Python ecosystem, including [Pandas](https://pandas.pydata.org/) and PyArrow. +## Create dataset + First, we need to connect to a `LanceDB` database. ```py @@ -27,10 +29,42 @@ data = pd.DataFrame({ table = db.create_table("pd_table", data=data) ``` -You will find detailed instructions of creating dataset and index in -[Basic Operations](basic.md) and [Indexing](ann_indexes.md) +Similar to [`pyarrow.write_dataset()`](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.write_dataset.html), +[db.create_table()](../python/#lancedb.db.DBConnection.create_table) accepts a wide-range of forms of data. + +For example, if you have a dataset that is larger than memory size, you can create table with `Iterator[pyarrow.RecordBatch]`, +to lazily generate data: + +```py + +from typing import Iterable +import pyarrow as pa +import lancedb + +def make_batches() -> Iterable[pa.RecordBatch]: + for i in range(5): + yield pa.RecordBatch.from_arrays( + [ + pa.array([[3.1, 4.1], [5.9, 26.5]]), + pa.array(["foo", "bar"]), + pa.array([10.0, 20.0]), + ], + ["vector", "item", "price"]) + +schema=pa.schema([ + pa.field("vector", pa.list_(pa.float32())), + pa.field("item", pa.utf8()), + pa.field("price", pa.float32()), +]) + +table = db.create_table("iterable_table", data=make_batches(), schema=schema) +``` + +You will find detailed instructions of creating dataset in +[Basic Operations](../basic.md) and [API](../python/#lancedb.db.DBConnection.create_table) sections. +## Vector Search We can now perform similarity search via `LanceDB` Python API. diff --git a/docs/src/python/python.md b/docs/src/python/python.md index 99b6dd15..ae34d2a4 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -46,10 +46,6 @@ pip install lancedb ## Utilities -::: lancedb.schema.schema_to_dict - -::: lancedb.schema.dict_to_schema - ::: lancedb.vector ## Integrations diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 69afafb7..5ed247e7 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -13,11 +13,12 @@ from __future__ import annotations -import functools import os from abc import ABC, abstractmethod from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple, Union +import pandas as pd import pyarrow as pa from pyarrow import fs @@ -38,8 +39,10 @@ class DBConnection(ABC): def create_table( self, name: str, - data: DATA = None, - schema: pa.Schema = None, + data: Optional[ + Union[List[dict], dict, pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]], + ] = None, + schema: Optional[pa.Schema] = None, mode: str = "create", on_bad_vectors: str = "error", fill_value: float = 0.0, @@ -51,7 +54,7 @@ class DBConnection(ABC): name: str The name of the table. data: list, tuple, dict, pd.DataFrame; optional - The data to insert into the table. + The data to initialize the table. User must provide at least one of `data` or `schema`. schema: pyarrow.Schema; optional The schema of the table. mode: str; default "create" @@ -64,16 +67,16 @@ class DBConnection(ABC): fill_value: float The value to use when filling vectors. Only used if on_bad_vectors="fill". - Note - ---- - The vector index won't be created by default. - To create the index, call the `create_index` method on the table. - Returns ------- LanceTable A reference to the newly created table. + !!! note + + The vector index won't be created by default. + To create the index, call the `create_index` method on the table. + Examples -------- @@ -119,7 +122,7 @@ class DBConnection(ABC): Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to - convert to or else provide a PyArrow table directly. + convert to or else provide a [PyArrow Table](pyarrow.Table) directly. >>> custom_schema = pa.schema([ ... pa.field("vector", pa.list_(pa.float32(), 2)), @@ -138,6 +141,30 @@ class DBConnection(ABC): vector: [[[1.1,1.2],[0.2,1.8]]] lat: [[45.5,40.1]] long: [[-122.7,-74.1]] + + + It is also possible to create an table from `[Iterable[pa.RecordBatch]]`: + + + >>> import pyarrow as pa + >>> def make_batches(): + ... for i in range(5): + ... yield pa.RecordBatch.from_arrays( + ... [ + ... pa.array([[3.1, 4.1], [5.9, 26.5]]), + ... pa.array(["foo", "bar"]), + ... pa.array([10.0, 20.0]), + ... ], + ... ["vector", "item", "price"], + ... ) + >>> schema=pa.schema([ + ... pa.field("vector", pa.list_(pa.float32())), + ... pa.field("item", pa.utf8()), + ... pa.field("price", pa.float32()), + ... ]) + >>> db.create_table("table4", make_batches(), schema=schema) + LanceTable(table4) + """ raise NotImplementedError @@ -252,7 +279,7 @@ class LanceDBConnection(DBConnection): def create_table( self, name: str, - data: DATA = None, + data: Optional[Union[List[dict], dict, pd.DataFrame]] = None, schema: pa.Schema = None, mode: str = "create", on_bad_vectors: str = "error", diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 23d4a20e..8d7a3b50 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -16,7 +16,7 @@ from __future__ import annotations import os from abc import ABC, abstractmethod from functools import cached_property -from typing import List, Union +from typing import Iterable, List, Union import lance import numpy as np @@ -44,7 +44,7 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value): data = _sanitize_schema( data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) - if not isinstance(data, pa.Table): + if not isinstance(data, (pa.Table, Iterable)): raise TypeError(f"Unsupported data type: {type(data)}") return data @@ -483,7 +483,7 @@ class LanceTable(Table): if schema is None: raise ValueError("Either data or schema must be provided") data = pa.Table.from_pylist([], schema=schema) - lance.write_dataset(data, tbl._dataset_uri, mode=mode) + lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode) return LanceTable(db, name) @classmethod diff --git a/python/tests/test_db.py b/python/tests/test_db.py index a3cb5ba5..96ef2dfd 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -76,6 +76,32 @@ def test_ingest_pd(tmp_path): assert db.open_table("test").name == db["test"].name +def test_ingest_record_batch_iterator(tmp_path): + def batch_reader(): + for i in range(5): + yield pa.RecordBatch.from_arrays( + [ + pa.array([[3.1, 4.1], [5.9, 26.5]]), + pa.array(["foo", "bar"]), + pa.array([10.0, 20.0]), + ], + ["vector", "item", "price"], + ) + + db = lancedb.connect(tmp_path) + tbl = db.create_table( + "test", + batch_reader(), + schema=pa.schema( + [ + pa.field("vector", pa.list_(pa.float32())), + pa.field("item", pa.utf8()), + pa.field("price", pa.float32()), + ] + ), + ) + + def test_create_mode(tmp_path): db = lancedb.connect(tmp_path) data = pd.DataFrame(