From d2c6759e7f7f6853b901eb224e5d655aaf929238 Mon Sep 17 00:00:00 2001 From: Tristan Zajonc Date: Wed, 23 Jul 2025 16:25:33 -0700 Subject: [PATCH] fix: use import stubs to prevent MLX doctest collection failures (#2536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add `create_import_stub()` helper to `embeddings/utils.py` for handling optional dependencies - Fix MLX doctest collection failures by using import stubs in `gte_mlx_model.py` - Module now imports successfully for doctest collection even when MLX is not installed ## Changes - **New utility function**: `create_import_stub()` creates placeholder objects that allow class inheritance but raise helpful errors when used - **Updated MLX model**: Uses import stubs instead of direct imports that fail immediately - **Graceful degradation**: Clear error messages when MLX functionality is accessed without MLX installed ## Test Results - ✅ `pytest --doctest-modules python/lancedb` now passes (with and without MLX installed) - ✅ All existing tests continue to pass - ✅ MLX functionality works normally when MLX is installed - ✅ Helpful error messages when MLX functionality is used without MLX installed Fixes #2538 --------- Co-authored-by: Will Jones --- .../lancedb/embeddings/gte_mlx_model.py | 7 +++- python/python/lancedb/embeddings/utils.py | 30 +++++++++++++ python/python/lancedb/query.py | 42 +++++++++---------- python/python/tests/test_hybrid_query.py | 2 +- python/python/tests/test_query.py | 5 ++- 5 files changed, 60 insertions(+), 26 deletions(-) diff --git a/python/python/lancedb/embeddings/gte_mlx_model.py b/python/python/lancedb/embeddings/gte_mlx_model.py index af6f5fbf..b877d5e1 100644 --- a/python/python/lancedb/embeddings/gte_mlx_model.py +++ b/python/python/lancedb/embeddings/gte_mlx_model.py @@ -9,11 +9,14 @@ from huggingface_hub import snapshot_download from pydantic import BaseModel from transformers import BertTokenizer +from .utils import create_import_stub + try: import mlx.core as mx import mlx.nn as nn except ImportError: - raise ImportError("You need to install MLX to use this model use - pip install mlx") + mx = create_import_stub("mlx.core", "mlx") + nn = create_import_stub("mlx.nn", "mlx") def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array: @@ -72,7 +75,7 @@ class TransformerEncoder(nn.Module): super().__init__() self.layers = [ TransformerEncoderLayer(dims, num_heads, mlp_dims) - for i in range(num_layers) + for _ in range(num_layers) ] def __call__(self, x, mask): diff --git a/python/python/lancedb/embeddings/utils.py b/python/python/lancedb/embeddings/utils.py index a9e4bfb8..8b892a06 100644 --- a/python/python/lancedb/embeddings/utils.py +++ b/python/python/lancedb/embeddings/utils.py @@ -21,6 +21,36 @@ from ..dependencies import pandas as pd from ..util import attempt_import_or_raise +def create_import_stub(module_name: str, package_name: str = None): + """ + Create a stub module that allows class definition but fails when used. + This allows modules to be imported for doctest collection even when + optional dependencies are not available. + + Parameters + ---------- + module_name : str + The name of the module to create a stub for + package_name : str, optional + The package name to suggest in the error message + + Returns + ------- + object + A stub object that can be used in place of the module + """ + + class _ImportStub: + def __getattr__(self, name): + return _ImportStub # Return stub for chained access like nn.Module + + def __call__(self, *args, **kwargs): + pkg = package_name or module_name + raise ImportError(f"You need to install {pkg} to use this functionality") + + return _ImportStub() + + # ruff: noqa: PERF203 def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1): def wrapper(fn): diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 56f988e8..32af99c3 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -906,11 +906,11 @@ class LanceQueryBuilder(ABC): >>> plan = table.search(query).explain_plan(True) >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance] - GlobalLimitExec: skip=0, fetch=10 - FilterExec: _distance@2 IS NOT NULL - SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] - KNNVectorDistance: metric=l2 - LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + GlobalLimitExec: skip=0, fetch=10 + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] + KNNVectorDistance: metric=l2 + LanceRead: uri=..., projection=[vector], ... Parameters ---------- @@ -940,19 +940,19 @@ class LanceQueryBuilder(ABC): >>> plan = table.search(query).analyze_plan() >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE AnalyzeExec verbose=true, metrics=[] - ProjectionExec: expr=[...], metrics=[...] - GlobalLimitExec: skip=0, fetch=10, metrics=[...] - FilterExec: _distance@2 IS NOT NULL, - metrics=[output_rows=..., elapsed_compute=...] - SortExec: TopK(fetch=10), expr=[...], - preserve_partitioning=[...], - metrics=[output_rows=..., elapsed_compute=..., row_replacements=...] - KNNVectorDistance: metric=l2, - metrics=[output_rows=..., elapsed_compute=..., output_batches=...] - LanceScan: uri=..., projection=[vector], row_id=true, - row_addr=false, ordered=false, - metrics=[output_rows=..., elapsed_compute=..., - bytes_read=..., iops=..., requests=...] + TracedExec, metrics=[] + ProjectionExec: expr=[...], metrics=[...] + GlobalLimitExec: skip=0, fetch=10, metrics=[...] + FilterExec: _distance@2 IS NOT NULL, + metrics=[output_rows=..., elapsed_compute=...] + SortExec: TopK(fetch=10), expr=[...], + preserve_partitioning=[...], + metrics=[output_rows=..., elapsed_compute=..., row_replacements=...] + KNNVectorDistance: metric=l2, + metrics=[output_rows=..., elapsed_compute=..., output_batches=...] + LanceRead: uri=..., projection=[vector], ... + metrics=[output_rows=..., elapsed_compute=..., + bytes_read=..., iops=..., requests=...] Returns ------- @@ -2043,7 +2043,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): FilterExec: _distance@2 IS NOT NULL SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] KNNVectorDistance: metric=l2 - LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + LanceRead: uri=..., projection=[vector], ... Parameters ---------- @@ -2429,7 +2429,7 @@ class AsyncQueryBase(object): FilterExec: _distance@2 IS NOT NULL SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] KNNVectorDistance: metric=l2 - LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + LanceRead: uri=..., projection=[vector], ... Parameters ---------- @@ -3054,7 +3054,7 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase): FilterExec: _distance@2 IS NOT NULL SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] KNNVectorDistance: metric=l2 - LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + LanceRead: uri=..., projection=[vector], ... FTS Search Plan: ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score] diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index 33245f7f..3957568a 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -166,7 +166,7 @@ async def test_explain_plan(table: AsyncTable): assert "Vector Search Plan" in plan assert "KNNVectorDistance" in plan assert "FTS Search Plan" in plan - assert "LanceScan" in plan + assert "LanceRead" in plan @pytest.mark.asyncio diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 9f51686d..72a504c3 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -839,7 +839,7 @@ async def test_explain_plan_with_filters(table_async: AsyncTable): table_async.query().nearest_to(pa.array([1, 2])).where("id = 1").explain_plan() ) assert "KNN" in plan_with_filter - assert "FilterExec" in plan_with_filter + assert "LanceRead" in plan_with_filter # Test FTS query with filter from lancedb.index import FTS @@ -850,7 +850,8 @@ async def test_explain_plan_with_filters(table_async: AsyncTable): ) plan_fts_filter = await query_fts_filter.where("id = 1").explain_plan() assert "MatchQuery: query=dog" in plan_fts_filter - assert "FilterExec: id@" in plan_fts_filter # Should show filter details + assert "LanceRead" in plan_fts_filter + assert "full_filter=id = Int64(1)" in plan_fts_filter # Should show filter details @pytest.mark.asyncio