diff --git a/python/lancedb/common.py b/python/lancedb/common.py new file mode 100644 index 00000000..3154cdbb --- /dev/null +++ b/python/lancedb/common.py @@ -0,0 +1,25 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Union + +import numpy as np +import pandas as pd +import pyarrow as pa + +VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] +URI = Union[str, Path] + +# TODO support generator +DATA = Union[list[dict], dict, pd.DataFrame] +VECTOR_COLUMN_NAME = "vector" diff --git a/python/lancedb/db.py b/python/lancedb/db.py index a1bdb556..e0ab2fea 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -13,23 +13,11 @@ from __future__ import annotations -from functools import cached_property from pathlib import Path -from typing import Union - -import lance -from lance import LanceDataset -from lance.vector import vec_to_table -import numpy as np -import pandas as pd import pyarrow as pa -VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] -URI = Union[str, Path] - -# TODO support generator -DATA = Union[list[dict], dict, pd.DataFrame] -VECTOR_COLUMN_NAME = "vector" +from .common import URI, DATA +from .table import LanceTable class LanceDBConnection: @@ -41,9 +29,32 @@ class LanceDBConnection: if isinstance(uri, str): uri = Path(uri) uri = uri.expanduser().absolute() - self.uri = uri + self._uri = str(uri) - def create_table(self, name: str, data: DATA = None) -> LanceTable: + @property + def uri(self) -> str: + return self._uri + + def table_names(self) -> list[str]: + """Get the names of all tables in the database. + + Returns + ------- + A list of table names. + """ + return [p.stem for p in Path(self.uri).glob("*.lance")] + + def __len__(self) -> int: + return len(self.table_names()) + + def __contains__(self, name: str) -> bool: + return name in self.table_names() + + def __getitem__(self, name: str) -> LanceTable: + return self.open_table(name) + + def create_table(self, name: str, data: DATA = None, + schema: pa.Schema = None) -> LanceTable: """Create a table in the database. Parameters @@ -52,205 +63,29 @@ class LanceDBConnection: The name of the table. data: list, tuple, dict, pd.DataFrame; optional The data to insert into the table. + schema: pyarrow.Schema; optional + The schema of the table. Returns ------- A LanceTable object representing the table. """ - tbl = LanceTable(self, name) if data is not None: - tbl.add(data) + tbl = LanceTable.create(self, name, data, schema) + else: + tbl = LanceTable(self, name) return tbl - -class LanceTable: - """ - A table in a LanceDB database. - """ - - def __init__(self, connection: LanceDBConnection, name: str, schema: pa.Schema = None): - self._conn = connection - self.name = name - self.schema = schema - - @property - def _dataset_uri(self) -> str: - return str(self._conn.uri / f"{self.name}.lance") - - @cached_property - def _dataset(self) -> LanceDataset: - return lance.dataset(self._dataset_uri) - - def add(self, data: DATA) -> int: - """Add data to the table. + def open_table(self, name: str) -> LanceTable: + """Open a table in the database. Parameters ---------- - data: list-of-dict, dict, pd.DataFrame - The data to insert into the table. + name: str + The name of the table. Returns ------- - The number of vectors added to the table. + A LanceTable object representing the table. """ - if isinstance(data, list): - data = pa.Table.from_pylist(data) - data = _sanitize_schema(data, schema=self.schema) - if isinstance(data, dict): - data = vec_to_table(data) - if isinstance(data, pd.DataFrame): - data = pa.Table.from_pandas(data) - data = _sanitize_schema(data, schema=self.schema) - if not isinstance(data, pa.Table): - raise TypeError(f"Unsupported data type: {type(data)}") - ds = lance.write_dataset(data, self._dataset_uri, mode="append") - return ds.count_rows() - - def search(self, query: VEC) -> LanceQueryBuilder: - """Create a search query to find the nearest neighbors - of the given query vector. - - Parameters - ---------- - query: list, np.ndarray - The query vector. - - Returns - ------- - A LanceQueryBuilder object representing the query. - """ - if isinstance(query, list): - query = np.array(query) - if isinstance(query, np.ndarray): - query = query.astype(np.float32) - else: - raise TypeError(f"Unsupported query type: {type(query)}") - return LanceQueryBuilder(self, query) - - -def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table: - """Ensure that the table has the expected schema. - - Parameters - ---------- - data: pa.Table - The table to sanitize. - schema: pa.Schema; optional - The expected schema. If not provided, this just converts the - vector column to fixed_size_list(float32) if necessary. - """ - if schema is not None: - if data.schema == schema: - return data - # cast the columns to the expected types - data = data.combine_chunks() - return pa.Table.from_arrays([ - data[name].cast(schema.field(name).type) - for name in schema.names - ], schema=schema) - # just check the vector column - return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME) - - -def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table: - """ - Ensure that the vector column exists and has type fixed_size_list(float32) - - Parameters - ---------- - data: pa.Table - The table to sanitize. - vector_column_name: str - The name of the vector column. - """ - i = data.column_names.index(vector_column_name) - if i < 0: - raise ValueError(f"Missing vector column: {vector_column_name}") - vec_arr = data[vector_column_name].combine_chunks() - if pa.types.is_fixed_size_list(vec_arr.type): - return data - if not pa.types.is_list(vec_arr.type): - raise TypeError(f"Unsupported vector column type: {vec_arr.type}") - values = vec_arr.values - if not pa.types.is_float32(values.type): - values = values.cast(pa.float32()) - list_size = len(values) / len(data) - vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) - return data.set_column(i, vector_column_name, vec_arr) - - -class LanceQueryBuilder: - """ - A builder for nearest neighbor queries for LanceDB. - """ - - def __init__(self, table: LanceTable, query: np.ndarray): - self._table = table - self._query = query - self._limit = 10 - self._columns = None - self._where = None - - def limit(self, limit: int) -> LanceQueryBuilder: - """Set the maximum number of results to return. - - Parameters - ---------- - limit: int - The maximum number of results to return. - - Returns - ------- - The LanceQueryBuilder object. - """ - self._limit = limit - return self - - def select(self, columns: list) -> LanceQueryBuilder: - """Set the columns to return. - - Parameters - ---------- - columns: list - The columns to return. - - Returns - ------- - The LanceQueryBuilder object. - """ - self._columns = columns - return self - - def where(self, where: str) -> LanceQueryBuilder: - """Set the where clause. - - Parameters - ---------- - where: str - The where clause. - - Returns - ------- - The LanceQueryBuilder object. - """ - self._where = where - return self - - def to_df(self) -> pd.DataFrame: - """Execute the query and return the results as a pandas DataFrame. - """ - ds = self._table._dataset - # TODO indexed search - import pdb; pdb.set_trace() - tbl = ds.to_table( - columns=self._columns, - filter=self._where, - nearest={ - "column": VECTOR_COLUMN_NAME, - "q": self._query, - "k": self._limit - } - ) - return tbl.to_pandas() - - + return LanceTable(self, name) diff --git a/python/lancedb/query.py b/python/lancedb/query.py new file mode 100644 index 00000000..9b970f2b --- /dev/null +++ b/python/lancedb/query.py @@ -0,0 +1,92 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import numpy as np +import pandas as pd + +from .common import VECTOR_COLUMN_NAME + + +class LanceQueryBuilder: + """ + A builder for nearest neighbor queries for LanceDB. + """ + + def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + self._table = table + self._query = query + self._limit = 10 + self._columns = None + self._where = None + + def limit(self, limit: int) -> LanceQueryBuilder: + """Set the maximum number of results to return. + + Parameters + ---------- + limit: int + The maximum number of results to return. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._limit = limit + return self + + def select(self, columns: list) -> LanceQueryBuilder: + """Set the columns to return. + + Parameters + ---------- + columns: list + The columns to return. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._columns = columns + return self + + def where(self, where: str) -> LanceQueryBuilder: + """Set the where clause. + + Parameters + ---------- + where: str + The where clause. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._where = where + return self + + def to_df(self) -> pd.DataFrame: + """Execute the query and return the results as a pandas DataFrame. + """ + ds = self._table._dataset + # TODO indexed search + tbl = ds.to_table( + columns=self._columns, + filter=self._where, + nearest={ + "column": VECTOR_COLUMN_NAME, + "q": self._query, + "k": self._limit + } + ) + return tbl.to_pandas() \ No newline at end of file diff --git a/python/lancedb/table.py b/python/lancedb/table.py new file mode 100644 index 00000000..cbd1e01d --- /dev/null +++ b/python/lancedb/table.py @@ -0,0 +1,166 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from functools import cached_property + +import lance +import numpy as np +import pandas as pd +from lance import LanceDataset +import pyarrow as pa +from lance.vector import vec_to_table + +from .query import LanceQueryBuilder +from .common import DATA, VECTOR_COLUMN_NAME, VEC + + +def _sanitize_data(data, schema): + if isinstance(data, list): + data = pa.Table.from_pylist(data) + data = _sanitize_schema(data, schema=schema) + if isinstance(data, dict): + data = vec_to_table(data) + if isinstance(data, pd.DataFrame): + data = pa.Table.from_pandas(data) + data = _sanitize_schema(data, schema=schema) + if not isinstance(data, pa.Table): + raise TypeError(f"Unsupported data type: {type(data)}") + return data + + +class LanceTable: + """ + A table in a LanceDB database. + """ + + def __init__(self, connection: "lancedb.db.LanceDBConnection", name: str): + self._conn = connection + self.name = name + + @property + def schema(self) -> pa.Schema: + """Return the schema of the table.""" + return self._dataset.schema + + @property + def _dataset_uri(self) -> str: + return os.path.join(self._conn.uri, f"{self.name}.lance") + + @cached_property + def _dataset(self) -> LanceDataset: + return lance.dataset(self._dataset_uri) + + def to_lance(self) -> LanceDataset: + """Return the LanceDataset backing this table.""" + return self._dataset + + def add(self, data: DATA, mode: str = "append") -> int: + """Add data to the table. + + Parameters + ---------- + data: list-of-dict, dict, pd.DataFrame + The data to insert into the table. + mode: str + The mode to use when writing the data. Valid values are + "append" and "overwrite". + + Returns + ------- + The number of vectors added to the table. + """ + data = _sanitize_data(data, self.schema) + ds = lance.write_dataset(data, self._dataset_uri, mode=mode) + return ds.count_rows() + + def search(self, query: VEC) -> LanceQueryBuilder: + """Create a search query to find the nearest neighbors + of the given query vector. + + Parameters + ---------- + query: list, np.ndarray + The query vector. + + Returns + ------- + A LanceQueryBuilder object representing the query. + """ + if isinstance(query, list): + query = np.array(query) + if isinstance(query, np.ndarray): + query = query.astype(np.float32) + else: + raise TypeError(f"Unsupported query type: {type(query)}") + return LanceQueryBuilder(self, query) + + @classmethod + def create(cls, db, name, data, schema): + tbl = LanceTable(db, name) + data = _sanitize_data(data, schema) + lance.write_dataset(data, tbl._dataset_uri, mode="create") + return tbl + + +def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table: + """Ensure that the table has the expected schema. + + Parameters + ---------- + data: pa.Table + The table to sanitize. + schema: pa.Schema; optional + The expected schema. If not provided, this just converts the + vector column to fixed_size_list(float32) if necessary. + """ + if schema is not None: + if data.schema == schema: + return data + # cast the columns to the expected types + data = data.combine_chunks() + return pa.Table.from_arrays([ + data[name].cast(schema.field(name).type) + for name in schema.names + ], schema=schema) + # just check the vector column + return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME) + + +def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table: + """ + Ensure that the vector column exists and has type fixed_size_list(float32) + + Parameters + ---------- + data: pa.Table + The table to sanitize. + vector_column_name: str + The name of the vector column. + """ + i = data.column_names.index(vector_column_name) + if i < 0: + raise ValueError(f"Missing vector column: {vector_column_name}") + vec_arr = data[vector_column_name].combine_chunks() + if pa.types.is_fixed_size_list(vec_arr.type): + return data + if not pa.types.is_list(vec_arr.type): + raise TypeError(f"Unsupported vector column type: {vec_arr.type}") + values = vec_arr.values + if not pa.types.is_float32(values.type): + values = values.cast(pa.float32()) + list_size = len(values) / len(data) + vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) + return data.set_column(i, vector_column_name, vec_arr)