Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sqlx-postgres/src/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ impl PgArguments {
&mut self,
conn: &mut PgConnection,
parameters: &[PgTypeInfo],
persistent: bool,
) -> Result<(), Error> {
let PgArgumentBuffer {
ref patches,
Expand All @@ -128,8 +129,8 @@ impl PgArguments {

for (offset, kind) in type_holes {
let oid = match kind {
HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
HoleKind::Type { name } => conn.fetch_type_id_by_name(persistent, name).await?,
HoleKind::Array(array) => conn.fetch_array_type_id(persistent, array).await?,
};
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
}
Expand Down
121 changes: 88 additions & 33 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl PgConnection {
pub(super) async fn handle_row_description(
&mut self,
desc: Option<RowDescription>,
persistent: bool,
fetch_type_info: bool,
fetch_column_description: bool,
) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
Expand All @@ -123,14 +124,19 @@ impl PgConnection {
let name = UStr::from(field.name);

let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, fetch_type_info)
.maybe_fetch_type_info_by_oid(field.data_type_id, persistent, fetch_type_info)
.await?;

let origin = if let (Some(relation_oid), Some(attribute_no)) =
(field.relation_id, field.relation_attribute_no)
{
self.maybe_fetch_column_origin(relation_oid, attribute_no, fetch_column_description)
.await?
self.maybe_fetch_column_origin(
relation_oid,
attribute_no,
persistent,
fetch_column_description,
)
.await?
} else {
ColumnOrigin::Expression
};
Expand All @@ -153,12 +159,16 @@ impl PgConnection {

pub(super) async fn handle_parameter_description(
&mut self,
persistent: bool,
desc: ParameterDescription,
) -> Result<Vec<PgTypeInfo>, Error> {
let mut params = Vec::with_capacity(desc.types.len());

for ty in desc.types {
params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
params.push(
self.maybe_fetch_type_info_by_oid(ty, persistent, true)
.await?,
);
}

Ok(params)
Expand All @@ -167,6 +177,7 @@ impl PgConnection {
async fn maybe_fetch_type_info_by_oid(
&mut self,
oid: Oid,
persistent: bool,
should_fetch: bool,
) -> Result<PgTypeInfo, Error> {
// first we check if this is a built-in type
Expand All @@ -183,7 +194,7 @@ impl PgConnection {
// fallback to asking the database directly for a type name
if should_fetch {
// we're boxing this future here so we can use async recursion
let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?;
let info = Box::pin(async { self.fetch_type_by_oid(persistent, oid).await }).await?;

// cache the type name <-> oid relationship in a paired hashmap
// so we don't come down this road again
Expand All @@ -208,6 +219,7 @@ impl PgConnection {
&mut self,
relation_id: Oid,
attribute_no: i16,
persistent: bool,
should_fetch: bool,
) -> Result<ColumnOrigin, Error> {
if let Some(origin) = self
Expand Down Expand Up @@ -238,6 +250,7 @@ impl PgConnection {
FROM pg_catalog.pg_attribute \
WHERE attrelid = $1 AND attnum = $2",
)
.persistent(persistent)
.bind(relation_id)
.bind(attribute_no)
.fetch_optional(&mut *self)
Expand Down Expand Up @@ -267,7 +280,7 @@ impl PgConnection {
}))
}

async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
async fn fetch_type_by_oid(&mut self, persistent: bool, oid: Oid) -> Result<PgTypeInfo, Error> {
let (name, typ_type, category, relation_id, element, base_type): (
String,
i8,
Expand All @@ -287,6 +300,7 @@ impl PgConnection {
FROM pg_catalog.pg_type \
WHERE oid = $1",
)
.persistent(persistent)
.bind(oid)
.fetch_one(&mut *self)
.await?;
Expand All @@ -295,12 +309,16 @@ impl PgConnection {
let category = TypCategory::try_from(category);

match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
(Ok(TypType::Domain), _) => {
self.fetch_domain_by_oid(oid, base_type, persistent, name)
.await
}

(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(
self.maybe_fetch_type_info_by_oid(element, true).await?,
self.maybe_fetch_type_info_by_oid(element, persistent, true)
.await?,
),
name: name.into(),
oid,
Expand All @@ -316,13 +334,16 @@ impl PgConnection {
}

(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
self.fetch_range_by_oid(oid, persistent, name).await
}

(Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await,
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
self.fetch_enum_by_oid(oid, persistent, name).await
}

(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
self.fetch_composite_by_oid(oid, relation_id, persistent, name)
.await
}

_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
Expand All @@ -333,7 +354,12 @@ impl PgConnection {
}
}

async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
async fn fetch_enum_by_oid(
&mut self,
oid: Oid,
persistent: bool,
name: String,
) -> Result<PgTypeInfo, Error> {
let variants: Vec<String> = query_scalar(
r#"
SELECT enumlabel
Expand All @@ -342,6 +368,7 @@ WHERE enumtypid = $1
ORDER BY enumsortorder
"#,
)
.persistent(persistent)
.bind(oid)
.fetch_all(self)
.await?;
Expand All @@ -357,6 +384,7 @@ ORDER BY enumsortorder
&mut self,
oid: Oid,
relation_id: Oid,
persistent: bool,
name: String,
) -> Result<PgTypeInfo, Error> {
let raw_fields: Vec<(String, Oid)> = query_as(
Expand All @@ -369,14 +397,17 @@ AND attnum > 0
ORDER BY attnum
"#,
)
.persistent(persistent)
.bind(relation_id)
.fetch_all(&mut *self)
.await?;

let mut fields = Vec::new();

for (field_name, field_oid) in raw_fields.into_iter() {
let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
let field_type = self
.maybe_fetch_type_info_by_oid(field_oid, persistent, true)
.await?;

fields.push((field_name, field_type));
}
Expand All @@ -392,9 +423,12 @@ ORDER BY attnum
&mut self,
oid: Oid,
base_type: Oid,
persistent: bool,
name: String,
) -> Result<PgTypeInfo, Error> {
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
let base_type = self
.maybe_fetch_type_info_by_oid(base_type, persistent, true)
.await?;

Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
Expand All @@ -403,7 +437,12 @@ ORDER BY attnum
}))))
}

async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
async fn fetch_range_by_oid(
&mut self,
oid: Oid,
persistent: bool,
name: String,
) -> Result<PgTypeInfo, Error> {
let element_oid: Oid = query_scalar(
r#"
SELECT rngsubtype
Expand All @@ -415,7 +454,9 @@ WHERE rngtypid = $1
.fetch_one(&mut *self)
.await?;

let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
let element = self
.maybe_fetch_type_info_by_oid(element_oid, persistent, true)
.await?;

Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(element),
Expand All @@ -424,26 +465,35 @@ WHERE rngtypid = $1
}))))
}

pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {
pub(crate) async fn resolve_type_id(
&mut self,
persistent: bool,
ty: &PgType,
) -> Result<Oid, Error> {
if let Some(oid) = ty.try_oid() {
return Ok(oid);
}

match ty {
PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await,
PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await,
PgType::DeclareWithName(name) => self.fetch_type_id_by_name(persistent, name).await,
PgType::DeclareArrayOf(array) => self.fetch_array_type_id(persistent, array).await,
// `.try_oid()` should return `Some()` or it should be covered here
_ => unreachable!("(bug) OID should be resolvable for type {ty:?}"),
}
}

pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
pub(crate) async fn fetch_type_id_by_name(
&mut self,
persistent: bool,
name: &str,
) -> Result<Oid, Error> {
if let Some(oid) = self.inner.cache_type_oid.get(name) {
return Ok(*oid);
}

// language=SQL
let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid")
.persistent(persistent)
.bind(name)
.fetch_optional(&mut *self)
.await?
Expand All @@ -457,7 +507,11 @@ WHERE rngtypid = $1
Ok(oid)
}

pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
pub(crate) async fn fetch_array_type_id(
&mut self,
persistent: bool,
array: &PgArrayOf,
) -> Result<Oid, Error> {
if let Some(oid) = self
.inner
.cache_type_oid
Expand All @@ -470,6 +524,7 @@ WHERE rngtypid = $1
// language=SQL
let (elem_oid, array_oid): (Oid, Oid) =
query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid")
.persistent(persistent)
.bind(&*array.elem_name)
.fetch_optional(&mut *self)
.await?
Expand Down Expand Up @@ -719,19 +774,19 @@ fn explain_parsing() {

// https://github.com/launchbadge/sqlx/issues/2622
let extra_field = r#"[
{
"Plan": {
"Node Type": "Result",
"Parallel Aware": false,
"Async Capable": false,
"Startup Cost": 0.00,
"Total Cost": 0.01,
"Plan Rows": 1,
"Plan Width": 4,
"Output": ["1"]
},
{
"Plan": {
"Node Type": "Result",
"Parallel Aware": false,
"Async Capable": false,
"Startup Cost": 0.00,
"Total Cost": 0.01,
"Plan Rows": 1,
"Plan Width": 4,
"Output": ["1"]
},
"Query Identifier": 1147616880456321454
}
}
]"#;

// https://github.com/launchbadge/sqlx/issues/1449
Expand Down
Loading