From 028a6e433d99c7fe393995932841ef12f30e6c3f Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sat, 15 Jul 2023 17:39:37 -0700 Subject: [PATCH] [Python] Get table schema (#313) --- python/lancedb/db.py | 95 +------------ python/lancedb/remote/db.py | 24 +++- python/lancedb/remote/table.py | 10 +- python/lancedb/schema.py | 245 +-------------------------------- python/tests/test_schema.py | 109 --------------- 5 files changed, 35 insertions(+), 448 deletions(-) delete mode 100644 python/tests/test_schema.py diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 721091a4..69afafb7 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -260,98 +260,9 @@ class LanceDBConnection(DBConnection): ) -> LanceTable: """Create a table in the database. - Parameters - ---------- - name: str - The name of the table. - data: list, tuple, dict, pd.DataFrame; optional - The data to insert into the table. - schema: pyarrow.Schema; optional - The schema of the table. - mode: str; default "create" - The mode to use when creating the table. Can be either "create" or "overwrite". - By default, if the table already exists, an exception is raised. - If you want to overwrite the table, use mode="overwrite". - on_bad_vectors: str, default "error" - What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". - fill_value: float - The value to use when filling vectors. Only used if on_bad_vectors="fill". - - Note - ---- - The vector index won't be created by default. - To create the index, call the `create_index` method on the table. - - Returns - ------- - LanceTable - A reference to the newly created table. - - Examples - -------- - - Can create with list of tuples or dictionaries: - - >>> import lancedb - >>> db = lancedb.connect("./.lancedb") - >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, - ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] - >>> db.create_table("my_table", data) - LanceTable(my_table) - >>> db["my_table"].head() - pyarrow.Table - vector: fixed_size_list[2] - child 0, item: float - lat: double - long: double - ---- - vector: [[[1.1,1.2],[0.2,1.8]]] - lat: [[45.5,40.1]] - long: [[-122.7,-74.1]] - - You can also pass a pandas DataFrame: - - >>> import pandas as pd - >>> data = pd.DataFrame({ - ... "vector": [[1.1, 1.2], [0.2, 1.8]], - ... "lat": [45.5, 40.1], - ... "long": [-122.7, -74.1] - ... }) - >>> db.create_table("table2", data) - LanceTable(table2) - >>> db["table2"].head() - pyarrow.Table - vector: fixed_size_list[2] - child 0, item: float - lat: double - long: double - ---- - vector: [[[1.1,1.2],[0.2,1.8]]] - lat: [[45.5,40.1]] - long: [[-122.7,-74.1]] - - Data is converted to Arrow before being written to disk. For maximum - control over how data is saved, either provide the PyArrow schema to - convert to or else provide a PyArrow table directly. - - >>> custom_schema = pa.schema([ - ... pa.field("vector", pa.list_(pa.float32(), 2)), - ... pa.field("lat", pa.float32()), - ... pa.field("long", pa.float32()) - ... ]) - >>> db.create_table("table3", data, schema = custom_schema) - LanceTable(table3) - >>> db["table3"].head() - pyarrow.Table - vector: fixed_size_list[2] - child 0, item: float - lat: float - long: float - ---- - vector: [[[1.1,1.2],[0.2,1.8]]] - lat: [[45.5,40.1]] - long: [[-122.7,-74.1]] + See + --- + DBConnection.create_table """ if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index d4ddc7bc..16672daf 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -19,7 +19,8 @@ import pyarrow as pa from lancedb.common import DATA from lancedb.db import DBConnection -from lancedb.table import Table +from lancedb.schema import schema_to_json +from lancedb.table import Table, _sanitize_data from .client import RestfulLanceDBClient @@ -75,4 +76,23 @@ class RemoteDBConnection(DBConnection): on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> Table: - raise NotImplementedError + if data is None and schema is None: + raise ValueError("Either data or schema must be provided.") + if data is not None: + data = _sanitize_data( + data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) + else: + if schema is None: + raise ValueError("Either data or schema must be provided") + data = pa.Table.from_pylist([], schema=schema) + + from .table import RemoteTable + + payload = { + "name": name, + "schema": schema_to_json(data.schema), + "records": data.to_pydict(), + } + self._loop.run_until_complete(self._client.create_table("/table/", payload)) + return RemoteTable(self, name) diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 08d7a055..1d79cecb 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Union import pyarrow as pa @@ -18,6 +19,7 @@ import pyarrow as pa from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from ..query import LanceQueryBuilder, Query +from ..schema import json_to_schema from ..table import Query, Table from .db import RemoteDBConnection @@ -30,8 +32,14 @@ class RemoteTable(Table): def __repr__(self) -> str: return f"RemoteTable({self._conn.db_name}.{self.name})" + @cached_property def schema(self) -> pa.Schema: - raise NotImplementedError + """Return the schema of the table.""" + resp = self._conn._loop.run_until_complete( + self._conn._client.get(f"/table/{self._name}/describe") + ) + schema = json_to_schema(resp["schema"]) + raise schema def to_arrow(self) -> pa.Table: raise NotImplementedError diff --git a/python/lancedb/schema.py b/python/lancedb/schema.py index 3ae0b872..d89c4301 100644 --- a/python/lancedb/schema.py +++ b/python/lancedb/schema.py @@ -17,6 +17,7 @@ import json from typing import Any, Dict, Type import pyarrow as pa +from lance import json_to_schema, schema_to_json def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType: @@ -43,247 +44,3 @@ def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataTyp ... ]) """ return pa.list_(value_type, dimension) - - -def _type_to_dict(dt: pa.DataType) -> Dict[str, Any]: - if pa.types.is_boolean(dt): - return {"type": "boolean"} - elif pa.types.is_int8(dt): - return {"type": "int8"} - elif pa.types.is_int16(dt): - return {"type": "int16"} - elif pa.types.is_int32(dt): - return {"type": "int32"} - elif pa.types.is_int64(dt): - return {"type": "int64"} - elif pa.types.is_uint8(dt): - return {"type": "uint8"} - elif pa.types.is_uint16(dt): - return {"type": "uint16"} - elif pa.types.is_uint32(dt): - return {"type": "uint32"} - elif pa.types.is_uint64(dt): - return {"type": "uint64"} - elif pa.types.is_float16(dt): - return {"type": "float16"} - elif pa.types.is_float32(dt): - return {"type": "float32"} - elif pa.types.is_float64(dt): - return {"type": "float64"} - elif pa.types.is_date32(dt): - return {"type": f"date32"} - elif pa.types.is_date64(dt): - return {"type": f"date64"} - elif pa.types.is_time32(dt): - return {"type": f"time32:{dt.unit}"} - elif pa.types.is_time64(dt): - return {"type": f"time64:{dt.unit}"} - elif pa.types.is_timestamp(dt): - return {"type": f"timestamp:{dt.unit}:{dt.tz if dt.tz is not None else ''}"} - elif pa.types.is_string(dt): - return {"type": "string"} - elif pa.types.is_binary(dt): - return {"type": "binary"} - elif pa.types.is_large_string(dt): - return {"type": "large_string"} - elif pa.types.is_large_binary(dt): - return {"type": "large_binary"} - elif pa.types.is_fixed_size_binary(dt): - return {"type": "fixed_size_binary", "width": dt.byte_width} - elif pa.types.is_fixed_size_list(dt): - return { - "type": "fixed_size_list", - "width": dt.list_size, - "value_type": _type_to_dict(dt.value_type), - } - elif pa.types.is_list(dt): - return { - "type": "list", - "value_type": _type_to_dict(dt.value_type), - } - elif pa.types.is_struct(dt): - return { - "type": "struct", - "fields": [_field_to_dict(dt.field(i)) for i in range(dt.num_fields)], - } - elif pa.types.is_dictionary(dt): - return { - "type": "dictionary", - "index_type": _type_to_dict(dt.index_type), - "value_type": _type_to_dict(dt.value_type), - } - # TODO: support extension types - - raise TypeError(f"Unsupported type: {dt}") - - -def _field_to_dict(field: pa.field) -> Dict[str, Any]: - ret = { - "name": field.name, - "type": _type_to_dict(field.type), - "nullable": field.nullable, - } - if field.metadata is not None: - ret["metadata"] = field.metadata - return ret - - -def schema_to_dict(schema: pa.Schema) -> Dict[str, Any]: - """Convert a PyArrow [Schema](pyarrow.Schema) to a dictionary. - - Parameters - ---------- - schema : pa.Schema - The PyArrow Schema to convert - - Returns - ------- - A dict of the data type. - - Examples - -------- - - >>> import pyarrow as pa - >>> import lancedb - >>> schema = pa.schema( - ... [ - ... pa.field("id", pa.int64()), - ... pa.field("vector", lancedb.vector(512), nullable=False), - ... pa.field( - ... "struct", - ... pa.struct( - ... [ - ... pa.field("a", pa.utf8()), - ... pa.field("b", pa.float32()), - ... ] - ... ), - ... True, - ... ), - ... ], - ... metadata={"key": "value"}, - ... ) - >>> json_schema = schema_to_dict(schema) - >>> assert json_schema == { - ... "fields": [ - ... {"name": "id", "type": {"type": "int64"}, "nullable": True}, - ... { - ... "name": "vector", - ... "type": { - ... "type": "fixed_size_list", - ... "value_type": {"type": "float32"}, - ... "width": 512, - ... }, - ... "nullable": False, - ... }, - ... { - ... "name": "struct", - ... "type": { - ... "type": "struct", - ... "fields": [ - ... {"name": "a", "type": {"type": "string"}, "nullable": True}, - ... {"name": "b", "type": {"type": "float32"}, "nullable": True}, - ... ], - ... }, - ... "nullable": True, - ... }, - ... ], - ... "metadata": {"key": "value"}, - ... } - - """ - fields = [] - for name in schema.names: - field = schema.field(name) - fields.append(_field_to_dict(field)) - json_schema = { - "fields": fields, - "metadata": { - k.decode("utf-8"): v.decode("utf-8") for (k, v) in schema.metadata.items() - } - if schema.metadata is not None - else {}, - } - return json_schema - - -def _dict_to_type(dt: Dict[str, Any]) -> pa.DataType: - type_name = dt["type"] - try: - return { - "boolean": pa.bool_(), - "int8": pa.int8(), - "int16": pa.int16(), - "int32": pa.int32(), - "int64": pa.int64(), - "uint8": pa.uint8(), - "uint16": pa.uint16(), - "uint32": pa.uint32(), - "uint64": pa.uint64(), - "float16": pa.float16(), - "float32": pa.float32(), - "float64": pa.float64(), - "string": pa.string(), - "binary": pa.binary(), - "large_string": pa.large_string(), - "large_binary": pa.large_binary(), - "date32": pa.date32(), - "date64": pa.date64(), - }[type_name] - except KeyError: - pass - - if type_name == "fixed_size_binary": - return pa.binary(dt["width"]) - elif type_name == "fixed_size_list": - return pa.list_(_dict_to_type(dt["value_type"]), dt["width"]) - elif type_name == "list": - return pa.list_(_dict_to_type(dt["value_type"])) - elif type_name == "struct": - fields = [] - for field in dt["fields"]: - fields.append(_dict_to_field(field)) - return pa.struct(fields) - elif type_name == "dictionary": - return pa.dictionary( - _dict_to_type(dt["index_type"]), _dict_to_type(dt["value_type"]) - ) - elif type_name.startswith("time32:"): - return pa.time32(type_name.split(":")[1]) - elif type_name.startswith("time64:"): - return pa.time64(type_name.split(":")[1]) - elif type_name.startswith("timestamp:"): - fields = type_name.split(":") - unit = fields[1] - tz = fields[2] if len(fields) > 2 else None - return pa.timestamp(unit, tz) - raise TypeError(f"Unsupported type: {dt}") - - -def _dict_to_field(field: Dict[str, Any]) -> pa.Field: - name = field["name"] - nullable = field["nullable"] if "nullable" in field else True - dt = _dict_to_type(field["type"]) - metadata = field.get("metadata", None) - return pa.field(name, dt, nullable, metadata) - - -def dict_to_schema(json: Dict[str, Any]) -> pa.Schema: - """Reconstruct a PyArrow Schema from a JSON dict. - - Parameters - ---------- - json : Dict[str, Any] - The JSON dict to reconstruct Schema from. - - Returns - ------- - A PyArrow Schema. - """ - fields = [] - for field in json["fields"]: - fields.append(_dict_to_field(field)) - metadata = { - k.encode("utf-8"): v.encode("utf-8") - for (k, v) in json.get("metadata", {}).items() - } - return pa.schema(fields, metadata) diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py deleted file mode 100644 index eb1d8e02..00000000 --- a/python/tests/test_schema.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pyarrow as pa - -import lancedb -from lancedb.schema import dict_to_schema, schema_to_dict - - -def test_schema_to_dict(): - schema = pa.schema( - [ - pa.field("id", pa.int64()), - pa.field("vector", lancedb.vector(512), nullable=False), - pa.field( - "struct", - pa.struct( - [ - pa.field("a", pa.utf8()), - pa.field("b", pa.float32()), - ] - ), - True, - ), - pa.field("d", pa.dictionary(pa.int64(), pa.utf8()), False), - ], - metadata={"key": "value"}, - ) - - json_schema = schema_to_dict(schema) - assert json_schema == { - "fields": [ - {"name": "id", "type": {"type": "int64"}, "nullable": True}, - { - "name": "vector", - "type": { - "type": "fixed_size_list", - "value_type": {"type": "float32"}, - "width": 512, - }, - "nullable": False, - }, - { - "name": "struct", - "type": { - "type": "struct", - "fields": [ - {"name": "a", "type": {"type": "string"}, "nullable": True}, - {"name": "b", "type": {"type": "float32"}, "nullable": True}, - ], - }, - "nullable": True, - }, - { - "name": "d", - "type": { - "type": "dictionary", - "index_type": {"type": "int64"}, - "value_type": {"type": "string"}, - }, - "nullable": False, - }, - ], - "metadata": {"key": "value"}, - } - - actual_schema = dict_to_schema(json_schema) - assert actual_schema == schema - - -def test_temporal_types(): - schema = pa.schema( - [ - pa.field("t32", pa.time32("s")), - pa.field("t32ms", pa.time32("ms")), - pa.field("t64", pa.time64("ns")), - pa.field("ts", pa.timestamp("s")), - pa.field("ts_us_tz", pa.timestamp("us", tz="America/New_York")), - ], - ) - json_schema = schema_to_dict(schema) - - assert json_schema == { - "fields": [ - {"name": "t32", "type": {"type": "time32:s"}, "nullable": True}, - {"name": "t32ms", "type": {"type": "time32:ms"}, "nullable": True}, - {"name": "t64", "type": {"type": "time64:ns"}, "nullable": True}, - {"name": "ts", "type": {"type": "timestamp:s:"}, "nullable": True}, - { - "name": "ts_us_tz", - "type": {"type": "timestamp:us:America/New_York"}, - "nullable": True, - }, - ], - "metadata": {}, - } - - actual_schema = dict_to_schema(json_schema) - assert actual_schema == schema