Skip to main content

operator/statement/
copy_table_from.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use client::{Output, OutputData, OutputMeta};
23use common_base::readable_size::ReadableSize;
24use common_datasource::file_format::csv::{
25    CsvFormat, is_skippable_arrow_error, tolerant_csv_stream,
26};
27use common_datasource::file_format::json::JsonFormat;
28use common_datasource::file_format::orc::{ReaderAdapter, infer_orc_schema, new_orc_stream_reader};
29use common_datasource::file_format::{FileFormat, Format, file_to_stream};
30use common_datasource::lister::{Lister, Source};
31use common_datasource::object_store::{FS_SCHEMA, build_backend, parse_url};
32use common_datasource::util::find_dir_and_filename;
33use common_query::{OutputCost, OutputRows};
34use common_recordbatch::DfSendableRecordBatchStream;
35use common_recordbatch::adapter::RecordBatchStreamTypeAdapter;
36use common_telemetry::{debug, tracing};
37use datafusion::datasource::physical_plan::{CsvSource, FileSource, JsonSource};
38use datafusion::parquet::arrow::ParquetRecordBatchStreamBuilder;
39use datafusion::parquet::arrow::arrow_reader::ArrowReaderMetadata;
40use datafusion_common::DataFusionError;
41use datafusion_common::arrow::error::ArrowError;
42use datafusion_common::config::CsvOptions;
43use datafusion_expr::Expr;
44use datatypes::arrow::compute::can_cast_types;
45use datatypes::arrow::datatypes::{DataType as ArrowDataType, Schema, SchemaRef};
46use datatypes::arrow::record_batch::RecordBatch;
47use datatypes::vectors::Helper;
48use futures_util::StreamExt;
49use object_store::{Entry, EntryMode, ObjectStore};
50use regex::Regex;
51use session::context::QueryContextRef;
52use snafu::{ResultExt, ensure};
53use table::requests::{CopyTableRequest, InsertRequest};
54use table::table_reference::TableReference;
55use tokio_util::compat::FuturesAsyncReadCompatExt;
56
57use crate::error::{self, IntoVectorsSnafu, PathNotFoundSnafu, Result};
58use crate::statement::StatementExecutor;
59
60const DEFAULT_BATCH_SIZE: usize = 8192;
61const DEFAULT_READ_BUFFER: usize = 256 * 1024;
62
63enum FileMetadata {
64    Parquet {
65        schema: SchemaRef,
66        metadata: ArrowReaderMetadata,
67        path: String,
68    },
69    Orc {
70        schema: SchemaRef,
71        path: String,
72    },
73    Json {
74        schema: SchemaRef,
75        format: JsonFormat,
76        path: String,
77    },
78    Csv {
79        schema: SchemaRef,
80        format: CsvFormat,
81        path: String,
82    },
83}
84
85impl FileMetadata {
86    /// Returns the [SchemaRef]
87    pub fn schema(&self) -> &SchemaRef {
88        match self {
89            FileMetadata::Parquet { schema, .. } => schema,
90            FileMetadata::Orc { schema, .. } => schema,
91            FileMetadata::Json { schema, .. } => schema,
92            FileMetadata::Csv { schema, .. } => schema,
93        }
94    }
95}
96
97impl StatementExecutor {
98    async fn list_copy_from_entries(
99        &self,
100        req: &CopyTableRequest,
101    ) -> Result<(ObjectStore, Vec<Entry>)> {
102        let (schema, _host, path) = parse_url(&req.location).context(error::ParseUrlSnafu)?;
103
104        if schema.to_uppercase() == FS_SCHEMA {
105            ensure!(Path::new(&path).exists(), PathNotFoundSnafu { path });
106        }
107
108        let object_store =
109            build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?;
110
111        let (dir, filename) = find_dir_and_filename(&path);
112        let regex = req
113            .pattern
114            .as_ref()
115            .map(|x| Regex::new(x))
116            .transpose()
117            .context(error::BuildRegexSnafu)?;
118
119        let source = if let Some(filename) = filename {
120            Source::Filename(filename)
121        } else {
122            Source::Dir
123        };
124
125        let lister = Lister::new(object_store.clone(), source.clone(), dir.clone(), regex);
126
127        let entries = lister.list().await.context(error::ListObjectsSnafu)?;
128        debug!("Copy from dir: {dir:?}, {source:?}, entries: {entries:?}");
129        Ok((object_store, entries))
130    }
131
132    async fn collect_metadata(
133        &self,
134        object_store: &ObjectStore,
135        format: Format,
136        path: String,
137    ) -> Result<FileMetadata> {
138        match format {
139            Format::Csv(format) => Ok(FileMetadata::Csv {
140                schema: Arc::new(
141                    format
142                        .infer_schema(object_store, &path)
143                        .await
144                        .context(error::InferSchemaSnafu { path: &path })?,
145                ),
146                format,
147                path,
148            }),
149            Format::Json(format) => Ok(FileMetadata::Json {
150                schema: Arc::new(
151                    format
152                        .infer_schema(object_store, &path)
153                        .await
154                        .context(error::InferSchemaSnafu { path: &path })?,
155                ),
156                format,
157                path,
158            }),
159            Format::Parquet(_) => {
160                let meta = object_store
161                    .stat(&path)
162                    .await
163                    .context(error::ReadObjectSnafu { path: &path })?;
164                let mut reader = object_store
165                    .reader(&path)
166                    .await
167                    .context(error::ReadObjectSnafu { path: &path })?
168                    .into_futures_async_read(0..meta.content_length())
169                    .await
170                    .context(error::ReadObjectSnafu { path: &path })?
171                    .compat();
172                let metadata = ArrowReaderMetadata::load_async(&mut reader, Default::default())
173                    .await
174                    .context(error::ReadParquetMetadataSnafu)?;
175
176                Ok(FileMetadata::Parquet {
177                    schema: metadata.schema().clone(),
178                    metadata,
179                    path,
180                })
181            }
182            Format::Orc(_) => {
183                let meta = object_store
184                    .stat(&path)
185                    .await
186                    .context(error::ReadObjectSnafu { path: &path })?;
187
188                let reader = object_store
189                    .reader(&path)
190                    .await
191                    .context(error::ReadObjectSnafu { path: &path })?;
192
193                let schema = infer_orc_schema(ReaderAdapter::new(reader, meta.content_length()))
194                    .await
195                    .context(error::ReadOrcSnafu)?;
196
197                Ok(FileMetadata::Orc {
198                    schema: Arc::new(schema),
199                    path,
200                })
201            }
202        }
203    }
204
205    async fn build_read_stream(
206        &self,
207        compat_schema: SchemaRef,
208        object_store: &ObjectStore,
209        file_metadata: &FileMetadata,
210        projection: Vec<usize>,
211        filters: Vec<Expr>,
212    ) -> Result<DfSendableRecordBatchStream> {
213        match file_metadata {
214            FileMetadata::Csv {
215                format,
216                path,
217                schema,
218            } => {
219                let output_schema = Arc::new(
220                    compat_schema
221                        .project(&projection)
222                        .context(error::ProjectSchemaSnafu)?,
223                );
224
225                let options = CsvOptions::default()
226                    .with_has_header(format.has_header)
227                    .with_delimiter(format.delimiter);
228                let csv_source = CsvSource::new(schema.clone())
229                    .with_csv_options(options)
230                    .with_batch_size(DEFAULT_BATCH_SIZE);
231                let stream = if format.skip_bad_records {
232                    let reader_schema =
233                        csv_reader_schema_for_skip_bad_records(schema, &compat_schema);
234                    tolerant_csv_stream(
235                        object_store,
236                        path,
237                        Arc::new(reader_schema),
238                        projection.clone(),
239                        format,
240                    )
241                    .await
242                    .context(error::BuildFileStreamSnafu)?
243                } else {
244                    file_to_stream(
245                        object_store,
246                        path,
247                        csv_source,
248                        Some(projection),
249                        format.compression_type,
250                    )
251                    .await
252                    .context(error::BuildFileStreamSnafu)?
253                };
254
255                let stream = Box::pin(
256                    // The projection is already applied in the CSV reader when we created the stream,
257                    // so we pass None here to avoid double projection which would cause schema mismatch errors.
258                    RecordBatchStreamTypeAdapter::new(output_schema, stream, None)
259                        .with_filter(filters)
260                        .context(error::PhysicalExprSnafu)?,
261                );
262                if format.skip_bad_records {
263                    Ok(Box::pin(SkipBadRecordsStream::new(stream, path)))
264                } else {
265                    Ok(stream)
266                }
267            }
268            FileMetadata::Json {
269                path,
270                format,
271                schema,
272            } => {
273                let output_schema = Arc::new(
274                    compat_schema
275                        .project(&projection)
276                        .context(error::ProjectSchemaSnafu)?,
277                );
278
279                let json_source =
280                    JsonSource::new(schema.clone()).with_batch_size(DEFAULT_BATCH_SIZE);
281                let stream = file_to_stream(
282                    object_store,
283                    path,
284                    json_source,
285                    Some(projection),
286                    format.compression_type,
287                )
288                .await
289                .context(error::BuildFileStreamSnafu)?;
290
291                Ok(Box::pin(
292                    // The projection is already applied in the JSON reader when we created the stream,
293                    // so we pass None here to avoid double projection which would cause schema mismatch errors.
294                    RecordBatchStreamTypeAdapter::new(output_schema, stream, None)
295                        .with_filter(filters)
296                        .context(error::PhysicalExprSnafu)?,
297                ))
298            }
299            FileMetadata::Parquet { metadata, path, .. } => {
300                let meta = object_store
301                    .stat(path)
302                    .await
303                    .context(error::ReadObjectSnafu { path })?;
304                let reader = object_store
305                    .reader_with(path)
306                    .chunk(DEFAULT_READ_BUFFER)
307                    .await
308                    .context(error::ReadObjectSnafu { path })?
309                    .into_futures_async_read(0..meta.content_length())
310                    .await
311                    .context(error::ReadObjectSnafu { path })?
312                    .compat();
313                let builder =
314                    ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata.clone());
315                let stream = builder
316                    .build()
317                    .context(error::BuildParquetRecordBatchStreamSnafu)?;
318
319                let output_schema = Arc::new(
320                    compat_schema
321                        .project(&projection)
322                        .context(error::ProjectSchemaSnafu)?,
323                );
324                Ok(Box::pin(
325                    RecordBatchStreamTypeAdapter::new(output_schema, stream, Some(projection))
326                        .with_filter(filters)
327                        .context(error::PhysicalExprSnafu)?,
328                ))
329            }
330            FileMetadata::Orc { path, .. } => {
331                let meta = object_store
332                    .stat(path)
333                    .await
334                    .context(error::ReadObjectSnafu { path })?;
335
336                let reader = object_store
337                    .reader_with(path)
338                    .chunk(DEFAULT_READ_BUFFER)
339                    .await
340                    .context(error::ReadObjectSnafu { path })?;
341                let stream =
342                    new_orc_stream_reader(ReaderAdapter::new(reader, meta.content_length()))
343                        .await
344                        .context(error::ReadOrcSnafu)?;
345
346                let output_schema = Arc::new(
347                    compat_schema
348                        .project(&projection)
349                        .context(error::ProjectSchemaSnafu)?,
350                );
351
352                Ok(Box::pin(
353                    RecordBatchStreamTypeAdapter::new(output_schema, stream, Some(projection))
354                        .with_filter(filters)
355                        .context(error::PhysicalExprSnafu)?,
356                ))
357            }
358        }
359    }
360
361    #[tracing::instrument(skip_all)]
362    pub async fn copy_table_from(
363        &self,
364        req: CopyTableRequest,
365        query_ctx: QueryContextRef,
366    ) -> Result<Output> {
367        let table_ref = TableReference {
368            catalog: &req.catalog_name,
369            schema: &req.schema_name,
370            table: &req.table_name,
371        };
372        let table = self.get_table(&table_ref).await?;
373        let format = Format::try_from(&req.with).context(error::ParseFileFormatSnafu)?;
374        let (object_store, entries) = self.list_copy_from_entries(&req).await?;
375        let mut files = Vec::with_capacity(entries.len());
376        let table_schema = table.schema().arrow_schema().clone();
377        let filters = table
378            .schema()
379            .timestamp_column()
380            .and_then(|c| {
381                common_query::logical_plan::build_same_type_ts_filter(c, req.timestamp_range)
382            })
383            .into_iter()
384            .collect::<Vec<_>>();
385
386        for entry in entries.iter() {
387            if entry.metadata().mode() != EntryMode::FILE {
388                continue;
389            }
390            let path = entry.path();
391            let file_metadata = self
392                .collect_metadata(&object_store, format.clone(), path.to_string())
393                .await?;
394
395            let schema_mapping = copy_from_schema_mapping(&file_metadata, &table_schema);
396            let projected_file_schema = Arc::new(
397                file_metadata
398                    .schema()
399                    .project(&schema_mapping.file_projection)
400                    .context(error::ProjectSchemaSnafu)?,
401            );
402            let projected_table_schema = Arc::new(
403                table_schema
404                    .project(&schema_mapping.table_projection)
405                    .context(error::ProjectSchemaSnafu)?,
406            );
407            ensure_schema_compatible(&projected_file_schema, &projected_table_schema)?;
408
409            files.push((
410                Arc::new(schema_mapping.compat_file_schema),
411                schema_mapping.file_projection,
412                projected_table_schema,
413                file_metadata,
414            ))
415        }
416
417        let mut rows_inserted = 0;
418        let mut insert_cost = 0;
419        let max_insert_rows = req.limit.map(|n| n as usize);
420        for (compat_schema, file_schema_projection, projected_table_schema, file_metadata) in files
421        {
422            let mut stream = self
423                .build_read_stream(
424                    compat_schema,
425                    &object_store,
426                    &file_metadata,
427                    file_schema_projection,
428                    filters.clone(),
429                )
430                .await?;
431
432            let fields = projected_table_schema
433                .fields()
434                .iter()
435                .map(|f| f.name().clone())
436                .collect::<Vec<_>>();
437
438            // TODO(hl): make this configurable through options.
439            let pending_mem_threshold = ReadableSize::mb(32).as_bytes();
440            let mut pending_mem_size = 0;
441            let mut pending = vec![];
442
443            while let Some(r) = stream.next().await {
444                let record_batch = r.context(error::ReadDfRecordBatchSnafu)?;
445                let vectors =
446                    Helper::try_into_vectors(record_batch.columns()).context(IntoVectorsSnafu)?;
447
448                pending_mem_size += vectors.iter().map(|v| v.memory_size()).sum::<usize>();
449
450                let columns_values = fields
451                    .iter()
452                    .cloned()
453                    .zip(vectors)
454                    .collect::<HashMap<_, _>>();
455
456                pending.push(self.inserter.handle_table_insert(
457                    InsertRequest {
458                        catalog_name: req.catalog_name.clone(),
459                        schema_name: req.schema_name.clone(),
460                        table_name: req.table_name.clone(),
461                        columns_values,
462                    },
463                    query_ctx.clone(),
464                ));
465
466                if pending_mem_size as u64 >= pending_mem_threshold {
467                    let (rows, cost) = batch_insert(&mut pending, &mut pending_mem_size).await?;
468                    rows_inserted += rows;
469                    insert_cost += cost;
470                }
471
472                if let Some(max_insert_rows) = max_insert_rows
473                    && rows_inserted >= max_insert_rows
474                {
475                    return Ok(gen_insert_output(rows_inserted, insert_cost));
476                }
477            }
478
479            if !pending.is_empty() {
480                let (rows, cost) = batch_insert(&mut pending, &mut pending_mem_size).await?;
481                rows_inserted += rows;
482                insert_cost += cost;
483            }
484        }
485
486        Ok(gen_insert_output(rows_inserted, insert_cost))
487    }
488}
489
490fn gen_insert_output(rows_inserted: usize, insert_cost: usize) -> Output {
491    Output::new(
492        OutputData::AffectedRows(rows_inserted),
493        OutputMeta::new_with_cost(insert_cost),
494    )
495}
496
497struct SkipBadRecordsStream {
498    inner: DfSendableRecordBatchStream,
499    path: String,
500}
501
502impl SkipBadRecordsStream {
503    fn new(inner: DfSendableRecordBatchStream, path: impl Into<String>) -> Self {
504        Self {
505            inner,
506            path: path.into(),
507        }
508    }
509}
510
511impl datafusion::physical_plan::RecordBatchStream for SkipBadRecordsStream {
512    fn schema(&self) -> SchemaRef {
513        self.inner.schema()
514    }
515}
516
517impl futures::Stream for SkipBadRecordsStream {
518    type Item = datafusion_common::Result<RecordBatch>;
519
520    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
521        let this = self.get_mut();
522        loop {
523            match this.inner.as_mut().poll_next(cx) {
524                Poll::Ready(Some(Err(error))) if is_skippable_record_error(&error) => {
525                    common_telemetry::warn!(
526                        "Skipping bad record while copying from {}: {}",
527                        this.path,
528                        error
529                    );
530                    continue;
531                }
532                other => return other,
533            }
534        }
535    }
536}
537
538fn is_skippable_record_error(error: &DataFusionError) -> bool {
539    match error {
540        DataFusionError::ArrowError(error, _) => is_skippable_arrow_error(error),
541        DataFusionError::External(error) => error
542            .downcast_ref::<ArrowError>()
543            .is_some_and(is_skippable_arrow_error),
544        DataFusionError::Context(_, error) => is_skippable_record_error(error),
545        _ => false,
546    }
547}
548
549/// Executes all pending inserts all at once, drain pending requests and reset pending bytes.
550async fn batch_insert(
551    pending: &mut Vec<impl Future<Output = Result<Output>>>,
552    pending_bytes: &mut usize,
553) -> Result<(OutputRows, OutputCost)> {
554    let batch = pending.drain(..);
555    let result = futures::future::try_join_all(batch)
556        .await?
557        .iter()
558        .map(|o| o.extract_rows_and_cost())
559        .reduce(|(a, b), (c, d)| (a + c, b + d))
560        .unwrap_or((0, 0));
561    *pending_bytes = 0;
562    Ok(result)
563}
564
565/// Custom type compatibility check for GreptimeDB that handles Map -> Binary (JSON) conversion
566fn can_cast_types_for_greptime(from: &ArrowDataType, to: &ArrowDataType) -> bool {
567    // Handle Map -> Binary conversion for JSON types
568    if let ArrowDataType::Map(_, _) = from
569        && let ArrowDataType::Binary = to
570    {
571        return true;
572    }
573
574    // For all other cases, use Arrow's built-in can_cast_types
575    can_cast_types(from, to)
576}
577
578fn csv_reader_schema_for_skip_bad_records(file: &SchemaRef, compat: &SchemaRef) -> Schema {
579    let fields = file
580        .fields()
581        .iter()
582        .enumerate()
583        .map(|(idx, file_field)| match compat.fields().get(idx) {
584            Some(compat_field) if can_csv_reader_parse_type(compat_field.data_type()) => {
585                compat_field.clone()
586            }
587            _ => file_field.clone(),
588        })
589        .collect::<Vec<_>>();
590
591    Schema::new_with_metadata(fields, file.metadata().clone())
592}
593
594fn can_csv_reader_parse_type(data_type: &ArrowDataType) -> bool {
595    match data_type {
596        ArrowDataType::Boolean
597        | ArrowDataType::Decimal32(_, _)
598        | ArrowDataType::Decimal64(_, _)
599        | ArrowDataType::Decimal128(_, _)
600        | ArrowDataType::Decimal256(_, _)
601        | ArrowDataType::Int8
602        | ArrowDataType::Int16
603        | ArrowDataType::Int32
604        | ArrowDataType::Int64
605        | ArrowDataType::UInt8
606        | ArrowDataType::UInt16
607        | ArrowDataType::UInt32
608        | ArrowDataType::UInt64
609        | ArrowDataType::Float32
610        | ArrowDataType::Float64
611        | ArrowDataType::Date32
612        | ArrowDataType::Date64
613        | ArrowDataType::Time32(_)
614        | ArrowDataType::Time64(_)
615        | ArrowDataType::Timestamp(_, _)
616        | ArrowDataType::Null
617        | ArrowDataType::Utf8
618        | ArrowDataType::Utf8View => true,
619        ArrowDataType::Dictionary(_, value_type) => value_type.as_ref() == &ArrowDataType::Utf8,
620        _ => false,
621    }
622}
623
624fn ensure_schema_compatible(from: &SchemaRef, to: &SchemaRef) -> Result<()> {
625    let not_match = from
626        .fields
627        .iter()
628        .zip(to.fields.iter())
629        .map(|(l, r)| (l.data_type(), r.data_type()))
630        .enumerate()
631        .find(|(_, (l, r))| !can_cast_types_for_greptime(l, r));
632
633    if let Some((index, _)) = not_match {
634        error::InvalidSchemaSnafu {
635            index,
636            table_schema: to.to_string(),
637            file_schema: from.to_string(),
638        }
639        .fail()
640    } else {
641        Ok(())
642    }
643}
644
645/// Generates a maybe compatible schema of the file schema.
646///
647/// If there is a field is found in table schema,
648/// copy the field data type to maybe compatible schema(`compatible_fields`).
649fn generated_schema_projection_and_compatible_file_schema(
650    file: &SchemaRef,
651    table: &SchemaRef,
652) -> (Vec<usize>, Vec<usize>, Schema) {
653    let mut file_projection = Vec::with_capacity(file.fields.len());
654    let mut table_projection = Vec::with_capacity(file.fields.len());
655    let mut compatible_fields = file.fields.iter().cloned().collect::<Vec<_>>();
656    for (file_idx, file_field) in file.fields.iter().enumerate() {
657        if let Some((table_idx, table_field)) = table.fields.find(file_field.name()) {
658            file_projection.push(file_idx);
659            table_projection.push(table_idx);
660
661            // Safety: the compatible_fields has same length as file schema
662            compatible_fields[file_idx] = table_field.clone();
663        }
664    }
665
666    (
667        file_projection,
668        table_projection,
669        Schema::new(compatible_fields),
670    )
671}
672
673struct CopyFromSchemaMapping {
674    file_projection: Vec<usize>,
675    table_projection: Vec<usize>,
676    compat_file_schema: Schema,
677}
678
679fn copy_from_schema_mapping(
680    file_metadata: &FileMetadata,
681    table: &SchemaRef,
682) -> CopyFromSchemaMapping {
683    match file_metadata {
684        FileMetadata::Csv { schema, format, .. } if !format.has_header => {
685            generated_positional_schema_projection_and_compatible_file_schema(schema, table)
686        }
687        _ => {
688            let (file_projection, table_projection, compat_file_schema) =
689                generated_schema_projection_and_compatible_file_schema(
690                    file_metadata.schema(),
691                    table,
692                );
693            CopyFromSchemaMapping {
694                file_projection,
695                table_projection,
696                compat_file_schema,
697            }
698        }
699    }
700}
701
702fn generated_positional_schema_projection_and_compatible_file_schema(
703    file: &SchemaRef,
704    table: &SchemaRef,
705) -> CopyFromSchemaMapping {
706    let len = file.fields.len().min(table.fields.len());
707    let file_projection = (0..len).collect::<Vec<_>>();
708    let table_projection = (0..len).collect::<Vec<_>>();
709    let compatible_fields = file
710        .fields
711        .iter()
712        .enumerate()
713        .map(|(idx, file_field)| {
714            if idx < len {
715                table.fields[idx].clone()
716            } else {
717                file_field.clone()
718            }
719        })
720        .collect::<Vec<_>>();
721
722    CopyFromSchemaMapping {
723        file_projection,
724        table_projection,
725        compat_file_schema: Schema::new(compatible_fields),
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use std::sync::Arc;
732
733    use datatypes::arrow::datatypes::{DataType, Field, Schema};
734
735    use super::*;
736
737    fn test_schema_matches(from: (DataType, bool), to: (DataType, bool), matches: bool) {
738        let s1 = Arc::new(Schema::new(vec![Field::new("col", from.0.clone(), from.1)]));
739        let s2 = Arc::new(Schema::new(vec![Field::new("col", to.0.clone(), to.1)]));
740        let res = ensure_schema_compatible(&s1, &s2);
741        assert_eq!(
742            matches,
743            res.is_ok(),
744            "from data type: {}, to data type: {}, expected: {}, but got: {}",
745            from.0,
746            to.0,
747            matches,
748            res.is_ok()
749        )
750    }
751
752    #[test]
753    fn test_ensure_datatype_matches_ignore_timezone() {
754        test_schema_matches(
755            (
756                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, None),
757                true,
758            ),
759            (
760                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, None),
761                true,
762            ),
763            true,
764        );
765
766        test_schema_matches(
767            (
768                DataType::Timestamp(
769                    datatypes::arrow::datatypes::TimeUnit::Second,
770                    Some("UTC".into()),
771                ),
772                true,
773            ),
774            (
775                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, None),
776                true,
777            ),
778            true,
779        );
780
781        test_schema_matches(
782            (
783                DataType::Timestamp(
784                    datatypes::arrow::datatypes::TimeUnit::Second,
785                    Some("UTC".into()),
786                ),
787                true,
788            ),
789            (
790                DataType::Timestamp(
791                    datatypes::arrow::datatypes::TimeUnit::Second,
792                    Some("PDT".into()),
793                ),
794                true,
795            ),
796            true,
797        );
798
799        test_schema_matches(
800            (
801                DataType::Timestamp(
802                    datatypes::arrow::datatypes::TimeUnit::Second,
803                    Some("UTC".into()),
804                ),
805                true,
806            ),
807            (
808                DataType::Timestamp(
809                    datatypes::arrow::datatypes::TimeUnit::Millisecond,
810                    Some("UTC".into()),
811                ),
812                true,
813            ),
814            true,
815        );
816
817        test_schema_matches((DataType::Int8, true), (DataType::Int8, true), true);
818
819        test_schema_matches((DataType::Int8, true), (DataType::Int16, true), true);
820    }
821
822    #[test]
823    fn test_data_type_equals_ignore_timezone_with_options() {
824        test_schema_matches(
825            (
826                DataType::Timestamp(
827                    datatypes::arrow::datatypes::TimeUnit::Microsecond,
828                    Some("UTC".into()),
829                ),
830                true,
831            ),
832            (
833                DataType::Timestamp(
834                    datatypes::arrow::datatypes::TimeUnit::Millisecond,
835                    Some("PDT".into()),
836                ),
837                true,
838            ),
839            true,
840        );
841
842        test_schema_matches(
843            (DataType::Utf8, true),
844            (
845                DataType::Timestamp(
846                    datatypes::arrow::datatypes::TimeUnit::Millisecond,
847                    Some("PDT".into()),
848                ),
849                true,
850            ),
851            true,
852        );
853
854        test_schema_matches(
855            (
856                DataType::Timestamp(
857                    datatypes::arrow::datatypes::TimeUnit::Millisecond,
858                    Some("PDT".into()),
859                ),
860                true,
861            ),
862            (DataType::Utf8, true),
863            true,
864        );
865    }
866
867    #[test]
868    fn test_map_to_binary_json_compatibility() {
869        // Test Map -> Binary conversion for JSON types
870        let map_type = DataType::Map(
871            Arc::new(Field::new(
872                "key_value",
873                DataType::Struct(
874                    vec![
875                        Field::new("key", DataType::Utf8, false),
876                        Field::new("value", DataType::Utf8, false),
877                    ]
878                    .into(),
879                ),
880                false,
881            )),
882            false,
883        );
884
885        test_schema_matches((map_type, false), (DataType::Binary, true), true);
886
887        test_schema_matches((DataType::Int8, true), (DataType::Int16, true), true);
888        test_schema_matches((DataType::Utf8, true), (DataType::Binary, true), true);
889    }
890
891    fn make_test_schema(v: &[Field]) -> Arc<Schema> {
892        Arc::new(Schema::new(v.to_vec()))
893    }
894
895    #[test]
896    fn test_compatible_file_schema() {
897        let file_schema0 = make_test_schema(&[
898            Field::new("c1", DataType::UInt8, true),
899            Field::new("c2", DataType::UInt8, true),
900        ]);
901
902        let table_schema = make_test_schema(&[
903            Field::new("c1", DataType::Int16, true),
904            Field::new("c2", DataType::Int16, true),
905            Field::new("c3", DataType::Int16, true),
906        ]);
907
908        let compat_schema = make_test_schema(&[
909            Field::new("c1", DataType::Int16, true),
910            Field::new("c2", DataType::Int16, true),
911        ]);
912
913        let (_, tp, _) =
914            generated_schema_projection_and_compatible_file_schema(&file_schema0, &table_schema);
915
916        assert_eq!(table_schema.project(&tp).unwrap(), *compat_schema);
917    }
918
919    #[test]
920    fn test_schema_projection() {
921        let file_schema0 = make_test_schema(&[
922            Field::new("c1", DataType::UInt8, true),
923            Field::new("c2", DataType::UInt8, true),
924            Field::new("c3", DataType::UInt8, true),
925        ]);
926
927        let file_schema1 = make_test_schema(&[
928            Field::new("c3", DataType::UInt8, true),
929            Field::new("c4", DataType::UInt8, true),
930        ]);
931
932        let file_schema2 = make_test_schema(&[
933            Field::new("c3", DataType::UInt8, true),
934            Field::new("c4", DataType::UInt8, true),
935            Field::new("c5", DataType::UInt8, true),
936        ]);
937
938        let file_schema3 = make_test_schema(&[
939            Field::new("c1", DataType::UInt8, true),
940            Field::new("c2", DataType::UInt8, true),
941        ]);
942
943        let table_schema = make_test_schema(&[
944            Field::new("c3", DataType::UInt8, true),
945            Field::new("c4", DataType::UInt8, true),
946            Field::new("c5", DataType::UInt8, true),
947        ]);
948
949        let tests = [
950            (&file_schema0, &table_schema, true), // intersection
951            (&file_schema1, &table_schema, true), // subset
952            (&file_schema2, &table_schema, true), // full-eq
953            (&file_schema3, &table_schema, true), // non-intersection
954        ];
955
956        for test in tests {
957            let (fp, tp, _) =
958                generated_schema_projection_and_compatible_file_schema(test.0, test.1);
959            assert_eq!(test.0.project(&fp).unwrap(), test.1.project(&tp).unwrap());
960        }
961    }
962
963    #[test]
964    fn test_csv_reader_schema_for_skip_bad_records() {
965        let file_schema = make_test_schema(&[
966            Field::new("id", DataType::Utf8, true),
967            Field::new("jsons", DataType::Utf8, true),
968            Field::new("ts", DataType::Utf8, true),
969        ]);
970        let compat_schema = make_test_schema(&[
971            Field::new("id", DataType::UInt32, true),
972            Field::new("jsons", DataType::Binary, true),
973            Field::new(
974                "ts",
975                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
976                true,
977            ),
978        ]);
979
980        let reader_schema = csv_reader_schema_for_skip_bad_records(&file_schema, &compat_schema);
981
982        assert_eq!(reader_schema.field(0).data_type(), &DataType::UInt32);
983        assert_eq!(reader_schema.field(1).data_type(), &DataType::Utf8);
984        assert_eq!(
985            reader_schema.field(2).data_type(),
986            compat_schema.field(2).data_type()
987        );
988    }
989
990    fn make_csv_metadata(schema: Arc<Schema>, has_header: bool) -> FileMetadata {
991        FileMetadata::Csv {
992            schema,
993            format: CsvFormat {
994                has_header,
995                ..CsvFormat::default()
996            },
997            path: "test.csv".to_string(),
998        }
999    }
1000
1001    fn assert_field(schema: &Schema, idx: usize, name: &str, data_type: &DataType) {
1002        let field = schema.field(idx);
1003        assert_eq!(field.name(), name);
1004        assert_eq!(field.data_type(), data_type);
1005    }
1006
1007    #[test]
1008    fn test_headerless_csv_schema_projection_is_positional() {
1009        let file_schema = make_test_schema(&[
1010            Field::new("column_1", DataType::UInt8, true),
1011            Field::new("column_2", DataType::Float64, true),
1012            Field::new("column_3", DataType::Utf8, true),
1013        ]);
1014        let table_schema = make_test_schema(&[
1015            Field::new("host_id", DataType::UInt32, true),
1016            Field::new("reading_value", DataType::Float64, true),
1017            Field::new(
1018                "ts",
1019                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
1020                true,
1021            ),
1022        ]);
1023
1024        let mapping =
1025            copy_from_schema_mapping(&make_csv_metadata(file_schema, false), &table_schema);
1026
1027        assert_eq!(mapping.file_projection, vec![0, 1, 2]);
1028        assert_eq!(mapping.table_projection, vec![0, 1, 2]);
1029        assert_field(&mapping.compat_file_schema, 0, "host_id", &DataType::UInt32);
1030        assert_field(
1031            &mapping.compat_file_schema,
1032            1,
1033            "reading_value",
1034            &DataType::Float64,
1035        );
1036        assert_field(
1037            &mapping.compat_file_schema,
1038            2,
1039            "ts",
1040            table_schema.field(2).data_type(),
1041        );
1042        assert_eq!(
1043            mapping
1044                .compat_file_schema
1045                .project(&mapping.file_projection)
1046                .unwrap(),
1047            table_schema.project(&mapping.table_projection).unwrap()
1048        );
1049    }
1050
1051    #[test]
1052    fn test_headerless_csv_schema_projection_ignores_extra_file_columns() {
1053        let file_schema = make_test_schema(&[
1054            Field::new("column_1", DataType::UInt8, true),
1055            Field::new("column_2", DataType::Float64, true),
1056            Field::new("column_3", DataType::Utf8, true),
1057            Field::new("column_4", DataType::Utf8, true),
1058        ]);
1059        let table_schema = make_test_schema(&[
1060            Field::new("host_id", DataType::UInt32, true),
1061            Field::new("reading_value", DataType::Float64, true),
1062            Field::new("ts", DataType::Utf8, true),
1063        ]);
1064
1065        let mapping =
1066            copy_from_schema_mapping(&make_csv_metadata(file_schema, false), &table_schema);
1067
1068        assert_eq!(mapping.file_projection, vec![0, 1, 2]);
1069        assert_eq!(mapping.table_projection, vec![0, 1, 2]);
1070        assert_eq!(mapping.compat_file_schema.fields().len(), 4);
1071        assert_field(&mapping.compat_file_schema, 0, "host_id", &DataType::UInt32);
1072        assert_field(
1073            &mapping.compat_file_schema,
1074            1,
1075            "reading_value",
1076            &DataType::Float64,
1077        );
1078        assert_field(&mapping.compat_file_schema, 2, "ts", &DataType::Utf8);
1079        assert_field(&mapping.compat_file_schema, 3, "column_4", &DataType::Utf8);
1080    }
1081
1082    #[test]
1083    fn test_headerless_csv_schema_projection_supports_prefix_import() {
1084        let file_schema = make_test_schema(&[
1085            Field::new("column_1", DataType::UInt8, true),
1086            Field::new("column_2", DataType::Float64, true),
1087        ]);
1088        let table_schema = make_test_schema(&[
1089            Field::new("host_id", DataType::UInt32, true),
1090            Field::new("reading_value", DataType::Float64, true),
1091            Field::new("ts", DataType::Utf8, true),
1092        ]);
1093
1094        let mapping =
1095            copy_from_schema_mapping(&make_csv_metadata(file_schema, false), &table_schema);
1096
1097        assert_eq!(mapping.file_projection, vec![0, 1]);
1098        assert_eq!(mapping.table_projection, vec![0, 1]);
1099        assert_field(&mapping.compat_file_schema, 0, "host_id", &DataType::UInt32);
1100        assert_field(
1101            &mapping.compat_file_schema,
1102            1,
1103            "reading_value",
1104            &DataType::Float64,
1105        );
1106        assert_eq!(
1107            mapping
1108                .compat_file_schema
1109                .project(&mapping.file_projection)
1110                .unwrap(),
1111            table_schema.project(&mapping.table_projection).unwrap()
1112        );
1113    }
1114
1115    #[test]
1116    fn test_csv_reader_schema_for_skip_bad_records_uses_positional_mapping() {
1117        let file_schema = make_test_schema(&[
1118            Field::new("column_1", DataType::Utf8, true),
1119            Field::new("column_2", DataType::Utf8, true),
1120            Field::new("column_3", DataType::Utf8, true),
1121        ]);
1122        let table_schema = make_test_schema(&[
1123            Field::new("host_id", DataType::UInt32, true),
1124            Field::new("jsons", DataType::Binary, true),
1125            Field::new(
1126                "ts",
1127                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
1128                true,
1129            ),
1130        ]);
1131        let mapping = copy_from_schema_mapping(
1132            &make_csv_metadata(file_schema.clone(), false),
1133            &table_schema,
1134        );
1135        let compat_schema = Arc::new(mapping.compat_file_schema);
1136
1137        let reader_schema = csv_reader_schema_for_skip_bad_records(&file_schema, &compat_schema);
1138
1139        assert_eq!(reader_schema.field(0).data_type(), &DataType::UInt32);
1140        assert_eq!(reader_schema.field(1).data_type(), &DataType::Utf8);
1141        assert_eq!(
1142            reader_schema.field(2).data_type(),
1143            table_schema.field(2).data_type()
1144        );
1145    }
1146}