diff --git a/python/pyproject.toml b/python/pyproject.toml index 2e39948e..3a583076 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -78,6 +78,9 @@ embeddings = [ "awscli>=1.29.57", "botocore>=1.31.57", ] +azure = [ + "adlfs>=2024.2.0" +] [tool.maturin] python-source = "python" diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 640dc4da..f5987c06 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -26,6 +26,18 @@ import pyarrow as pa import pyarrow.fs as pa_fs +def safe_import_adlfs(): + try: + import adlfs + + return adlfs + except ImportError: + return None + + +adlfs = safe_import_adlfs() + + def get_uri_scheme(uri: str) -> str: """ Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI. @@ -92,6 +104,17 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]: path = get_uri_location(uri) return fs, path + elif get_uri_scheme(uri) == "az" and adlfs is not None: + az_blob_fs = adlfs.AzureBlobFileSystem( + account_name=os.environ.get("AZURE_STORAGE_ACCOUNT_NAME"), + account_key=os.environ.get("AZURE_STORAGE_ACCOUNT_KEY"), + ) + + fs = pa_fs.PyFileSystem(pa_fs.FSSpecHandler(az_blob_fs)) + + path = get_uri_location(uri) + return fs, path + return pa_fs.FileSystem.from_uri(uri) diff --git a/python/python/tests/test_io.py b/python/python/tests/test_io.py index 10b749b2..af082180 100644 --- a/python/python/tests/test_io.py +++ b/python/python/tests/test_io.py @@ -16,16 +16,35 @@ import os import lancedb import pytest +# AWS: # You need to setup AWS credentials an a base path to run this test. Example # AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py +# +# Azure: +# You need to setup Azure credentials an a base path to run this test. Example +# export AZURE_STORAGE_ACCOUNT_NAME="" +# export AZURE_STORAGE_ACCOUNT_KEY="" +# export REMOTE_BASE_URL=az://my_blob/dataset +# pytest tests/test_io.py + + +@pytest.fixture(autouse=True, scope="module") +def setup(): + yield + + if remote_url := os.environ.get("REMOTE_BASE_URL"): + db = lancedb.connect(remote_url) + + for table in db.table_names(): + db.drop_table(table) @pytest.mark.skipif( - (os.environ.get("TEST_S3_BASE_URL") is None), - reason="please setup s3 base url", + (os.environ.get("REMOTE_BASE_URL") is None), + reason="please setup remote base url", ) -def test_s3_io(): - db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL")) +def test_remote_io(): + db = lancedb.connect(os.environ.get("REMOTE_BASE_URL")) assert db.table_names() == [] table = db.create_table(