[Python] Create table with Iterator[RecordBatch] and add docs (#316)

This commit is contained in:
Lei Xu
2023-07-16 21:45:55 -07:00
committed by GitHub
parent 7a57cddb2c
commit 088e745e1d
5 changed files with 103 additions and 20 deletions

View File

@@ -13,11 +13,12 @@
from __future__ import annotations
import functools
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
import pandas as pd
import pyarrow as pa
from pyarrow import fs
@@ -38,8 +39,10 @@ class DBConnection(ABC):
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
data: Optional[
Union[List[dict], dict, pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]],
] = None,
schema: Optional[pa.Schema] = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
@@ -51,7 +54,7 @@ class DBConnection(ABC):
name: str
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to insert into the table.
The data to initialize the table. User must provide at least one of `data` or `schema`.
schema: pyarrow.Schema; optional
The schema of the table.
mode: str; default "create"
@@ -64,16 +67,16 @@ class DBConnection(ABC):
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Note
----
The vector index won't be created by default.
To create the index, call the `create_index` method on the table.
Returns
-------
LanceTable
A reference to the newly created table.
!!! note
The vector index won't be created by default.
To create the index, call the `create_index` method on the table.
Examples
--------
@@ -119,7 +122,7 @@ class DBConnection(ABC):
Data is converted to Arrow before being written to disk. For maximum
control over how data is saved, either provide the PyArrow schema to
convert to or else provide a PyArrow table directly.
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
>>> custom_schema = pa.schema([
... pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -138,6 +141,30 @@ class DBConnection(ABC):
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
>>> import pyarrow as pa
>>> def make_batches():
... for i in range(5):
... yield pa.RecordBatch.from_arrays(
... [
... pa.array([[3.1, 4.1], [5.9, 26.5]]),
... pa.array(["foo", "bar"]),
... pa.array([10.0, 20.0]),
... ],
... ["vector", "item", "price"],
... )
>>> schema=pa.schema([
... pa.field("vector", pa.list_(pa.float32())),
... pa.field("item", pa.utf8()),
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(table4)
"""
raise NotImplementedError
@@ -252,7 +279,7 @@ class LanceDBConnection(DBConnection):
def create_table(
self,
name: str,
data: DATA = None,
data: Optional[Union[List[dict], dict, pd.DataFrame]] = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import List, Union
from typing import Iterable, List, Union
import lance
import numpy as np
@@ -44,7 +44,7 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if not isinstance(data, pa.Table):
if not isinstance(data, (pa.Table, Iterable)):
raise TypeError(f"Unsupported data type: {type(data)}")
return data
@@ -483,7 +483,7 @@ class LanceTable(Table):
if schema is None:
raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode)
return LanceTable(db, name)
@classmethod