Files
lancedb/python/lancedb/table.py
Rob Meng cbb56e25ab port remote connection client into lancedb (#194)
* to_df() is now async, added `to_df_blocking` to convenience
* add remote lancedb client to public lancedb
* make lancedb connection class understand url scheme
`lancedb+<connection_type>://<host>:<port>`.
2023-06-15 18:57:52 -04:00

346 lines
11 KiB
Python

# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from functools import cached_property
from typing import List, Union
import lance
import numpy as np
import pandas as pd
import pyarrow as pa
from lance import LanceDataset
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
def _sanitize_data(data, schema):
if isinstance(data, list):
data = pa.Table.from_pylist(data)
data = _sanitize_schema(data, schema=schema)
if isinstance(data, dict):
data = vec_to_table(data)
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = _sanitize_schema(data, schema=schema)
if not isinstance(data, pa.Table):
raise TypeError(f"Unsupported data type: {type(data)}")
return data
class LanceTable:
"""
A table in a LanceDB database.
Examples
--------
Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table]
(more examples in that method's documentation).
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}])
>>> table.head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
b: int64
----
vector: [[[1.1,1.2]]]
b: [[2]]
Can append new data with [LanceTable.add][lancedb.table.LanceTable.add].
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
2
Can query the table with [LanceTable.search][lancedb.table.LanceTable.search].
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
b vector score
0 4 [0.5, 1.3] 0.82
1 2 [1.1, 1.2] 1.13
Search queries are much faster when an index is created. See
[LanceTable.create_index][lancedb.table.LanceTable.create_index].
"""
def __init__(
self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None
):
self._conn = connection
self.name = name
self._version = version
def _reset_dataset(self):
try:
del self.__dict__["_dataset"]
except AttributeError:
pass
@property
def schema(self) -> pa.Schema:
"""Return the schema of the table.
Returns
-------
pa.Schema
A PyArrow schema object."""
return self._dataset.schema
def list_versions(self):
"""List all versions of the table"""
return self._dataset.versions()
@property
def version(self) -> int:
"""Get the current version of the table"""
return self._dataset.version
def checkout(self, version: int):
"""Checkout a version of the table. This is an in-place operation.
This allows viewing previous versions of the table.
Parameters
----------
version : int
The version to checkout.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
1
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
2
>>> table.version
2
>>> table.checkout(1)
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._version = version
self._reset_dataset()
def __len__(self):
return self._dataset.count_rows()
def __repr__(self) -> str:
return f"LanceTable({self.name})"
def __str__(self) -> str:
return self.__repr__()
def head(self, n=5) -> pa.Table:
"""Return the first n rows of the table."""
return self._dataset.head(n)
def to_pandas(self) -> pd.DataFrame:
"""Return the table as a pandas DataFrame.
Returns
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
Returns
-------
pa.Table"""
return self._dataset.to_table()
@property
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index. Valid values are "L2" or "cosine".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
"""
self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
index_type="IVF_PQ",
metric=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
)
self._reset_dataset()
def create_fts_index(self, field_names: Union[str, List[str]]):
"""Create a full-text search index on the table.
Warning - this API is highly experimental and is highly likely to change
in the future.
Parameters
----------
field_names: str or list of str
The name(s) of the field to index.
"""
from .fts import create_index, populate_index
if isinstance(field_names, str):
field_names = [field_names]
index = create_index(self._get_fts_index_path(), field_names)
populate_index(index, self, field_names)
def _get_fts_index_path(self):
return os.path.join(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add(self, data: DATA, mode: str = "append") -> int:
"""Add data to the table.
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
Returns
-------
int
The number of vectors in the table.
"""
data = _sanitize_data(data, self.schema)
lance.write_dataset(data, self._dataset_uri, mode=mode)
self._reset_dataset()
return len(self)
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns selected columns, the vector,
and also the "score" column which is the distance between the query
vector and the returned vector.
"""
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(self, query)
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query)
@classmethod
def create(cls, db, name, data, schema=None, mode="create"):
tbl = LanceTable(db, name)
data = _sanitize_data(data, schema)
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
return tbl
def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
"""Ensure that the table has the expected schema.
Parameters
----------
data: pa.Table
The table to sanitize.
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
"""
if schema is not None:
if data.schema == schema:
return data
# cast the columns to the expected types
data = data.combine_chunks()
data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table:
"""
Ensure that the vector column exists and has type fixed_size_list(float32)
Parameters
----------
data: pa.Table
The table to sanitize.
vector_column_name: str
The name of the vector column.
"""
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_fixed_size_list(vec_arr.type):
return data
if not pa.types.is_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
values = vec_arr.values
if not pa.types.is_float32(values.type):
values = values.cast(pa.float32())
list_size = len(values) / len(data)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)