mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
191 lines
5.5 KiB
Python
191 lines
5.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
# --8<-- [start:import-lancedb]
|
|
import lancedb
|
|
|
|
# --8<-- [end:import-lancedb]
|
|
# --8<-- [start:import-pandas]
|
|
import pandas as pd
|
|
|
|
# --8<-- [end:import-pandas]
|
|
# --8<-- [start:import-iterable]
|
|
from typing import Iterable
|
|
|
|
# --8<-- [end:import-iterable]
|
|
# --8<-- [start:import-pyarrow]
|
|
import pyarrow as pa
|
|
|
|
# --8<-- [end:import-pyarrow]
|
|
# --8<-- [start:import-polars]
|
|
import polars as pl
|
|
|
|
# --8<-- [end:import-polars]
|
|
# --8<-- [start:import-lancedb-pydantic]
|
|
from lancedb.pydantic import Vector, LanceModel
|
|
|
|
# --8<-- [end:import-lancedb-pydantic]
|
|
import pytest
|
|
|
|
|
|
# --8<-- [start:make_batches]
|
|
def make_batches() -> Iterable[pa.RecordBatch]:
|
|
for i in range(5):
|
|
yield pa.RecordBatch.from_arrays(
|
|
[
|
|
pa.array([[3.1, 4.1], [5.9, 26.5]]),
|
|
pa.array(["foo", "bar"]),
|
|
pa.array([10.0, 20.0]),
|
|
],
|
|
["vector", "item", "price"],
|
|
)
|
|
|
|
|
|
# --8<-- [end:make_batches]
|
|
|
|
|
|
def test_pandas_and_pyarrow():
|
|
# --8<-- [start:connect_to_lancedb]
|
|
uri = "data/sample-lancedb"
|
|
db = lancedb.connect(uri)
|
|
# --8<-- [end:connect_to_lancedb]
|
|
# --8<-- [start:create_table_pandas]
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
table = db.create_table("pd_table", data=data)
|
|
# --8<-- [end:create_table_pandas]
|
|
# --8<-- [start:create_table_iterable]
|
|
schema = pa.schema(
|
|
[
|
|
pa.field("vector", pa.list_(pa.float32())),
|
|
pa.field("item", pa.utf8()),
|
|
pa.field("price", pa.float32()),
|
|
]
|
|
)
|
|
table = db.create_table("iterable_table", data=make_batches(), schema=schema)
|
|
# --8<-- [end:create_table_iterable]
|
|
# --8<-- [start:vector_search]
|
|
# Open the table previously created.
|
|
table = db.open_table("pd_table")
|
|
|
|
query_vector = [100, 100]
|
|
# Pandas DataFrame
|
|
df = table.search(query_vector).limit(1).to_pandas()
|
|
print(df)
|
|
# --8<-- [end:vector_search]
|
|
# --8<-- [start:vector_search_with_filter]
|
|
# Apply the filter via LanceDB
|
|
results = table.search([100, 100]).where("price < 15").to_pandas()
|
|
assert len(results) == 1
|
|
assert results["item"].iloc[0] == "foo"
|
|
|
|
# Apply the filter via Pandas
|
|
df = results = table.search([100, 100]).to_pandas()
|
|
results = df[df.price < 15]
|
|
assert len(results) == 1
|
|
assert results["item"].iloc[0] == "foo"
|
|
# --8<-- [end:vector_search_with_filter]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pandas_and_pyarrow_async():
|
|
# --8<-- [start:connect_to_lancedb_async]
|
|
uri = "data/sample-lancedb"
|
|
async_db = await lancedb.connect_async(uri)
|
|
# --8<-- [end:connect_to_lancedb_async]
|
|
# --8<-- [start:create_table_pandas_async]
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
await async_db.create_table("pd_table_async", data=data)
|
|
# --8<-- [end:create_table_pandas_async]
|
|
# --8<-- [start:create_table_iterable_async]
|
|
schema = pa.schema(
|
|
[
|
|
pa.field("vector", pa.list_(pa.float32())),
|
|
pa.field("item", pa.utf8()),
|
|
pa.field("price", pa.float32()),
|
|
]
|
|
)
|
|
await async_db.create_table(
|
|
"iterable_table_async", data=make_batches(), schema=schema
|
|
)
|
|
# --8<-- [end:create_table_iterable_async]
|
|
# --8<-- [start:vector_search_async]
|
|
# Open the table previously created.
|
|
async_tbl = await async_db.open_table("pd_table_async")
|
|
|
|
query_vector = [100, 100]
|
|
# Pandas DataFrame
|
|
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
|
|
print(df)
|
|
# --8<-- [end:vector_search_async]
|
|
# --8<-- [start:vector_search_with_filter_async]
|
|
# Apply the filter via LanceDB
|
|
results = (
|
|
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
|
|
)
|
|
assert len(results) == 1
|
|
assert results["item"].iloc[0] == "foo"
|
|
|
|
# Apply the filter via Pandas
|
|
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
|
|
results = df[df.price < 15]
|
|
assert len(results) == 1
|
|
assert results["item"].iloc[0] == "foo"
|
|
# --8<-- [end:vector_search_with_filter_async]
|
|
|
|
|
|
# --8<-- [start:class_Item]
|
|
class Item(LanceModel):
|
|
vector: Vector(2)
|
|
item: str
|
|
price: float
|
|
|
|
|
|
# --8<-- [end:class_Item]
|
|
|
|
|
|
def test_polars():
|
|
uri = "data/sample-lancedb"
|
|
db = lancedb.connect(uri)
|
|
|
|
# --8<-- [start:create_table_polars]
|
|
data = pl.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
table = db.create_table("pl_table", data=data)
|
|
# --8<-- [end:create_table_polars]
|
|
# --8<-- [start:vector_search_polars]
|
|
query = [3.0, 4.0]
|
|
result = table.search(query).limit(1).to_polars()
|
|
print(result)
|
|
print(type(result))
|
|
# --8<-- [end:vector_search_polars]
|
|
# --8<-- [start:create_table_pydantic]
|
|
table = db.create_table("pydantic_table", schema=Item)
|
|
df = pl.DataFrame(data)
|
|
# Add Polars DataFrame to table
|
|
table.add(df)
|
|
# --8<-- [end:create_table_pydantic]
|
|
# --8<-- [start:dump_table_lazyform]
|
|
ldf = table.to_polars()
|
|
print(type(ldf))
|
|
# --8<-- [end:dump_table_lazyform]
|
|
# --8<-- [start:print_table_lazyform]
|
|
print(ldf.first().collect())
|
|
# --8<-- [end:print_table_lazyform]
|