bug(python): fix path handling in windows (#724)

Use pathlib for local paths so that pathlib
can handle the correct separator on windows.

Closes #703

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Chang She
2023-12-20 15:41:36 -08:00
committed by Weston Pace
parent 3f3acb48c6
commit 009297e900
5 changed files with 107 additions and 14 deletions

View File

@@ -44,12 +44,19 @@ jobs:
run: pytest -m "not slow" -x -v --durations=30 tests run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest - name: doctest
run: pytest --doctest-modules lancedb run: pytest --doctest-modules lancedb
mac: platform:
name: "Platform: ${{ matrix.config.name }}"
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
mac-runner: [ "macos-13", "macos-13-xlarge" ] config:
runs-on: "${{ matrix.mac-runner }}" - name: x86 Mac
runner: macos-13
- name: Arm Mac
runner: macos-13-xlarge
- name: x86 Windows
runner: windows-latest
runs-on: "${{ matrix.config.runner }}"
defaults: defaults:
run: run:
shell: bash shell: bash

View File

@@ -23,7 +23,7 @@ from overrides import EnforceOverrides, override
from pyarrow import fs from pyarrow import fs
from .table import LanceTable, Table from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING: if TYPE_CHECKING:
from .common import DATA, URI from .common import DATA, URI
@@ -288,14 +288,13 @@ class LanceDBConnection(DBConnection):
A list of table names. A list of table names.
""" """
try: try:
filesystem, path = fs_from_uri(self.uri) filesystem = fs_from_uri(self.uri)[0]
except pa.ArrowInvalid: except pa.ArrowInvalid:
raise NotImplementedError("Unsupported scheme: " + self.uri) raise NotImplementedError("Unsupported scheme: " + self.uri)
try: try:
paths = filesystem.get_file_info( loc = get_uri_location(self.uri)
fs.FileSelector(get_uri_location(self.uri)) paths = filesystem.get_file_info(fs.FileSelector(loc))
)
except FileNotFoundError: except FileNotFoundError:
# It is ok if the file does not exist since it will be created # It is ok if the file does not exist since it will be created
paths = [] paths = []
@@ -373,7 +372,7 @@ class LanceDBConnection(DBConnection):
""" """
try: try:
filesystem, path = fs_from_uri(self.uri) filesystem, path = fs_from_uri(self.uri)
table_path = os.path.join(path, name + ".lance") table_path = join_uri(path, name + ".lance")
filesystem.delete_dir(table_path) filesystem.delete_dir(table_path)
except FileNotFoundError: except FileNotFoundError:
if not ignore_missing: if not ignore_missing:

View File

@@ -31,7 +31,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import timedelta from datetime import timedelta
@@ -551,7 +551,7 @@ class LanceTable(Table):
@property @property
def _dataset_uri(self) -> str: def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance") return join_uri(self._conn.uri, f"{self.name}.lance")
def create_index( def create_index(
self, self,
@@ -611,7 +611,7 @@ class LanceTable(Table):
populate_index(index, self, field_names) populate_index(index, self, field_names)
def _get_fts_index_path(self): def _get_fts_index_path(self):
return os.path.join(self._dataset_uri, "_indices", "tantivy") return join_uri(self._dataset_uri, "_indices", "tantivy")
@cached_property @cached_property
def _dataset(self) -> LanceDataset: def _dataset(self) -> LanceDataset:

View File

@@ -14,7 +14,8 @@
import os import os
from datetime import date, datetime from datetime import date, datetime
from functools import singledispatch from functools import singledispatch
from typing import Tuple import pathlib
from typing import Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@@ -62,6 +63,12 @@ def get_uri_location(uri: str) -> str:
str: Location part of the URL, without scheme str: Location part of the URL, without scheme
""" """
parsed = urlparse(uri) parsed = urlparse(uri)
if len(parsed.scheme) == 1:
# Windows drive names are parsed as the scheme
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
# So we add special handling here for schemes that are a single character
return uri
if not parsed.netloc: if not parsed.netloc:
return parsed.path return parsed.path
else: else:
@@ -84,6 +91,29 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
return pa_fs.FileSystem.from_uri(uri) return pa_fs.FileSystem.from_uri(uri)
def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
"""
Join a URI with multiple parts, handles both local and remote paths
Parameters
----------
base : str
The base URI
parts : str
The parts to join to the base URI, each separated by the
appropriate path separator for the URI scheme and OS
"""
if isinstance(base, pathlib.Path):
return base.joinpath(*parts)
base = str(base)
if get_uri_scheme(base) == "file":
# using pathlib for local paths make this windows compatible
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
return str(pathlib.Path(base, *parts))
# for remote paths, just use os.path.join
return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import_pandas(): def safe_import_pandas():
try: try:
import pandas as pd import pandas as pd

View File

@@ -11,7 +11,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lancedb.util import get_uri_scheme import os
import pathlib
import pytest
from lancedb.util import get_uri_scheme, join_uri
def test_normalize_uri(): def test_normalize_uri():
@@ -28,3 +33,55 @@ def test_normalize_uri():
for uri, expected_scheme in zip(uris, schemes): for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri) parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme assert parsed_scheme == expected_scheme
def test_join_uri_remote():
schemes = ["s3", "az", "gs"]
for scheme in schemes:
expected = f"{scheme}://bucket/path/to/table.lance"
base_uri = f"{scheme}://bucket/path/to/"
parts = ["table.lance"]
assert join_uri(base_uri, *parts) == expected
base_uri = f"{scheme}://bucket"
parts = ["path", "to", "table.lance"]
assert join_uri(base_uri, *parts) == expected
# skip this test if on windows
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
def test_join_uri_posix():
for base in [
# relative path
"relative/path",
"relative/path/",
# an absolute path
"/absolute/path",
"/absolute/path/",
# a file URI
"file:///absolute/path",
"file:///absolute/path/",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
# skip this test if not on windows
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
def test_local_join_uri_windows():
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
for base in [
# windows relative path
"relative\\path",
"relative\\path\\",
# windows absolute path from current drive
"c:\\absolute\\path",
# relative path from root of current drive
"\\relative\\path",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"