feat: add to_list and to_pandas api's (#556)

Add `to_list` to return query results as list of python dict (so we're
not too pandas-centric). Closes #555

Add `to_pandas` API and add deprecation warning on `to_df`. Closes #545

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-10-11 12:18:55 -07:00
committed by Weston Pace
parent a737bbff19
commit 8469d010f8
26 changed files with 125 additions and 71 deletions

View File

@@ -16,7 +16,7 @@ pip install lancedb
import lancedb
db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>')
table = db.open_table('my_table')
results = table.search([0.1, 0.3]).limit(20).to_df()
results = table.search([0.1, 0.3]).limit(20).to_list()
print(results)
```

View File

@@ -14,12 +14,12 @@
import importlib.metadata
from typing import Optional
__version__ = importlib.metadata.version("lancedb")
from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector
__version__ = importlib.metadata.version("lancedb")
def connect(
uri: URI,

View File

@@ -12,6 +12,9 @@
# limitations under the License.
from __future__ import annotations
import deprecation
from . import __version__
from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas
@@ -43,7 +46,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
this how many tokens, but depending on the input data, it could be sentences,
paragraphs, messages, etc.
>>> contextualize(data).window(3).stride(1).text_col('token').to_df()
>>> contextualize(data).window(3).stride(1).text_col('token').to_pandas()
token document_id
0 The quick brown 1
1 quick brown fox 1
@@ -56,7 +59,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
8 dog I love 1
9 I love sandwiches 2
10 love sandwiches 2
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_df()
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
token document_id
0 The quick brown fox jumped over the 1
1 quick brown fox jumped over the lazy 1
@@ -68,7 +71,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
``stride`` determines how many rows to skip between each window start. This can
be used to reduce the total number of windows generated.
>>> contextualize(data).window(4).stride(2).text_col('token').to_df()
>>> contextualize(data).window(4).stride(2).text_col('token').to_pandas()
token document_id
0 The quick brown fox 1
2 brown fox jumped over 1
@@ -81,7 +84,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
context windows that don't cross document boundaries. In this case, we can
pass ``document_id`` as the group by.
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_df()
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_pandas()
token document_id
0 The quick brown fox 1
2 brown fox jumped over 1
@@ -93,14 +96,14 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
This can be used to trim the last few context windows which have size less than
``min_window_size``. By default context windows of size 1 are skipped.
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_df()
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_pandas()
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
6 the lazy dog 1
9 I love sandwiches 2
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_df()
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_pandas()
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
@@ -176,7 +179,16 @@ class Contextualizer:
self._min_window_size = min_window_size
return self
@deprecation.deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
current_version=__version__,
details="Use the bar function instead",
)
def to_df(self) -> "pd.DataFrame":
return self.to_pandas()
def to_pandas(self) -> "pd.DataFrame":
"""Create the context windows and return a DataFrame."""
if pd is None:
raise ImportError(

View File

@@ -16,10 +16,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
import deprecation
import numpy as np
import pyarrow as pa
import pydantic
from . import __version__
from .common import VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .util import safe_import_pandas
@@ -127,7 +129,24 @@ class LanceQueryBuilder(ABC):
self._columns = None
self._where = None
@deprecation.deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
current_version=__version__,
details="Use the bar function instead",
)
def to_df(self) -> "pd.DataFrame":
"""
Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_pandas()
def to_pandas(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
@@ -148,6 +167,16 @@ class LanceQueryBuilder(ABC):
"""
raise NotImplementedError
def to_list(self) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
"""
return self.to_arrow().to_pylist()
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
@@ -232,7 +261,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
... .where("b < 10")
... .select(["b"])
... .limit(2)
... .to_df())
... .to_pandas())
b vector _distance
0 6 [0.4, 0.4] 0.0
"""

View File

@@ -136,7 +136,7 @@ class Table(ABC):
Can query the table with [Table.search][lancedb.table.Table.search].
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
b vector _distance
0 4 [0.5, 1.3] 0.82
1 2 [1.1, 1.2] 1.13

View File

@@ -2,6 +2,7 @@
name = "lancedb"
version = "0.3.0"
dependencies = [
"deprecation",
"pylance==0.8.3",
"ratelimiter~=1.0",
"retry>=0.9.2",

View File

@@ -47,7 +47,7 @@ def test_contextualizer(raw_df: pd.DataFrame):
.stride(3)
.text_col("token")
.groupby("document_id")
.to_df()["token"]
.to_pandas()["token"]
.to_list()
)
@@ -67,7 +67,7 @@ def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
.text_col("token")
.groupby("document_id")
.min_window_size(4)
.to_df()["token"]
.to_pandas()["token"]
.to_list()
)

View File

@@ -33,11 +33,11 @@ def test_basic(tmp_path):
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
@@ -62,11 +62,11 @@ def test_ingest_pd(tmp_path):
}
)
table = db.create_table("test", data=data)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
@@ -137,8 +137,8 @@ def test_ingest_iterator(tmp_path):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
tbl_len = len(tbl)
tbl.add(make_batches())
assert tbl_len == 50

View File

@@ -23,5 +23,5 @@ from lancedb import LanceDBConnection
def test_against_local_server():
conn = LanceDBConnection("lancedb+http://localhost:10024")
table = conn.open_table("sift1m_ivf1024_pq16")
df = table.search(np.random.rand(128)).to_df()
df = table.search(np.random.rand(128)).to_pandas()
assert len(df) == 10

View File

@@ -71,14 +71,14 @@ def test_search_index(tmp_path, table):
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_df()
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) == 10
assert "text" in df.columns
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])
df = table.search("puppy").limit(10).to_df()
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
assert "text" in df.columns
assert "text2" in df.columns
@@ -87,5 +87,5 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"])
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_df()
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 0

View File

@@ -36,11 +36,11 @@ def test_s3_io():
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"

View File

@@ -85,17 +85,20 @@ def test_cast(table):
def test_query_builder(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.select(["id"])
.to_list()
)
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
assert rs[0]["id"] == 1
assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_query_builder_with_filter(table):
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
assert rs[0]["id"] == 2
assert all(np.array(rs[0]["vector"]) == [3, 4])
def test_query_builder_with_prefilter(table):
@@ -103,7 +106,7 @@ def test_query_builder_with_prefilter(table):
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2")
.limit(1)
.to_df()
.to_pandas()
)
assert len(df) == 0
@@ -111,7 +114,7 @@ def test_query_builder_with_prefilter(table):
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2", prefilter=True)
.limit(1)
.to_df()
.to_pandas()
)
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
@@ -120,9 +123,11 @@ def test_query_builder_with_prefilter(table):
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("L2")
.to_pandas()
)
tm.assert_frame_equal(df_default, df_l2)
@@ -130,7 +135,7 @@ def test_query_builder_with_metric(table):
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
.to_pandas()
)
assert df_cosine._distance[0] == pytest.approx(
cosine_distance(query, df_cosine.vector[0]),

View File

@@ -86,7 +86,7 @@ async def test_e2e_with_mock_server():
columns=["id", "vector"],
),
)
).to_df()
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns

View File

@@ -32,4 +32,4 @@ def test_remote_db():
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.search([1.0, 2.0]).to_df()
table.search([1.0, 2.0]).to_pandas()

View File

@@ -427,8 +427,8 @@ def test_multiple_vector_columns(db):
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
@@ -439,6 +439,6 @@ def test_empty_query(db):
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df.id.iloc[0]
assert val == 1