diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 5ed247e7..4af40d8a 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -24,7 +24,7 @@ from pyarrow import fs from .common import DATA, URI from .table import LanceTable, Table -from .util import get_uri_location, get_uri_scheme +from .util import fs_from_uri, get_uri_location, get_uri_scheme class DBConnection(ABC): @@ -252,7 +252,7 @@ class LanceDBConnection(DBConnection): A list of table names. """ try: - filesystem, path = fs.FileSystem.from_uri(self.uri) + filesystem, path = fs_from_uri(self.uri) except pa.ArrowInvalid: raise NotImplementedError("Unsupported scheme: " + self.uri) diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 47865b07..88f85d27 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -11,8 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Tuple from urllib.parse import urlparse +import pyarrow as pa +import pyarrow.fs as pa_fs + def get_uri_scheme(uri: str) -> str: """ @@ -59,3 +64,14 @@ def get_uri_location(uri: str) -> str: return parsed.path else: return parsed.netloc + parsed.path + + +def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]: + """ + Get a PyArrow FileSystem from a URI, handling extra environment variables. + """ + if get_uri_scheme(uri) == "s3": + if os.environ["AWS_ENDPOINT"]: + uri += "?endpoint_override=" + os.environ["AWS_ENDPOINT"] + + return pa_fs.FileSystem.from_uri(uri)