mas_storage_pg/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{
13        UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14    },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError, DatabaseInconsistencyError,
27    filter::{Filter, StatementExt},
28    iden::UpstreamOAuthProviders,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
34/// connection
35pub struct PgUpstreamOAuthProviderRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40    /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active
41    /// PostgreSQL connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50    upstream_oauth_provider_id: Uuid,
51    issuer: Option<String>,
52    human_name: Option<String>,
53    brand_name: Option<String>,
54    scope: String,
55    client_id: String,
56    encrypted_client_secret: Option<String>,
57    token_endpoint_signing_alg: Option<String>,
58    token_endpoint_auth_method: String,
59    id_token_signed_response_alg: String,
60    fetch_userinfo: bool,
61    userinfo_signed_response_alg: Option<String>,
62    created_at: DateTime<Utc>,
63    disabled_at: Option<DateTime<Utc>>,
64    claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65    jwks_uri_override: Option<String>,
66    authorization_endpoint_override: Option<String>,
67    token_endpoint_override: Option<String>,
68    userinfo_endpoint_override: Option<String>,
69    discovery_mode: String,
70    pkce_mode: String,
71    response_mode: Option<String>,
72    additional_parameters: Option<Json<Vec<(String, String)>>>,
73    forward_login_hint: bool,
74    on_backchannel_logout: String,
75}
76
77impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
78    type Error = DatabaseInconsistencyError;
79
80    #[allow(clippy::too_many_lines)]
81    fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
82        let id = value.upstream_oauth_provider_id.into();
83        let scope = value.scope.parse().map_err(|e| {
84            DatabaseInconsistencyError::on("upstream_oauth_providers")
85                .column("scope")
86                .row(id)
87                .source(e)
88        })?;
89        let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
90            DatabaseInconsistencyError::on("upstream_oauth_providers")
91                .column("token_endpoint_auth_method")
92                .row(id)
93                .source(e)
94        })?;
95        let token_endpoint_signing_alg = value
96            .token_endpoint_signing_alg
97            .map(|x| x.parse())
98            .transpose()
99            .map_err(|e| {
100                DatabaseInconsistencyError::on("upstream_oauth_providers")
101                    .column("token_endpoint_signing_alg")
102                    .row(id)
103                    .source(e)
104            })?;
105        let id_token_signed_response_alg =
106            value.id_token_signed_response_alg.parse().map_err(|e| {
107                DatabaseInconsistencyError::on("upstream_oauth_providers")
108                    .column("id_token_signed_response_alg")
109                    .row(id)
110                    .source(e)
111            })?;
112
113        let userinfo_signed_response_alg = value
114            .userinfo_signed_response_alg
115            .map(|x| x.parse())
116            .transpose()
117            .map_err(|e| {
118                DatabaseInconsistencyError::on("upstream_oauth_providers")
119                    .column("userinfo_signed_response_alg")
120                    .row(id)
121                    .source(e)
122            })?;
123
124        let authorization_endpoint_override = value
125            .authorization_endpoint_override
126            .map(|x| x.parse())
127            .transpose()
128            .map_err(|e| {
129                DatabaseInconsistencyError::on("upstream_oauth_providers")
130                    .column("authorization_endpoint_override")
131                    .row(id)
132                    .source(e)
133            })?;
134
135        let token_endpoint_override = value
136            .token_endpoint_override
137            .map(|x| x.parse())
138            .transpose()
139            .map_err(|e| {
140                DatabaseInconsistencyError::on("upstream_oauth_providers")
141                    .column("token_endpoint_override")
142                    .row(id)
143                    .source(e)
144            })?;
145
146        let userinfo_endpoint_override = value
147            .userinfo_endpoint_override
148            .map(|x| x.parse())
149            .transpose()
150            .map_err(|e| {
151                DatabaseInconsistencyError::on("upstream_oauth_providers")
152                    .column("userinfo_endpoint_override")
153                    .row(id)
154                    .source(e)
155            })?;
156
157        let jwks_uri_override = value
158            .jwks_uri_override
159            .map(|x| x.parse())
160            .transpose()
161            .map_err(|e| {
162                DatabaseInconsistencyError::on("upstream_oauth_providers")
163                    .column("jwks_uri_override")
164                    .row(id)
165                    .source(e)
166            })?;
167
168        let discovery_mode = value.discovery_mode.parse().map_err(|e| {
169            DatabaseInconsistencyError::on("upstream_oauth_providers")
170                .column("discovery_mode")
171                .row(id)
172                .source(e)
173        })?;
174
175        let pkce_mode = value.pkce_mode.parse().map_err(|e| {
176            DatabaseInconsistencyError::on("upstream_oauth_providers")
177                .column("pkce_mode")
178                .row(id)
179                .source(e)
180        })?;
181
182        let response_mode = value
183            .response_mode
184            .map(|x| x.parse())
185            .transpose()
186            .map_err(|e| {
187                DatabaseInconsistencyError::on("upstream_oauth_providers")
188                    .column("response_mode")
189                    .row(id)
190                    .source(e)
191            })?;
192
193        let additional_authorization_parameters = value
194            .additional_parameters
195            .map(|Json(x)| x)
196            .unwrap_or_default();
197
198        let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
199            DatabaseInconsistencyError::on("upstream_oauth_providers")
200                .column("on_backchannel_logout")
201                .row(id)
202                .source(e)
203        })?;
204
205        Ok(UpstreamOAuthProvider {
206            id,
207            issuer: value.issuer,
208            human_name: value.human_name,
209            brand_name: value.brand_name,
210            scope,
211            client_id: value.client_id,
212            encrypted_client_secret: value.encrypted_client_secret,
213            token_endpoint_auth_method,
214            token_endpoint_signing_alg,
215            id_token_signed_response_alg,
216            fetch_userinfo: value.fetch_userinfo,
217            userinfo_signed_response_alg,
218            created_at: value.created_at,
219            disabled_at: value.disabled_at,
220            claims_imports: value.claims_imports.0,
221            authorization_endpoint_override,
222            token_endpoint_override,
223            userinfo_endpoint_override,
224            jwks_uri_override,
225            discovery_mode,
226            pkce_mode,
227            response_mode,
228            additional_authorization_parameters,
229            forward_login_hint: value.forward_login_hint,
230            on_backchannel_logout,
231        })
232    }
233}
234
235impl Filter for UpstreamOAuthProviderFilter<'_> {
236    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
237        sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
238            Expr::col((
239                UpstreamOAuthProviders::Table,
240                UpstreamOAuthProviders::DisabledAt,
241            ))
242            .is_null()
243            .eq(enabled)
244        }))
245    }
246}
247
248#[async_trait]
249impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
250    type Error = DatabaseError;
251
252    #[tracing::instrument(
253        name = "db.upstream_oauth_provider.lookup",
254        skip_all,
255        fields(
256            db.query.text,
257            upstream_oauth_provider.id = %id,
258        ),
259        err,
260    )]
261    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
262        let res = sqlx::query_as!(
263            ProviderLookup,
264            r#"
265                SELECT
266                    upstream_oauth_provider_id,
267                    issuer,
268                    human_name,
269                    brand_name,
270                    scope,
271                    client_id,
272                    encrypted_client_secret,
273                    token_endpoint_signing_alg,
274                    token_endpoint_auth_method,
275                    id_token_signed_response_alg,
276                    fetch_userinfo,
277                    userinfo_signed_response_alg,
278                    created_at,
279                    disabled_at,
280                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
281                    jwks_uri_override,
282                    authorization_endpoint_override,
283                    token_endpoint_override,
284                    userinfo_endpoint_override,
285                    discovery_mode,
286                    pkce_mode,
287                    response_mode,
288                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
289                    forward_login_hint,
290                    on_backchannel_logout
291                FROM upstream_oauth_providers
292                WHERE upstream_oauth_provider_id = $1
293            "#,
294            Uuid::from(id),
295        )
296        .traced()
297        .fetch_optional(&mut *self.conn)
298        .await?;
299
300        let res = res
301            .map(UpstreamOAuthProvider::try_from)
302            .transpose()
303            .map_err(DatabaseError::from)?;
304
305        Ok(res)
306    }
307
308    #[tracing::instrument(
309        name = "db.upstream_oauth_provider.add",
310        skip_all,
311        fields(
312            db.query.text,
313            upstream_oauth_provider.id,
314            upstream_oauth_provider.issuer = params.issuer,
315            upstream_oauth_provider.client_id = %params.client_id,
316        ),
317        err,
318    )]
319    async fn add(
320        &mut self,
321        rng: &mut (dyn RngCore + Send),
322        clock: &dyn Clock,
323        params: UpstreamOAuthProviderParams,
324    ) -> Result<UpstreamOAuthProvider, Self::Error> {
325        let created_at = clock.now();
326        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
327        tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
328
329        sqlx::query!(
330            r#"
331            INSERT INTO upstream_oauth_providers (
332                upstream_oauth_provider_id,
333                issuer,
334                human_name,
335                brand_name,
336                scope,
337                token_endpoint_auth_method,
338                token_endpoint_signing_alg,
339                id_token_signed_response_alg,
340                fetch_userinfo,
341                userinfo_signed_response_alg,
342                client_id,
343                encrypted_client_secret,
344                claims_imports,
345                authorization_endpoint_override,
346                token_endpoint_override,
347                userinfo_endpoint_override,
348                jwks_uri_override,
349                discovery_mode,
350                pkce_mode,
351                response_mode,
352                forward_login_hint,
353                on_backchannel_logout,
354                created_at
355            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
356                      $12, $13, $14, $15, $16, $17, $18, $19, $20,
357                      $21, $22, $23)
358        "#,
359            Uuid::from(id),
360            params.issuer.as_deref(),
361            params.human_name.as_deref(),
362            params.brand_name.as_deref(),
363            params.scope.to_string(),
364            params.token_endpoint_auth_method.to_string(),
365            params
366                .token_endpoint_signing_alg
367                .as_ref()
368                .map(ToString::to_string),
369            params.id_token_signed_response_alg.to_string(),
370            params.fetch_userinfo,
371            params
372                .userinfo_signed_response_alg
373                .as_ref()
374                .map(ToString::to_string),
375            &params.client_id,
376            params.encrypted_client_secret.as_deref(),
377            Json(&params.claims_imports) as _,
378            params
379                .authorization_endpoint_override
380                .as_ref()
381                .map(ToString::to_string),
382            params
383                .token_endpoint_override
384                .as_ref()
385                .map(ToString::to_string),
386            params
387                .userinfo_endpoint_override
388                .as_ref()
389                .map(ToString::to_string),
390            params.jwks_uri_override.as_ref().map(ToString::to_string),
391            params.discovery_mode.as_str(),
392            params.pkce_mode.as_str(),
393            params.response_mode.as_ref().map(ToString::to_string),
394            params.forward_login_hint,
395            params.on_backchannel_logout.as_str(),
396            created_at,
397        )
398        .traced()
399        .execute(&mut *self.conn)
400        .await?;
401
402        Ok(UpstreamOAuthProvider {
403            id,
404            issuer: params.issuer,
405            human_name: params.human_name,
406            brand_name: params.brand_name,
407            scope: params.scope,
408            client_id: params.client_id,
409            encrypted_client_secret: params.encrypted_client_secret,
410            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
411            token_endpoint_auth_method: params.token_endpoint_auth_method,
412            id_token_signed_response_alg: params.id_token_signed_response_alg,
413            fetch_userinfo: params.fetch_userinfo,
414            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
415            created_at,
416            disabled_at: None,
417            claims_imports: params.claims_imports,
418            authorization_endpoint_override: params.authorization_endpoint_override,
419            token_endpoint_override: params.token_endpoint_override,
420            userinfo_endpoint_override: params.userinfo_endpoint_override,
421            jwks_uri_override: params.jwks_uri_override,
422            discovery_mode: params.discovery_mode,
423            pkce_mode: params.pkce_mode,
424            response_mode: params.response_mode,
425            additional_authorization_parameters: params.additional_authorization_parameters,
426            on_backchannel_logout: params.on_backchannel_logout,
427            forward_login_hint: params.forward_login_hint,
428        })
429    }
430
431    #[tracing::instrument(
432        name = "db.upstream_oauth_provider.delete_by_id",
433        skip_all,
434        fields(
435            db.query.text,
436            upstream_oauth_provider.id = %id,
437        ),
438        err,
439    )]
440    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
441        // Delete the authorization sessions first, as they have a foreign key
442        // constraint on the links and the providers.
443        {
444            let span = info_span!(
445                "db.oauth2_client.delete_by_id.authorization_sessions",
446                upstream_oauth_provider.id = %id,
447                { DB_QUERY_TEXT } = tracing::field::Empty,
448            );
449            sqlx::query!(
450                r#"
451                    DELETE FROM upstream_oauth_authorization_sessions
452                    WHERE upstream_oauth_provider_id = $1
453                "#,
454                Uuid::from(id),
455            )
456            .record(&span)
457            .execute(&mut *self.conn)
458            .instrument(span)
459            .await?;
460        }
461
462        // Delete the links next, as they have a foreign key constraint on the
463        // providers.
464        {
465            let span = info_span!(
466                "db.oauth2_client.delete_by_id.links",
467                upstream_oauth_provider.id = %id,
468                { DB_QUERY_TEXT } = tracing::field::Empty,
469            );
470            sqlx::query!(
471                r#"
472                    DELETE FROM upstream_oauth_links
473                    WHERE upstream_oauth_provider_id = $1
474                "#,
475                Uuid::from(id),
476            )
477            .record(&span)
478            .execute(&mut *self.conn)
479            .instrument(span)
480            .await?;
481        }
482
483        let res = sqlx::query!(
484            r#"
485                DELETE FROM upstream_oauth_providers
486                WHERE upstream_oauth_provider_id = $1
487            "#,
488            Uuid::from(id),
489        )
490        .traced()
491        .execute(&mut *self.conn)
492        .await?;
493
494        DatabaseError::ensure_affected_rows(&res, 1)
495    }
496
497    #[tracing::instrument(
498        name = "db.upstream_oauth_provider.add",
499        skip_all,
500        fields(
501            db.query.text,
502            upstream_oauth_provider.id = %id,
503            upstream_oauth_provider.issuer = params.issuer,
504            upstream_oauth_provider.client_id = %params.client_id,
505        ),
506        err,
507    )]
508    async fn upsert(
509        &mut self,
510        clock: &dyn Clock,
511        id: Ulid,
512        params: UpstreamOAuthProviderParams,
513    ) -> Result<UpstreamOAuthProvider, Self::Error> {
514        let created_at = clock.now();
515
516        let created_at = sqlx::query_scalar!(
517            r#"
518                INSERT INTO upstream_oauth_providers (
519                    upstream_oauth_provider_id,
520                    issuer,
521                    human_name,
522                    brand_name,
523                    scope,
524                    token_endpoint_auth_method,
525                    token_endpoint_signing_alg,
526                    id_token_signed_response_alg,
527                    fetch_userinfo,
528                    userinfo_signed_response_alg,
529                    client_id,
530                    encrypted_client_secret,
531                    claims_imports,
532                    authorization_endpoint_override,
533                    token_endpoint_override,
534                    userinfo_endpoint_override,
535                    jwks_uri_override,
536                    discovery_mode,
537                    pkce_mode,
538                    response_mode,
539                    additional_parameters,
540                    forward_login_hint,
541                    ui_order,
542                    on_backchannel_logout,
543                    created_at
544                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
545                          $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
546                          $21, $22, $23, $24, $25)
547                ON CONFLICT (upstream_oauth_provider_id)
548                    DO UPDATE
549                    SET
550                        issuer = EXCLUDED.issuer,
551                        human_name = EXCLUDED.human_name,
552                        brand_name = EXCLUDED.brand_name,
553                        scope = EXCLUDED.scope,
554                        token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
555                        token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
556                        id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
557                        fetch_userinfo = EXCLUDED.fetch_userinfo,
558                        userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
559                        disabled_at = NULL,
560                        client_id = EXCLUDED.client_id,
561                        encrypted_client_secret = EXCLUDED.encrypted_client_secret,
562                        claims_imports = EXCLUDED.claims_imports,
563                        authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
564                        token_endpoint_override = EXCLUDED.token_endpoint_override,
565                        userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
566                        jwks_uri_override = EXCLUDED.jwks_uri_override,
567                        discovery_mode = EXCLUDED.discovery_mode,
568                        pkce_mode = EXCLUDED.pkce_mode,
569                        response_mode = EXCLUDED.response_mode,
570                        additional_parameters = EXCLUDED.additional_parameters,
571                        forward_login_hint = EXCLUDED.forward_login_hint,
572                        ui_order = EXCLUDED.ui_order,
573                        on_backchannel_logout = EXCLUDED.on_backchannel_logout
574                RETURNING created_at
575            "#,
576            Uuid::from(id),
577            params.issuer.as_deref(),
578            params.human_name.as_deref(),
579            params.brand_name.as_deref(),
580            params.scope.to_string(),
581            params.token_endpoint_auth_method.to_string(),
582            params
583                .token_endpoint_signing_alg
584                .as_ref()
585                .map(ToString::to_string),
586            params.id_token_signed_response_alg.to_string(),
587            params.fetch_userinfo,
588            params
589                .userinfo_signed_response_alg
590                .as_ref()
591                .map(ToString::to_string),
592            &params.client_id,
593            params.encrypted_client_secret.as_deref(),
594            Json(&params.claims_imports) as _,
595            params
596                .authorization_endpoint_override
597                .as_ref()
598                .map(ToString::to_string),
599            params
600                .token_endpoint_override
601                .as_ref()
602                .map(ToString::to_string),
603            params
604                .userinfo_endpoint_override
605                .as_ref()
606                .map(ToString::to_string),
607            params.jwks_uri_override.as_ref().map(ToString::to_string),
608            params.discovery_mode.as_str(),
609            params.pkce_mode.as_str(),
610            params.response_mode.as_ref().map(ToString::to_string),
611            Json(&params.additional_authorization_parameters) as _,
612            params.forward_login_hint,
613            params.ui_order,
614            params.on_backchannel_logout.as_str(),
615            created_at,
616        )
617        .traced()
618        .fetch_one(&mut *self.conn)
619        .await?;
620
621        Ok(UpstreamOAuthProvider {
622            id,
623            issuer: params.issuer,
624            human_name: params.human_name,
625            brand_name: params.brand_name,
626            scope: params.scope,
627            client_id: params.client_id,
628            encrypted_client_secret: params.encrypted_client_secret,
629            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
630            token_endpoint_auth_method: params.token_endpoint_auth_method,
631            id_token_signed_response_alg: params.id_token_signed_response_alg,
632            fetch_userinfo: params.fetch_userinfo,
633            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
634            created_at,
635            disabled_at: None,
636            claims_imports: params.claims_imports,
637            authorization_endpoint_override: params.authorization_endpoint_override,
638            token_endpoint_override: params.token_endpoint_override,
639            userinfo_endpoint_override: params.userinfo_endpoint_override,
640            jwks_uri_override: params.jwks_uri_override,
641            discovery_mode: params.discovery_mode,
642            pkce_mode: params.pkce_mode,
643            response_mode: params.response_mode,
644            additional_authorization_parameters: params.additional_authorization_parameters,
645            forward_login_hint: params.forward_login_hint,
646            on_backchannel_logout: params.on_backchannel_logout,
647        })
648    }
649
650    #[tracing::instrument(
651        name = "db.upstream_oauth_provider.disable",
652        skip_all,
653        fields(
654            db.query.text,
655            %upstream_oauth_provider.id,
656        ),
657        err,
658    )]
659    async fn disable(
660        &mut self,
661        clock: &dyn Clock,
662        mut upstream_oauth_provider: UpstreamOAuthProvider,
663    ) -> Result<UpstreamOAuthProvider, Self::Error> {
664        let disabled_at = clock.now();
665        let res = sqlx::query!(
666            r#"
667                UPDATE upstream_oauth_providers
668                SET disabled_at = $2
669                WHERE upstream_oauth_provider_id = $1
670            "#,
671            Uuid::from(upstream_oauth_provider.id),
672            disabled_at,
673        )
674        .traced()
675        .execute(&mut *self.conn)
676        .await?;
677
678        DatabaseError::ensure_affected_rows(&res, 1)?;
679
680        upstream_oauth_provider.disabled_at = Some(disabled_at);
681
682        Ok(upstream_oauth_provider)
683    }
684
685    #[tracing::instrument(
686        name = "db.upstream_oauth_provider.list",
687        skip_all,
688        fields(
689            db.query.text,
690        ),
691        err,
692    )]
693    async fn list(
694        &mut self,
695        filter: UpstreamOAuthProviderFilter<'_>,
696        pagination: Pagination,
697    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
698        let (sql, arguments) = Query::select()
699            .expr_as(
700                Expr::col((
701                    UpstreamOAuthProviders::Table,
702                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
703                )),
704                ProviderLookupIden::UpstreamOauthProviderId,
705            )
706            .expr_as(
707                Expr::col((
708                    UpstreamOAuthProviders::Table,
709                    UpstreamOAuthProviders::Issuer,
710                )),
711                ProviderLookupIden::Issuer,
712            )
713            .expr_as(
714                Expr::col((
715                    UpstreamOAuthProviders::Table,
716                    UpstreamOAuthProviders::HumanName,
717                )),
718                ProviderLookupIden::HumanName,
719            )
720            .expr_as(
721                Expr::col((
722                    UpstreamOAuthProviders::Table,
723                    UpstreamOAuthProviders::BrandName,
724                )),
725                ProviderLookupIden::BrandName,
726            )
727            .expr_as(
728                Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
729                ProviderLookupIden::Scope,
730            )
731            .expr_as(
732                Expr::col((
733                    UpstreamOAuthProviders::Table,
734                    UpstreamOAuthProviders::ClientId,
735                )),
736                ProviderLookupIden::ClientId,
737            )
738            .expr_as(
739                Expr::col((
740                    UpstreamOAuthProviders::Table,
741                    UpstreamOAuthProviders::EncryptedClientSecret,
742                )),
743                ProviderLookupIden::EncryptedClientSecret,
744            )
745            .expr_as(
746                Expr::col((
747                    UpstreamOAuthProviders::Table,
748                    UpstreamOAuthProviders::TokenEndpointSigningAlg,
749                )),
750                ProviderLookupIden::TokenEndpointSigningAlg,
751            )
752            .expr_as(
753                Expr::col((
754                    UpstreamOAuthProviders::Table,
755                    UpstreamOAuthProviders::TokenEndpointAuthMethod,
756                )),
757                ProviderLookupIden::TokenEndpointAuthMethod,
758            )
759            .expr_as(
760                Expr::col((
761                    UpstreamOAuthProviders::Table,
762                    UpstreamOAuthProviders::IdTokenSignedResponseAlg,
763                )),
764                ProviderLookupIden::IdTokenSignedResponseAlg,
765            )
766            .expr_as(
767                Expr::col((
768                    UpstreamOAuthProviders::Table,
769                    UpstreamOAuthProviders::FetchUserinfo,
770                )),
771                ProviderLookupIden::FetchUserinfo,
772            )
773            .expr_as(
774                Expr::col((
775                    UpstreamOAuthProviders::Table,
776                    UpstreamOAuthProviders::UserinfoSignedResponseAlg,
777                )),
778                ProviderLookupIden::UserinfoSignedResponseAlg,
779            )
780            .expr_as(
781                Expr::col((
782                    UpstreamOAuthProviders::Table,
783                    UpstreamOAuthProviders::CreatedAt,
784                )),
785                ProviderLookupIden::CreatedAt,
786            )
787            .expr_as(
788                Expr::col((
789                    UpstreamOAuthProviders::Table,
790                    UpstreamOAuthProviders::DisabledAt,
791                )),
792                ProviderLookupIden::DisabledAt,
793            )
794            .expr_as(
795                Expr::col((
796                    UpstreamOAuthProviders::Table,
797                    UpstreamOAuthProviders::ClaimsImports,
798                )),
799                ProviderLookupIden::ClaimsImports,
800            )
801            .expr_as(
802                Expr::col((
803                    UpstreamOAuthProviders::Table,
804                    UpstreamOAuthProviders::JwksUriOverride,
805                )),
806                ProviderLookupIden::JwksUriOverride,
807            )
808            .expr_as(
809                Expr::col((
810                    UpstreamOAuthProviders::Table,
811                    UpstreamOAuthProviders::TokenEndpointOverride,
812                )),
813                ProviderLookupIden::TokenEndpointOverride,
814            )
815            .expr_as(
816                Expr::col((
817                    UpstreamOAuthProviders::Table,
818                    UpstreamOAuthProviders::AuthorizationEndpointOverride,
819                )),
820                ProviderLookupIden::AuthorizationEndpointOverride,
821            )
822            .expr_as(
823                Expr::col((
824                    UpstreamOAuthProviders::Table,
825                    UpstreamOAuthProviders::UserinfoEndpointOverride,
826                )),
827                ProviderLookupIden::UserinfoEndpointOverride,
828            )
829            .expr_as(
830                Expr::col((
831                    UpstreamOAuthProviders::Table,
832                    UpstreamOAuthProviders::DiscoveryMode,
833                )),
834                ProviderLookupIden::DiscoveryMode,
835            )
836            .expr_as(
837                Expr::col((
838                    UpstreamOAuthProviders::Table,
839                    UpstreamOAuthProviders::PkceMode,
840                )),
841                ProviderLookupIden::PkceMode,
842            )
843            .expr_as(
844                Expr::col((
845                    UpstreamOAuthProviders::Table,
846                    UpstreamOAuthProviders::ResponseMode,
847                )),
848                ProviderLookupIden::ResponseMode,
849            )
850            .expr_as(
851                Expr::col((
852                    UpstreamOAuthProviders::Table,
853                    UpstreamOAuthProviders::AdditionalParameters,
854                )),
855                ProviderLookupIden::AdditionalParameters,
856            )
857            .expr_as(
858                Expr::col((
859                    UpstreamOAuthProviders::Table,
860                    UpstreamOAuthProviders::ForwardLoginHint,
861                )),
862                ProviderLookupIden::ForwardLoginHint,
863            )
864            .expr_as(
865                Expr::col((
866                    UpstreamOAuthProviders::Table,
867                    UpstreamOAuthProviders::OnBackchannelLogout,
868                )),
869                ProviderLookupIden::OnBackchannelLogout,
870            )
871            .from(UpstreamOAuthProviders::Table)
872            .apply_filter(filter)
873            .generate_pagination(
874                (
875                    UpstreamOAuthProviders::Table,
876                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
877                ),
878                pagination,
879            )
880            .build_sqlx(PostgresQueryBuilder);
881
882        let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
883            .traced()
884            .fetch_all(&mut *self.conn)
885            .await?;
886
887        let page = pagination
888            .process(edges)
889            .try_map(UpstreamOAuthProvider::try_from)?;
890
891        return Ok(page);
892    }
893
894    #[tracing::instrument(
895        name = "db.upstream_oauth_provider.count",
896        skip_all,
897        fields(
898            db.query.text,
899        ),
900        err,
901    )]
902    async fn count(
903        &mut self,
904        filter: UpstreamOAuthProviderFilter<'_>,
905    ) -> Result<usize, Self::Error> {
906        let (sql, arguments) = Query::select()
907            .expr(
908                Expr::col((
909                    UpstreamOAuthProviders::Table,
910                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
911                ))
912                .count(),
913            )
914            .from(UpstreamOAuthProviders::Table)
915            .apply_filter(filter)
916            .build_sqlx(PostgresQueryBuilder);
917
918        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
919            .traced()
920            .fetch_one(&mut *self.conn)
921            .await?;
922
923        count
924            .try_into()
925            .map_err(DatabaseError::to_invalid_operation)
926    }
927
928    #[tracing::instrument(
929        name = "db.upstream_oauth_provider.all_enabled",
930        skip_all,
931        fields(
932            db.query.text,
933        ),
934        err,
935    )]
936    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
937        let res = sqlx::query_as!(
938            ProviderLookup,
939            r#"
940                SELECT
941                    upstream_oauth_provider_id,
942                    issuer,
943                    human_name,
944                    brand_name,
945                    scope,
946                    client_id,
947                    encrypted_client_secret,
948                    token_endpoint_signing_alg,
949                    token_endpoint_auth_method,
950                    id_token_signed_response_alg,
951                    fetch_userinfo,
952                    userinfo_signed_response_alg,
953                    created_at,
954                    disabled_at,
955                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
956                    jwks_uri_override,
957                    authorization_endpoint_override,
958                    token_endpoint_override,
959                    userinfo_endpoint_override,
960                    discovery_mode,
961                    pkce_mode,
962                    response_mode,
963                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
964                    forward_login_hint,
965                    on_backchannel_logout
966                FROM upstream_oauth_providers
967                WHERE disabled_at IS NULL
968                ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
969            "#,
970        )
971        .traced()
972        .fetch_all(&mut *self.conn)
973        .await?;
974
975        let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
976        Ok(res?)
977    }
978}