feat(python): transition Python remote sdk to use Rust implementation (#1701)

* Replaces Python implementation of Remote SDK with Rust one.
* Drops dependency on `attrs` and `cachetools`. Makes `requests` an
optional dependency used only for embeddings feature.
* Adds dependency on `nest-asyncio`. This was required to get hybrid
search working.
* Deprecate `request_thread_pool` parameter. We now use the tokio
threadpool.
* Stop caching the `schema` on a remote table. Schema is mutable and
there's no mechanism in place to invalidate the cache.
* Removed the client-side resolution of the vector column. We should
already be resolving this server-side.
This commit is contained in:
Will Jones
2024-11-05 13:44:39 -08:00
committed by GitHub
parent c61bfc3af8
commit 91cab3b556
21 changed files with 521 additions and 901 deletions

View File

@@ -2,91 +2,19 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib
from datetime import timedelta
import http.server
import json
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig
from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
class FakeLanceDBClient:
def close(self):
pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()
def test_create_empty_table():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", schema=schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
json_schema = {
"fields": [
{
"name": "vector",
"nullable": True,
"type": {
"type": "fixed_size_list",
"fields": [
{"name": "item", "nullable": True, "type": {"type": "float"}}
],
"length": 2,
},
},
]
}
client.post.return_value = {"schema": json_schema}
assert table.schema == schema
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
client.post.return_value = 0
assert table.count_rows(None) == 0
def test_create_table_with_recordbatches():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
import pyarrow as pa
def make_mock_http_handler(handler):
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler
@contextlib.contextmanager
def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager
async def mock_lancedb_connection(handler):
async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
@@ -143,7 +98,7 @@ async def test_async_remote_db():
request.end_headers()
request.wfile.write(b'{"tables": []}')
async with mock_lancedb_connection(handler) as db:
async with mock_lancedb_connection_async(handler) as db:
table_names = await db.table_names()
assert table_names == []
@@ -159,12 +114,12 @@ async def test_http_error():
request.end_headers()
request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(HttpError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507
assert "Internal Server Error" in str(exc_info.value)
@pytest.mark.asyncio
@@ -178,15 +133,225 @@ async def test_retry_error():
request.end_headers()
request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(RetryError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__
assert isinstance(cause, HttpError)
assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429
@contextlib.contextmanager
def query_test_table(query_handler):
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/query/":
content_len = int(request.headers.get("Content-Length"))
body = request.rfile.read(content_len)
body = json.loads(body)
data = query_handler(body)
request.send_response(200)
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
request.end_headers()
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
f.write_table(data)
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)"
yield table
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_maximal():
def handler(body):
assert body == {
"distance_type": "cosine",
"k": 42,
"prefilter": True,
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
"fast_search": True,
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.metric("cosine")
.limit(42)
.refine_factor(10)
.nprobes(5)
.where("id > 0", prefilter=True)
.with_row_id(True)
.select(["id", "name"])
.to_list()
)
def test_query_sync_fts():
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 10,
"vector": [],
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(table.search("puppy", query_type="fts").to_list())
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": ["name", "description"],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
.with_row_id(True)
.limit(42)
.to_list()
)
def test_query_sync_hybrid():
def handler(body):
if "full_text_query" in body:
# FTS query
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": False,
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
with query_test_table(handler) as table:
embedding_func = MockTextEmbeddingFunction()
embedding_config = MagicMock()
embedding_config.function = embedding_func
embedding_funcs = MagicMock()
embedding_funcs.get = MagicMock(return_value=embedding_config)
table.embedding_functions = embedding_funcs
(table.search("puppy", query_type="hybrid").limit(42).to_list())
def test_create_client():
mandatory_args = {
"uri": "db://dev",
"api_key": "fake-api-key",
"region": "us-east-1",
}
db = lancedb.connect(**mandatory_args)
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(**mandatory_args, client_config={})
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(
**mandatory_args,
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args,
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
db = lancedb.connect(
**mandatory_args, client_config={"retry_config": {"retries": 42}}
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, connection_timeout=42)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, read_timeout=42)
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
lancedb.connect(**mandatory_args, request_thread_pool=10)