Skip to content

Commit 92dc4c4

Browse files
mshauneuhouqp
authored andcommitted
Add Decimal128 support. (returnString#297)
1 parent 5ddae0d commit 92dc4c4

File tree

7 files changed

+42
-13
lines changed

7 files changed

+42
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
Cargo.lock
3+
.idea/

convergence-arrow/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@ async-trait = "0.1"
1313
datafusion = "43"
1414
convergence = { path = "../convergence", version = "0.16.0" }
1515
chrono = "0.4"
16-
17-
[dev-dependencies]
1816
tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4" ] }
17+
rust_decimal = { version = "1.37.1", features = ["default", "db-postgres"] }

convergence-arrow/src/table.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState};
44
use convergence::protocol_ext::DataRowBatch;
55
use datafusion::arrow::array::{
6-
BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array,
7-
Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
8-
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
6+
BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, Float64Array, Int16Array,
7+
Int32Array, Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray,
8+
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
9+
UInt8Array,
910
};
1011
use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
1112
use datafusion::arrow::record_batch::RecordBatch;
@@ -47,6 +48,7 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat
4748
DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()),
4849
DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)),
4950
DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)),
51+
DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s),
5052
DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)),
5153
DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)),
5254
DataType::Date32 => {
@@ -103,8 +105,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
103105
DataType::UInt64 => DataTypeOid::Int8,
104106
DataType::Float16 | DataType::Float32 => DataTypeOid::Float4,
105107
DataType::Float64 => DataTypeOid::Float8,
106-
DataType::Utf8 => DataTypeOid::Text,
107-
DataType::Utf8View => DataTypeOid::Text,
108+
DataType::Decimal128(_, _) => DataTypeOid::Numeric,
109+
DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
108110
DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
109111
DataType::Timestamp(_, None) => DataTypeOid::Timestamp,
110112
other => {

convergence-arrow/tests/test_arrow.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use convergence::protocol_ext::DataRowBatch;
66
use convergence::server::{self, BindOptions};
77
use convergence::sqlparser::ast::Statement;
88
use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc};
9-
use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, TimestampSecondArray};
9+
use datafusion::arrow::array::{ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray};
1010
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
1111
use datafusion::arrow::record_batch::RecordBatch;
1212
use std::sync::Arc;
13+
use rust_decimal::Decimal;
1314
use tokio_postgres::{connect, NoTls};
1415

1516
struct ArrowPortal {
@@ -31,20 +32,22 @@ impl ArrowEngine {
3132
fn new() -> Self {
3233
let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
3334
let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef;
35+
let decimal_col = Arc::new(Decimal128Array::from(vec![11, 22, 33]).with_precision_and_scale(2, 0).unwrap()) as ArrayRef;
3436
let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3537
let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef;
3638
let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef;
3739

3840
let schema = Schema::new(vec![
3941
Field::new("int_col", DataType::Int32, true),
4042
Field::new("float_col", DataType::Float32, true),
43+
Field::new("decimal_col", DataType::Decimal128(2, 0), true),
4144
Field::new("string_col", DataType::Utf8, true),
4245
Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true),
4346
Field::new("date_col", DataType::Date32, true),
4447
]);
4548

4649
Self {
47-
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, ts_col, date_col])
50+
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, decimal_col, string_col, string_view_col, ts_col, date_col])
4851
.expect("failed to create batch"),
4952
}
5053
}
@@ -89,8 +92,8 @@ async fn basic_data_types() {
8992
let rows = client.query("select 1", &[]).await.unwrap();
9093
let get_row = |idx: usize| {
9194
let row = &rows[idx];
92-
let cols: (i32, f32, &str, NaiveDateTime, NaiveDate) =
93-
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4));
95+
let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, NaiveDate) =
96+
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5), row.get(6));
9497
cols
9598
};
9699

@@ -99,6 +102,7 @@ async fn basic_data_types() {
99102
(
100103
1,
101104
1.5,
105+
Decimal::from(11),
102106
"a",
103107
NaiveDate::from_ymd_opt(2020, 1, 1)
104108
.unwrap()
@@ -112,6 +116,7 @@ async fn basic_data_types() {
112116
(
113117
2,
114118
2.5,
119+
Decimal::from(22),
115120
"b",
116121
NaiveDate::from_ymd_opt(2020, 2, 1)
117122
.unwrap()
@@ -125,6 +130,7 @@ async fn basic_data_types() {
125130
(
126131
3,
127132
3.5,
133+
Decimal::from(33),
128134
"c",
129135
NaiveDate::from_ymd_opt(2020, 3, 1)
130136
.unwrap()

convergence/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@ futures = "0.3"
1616
sqlparser = "0.46"
1717
async-trait = "0.1"
1818
chrono = "0.4"
19-
20-
[dev-dependencies]
19+
rust_decimal = { version = "1.37.1", features = ["default", "db-postgres"] }
2120
tokio-postgres = "0.7"

convergence/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ data_types! {
7575
Float4 = 700, 4
7676
Float8 = 701, 8
7777

78+
Numeric = 1700, -1
79+
7880
Date = 1082, 4
7981
Timestamp = 1114, 8
8082

convergence/src/protocol_ext.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription};
44
use bytes::{BufMut, BytesMut};
55
use chrono::{NaiveDate, NaiveDateTime};
6+
use rust_decimal::Decimal;
7+
use tokio_postgres::types::{ToSql, Type};
68
use tokio_util::codec::Encoder;
79

810
/// Supports batched rows for e.g. returning portal result sets.
@@ -131,6 +133,24 @@ impl<'a> DataRowWriter<'a> {
131133
}
132134
}
133135

136+
/// Writes a numeric value for the next column.
137+
pub fn write_numeric_16(&mut self, val: i128, _p: &u8, s: &i8) {
138+
let decimal = Decimal::from_i128_with_scale(val, *s as u32);
139+
match self.parent.format_code {
140+
FormatCode::Text => {
141+
self.write_string(&decimal.to_string())
142+
}
143+
FormatCode::Binary => {
144+
let numeric_type = Type::from_oid(1700).expect("failed to create numeric type");
145+
let mut buf = BytesMut::new();
146+
decimal.to_sql(&numeric_type, &mut buf)
147+
.expect("failed to write numeric");
148+
149+
self.write_value(&buf.freeze())
150+
}
151+
};
152+
}
153+
134154
primitive_write!(write_int2, i16);
135155
primitive_write!(write_int4, i32);
136156
primitive_write!(write_int8, i64);

0 commit comments

Comments
 (0)