mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-23 15:00:39 +00:00
LanceDB default vector column discovery only considered top-level fields, so tables with a single nested vector leaf still required users to pass an explicit field path. This updates Rust and Python discovery to recurse into struct fields, return canonical field paths, and preserve actionable errors when no default or multiple defaults exist. The explicit nested path flow for index creation and search remains supported across Rust, Python, and Node, with regression coverage for single nested vector leaves, multiple candidate leaves, and schemas without vector leaves. Closes #3405.
497 lines
14 KiB
Python
497 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
import binascii
|
|
import functools
|
|
import importlib
|
|
import os
|
|
import pathlib
|
|
import warnings
|
|
from datetime import date, datetime
|
|
from functools import singledispatch
|
|
from typing import Tuple, Union, Optional, Any, List
|
|
from urllib.parse import urlparse
|
|
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import pyarrow.fs as pa_fs
|
|
|
|
from ._lancedb import validate_table_name as native_validate_table_name
|
|
|
|
|
|
def safe_import_adlfs():
|
|
try:
|
|
import adlfs
|
|
|
|
return adlfs
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
adlfs = safe_import_adlfs()
|
|
|
|
|
|
def get_uri_scheme(uri: str) -> str:
|
|
"""
|
|
Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI.
|
|
|
|
Parameters
|
|
----------
|
|
uri : str
|
|
The URI to parse.
|
|
|
|
Returns
|
|
-------
|
|
str: The scheme of the URI.
|
|
"""
|
|
parsed = urlparse(uri)
|
|
scheme = parsed.scheme
|
|
if not scheme:
|
|
scheme = "file"
|
|
elif scheme in ["s3a", "s3n"]:
|
|
scheme = "s3"
|
|
elif len(scheme) == 1:
|
|
# Windows drive names are parsed as the scheme
|
|
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
|
|
# So we add special handling here for schemes that are a single character
|
|
scheme = "file"
|
|
return scheme
|
|
|
|
|
|
def get_uri_location(uri: str) -> str:
|
|
"""
|
|
Get the location of a URI. If the parameter is not a url, assumes it is just a path
|
|
|
|
Parameters
|
|
----------
|
|
uri : str
|
|
The URI to parse.
|
|
|
|
Returns
|
|
-------
|
|
str: Location part of the URL, without scheme
|
|
"""
|
|
parsed = urlparse(uri)
|
|
if len(parsed.scheme) == 1:
|
|
# Windows drive names are parsed as the scheme
|
|
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
|
|
# So we add special handling here for schemes that are a single character
|
|
return uri
|
|
|
|
if not parsed.netloc:
|
|
return parsed.path
|
|
else:
|
|
return parsed.netloc + parsed.path
|
|
|
|
|
|
def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
|
|
"""
|
|
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
|
"""
|
|
if get_uri_scheme(uri) == "s3":
|
|
fs = pa_fs.S3FileSystem(
|
|
endpoint_override=os.environ.get("AWS_ENDPOINT"),
|
|
request_timeout=30,
|
|
connect_timeout=30,
|
|
)
|
|
path = get_uri_location(uri)
|
|
return fs, path
|
|
|
|
elif get_uri_scheme(uri) == "az" and adlfs is not None:
|
|
az_blob_fs = adlfs.AzureBlobFileSystem(
|
|
account_name=os.environ.get("AZURE_STORAGE_ACCOUNT_NAME"),
|
|
account_key=os.environ.get("AZURE_STORAGE_ACCOUNT_KEY"),
|
|
)
|
|
|
|
fs = pa_fs.PyFileSystem(pa_fs.FSSpecHandler(az_blob_fs))
|
|
|
|
path = get_uri_location(uri)
|
|
return fs, path
|
|
|
|
return pa_fs.FileSystem.from_uri(uri)
|
|
|
|
|
|
def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
|
"""
|
|
Join a URI with multiple parts, handles both local and remote paths
|
|
|
|
Parameters
|
|
----------
|
|
base : str
|
|
The base URI
|
|
parts : str
|
|
The parts to join to the base URI, each separated by the
|
|
appropriate path separator for the URI scheme and OS
|
|
"""
|
|
if isinstance(base, pathlib.Path):
|
|
return base.joinpath(*parts)
|
|
base = str(base)
|
|
if get_uri_scheme(base) == "file":
|
|
# using pathlib for local paths make this windows compatible
|
|
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
|
|
return str(pathlib.Path(base, *parts))
|
|
else:
|
|
# there might be query parameters in the base URI
|
|
url = urlparse(base)
|
|
new_path = "/".join([p.rstrip("/") for p in [url.path, *parts]])
|
|
return url._replace(path=new_path).geturl()
|
|
|
|
|
|
def attempt_import_or_raise(module: str, mitigation=None):
|
|
"""
|
|
Import the specified module. If the module is not installed,
|
|
raise an ImportError with a helpful message.
|
|
|
|
Parameters
|
|
----------
|
|
module : str
|
|
The name of the module to import
|
|
mitigation : Optional[str]
|
|
The package(s) to install to mitigate the error.
|
|
If not provided then the module name will be used.
|
|
"""
|
|
try:
|
|
return importlib.import_module(module)
|
|
except ImportError:
|
|
raise ImportError(f"Please install {mitigation or module}")
|
|
|
|
|
|
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
|
"""
|
|
Flatten all struct columns in a table.
|
|
|
|
Parameters
|
|
----------
|
|
flatten: Optional[Union[int, bool]]
|
|
If flatten is True, flatten all nested columns.
|
|
If flatten is an integer, flatten the nested columns up to the
|
|
specified depth.
|
|
If unspecified, do not flatten the nested columns.
|
|
"""
|
|
if flatten is True:
|
|
while True:
|
|
tbl = tbl.flatten()
|
|
# loop through all columns to check if there is any struct column
|
|
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
|
continue
|
|
else:
|
|
break
|
|
elif isinstance(flatten, int):
|
|
if flatten <= 0:
|
|
raise ValueError(
|
|
"Please specify a positive integer for flatten or the boolean "
|
|
"value `True`"
|
|
)
|
|
while flatten > 0:
|
|
tbl = tbl.flatten()
|
|
flatten -= 1
|
|
return tbl
|
|
|
|
|
|
def _format_field_path(path: List[str]) -> str:
|
|
def format_segment(segment: str) -> str:
|
|
if all(char.isalnum() or char == "_" for char in segment):
|
|
return segment
|
|
return f"`{segment.replace('`', '``')}`"
|
|
|
|
return ".".join(format_segment(segment) for segment in path)
|
|
|
|
|
|
def _iter_vector_columns(
|
|
field: pa.Field, path: List[str], dim: Optional[int] = None
|
|
) -> List[str]:
|
|
field_path = [*path, field.name]
|
|
if is_vector_column(field.type):
|
|
vector_dim = infer_vector_column_dim(field.type)
|
|
if dim is None or vector_dim == dim:
|
|
return [_format_field_path(field_path)]
|
|
return []
|
|
if pa.types.is_struct(field.type):
|
|
columns = []
|
|
for idx in range(field.type.num_fields):
|
|
columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim))
|
|
return columns
|
|
return []
|
|
|
|
|
|
def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str:
|
|
"""
|
|
Get the vector column name
|
|
|
|
Parameters
|
|
----------
|
|
schema : pa.Schema
|
|
The schema of the vector column.
|
|
|
|
Returns
|
|
-------
|
|
str: the vector column name.
|
|
"""
|
|
vector_col_names = []
|
|
for field in schema:
|
|
vector_col_names.extend(_iter_vector_columns(field, [], dim))
|
|
if len(vector_col_names) > 1:
|
|
raise ValueError(
|
|
"Schema has more than one vector column. "
|
|
"Please specify the vector column name "
|
|
f"for vector search. Candidates: {vector_col_names}"
|
|
)
|
|
if len(vector_col_names) == 0:
|
|
raise ValueError(
|
|
"There is no vector column in the data. "
|
|
"Please specify the vector column name for vector search"
|
|
)
|
|
return vector_col_names[0]
|
|
|
|
|
|
def is_vector_column(data_type: pa.DataType) -> bool:
|
|
"""
|
|
Check if the column is a vector column.
|
|
|
|
Parameters
|
|
----------
|
|
data_type : pa.DataType
|
|
The data type of the column.
|
|
|
|
Returns
|
|
-------
|
|
bool: True if the column is a vector column.
|
|
"""
|
|
if pa.types.is_fixed_size_list(data_type) and (
|
|
pa.types.is_floating(data_type.value_type)
|
|
or pa.types.is_uint8(data_type.value_type)
|
|
):
|
|
return True
|
|
elif pa.types.is_list(data_type):
|
|
return is_vector_column(data_type.value_type)
|
|
return False
|
|
|
|
|
|
def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]:
|
|
if pa.types.is_fixed_size_list(data_type):
|
|
return data_type.list_size
|
|
if pa.types.is_list(data_type):
|
|
return infer_vector_column_dim(data_type.value_type)
|
|
return None
|
|
|
|
|
|
def _query_vector_dim(query: Optional[Any]) -> Optional[int]:
|
|
if query is None:
|
|
return None
|
|
if isinstance(query, np.ndarray):
|
|
if query.ndim == 0:
|
|
return None
|
|
return query.shape[-1]
|
|
if isinstance(query, list) and query:
|
|
first = query[0]
|
|
if isinstance(first, (list, tuple, np.ndarray)):
|
|
return len(first)
|
|
return len(query)
|
|
return None
|
|
|
|
|
|
def infer_vector_column_name(
|
|
schema: pa.Schema,
|
|
query_type: str,
|
|
query: Optional[Any], # inferred later in query builder
|
|
vector_column_name: Optional[str],
|
|
):
|
|
if vector_column_name is not None:
|
|
return vector_column_name
|
|
|
|
if query_type == "fts":
|
|
# FTS queries do not require a vector column
|
|
return None
|
|
|
|
if query is not None or query_type == "hybrid":
|
|
try:
|
|
vector_column_name = inf_vector_column_query(
|
|
schema, dim=_query_vector_dim(query)
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return vector_column_name
|
|
|
|
|
|
@singledispatch
|
|
def value_to_sql(value):
|
|
raise NotImplementedError("SQL conversion is not implemented for this type")
|
|
|
|
|
|
@value_to_sql.register(str)
|
|
def _(value: str):
|
|
value = value.replace("'", "''")
|
|
return f"'{value}'"
|
|
|
|
|
|
@value_to_sql.register(bytes)
|
|
def _(value: bytes):
|
|
"""Convert bytes to a hex string literal.
|
|
|
|
See https://datafusion.apache.org/user-guide/sql/data_types.html#binary-types
|
|
"""
|
|
return f"X'{binascii.hexlify(value).decode()}'"
|
|
|
|
|
|
@value_to_sql.register(int)
|
|
def _(value: int):
|
|
return str(value)
|
|
|
|
|
|
@value_to_sql.register(float)
|
|
def _(value: float):
|
|
return str(value)
|
|
|
|
|
|
@value_to_sql.register(bool)
|
|
def _(value: bool):
|
|
return str(value).upper()
|
|
|
|
|
|
@value_to_sql.register(type(None))
|
|
def _(value: type(None)):
|
|
return "NULL"
|
|
|
|
|
|
@value_to_sql.register(datetime)
|
|
def _(value: datetime):
|
|
return f"'{value.isoformat()}'"
|
|
|
|
|
|
@value_to_sql.register(date)
|
|
def _(value: date):
|
|
return f"'{value.isoformat()}'"
|
|
|
|
|
|
@value_to_sql.register(list)
|
|
def _(value: list):
|
|
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
|
|
|
|
|
@value_to_sql.register(dict)
|
|
def _(value: dict):
|
|
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#named-struct
|
|
return (
|
|
"named_struct("
|
|
+ ", ".join(f"'{k}', {value_to_sql(v)}" for k, v in value.items())
|
|
+ ")"
|
|
)
|
|
|
|
|
|
@value_to_sql.register(np.ndarray)
|
|
def _(value: np.ndarray):
|
|
return value_to_sql(value.tolist())
|
|
|
|
|
|
def deprecated(func):
|
|
"""This is a decorator which can be used to mark functions
|
|
as deprecated. It will result in a warning being emitted
|
|
when the function is used."""
|
|
|
|
@functools.wraps(func)
|
|
def new_func(*args, **kwargs):
|
|
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
warnings.warn(
|
|
(
|
|
f"Function {func.__name__} is deprecated and will be "
|
|
"removed in a future version"
|
|
),
|
|
category=DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
return func(*args, **kwargs)
|
|
|
|
return new_func
|
|
|
|
|
|
def validate_table_name(name: str):
|
|
"""Verify the table name is valid."""
|
|
native_validate_table_name(name)
|
|
|
|
|
|
def add_note(base_exception: BaseException, note: str):
|
|
if hasattr(base_exception, "add_note"):
|
|
base_exception.add_note(note)
|
|
elif isinstance(base_exception.args[0], str):
|
|
base_exception.args = (
|
|
base_exception.args[0] + "\n" + note,
|
|
*base_exception.args[1:],
|
|
)
|
|
else:
|
|
raise ValueError("Cannot add note to exception")
|
|
|
|
|
|
def tbl_to_tensor(tbl: pa.Table):
|
|
"""
|
|
Convert a PyArrow Table to a PyTorch Tensor.
|
|
|
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
|
and the columns are then stacked into a single tensor.
|
|
|
|
Fails if torch is not installed.
|
|
Fails if any column is more than one chunk.
|
|
Fails if a column's data type is not supported by PyTorch.
|
|
|
|
Parameters
|
|
----------
|
|
tbl : pa.Table or pa.RecordBatch
|
|
The table or record batch to convert to a tensor.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor: The tensor containing the columns of the table.
|
|
"""
|
|
torch = attempt_import_or_raise("torch", "torch")
|
|
|
|
def to_tensor(col: pa.ChunkedArray):
|
|
if col.num_chunks > 1:
|
|
raise Exception("Single batch was too large to fit into a one-chunk table")
|
|
return torch.from_dlpack(col.chunk(0))
|
|
|
|
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
|
|
|
|
|
def batch_to_tensor(batch: pa.RecordBatch):
|
|
"""
|
|
Convert a PyArrow RecordBatch to a PyTorch Tensor.
|
|
|
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
|
and the columns are then stacked into a single tensor.
|
|
|
|
Fails if torch is not installed.
|
|
Fails if a column's data type is not supported by PyTorch.
|
|
|
|
Parameters
|
|
----------
|
|
batch : pa.RecordBatch
|
|
The record batch to convert to a tensor.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor: The tensor containing the columns of the record batch.
|
|
"""
|
|
torch = attempt_import_or_raise("torch", "torch")
|
|
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
|
|
|
|
|
|
def batch_to_tensor_rows(batch: pa.RecordBatch):
|
|
"""
|
|
Convert a PyArrow RecordBatch to a list of PyTorch Tensor, one per row
|
|
|
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
|
and the columns are then stacked into a single tensor. The 2D tensor
|
|
is then converted to a list of tensors, one per row
|
|
|
|
Fails if torch or numpy is not installed.
|
|
Fails if a column's data type is not supported by PyTorch.
|
|
"""
|
|
torch = attempt_import_or_raise("torch", "torch")
|
|
numpy = attempt_import_or_raise("numpy", "numpy")
|
|
columns = [col.to_numpy(zero_copy_only=False) for col in batch.columns]
|
|
stacked = torch.tensor(numpy.column_stack(columns))
|
|
rows = list(stacked.unbind(dim=0))
|
|
return rows
|