From 33b402c8610fc9c1145120c0a26118da134d979d Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 1 Oct 2024 09:16:18 -0700 Subject: [PATCH] fix: `list_indices` returns correct index type (#1715) Fixes https://github.com/lancedb/lancedb/issues/1711 Doesn't address this https://github.com/lancedb/lance/issues/2039 Instead we load the index statistics, which seems to contain the index type. However, this involves more IO than previously. I'm not sure whether we care that much. If we do, we can fix that upstream Lance issue. --- python/python/tests/test_index.py | 10 ++++------ rust/lancedb/src/index.rs | 21 +++++++++++++++++++- rust/lancedb/src/table.rs | 33 ++++++++++++++++--------------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index b0646afe..1245997e 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -63,9 +63,8 @@ async def test_create_scalar_index(some_table: AsyncTable): @pytest.mark.asyncio async def test_create_bitmap_index(some_table: AsyncTable): await some_table.create_index("id", config=Bitmap()) - # TODO: Fix via https://github.com/lancedb/lance/issues/2039 - # indices = await some_table.list_indices() - # assert str(indices) == '[Index(Bitmap, columns=["id"])]' + indices = await some_table.list_indices() + assert str(indices) == '[Index(Bitmap, columns=["id"])]' indices = await some_table.list_indices() assert len(indices) == 1 index_name = indices[0].name @@ -80,9 +79,8 @@ async def test_create_bitmap_index(some_table: AsyncTable): @pytest.mark.asyncio async def test_create_label_list_index(some_table: AsyncTable): await some_table.create_index("tags", config=LabelList()) - # TODO: Fix via https://github.com/lancedb/lance/issues/2039 - # indices = await some_table.list_indices() - # assert str(indices) == '[Index(LabelList, columns=["id"])]' + indices = await some_table.list_indices() + assert str(indices) == '[Index(LabelList, columns=["tags"])]' @pytest.mark.asyncio diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index 21301a2b..1ff80137 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -18,7 +18,7 @@ use scalar::FtsIndexBuilder; use serde::Deserialize; use serde_with::skip_serializing_none; -use crate::{table::TableInternal, DistanceType, Result}; +use crate::{table::TableInternal, DistanceType, Error, Result}; use self::{ scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder}, @@ -136,6 +136,25 @@ impl std::fmt::Display for IndexType { } } +impl std::str::FromStr for IndexType { + type Err = Error; + + fn from_str(value: &str) -> Result { + match value.to_uppercase().as_str() { + "BTREE" => Ok(Self::BTree), + "BITMAP" => Ok(Self::Bitmap), + "LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList), + "FTS" => Ok(Self::FTS), + "IVF_PQ" => Ok(Self::IvfPq), + "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), + "IVF_HNSW_SQ" => Ok(Self::IvfHnswSq), + _ => Err(Error::InvalidInput { + message: format!("the input value {} is not a valid IndexType", value), + }), + } + } +} + /// A description of an index currently configured on a column #[derive(Debug, PartialEq, Clone)] pub struct IndexConfig { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index cd52c601..5f286cf7 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -25,6 +25,7 @@ use arrow_schema::{Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::ExecutionPlan; +use futures::{StreamExt, TryStreamExt}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; @@ -2023,28 +2024,28 @@ impl TableInternal for NativeTable { async fn list_indices(&self) -> Result> { let dataset = self.dataset.get().await?; let indices = dataset.load_indices().await?; - indices.iter().map(|idx| { - let mut is_vector = false; + futures::stream::iter(indices.as_slice()).then(|idx| async { + let stats = dataset.index_statistics(idx.name.as_str()).await?; + let stats: serde_json::Value = serde_json::from_str(&stats).map_err(|e| Error::Runtime { + message: format!("error deserializing index statistics: {}", e), + })?; + let index_type = stats.get("index_type").and_then(|v| v.as_str()) + .ok_or_else(|| Error::Runtime { + message: "index statistics was missing index type".to_string(), + })?; + let index_type: crate::index::IndexType = index_type.parse().map_err(|e| Error::Runtime { + message: format!("error parsing index type: {}", e), + })?; + let mut columns = Vec::with_capacity(idx.fields.len()); for field_id in &idx.fields { let field = dataset.schema().field_by_id(*field_id).ok_or_else(|| Error::Runtime { message: format!("The index with name {} and uuid {} referenced a field with id {} which does not exist in the schema", idx.name, idx.uuid, field_id) })?; - if field.data_type().is_nested() { - // Temporary hack to determine if an index is scalar or vector - // Should be removed in https://github.com/lancedb/lance/issues/2039 - is_vector = true; - } columns.push(field.name.clone()); } - let index_type = if is_vector { - crate::index::IndexType::IvfPq - } else { - crate::index::IndexType::BTree - }; - let name = idx.name.clone(); Ok(IndexConfig { index_type, columns, name }) - }).collect::>>() + }).try_collect::>().await } fn dataset_uri(&self) -> &str { @@ -2803,7 +2804,7 @@ mod tests { let index_configs = table.list_indices().await.unwrap(); assert_eq!(index_configs.len(), 1); let index = index_configs.into_iter().next().unwrap(); - assert_eq!(index.index_type, crate::index::IndexType::IvfPq); + assert_eq!(index.index_type, crate::index::IndexType::IvfHnswSq); assert_eq!(index.columns, vec!["embeddings".to_string()]); assert_eq!(table.count_rows(None).await.unwrap(), 512); assert_eq!(table.name(), "test"); @@ -2867,7 +2868,7 @@ mod tests { let index_configs = table.list_indices().await.unwrap(); assert_eq!(index_configs.len(), 1); let index = index_configs.into_iter().next().unwrap(); - assert_eq!(index.index_type, crate::index::IndexType::IvfPq); + assert_eq!(index.index_type, crate::index::IndexType::IvfHnswPq); assert_eq!(index.columns, vec!["embeddings".to_string()]); assert_eq!(table.count_rows(None).await.unwrap(), 512); assert_eq!(table.name(), "test");