Skip to main content

common_datasource/file_format/
csv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::io;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::task::Poll;
20
21use arrow::csv::reader::Format;
22use arrow::csv::{self, WriterBuilder};
23use arrow::error::ArrowError;
24use arrow::record_batch::RecordBatch;
25use arrow_schema::{Schema, SchemaRef};
26use async_trait::async_trait;
27use bytes::{Buf, Bytes};
28use common_runtime;
29use common_telemetry::warn;
30use datafusion::physical_plan::SendableRecordBatchStream;
31use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
32use futures::StreamExt;
33use futures::stream::BoxStream;
34use object_store::ObjectStore;
35use snafu::ResultExt;
36use tokio_util::compat::FuturesAsyncReadCompatExt;
37use tokio_util::io::SyncIoBridge;
38
39use crate::buffered_writer::DfRecordBatchEncoder;
40use crate::compression::CompressionType;
41use crate::error::{self, Result};
42use crate::file_format::{self, FileFormat, stream_to_file};
43use crate::share_buffer::SharedBuffer;
44use crate::util::normalize_infer_schema;
45
46const SKIP_BAD_RECORDS_BATCH_SIZE: usize = 1;
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct CsvFormat {
50    pub has_header: bool,
51    pub skip_bad_records: bool,
52    pub delimiter: u8,
53    pub schema_infer_max_record: Option<usize>,
54    pub compression_type: CompressionType,
55    pub timestamp_format: Option<String>,
56    pub time_format: Option<String>,
57    pub date_format: Option<String>,
58}
59
60impl TryFrom<&HashMap<String, String>> for CsvFormat {
61    type Error = error::Error;
62
63    fn try_from(value: &HashMap<String, String>) -> Result<Self> {
64        let mut format = CsvFormat::default();
65        if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
66            // TODO(weny): considers to support parse like "\t" (not only b'\t')
67            format.delimiter = u8::from_str(delimiter).map_err(|_| {
68                error::ParseFormatSnafu {
69                    key: file_format::FORMAT_DELIMITER,
70                    value: delimiter,
71                }
72                .build()
73            })?;
74        };
75        if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
76            format.compression_type = CompressionType::from_str(compression_type)?;
77        };
78        if let Some(schema_infer_max_record) =
79            value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
80        {
81            format.schema_infer_max_record =
82                Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
83                    error::ParseFormatSnafu {
84                        key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
85                        value: schema_infer_max_record,
86                    }
87                    .build()
88                })?);
89        };
90        let headers = value
91            .get(file_format::FORMAT_HEADERS)
92            .map(|headers| parse_bool(file_format::FORMAT_HEADERS, headers))
93            .transpose()?;
94        let has_header = value
95            .get(file_format::FORMAT_HAS_HEADER)
96            .map(|has_header| parse_bool(file_format::FORMAT_HAS_HEADER, has_header))
97            .transpose()?;
98        match (headers, has_header) {
99            (Some(headers), Some(has_header)) if headers != has_header => {
100                return error::ParseFormatSnafu {
101                    key: file_format::FORMAT_HEADERS,
102                    value: format!("headers={headers}, has_header={has_header}"),
103                }
104                .fail();
105            }
106            (Some(headers), _) => format.has_header = headers,
107            (_, Some(has_header)) => format.has_header = has_header,
108            _ => {}
109        }
110        if let Some(skip_bad_records) = value.get(file_format::FORMAT_SKIP_BAD_RECORDS) {
111            format.skip_bad_records =
112                parse_bool(file_format::FORMAT_SKIP_BAD_RECORDS, skip_bad_records)?;
113        };
114        if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
115            format.timestamp_format = Some(timestamp_format.clone());
116        }
117        if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
118            format.time_format = Some(time_format.clone());
119        }
120        if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
121            format.date_format = Some(date_format.clone());
122        }
123        Ok(format)
124    }
125}
126
127fn parse_bool(key: &'static str, value: &str) -> Result<bool> {
128    value
129        .parse()
130        .map_err(|_| error::ParseFormatSnafu { key, value }.build())
131}
132
133impl Default for CsvFormat {
134    fn default() -> Self {
135        Self {
136            has_header: true,
137            skip_bad_records: false,
138            delimiter: b',',
139            schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
140            compression_type: CompressionType::Uncompressed,
141            timestamp_format: None,
142            time_format: None,
143            date_format: None,
144        }
145    }
146}
147
148#[async_trait]
149impl FileFormat for CsvFormat {
150    async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
151        let meta = store
152            .stat(path)
153            .await
154            .context(error::ReadObjectSnafu { path })?;
155
156        let reader = store
157            .reader(path)
158            .await
159            .context(error::ReadObjectSnafu { path })?
160            .into_futures_async_read(0..meta.content_length())
161            .await
162            .context(error::ReadObjectSnafu { path })?
163            .compat();
164
165        let decoded = self.compression_type.convert_async_read(reader);
166
167        let delimiter = self.delimiter;
168        let schema_infer_max_record = self.schema_infer_max_record;
169        let has_header = self.has_header;
170
171        common_runtime::spawn_blocking_global(move || {
172            let reader = SyncIoBridge::new(decoded);
173
174            let format = Format::default()
175                .with_delimiter(delimiter)
176                .with_header(has_header);
177            let (schema, _records_read) = format
178                .infer_schema(reader, schema_infer_max_record)
179                .context(error::InferSchemaSnafu)?;
180
181            Ok(normalize_infer_schema(schema))
182        })
183        .await
184        .context(error::JoinHandleSnafu)?
185    }
186}
187
188pub async fn stream_to_csv(
189    stream: SendableRecordBatchStream,
190    store: ObjectStore,
191    path: &str,
192    threshold: usize,
193    concurrency: usize,
194    format: &CsvFormat,
195) -> Result<usize> {
196    stream_to_file(
197        stream,
198        store,
199        path,
200        threshold,
201        concurrency,
202        format.compression_type,
203        |buffer| {
204            let mut builder = WriterBuilder::new();
205            if let Some(timestamp_format) = &format.timestamp_format {
206                builder = builder.with_timestamp_format(timestamp_format.to_owned())
207            }
208            if let Some(date_format) = &format.date_format {
209                builder = builder.with_date_format(date_format.to_owned())
210            }
211            if let Some(time_format) = &format.time_format {
212                builder = builder.with_time_format(time_format.to_owned())
213            }
214            builder.build(buffer)
215        },
216    )
217    .await
218}
219
220impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
221    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
222        self.write(batch).context(error::WriteRecordBatchSnafu)
223    }
224}
225
226/// Builds a CSV stream that can skip selected record-level parse/cast errors.
227///
228/// This recovery path intentionally uses one-record batches. It is slower than
229/// normal CSV scanning, but keeps each parse/cast failure isolated to a single
230/// record. Arrow's CSV decoder clears buffered rows before type parsing, so a
231/// failed multi-row flush cannot be safely retried row by row without replaying
232/// input bytes.
233pub async fn tolerant_csv_stream(
234    store: &ObjectStore,
235    path: &str,
236    schema: SchemaRef,
237    projection: Vec<usize>,
238    format: &CsvFormat,
239) -> Result<SendableRecordBatchStream> {
240    let meta = store
241        .stat(path)
242        .await
243        .context(error::ReadObjectSnafu { path })?;
244
245    let reader = store
246        .reader(path)
247        .await
248        .context(error::ReadObjectSnafu { path })?
249        .into_bytes_stream(0..meta.content_length())
250        .await
251        .context(error::ReadObjectSnafu { path })?;
252
253    let reader = format.compression_type.convert_stream(reader).boxed();
254    tolerant_csv_stream_from_reader(
255        reader,
256        path,
257        schema,
258        projection,
259        format.has_header,
260        format.delimiter,
261    )
262}
263
264fn tolerant_csv_stream_from_reader(
265    reader: BoxStream<'static, io::Result<Bytes>>,
266    path: &str,
267    schema: SchemaRef,
268    projection: Vec<usize>,
269    has_header: bool,
270    delimiter: u8,
271) -> Result<SendableRecordBatchStream> {
272    let projected_schema = Arc::new(
273        schema
274            .project(&projection)
275            .context(error::InferSchemaSnafu)?,
276    );
277    let mut decoder = csv::ReaderBuilder::new(schema)
278        .with_header(has_header)
279        .with_delimiter(delimiter)
280        .with_batch_size(SKIP_BAD_RECORDS_BATCH_SIZE)
281        .with_projection(projection)
282        .build_decoder();
283
284    let path = path.to_string();
285    let mut upstream = reader.fuse();
286    let mut buffered = Bytes::new();
287    let mut input_finished = false;
288    let stream = futures::stream::poll_fn(move |cx| {
289        loop {
290            while !input_finished {
291                if buffered.is_empty() {
292                    match futures::ready!(upstream.poll_next_unpin(cx)) {
293                        Some(Ok(bytes)) if bytes.is_empty() => continue,
294                        Some(Ok(bytes)) => buffered = bytes,
295                        Some(Err(error)) => return Poll::Ready(Some(Err(error.into()))),
296                        None => input_finished = true,
297                    }
298                }
299
300                let decoded = decoder.decode(buffered.as_ref())?;
301                if decoded > 0 {
302                    buffered.advance(decoded);
303                    continue;
304                }
305
306                if decoder.capacity() == 0 || input_finished {
307                    break;
308                }
309
310                if buffered.is_empty() {
311                    continue;
312                }
313
314                return Poll::Ready(Some(Err(ArrowError::ParseError(
315                    "CSV decoder made no progress while input bytes remain".to_string(),
316                ))));
317            }
318
319            match decoder.flush() {
320                Ok(Some(batch)) => return Poll::Ready(Some(Ok(batch))),
321                Ok(None) if input_finished => return Poll::Ready(None),
322                Ok(None) => continue,
323                Err(error) if is_skippable_arrow_error(&error) => {
324                    warn!(
325                        "Skipping bad CSV record while copying from {}: {}",
326                        path, error
327                    );
328                }
329                Err(error) => return Poll::Ready(Some(Err(error))),
330            }
331        }
332    })
333    .map(|result: std::result::Result<RecordBatch, ArrowError>| result.map_err(Into::into));
334
335    Ok(Box::pin(RecordBatchStreamAdapter::new(
336        projected_schema,
337        stream,
338    )))
339}
340
341pub fn is_skippable_arrow_error(error: &ArrowError) -> bool {
342    matches!(
343        error,
344        ArrowError::ParseError(_)
345            | ArrowError::CastError(_)
346            | ArrowError::ComputeError(_)
347            | ArrowError::InvalidArgumentError(_)
348    )
349}
350
351#[cfg(test)]
352mod tests {
353    use std::sync::Arc;
354
355    use arrow_schema::{DataType, Field};
356    use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
357    use common_recordbatch::{RecordBatch, RecordBatches};
358    use common_test_util::find_workspace_path;
359    use datafusion::datasource::physical_plan::{CsvSource, FileSource};
360    use datatypes::prelude::ConcreteDataType;
361    use datatypes::schema::{ColumnSchema, Schema};
362    use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
363    use futures::TryStreamExt;
364
365    use super::*;
366    use crate::file_format::{
367        FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER, FORMAT_HEADERS,
368        FORMAT_SCHEMA_INFER_MAX_RECORD, FORMAT_SKIP_BAD_RECORDS, FileFormat, file_to_stream,
369    };
370    use crate::test_util::{format_schema, test_store};
371
372    fn test_data_root() -> String {
373        find_workspace_path("/src/common/datasource/tests/csv")
374            .display()
375            .to_string()
376    }
377
378    #[tokio::test]
379    async fn infer_schema_basic() {
380        let csv = CsvFormat::default();
381        let store = test_store(&test_data_root());
382        let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
383        let formatted: Vec<_> = format_schema(schema);
384
385        assert_eq!(
386            vec![
387                "c1: Utf8: NULL",
388                "c2: Int64: NULL",
389                "c3: Int64: NULL",
390                "c4: Int64: NULL",
391                "c5: Int64: NULL",
392                "c6: Int64: NULL",
393                "c7: Int64: NULL",
394                "c8: Int64: NULL",
395                "c9: Int64: NULL",
396                "c10: Utf8: NULL",
397                "c11: Float64: NULL",
398                "c12: Float64: NULL",
399                "c13: Utf8: NULL"
400            ],
401            formatted,
402        );
403    }
404
405    #[tokio::test]
406    async fn normalize_infer_schema() {
407        let csv = CsvFormat {
408            schema_infer_max_record: Some(3),
409            ..CsvFormat::default()
410        };
411        let store = test_store(&test_data_root());
412        let schema = csv.infer_schema(&store, "max_infer.csv").await.unwrap();
413        let formatted: Vec<_> = format_schema(schema);
414
415        assert_eq!(
416            vec![
417                "num: Int64: NULL",
418                "str: Utf8: NULL",
419                "ts: Utf8: NULL",
420                "t: Utf8: NULL",
421                "date: Date32: NULL"
422            ],
423            formatted,
424        );
425    }
426
427    #[tokio::test]
428    async fn infer_schema_with_limit() {
429        let csv = CsvFormat {
430            schema_infer_max_record: Some(3),
431            ..CsvFormat::default()
432        };
433        let store = test_store(&test_data_root());
434        let schema = csv
435            .infer_schema(&store, "schema_infer_limit.csv")
436            .await
437            .unwrap();
438        let formatted: Vec<_> = format_schema(schema);
439
440        assert_eq!(
441            vec![
442                "a: Int64: NULL",
443                "b: Float64: NULL",
444                "c: Int64: NULL",
445                "d: Int64: NULL"
446            ],
447            formatted
448        );
449
450        let csv = CsvFormat::default();
451        let store = test_store(&test_data_root());
452        let schema = csv
453            .infer_schema(&store, "schema_infer_limit.csv")
454            .await
455            .unwrap();
456        let formatted: Vec<_> = format_schema(schema);
457
458        assert_eq!(
459            vec![
460                "a: Int64: NULL",
461                "b: Float64: NULL",
462                "c: Int64: NULL",
463                "d: Utf8: NULL"
464            ],
465            formatted
466        );
467    }
468
469    #[test]
470    fn test_try_from() {
471        let map = HashMap::new();
472        let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
473
474        assert_eq!(format, CsvFormat::default());
475
476        let map = HashMap::from([
477            (
478                FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
479                "2000".to_string(),
480            ),
481            (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
482            (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
483            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
484        ]);
485        let format = CsvFormat::try_from(&map).unwrap();
486
487        assert_eq!(
488            format,
489            CsvFormat {
490                compression_type: CompressionType::Zstd,
491                schema_infer_max_record: Some(2000),
492                delimiter: b'\t',
493                has_header: false,
494                skip_bad_records: false,
495                timestamp_format: None,
496                time_format: None,
497                date_format: None
498            }
499        );
500
501        let map = HashMap::from([(FORMAT_SKIP_BAD_RECORDS.to_string(), "true".to_string())]);
502        let format = CsvFormat::try_from(&map).unwrap();
503
504        assert_eq!(
505            format,
506            CsvFormat {
507                skip_bad_records: true,
508                ..CsvFormat::default()
509            }
510        );
511
512        let map = HashMap::from([(FORMAT_HEADERS.to_string(), "true".to_string())]);
513        let format = CsvFormat::try_from(&map).unwrap();
514        assert_eq!(format, CsvFormat::default());
515
516        let map = HashMap::from([(FORMAT_HEADERS.to_string(), "false".to_string())]);
517        let format = CsvFormat::try_from(&map).unwrap();
518        assert_eq!(
519            format,
520            CsvFormat {
521                has_header: false,
522                ..CsvFormat::default()
523            }
524        );
525
526        let map = HashMap::from([
527            (FORMAT_HEADERS.to_string(), "false".to_string()),
528            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
529        ]);
530        let format = CsvFormat::try_from(&map).unwrap();
531        assert_eq!(
532            format,
533            CsvFormat {
534                has_header: false,
535                ..CsvFormat::default()
536            }
537        );
538    }
539
540    #[test]
541    fn test_try_from_rejects_invalid_bool_options() {
542        let map = HashMap::from([(FORMAT_SKIP_BAD_RECORDS.to_string(), "yes".to_string())]);
543        assert!(CsvFormat::try_from(&map).is_err());
544
545        let map = HashMap::from([(FORMAT_HEADERS.to_string(), "yes".to_string())]);
546        assert!(CsvFormat::try_from(&map).is_err());
547    }
548
549    #[test]
550    fn test_try_from_rejects_conflicting_header_options() {
551        let map = HashMap::from([
552            (FORMAT_HEADERS.to_string(), "false".to_string()),
553            (FORMAT_HAS_HEADER.to_string(), "true".to_string()),
554        ]);
555        assert!(CsvFormat::try_from(&map).is_err());
556    }
557
558    #[tokio::test]
559    async fn test_compressed_csv() {
560        // Create test data
561        let column_schemas = vec![
562            ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
563            ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
564            ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
565        ];
566        let schema = Arc::new(Schema::new(column_schemas));
567
568        // Create multiple record batches with different data
569        let batch1_columns: Vec<VectorRef> = vec![
570            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
571            Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
572            Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
573        ];
574        let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
575
576        let batch2_columns: Vec<VectorRef> = vec![
577            Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
578            Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
579            Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
580        ];
581        let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
582
583        let batch3_columns: Vec<VectorRef> = vec![
584            Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
585            Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
586            Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
587        ];
588        let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
589
590        // Combine all batches into a RecordBatches collection
591        let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
592
593        // Test with different compression types
594        let compression_types = vec![
595            CompressionType::Gzip,
596            CompressionType::Bzip2,
597            CompressionType::Xz,
598            CompressionType::Zstd,
599        ];
600
601        // Create a temporary file path
602        let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
603        for compression_type in compression_types {
604            let format = CsvFormat {
605                compression_type,
606                ..CsvFormat::default()
607            };
608
609            // Use correct format without Debug formatter
610            let compressed_file_name =
611                format!("test_compressed_csv.{}", compression_type.file_extension());
612            let compressed_file_path = temp_dir.path().join(&compressed_file_name);
613            let compressed_file_path_str = compressed_file_path.to_str().unwrap();
614
615            // Create a simple file store for testing
616            let store = test_store("/");
617
618            // Export CSV with compression
619            let rows = stream_to_csv(
620                Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
621                store,
622                compressed_file_path_str,
623                1024,
624                1,
625                &format,
626            )
627            .await
628            .unwrap();
629
630            assert_eq!(rows, 9);
631
632            // Verify compressed file was created and has content
633            assert!(compressed_file_path.exists());
634            let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
635            assert!(file_size > 0);
636
637            // Verify the file is actually compressed
638            let file_content = std::fs::read(&compressed_file_path).unwrap();
639            // Compressed files should not start with CSV header
640            // They should have compression magic bytes
641            match compression_type {
642                CompressionType::Gzip => {
643                    // Gzip magic bytes: 0x1f 0x8b
644                    assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
645                    assert_eq!(
646                        file_content[1], 0x8b,
647                        "Gzip file should have 0x8b as second byte"
648                    );
649                }
650                CompressionType::Bzip2 => {
651                    // Bzip2 magic bytes: 'BZ'
652                    assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
653                    assert_eq!(
654                        file_content[1], b'Z',
655                        "Bzip2 file should have 'Z' as second byte"
656                    );
657                }
658                CompressionType::Xz => {
659                    // XZ magic bytes: 0xFD '7zXZ'
660                    assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
661                }
662                CompressionType::Zstd => {
663                    // Zstd magic bytes: 0x28 0xB5 0x2F 0xFD
664                    assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
665                    assert_eq!(
666                        file_content[1], 0xB5,
667                        "Zstd file should have 0xB5 as second byte"
668                    );
669                }
670                _ => {}
671            }
672
673            // Verify the compressed file can be decompressed and content matches original data
674            let store = test_store("/");
675            let schema = Arc::new(
676                CsvFormat {
677                    compression_type,
678                    ..Default::default()
679                }
680                .infer_schema(&store, compressed_file_path_str)
681                .await
682                .unwrap(),
683            );
684            let csv_source = CsvSource::new(schema).with_batch_size(8192);
685
686            let stream = file_to_stream(
687                &store,
688                compressed_file_path_str,
689                csv_source.clone(),
690                None,
691                compression_type,
692            )
693            .await
694            .unwrap();
695
696            let batches = stream.try_collect::<Vec<_>>().await.unwrap();
697            let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
698                .unwrap()
699                .to_string();
700            let expected = r#"+----+---------+-------+
701| id | name    | value |
702+----+---------+-------+
703| 1  | Alice   | 10.5  |
704| 2  | Bob     | 20.3  |
705| 3  | Charlie | 30.7  |
706| 4  | David   | 40.1  |
707| 5  | Eva     | 50.2  |
708| 6  | Frank   | 60.3  |
709| 7  | Grace   | 70.4  |
710| 8  | Henry   | 80.5  |
711| 9  | Ivy     | 90.6  |
712+----+---------+-------+"#;
713            assert_eq!(expected, pretty_print);
714        }
715    }
716
717    #[tokio::test]
718    async fn test_tolerant_csv_stream_continues_after_parse_error() {
719        let temp_dir = common_test_util::temp_dir::create_temp_dir("test_tolerant_csv_stream");
720        let csv_file_path = temp_dir.path().join("input.csv");
721        std::fs::write(
722            &csv_file_path,
723            "id,name,value\n1,Alice,10.5\nbad,Bad,20.0\nworse,Bad,21.0\n2,Bob,30.5",
724        )
725        .unwrap();
726
727        let store = test_store("/");
728        let schema = Arc::new(arrow_schema::Schema::new(vec![
729            Field::new("id", DataType::UInt32, false),
730            Field::new("name", DataType::Utf8, false),
731            Field::new("value", DataType::Float64, false),
732        ]));
733        let path = csv_file_path.to_str().unwrap();
734
735        let stream =
736            tolerant_csv_stream(&store, path, schema, vec![0, 1, 2], &CsvFormat::default())
737                .await
738                .unwrap();
739        let batches = stream.try_collect::<Vec<_>>().await.unwrap();
740        let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
741            .unwrap()
742            .to_string();
743        let expected = r#"+----+-------+-------+
744| id | name  | value |
745+----+-------+-------+
746| 1  | Alice | 10.5  |
747| 2  | Bob   | 30.5  |
748+----+-------+-------+"#;
749        assert_eq!(expected, pretty_print);
750    }
751
752    #[tokio::test]
753    async fn test_tolerant_csv_stream_fails_on_structural_csv_error() {
754        let temp_dir =
755            common_test_util::temp_dir::create_temp_dir("test_tolerant_csv_stream_csv_error");
756        let csv_file_path = temp_dir.path().join("input.csv");
757        std::fs::write(&csv_file_path, "id,name,value\n1,Alice,10.5\n2,Bob\n").unwrap();
758
759        let store = test_store("/");
760        let schema = Arc::new(arrow_schema::Schema::new(vec![
761            Field::new("id", DataType::UInt32, false),
762            Field::new("name", DataType::Utf8, false),
763            Field::new("value", DataType::Float64, false),
764        ]));
765        let path = csv_file_path.to_str().unwrap();
766
767        let stream =
768            tolerant_csv_stream(&store, path, schema, vec![0, 1, 2], &CsvFormat::default())
769                .await
770                .unwrap();
771        let error = stream.try_collect::<Vec<_>>().await.unwrap_err();
772
773        assert!(error.to_string().contains("incorrect number of fields"));
774    }
775}