mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24        UpstreamOAuthProviderTokenAuthMethod,
25    };
26    use mas_iana::jose::JsonWebSignatureAlg;
27    use mas_storage::{
28        Pagination, RepositoryAccess,
29        clock::MockClock,
30        upstream_oauth2::{
31            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
32            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
33            UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
34        },
35        user::UserRepository,
36    };
37    use oauth2_types::scope::{OPENID, Scope};
38    use rand::SeedableRng;
39    use sqlx::PgPool;
40
41    use crate::PgRepository;
42
43    #[sqlx::test(migrator = "crate::MIGRATOR")]
44    async fn test_repository(pool: PgPool) {
45        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
46        let clock = MockClock::default();
47        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
48
49        // The provider list should be empty at the start
50        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
51        assert!(all_providers.is_empty());
52
53        // Let's add a provider
54        let provider = repo
55            .upstream_oauth_provider()
56            .add(
57                &mut rng,
58                &clock,
59                UpstreamOAuthProviderParams {
60                    issuer: Some("https://example.com/".to_owned()),
61                    human_name: None,
62                    brand_name: None,
63                    scope: Scope::from_iter([OPENID]),
64                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
65                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
66                    fetch_userinfo: false,
67                    userinfo_signed_response_alg: None,
68                    token_endpoint_signing_alg: None,
69                    client_id: "client-id".to_owned(),
70                    encrypted_client_secret: None,
71                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
72                    token_endpoint_override: None,
73                    authorization_endpoint_override: None,
74                    userinfo_endpoint_override: None,
75                    jwks_uri_override: None,
76                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
77                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
78                    response_mode: None,
79                    additional_authorization_parameters: Vec::new(),
80                    forward_login_hint: false,
81                    ui_order: 0,
82                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
83                },
84            )
85            .await
86            .unwrap();
87
88        // Look it up in the database
89        let provider = repo
90            .upstream_oauth_provider()
91            .lookup(provider.id)
92            .await
93            .unwrap()
94            .expect("provider to be found in the database");
95        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
96        assert_eq!(provider.client_id, "client-id");
97
98        // It should be in the list of all providers
99        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
100        assert_eq!(providers.len(), 1);
101        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
102        assert_eq!(providers[0].client_id, "client-id");
103
104        // Start a session
105        let session = repo
106            .upstream_oauth_session()
107            .add(
108                &mut rng,
109                &clock,
110                &provider,
111                "some-state".to_owned(),
112                None,
113                Some("some-nonce".to_owned()),
114            )
115            .await
116            .unwrap();
117
118        // Look it up in the database
119        let session = repo
120            .upstream_oauth_session()
121            .lookup(session.id)
122            .await
123            .unwrap()
124            .expect("session to be found in the database");
125        assert_eq!(session.provider_id, provider.id);
126        assert_eq!(session.link_id(), None);
127        assert!(session.is_pending());
128        assert!(!session.is_completed());
129        assert!(!session.is_consumed());
130
131        // Create a link
132        let link = repo
133            .upstream_oauth_link()
134            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
135            .await
136            .unwrap();
137
138        // We can look it up by its ID
139        repo.upstream_oauth_link()
140            .lookup(link.id)
141            .await
142            .unwrap()
143            .expect("link to be found in database");
144
145        // or by its subject
146        let link = repo
147            .upstream_oauth_link()
148            .find_by_subject(&provider, "a-subject")
149            .await
150            .unwrap()
151            .expect("link to be found in database");
152        assert_eq!(link.subject, "a-subject");
153        assert_eq!(link.provider_id, provider.id);
154
155        let session = repo
156            .upstream_oauth_session()
157            .complete_with_link(&clock, session, &link, None, None, None, None)
158            .await
159            .unwrap();
160        // Reload the session
161        let session = repo
162            .upstream_oauth_session()
163            .lookup(session.id)
164            .await
165            .unwrap()
166            .expect("session to be found in the database");
167        assert!(session.is_completed());
168        assert!(!session.is_consumed());
169        assert_eq!(session.link_id(), Some(link.id));
170
171        let session = repo
172            .upstream_oauth_session()
173            .consume(&clock, session)
174            .await
175            .unwrap();
176        // Reload the session
177        let session = repo
178            .upstream_oauth_session()
179            .lookup(session.id)
180            .await
181            .unwrap()
182            .expect("session to be found in the database");
183        assert!(session.is_consumed());
184
185        let user = repo
186            .user()
187            .add(&mut rng, &clock, "john".to_owned())
188            .await
189            .unwrap();
190        repo.upstream_oauth_link()
191            .associate_to_user(&link, &user)
192            .await
193            .unwrap();
194
195        // XXX: we should also try other combinations of the filter
196        let filter = UpstreamOAuthLinkFilter::new()
197            .for_user(&user)
198            .for_provider(&provider)
199            .for_subject("a-subject")
200            .enabled_providers_only();
201
202        let links = repo
203            .upstream_oauth_link()
204            .list(filter, Pagination::first(10))
205            .await
206            .unwrap();
207        assert!(!links.has_previous_page);
208        assert!(!links.has_next_page);
209        assert_eq!(links.edges.len(), 1);
210        assert_eq!(links.edges[0].id, link.id);
211        assert_eq!(links.edges[0].user_id, Some(user.id));
212
213        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
214
215        // There should be exactly one enabled provider
216        assert_eq!(
217            repo.upstream_oauth_provider()
218                .count(UpstreamOAuthProviderFilter::new())
219                .await
220                .unwrap(),
221            1
222        );
223        assert_eq!(
224            repo.upstream_oauth_provider()
225                .count(UpstreamOAuthProviderFilter::new().enabled_only())
226                .await
227                .unwrap(),
228            1
229        );
230        assert_eq!(
231            repo.upstream_oauth_provider()
232                .count(UpstreamOAuthProviderFilter::new().disabled_only())
233                .await
234                .unwrap(),
235            0
236        );
237
238        // Disable the provider
239        repo.upstream_oauth_provider()
240            .disable(&clock, provider.clone())
241            .await
242            .unwrap();
243
244        // There should be exactly one disabled provider
245        assert_eq!(
246            repo.upstream_oauth_provider()
247                .count(UpstreamOAuthProviderFilter::new())
248                .await
249                .unwrap(),
250            1
251        );
252        assert_eq!(
253            repo.upstream_oauth_provider()
254                .count(UpstreamOAuthProviderFilter::new().enabled_only())
255                .await
256                .unwrap(),
257            0
258        );
259        assert_eq!(
260            repo.upstream_oauth_provider()
261                .count(UpstreamOAuthProviderFilter::new().disabled_only())
262                .await
263                .unwrap(),
264            1
265        );
266
267        // Test listing and counting sessions
268        let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
269
270        // Count the sessions for the provider
271        let session_count = repo
272            .upstream_oauth_session()
273            .count(session_filter)
274            .await
275            .unwrap();
276        assert_eq!(session_count, 1);
277
278        // List the sessions for the provider
279        let session_page = repo
280            .upstream_oauth_session()
281            .list(session_filter, Pagination::first(10))
282            .await
283            .unwrap();
284
285        assert_eq!(session_page.edges.len(), 1);
286        assert_eq!(session_page.edges[0].id, session.id);
287        assert!(!session_page.has_next_page);
288        assert!(!session_page.has_previous_page);
289
290        // Try deleting the provider
291        repo.upstream_oauth_provider()
292            .delete(provider)
293            .await
294            .unwrap();
295        assert_eq!(
296            repo.upstream_oauth_provider()
297                .count(UpstreamOAuthProviderFilter::new())
298                .await
299                .unwrap(),
300            0
301        );
302    }
303
304    /// Test that the pagination works as expected in the upstream OAuth
305    /// provider repository
306    #[sqlx::test(migrator = "crate::MIGRATOR")]
307    async fn test_provider_repository_pagination(pool: PgPool) {
308        let scope = Scope::from_iter([OPENID]);
309
310        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
311        let clock = MockClock::default();
312        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
313
314        let filter = UpstreamOAuthProviderFilter::new();
315
316        // Count the number of providers before we start
317        assert_eq!(
318            repo.upstream_oauth_provider().count(filter).await.unwrap(),
319            0
320        );
321
322        let mut ids = Vec::with_capacity(20);
323        // Create 20 providers
324        for idx in 0..20 {
325            let client_id = format!("client-{idx}");
326            let provider = repo
327                .upstream_oauth_provider()
328                .add(
329                    &mut rng,
330                    &clock,
331                    UpstreamOAuthProviderParams {
332                        issuer: None,
333                        human_name: None,
334                        brand_name: None,
335                        scope: scope.clone(),
336                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
337                        fetch_userinfo: false,
338                        userinfo_signed_response_alg: None,
339                        token_endpoint_signing_alg: None,
340                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
341                        client_id,
342                        encrypted_client_secret: None,
343                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
344                        token_endpoint_override: None,
345                        authorization_endpoint_override: None,
346                        userinfo_endpoint_override: None,
347                        jwks_uri_override: None,
348                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
349                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
350                        response_mode: None,
351                        additional_authorization_parameters: Vec::new(),
352                        forward_login_hint: false,
353                        ui_order: 0,
354                        on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
355                    },
356                )
357                .await
358                .unwrap();
359            ids.push(provider.id);
360            clock.advance(Duration::microseconds(10 * 1000 * 1000));
361        }
362
363        // Now we have 20 providers
364        assert_eq!(
365            repo.upstream_oauth_provider().count(filter).await.unwrap(),
366            20
367        );
368
369        // Lookup the first 10 items
370        let page = repo
371            .upstream_oauth_provider()
372            .list(filter, Pagination::first(10))
373            .await
374            .unwrap();
375
376        // It returned the first 10 items
377        assert!(page.has_next_page);
378        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
379        assert_eq!(&edge_ids, &ids[..10]);
380
381        // Getting the same page with the "enabled only" filter should return the same
382        // results
383        let other_page = repo
384            .upstream_oauth_provider()
385            .list(filter.enabled_only(), Pagination::first(10))
386            .await
387            .unwrap();
388
389        assert_eq!(page, other_page);
390
391        // Lookup the next 10 items
392        let page = repo
393            .upstream_oauth_provider()
394            .list(filter, Pagination::first(10).after(ids[9]))
395            .await
396            .unwrap();
397
398        // It returned the next 10 items
399        assert!(!page.has_next_page);
400        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
401        assert_eq!(&edge_ids, &ids[10..]);
402
403        // Lookup the last 10 items
404        let page = repo
405            .upstream_oauth_provider()
406            .list(filter, Pagination::last(10))
407            .await
408            .unwrap();
409
410        // It returned the last 10 items
411        assert!(page.has_previous_page);
412        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
413        assert_eq!(&edge_ids, &ids[10..]);
414
415        // Lookup the previous 10 items
416        let page = repo
417            .upstream_oauth_provider()
418            .list(filter, Pagination::last(10).before(ids[10]))
419            .await
420            .unwrap();
421
422        // It returned the previous 10 items
423        assert!(!page.has_previous_page);
424        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
425        assert_eq!(&edge_ids, &ids[..10]);
426
427        // Lookup 10 items between two IDs
428        let page = repo
429            .upstream_oauth_provider()
430            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
431            .await
432            .unwrap();
433
434        // It returned the items in between
435        assert!(!page.has_next_page);
436        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
437        assert_eq!(&edge_ids, &ids[6..8]);
438
439        // There should not be any disabled providers
440        assert!(
441            repo.upstream_oauth_provider()
442                .list(
443                    UpstreamOAuthProviderFilter::new().disabled_only(),
444                    Pagination::first(1)
445                )
446                .await
447                .unwrap()
448                .edges
449                .is_empty()
450        );
451    }
452
453    /// Test that the pagination works as expected in the upstream OAuth
454    /// session repository
455    #[sqlx::test(migrator = "crate::MIGRATOR")]
456    async fn test_session_repository_pagination(pool: PgPool) {
457        let scope = Scope::from_iter([OPENID]);
458
459        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
460        let clock = MockClock::default();
461        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
462
463        // Create a provider
464        let provider = repo
465            .upstream_oauth_provider()
466            .add(
467                &mut rng,
468                &clock,
469                UpstreamOAuthProviderParams {
470                    issuer: Some("https://example.com/".to_owned()),
471                    human_name: None,
472                    brand_name: None,
473                    scope,
474                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
475                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
476                    fetch_userinfo: false,
477                    userinfo_signed_response_alg: None,
478                    token_endpoint_signing_alg: None,
479                    client_id: "client-id".to_owned(),
480                    encrypted_client_secret: None,
481                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
482                    token_endpoint_override: None,
483                    authorization_endpoint_override: None,
484                    userinfo_endpoint_override: None,
485                    jwks_uri_override: None,
486                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
487                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
488                    response_mode: None,
489                    additional_authorization_parameters: Vec::new(),
490                    forward_login_hint: false,
491                    ui_order: 0,
492                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
493                },
494            )
495            .await
496            .unwrap();
497
498        let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
499
500        // Count the number of sessions before we start
501        assert_eq!(
502            repo.upstream_oauth_session().count(filter).await.unwrap(),
503            0
504        );
505
506        let mut links = Vec::with_capacity(3);
507        for subject in ["alice", "bob", "charlie"] {
508            let link = repo
509                .upstream_oauth_link()
510                .add(&mut rng, &clock, &provider, subject.to_owned(), None)
511                .await
512                .unwrap();
513            links.push(link);
514        }
515
516        let mut ids = Vec::with_capacity(20);
517        let sids = ["one", "two"].into_iter().cycle();
518        // Create 20 sessions
519        for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
520            let state = format!("state-{idx}");
521            let session = repo
522                .upstream_oauth_session()
523                .add(&mut rng, &clock, &provider, state, None, None)
524                .await
525                .unwrap();
526            let id_token_claims = serde_json::json!({
527                "sub": link.subject,
528                "sid": sid,
529                "aud": provider.client_id,
530                "iss": "https://example.com/",
531            });
532            let session = repo
533                .upstream_oauth_session()
534                .complete_with_link(
535                    &clock,
536                    session,
537                    link,
538                    None,
539                    Some(id_token_claims),
540                    None,
541                    None,
542                )
543                .await
544                .unwrap();
545            ids.push(session.id);
546            clock.advance(Duration::microseconds(10 * 1000 * 1000));
547        }
548
549        // Now we have 20 sessions
550        assert_eq!(
551            repo.upstream_oauth_session().count(filter).await.unwrap(),
552            20
553        );
554
555        // Lookup the first 10 items
556        let page = repo
557            .upstream_oauth_session()
558            .list(filter, Pagination::first(10))
559            .await
560            .unwrap();
561
562        // It returned the first 10 items
563        assert!(page.has_next_page);
564        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
565        assert_eq!(&edge_ids, &ids[..10]);
566
567        // Lookup the next 10 items
568        let page = repo
569            .upstream_oauth_session()
570            .list(filter, Pagination::first(10).after(ids[9]))
571            .await
572            .unwrap();
573
574        // It returned the next 10 items
575        assert!(!page.has_next_page);
576        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
577        assert_eq!(&edge_ids, &ids[10..]);
578
579        // Lookup the last 10 items
580        let page = repo
581            .upstream_oauth_session()
582            .list(filter, Pagination::last(10))
583            .await
584            .unwrap();
585
586        // It returned the last 10 items
587        assert!(page.has_previous_page);
588        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
589        assert_eq!(&edge_ids, &ids[10..]);
590
591        // Lookup the previous 10 items
592        let page = repo
593            .upstream_oauth_session()
594            .list(filter, Pagination::last(10).before(ids[10]))
595            .await
596            .unwrap();
597
598        // It returned the previous 10 items
599        assert!(!page.has_previous_page);
600        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
601        assert_eq!(&edge_ids, &ids[..10]);
602
603        // Lookup 5 items between two IDs
604        let page = repo
605            .upstream_oauth_session()
606            .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
607            .await
608            .unwrap();
609
610        // It returned the items in between
611        assert!(!page.has_next_page);
612        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
613        assert_eq!(&edge_ids, &ids[6..11]);
614
615        // Check the sub/sid filters
616        assert_eq!(
617            repo.upstream_oauth_session()
618                .count(filter.with_sub_claim("alice").with_sid_claim("one"))
619                .await
620                .unwrap(),
621            4
622        );
623        assert_eq!(
624            repo.upstream_oauth_session()
625                .count(filter.with_sub_claim("bob").with_sid_claim("two"))
626                .await
627                .unwrap(),
628            4
629        );
630
631        let page = repo
632            .upstream_oauth_session()
633            .list(
634                filter.with_sub_claim("alice").with_sid_claim("one"),
635                Pagination::first(10),
636            )
637            .await
638            .unwrap();
639        assert_eq!(page.edges.len(), 4);
640        for edge in page.edges {
641            assert_eq!(
642                edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
643                Some("alice")
644            );
645            assert_eq!(
646                edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
647                Some("one")
648            );
649        }
650    }
651}