diff --git a/.gitignore b/.gitignore index 82107c69c..40bd8a16e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ python/build python/dist notebooks/.ipynb_checkpoints + +**/.hypothesis diff --git a/python/lancedb/db.py b/python/lancedb/db.py index f905aae3b..9782163fc 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -18,6 +18,7 @@ import pyarrow as pa from .common import URI, DATA from .table import LanceTable +from .util import get_uri_scheme class LanceDBConnection: @@ -26,10 +27,12 @@ class LanceDBConnection: """ def __init__(self, uri: URI): - if isinstance(uri, str): - uri = Path(uri) - uri = uri.expanduser().absolute() - Path(uri).mkdir(parents=True, exist_ok=True) + is_local = isinstance(uri, Path) or get_uri_scheme(uri) == "file" + if is_local: + if isinstance(uri, str): + uri = Path(uri) + uri = uri.expanduser().absolute() + Path(uri).mkdir(parents=True, exist_ok=True) self._uri = str(uri) @property @@ -43,7 +46,11 @@ class LanceDBConnection: ------- A list of table names. """ - return [p.stem for p in Path(self.uri).glob("*.lance")] + if get_uri_scheme(self.uri) == "file": + return [p.stem for p in Path(self.uri).glob("*.lance")] + raise NotImplementedError( + "List table_names is only supported for local filesystem for now" + ) def __len__(self) -> int: return len(self.table_names()) diff --git a/python/lancedb/util.py b/python/lancedb/util.py new file mode 100644 index 000000000..92bb93158 --- /dev/null +++ b/python/lancedb/util.py @@ -0,0 +1,43 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from urllib.parse import ParseResult, urlparse + +from pyarrow import fs + + +def get_uri_scheme(uri: str) -> str: + """ + Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI. + + Parameters + ---------- + uri : str + The URI to parse. + + Returns + ------- + str: The scheme of the URI. + """ + parsed = urlparse(uri) + scheme = parsed.scheme + if not scheme: + scheme = "file" + elif scheme in ["s3a", "s3n"]: + scheme = "s3" + elif len(scheme) == 1: + # Windows drive names are parsed as the scheme + # e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...) + # So we add special handling here for schemes that are a single character + scheme = "file" + return scheme diff --git a/python/tests/test_util.py b/python/tests/test_util.py new file mode 100644 index 000000000..1090fa3de --- /dev/null +++ b/python/tests/test_util.py @@ -0,0 +1,30 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lancedb.util import get_uri_scheme + + +def test_normalize_uri(): + uris = [ + "relative/path", + "/absolute/path", + "file:///absolute/path", + "s3://bucket/path", + "gs://bucket/path", + "c:\\windows\\path", + ] + schemes = ["file", "file", "file", "s3", "gs", "file"] + + for uri, expected_scheme in zip(uris, schemes): + parsed_scheme = get_uri_scheme(uri) + assert parsed_scheme == expected_scheme