Allow creation of an empty table (#254)

It's inconvenient to always require data at table creation time.
Here we enable you to create an empty table and add data and set schema
later.

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-07-06 20:44:58 -07:00
committed by GitHub
parent 507eeae9c8
commit e2325c634b
5 changed files with 61 additions and 21 deletions

View File

@@ -61,6 +61,8 @@ jobs:
run: |
pip install -e .
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock
pip install pytest pytest-mock black
- name: Black
run: black --check --diff --no-color --quiet .
- name: Run tests
run: pytest -x -v --durations=30 tests
run: pytest -x -v --durations=30 tests

View File

@@ -1,10 +1,8 @@
import builtins
import os
import pytest
# import lancedb so we don't have to in every example
import lancedb
@pytest.fixture(autouse=True)

View File

@@ -13,7 +13,7 @@
from __future__ import annotations
import asyncio
from typing import Literal
from typing import Literal, Union
import numpy as np
import pandas as pd
@@ -48,7 +48,7 @@ class LanceQueryBuilder:
def __init__(
self,
table: "lancedb.table.LanceTable",
query: np.ndarray,
query: Union[np.ndarray, str],
vector_column_name: str = VECTOR_COLUMN_NAME,
):
self._metric = "L2"

View File

@@ -22,6 +22,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs
from lance import LanceDataset
from lance.vector import vec_to_table
@@ -95,7 +96,8 @@ class LanceTable:
def _reset_dataset(self):
try:
del self.__dict__["_dataset"]
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
except AttributeError:
pass
@@ -281,6 +283,7 @@ class LanceTable:
int
The number of vectors in the table.
"""
# TODO: manage table listing and metadata separately
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
@@ -326,7 +329,7 @@ class LanceTable:
cls,
db,
name,
data,
data=None,
schema=None,
mode="create",
on_bad_vectors: str = "error",
@@ -354,10 +357,12 @@ class LanceTable:
The LanceDB instance to create the table in.
name: str
The name of the table to create.
data: list-of-dict, dict, pd.DataFrame
data: list-of-dict, dict, pd.DataFrame, default None
The data to insert into the table.
At least one of `data` or `schema` must be provided.
schema: dict, optional
The schema of the table. If not provided, the schema is inferred from the data.
At least one of `data` or `schema` must be provided.
mode: str, default "create"
The mode to use when writing the data. Valid values are
"create", "overwrite", and "append".
@@ -368,11 +373,16 @@ class LanceTable:
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
tbl = LanceTable(db, name)
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if data is not None:
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
else:
if schema is None:
raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
return tbl
return LanceTable(db, name)
@classmethod
def open(cls, db, name):
@@ -384,7 +394,6 @@ class LanceTable:
raise FileNotFoundError(
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
)
return tbl
def delete(self, where: str):

View File

@@ -19,6 +19,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lance.vector import vec_to_table
from lancedb.db import LanceDBConnection
from lancedb.table import LanceTable
@@ -89,7 +90,31 @@ def test_create_table(db):
assert expected == tbl
def test_empty_table(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
]
)
tbl = LanceTable.create(db, "test", schema=schema)
data = [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
tbl.add(data=data)
def test_add(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
)
table = LanceTable.create(
db,
"test",
@@ -98,7 +123,19 @@ def test_add(db):
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
table = LanceTable.create(db, "test2", schema=schema)
table.add(
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
@@ -113,13 +150,7 @@ def test_add(db):
pa.array(["foo", "bar", "new"]),
pa.array([10.0, 20.0, 30.0]),
],
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
),
schema=schema,
)
assert expected == table.to_arrow()