mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: add to_list and to_pandas api's (#556)
Add `to_list` to return query results as list of python dict (so we're not too pandas-centric). Closes #555 Add `to_pandas` API and add deprecation warning on `to_df`. Closes #545 Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -16,7 +16,7 @@ pip install lancedb
|
||||
import lancedb
|
||||
db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>')
|
||||
table = db.open_table('my_table')
|
||||
results = table.search([0.1, 0.3]).limit(20).to_df()
|
||||
results = table.search([0.1, 0.3]).limit(20).to_list()
|
||||
print(results)
|
||||
```
|
||||
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import importlib.metadata
|
||||
from typing import Optional
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI,
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import_pandas
|
||||
|
||||
@@ -43,7 +46,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
this how many tokens, but depending on the input data, it could be sentences,
|
||||
paragraphs, messages, etc.
|
||||
|
||||
>>> contextualize(data).window(3).stride(1).text_col('token').to_df()
|
||||
>>> contextualize(data).window(3).stride(1).text_col('token').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown 1
|
||||
1 quick brown fox 1
|
||||
@@ -56,7 +59,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
8 dog I love 1
|
||||
9 I love sandwiches 2
|
||||
10 love sandwiches 2
|
||||
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_df()
|
||||
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox jumped over the 1
|
||||
1 quick brown fox jumped over the lazy 1
|
||||
@@ -68,7 +71,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
``stride`` determines how many rows to skip between each window start. This can
|
||||
be used to reduce the total number of windows generated.
|
||||
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').to_df()
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox 1
|
||||
2 brown fox jumped over 1
|
||||
@@ -81,7 +84,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
context windows that don't cross document boundaries. In this case, we can
|
||||
pass ``document_id`` as the group by.
|
||||
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_df()
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox 1
|
||||
2 brown fox jumped over 1
|
||||
@@ -93,14 +96,14 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
This can be used to trim the last few context windows which have size less than
|
||||
``min_window_size``. By default context windows of size 1 are skipped.
|
||||
|
||||
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_df()
|
||||
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
6 the lazy dog 1
|
||||
9 I love sandwiches 2
|
||||
|
||||
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_df()
|
||||
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
@@ -176,7 +179,16 @@ class Contextualizer:
|
||||
self._min_window_size = min_window_size
|
||||
return self
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use the bar function instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Create the context windows and return a DataFrame."""
|
||||
if pd is None:
|
||||
raise ImportError(
|
||||
|
||||
@@ -16,10 +16,12 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
from .pydantic import LanceModel
|
||||
from .util import safe_import_pandas
|
||||
@@ -127,7 +129,24 @@ class LanceQueryBuilder(ABC):
|
||||
self._columns = None
|
||||
self._where = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use the bar function instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.
|
||||
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
@@ -148,6 +167,16 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_list(self) -> List[dict]:
|
||||
"""
|
||||
Execute the query and return the results as a list of dictionaries.
|
||||
|
||||
Each list entry is a dictionary with the selected column names as keys,
|
||||
or all table columns if `select` is not called. The vector and the "_distance"
|
||||
fields are returned whether or not they're explicitly selected.
|
||||
"""
|
||||
return self.to_arrow().to_pylist()
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
@@ -232,7 +261,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .limit(2)
|
||||
... .to_df())
|
||||
... .to_pandas())
|
||||
b vector _distance
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
@@ -136,7 +136,7 @@ class Table(ABC):
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
|
||||
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
|
||||
b vector _distance
|
||||
0 4 [0.5, 1.3] 0.82
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
name = "lancedb"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.8.3",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
|
||||
@@ -47,7 +47,7 @@ def test_contextualizer(raw_df: pd.DataFrame):
|
||||
.stride(3)
|
||||
.text_col("token")
|
||||
.groupby("document_id")
|
||||
.to_df()["token"]
|
||||
.to_pandas()["token"]
|
||||
.to_list()
|
||||
)
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
|
||||
.text_col("token")
|
||||
.groupby("document_id")
|
||||
.min_window_size(4)
|
||||
.to_df()["token"]
|
||||
.to_pandas()["token"]
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
@@ -33,11 +33,11 @@ def test_basic(tmp_path):
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
rs = table.search([100, 100]).limit(1).to_df()
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
@@ -62,11 +62,11 @@ def test_ingest_pd(tmp_path):
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data=data)
|
||||
rs = table.search([100, 100]).limit(1).to_df()
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
@@ -137,8 +137,8 @@ def test_ingest_iterator(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
|
||||
@@ -23,5 +23,5 @@ from lancedb import LanceDBConnection
|
||||
def test_against_local_server():
|
||||
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
||||
table = conn.open_table("sift1m_ivf1024_pq16")
|
||||
df = table.search(np.random.rand(128)).to_df()
|
||||
df = table.search(np.random.rand(128)).to_pandas()
|
||||
assert len(df) == 10
|
||||
|
||||
@@ -71,14 +71,14 @@ def test_search_index(tmp_path, table):
|
||||
|
||||
def test_create_index_from_table(tmp_path, table):
|
||||
table.create_fts_index("text")
|
||||
df = table.search("puppy").limit(10).select(["text"]).to_df()
|
||||
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert "text" in df.columns
|
||||
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
df = table.search("puppy").limit(10).to_df()
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert "text" in df.columns
|
||||
assert "text2" in df.columns
|
||||
@@ -87,5 +87,5 @@ def test_create_index_multiple_columns(tmp_path, table):
|
||||
def test_empty_rs(tmp_path, table, mocker):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(10).to_df()
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 0
|
||||
|
||||
@@ -36,11 +36,11 @@ def test_s3_io():
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
rs = table.search([100, 100]).limit(1).to_df()
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
|
||||
@@ -85,17 +85,20 @@ def test_cast(table):
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(1)
|
||||
.select(["id"])
|
||||
.to_list()
|
||||
)
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
assert rs[0]["id"] == 1
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
||||
assert rs[0]["id"] == 2
|
||||
assert all(np.array(rs[0]["vector"]) == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_prefilter(table):
|
||||
@@ -103,7 +106,7 @@ def test_query_builder_with_prefilter(table):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
.to_pandas()
|
||||
)
|
||||
assert len(df) == 0
|
||||
|
||||
@@ -111,7 +114,7 @@ def test_query_builder_with_prefilter(table):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2", prefilter=True)
|
||||
.limit(1)
|
||||
.to_df()
|
||||
.to_pandas()
|
||||
)
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
@@ -120,9 +123,11 @@ def test_query_builder_with_prefilter(table):
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("L2")
|
||||
.to_pandas()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
@@ -130,7 +135,7 @@ def test_query_builder_with_metric(table):
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
.to_pandas()
|
||||
)
|
||||
assert df_cosine._distance[0] == pytest.approx(
|
||||
cosine_distance(query, df_cosine.vector[0]),
|
||||
|
||||
@@ -86,7 +86,7 @@ async def test_e2e_with_mock_server():
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_df()
|
||||
).to_pandas()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
|
||||
@@ -32,4 +32,4 @@ def test_remote_db():
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.search([1.0, 2.0]).to_df()
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
@@ -427,8 +427,8 @@ def test_multiple_vector_columns(db):
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
|
||||
@@ -439,6 +439,6 @@ def test_empty_query(db):
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
|
||||
val = df.id.iloc[0]
|
||||
assert val == 1
|
||||
|
||||
Reference in New Issue
Block a user