From f765a453cf1ad896139a4af7904fe6282eb1c273 Mon Sep 17 00:00:00 2001 From: gsilvestrin Date: Thu, 1 Jun 2023 16:56:26 -0700 Subject: [PATCH] Use fsspec to implement table_names with cloud storage support (#117) Co-authored-by: Will Jones --- python/lancedb/db.py | 23 ++++++++++++++----- python/lancedb/util.py | 20 +++++++++++++++++ python/tests/test_io.py | 49 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 python/tests/test_io.py diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 45bf70c1..4f380e28 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -13,14 +13,16 @@ from __future__ import annotations +import os from pathlib import Path import os import pyarrow as pa +from pyarrow import fs from .common import DATA, URI from .table import LanceTable -from .util import get_uri_scheme +from .util import get_uri_scheme, get_uri_location class LanceDBConnection: @@ -48,11 +50,20 @@ class LanceDBConnection: ------- A list of table names. """ - if get_uri_scheme(self.uri) == "file": - return [p.stem for p in Path(self.uri).glob("*.lance")] - raise NotImplementedError( - "List table_names is only supported for local filesystem for now" - ) + try: + filesystem, path = fs.FileSystem.from_uri(self.uri) + except pa.ArrowInvalid: + raise NotImplementedError( + "Unsupported scheme: " + self.uri + ) + + try: + paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri))) + except FileNotFoundError: + # It is ok if the file does not exist since it will be created + paths = [] + tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance'] + return tables def __len__(self) -> int: return len(self.table_names()) diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 92bb9315..bc5cc7ba 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -41,3 +41,23 @@ def get_uri_scheme(uri: str) -> str: # So we add special handling here for schemes that are a single character scheme = "file" return scheme + + +def get_uri_location(uri: str) -> str: + """ + Get the location of a URI. If the parameter is not a url, assumes it is just a path + + Parameters + ---------- + uri : str + The URI to parse. + + Returns + ------- + str: Location part of the URL, without scheme + """ + parsed = urlparse(uri) + if not parsed.netloc: + return parsed.path + else: + return parsed.netloc + parsed.path diff --git a/python/tests/test_io.py b/python/tests/test_io.py new file mode 100644 index 00000000..f05d8154 --- /dev/null +++ b/python/tests/test_io.py @@ -0,0 +1,49 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest + +import lancedb + +# You need to setup AWS credentials an a base path to run this test. Example +# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py + +@pytest.mark.skipif( + (os.environ.get("TEST_S3_BASE_URL") is None), + reason="please setup s3 base url", +) +def test_s3_io(): + db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL")) + assert db.table_names() == [] + + table = db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + rs = table.search([100, 100]).limit(1).to_df() + assert len(rs) == 1 + assert rs["item"].iloc[0] == "bar" + + rs = table.search([100, 100]).where("price < 15").limit(2).to_df() + assert len(rs) == 1 + assert rs["item"].iloc[0] == "foo" + + assert db.table_names() == ["test"] + assert "test" in db + assert len(db) == 1 + + assert db.open_table("test").name == db["test"].name