code for cloud doc

This commit is contained in:
qzhu
2024-11-13 22:05:09 -08:00
parent b70fa3892e
commit 955a295026

View File

@@ -3,47 +3,45 @@ import lancedb
import pyarrow as pa
import numpy as np
# --8<-- [end:imports]
# --8<-- [start:gen_data]
def gen_data(total_rows: int, ndims: int = 1536):
return pa.RecordBatch.from_pylist(
[
{
"vector": np.random.rand(ndims).astype(np.float32).tolist(),
"id": i,
"name": "name_"+str(i),
}
for i in range(total_rows)
],
).to_pandas()
return pa.RecordBatch.from_pylist(
[
{
"vector": np.random.rand(ndims).astype(np.float32).tolist(),
"id": i,
"name": "name_" + str(i),
}
for i in range(total_rows)
],
).to_pandas()
# --8<-- [end:gen_data]
def test_cloud_quickstart():
# --8<-- [start:connect]
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="your-cloud-region"
uri="db://your-project-slug", api_key="your-api-key", region="your-cloud-region"
)
# --8<-- [end:connect]
# --8<-- [start:create_table]
table_name = "myTable"
table_name = "myTable"
table = db.create_table(table_name, data=gen_data(5000))
# --8<-- [end:create_table]
# --8<-- [start:create_index_search]
# create a vector index
table.create_index("cosine", vector_column_name="vector")
result = (
table.search([0.01, 0.02])
.select(["vector", "item"])
.limit(1)
.to_pandas()
)
result = table.search([0.01, 0.02]).select(["vector", "item"]).limit(1).to_pandas()
# --8<-- [end:create_index_search]
# --8<-- [start:drop_table]
db.drop_table(table_name)
# --8<-- [end:drop_table]
def test_ingest_data():
# --8<-- [start:ingest_data]
import lancedb
@@ -51,9 +49,7 @@ def test_ingest_data():
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
# create an empty table with schema
@@ -64,59 +60,59 @@ def test_ingest_data():
{"vector": [10.2, 100.8], "item": "baz", "price": 30.0},
{"vector": [1.4, 9.5], "item": "fred", "price": 40.0},
]
schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float32()),
])
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float32()),
]
)
table = db.create_table(table_name, schema=schema)
table.add(data)
# --8<-- [end:ingest_data]
# --8<-- [start:ingest_data_in_batch]
def make_batches():
for i in range(5):
yield pa.RecordBatch.from_arrays(
[
pa.array([[3.1, 4.1], [5.9, 26.5]],
pa.list_(pa.float32(), 2)),
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
["vector", "item", "price"],
)
schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float32()),
])
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float32()),
]
)
db.create_table("table2", make_batches(), schema=schema)
# --8<-- [end:ingest_data_in_batch]
def test_updates():
# --8<-- [start:update_data]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
table = db.open_table(table_name)
table.update(where="price < 20.0", values={"vector": [2, 2], "item": "foo-updated"})
# --8<-- [end:update_data]
# --8<-- [end:update_data]
# --8<-- [start:merge_insert]
table = db.open_table(table_name)
# upsert
new_data = [
{"vector": [1, 1], "item": 'foo-updated', "price": 50.0}
]
table.merge_insert("item") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(new_data)
new_data = [{"vector": [1, 1], "item": "foo-updated", "price": 50.0}]
table.merge_insert(
"item"
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
# --8<-- [end:merge_insert]
# --8<-- [start:delete_data]
table_name = "myTable"
@@ -126,34 +122,32 @@ def test_updates():
table.delete(predicate)
# --8<-- [end:delete_data]
def test_create_index():
# --8<-- [start:create_index]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
table = db.open_table(table_name)
# the vector column only needs to be specified when there are
# the vector column only needs to be specified when there are
# multiple vector columns or the column is not named as "vector"
# L2 is used as the default distance metric
table.create_index(metric="cosine", vector_column_name="vector")
# --8<-- [end:create_index]
def test_create_scalar_index():
# --8<-- [start:create_scalar_index]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
@@ -162,15 +156,14 @@ def test_create_scalar_index():
table.create_scalar_index("item", index_type="BITMAP")
# --8<-- [end:create_scalar_index]
def test_create_fts_index():
# --8<-- [start:create_fts_index]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
@@ -182,15 +175,14 @@ def test_create_fts_index():
table.create_fts_index("text")
# --8<-- [end:create_fts_index]
def test_search():
# --8<-- [start:vector_search]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
@@ -206,12 +198,10 @@ def test_search():
# --8<-- [end:vector_search]
# --8<-- [start:full_text_search]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
table = db.create_table(
@@ -236,9 +226,7 @@ def test_search():
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
# Configuring the environment variable OPENAI_API_KEY
@@ -278,15 +266,14 @@ def test_search():
)
# --8<-- [end:hybrid_search]
def test_filtering():
# --8<-- [start:filtering]
import lancedb
# connect to LanceDB
db = lancedb.connect(
uri="db://your-project-slug",
api_key="your-api-key",
region="us-east-1"
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
)
table_name = "myTable"
table = db.open_table(table_name)
@@ -297,8 +284,7 @@ def test_filtering():
)
# --8<-- [end:filtering]
# --8<-- [start:sql_filtering]
table.search([100, 102]) \
.where("(item IN ('foo', 'bar')) AND (price > 10.0)") \
.to_arrow()
table.search([100, 102]).where(
"(item IN ('foo', 'bar')) AND (price > 10.0)"
).to_arrow()
# --8<-- [end:sql_filtering]