mas_handlers/upstream_oauth2/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::string::FromUtf8Error;
8
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderTokenAuthMethod};
10use mas_iana::jose::JsonWebSignatureAlg;
11use mas_keystore::{DecryptError, Encrypter, Keystore};
12use mas_oidc_client::types::client_credentials::ClientCredentials;
13use pkcs8::DecodePrivateKey;
14use serde::Deserialize;
15use thiserror::Error;
16use url::Url;
17
18pub(crate) mod authorize;
19pub(crate) mod backchannel_logout;
20pub(crate) mod cache;
21pub(crate) mod callback;
22mod cookie;
23pub(crate) mod link;
24mod template;
25
26use self::cookie::UpstreamSessions as UpstreamSessionsCookie;
27
28#[derive(Debug, Error)]
29#[allow(clippy::enum_variant_names)]
30enum ProviderCredentialsError {
31    #[error("Provider doesn't have a client secret")]
32    MissingClientSecret,
33
34    #[error("Could not decrypt client secret")]
35    DecryptClientSecret {
36        #[from]
37        inner: DecryptError,
38    },
39
40    #[error("Client secret is invalid")]
41    InvalidClientSecret {
42        #[from]
43        inner: FromUtf8Error,
44    },
45
46    #[error("Invalid JSON in client secret")]
47    InvalidClientSecretJson {
48        #[from]
49        inner: serde_json::Error,
50    },
51
52    #[error("Could not parse PEM encoded private key")]
53    InvalidPrivateKey {
54        #[from]
55        inner: pkcs8::Error,
56    },
57}
58
59#[derive(Debug, Deserialize)]
60pub struct SignInWithApple {
61    pub private_key: String,
62    pub team_id: String,
63    pub key_id: String,
64}
65
66fn client_credentials_for_provider(
67    provider: &UpstreamOAuthProvider,
68    token_endpoint: &Url,
69    keystore: &Keystore,
70    encrypter: &Encrypter,
71) -> Result<ClientCredentials, ProviderCredentialsError> {
72    let client_id = provider.client_id.clone();
73
74    // Decrypt the client secret
75    let client_secret = provider
76        .encrypted_client_secret
77        .as_deref()
78        .map(|encrypted_client_secret| {
79            let decrypted = encrypter.decrypt_string(encrypted_client_secret)?;
80            let decrypted = String::from_utf8(decrypted)?;
81            Ok::<_, ProviderCredentialsError>(decrypted)
82        })
83        .transpose()?;
84
85    let client_credentials = match provider.token_endpoint_auth_method {
86        UpstreamOAuthProviderTokenAuthMethod::None => ClientCredentials::None { client_id },
87
88        UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost => {
89            ClientCredentials::ClientSecretPost {
90                client_id,
91                client_secret: client_secret
92                    .ok_or(ProviderCredentialsError::MissingClientSecret)?,
93            }
94        }
95
96        UpstreamOAuthProviderTokenAuthMethod::ClientSecretBasic => {
97            ClientCredentials::ClientSecretBasic {
98                client_id,
99                client_secret: client_secret
100                    .ok_or(ProviderCredentialsError::MissingClientSecret)?,
101            }
102        }
103
104        UpstreamOAuthProviderTokenAuthMethod::ClientSecretJwt => {
105            ClientCredentials::ClientSecretJwt {
106                client_id,
107                client_secret: client_secret
108                    .ok_or(ProviderCredentialsError::MissingClientSecret)?,
109                signing_algorithm: provider
110                    .token_endpoint_signing_alg
111                    .clone()
112                    .unwrap_or(JsonWebSignatureAlg::Rs256),
113                token_endpoint: token_endpoint.clone(),
114            }
115        }
116
117        UpstreamOAuthProviderTokenAuthMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt {
118            client_id,
119            keystore: keystore.clone(),
120            signing_algorithm: provider
121                .token_endpoint_signing_alg
122                .clone()
123                .unwrap_or(JsonWebSignatureAlg::Rs256),
124            token_endpoint: token_endpoint.clone(),
125        },
126
127        UpstreamOAuthProviderTokenAuthMethod::SignInWithApple => {
128            let params = client_secret.ok_or(ProviderCredentialsError::MissingClientSecret)?;
129            let params: SignInWithApple = serde_json::from_str(&params)?;
130
131            let key = elliptic_curve::SecretKey::from_pkcs8_pem(&params.private_key)?;
132
133            ClientCredentials::SignInWithApple {
134                client_id,
135                key,
136                key_id: params.key_id,
137                team_id: params.team_id,
138            }
139        }
140    };
141
142    Ok(client_credentials)
143}