diff --git a/python/lancedb/remote/__init__.py b/python/lancedb/remote/__init__.py index 57cc98fa..d8c37464 100644 --- a/python/lancedb/remote/__init__.py +++ b/python/lancedb/remote/__init__.py @@ -14,7 +14,7 @@ import abc from typing import List, Optional -import attr +import attrs import pyarrow as pa from pydantic import BaseModel @@ -44,7 +44,7 @@ class VectorQuery(BaseModel): refine_factor: Optional[int] = None -@attr.define +@attrs.define class VectorQueryResult: # for now the response is directly seralized into a pandas dataframe tbl: pa.Table diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index a36a8702..375b7894 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -16,7 +16,7 @@ import functools from typing import Any, Callable, Dict, Optional, Union import aiohttp -import attr +import attrs import pyarrow as pa from pydantic import BaseModel @@ -43,14 +43,14 @@ async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: return reader.read_all() -@attr.define(slots=False) +@attrs.define(slots=False) class RestfulLanceDBClient: db_name: str region: str api_key: Credential - host_override: Optional[str] = attr.field(default=None) + host_override: Optional[str] = attrs.field(default=None) - closed: bool = attr.field(default=False, init=False) + closed: bool = attrs.field(default=False, init=False) @functools.cached_property def session(self) -> aiohttp.ClientSession: diff --git a/python/tests/test_remote_client.py b/python/tests/test_remote_client.py index ee90f28a..1afd0b60 100644 --- a/python/tests/test_remote_client.py +++ b/python/tests/test_remote_client.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import attr +import attrs import numpy as np import pandas as pd import pyarrow as pa @@ -21,10 +21,10 @@ from aiohttp import web from lancedb.remote.client import RestfulLanceDBClient, VectorQuery -@attr.define +@attrs.define class MockLanceDBServer: - runner: web.AppRunner = attr.field(init=False) - site: web.TCPSite = attr.field(init=False) + runner: web.AppRunner = attrs.field(init=False) + site: web.TCPSite = attrs.field(init=False) async def query_handler(self, request: web.Request) -> web.Response: table_name = request.match_info["table_name"]