diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 46e3860d..d63c2084 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -94,6 +94,8 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: schema = data[0].__class__.to_arrow_schema() data = [model_to_dict(d) for d in data] return pa.Table.from_pylist(data, schema=schema) + elif isinstance(data[0], pa.RecordBatch): + return pa.Table.from_batches(data, schema=schema) else: return pa.Table.from_pylist(data) elif isinstance(data, dict): @@ -173,6 +175,9 @@ def sanitize_create_table( on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) + else: + if schema is not None: + data = pa.Table.from_pylist([], schema) if schema is None: if data is None: raise ValueError("Either data or schema must be provided") diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index a775d5c7..729fb550 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -1,15 +1,7 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from unittest.mock import MagicMock import lancedb import pyarrow as pa @@ -39,3 +31,53 @@ def test_remote_db(): table = conn["test"] table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) table.search([1.0, 2.0]).to_pandas() + + +def test_create_empty_table(): + client = MagicMock() + conn = lancedb.connect("db://client-will-be-injected", api_key="fake") + + conn._client = client + + schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) + + client.post.return_value = {"status": "ok"} + table = conn.create_table("test", schema=schema) + assert table.name == "test" + assert client.post.call_args[0][0] == "/v1/table/test/create/" + + json_schema = { + "fields": [ + { + "name": "vector", + "nullable": True, + "type": { + "type": "fixed_size_list", + "fields": [ + {"name": "item", "nullable": True, "type": {"type": "float"}} + ], + "length": 2, + }, + }, + ] + } + client.post.return_value = {"schema": json_schema} + assert table.schema == schema + assert client.post.call_args[0][0] == "/v1/table/test/describe/" + + client.post.return_value = 0 + assert table.count_rows(None) == 0 + + +def test_create_table_with_recordbatches(): + client = MagicMock() + conn = lancedb.connect("db://client-will-be-injected", api_key="fake") + + conn._client = client + + batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"]) + + client.post.return_value = {"status": "ok"} + table = conn.create_table("test", [batch], schema=batch.schema) + assert table.name == "test" + assert client.post.call_args[0][0] == "/v1/table/test/create/"