mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
Co-authored-by: Lance Release <lance-dev@lancedb.com> Co-authored-by: Rob Meng <rob.xu.meng@gmail.com> Co-authored-by: Will Jones <willjones127@gmail.com> Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> Co-authored-by: rmeng <rob@lancedb.com> Co-authored-by: Chang She <chang@lancedb.com> Co-authored-by: Rok Mihevc <rok@mihevc.org>
446 lines
12 KiB
Python
446 lines
12 KiB
Python
# Copyright 2023 LanceDB Developers
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import contextlib
|
|
import importlib
|
|
import logging.config
|
|
import os
|
|
import platform
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
LOGGING_NAME = "lancedb"
|
|
VERBOSE = (
|
|
str(os.getenv("LANCEDB_VERBOSE", True)).lower() == "true"
|
|
) # global verbose mode
|
|
|
|
|
|
def set_logging(name=LOGGING_NAME, verbose=True):
|
|
"""Sets up logging for the given name.
|
|
|
|
Parameters
|
|
----------
|
|
name : str, optional
|
|
The name of the logger. Default is 'lancedb'.
|
|
verbose : bool, optional
|
|
Whether to enable verbose logging. Default is True.
|
|
"""
|
|
|
|
rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings
|
|
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
|
logging.config.dictConfig(
|
|
{
|
|
"version": 1,
|
|
"disable_existing_loggers": False,
|
|
"formatters": {name: {"format": "%(message)s"}},
|
|
"handlers": {
|
|
name: {
|
|
"class": "logging.StreamHandler",
|
|
"formatter": name,
|
|
"level": level,
|
|
}
|
|
},
|
|
"loggers": {name: {"level": level, "handlers": [name], "propagate": False}},
|
|
}
|
|
)
|
|
|
|
|
|
set_logging(LOGGING_NAME, verbose=VERBOSE)
|
|
LOGGER = logging.getLogger(LOGGING_NAME)
|
|
|
|
|
|
def is_pip_package(filepath: str = __name__) -> bool:
|
|
"""Determines if the file at the given filepath is part of a pip package.
|
|
|
|
Parameters
|
|
----------
|
|
filepath : str, optional
|
|
The filepath to check. Default is the current file.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the file is part of a pip package, False otherwise.
|
|
"""
|
|
# Get the spec for the module
|
|
spec = importlib.util.find_spec(filepath)
|
|
|
|
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
|
return spec is not None and spec.origin is not None
|
|
|
|
|
|
def is_pytest_running():
|
|
"""Determines whether pytest is currently running or not.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if pytest is running, False otherwise.
|
|
"""
|
|
return (
|
|
("PYTEST_CURRENT_TEST" in os.environ)
|
|
or ("pytest" in sys.modules)
|
|
or ("pytest" in Path(sys.argv[0]).stem)
|
|
)
|
|
|
|
|
|
def is_github_actions_ci() -> bool:
|
|
"""
|
|
Determine if the current environment is a GitHub Actions CI Python runner.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
|
"""
|
|
|
|
return (
|
|
"GITHUB_ACTIONS" in os.environ
|
|
and "RUNNER_OS" in os.environ
|
|
and "RUNNER_TOOL_CACHE" in os.environ
|
|
)
|
|
|
|
|
|
def is_git_dir():
|
|
"""
|
|
Determines whether the current file is part of a git repository.
|
|
If the current file is not part of a git repository, returns None.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if current file is part of a git repository.
|
|
"""
|
|
return get_git_dir() is not None
|
|
|
|
|
|
def is_online() -> bool:
|
|
"""
|
|
Check internet connectivity by attempting to connect to a known online host.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if connection is successful, False otherwise.
|
|
"""
|
|
import socket
|
|
|
|
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
|
try:
|
|
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
|
except (socket.timeout, socket.gaierror, OSError):
|
|
continue
|
|
else:
|
|
# If the connection was successful, close it to avoid a ResourceWarning
|
|
test_connection.close()
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
|
"""Check if a directory is writeable.
|
|
|
|
Parameters
|
|
----------
|
|
dir_path : Union[str, Path]
|
|
The path to the directory.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the directory is writeable, False otherwise.
|
|
"""
|
|
return os.access(str(dir_path), os.W_OK)
|
|
|
|
|
|
def is_colab():
|
|
"""Check if the current script is running inside a Google Colab notebook.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if running inside a Colab notebook, False otherwise.
|
|
"""
|
|
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
|
|
|
|
|
|
def is_kaggle():
|
|
"""Check if the current script is running inside a Kaggle kernel.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if running inside a Kaggle kernel, False otherwise.
|
|
"""
|
|
return (
|
|
os.environ.get("PWD") == "/kaggle/working"
|
|
and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
|
)
|
|
|
|
|
|
def is_jupyter():
|
|
"""Check if the current script is running inside a Jupyter Notebook.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if running inside a Jupyter Notebook, False otherwise.
|
|
"""
|
|
with contextlib.suppress(Exception):
|
|
from IPython import get_ipython
|
|
|
|
return get_ipython() is not None
|
|
return False
|
|
|
|
|
|
def is_docker() -> bool:
|
|
"""Determine if the script is running inside a Docker container.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the script is running inside a Docker container, False otherwise.
|
|
"""
|
|
file = Path("/proc/self/cgroup")
|
|
if file.exists():
|
|
with open(file) as f:
|
|
return "docker" in f.read()
|
|
else:
|
|
return False
|
|
|
|
|
|
def get_git_dir():
|
|
"""Determine whether the current file is part of a git repository and if so, returns the repository root directory.
|
|
If the current file is not part of a git repository, returns None.
|
|
|
|
Returns
|
|
-------
|
|
Path | None
|
|
Git root directory if found or None if not found.
|
|
"""
|
|
for d in Path(__file__).parents:
|
|
if (d / ".git").is_dir():
|
|
return d
|
|
|
|
|
|
def get_git_origin_url():
|
|
"""Retrieve the origin URL of a git repository.
|
|
|
|
Returns
|
|
-------
|
|
str | None
|
|
The origin URL of the git repository or None if not git directory.
|
|
"""
|
|
if is_git_dir():
|
|
with contextlib.suppress(subprocess.CalledProcessError):
|
|
origin = subprocess.check_output(
|
|
["git", "config", "--get", "remote.origin.url"]
|
|
)
|
|
return origin.decode().strip()
|
|
|
|
|
|
def yaml_save(file="data.yaml", data=None, header=""):
|
|
"""Save YAML data to a file.
|
|
|
|
Parameters
|
|
----------
|
|
file : str, optional
|
|
File name, by default 'data.yaml'.
|
|
data : dict, optional
|
|
Data to save in YAML format, by default None.
|
|
header : str, optional
|
|
YAML header to add, by default "".
|
|
"""
|
|
if data is None:
|
|
data = {}
|
|
file = Path(file)
|
|
if not file.parent.exists():
|
|
# Create parent directories if they don't exist
|
|
file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert Path objects to strings
|
|
for k, v in data.items():
|
|
if isinstance(v, Path):
|
|
data[k] = str(v)
|
|
|
|
# Dump data to file in YAML format
|
|
with open(file, "w", errors="ignore", encoding="utf-8") as f:
|
|
if header:
|
|
f.write(header)
|
|
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
|
|
|
|
|
def yaml_load(file="data.yaml", append_filename=False):
|
|
"""
|
|
Load YAML data from a file.
|
|
|
|
Parameters
|
|
----------
|
|
file : str, optional
|
|
File name. Default is 'data.yaml'.
|
|
append_filename : bool, optional
|
|
Add the YAML filename to the YAML dictionary. Default is False.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
YAML data and file name.
|
|
"""
|
|
assert Path(file).suffix in (
|
|
".yaml",
|
|
".yml",
|
|
), f"Attempting to load non-YAML file {file} with yaml_load()"
|
|
with open(file, errors="ignore", encoding="utf-8") as f:
|
|
s = f.read() # string
|
|
|
|
# Add YAML filename to dict and return
|
|
data = (
|
|
yaml.safe_load(s) or {}
|
|
) # always return a dict (yaml.safe_load() may return None for empty files)
|
|
if append_filename:
|
|
data["yaml_file"] = str(file)
|
|
return data
|
|
|
|
|
|
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
|
"""
|
|
Pretty prints a YAML file or a YAML-formatted dictionary.
|
|
|
|
Parameters
|
|
----------
|
|
yaml_file : Union[str, Path, dict]
|
|
The file path of the YAML file or a YAML-formatted dictionary.
|
|
|
|
Returns
|
|
-------
|
|
None
|
|
"""
|
|
yaml_dict = (
|
|
yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
|
)
|
|
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
|
LOGGER.info(f"Printing '{yaml_file}'\n\n{dump}")
|
|
|
|
|
|
PLATFORMS = [platform.system()]
|
|
if is_colab():
|
|
PLATFORMS.append("Colab")
|
|
if is_kaggle():
|
|
PLATFORMS.append("Kaggle")
|
|
if is_jupyter():
|
|
PLATFORMS.append("Jupyter")
|
|
if is_docker():
|
|
PLATFORMS.append("Docker")
|
|
|
|
PLATFORMS = "|".join(PLATFORMS)
|
|
|
|
|
|
class TryExcept(contextlib.ContextDecorator):
|
|
"""
|
|
TryExcept context manager.
|
|
Usage: @TryExcept() decorator or 'with TryExcept():' context manager.
|
|
"""
|
|
|
|
def __init__(self, msg="", verbose=True):
|
|
"""
|
|
Parameters
|
|
----------
|
|
msg : str, optional
|
|
Custom message to display in case of exception, by default "".
|
|
verbose : bool, optional
|
|
Whether to display the message, by default True.
|
|
"""
|
|
self.msg = msg
|
|
self.verbose = verbose
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, exc_type, value, traceback):
|
|
if self.verbose and value:
|
|
LOGGER.info(f"{self.msg}{': ' if self.msg else ''}{value}")
|
|
return True
|
|
|
|
|
|
def threaded_request(
|
|
method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, **kwargs
|
|
):
|
|
"""
|
|
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
|
|
|
Parameters
|
|
----------
|
|
method : str
|
|
The HTTP method to use for the request. Choices are 'post' and 'get'.
|
|
url : str
|
|
The URL to make the request to.
|
|
retry : int, optional
|
|
Number of retries to attempt before giving up, by default 3.
|
|
timeout : int, optional
|
|
Timeout in seconds after which the function will give up retrying, by default 30.
|
|
thread : bool, optional
|
|
Whether to execute the request in a separate daemon thread, by default True.
|
|
code : int, optional
|
|
An identifier for the request, used for logging purposes, by default -1.
|
|
verbose : bool, optional
|
|
A flag to determine whether to print out to console or not, by default True.
|
|
|
|
Returns
|
|
-------
|
|
requests.Response
|
|
The HTTP response object. If the request is executed in a separate thread, returns the thread itself.
|
|
"""
|
|
retry_codes = () # retry only these codes TODO: add codes if needed in future (500, 408)
|
|
|
|
@TryExcept(verbose=verbose)
|
|
def func(method, url, **kwargs):
|
|
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
|
response = None
|
|
t0 = time.time()
|
|
for i in range(retry + 1):
|
|
if (time.time() - t0) > timeout:
|
|
break
|
|
response = requests.request(method, url, **kwargs)
|
|
if response.status_code < 300: # good return codes in the 2xx range
|
|
break
|
|
try:
|
|
m = response.json().get("message", "No JSON message.")
|
|
except AttributeError:
|
|
m = "Unable to read JSON."
|
|
if i == 0:
|
|
if response.status_code in retry_codes:
|
|
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
|
elif response.status_code == 429: # rate limit
|
|
m = f"Rate limit reached"
|
|
if verbose:
|
|
LOGGER.warning(f"{response.status_code} #{code}")
|
|
if response.status_code not in retry_codes:
|
|
return response
|
|
time.sleep(2**i) # exponential standoff
|
|
return response
|
|
|
|
args = method, url
|
|
if thread:
|
|
return threading.Thread(
|
|
target=func, args=args, kwargs=kwargs, daemon=True
|
|
).start()
|
|
else:
|
|
return func(*args, **kwargs)
|