Add functionality for opening a table, introspection for db / table

This commit is contained in:
Chang She
2023-03-21 19:34:48 -07:00
parent d4f95a42a1
commit fd4870576e
4 changed files with 321 additions and 203 deletions

25
python/lancedb/common.py Normal file
View File

@@ -0,0 +1,25 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import pyarrow as pa
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
URI = Union[str, Path]
# TODO support generator
DATA = Union[list[dict], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector"

View File

@@ -13,23 +13,11 @@
from __future__ import annotations
from functools import cached_property
from pathlib import Path
from typing import Union
import lance
from lance import LanceDataset
from lance.vector import vec_to_table
import numpy as np
import pandas as pd
import pyarrow as pa
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
URI = Union[str, Path]
# TODO support generator
DATA = Union[list[dict], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector"
from .common import URI, DATA
from .table import LanceTable
class LanceDBConnection:
@@ -41,9 +29,32 @@ class LanceDBConnection:
if isinstance(uri, str):
uri = Path(uri)
uri = uri.expanduser().absolute()
self.uri = uri
self._uri = str(uri)
def create_table(self, name: str, data: DATA = None) -> LanceTable:
@property
def uri(self) -> str:
return self._uri
def table_names(self) -> list[str]:
"""Get the names of all tables in the database.
Returns
-------
A list of table names.
"""
return [p.stem for p in Path(self.uri).glob("*.lance")]
def __len__(self) -> int:
return len(self.table_names())
def __contains__(self, name: str) -> bool:
return name in self.table_names()
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def create_table(self, name: str, data: DATA = None,
schema: pa.Schema = None) -> LanceTable:
"""Create a table in the database.
Parameters
@@ -52,205 +63,29 @@ class LanceDBConnection:
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to insert into the table.
schema: pyarrow.Schema; optional
The schema of the table.
Returns
-------
A LanceTable object representing the table.
"""
tbl = LanceTable(self, name)
if data is not None:
tbl.add(data)
tbl = LanceTable.create(self, name, data, schema)
else:
tbl = LanceTable(self, name)
return tbl
class LanceTable:
"""
A table in a LanceDB database.
"""
def __init__(self, connection: LanceDBConnection, name: str, schema: pa.Schema = None):
self._conn = connection
self.name = name
self.schema = schema
@property
def _dataset_uri(self) -> str:
return str(self._conn.uri / f"{self.name}.lance")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri)
def add(self, data: DATA) -> int:
"""Add data to the table.
def open_table(self, name: str) -> LanceTable:
"""Open a table in the database.
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
name: str
The name of the table.
Returns
-------
The number of vectors added to the table.
A LanceTable object representing the table.
"""
if isinstance(data, list):
data = pa.Table.from_pylist(data)
data = _sanitize_schema(data, schema=self.schema)
if isinstance(data, dict):
data = vec_to_table(data)
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = _sanitize_schema(data, schema=self.schema)
if not isinstance(data, pa.Table):
raise TypeError(f"Unsupported data type: {type(data)}")
ds = lance.write_dataset(data, self._dataset_uri, mode="append")
return ds.count_rows()
def search(self, query: VEC) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
Returns
-------
A LanceQueryBuilder object representing the query.
"""
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query)
def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
"""Ensure that the table has the expected schema.
Parameters
----------
data: pa.Table
The table to sanitize.
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
"""
if schema is not None:
if data.schema == schema:
return data
# cast the columns to the expected types
data = data.combine_chunks()
return pa.Table.from_arrays([
data[name].cast(schema.field(name).type)
for name in schema.names
], schema=schema)
# just check the vector column
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table:
"""
Ensure that the vector column exists and has type fixed_size_list(float32)
Parameters
----------
data: pa.Table
The table to sanitize.
vector_column_name: str
The name of the vector column.
"""
i = data.column_names.index(vector_column_name)
if i < 0:
raise ValueError(f"Missing vector column: {vector_column_name}")
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_fixed_size_list(vec_arr.type):
return data
if not pa.types.is_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
values = vec_arr.values
if not pa.types.is_float32(values.type):
values = values.cast(pa.float32())
list_size = len(values) / len(data)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(i, vector_column_name, vec_arr)
class LanceQueryBuilder:
"""
A builder for nearest neighbor queries for LanceDB.
"""
def __init__(self, table: LanceTable, query: np.ndarray):
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
The LanceQueryBuilder object.
"""
self._where = where
return self
def to_df(self) -> pd.DataFrame:
"""Execute the query and return the results as a pandas DataFrame.
"""
ds = self._table._dataset
# TODO indexed search
import pdb; pdb.set_trace()
tbl = ds.to_table(
columns=self._columns,
filter=self._where,
nearest={
"column": VECTOR_COLUMN_NAME,
"q": self._query,
"k": self._limit
}
)
return tbl.to_pandas()
return LanceTable(self, name)

92
python/lancedb/query.py Normal file
View File

@@ -0,0 +1,92 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import numpy as np
import pandas as pd
from .common import VECTOR_COLUMN_NAME
class LanceQueryBuilder:
"""
A builder for nearest neighbor queries for LanceDB.
"""
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
The LanceQueryBuilder object.
"""
self._where = where
return self
def to_df(self) -> pd.DataFrame:
"""Execute the query and return the results as a pandas DataFrame.
"""
ds = self._table._dataset
# TODO indexed search
tbl = ds.to_table(
columns=self._columns,
filter=self._where,
nearest={
"column": VECTOR_COLUMN_NAME,
"q": self._query,
"k": self._limit
}
)
return tbl.to_pandas()

