From def087fc856c65b8a30836d1f84471e0ff50bcf8 Mon Sep 17 00:00:00 2001 From: QianZhu Date: Thu, 23 May 2024 13:10:46 -0700 Subject: [PATCH] fix: parse index_stats for scalar index (#1319) parse the index stats for scalar index - it is different from the index stats for vector index --- rust/lancedb/Cargo.toml | 1 + rust/lancedb/src/index.rs | 19 ++++++++++++ rust/lancedb/src/index/vector.rs | 16 ---------- rust/lancedb/src/table.rs | 53 +++++++++++++++++++++++++++----- 4 files changed, 66 insertions(+), 23 deletions(-) diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 1f760899..6b7338c9 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -38,6 +38,7 @@ url.workspace = true regex.workspace = true serde = { version = "^1" } serde_json = { version = "1" } +serde_with = { version = "3.8.1" } # For remote feature reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index d29ef9cf..b0a36430 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -14,6 +14,9 @@ use std::sync::Arc; +use serde::Deserialize; +use serde_with::skip_serializing_none; + use crate::{table::TableInternal, Result}; use self::{ @@ -83,3 +86,19 @@ pub struct IndexConfig { /// be more columns to represent composite indices. pub columns: Vec, } + +#[skip_serializing_none] +#[derive(Debug, Deserialize)] +pub struct IndexMetadata { + pub metric_type: Option, + pub index_type: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Deserialize)] +pub struct IndexStatistics { + pub num_indexed_rows: usize, + pub num_unindexed_rows: usize, + pub index_type: Option, + pub indices: Vec, +} diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index 4be1a5e6..fe794394 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -19,8 +19,6 @@ //! values use std::cmp::max; -use serde::Deserialize; - use lance::table::format::{Index, Manifest}; use crate::DistanceType; @@ -46,20 +44,6 @@ impl VectorIndex { } } -#[derive(Debug, Deserialize)] -pub struct VectorIndexMetadata { - pub metric_type: String, - pub index_type: String, -} - -#[derive(Debug, Deserialize)] -pub struct VectorIndexStatistics { - pub num_indexed_rows: usize, - pub num_unindexed_rows: usize, - pub index_type: String, - pub indices: Vec, -} - /// Builder for an IVF PQ index. /// /// This index stores a compressed (quantized) copy of every vector. These vectors diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 4485d242..5ca1570b 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -48,10 +48,9 @@ use crate::arrow::IntoArrow; use crate::connection::NoData; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::error::{Error, Result}; -use crate::index::vector::{ - IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics, -}; +use crate::index::vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex}; use crate::index::IndexConfig; +use crate::index::IndexStatistics; use crate::index::{ vector::{suggested_num_partitions, suggested_num_sub_vectors}, Index, IndexBuilder, @@ -1217,7 +1216,7 @@ impl NativeTable { pub async fn get_index_type(&self, index_uuid: &str) -> Result> { match self.load_index_stats(index_uuid).await? { - Some(stats) => Ok(Some(stats.index_type)), + Some(stats) => Ok(Some(stats.index_type.unwrap_or_default())), None => Ok(None), } } @@ -1228,7 +1227,7 @@ impl NativeTable { stats .indices .iter() - .map(|i| i.metric_type.clone()) + .filter_map(|i| i.metric_type.clone()) .collect(), )), None => Ok(None), @@ -1244,7 +1243,7 @@ impl NativeTable { .collect()) } - async fn load_index_stats(&self, index_uuid: &str) -> Result> { + async fn load_index_stats(&self, index_uuid: &str) -> Result> { let index = self .load_indices() .await? @@ -1255,7 +1254,7 @@ impl NativeTable { } let dataset = self.dataset.get().await?; let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?; - let index_stats: VectorIndexStatistics = whatever!( + let index_stats: IndexStatistics = whatever!( serde_json::from_str(&index_stats), "error deserializing index statistics {index_stats}", ); @@ -2475,6 +2474,25 @@ mod tests { .unwrap(), Some(0) ); + assert_eq!( + table + .as_native() + .unwrap() + .get_index_type(index_uuid) + .await + .unwrap() + .map(|index_type| index_type.to_string()), + Some("IVF".to_string()) + ); + assert_eq!( + table + .as_native() + .unwrap() + .get_distance_type(index_uuid) + .await + .unwrap(), + Some(crate::DistanceType::L2.to_string()) + ); } #[tokio::test] @@ -2644,6 +2662,27 @@ mod tests { let index = index_configs.into_iter().next().unwrap(); assert_eq!(index.index_type, crate::index::IndexType::BTree); assert_eq!(index.columns, vec!["i".to_string()]); + + let indices = table.as_native().unwrap().load_indices().await.unwrap(); + let index_uuid = &indices[0].index_uuid; + assert_eq!( + table + .as_native() + .unwrap() + .count_indexed_rows(index_uuid) + .await + .unwrap(), + Some(1) + ); + assert_eq!( + table + .as_native() + .unwrap() + .count_unindexed_rows(index_uuid) + .await + .unwrap(), + Some(0) + ); } #[tokio::test]