mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-08 17:00:40 +00:00
Compare commits
3 Commits
python-v0.
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a813ce2f71 | ||
|
|
a898dc81c2 | ||
|
|
de3f8097e7 |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.28.0-beta.0"
|
current_version = "0.28.0-beta.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -4630,7 +4630,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.28.0-beta.0"
|
version = "0.28.0-beta.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
@@ -4712,7 +4712,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
version = "0.28.0-beta.0"
|
version = "0.28.0-beta.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4734,7 +4734,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.31.0-beta.0"
|
version = "0.31.0-beta.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-core</artifactId>
|
<artifactId>lancedb-core</artifactId>
|
||||||
<version>0.28.0-beta.0</version>
|
<version>0.28.0-beta.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -53,3 +53,18 @@ optional tlsConfig: TlsConfig;
|
|||||||
```ts
|
```ts
|
||||||
optional userAgent: string;
|
optional userAgent: string;
|
||||||
```
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### userId?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional userId: string;
|
||||||
|
```
|
||||||
|
|
||||||
|
User identifier for tracking purposes.
|
||||||
|
|
||||||
|
This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||||
|
It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||||
|
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||||
|
variable that contains the user ID value.
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.28.0-beta.0</version>
|
<version>0.28.0-beta.1</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.28.0-beta.0</version>
|
<version>0.28.0-beta.1</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.28.0-beta.0"
|
version = "0.28.0-beta.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.28.0-beta.0",
|
"version": "0.28.0-beta.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -92,6 +92,13 @@ pub struct ClientConfig {
|
|||||||
pub extra_headers: Option<HashMap<String, String>>,
|
pub extra_headers: Option<HashMap<String, String>>,
|
||||||
pub id_delimiter: Option<String>,
|
pub id_delimiter: Option<String>,
|
||||||
pub tls_config: Option<TlsConfig>,
|
pub tls_config: Option<TlsConfig>,
|
||||||
|
/// User identifier for tracking purposes.
|
||||||
|
///
|
||||||
|
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||||
|
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||||
|
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||||
|
/// variable that contains the user ID value.
|
||||||
|
pub user_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||||
@@ -145,6 +152,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
|||||||
id_delimiter: config.id_delimiter,
|
id_delimiter: config.id_delimiter,
|
||||||
tls_config: config.tls_config.map(Into::into),
|
tls_config: config.tls_config.map(Into::into),
|
||||||
header_provider: None, // the header provider is set separately later
|
header_provider: None, // the header provider is set separately later
|
||||||
|
user_id: config.user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -145,6 +145,33 @@ class TlsConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClientConfig:
|
class ClientConfig:
|
||||||
|
"""Configuration for the LanceDB Cloud HTTP client.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
user_agent: str
|
||||||
|
User agent string sent with requests.
|
||||||
|
retry_config: RetryConfig
|
||||||
|
Configuration for retrying failed requests.
|
||||||
|
timeout_config: Optional[TimeoutConfig]
|
||||||
|
Configuration for request timeouts.
|
||||||
|
extra_headers: Optional[dict]
|
||||||
|
Additional headers to include in requests.
|
||||||
|
id_delimiter: Optional[str]
|
||||||
|
The delimiter to use when constructing object identifiers.
|
||||||
|
tls_config: Optional[TlsConfig]
|
||||||
|
TLS/mTLS configuration for secure connections.
|
||||||
|
header_provider: Optional[HeaderProvider]
|
||||||
|
Provider for dynamic headers to be added to each request.
|
||||||
|
user_id: Optional[str]
|
||||||
|
User identifier for tracking purposes. This is sent as the
|
||||||
|
`x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||||
|
|
||||||
|
This can also be set via the `LANCEDB_USER_ID` environment variable.
|
||||||
|
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another
|
||||||
|
environment variable that contains the user ID value.
|
||||||
|
"""
|
||||||
|
|
||||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||||
@@ -152,6 +179,7 @@ class ClientConfig:
|
|||||||
id_delimiter: Optional[str] = None
|
id_delimiter: Optional[str] = None
|
||||||
tls_config: Optional[TlsConfig] = None
|
tls_config: Optional[TlsConfig] = None
|
||||||
header_provider: Optional["HeaderProvider"] = None
|
header_provider: Optional["HeaderProvider"] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.retry_config, dict):
|
if isinstance(self.retry_config, dict):
|
||||||
|
|||||||
@@ -270,15 +270,17 @@ def _sanitize_data(
|
|||||||
reader,
|
reader,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
|
target_schema=target_schema,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_schema is None:
|
if target_schema is None:
|
||||||
target_schema, reader = _infer_target_schema(reader)
|
target_schema, reader = _infer_target_schema(reader)
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
new_metadata = target_schema.metadata or {}
|
target_schema = target_schema.with_metadata(
|
||||||
new_metadata.update(metadata)
|
_merge_metadata(target_schema.metadata, metadata)
|
||||||
target_schema = target_schema.with_metadata(new_metadata)
|
)
|
||||||
|
|
||||||
_validate_schema(target_schema)
|
_validate_schema(target_schema)
|
||||||
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||||
@@ -294,7 +296,7 @@ def _cast_to_target_schema(
|
|||||||
# pa.Table.cast expects field order not to be changed.
|
# pa.Table.cast expects field order not to be changed.
|
||||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||||
# to match the target schema. We just need to correctly cast the fields.
|
# to match the target schema. We just need to correctly cast the fields.
|
||||||
if reader.schema == target_schema:
|
if reader.schema.equals(target_schema, check_metadata=True):
|
||||||
# Fast path when the schemas are already the same
|
# Fast path when the schemas are already the same
|
||||||
return reader
|
return reader
|
||||||
|
|
||||||
@@ -314,7 +316,13 @@ def _cast_to_target_schema(
|
|||||||
def gen():
|
def gen():
|
||||||
for batch in reader:
|
for batch in reader:
|
||||||
# Table but not RecordBatch has cast.
|
# Table but not RecordBatch has cast.
|
||||||
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
|
cast_batches = (
|
||||||
|
pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()
|
||||||
|
)
|
||||||
|
if cast_batches:
|
||||||
|
yield pa.RecordBatch.from_arrays(
|
||||||
|
cast_batches[0].columns, schema=reordered_schema
|
||||||
|
)
|
||||||
|
|
||||||
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||||
|
|
||||||
@@ -332,37 +340,51 @@ def _align_field_types(
|
|||||||
if target_field is None:
|
if target_field is None:
|
||||||
raise ValueError(f"Field '{field.name}' not found in target schema")
|
raise ValueError(f"Field '{field.name}' not found in target schema")
|
||||||
if pa.types.is_struct(target_field.type):
|
if pa.types.is_struct(target_field.type):
|
||||||
new_type = pa.struct(
|
if pa.types.is_struct(field.type):
|
||||||
_align_field_types(
|
new_type = pa.struct(
|
||||||
field.type.fields,
|
_align_field_types(
|
||||||
target_field.type.fields,
|
field.type.fields,
|
||||||
|
target_field.type.fields,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
new_type = target_field.type
|
||||||
elif pa.types.is_list(target_field.type):
|
elif pa.types.is_list(target_field.type):
|
||||||
new_type = pa.list_(
|
if _is_list_like(field.type):
|
||||||
_align_field_types(
|
new_type = pa.list_(
|
||||||
[field.type.value_field],
|
_align_field_types(
|
||||||
[target_field.type.value_field],
|
[field.type.value_field],
|
||||||
)[0]
|
[target_field.type.value_field],
|
||||||
)
|
)[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_type = target_field.type
|
||||||
elif pa.types.is_large_list(target_field.type):
|
elif pa.types.is_large_list(target_field.type):
|
||||||
new_type = pa.large_list(
|
if _is_list_like(field.type):
|
||||||
_align_field_types(
|
new_type = pa.large_list(
|
||||||
[field.type.value_field],
|
_align_field_types(
|
||||||
[target_field.type.value_field],
|
[field.type.value_field],
|
||||||
)[0]
|
[target_field.type.value_field],
|
||||||
)
|
)[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_type = target_field.type
|
||||||
elif pa.types.is_fixed_size_list(target_field.type):
|
elif pa.types.is_fixed_size_list(target_field.type):
|
||||||
new_type = pa.list_(
|
if _is_list_like(field.type):
|
||||||
_align_field_types(
|
new_type = pa.list_(
|
||||||
[field.type.value_field],
|
_align_field_types(
|
||||||
[target_field.type.value_field],
|
[field.type.value_field],
|
||||||
)[0],
|
[target_field.type.value_field],
|
||||||
target_field.type.list_size,
|
)[0],
|
||||||
)
|
target_field.type.list_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_type = target_field.type
|
||||||
else:
|
else:
|
||||||
new_type = target_field.type
|
new_type = target_field.type
|
||||||
new_fields.append(pa.field(field.name, new_type, field.nullable))
|
new_fields.append(
|
||||||
|
pa.field(field.name, new_type, field.nullable, target_field.metadata)
|
||||||
|
)
|
||||||
return new_fields
|
return new_fields
|
||||||
|
|
||||||
|
|
||||||
@@ -440,6 +462,7 @@ def sanitize_create_table(
|
|||||||
schema = data.schema
|
schema = data.schema
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
|
metadata = _merge_metadata(schema.metadata, metadata)
|
||||||
schema = schema.with_metadata(metadata)
|
schema = schema.with_metadata(metadata)
|
||||||
# Need to apply metadata to the data as well
|
# Need to apply metadata to the data as well
|
||||||
if isinstance(data, pa.Table):
|
if isinstance(data, pa.Table):
|
||||||
@@ -492,9 +515,9 @@ def _append_vector_columns(
|
|||||||
vector columns to the table.
|
vector columns to the table.
|
||||||
"""
|
"""
|
||||||
if schema is None:
|
if schema is None:
|
||||||
metadata = metadata or {}
|
metadata = _merge_metadata(metadata)
|
||||||
else:
|
else:
|
||||||
metadata = schema.metadata or metadata or {}
|
metadata = _merge_metadata(schema.metadata, metadata)
|
||||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||||
|
|
||||||
if not functions:
|
if not functions:
|
||||||
@@ -3211,43 +3234,157 @@ def _handle_bad_vectors(
|
|||||||
reader: pa.RecordBatchReader,
|
reader: pa.RecordBatchReader,
|
||||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
target_schema: Optional[pa.Schema] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader:
|
||||||
vector_columns = []
|
vector_columns = _find_vector_columns(reader.schema, target_schema, metadata)
|
||||||
|
if not vector_columns:
|
||||||
|
return reader
|
||||||
|
|
||||||
for field in reader.schema:
|
output_schema = _vector_output_schema(reader.schema, vector_columns)
|
||||||
# They can provide a 'vector' column that isn't yet a FSL
|
|
||||||
named_vector_col = (
|
|
||||||
(
|
|
||||||
pa.types.is_list(field.type)
|
|
||||||
or pa.types.is_large_list(field.type)
|
|
||||||
or pa.types.is_fixed_size_list(field.type)
|
|
||||||
)
|
|
||||||
and pa.types.is_floating(field.type.value_type)
|
|
||||||
and field.name == VECTOR_COLUMN_NAME
|
|
||||||
)
|
|
||||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
|
||||||
# is a vector column. This is definitely a bit hacky.
|
|
||||||
likely_vector_col = (
|
|
||||||
pa.types.is_fixed_size_list(field.type)
|
|
||||||
and pa.types.is_floating(field.type.value_type)
|
|
||||||
and (field.type.list_size >= 10)
|
|
||||||
)
|
|
||||||
|
|
||||||
if named_vector_col or likely_vector_col:
|
|
||||||
vector_columns.append(field.name)
|
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for batch in reader:
|
for batch in reader:
|
||||||
for name in vector_columns:
|
pending_dims = []
|
||||||
|
for vector_column in vector_columns:
|
||||||
|
dim = vector_column["expected_dim"]
|
||||||
|
if target_schema is not None and dim is None:
|
||||||
|
dim = _infer_vector_dim(batch[vector_column["name"]])
|
||||||
|
pending_dims.append(vector_column)
|
||||||
batch = _handle_bad_vector_column(
|
batch = _handle_bad_vector_column(
|
||||||
batch,
|
batch,
|
||||||
vector_column_name=name,
|
vector_column_name=vector_column["name"],
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
|
expected_dim=dim,
|
||||||
|
expected_value_type=vector_column["expected_value_type"],
|
||||||
)
|
)
|
||||||
yield batch
|
for vector_column in pending_dims:
|
||||||
|
if vector_column["expected_dim"] is None:
|
||||||
|
vector_column["expected_dim"] = _infer_vector_dim(
|
||||||
|
batch[vector_column["name"]]
|
||||||
|
)
|
||||||
|
if batch.schema.equals(output_schema, check_metadata=True):
|
||||||
|
yield batch
|
||||||
|
continue
|
||||||
|
|
||||||
return pa.RecordBatchReader.from_batches(reader.schema, gen())
|
cast_batches = (
|
||||||
|
pa.Table.from_batches([batch]).cast(output_schema).to_batches()
|
||||||
|
)
|
||||||
|
if cast_batches:
|
||||||
|
yield pa.RecordBatch.from_arrays(
|
||||||
|
cast_batches[0].columns,
|
||||||
|
schema=output_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pa.RecordBatchReader.from_batches(output_schema, gen())
|
||||||
|
|
||||||
|
|
||||||
|
def _find_vector_columns(
|
||||||
|
reader_schema: pa.Schema,
|
||||||
|
target_schema: Optional[pa.Schema],
|
||||||
|
metadata: Optional[dict],
|
||||||
|
) -> List[dict]:
|
||||||
|
if target_schema is None:
|
||||||
|
vector_columns = []
|
||||||
|
for field in reader_schema:
|
||||||
|
named_vector_col = (
|
||||||
|
_is_list_like(field.type)
|
||||||
|
and pa.types.is_floating(field.type.value_type)
|
||||||
|
and field.name == VECTOR_COLUMN_NAME
|
||||||
|
)
|
||||||
|
likely_vector_col = (
|
||||||
|
pa.types.is_fixed_size_list(field.type)
|
||||||
|
and pa.types.is_floating(field.type.value_type)
|
||||||
|
and (field.type.list_size >= 10)
|
||||||
|
)
|
||||||
|
if named_vector_col or likely_vector_col:
|
||||||
|
vector_columns.append(
|
||||||
|
{
|
||||||
|
"name": field.name,
|
||||||
|
"expected_dim": None,
|
||||||
|
"expected_value_type": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return vector_columns
|
||||||
|
|
||||||
|
reader_column_names = set(reader_schema.names)
|
||||||
|
active_metadata = _merge_metadata(target_schema.metadata, metadata)
|
||||||
|
embedding_function_columns = set(
|
||||||
|
EmbeddingFunctionRegistry.get_instance().parse_functions(active_metadata).keys()
|
||||||
|
)
|
||||||
|
vector_columns = []
|
||||||
|
for field in target_schema:
|
||||||
|
if field.name not in reader_column_names:
|
||||||
|
continue
|
||||||
|
if not _is_list_like(field.type) or not pa.types.is_floating(
|
||||||
|
field.type.value_type
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
reader_field = reader_schema.field(field.name)
|
||||||
|
named_vector_col = (
|
||||||
|
field.name in embedding_function_columns
|
||||||
|
or field.name == VECTOR_COLUMN_NAME
|
||||||
|
or (field.name == "embedding" and pa.types.is_fixed_size_list(field.type))
|
||||||
|
)
|
||||||
|
typed_fixed_vector_col = (
|
||||||
|
pa.types.is_fixed_size_list(reader_field.type)
|
||||||
|
and pa.types.is_floating(reader_field.type.value_type)
|
||||||
|
and reader_field.type.list_size >= 10
|
||||||
|
)
|
||||||
|
|
||||||
|
if named_vector_col or typed_fixed_vector_col:
|
||||||
|
vector_columns.append(
|
||||||
|
{
|
||||||
|
"name": field.name,
|
||||||
|
"expected_dim": (
|
||||||
|
field.type.list_size
|
||||||
|
if pa.types.is_fixed_size_list(field.type)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"expected_value_type": field.type.value_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return vector_columns
|
||||||
|
|
||||||
|
|
||||||
|
def _vector_output_schema(
|
||||||
|
reader_schema: pa.Schema,
|
||||||
|
vector_columns: List[dict],
|
||||||
|
) -> pa.Schema:
|
||||||
|
columns_by_name = {column["name"]: column for column in vector_columns}
|
||||||
|
fields = []
|
||||||
|
for field in reader_schema:
|
||||||
|
column = columns_by_name.get(field.name)
|
||||||
|
if column is None:
|
||||||
|
output_type = field.type
|
||||||
|
else:
|
||||||
|
output_type = _vector_output_type(field, column)
|
||||||
|
fields.append(pa.field(field.name, output_type, field.nullable, field.metadata))
|
||||||
|
return pa.schema(fields, metadata=reader_schema.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def _vector_output_type(field: pa.Field, vector_column: dict) -> pa.DataType:
|
||||||
|
if not _is_list_like(field.type):
|
||||||
|
return field.type
|
||||||
|
|
||||||
|
if vector_column["expected_value_type"] is not None and (
|
||||||
|
pa.types.is_null(field.type.value_type)
|
||||||
|
or pa.types.is_integer(field.type.value_type)
|
||||||
|
or pa.types.is_unsigned_integer(field.type.value_type)
|
||||||
|
):
|
||||||
|
return pa.list_(vector_column["expected_value_type"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
vector_column["expected_dim"] is not None
|
||||||
|
and pa.types.is_fixed_size_list(field.type)
|
||||||
|
and field.type.list_size != vector_column["expected_dim"]
|
||||||
|
):
|
||||||
|
return pa.list_(field.type.value_type)
|
||||||
|
|
||||||
|
return field.type
|
||||||
|
|
||||||
|
|
||||||
def _handle_bad_vector_column(
|
def _handle_bad_vector_column(
|
||||||
@@ -3255,6 +3392,8 @@ def _handle_bad_vector_column(
|
|||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
expected_dim: Optional[int] = None,
|
||||||
|
expected_value_type: Optional[pa.DataType] = None,
|
||||||
) -> pa.RecordBatch:
|
) -> pa.RecordBatch:
|
||||||
"""
|
"""
|
||||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||||
@@ -3271,14 +3410,39 @@ def _handle_bad_vector_column(
|
|||||||
fill_value: float, default 0.0
|
fill_value: float, default 0.0
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
|
position = data.column_names.index(vector_column_name)
|
||||||
vec_arr = data[vector_column_name]
|
vec_arr = data[vector_column_name]
|
||||||
|
if not _is_list_like(vec_arr.type):
|
||||||
|
return data
|
||||||
|
|
||||||
has_nan = has_nan_values(vec_arr)
|
if (
|
||||||
|
expected_dim is not None
|
||||||
|
and pa.types.is_fixed_size_list(vec_arr.type)
|
||||||
|
and vec_arr.type.list_size != expected_dim
|
||||||
|
):
|
||||||
|
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(vec_arr.type.value_type))
|
||||||
|
data = data.set_column(position, vector_column_name, vec_arr)
|
||||||
|
|
||||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
if expected_value_type is not None and (
|
||||||
|
pa.types.is_integer(vec_arr.type.value_type)
|
||||||
|
or pa.types.is_unsigned_integer(vec_arr.type.value_type)
|
||||||
|
):
|
||||||
|
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(expected_value_type))
|
||||||
|
data = data.set_column(position, vector_column_name, vec_arr)
|
||||||
|
|
||||||
|
if pa.types.is_floating(vec_arr.type.value_type):
|
||||||
|
has_nan = has_nan_values(vec_arr)
|
||||||
|
else:
|
||||||
|
has_nan = pa.array([False] * len(vec_arr))
|
||||||
|
|
||||||
|
if expected_dim is not None:
|
||||||
|
dim = expected_dim
|
||||||
|
elif pa.types.is_fixed_size_list(vec_arr.type):
|
||||||
dim = vec_arr.type.list_size
|
dim = vec_arr.type.list_size
|
||||||
else:
|
else:
|
||||||
dim = _modal_list_size(vec_arr)
|
dim = _infer_vector_dim(vec_arr)
|
||||||
|
if dim is None:
|
||||||
|
return data
|
||||||
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
|
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
|
||||||
|
|
||||||
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
|
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
|
||||||
@@ -3316,13 +3480,12 @@ def _handle_bad_vector_column(
|
|||||||
)
|
)
|
||||||
vec_arr = pc.if_else(
|
vec_arr = pc.if_else(
|
||||||
is_bad,
|
is_bad,
|
||||||
pa.scalar([fill_value] * dim),
|
pa.scalar([fill_value] * dim, type=vec_arr.type),
|
||||||
vec_arr,
|
vec_arr,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
||||||
|
|
||||||
position = data.column_names.index(vector_column_name)
|
|
||||||
return data.set_column(position, vector_column_name, vec_arr)
|
return data.set_column(position, vector_column_name, vec_arr)
|
||||||
|
|
||||||
|
|
||||||
@@ -3343,6 +3506,28 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
|
|||||||
return pc.is_in(indices, has_nan_indices)
|
return pc.is_in(indices, has_nan_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_list_like(data_type: pa.DataType) -> bool:
|
||||||
|
return (
|
||||||
|
pa.types.is_list(data_type)
|
||||||
|
or pa.types.is_large_list(data_type)
|
||||||
|
or pa.types.is_fixed_size_list(data_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_metadata(*metadata_dicts: Optional[dict]) -> dict:
|
||||||
|
merged = {}
|
||||||
|
for metadata in metadata_dicts:
|
||||||
|
if metadata is None:
|
||||||
|
continue
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode("utf-8")
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value.encode("utf-8")
|
||||||
|
merged[key] = value
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def _name_suggests_vector_column(field_name: str) -> bool:
|
def _name_suggests_vector_column(field_name: str) -> bool:
|
||||||
"""Check if a field name indicates a vector column."""
|
"""Check if a field name indicates a vector column."""
|
||||||
name_lower = field_name.lower()
|
name_lower = field_name.lower()
|
||||||
@@ -3410,6 +3595,16 @@ def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
|||||||
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
|
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_vector_dim(arr: Union[pa.Array, pa.ChunkedArray]) -> Optional[int]:
|
||||||
|
if not _is_list_like(arr.type):
|
||||||
|
return None
|
||||||
|
lengths = pc.list_value_length(arr)
|
||||||
|
lengths = pc.filter(lengths, pc.greater(lengths, 0))
|
||||||
|
if len(lengths) == 0:
|
||||||
|
return None
|
||||||
|
return pc.mode(lengths)[0].as_py()["mode"]
|
||||||
|
|
||||||
|
|
||||||
def _validate_schema(schema: pa.Schema):
|
def _validate_schema(schema: pa.Schema):
|
||||||
"""
|
"""
|
||||||
Make sure the metadata is valid utf8
|
Make sure the metadata is valid utf8
|
||||||
|
|||||||
@@ -1049,6 +1049,231 @@ def test_add_with_nans(mem_db: DBConnection):
|
|||||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str
|
||||||
|
embedding: Vector(16)
|
||||||
|
|
||||||
|
table = mem_db.create_table("test_empty_embeddings", schema=Schema)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"text": "hello", "embedding": []},
|
||||||
|
{"text": "bar", "embedding": [0.1] * 16},
|
||||||
|
],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = table.to_arrow()
|
||||||
|
assert data["text"].to_pylist() == ["bar"]
|
||||||
|
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str
|
||||||
|
embedding: Vector(4)
|
||||||
|
|
||||||
|
table = mem_db.create_table("test_integer_embeddings", schema=Schema)
|
||||||
|
table.add(
|
||||||
|
[{"text": "foo", "embedding": [1, 2, 3, 4]}],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
):
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||||
|
pa.field("bbox", pa.list_(pa.float32(), 4)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
table = mem_db.create_table("test_bbox_schema", schema=schema)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||||
|
table.add(
|
||||||
|
[{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
):
|
||||||
|
schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))])
|
||||||
|
table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"features": []},
|
||||||
|
{"features": [0.1] * 16},
|
||||||
|
],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection):
|
||||||
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||||
|
table = mem_db.create_table("test_schema_list_vector", schema=schema)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"vector": [1.0, 2.0]},
|
||||||
|
{"vector": [3.0]},
|
||||||
|
{"vector": [4.0, 5.0]},
|
||||||
|
],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
):
|
||||||
|
schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))])
|
||||||
|
table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vec": pa.array(
|
||||||
|
[[float("nan")] * 16, [1.0] * 16],
|
||||||
|
type=pa.list_(pa.float32(), 16),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
table.add(data, on_bad_vectors="drop")
|
||||||
|
|
||||||
|
assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16]
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection):
|
||||||
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||||
|
table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": pa.array(
|
||||||
|
[[1.0, 2.0], [float("nan"), 3.0]],
|
||||||
|
type=pa.list_(pa.float32(), 2),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
table.add(
|
||||||
|
data,
|
||||||
|
on_bad_vectors="fill",
|
||||||
|
fill_value=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("table_name", "batch1", "expected"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"test_schema_list_vector_empty_prefix",
|
||||||
|
pa.record_batch({"vector": [[], []]}),
|
||||||
|
[[], [], [1.0, 2.0], [3.0, 4.0]],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"test_schema_list_vector_all_bad_prefix",
|
||||||
|
pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}),
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
table_name: str,
|
||||||
|
batch1: pa.RecordBatch,
|
||||||
|
expected: list,
|
||||||
|
):
|
||||||
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||||
|
table = mem_db.create_table(table_name, schema=schema)
|
||||||
|
batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]})
|
||||||
|
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||||
|
|
||||||
|
table.add(reader, on_bad_vectors="drop")
|
||||||
|
|
||||||
|
assert table.to_arrow()["vector"].to_pylist() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
):
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = MockTextEmbeddingFunction.create()
|
||||||
|
metadata = registry.get_table_metadata(
|
||||||
|
[
|
||||||
|
EmbeddingFunctionConfig(
|
||||||
|
source_column="text1", vector_column="vec1", function=func
|
||||||
|
),
|
||||||
|
EmbeddingFunctionConfig(
|
||||||
|
source_column="text2", vector_column="vec2", function=func
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vec1", pa.list_(pa.float32())),
|
||||||
|
pa.field("vec2", pa.list_(pa.float32())),
|
||||||
|
],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema)
|
||||||
|
batch1 = pa.record_batch(
|
||||||
|
{
|
||||||
|
"vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]],
|
||||||
|
"vec2": [[float("nan"), 0.0], [5.0, 6.0]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
batch2 = pa.record_batch(
|
||||||
|
{
|
||||||
|
"vec1": [[20.0, 21.0], [30.0, 31.0]],
|
||||||
|
"vec2": [[7.0, 8.0], [9.0, 10.0]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||||
|
|
||||||
|
table.add(reader, on_bad_vectors="drop")
|
||||||
|
|
||||||
|
data = table.to_arrow()
|
||||||
|
assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]
|
||||||
|
assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection):
|
||||||
|
schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))])
|
||||||
|
table = mem_db.create_table("test_non_vector_list_schema", schema=schema)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"embedding_history": [1.0, 2.0]},
|
||||||
|
{"embedding_history": [3.0]},
|
||||||
|
],
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert table.to_arrow()["embedding_history"].to_pylist() == [
|
||||||
|
[1.0, 2.0],
|
||||||
|
[3.0],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash(
|
||||||
|
mem_db: DBConnection,
|
||||||
|
):
|
||||||
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)])
|
||||||
|
table = mem_db.create_table("test_all_null_vector_batch", schema=schema)
|
||||||
|
|
||||||
|
table.add([{"vector": None}], on_bad_vectors="drop")
|
||||||
|
|
||||||
|
assert table.to_arrow()["vector"].to_pylist() == [None]
|
||||||
|
|
||||||
|
|
||||||
def test_restore(mem_db: DBConnection):
|
def test_restore(mem_db: DBConnection):
|
||||||
table = mem_db.create_table(
|
table = mem_db.create_table(
|
||||||
"my_table",
|
"my_table",
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ from lancedb.table import (
|
|||||||
_cast_to_target_schema,
|
_cast_to_target_schema,
|
||||||
_handle_bad_vectors,
|
_handle_bad_vectors,
|
||||||
_into_pyarrow_reader,
|
_into_pyarrow_reader,
|
||||||
_sanitize_data,
|
|
||||||
_infer_target_schema,
|
_infer_target_schema,
|
||||||
|
_merge_metadata,
|
||||||
|
_sanitize_data,
|
||||||
|
sanitize_create_table,
|
||||||
)
|
)
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -304,6 +306,117 @@ def test_handle_bad_vectors_noop():
|
|||||||
assert output["vector"] == vector
|
assert output["vector"] == vector
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_bad_vectors_updates_reader_schema_for_target_schema():
|
||||||
|
data = pa.table({"vector": [[1, 2, 3, 4]]})
|
||||||
|
target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))])
|
||||||
|
|
||||||
|
output = _handle_bad_vectors(
|
||||||
|
data.to_reader(),
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
target_schema=target_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||||
|
assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_data_keeps_target_field_metadata():
|
||||||
|
source_field = pa.field(
|
||||||
|
"vector",
|
||||||
|
pa.list_(pa.float32(), 2),
|
||||||
|
metadata={b"source": b"drop-me"},
|
||||||
|
)
|
||||||
|
target_field = pa.field(
|
||||||
|
"vector",
|
||||||
|
pa.list_(pa.float32(), 2),
|
||||||
|
metadata={b"target": b"keep-me"},
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))},
|
||||||
|
schema=pa.schema([source_field]),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = _sanitize_data(
|
||||||
|
data,
|
||||||
|
target_schema=pa.schema([target_field]),
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
).read_all()
|
||||||
|
|
||||||
|
assert output.schema.field("vector").metadata == {b"target": b"keep-me"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors():
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text",
|
||||||
|
vector_column="custom_vector",
|
||||||
|
function=MockTextEmbeddingFunction.create(),
|
||||||
|
)
|
||||||
|
metadata = registry.get_table_metadata([conf])
|
||||||
|
schema = pa.schema(
|
||||||
|
{
|
||||||
|
"text": pa.string(),
|
||||||
|
"custom_vector": pa.list_(pa.float32(), 10),
|
||||||
|
},
|
||||||
|
metadata={b"note": b"keep-me"},
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": ["bad", "good"],
|
||||||
|
"custom_vector": [[1.0] * 9, [2.0] * 10],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output = _sanitize_data(
|
||||||
|
data,
|
||||||
|
target_schema=schema,
|
||||||
|
metadata=metadata,
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
).read_all()
|
||||||
|
|
||||||
|
assert output["text"].to_pylist() == ["good"]
|
||||||
|
assert output.schema.metadata[b"note"] == b"keep-me"
|
||||||
|
assert b"embedding_functions" in output.schema.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_create_table_merges_and_overrides_embedding_metadata():
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
old_conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text",
|
||||||
|
vector_column="old_vector",
|
||||||
|
function=MockTextEmbeddingFunction.create(),
|
||||||
|
)
|
||||||
|
new_conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text",
|
||||||
|
vector_column="custom_vector",
|
||||||
|
function=MockTextEmbeddingFunction.create(),
|
||||||
|
)
|
||||||
|
metadata = registry.get_table_metadata([new_conf])
|
||||||
|
schema = pa.schema(
|
||||||
|
{
|
||||||
|
"text": pa.string(),
|
||||||
|
"custom_vector": pa.list_(pa.float32(), 10),
|
||||||
|
},
|
||||||
|
metadata=_merge_metadata(
|
||||||
|
{b"note": b"keep-me"},
|
||||||
|
registry.get_table_metadata([old_conf]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
data, schema = sanitize_create_table(
|
||||||
|
pa.table({"text": ["good"]}),
|
||||||
|
schema,
|
||||||
|
metadata=metadata,
|
||||||
|
on_bad_vectors="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert schema.metadata[b"note"] == b"keep-me"
|
||||||
|
assert b"embedding_functions" in schema.metadata
|
||||||
|
assert data.schema.metadata[b"note"] == b"keep-me"
|
||||||
|
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||||
|
assert set(funcs.keys()) == {"custom_vector"}
|
||||||
|
|
||||||
|
|
||||||
class TestModel(lancedb.pydantic.LanceModel):
|
class TestModel(lancedb.pydantic.LanceModel):
|
||||||
a: Optional[int]
|
a: Optional[int]
|
||||||
b: Optional[int]
|
b: Optional[int]
|
||||||
|
|||||||
@@ -547,6 +547,7 @@ pub struct PyClientConfig {
|
|||||||
id_delimiter: Option<String>,
|
id_delimiter: Option<String>,
|
||||||
tls_config: Option<PyClientTlsConfig>,
|
tls_config: Option<PyClientTlsConfig>,
|
||||||
header_provider: Option<Py<PyAny>>,
|
header_provider: Option<Py<PyAny>>,
|
||||||
|
user_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
@@ -631,6 +632,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
|||||||
id_delimiter: value.id_delimiter,
|
id_delimiter: value.id_delimiter,
|
||||||
tls_config: value.tls_config.map(Into::into),
|
tls_config: value.tls_config.map(Into::into),
|
||||||
header_provider,
|
header_provider,
|
||||||
|
user_id: value.user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.28.0-beta.0"
|
version = "0.28.0-beta.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ pub struct ClientConfig {
|
|||||||
pub tls_config: Option<TlsConfig>,
|
pub tls_config: Option<TlsConfig>,
|
||||||
/// Provider for custom headers to be added to each request
|
/// Provider for custom headers to be added to each request
|
||||||
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||||
|
/// User identifier for tracking purposes.
|
||||||
|
///
|
||||||
|
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||||
|
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||||
|
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||||
|
/// variable that contains the user ID value.
|
||||||
|
pub user_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for ClientConfig {
|
impl std::fmt::Debug for ClientConfig {
|
||||||
@@ -67,6 +74,7 @@ impl std::fmt::Debug for ClientConfig {
|
|||||||
"header_provider",
|
"header_provider",
|
||||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||||
)
|
)
|
||||||
|
.field("user_id", &self.user_id)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,10 +89,41 @@ impl Default for ClientConfig {
|
|||||||
id_delimiter: None,
|
id_delimiter: None,
|
||||||
tls_config: None,
|
tls_config: None,
|
||||||
header_provider: None,
|
header_provider: None,
|
||||||
|
user_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ClientConfig {
|
||||||
|
/// Resolve the user ID from the config or environment variables.
|
||||||
|
///
|
||||||
|
/// Resolution order:
|
||||||
|
/// 1. If `user_id` is set in the config, use that value
|
||||||
|
/// 2. If `LANCEDB_USER_ID` environment variable is set, use that value
|
||||||
|
/// 3. If `LANCEDB_USER_ID_ENV_KEY` is set, read the env var it points to
|
||||||
|
/// 4. Otherwise, return None
|
||||||
|
pub fn resolve_user_id(&self) -> Option<String> {
|
||||||
|
if self.user_id.is_some() {
|
||||||
|
return self.user_id.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(user_id) = std::env::var("LANCEDB_USER_ID")
|
||||||
|
&& !user_id.is_empty()
|
||||||
|
{
|
||||||
|
return Some(user_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(env_key) = std::env::var("LANCEDB_USER_ID_ENV_KEY")
|
||||||
|
&& let Ok(user_id) = std::env::var(&env_key)
|
||||||
|
&& !user_id.is_empty()
|
||||||
|
{
|
||||||
|
return Some(user_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// How to handle timeouts for HTTP requests.
|
/// How to handle timeouts for HTTP requests.
|
||||||
#[derive(Clone, Default, Debug)]
|
#[derive(Clone, Default, Debug)]
|
||||||
pub struct TimeoutConfig {
|
pub struct TimeoutConfig {
|
||||||
@@ -464,6 +503,15 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(user_id) = config.resolve_user_id() {
|
||||||
|
headers.insert(
|
||||||
|
HeaderName::from_static("x-lancedb-user-id"),
|
||||||
|
HeaderValue::from_str(&user_id).map_err(|_| Error::InvalidInput {
|
||||||
|
message: format!("non-ascii user_id '{}' provided", user_id),
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(headers)
|
Ok(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1072,4 +1120,91 @@ mod tests {
|
|||||||
_ => panic!("Expected Runtime error"),
|
_ => panic!("Expected Runtime error"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_direct_value() {
|
||||||
|
let config = ClientConfig {
|
||||||
|
user_id: Some("direct-user-id".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_none() {
|
||||||
|
let config = ClientConfig::default();
|
||||||
|
// Clear env vars that might be set from other tests
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID");
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||||
|
}
|
||||||
|
assert_eq!(config.resolve_user_id(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_from_env() {
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||||
|
}
|
||||||
|
let config = ClientConfig::default();
|
||||||
|
assert_eq!(config.resolve_user_id(), Some("env-user-id".to_string()));
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_from_env_key() {
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID");
|
||||||
|
std::env::set_var("LANCEDB_USER_ID_ENV_KEY", "MY_CUSTOM_USER_ID");
|
||||||
|
std::env::set_var("MY_CUSTOM_USER_ID", "custom-env-user-id");
|
||||||
|
}
|
||||||
|
let config = ClientConfig::default();
|
||||||
|
assert_eq!(
|
||||||
|
config.resolve_user_id(),
|
||||||
|
Some("custom-env-user-id".to_string())
|
||||||
|
);
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||||
|
std::env::remove_var("MY_CUSTOM_USER_ID");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_direct_takes_precedence() {
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||||
|
}
|
||||||
|
let config = ClientConfig {
|
||||||
|
user_id: Some("direct-user-id".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_user_id_empty_env_ignored() {
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("LANCEDB_USER_ID", "");
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||||
|
}
|
||||||
|
let config = ClientConfig::default();
|
||||||
|
assert_eq!(config.resolve_user_id(), None);
|
||||||
|
// SAFETY: This is only called in tests
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("LANCEDB_USER_ID");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user