Use fsspec to implement table_names with cloud storage support (#117)

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
gsilvestrin
2023-06-01 16:56:26 -07:00
committed by GitHub
parent 45b3a14f26
commit f765a453cf
3 changed files with 86 additions and 6 deletions

View File

@@ -13,14 +13,16 @@
from __future__ import annotations
import os
from pathlib import Path
import os
import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .table import LanceTable
from .util import get_uri_scheme
from .util import get_uri_scheme, get_uri_location
class LanceDBConnection:
@@ -48,11 +50,20 @@ class LanceDBConnection:
-------
A list of table names.
"""
if get_uri_scheme(self.uri) == "file":
return [p.stem for p in Path(self.uri).glob("*.lance")]
raise NotImplementedError(
"List table_names is only supported for local filesystem for now"
)
try:
filesystem, path = fs.FileSystem.from_uri(self.uri)
except pa.ArrowInvalid:
raise NotImplementedError(
"Unsupported scheme: " + self.uri
)
try:
paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri)))
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance']
return tables
def __len__(self) -> int:
return len(self.table_names())

View File

@@ -41,3 +41,23 @@ def get_uri_scheme(uri: str) -> str:
# So we add special handling here for schemes that are a single character
scheme = "file"
return scheme
def get_uri_location(uri: str) -> str:
"""
Get the location of a URI. If the parameter is not a url, assumes it is just a path
Parameters
----------
uri : str
The URI to parse.
Returns
-------
str: Location part of the URL, without scheme
"""
parsed = urlparse(uri)
if not parsed.netloc:
return parsed.path
else:
return parsed.netloc + parsed.path

49
python/tests/test_io.py Normal file
View File

@@ -0,0 +1,49 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import lancedb
# You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
@pytest.mark.skipif(
(os.environ.get("TEST_S3_BASE_URL") is None),
reason="please setup s3 base url",
)
def test_s3_io():
db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL"))
assert db.table_names() == []
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1
assert db.open_table("test").name == db["test"].name