diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 10892841..d490520a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -44,12 +44,19 @@ jobs: run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest run: pytest --doctest-modules lancedb - mac: + platform: + name: "Platform: ${{ matrix.config.name }}" timeout-minutes: 30 strategy: matrix: - mac-runner: [ "macos-13", "macos-13-xlarge" ] - runs-on: "${{ matrix.mac-runner }}" + config: + - name: x86 Mac + runner: macos-13 + - name: Arm Mac + runner: macos-13-xlarge + - name: x86 Windows + runner: windows-latest + runs-on: "${{ matrix.config.runner }}" defaults: run: shell: bash diff --git a/python/lancedb/db.py b/python/lancedb/db.py index bcfc73b8..e2204164 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -23,7 +23,7 @@ from overrides import EnforceOverrides, override from pyarrow import fs from .table import LanceTable, Table -from .util import fs_from_uri, get_uri_location, get_uri_scheme +from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri if TYPE_CHECKING: from .common import DATA, URI @@ -288,14 +288,13 @@ class LanceDBConnection(DBConnection): A list of table names. """ try: - filesystem, path = fs_from_uri(self.uri) + filesystem = fs_from_uri(self.uri)[0] except pa.ArrowInvalid: raise NotImplementedError("Unsupported scheme: " + self.uri) try: - paths = filesystem.get_file_info( - fs.FileSelector(get_uri_location(self.uri)) - ) + loc = get_uri_location(self.uri) + paths = filesystem.get_file_info(fs.FileSelector(loc)) except FileNotFoundError: # It is ok if the file does not exist since it will be created paths = [] @@ -373,7 +372,7 @@ class LanceDBConnection(DBConnection): """ try: filesystem, path = fs_from_uri(self.uri) - table_path = os.path.join(path, name + ".lance") + table_path = join_uri(path, name + ".lance") filesystem.delete_dir(table_path) except FileNotFoundError: if not ignore_missing: diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 0927a2be..9a2bf395 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -31,7 +31,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query -from .util import fs_from_uri, safe_import_pandas, value_to_sql +from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri from .utils.events import register_event if TYPE_CHECKING: @@ -552,7 +552,7 @@ class LanceTable(Table): @property def _dataset_uri(self) -> str: - return os.path.join(self._conn.uri, f"{self.name}.lance") + return join_uri(self._conn.uri, f"{self.name}.lance") def create_index( self, @@ -614,7 +614,7 @@ class LanceTable(Table): register_event("create_fts_index") def _get_fts_index_path(self): - return os.path.join(self._dataset_uri, "_indices", "tantivy") + return join_uri(self._dataset_uri, "_indices", "tantivy") @cached_property def _dataset(self) -> LanceDataset: diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 4774c429..f38ed49a 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -14,7 +14,8 @@ import os from datetime import date, datetime from functools import singledispatch -from typing import Tuple +import pathlib +from typing import Tuple, Union from urllib.parse import urlparse import numpy as np @@ -62,6 +63,12 @@ def get_uri_location(uri: str) -> str: str: Location part of the URL, without scheme """ parsed = urlparse(uri) + if len(parsed.scheme) == 1: + # Windows drive names are parsed as the scheme + # e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...) + # So we add special handling here for schemes that are a single character + return uri + if not parsed.netloc: return parsed.path else: @@ -84,6 +91,29 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]: return pa_fs.FileSystem.from_uri(uri) +def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str: + """ + Join a URI with multiple parts, handles both local and remote paths + + Parameters + ---------- + base : str + The base URI + parts : str + The parts to join to the base URI, each separated by the + appropriate path separator for the URI scheme and OS + """ + if isinstance(base, pathlib.Path): + return base.joinpath(*parts) + base = str(base) + if get_uri_scheme(base) == "file": + # using pathlib for local paths make this windows compatible + # `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`) + return str(pathlib.Path(base, *parts)) + # for remote paths, just use os.path.join + return "/".join([p.rstrip("/") for p in [base, *parts]]) + + def safe_import_pandas(): try: import pandas as pd diff --git a/python/tests/test_util.py b/python/tests/test_util.py index 1090fa3d..1bf3e693 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -11,7 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lancedb.util import get_uri_scheme +import os +import pathlib + +import pytest + +from lancedb.util import get_uri_scheme, join_uri def test_normalize_uri(): @@ -28,3 +33,55 @@ def test_normalize_uri(): for uri, expected_scheme in zip(uris, schemes): parsed_scheme = get_uri_scheme(uri) assert parsed_scheme == expected_scheme + + +def test_join_uri_remote(): + schemes = ["s3", "az", "gs"] + for scheme in schemes: + expected = f"{scheme}://bucket/path/to/table.lance" + base_uri = f"{scheme}://bucket/path/to/" + parts = ["table.lance"] + assert join_uri(base_uri, *parts) == expected + + base_uri = f"{scheme}://bucket" + parts = ["path", "to", "table.lance"] + assert join_uri(base_uri, *parts) == expected + + +# skip this test if on windows +@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX") +def test_join_uri_posix(): + for base in [ + # relative path + "relative/path", + "relative/path/", + # an absolute path + "/absolute/path", + "/absolute/path/", + # a file URI + "file:///absolute/path", + "file:///absolute/path/", + ]: + joined = join_uri(base, "table.lance") + assert joined == str(pathlib.Path(base) / "table.lance") + joined = join_uri(pathlib.Path(base), "table.lance") + assert joined == pathlib.Path(base) / "table.lance" + + +# skip this test if not on windows +@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX") +def test_local_join_uri_windows(): + # https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats + for base in [ + # windows relative path + "relative\\path", + "relative\\path\\", + # windows absolute path from current drive + "c:\\absolute\\path", + # relative path from root of current drive + "\\relative\\path", + ]: + joined = join_uri(base, "table.lance") + assert joined == str(pathlib.Path(base) / "table.lance") + joined = join_uri(pathlib.Path(base), "table.lance") + assert joined == pathlib.Path(base) / "table.lance"