166
python/lancedb/table.py Normal file
View File

@@ -0,0 +1,166 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from functools import cached_property
import lance
import numpy as np
import pandas as pd
from lance import LanceDataset
import pyarrow as pa
from lance.vector import vec_to_table
from .query import LanceQueryBuilder
from .common import DATA, VECTOR_COLUMN_NAME, VEC
def _sanitize_data(data, schema):
if isinstance(data, list):
data = pa.Table.from_pylist(data)
data = _sanitize_schema(data, schema=schema)
if isinstance(data, dict):
data = vec_to_table(data)
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = _sanitize_schema(data, schema=schema)
if not isinstance(data, pa.Table):
raise TypeError(f"Unsupported data type: {type(data)}")
return data
class LanceTable:
"""
A table in a LanceDB database.
"""
def __init__(self, connection: "lancedb.db.LanceDBConnection", name: str):
self._conn = connection
self.name = name
@property
def schema(self) -> pa.Schema:
"""Return the schema of the table."""
return self._dataset.schema
@property
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add(self, data: DATA, mode: str = "append") -> int:
"""Add data to the table.
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
Returns
-------
The number of vectors added to the table.
"""
data = _sanitize_data(data, self.schema)
ds = lance.write_dataset(data, self._dataset_uri, mode=mode)
return ds.count_rows()
def search(self, query: VEC) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
Returns
-------
A LanceQueryBuilder object representing the query.
"""
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query)
@classmethod
def create(cls, db, name, data, schema):
tbl = LanceTable(db, name)
data = _sanitize_data(data, schema)
lance.write_dataset(data, tbl._dataset_uri, mode="create")
return tbl
def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
"""Ensure that the table has the expected schema.
Parameters
----------
data: pa.Table
The table to sanitize.
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
"""
if schema is not None:
if data.schema == schema:
return data
# cast the columns to the expected types
data = data.combine_chunks()
return pa.Table.from_arrays([
data[name].cast(schema.field(name).type)
for name in schema.names
], schema=schema)
# just check the vector column
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table:
"""
Ensure that the vector column exists and has type fixed_size_list(float32)
Parameters
----------
data: pa.Table
The table to sanitize.
vector_column_name: str
The name of the vector column.
"""
i = data.column_names.index(vector_column_name)
if i < 0:
raise ValueError(f"Missing vector column: {vector_column_name}")
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_fixed_size_list(vec_arr.type):
return data
if not pa.types.is_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
values = vec_arr.values
if not pa.types.is_float32(values.type):
values = values.cast(pa.float32())
list_size = len(values) / len(data)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(i, vector_column_name, vec_arr)