mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-05 05:10:41 +00:00
Compare commits
11 Commits
v0.30.1-be
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d90febea3 | ||
|
|
39a9f3e1e9 | ||
|
|
952055d428 | ||
|
|
927ba2c948 | ||
|
|
415d199c15 | ||
|
|
a16676e05f | ||
|
|
4e44262499 | ||
|
|
632375faf1 | ||
|
|
9969191d0d | ||
|
|
1e7326cd8c | ||
|
|
9483b534af |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.1-beta.1"
|
||||
current_version = "0.30.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
@@ -21,3 +21,14 @@ updates:
|
||||
update-types:
|
||||
- minor
|
||||
- patch
|
||||
|
||||
- package-ecosystem: pip
|
||||
directory: /python
|
||||
schedule:
|
||||
interval: weekly
|
||||
# Only update uv.lock, never widen version requirements in pyproject.toml.
|
||||
versioning-strategy: lockfile-only
|
||||
groups:
|
||||
python-deps:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -5128,7 +5128,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.30.1-beta.2"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -5211,7 +5211,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.30.1-beta.2"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5234,7 +5234,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.1-beta.0"
|
||||
version = "0.33.1-beta.2"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
|
||||
26
REVIEW.md
Normal file
26
REVIEW.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Code review guidelines
|
||||
|
||||
Repo-specific guidance for automated PR reviews.
|
||||
|
||||
## Cross-SDK parity
|
||||
|
||||
LanceDB exposes the same core (`rust/lancedb`) through Python, TypeScript (`nodejs`),
|
||||
and Java bindings. Behavioral drift between SDKs is a recurring problem, so watch for
|
||||
parity gaps when reviewing — but only flag real ones:
|
||||
|
||||
* If the change adds or modifies user-facing API or behavior in the shared core
|
||||
(`rust/lancedb`), check whether each binding that should expose it (`python`,
|
||||
`nodejs`) does. A core change with no corresponding binding update is worth a note.
|
||||
* If the change adds or modifies a public API in one SDK but not the other, open the
|
||||
sibling SDK's corresponding module and state whether an equivalent exists. If not,
|
||||
note it as a possible parity gap and suggest a follow-up issue.
|
||||
* For bug fixes, first read the sibling SDK's analogous code path to check whether the
|
||||
same bug exists there. Only raise parity if it actually does. Do not ask to "port" a
|
||||
fix for a bug that only ever existed in one binding.
|
||||
* Stay silent on internal-only refactors, tests, docs, and changes with no cross-SDK
|
||||
surface.
|
||||
* Parity expectations apply to the Python and TypeScript (`nodejs`) SDKs. Java currently
|
||||
implements only the remote table, not the local/embedded backend, so it is expected to
|
||||
be partial — do not flag Java for missing local-only functionality.
|
||||
* Keep parity feedback to a short, clearly-labeled note (e.g. "Possible SDK parity
|
||||
gap: …"). It is advisory, not a merge blocker.
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.30.1-beta.1</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.1</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.1</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.30.1-beta.1"
|
||||
version = "0.30.1-beta.2"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.30.1-beta.1",
|
||||
"version": "0.30.1-beta.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.33.1-beta.1"
|
||||
current_version = "0.33.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.1-beta.1"
|
||||
version = "0.33.1-beta.2"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -91,14 +91,14 @@ def _schema_has_blob_field(schema: pa.Schema) -> bool:
|
||||
|
||||
|
||||
def _blob_mode_requires_native_pandas(blob_mode: BlobMode, schema: pa.Schema) -> bool:
|
||||
return blob_mode in ("lazy", "bytes") and _schema_has_blob_field(schema)
|
||||
return blob_mode in _BLOB_MODE_TO_HANDLING and _schema_has_blob_field(schema)
|
||||
|
||||
|
||||
def _unsupported_blob_pandas_error(reason: str) -> RuntimeError:
|
||||
return RuntimeError(
|
||||
"blob_mode='lazy' and blob_mode='bytes' require Lance native pandas "
|
||||
f"conversion for queries that return blob columns, but {reason}. "
|
||||
"Use blob_mode='descriptions' or remove blob columns from the projection."
|
||||
"blob columns require Lance native scanner conversion for query "
|
||||
f"to_pandas(), but {reason}. Use a plain scan query or remove blob "
|
||||
"columns from the projection."
|
||||
)
|
||||
|
||||
|
||||
@@ -149,19 +149,48 @@ def _projection_to_scanner_kwargs(
|
||||
return {"columns": projection}
|
||||
|
||||
|
||||
def _scanner_kwargs_for_query(query: Query, blob_mode: BlobMode) -> Dict[str, Any]:
|
||||
def _scanner_kwargs_for_query(
|
||||
query: Query, blob_mode: BlobMode, dataset: Optional[Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
fragments = _scanner_fragments_for_query(query, dataset)
|
||||
kwargs = {
|
||||
**_projection_to_scanner_kwargs(query.columns),
|
||||
"filter": _filter_to_sql(query.filter),
|
||||
"limit": query.limit,
|
||||
"offset": query.offset,
|
||||
"with_row_id": query.with_row_id,
|
||||
"with_row_address": query.with_row_address,
|
||||
"fast_search": query.fast_search,
|
||||
"blob_handling": _BLOB_MODE_TO_HANDLING[blob_mode],
|
||||
"fragments": fragments,
|
||||
}
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
|
||||
def _scanner_fragments_for_query(query: Query, dataset: Optional[Any]) -> Optional[Any]:
|
||||
if query.fragments is not None and query.fragment_ids is not None:
|
||||
raise ValueError("fragments and fragment_ids cannot both be set")
|
||||
if query.fragments is not None:
|
||||
return query.fragments
|
||||
if query.fragment_ids is None:
|
||||
return None
|
||||
if dataset is None:
|
||||
raise ValueError("fragment_ids require a Lance dataset")
|
||||
|
||||
requested = set(query.fragment_ids)
|
||||
fragments = [
|
||||
fragment
|
||||
for fragment in dataset.get_fragments()
|
||||
if fragment.fragment_id in requested
|
||||
]
|
||||
found = {fragment.fragment_id for fragment in fragments}
|
||||
missing = requested - found
|
||||
if missing:
|
||||
missing_ids = ", ".join(str(fragment_id) for fragment_id in sorted(missing))
|
||||
raise ValueError(f"fragment_ids not found in dataset: {missing_ids}")
|
||||
return fragments
|
||||
|
||||
|
||||
def _ensure_lazy_blob_frame(
|
||||
df: "pd.DataFrame", schema: pa.Schema, blob_mode: BlobMode
|
||||
) -> "pd.DataFrame":
|
||||
@@ -179,6 +208,16 @@ def _ensure_lazy_blob_frame(
|
||||
return df
|
||||
|
||||
|
||||
def _scanner_to_table(scanner: Any) -> pa.Table:
|
||||
if hasattr(scanner, "to_pyarrow"):
|
||||
reader = scanner.to_pyarrow()
|
||||
return reader.read_all()
|
||||
if hasattr(scanner, "to_table"):
|
||||
return scanner.to_table()
|
||||
reader = scanner.to_reader()
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataFrame":
|
||||
schema = getattr(scanner, "projected_schema", None)
|
||||
if schema is None:
|
||||
@@ -199,14 +238,7 @@ def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataF
|
||||
return _ensure_lazy_blob_frame(df, schema, blob_mode)
|
||||
return df
|
||||
|
||||
if hasattr(scanner, "to_pyarrow"):
|
||||
reader = scanner.to_pyarrow()
|
||||
tbl = reader.read_all()
|
||||
elif hasattr(scanner, "to_table"):
|
||||
tbl = scanner.to_table()
|
||||
else:
|
||||
reader = scanner.to_reader()
|
||||
tbl = reader.read_all()
|
||||
tbl = _scanner_to_table(scanner)
|
||||
if blob_mode == "lazy" and _schema_has_blob_field(tbl.schema):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"the Lance scanner does not expose to_pandas"
|
||||
@@ -648,6 +680,13 @@ class Query(pydantic.BaseModel):
|
||||
# if true, include the row id in the results
|
||||
with_row_id: Optional[bool] = None
|
||||
|
||||
# if true, include the row address in the results
|
||||
with_row_address: Optional[bool] = None
|
||||
|
||||
# Lance fragments or fragment ids to scan on scanner-backed plain queries
|
||||
fragments: Optional[Any] = None
|
||||
fragment_ids: Optional[List[int]] = None
|
||||
|
||||
# offset to start fetching results from
|
||||
offset: Optional[int] = None
|
||||
|
||||
@@ -840,6 +879,9 @@ class LanceQueryBuilder(ABC):
|
||||
self._where = None
|
||||
self._postfilter = None
|
||||
self._with_row_id = None
|
||||
self._with_row_address = None
|
||||
self._fragments = None
|
||||
self._fragment_ids = None
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
@@ -901,9 +943,11 @@ class LanceQueryBuilder(ABC):
|
||||
schema = output_schema()
|
||||
if _blob_mode_requires_native_pandas(blob_mode, schema):
|
||||
native_error = None
|
||||
if flatten is None and timeout is None:
|
||||
if (flatten is None or blob_mode == "descriptions") and timeout is None:
|
||||
try:
|
||||
df = self._plain_scan_to_pandas(blob_mode, **kwargs)
|
||||
df = self._plain_scan_to_pandas(
|
||||
blob_mode, flatten=flatten, **kwargs
|
||||
)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
@@ -1125,6 +1169,32 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""Set whether to return row addresses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_row_address: bool, default True
|
||||
If True, return the _rowaddr column in the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._with_row_address = with_row_address
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""Set the Lance fragments to scan for plain scanner-backed queries."""
|
||||
self._fragments = fragments
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""Set the Lance fragment ids to scan for plain scanner-backed queries."""
|
||||
self._fragment_ids = fragment_ids
|
||||
return self
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
@@ -1267,6 +1337,7 @@ class LanceQueryBuilder(ABC):
|
||||
def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
query = self.to_query_object()
|
||||
@@ -1274,7 +1345,12 @@ class LanceQueryBuilder(ABC):
|
||||
return None
|
||||
|
||||
dataset = self._table.to_lance()
|
||||
scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode))
|
||||
scanner = dataset.scanner(
|
||||
**_scanner_kwargs_for_query(query, blob_mode, dataset)
|
||||
)
|
||||
if flatten is not None:
|
||||
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
@@ -1548,6 +1624,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
@@ -1750,6 +1829,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
limit=self._limit,
|
||||
postfilter=self._postfilter,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
full_text_query=FullTextSearchQuery(
|
||||
query=self._query, columns=self._fts_columns
|
||||
),
|
||||
@@ -1820,6 +1902,9 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
offset=self._offset,
|
||||
order_by=self._order_by,
|
||||
)
|
||||
@@ -2411,6 +2496,9 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
self._inner = inner
|
||||
self._table = table
|
||||
self._with_row_address = None
|
||||
self._fragments = None
|
||||
self._fragment_ids = None
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
@@ -2419,7 +2507,11 @@ class AsyncQueryBase(object):
|
||||
This is currently experimental but can be useful as the query object is pure
|
||||
python and more easily serializable.
|
||||
"""
|
||||
return Query.from_inner(self._inner.to_query_request())
|
||||
query = Query.from_inner(self._inner.to_query_request())
|
||||
query.with_row_address = self._with_row_address
|
||||
query.fragments = self._fragments
|
||||
query.fragment_ids = self._fragment_ids
|
||||
return query
|
||||
|
||||
def select(self, columns: Union[List[str], dict[str, str]]) -> Self:
|
||||
"""
|
||||
@@ -2476,6 +2568,27 @@ class AsyncQueryBase(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""
|
||||
Include the _rowaddr column in scanner-backed plain query results.
|
||||
"""
|
||||
self._with_row_address = with_row_address
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragments.
|
||||
"""
|
||||
self._fragments = fragments
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragment ids.
|
||||
"""
|
||||
self._fragment_ids = fragment_ids
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
*,
|
||||
@@ -2601,9 +2714,11 @@ class AsyncQueryBase(object):
|
||||
schema = await self.output_schema()
|
||||
if _blob_mode_requires_native_pandas(blob_mode, schema):
|
||||
native_error = None
|
||||
if flatten is None and timeout is None:
|
||||
if (flatten is None or blob_mode == "descriptions") and timeout is None:
|
||||
try:
|
||||
df = await self._plain_scan_to_pandas(blob_mode, **kwargs)
|
||||
df = await self._plain_scan_to_pandas(
|
||||
blob_mode, flatten=flatten, **kwargs
|
||||
)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
@@ -2625,6 +2740,7 @@ class AsyncQueryBase(object):
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
if self._table is None:
|
||||
@@ -2635,7 +2751,12 @@ class AsyncQueryBase(object):
|
||||
return None
|
||||
|
||||
dataset = await self._table._to_lance()
|
||||
scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode))
|
||||
scanner = dataset.scanner(
|
||||
**_scanner_kwargs_for_query(query, blob_mode, dataset)
|
||||
)
|
||||
if flatten is not None:
|
||||
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
async def to_polars(
|
||||
@@ -3522,6 +3643,7 @@ class AsyncTakeQuery(AsyncQueryBase):
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
return None
|
||||
@@ -3576,6 +3698,27 @@ class BaseQueryBuilder(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""
|
||||
Include the _rowaddr column in scanner-backed plain query results.
|
||||
"""
|
||||
self._inner.with_row_address(with_row_address)
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragments.
|
||||
"""
|
||||
self._inner.with_fragments(fragments)
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragment ids.
|
||||
"""
|
||||
self._inner.fragment_ids(fragment_ids)
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
@@ -125,6 +125,9 @@ class MRRReranker(Reranker):
|
||||
This cannot reuse rerank_hybrid because MRR semantics require treating
|
||||
each vector result as a separate ranking system.
|
||||
"""
|
||||
if not vector_results:
|
||||
raise ValueError("vector_results must not be empty")
|
||||
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
"All elements in vector_results should be of the same type"
|
||||
|
||||
@@ -82,6 +82,9 @@ class RRFReranker(Reranker):
|
||||
results from multiple vector searches as it doesn't support reranking
|
||||
vector results individually.
|
||||
"""
|
||||
if not vector_results:
|
||||
raise ValueError("vector_results must not be empty")
|
||||
|
||||
# Make sure all elements are of the same type
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
|
||||
@@ -255,8 +255,9 @@ def test_plain_scan_query_to_pandas_blob_projection(tmp_db):
|
||||
assert df["double_id"].tolist() == [6, 8]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
|
||||
def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
|
||||
tmp_db, monkeypatch
|
||||
tmp_db, monkeypatch, blob_mode
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
@@ -269,10 +270,69 @@ def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = query.to_pandas(blob_mode="bytes")
|
||||
df = query.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
if blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_blob_descriptions_flatten_uses_scanner(
|
||||
tmp_db, monkeypatch
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_query_to_pandas_blob_desc_flatten", _blob_query_data()
|
||||
)
|
||||
query = table.search().where("id = 1").select(["id", "blob"])
|
||||
|
||||
def fail_to_arrow(*args, **kwargs):
|
||||
raise AssertionError("to_arrow should not be called before scanner pandas")
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = query.to_pandas(blob_mode="descriptions", flatten=True)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
assert any(column == "blob" or column.startswith("blob.") for column in df.columns)
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_scanner_state(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
data = _blob_query_data()
|
||||
table = tmp_db.create_table("test_query_to_pandas_scanner_state", data.slice(0, 2))
|
||||
table.add(data.slice(2, 2))
|
||||
|
||||
fragments = table.to_lance().get_fragments()
|
||||
assert len(fragments) == 2
|
||||
|
||||
query = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.with_row_address()
|
||||
.fragment_ids([fragments[1].fragment_id])
|
||||
)
|
||||
query_obj = query.to_query_object()
|
||||
assert query_obj.with_row_address is True
|
||||
assert query_obj.fragment_ids == [fragments[1].fragment_id]
|
||||
|
||||
df = query.to_pandas(blob_mode="descriptions")
|
||||
|
||||
assert df["id"].tolist() == [3, 4]
|
||||
assert "_rowaddr" in df.columns
|
||||
assert {rowaddr >> 32 for rowaddr in df["_rowaddr"]} == {fragments[1].fragment_id}
|
||||
|
||||
df_by_fragment = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.with_fragments([fragments[0]])
|
||||
.to_pandas(blob_mode="descriptions")
|
||||
)
|
||||
assert df_by_fragment["id"].tolist() == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -312,8 +372,9 @@ async def test_async_plain_scan_query_to_pandas_blob_projection(tmp_db_async):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
|
||||
async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
|
||||
tmp_db_async, monkeypatch
|
||||
tmp_db_async, monkeypatch, blob_mode
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = await tmp_db_async.create_table(
|
||||
@@ -326,10 +387,15 @@ async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = await query.to_pandas(blob_mode="bytes")
|
||||
df = await query.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
if blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db):
|
||||
@@ -342,6 +408,18 @@ def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db):
|
||||
)
|
||||
|
||||
|
||||
def test_vector_query_to_pandas_blob_descriptions_requires_plain_scan(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_vector_query_blob_descriptions", _blob_query_data()
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="plain scan query"):
|
||||
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
|
||||
blob_mode="descriptions"
|
||||
)
|
||||
|
||||
|
||||
def test_order_by_plain_query(mem_db):
|
||||
table = mem_db.create_table(
|
||||
"test_order_by",
|
||||
|
||||
@@ -344,6 +344,12 @@ def test_mrr_reranker(tmp_path):
|
||||
assert len(result_deduped) == len(result)
|
||||
|
||||
|
||||
def test_mrr_reranker_empty_input():
|
||||
reranker = MRRReranker()
|
||||
with pytest.raises(ValueError, match="must not be empty"):
|
||||
reranker.rerank_multivector([])
|
||||
|
||||
|
||||
def test_rrf_reranker_distance():
|
||||
data = pa.table(
|
||||
{
|
||||
|
||||
@@ -1288,6 +1288,45 @@ def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||
|
||||
|
||||
def test_add_nullable_struct_with_none(mem_db: DBConnection):
|
||||
"""Regression test for issue #2654: a nullable struct column whose
|
||||
first batch contains only None values must not crash in
|
||||
_align_field_types with AttributeError: 'pyarrow.lib.DataType'
|
||||
object has no attribute 'fields'.
|
||||
|
||||
PyArrow infers an all-None struct column as `null` (not `struct`),
|
||||
so the type-alignment path needs to handle the case where the
|
||||
source field type is null and use the target type directly.
|
||||
"""
|
||||
# Use the v2.1 file format so that nullable structs are supported.
|
||||
table = mem_db.create_table(
|
||||
"test_nullable_struct",
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.string()),
|
||||
pa.field(
|
||||
"data",
|
||||
pa.struct([pa.field("x", pa.float32())]),
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
),
|
||||
storage_options=dict(new_table_data_storage_version="2.1"),
|
||||
)
|
||||
|
||||
# Adding a row with a non-null struct should work.
|
||||
table.add([{"id": "1", "data": {"x": 1.0}}])
|
||||
|
||||
# Adding a row with None for the nullable struct field should also
|
||||
# work — this is what used to crash.
|
||||
table.add([{"id": "2", "data": None}])
|
||||
|
||||
result = table.to_arrow()
|
||||
assert result.num_rows == 2
|
||||
assert result.column("id").to_pylist() == ["1", "2"]
|
||||
assert result.column("data").to_pylist() == [{"x": 1.0}, None]
|
||||
|
||||
|
||||
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
|
||||
4226
python/uv.lock
generated
4226
python/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.30.1-beta.1"
|
||||
version = "0.30.1-beta.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::table::DropColumnsResult;
|
||||
use crate::table::MergeResult;
|
||||
use crate::table::Tags;
|
||||
use crate::table::UpdateResult;
|
||||
use crate::table::merge::MergeFilter;
|
||||
use crate::table::query::create_multi_vector_plan;
|
||||
use crate::table::{AlterColumnsResult, FieldMetadataUpdate, UpdateFieldMetadataResult};
|
||||
use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics};
|
||||
@@ -1826,16 +1827,57 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
})
|
||||
}
|
||||
|
||||
async fn set_lsm_write_spec(&self, _spec: crate::table::LsmWriteSpec) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "set_lsm_write_spec is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
async fn set_lsm_write_spec(&self, spec: crate::table::LsmWriteSpec) -> Result<()> {
|
||||
use crate::table::LsmWriteSpec;
|
||||
self.check_mutable().await?;
|
||||
|
||||
// Map the spec onto the server's request DTO. `sharding` is internally
|
||||
// tagged on `mode` to mirror sophon's `Sharding` enum; `maintained_indexes`
|
||||
// and `writer_config_defaults` are sent verbatim (an empty list means "no
|
||||
// maintained indexes", not "default to all").
|
||||
let sharding = match &spec {
|
||||
LsmWriteSpec::Bucket {
|
||||
column,
|
||||
num_buckets,
|
||||
..
|
||||
} => serde_json::json!({
|
||||
"mode": "bucket",
|
||||
"column": column,
|
||||
"num_buckets": num_buckets,
|
||||
}),
|
||||
LsmWriteSpec::Identity { column, .. } => serde_json::json!({
|
||||
"mode": "identity",
|
||||
"column": column,
|
||||
}),
|
||||
LsmWriteSpec::Unsharded { .. } => serde_json::json!({ "mode": "unsharded" }),
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"sharding": sharding,
|
||||
"maintained_indexes": spec.maintained_indexes(),
|
||||
"writer_config_defaults": spec.writer_config_defaults(),
|
||||
});
|
||||
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!(
|
||||
"/v1/table/{}/set_lsm_write_spec/",
|
||||
self.identifier
|
||||
))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unset_lsm_write_spec(&self) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "unset_lsm_write_spec is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
self.check_mutable().await?;
|
||||
let request = self.client.post(&format!(
|
||||
"/v1/table/{}/unset_lsm_write_spec/",
|
||||
self.identifier
|
||||
));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
|
||||
@@ -2266,13 +2308,34 @@ impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
|
||||
}
|
||||
let on = value.on[0].clone();
|
||||
|
||||
let when_matched_update_all_filt = match value.when_matched_update_all_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let when_not_matched_by_source_delete_filt =
|
||||
match value.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
on,
|
||||
when_matched_update_all: value.when_matched_update_all,
|
||||
when_matched_update_all_filt: value.when_matched_update_all_filt,
|
||||
when_matched_update_all_filt,
|
||||
when_not_matched_insert_all: value.when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
|
||||
when_not_matched_by_source_delete_filt,
|
||||
// Only serialize use_index when it's false for backwards compatibility
|
||||
use_index: value.use_index,
|
||||
})
|
||||
@@ -4406,6 +4469,91 @@ mod tests {
|
||||
assert!(matches!(e, Error::IndexNotFound { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_unsharded() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(body["sharding"], serde_json::json!({ "mode": "unsharded" }));
|
||||
assert_eq!(body["maintained_indexes"], serde_json::json!(["id_idx"]));
|
||||
assert_eq!(
|
||||
body["writer_config_defaults"],
|
||||
serde_json::json!({ "max_memtable_rows": "1000" })
|
||||
);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"maintained_indexes":["id_idx"]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let spec = crate::table::LsmWriteSpec::unsharded()
|
||||
.with_maintained_indexes(["id_idx"])
|
||||
.with_writer_config_defaults([("max_memtable_rows", "1000")]);
|
||||
table.set_lsm_write_spec(spec).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_bucket() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(
|
||||
body["sharding"],
|
||||
serde_json::json!({ "mode": "bucket", "column": "id", "num_buckets": 16 })
|
||||
);
|
||||
assert_eq!(body["maintained_indexes"], serde_json::json!([]));
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table
|
||||
.set_lsm_write_spec(crate::table::LsmWriteSpec::bucket("id", 16))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_identity() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(
|
||||
body["sharding"],
|
||||
serde_json::json!({ "mode": "identity", "column": "tenant" })
|
||||
);
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table
|
||||
.set_lsm_write_spec(crate::table::LsmWriteSpec::identity("tenant"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unset_lsm_write_spec() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/unset_lsm_write_spec/"
|
||||
);
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table.unset_lsm_write_spec().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_index() {
|
||||
let table = _make_table_with_indices(0);
|
||||
|
||||
@@ -53,6 +53,12 @@ pub struct MergeResult {
|
||||
pub num_rows: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MergeFilter {
|
||||
Sql(String),
|
||||
Expr(datafusion_expr::Expr),
|
||||
}
|
||||
|
||||
/// A builder used to create and run a merge insert operation
|
||||
///
|
||||
/// See [`super::Table::merge_insert`] for more context
|
||||
@@ -61,10 +67,10 @@ pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn BaseTable>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
pub(crate) when_matched_update_all_filt: Option<MergeFilter>,
|
||||
pub(crate) when_not_matched_insert_all: bool,
|
||||
pub(crate) when_not_matched_by_source_delete: bool,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<MergeFilter>,
|
||||
pub(crate) timeout: Option<Duration>,
|
||||
pub(crate) use_index: bool,
|
||||
pub(crate) use_lsm_write: Option<bool>,
|
||||
@@ -110,7 +116,14 @@ impl MergeInsertBuilder {
|
||||
/// For example, "target.last_update < source.last_update"
|
||||
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = condition;
|
||||
self.when_matched_update_all_filt = condition.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_matched_update_all`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_matched_update_all_expr(&mut self, condition: datafusion_expr::Expr) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = Some(MergeFilter::Expr(condition));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -132,7 +145,17 @@ impl MergeInsertBuilder {
|
||||
/// limit what rows are deleted.
|
||||
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = filter;
|
||||
self.when_not_matched_by_source_delete_filt = filter.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_not_matched_by_source_delete`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_not_matched_by_source_delete_expr(
|
||||
&mut self,
|
||||
filter: datafusion_expr::Expr,
|
||||
) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = Some(MergeFilter::Expr(filter));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -234,7 +257,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
) {
|
||||
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||
(true, Some(MergeFilter::Sql(filt))) => {
|
||||
builder.when_matched(WhenMatched::update_if(&dataset, &filt)?)
|
||||
}
|
||||
(true, Some(MergeFilter::Expr(expr))) => {
|
||||
builder.when_matched(WhenMatched::update_if_expr(expr))
|
||||
}
|
||||
};
|
||||
if params.when_not_matched_insert_all {
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||
@@ -242,10 +270,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
|
||||
}
|
||||
if params.when_not_matched_by_source_delete {
|
||||
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
} else {
|
||||
WhenNotMatchedBySource::Delete
|
||||
let behavior = match params.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(filter)) => {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
}
|
||||
Some(MergeFilter::Expr(expr)) => WhenNotMatchedBySource::DeleteIf(expr),
|
||||
None => WhenNotMatchedBySource::Delete,
|
||||
};
|
||||
builder.when_not_matched_by_source(behavior);
|
||||
} else {
|
||||
@@ -386,6 +416,45 @@ mod tests {
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 25);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_expr() {
|
||||
use datafusion_expr::{col, lit};
|
||||
|
||||
let conn = connect("memory://").execute().await.unwrap();
|
||||
|
||||
// Create a dataset with i=0..10
|
||||
let batches = merge_insert_test_batches(0, 0);
|
||||
let table = conn
|
||||
.create_table("my_table_expr", batches)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Conditional update that only replaces the age=0 data
|
||||
let new_batches = merge_insert_test_batches(5, 3);
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// use expression: target.age = 0
|
||||
let expr = col("target.age").eq(lit(0));
|
||||
merge_insert_builder.when_matched_update_all_expr(expr);
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
|
||||
// Delete with expression
|
||||
// Create new batches with i=10..20 (so target rows i=0..9 are not matched by source)
|
||||
let new_batches = merge_insert_test_batches(10, 0); // won't insert or update since we don't enable matched/unmatched actions
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// delete if target.age = 3
|
||||
let delete_expr = col("target.age").eq(lit(3));
|
||||
merge_insert_builder.when_not_matched_by_source_delete_expr(delete_expr);
|
||||
let result = merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(result.num_deleted_rows, 5);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user