diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 5639b8bba..c50cf29f9 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -362,6 +362,22 @@ def test_table_create_indices(): schema=dict( fields=[ dict(name="id", type={"type": "int64"}, nullable=False), + dict(name="text", type={"type": "string"}, nullable=False), + dict( + name="vector", + type={ + "type": "fixed_size_list", + "fields": [ + dict( + name="item", + type={"type": "float"}, + nullable=True, + ) + ], + "length": 2, + }, + nullable=False, + ), ] ), ) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index efc23415e..b50ae4e7a 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1528,8 +1528,10 @@ impl BaseTable for RemoteTable { }); } }; + let schema = self.schema().await?; + let (canonical_column, field) = resolve_arrow_field_path(&schema, &column)?; let mut body = serde_json::json!({ - "column": column + "column": canonical_column }); // Add name parameter if provided (for backwards compatibility, only include if Some) @@ -1564,8 +1566,6 @@ impl BaseTable for RemoteTable { Index::LabelList(p) => ("LABEL_LIST", Some(to_json(p)?)), Index::FTS(p) => ("FTS", Some(to_json(p)?)), Index::Auto => { - let schema = self.schema().await?; - let field = resolve_arrow_field_path(&schema, &column)?; if supported_vector_data_type(field.data_type()) { body[METRIC_TYPE_KEY] = serde_json::Value::String(DistanceType::L2.to_string().to_lowercase()); @@ -1862,16 +1862,26 @@ impl BaseTable for RemoteTable { status_code: None, })?; + let schema = self.schema().await?; + // Make request to get stats for each index, so we get the index type. // This is a bit inefficient, but it's the only way to get the index type. let mut futures = Vec::with_capacity(body.indexes.len()); for index in body.indexes { + let columns = index + .columns + .iter() + .map(|column| { + resolve_arrow_field_path(&schema, column) + .map(|(canonical_column, _)| canonical_column) + }) + .collect::>>()?; let future = async move { match self.index_stats(&index.index_name).await { Ok(Some(stats)) => Ok(Some(IndexConfig { name: index.index_name, index_type: stats.index_type, - columns: index.columns, + columns, })), Ok(None) => Ok(None), // The index must have been deleted since we listed it. Err(e) => Err(e), @@ -2313,6 +2323,38 @@ mod tests { .unwrap() } + fn nested_index_schema() -> Schema { + let vector_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 8); + Schema::new(vec![ + Field::new( + "metadata", + DataType::Struct(vec![Field::new("user_id", DataType::Int32, false)].into()), + false, + ), + Field::new( + "image", + DataType::Struct(vec![Field::new("embedding", vector_type, false)].into()), + false, + ), + Field::new( + "payload", + DataType::Struct(vec![Field::new("text", DataType::Utf8, false)].into()), + false, + ), + Field::new( + "meta-data", + DataType::Struct(vec![Field::new("user-id", DataType::Int32, false)].into()), + false, + ), + Field::new( + "literal", + DataType::Struct(vec![Field::new("a.b", DataType::Int32, false)].into()), + false, + ), + ]) + } + #[rstest] #[case("", 0)] #[case("{}", 0)] @@ -3079,6 +3121,59 @@ mod tests { .unwrap(); } + #[tokio::test] + async fn test_query_vector_nested_field_path() { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "vector_column": "image.embedding", + "prefilter": true, + "k": 10, + "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, + "lower_bound": Option::::None, + "upper_bound": Option::::None, + "ef": Option::::None, + "refine_factor": Option::::None, + "version": null, + }); + expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); + assert_eq!(body, expected_body); + + let response_body = write_ipc_file(&expected_data_ref); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let _ = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .column("image.embedding") + .execute() + .await + .unwrap(); + } + #[tokio::test] async fn test_query_fts() { let table = Table::new_with_handler("my_table", |request| { @@ -3160,7 +3255,7 @@ mod tests { "query": { "match": { "terms": "hello world", - "column": "a", + "column": "payload.text", "boost": 1.0, "fuzziness": 0, "max_expansions": 50, @@ -3194,7 +3289,7 @@ mod tests { .query() .full_text_search(FullTextSearchQuery::new_query( MatchQuery::new("hello world".to_owned()) - .with_column(Some("a".to_owned())) + .with_column(Some("payload.text".to_owned())) .into(), )) .with_row_id() @@ -3465,32 +3560,152 @@ mod tests { for (index_type, expected_body, index) in cases { let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/v1/table/my_table/create_index/"); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - JSON_CONTENT_TYPE - ); - let body = request.body().unwrap().as_bytes().unwrap(); - let body: serde_json::Value = serde_json::from_slice(body).unwrap(); - let mut expected_body = expected_body.clone(); - expected_body["column"] = "a".into(); - expected_body[INDEX_TYPE_KEY] = index_type.into(); + match request.url().path() { + "/v1/table/my_table/describe/" => { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + http::Response::builder() + .status(200) + .body(describe_response(&schema)) + .unwrap() + } + "/v1/table/my_table/create_index/" => { + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = expected_body.clone(); + expected_body["column"] = "a".into(); + expected_body[INDEX_TYPE_KEY] = index_type.into(); - assert_eq!(body, expected_body); + assert_eq!(body, expected_body); - http::Response::builder().status(200).body("{}").unwrap() + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("Unexpected path: {}", path), + } }); table.create_index(&["a"], index).execute().await.unwrap(); } } + #[tokio::test] + async fn test_create_index_nested_field_paths() { + let schema = nested_index_schema(); + let expected_requests = Arc::new(vec![ + json!({ + "column": "metadata.user_id", + "index_type": "BTREE", + }), + json!({ + "column": "image.embedding", + "index_type": "IVF_PQ", + "metric_type": "l2", + }), + { + let mut body = serde_json::to_value(InvertedIndexParams::default()).unwrap(); + body["column"] = "payload.text".into(); + body["index_type"] = "FTS".into(); + body + }, + json!({ + "column": "`meta-data`.`user-id`", + "index_type": "BTREE", + }), + json!({ + "column": "literal.`a.b`", + "index_type": "BTREE", + }), + ]); + let request_idx = Arc::new(AtomicUsize::new(0)); + let table = Table::new_with_handler("my_table", { + let schema = schema.clone(); + let expected_requests = expected_requests.clone(); + let request_idx = request_idx.clone(); + move |request| { + assert_eq!(request.method(), "POST"); + match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() + .status(200) + .body(describe_response(&schema)) + .unwrap(), + "/v1/table/my_table/create_index/" => { + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + let idx = request_idx.fetch_add(1, Ordering::SeqCst); + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + assert_eq!(body, expected_requests[idx]); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("Unexpected path: {}", path), + } + } + }); + + table + .create_index(&["Metadata.USER_ID"], Index::BTree(Default::default())) + .execute() + .await + .unwrap(); + table + .create_index(&["Image.Embedding"], Index::Auto) + .execute() + .await + .unwrap(); + table + .create_index(&["Payload.Text"], Index::FTS(Default::default())) + .execute() + .await + .unwrap(); + table + .create_index(&["`META-DATA`.`USER-ID`"], Index::BTree(Default::default())) + .execute() + .await + .unwrap(); + table + .create_index(&["literal.`A.B`"], Index::BTree(Default::default())) + .execute() + .await + .unwrap(); + + assert_eq!(request_idx.load(Ordering::SeqCst), expected_requests.len()); + } + #[tokio::test] async fn test_list_indices() { - let table = Table::new_with_handler("my_table", |request| { + let schema = Schema::new(vec![ + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 8), + false, + ), + Field::new( + "metadata", + DataType::Struct(vec![Field::new("my.column", DataType::Utf8, true)].into()), + false, + ), + ]); + let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + return http::Response::builder() + .status(200) + .body(describe_response(&schema)) + .unwrap(); + } "/v1/table/my_table/index/list/" => { serde_json::json!({ "indexes": [ @@ -4010,6 +4225,20 @@ mod tests { assert_eq!(request.method(), "POST"); let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + let schema = Schema::new(vec![ + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 8, + ), + false, + ), + Field::new("my_column", DataType::Utf8, false), + ]); + serde_json::from_str::(&describe_response(&schema)).unwrap() + } "/v1/table/my_table/index/list/" => { serde_json::json!({ "indexes": [ @@ -4171,13 +4400,23 @@ mod tests { assert_eq!(value["index_type"], "IVF_PQ"); } - http::Response::builder().status(200).body("").unwrap() - } - "/v1/table/dev$users/describe/" => { - // Needed for schema check in Auto index type http::Response::builder() .status(200) - .body(r#"{"version": 1, "schema": {"fields": [{"name": "embedding", "type": {"type": "list", "item": {"type": "float32"}}, "nullable": false}]}}"#) + .body("".to_string()) + .unwrap() + } + "/v1/table/dev$users/describe/" => { + let schema = Schema::new(vec![Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 8, + ), + false, + )]); + http::Response::builder() + .status(200) + .body(describe_response(&schema)) .unwrap() } _ => { diff --git a/rust/lancedb/src/utils/mod.rs b/rust/lancedb/src/utils/mod.rs index d43912058..c0823bfd3 100644 --- a/rust/lancedb/src/utils/mod.rs +++ b/rust/lancedb/src/utils/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod background_cache; use std::sync::Arc; use arrow_array::RecordBatch; -use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::RecordBatchStream; use futures::{FutureExt, Stream}; @@ -199,38 +199,32 @@ fn collect_vector_columns( path.pop(); } -pub(crate) fn resolve_arrow_field_path(schema: &Schema, column: &str) -> Result { - let segments = - lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput { - message: format!("Invalid field path `{}`: {}", column, e), +pub(crate) fn resolve_arrow_field_path(schema: &Schema, column: &str) -> Result<(String, Field)> { + lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput { + message: format!("Invalid field path `{}`: {}", column, e), + })?; + + let lance_schema = + lance_core::datatypes::Schema::try_from(schema).map_err(|e| Error::Schema { + message: format!("Invalid schema: {}", e), })?; - let mut fields = schema.fields(); - - for (idx, segment) in segments.iter().enumerate() { - let field = find_field(fields, segment).ok_or_else(|| Error::Schema { - message: format!("Field path `{}` not found in schema", column), + let field_path = lance_schema + .resolve_case_insensitive(column) + .ok_or_else(|| Error::Schema { + message: format!( + "Field path `{}` not found in schema. Available field paths: {}", + column, + lance_schema.field_paths().join(", ") + ), })?; - if idx + 1 == segments.len() { - return Ok(field.clone()); - } - fields = match field.data_type() { - DataType::Struct(fields) => fields, - _ => { - return Err(Error::Schema { - message: format!("Field path `{}` not found in schema", column), - }); - } - }; - } - - unreachable!("parse_field_path returns at least one segment") -} - -fn find_field<'a>(fields: &'a Fields, name: &str) -> Option<&'a Field> { - fields + let field = field_path.last().expect("field path should be non-empty"); + let path_segments = field_path .iter() - .find(|field| field.name() == name) - .map(|field| field.as_ref()) + .map(|field| field.name.as_str()) + .collect::>(); + let canonical_path = lance_core::datatypes::format_field_path(&path_segments); + + Ok((canonical_path, Field::from(*field))) } pub fn supported_btree_data_type(dtype: &DataType) -> bool {