diff --git a/python/python/tests/docs/test_cloud.py b/python/python/tests/docs/test_cloud.py index d15fbd9f..6fd65e07 100644 --- a/python/python/tests/docs/test_cloud.py +++ b/python/python/tests/docs/test_cloud.py @@ -3,47 +3,45 @@ import lancedb import pyarrow as pa import numpy as np + # --8<-- [end:imports] # --8<-- [start:gen_data] def gen_data(total_rows: int, ndims: int = 1536): - return pa.RecordBatch.from_pylist( - [ - { - "vector": np.random.rand(ndims).astype(np.float32).tolist(), - "id": i, - "name": "name_"+str(i), - } - for i in range(total_rows) - ], -).to_pandas() + return pa.RecordBatch.from_pylist( + [ + { + "vector": np.random.rand(ndims).astype(np.float32).tolist(), + "id": i, + "name": "name_" + str(i), + } + for i in range(total_rows) + ], + ).to_pandas() + + # --8<-- [end:gen_data] + def test_cloud_quickstart(): # --8<-- [start:connect] db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="your-cloud-region" + uri="db://your-project-slug", api_key="your-api-key", region="your-cloud-region" ) # --8<-- [end:connect] # --8<-- [start:create_table] - table_name = "myTable" + table_name = "myTable" table = db.create_table(table_name, data=gen_data(5000)) # --8<-- [end:create_table] # --8<-- [start:create_index_search] # create a vector index table.create_index("cosine", vector_column_name="vector") - result = ( - table.search([0.01, 0.02]) - .select(["vector", "item"]) - .limit(1) - .to_pandas() - ) + result = table.search([0.01, 0.02]).select(["vector", "item"]).limit(1).to_pandas() # --8<-- [end:create_index_search] # --8<-- [start:drop_table] db.drop_table(table_name) # --8<-- [end:drop_table] + def test_ingest_data(): # --8<-- [start:ingest_data] import lancedb @@ -51,9 +49,7 @@ def test_ingest_data(): # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) # create an empty table with schema @@ -64,59 +60,59 @@ def test_ingest_data(): {"vector": [10.2, 100.8], "item": "baz", "price": 30.0}, {"vector": [1.4, 9.5], "item": "fred", "price": 40.0}, ] - schema = pa.schema([ - pa.field("vector", pa.list_(pa.float32(), 2)), - pa.field("item", pa.utf8()), - pa.field("price", pa.float32()), - ]) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.utf8()), + pa.field("price", pa.float32()), + ] + ) table = db.create_table(table_name, schema=schema) table.add(data) + # --8<-- [end:ingest_data] # --8<-- [start:ingest_data_in_batch] def make_batches(): for i in range(5): yield pa.RecordBatch.from_arrays( [ - pa.array([[3.1, 4.1], [5.9, 26.5]], - pa.list_(pa.float32(), 2)), + pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)), pa.array(["foo", "bar"]), pa.array([10.0, 20.0]), ], ["vector", "item", "price"], ) - schema = pa.schema([ - pa.field("vector", pa.list_(pa.float32(), 2)), - pa.field("item", pa.utf8()), - pa.field("price", pa.float32()), - ]) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.utf8()), + pa.field("price", pa.float32()), + ] + ) db.create_table("table2", make_batches(), schema=schema) # --8<-- [end:ingest_data_in_batch] + def test_updates(): # --8<-- [start:update_data] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" table = db.open_table(table_name) table.update(where="price < 20.0", values={"vector": [2, 2], "item": "foo-updated"}) - # --8<-- [end:update_data] + # --8<-- [end:update_data] # --8<-- [start:merge_insert] table = db.open_table(table_name) # upsert - new_data = [ - {"vector": [1, 1], "item": 'foo-updated', "price": 50.0} - ] - table.merge_insert("item") \ - .when_matched_update_all() \ - .when_not_matched_insert_all() \ - .execute(new_data) + new_data = [{"vector": [1, 1], "item": "foo-updated", "price": 50.0}] + table.merge_insert( + "item" + ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) # --8<-- [end:merge_insert] # --8<-- [start:delete_data] table_name = "myTable" @@ -126,34 +122,32 @@ def test_updates(): table.delete(predicate) # --8<-- [end:delete_data] + def test_create_index(): # --8<-- [start:create_index] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" table = db.open_table(table_name) - # the vector column only needs to be specified when there are + # the vector column only needs to be specified when there are # multiple vector columns or the column is not named as "vector" # L2 is used as the default distance metric table.create_index(metric="cosine", vector_column_name="vector") # --8<-- [end:create_index] + def test_create_scalar_index(): # --8<-- [start:create_scalar_index] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" @@ -162,15 +156,14 @@ def test_create_scalar_index(): table.create_scalar_index("item", index_type="BITMAP") # --8<-- [end:create_scalar_index] + def test_create_fts_index(): # --8<-- [start:create_fts_index] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" @@ -182,15 +175,14 @@ def test_create_fts_index(): table.create_fts_index("text") # --8<-- [end:create_fts_index] + def test_search(): # --8<-- [start:vector_search] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" @@ -206,12 +198,10 @@ def test_search(): # --8<-- [end:vector_search] # --8<-- [start:full_text_search] import lancedb - + # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" table = db.create_table( @@ -236,9 +226,7 @@ def test_search(): # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) # Configuring the environment variable OPENAI_API_KEY @@ -278,15 +266,14 @@ def test_search(): ) # --8<-- [end:hybrid_search] + def test_filtering(): # --8<-- [start:filtering] import lancedb # connect to LanceDB db = lancedb.connect( - uri="db://your-project-slug", - api_key="your-api-key", - region="us-east-1" + uri="db://your-project-slug", api_key="your-api-key", region="us-east-1" ) table_name = "myTable" table = db.open_table(table_name) @@ -297,8 +284,7 @@ def test_filtering(): ) # --8<-- [end:filtering] # --8<-- [start:sql_filtering] - table.search([100, 102]) \ - .where("(item IN ('foo', 'bar')) AND (price > 10.0)") \ - .to_arrow() + table.search([100, 102]).where( + "(item IN ('foo', 'bar')) AND (price > 10.0)" + ).to_arrow() # --8<-- [end:sql_filtering] - \ No newline at end of file