mas_storage/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11    UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12    UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
13    UpstreamOAuthProviderResponseMode, UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26    /// The OIDC issuer of the provider
27    pub issuer: Option<String>,
28
29    /// A human-readable name for the provider
30    pub human_name: Option<String>,
31
32    /// A brand identifier, e.g. "apple" or "google"
33    pub brand_name: Option<String>,
34
35    /// The scope to request during the authorization flow
36    pub scope: Scope,
37
38    /// The token endpoint authentication method
39    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41    /// The JWT signing algorithm to use when then `client_secret_jwt` or
42    /// `private_key_jwt` authentication methods are used
43    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45    /// Expected signature for the JWT payload returned by the token
46    /// authentication endpoint.
47    ///
48    /// Defaults to `RS256`.
49    pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51    /// Whether to fetch the user profile from the userinfo endpoint,
52    /// or to rely on the data returned in the `id_token` from the
53    /// `token_endpoint`.
54    pub fetch_userinfo: bool,
55
56    /// Expected signature for the JWT payload returned by the userinfo
57    /// endpoint.
58    ///
59    /// If not specified, the response is expected to be an unsigned JSON
60    /// payload. Defaults to `None`.
61    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63    /// The client ID to use when authenticating to the upstream
64    pub client_id: String,
65
66    /// The encrypted client secret to use when authenticating to the upstream
67    pub encrypted_client_secret: Option<String>,
68
69    /// How claims should be imported from the upstream provider
70    pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72    /// The URL to use as the authorization endpoint. If `None`, the URL will be
73    /// discovered
74    pub authorization_endpoint_override: Option<Url>,
75
76    /// The URL to use as the token endpoint. If `None`, the URL will be
77    /// discovered
78    pub token_endpoint_override: Option<Url>,
79
80    /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81    /// discovered
82    pub userinfo_endpoint_override: Option<Url>,
83
84    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85    pub jwks_uri_override: Option<Url>,
86
87    /// How the provider metadata should be discovered
88    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90    /// How should PKCE be used
91    pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93    /// What response mode it should ask
94    pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96    /// Additional parameters to include in the authorization request
97    pub additional_authorization_parameters: Vec<(String, String)>,
98
99    /// Whether to forward the login hint to the upstream provider.
100    pub forward_login_hint: bool,
101
102    /// The position of the provider in the UI
103    pub ui_order: i32,
104
105    /// The behavior when receiving a backchannel logout notification
106    pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
107}
108
109/// Filter parameters for listing upstream OAuth 2.0 providers
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
111pub struct UpstreamOAuthProviderFilter<'a> {
112    /// Filter by whether the provider is enabled
113    ///
114    /// If `None`, all providers are returned
115    enabled: Option<bool>,
116
117    _lifetime: PhantomData<&'a ()>,
118}
119
120impl UpstreamOAuthProviderFilter<'_> {
121    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
122    #[must_use]
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Return only enabled providers
128    #[must_use]
129    pub const fn enabled_only(mut self) -> Self {
130        self.enabled = Some(true);
131        self
132    }
133
134    /// Return only disabled providers
135    #[must_use]
136    pub const fn disabled_only(mut self) -> Self {
137        self.enabled = Some(false);
138        self
139    }
140
141    /// Get the enabled filter
142    ///
143    /// Returns `None` if the filter is not set
144    #[must_use]
145    pub const fn enabled(&self) -> Option<bool> {
146        self.enabled
147    }
148}
149
150/// An [`UpstreamOAuthProviderRepository`] helps interacting with
151/// [`UpstreamOAuthProvider`] saved in the storage backend
152#[async_trait]
153pub trait UpstreamOAuthProviderRepository: Send + Sync {
154    /// The error type returned by the repository
155    type Error;
156
157    /// Lookup an upstream OAuth provider by its ID
158    ///
159    /// Returns `None` if the provider was not found
160    ///
161    /// # Parameters
162    ///
163    /// * `id`: The ID of the provider to lookup
164    ///
165    /// # Errors
166    ///
167    /// Returns [`Self::Error`] if the underlying repository fails
168    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
169
170    /// Add a new upstream OAuth provider
171    ///
172    /// Returns the newly created provider
173    ///
174    /// # Parameters
175    ///
176    /// * `rng`: A random number generator
177    /// * `clock`: The clock used to generate timestamps
178    /// * `params`: The parameters of the provider to add
179    ///
180    /// # Errors
181    ///
182    /// Returns [`Self::Error`] if the underlying repository fails
183    async fn add(
184        &mut self,
185        rng: &mut (dyn RngCore + Send),
186        clock: &dyn Clock,
187        params: UpstreamOAuthProviderParams,
188    ) -> Result<UpstreamOAuthProvider, Self::Error>;
189
190    /// Delete an upstream OAuth provider
191    ///
192    /// # Parameters
193    ///
194    /// * `provider`: The provider to delete
195    ///
196    /// # Errors
197    ///
198    /// Returns [`Self::Error`] if the underlying repository fails
199    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
200        self.delete_by_id(provider.id).await
201    }
202
203    /// Delete an upstream OAuth provider by its ID
204    ///
205    /// # Parameters
206    ///
207    /// * `id`: The ID of the provider to delete
208    ///
209    /// # Errors
210    ///
211    /// Returns [`Self::Error`] if the underlying repository fails
212    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
213
214    /// Insert or update an upstream OAuth provider
215    ///
216    /// # Parameters
217    ///
218    /// * `clock`: The clock used to generate timestamps
219    /// * `id`: The ID of the provider to update
220    /// * `params`: The parameters of the provider to update
221    ///
222    /// # Errors
223    ///
224    /// Returns [`Self::Error`] if the underlying repository fails
225    async fn upsert(
226        &mut self,
227        clock: &dyn Clock,
228        id: Ulid,
229        params: UpstreamOAuthProviderParams,
230    ) -> Result<UpstreamOAuthProvider, Self::Error>;
231
232    /// Disable an upstream OAuth provider
233    ///
234    /// Returns the disabled provider
235    ///
236    /// # Parameters
237    ///
238    /// * `clock`: The clock used to generate timestamps
239    /// * `provider`: The provider to disable
240    ///
241    /// # Errors
242    ///
243    /// Returns [`Self::Error`] if the underlying repository fails
244    async fn disable(
245        &mut self,
246        clock: &dyn Clock,
247        provider: UpstreamOAuthProvider,
248    ) -> Result<UpstreamOAuthProvider, Self::Error>;
249
250    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
251    ///
252    /// # Parameters
253    ///
254    /// * `filter`: The filter to apply
255    /// * `pagination`: The pagination parameters
256    ///
257    /// # Errors
258    ///
259    /// Returns [`Self::Error`] if the underlying repository fails
260    async fn list(
261        &mut self,
262        filter: UpstreamOAuthProviderFilter<'_>,
263        pagination: Pagination,
264    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
265
266    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
267    ///
268    /// # Parameters
269    ///
270    /// * `filter`: The filter to apply
271    ///
272    /// # Errors
273    ///
274    /// Returns [`Self::Error`] if the underlying repository fails
275    async fn count(
276        &mut self,
277        filter: UpstreamOAuthProviderFilter<'_>,
278    ) -> Result<usize, Self::Error>;
279
280    /// Get all enabled upstream OAuth providers
281    ///
282    /// # Errors
283    ///
284    /// Returns [`Self::Error`] if the underlying repository fails
285    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
286}
287
288repository_impl!(UpstreamOAuthProviderRepository:
289    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
290
291    async fn add(
292        &mut self,
293        rng: &mut (dyn RngCore + Send),
294        clock: &dyn Clock,
295        params: UpstreamOAuthProviderParams
296    ) -> Result<UpstreamOAuthProvider, Self::Error>;
297
298    async fn upsert(
299        &mut self,
300        clock: &dyn Clock,
301        id: Ulid,
302        params: UpstreamOAuthProviderParams
303    ) -> Result<UpstreamOAuthProvider, Self::Error>;
304
305    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
306
307    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
308
309    async fn disable(
310        &mut self,
311        clock: &dyn Clock,
312        provider: UpstreamOAuthProvider
313    ) -> Result<UpstreamOAuthProvider, Self::Error>;
314
315    async fn list(
316        &mut self,
317        filter: UpstreamOAuthProviderFilter<'_>,
318        pagination: Pagination
319    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
320
321    async fn count(
322        &mut self,
323        filter: UpstreamOAuthProviderFilter<'_>
324    ) -> Result<usize, Self::Error>;
325
326    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
327);