mas_storage_pg/user/
email.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11    UserRegistration,
12};
13use mas_storage::{
14    Clock, Page, Pagination,
15    user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError,
26    filter::{Filter, StatementExt},
27    iden::UserEmails,
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
33pub struct PgUserEmailRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38    /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48    user_email_id: Uuid,
49    user_id: Uuid,
50    email: String,
51    created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55    fn from(e: UserEmailLookup) -> UserEmail {
56        UserEmail {
57            id: e.user_email_id.into(),
58            user_id: e.user_id.into(),
59            email: e.email,
60            created_at: e.created_at,
61        }
62    }
63}
64
65struct UserEmailAuthenticationLookup {
66    user_email_authentication_id: Uuid,
67    user_session_id: Option<Uuid>,
68    user_registration_id: Option<Uuid>,
69    email: String,
70    created_at: DateTime<Utc>,
71    completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75    fn from(value: UserEmailAuthenticationLookup) -> Self {
76        UserEmailAuthentication {
77            id: value.user_email_authentication_id.into(),
78            user_session_id: value.user_session_id.map(Ulid::from),
79            user_registration_id: value.user_registration_id.map(Ulid::from),
80            email: value.email,
81            created_at: value.created_at,
82            completed_at: value.completed_at,
83        }
84    }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88    user_email_authentication_code_id: Uuid,
89    user_email_authentication_id: Uuid,
90    code: String,
91    created_at: DateTime<Utc>,
92    expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96    fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97        UserEmailAuthenticationCode {
98            id: value.user_email_authentication_code_id.into(),
99            user_email_authentication_id: value.user_email_authentication_id.into(),
100            code: value.code,
101            created_at: value.created_at,
102            expires_at: value.expires_at,
103        }
104    }
105}
106
107impl Filter for UserEmailFilter<'_> {
108    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109        sea_query::Condition::all()
110            .add_option(self.user().map(|user| {
111                Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112            }))
113            .add_option(self.email().map(|email| {
114                SimpleExpr::from(Func::lower(Expr::col((
115                    UserEmails::Table,
116                    UserEmails::Email,
117                ))))
118                .eq(Func::lower(email))
119            }))
120    }
121}
122
123#[async_trait]
124impl UserEmailRepository for PgUserEmailRepository<'_> {
125    type Error = DatabaseError;
126
127    #[tracing::instrument(
128        name = "db.user_email.lookup",
129        skip_all,
130        fields(
131            db.query.text,
132            user_email.id = %id,
133        ),
134        err,
135    )]
136    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
137        let res = sqlx::query_as!(
138            UserEmailLookup,
139            r#"
140                SELECT user_email_id
141                     , user_id
142                     , email
143                     , created_at
144                FROM user_emails
145
146                WHERE user_email_id = $1
147            "#,
148            Uuid::from(id),
149        )
150        .traced()
151        .fetch_optional(&mut *self.conn)
152        .await?;
153
154        let Some(user_email) = res else {
155            return Ok(None);
156        };
157
158        Ok(Some(user_email.into()))
159    }
160
161    #[tracing::instrument(
162        name = "db.user_email.find",
163        skip_all,
164        fields(
165            db.query.text,
166            %user.id,
167            user_email.email = email,
168        ),
169        err,
170    )]
171    async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
172        let res = sqlx::query_as!(
173            UserEmailLookup,
174            r#"
175                SELECT user_email_id
176                     , user_id
177                     , email
178                     , created_at
179                FROM user_emails
180
181                WHERE user_id = $1 AND LOWER(email) = LOWER($2)
182            "#,
183            Uuid::from(user.id),
184            email,
185        )
186        .traced()
187        .fetch_optional(&mut *self.conn)
188        .await?;
189
190        let Some(user_email) = res else {
191            return Ok(None);
192        };
193
194        Ok(Some(user_email.into()))
195    }
196
197    #[tracing::instrument(
198        name = "db.user_email.find_by_email",
199        skip_all,
200        fields(
201            db.query.text,
202            user_email.email = email,
203        ),
204        err,
205    )]
206    async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
207        let res = sqlx::query_as!(
208            UserEmailLookup,
209            r#"
210                SELECT user_email_id
211                     , user_id
212                     , email
213                     , created_at
214                FROM user_emails
215                WHERE LOWER(email) = LOWER($1)
216            "#,
217            email,
218        )
219        .traced()
220        .fetch_all(&mut *self.conn)
221        .await?;
222
223        if res.len() != 1 {
224            return Ok(None);
225        }
226
227        let Some(user_email) = res.into_iter().next() else {
228            return Ok(None);
229        };
230
231        Ok(Some(user_email.into()))
232    }
233
234    #[tracing::instrument(
235        name = "db.user_email.all",
236        skip_all,
237        fields(
238            db.query.text,
239            %user.id,
240        ),
241        err,
242    )]
243    async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
244        let res = sqlx::query_as!(
245            UserEmailLookup,
246            r#"
247                SELECT user_email_id
248                     , user_id
249                     , email
250                     , created_at
251                FROM user_emails
252
253                WHERE user_id = $1
254
255                ORDER BY email ASC
256            "#,
257            Uuid::from(user.id),
258        )
259        .traced()
260        .fetch_all(&mut *self.conn)
261        .await?;
262
263        Ok(res.into_iter().map(Into::into).collect())
264    }
265
266    #[tracing::instrument(
267        name = "db.user_email.list",
268        skip_all,
269        fields(
270            db.query.text,
271        ),
272        err,
273    )]
274    async fn list(
275        &mut self,
276        filter: UserEmailFilter<'_>,
277        pagination: Pagination,
278    ) -> Result<Page<UserEmail>, DatabaseError> {
279        let (sql, arguments) = Query::select()
280            .expr_as(
281                Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
282                UserEmailLookupIden::UserEmailId,
283            )
284            .expr_as(
285                Expr::col((UserEmails::Table, UserEmails::UserId)),
286                UserEmailLookupIden::UserId,
287            )
288            .expr_as(
289                Expr::col((UserEmails::Table, UserEmails::Email)),
290                UserEmailLookupIden::Email,
291            )
292            .expr_as(
293                Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
294                UserEmailLookupIden::CreatedAt,
295            )
296            .from(UserEmails::Table)
297            .apply_filter(filter)
298            .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
299            .build_sqlx(PostgresQueryBuilder);
300
301        let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
302            .traced()
303            .fetch_all(&mut *self.conn)
304            .await?;
305
306        let page = pagination.process(edges).map(UserEmail::from);
307
308        Ok(page)
309    }
310
311    #[tracing::instrument(
312        name = "db.user_email.count",
313        skip_all,
314        fields(
315            db.query.text,
316        ),
317        err,
318    )]
319    async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
320        let (sql, arguments) = Query::select()
321            .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
322            .from(UserEmails::Table)
323            .apply_filter(filter)
324            .build_sqlx(PostgresQueryBuilder);
325
326        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
327            .traced()
328            .fetch_one(&mut *self.conn)
329            .await?;
330
331        count
332            .try_into()
333            .map_err(DatabaseError::to_invalid_operation)
334    }
335
336    #[tracing::instrument(
337        name = "db.user_email.add",
338        skip_all,
339        fields(
340            db.query.text,
341            %user.id,
342            user_email.id,
343            user_email.email = email,
344        ),
345        err,
346    )]
347    async fn add(
348        &mut self,
349        rng: &mut (dyn RngCore + Send),
350        clock: &dyn Clock,
351        user: &User,
352        email: String,
353    ) -> Result<UserEmail, Self::Error> {
354        let created_at = clock.now();
355        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
356        tracing::Span::current().record("user_email.id", tracing::field::display(id));
357
358        sqlx::query!(
359            r#"
360                INSERT INTO user_emails (user_email_id, user_id, email, created_at)
361                VALUES ($1, $2, $3, $4)
362            "#,
363            Uuid::from(id),
364            Uuid::from(user.id),
365            &email,
366            created_at,
367        )
368        .traced()
369        .execute(&mut *self.conn)
370        .await?;
371
372        Ok(UserEmail {
373            id,
374            user_id: user.id,
375            email,
376            created_at,
377        })
378    }
379
380    #[tracing::instrument(
381        name = "db.user_email.remove",
382        skip_all,
383        fields(
384            db.query.text,
385            user.id = %user_email.user_id,
386            %user_email.id,
387            %user_email.email,
388        ),
389        err,
390    )]
391    async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
392        let res = sqlx::query!(
393            r#"
394                DELETE FROM user_emails
395                WHERE user_email_id = $1
396            "#,
397            Uuid::from(user_email.id),
398        )
399        .traced()
400        .execute(&mut *self.conn)
401        .await?;
402
403        DatabaseError::ensure_affected_rows(&res, 1)?;
404
405        Ok(())
406    }
407
408    #[tracing::instrument(
409        name = "db.user_email.remove_bulk",
410        skip_all,
411        fields(
412            db.query.text,
413        ),
414        err,
415    )]
416    async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
417        let (sql, arguments) = Query::delete()
418            .from_table(UserEmails::Table)
419            .apply_filter(filter)
420            .build_sqlx(PostgresQueryBuilder);
421
422        let res = sqlx::query_with(&sql, arguments)
423            .traced()
424            .execute(&mut *self.conn)
425            .await?;
426
427        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
428    }
429
430    #[tracing::instrument(
431        name = "db.user_email.add_authentication_for_session",
432        skip_all,
433        fields(
434            db.query.text,
435            %session.id,
436            user_email_authentication.id,
437            user_email_authentication.email = email,
438        ),
439        err,
440    )]
441    async fn add_authentication_for_session(
442        &mut self,
443        rng: &mut (dyn RngCore + Send),
444        clock: &dyn Clock,
445        email: String,
446        session: &BrowserSession,
447    ) -> Result<UserEmailAuthentication, Self::Error> {
448        let created_at = clock.now();
449        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
450        tracing::Span::current()
451            .record("user_email_authentication.id", tracing::field::display(id));
452
453        sqlx::query!(
454            r#"
455                INSERT INTO user_email_authentications
456                  ( user_email_authentication_id
457                  , user_session_id
458                  , email
459                  , created_at
460                  )
461                VALUES ($1, $2, $3, $4)
462            "#,
463            Uuid::from(id),
464            Uuid::from(session.id),
465            &email,
466            created_at,
467        )
468        .traced()
469        .execute(&mut *self.conn)
470        .await?;
471
472        Ok(UserEmailAuthentication {
473            id,
474            user_session_id: Some(session.id),
475            user_registration_id: None,
476            email,
477            created_at,
478            completed_at: None,
479        })
480    }
481
482    #[tracing::instrument(
483        name = "db.user_email.add_authentication_for_registration",
484        skip_all,
485        fields(
486            db.query.text,
487            %user_registration.id,
488            user_email_authentication.id,
489            user_email_authentication.email = email,
490        ),
491        err,
492    )]
493    async fn add_authentication_for_registration(
494        &mut self,
495        rng: &mut (dyn RngCore + Send),
496        clock: &dyn Clock,
497        email: String,
498        user_registration: &UserRegistration,
499    ) -> Result<UserEmailAuthentication, Self::Error> {
500        let created_at = clock.now();
501        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
502        tracing::Span::current()
503            .record("user_email_authentication.id", tracing::field::display(id));
504
505        sqlx::query!(
506            r#"
507                INSERT INTO user_email_authentications
508                  ( user_email_authentication_id
509                  , user_registration_id
510                  , email
511                  , created_at
512                  )
513                VALUES ($1, $2, $3, $4)
514            "#,
515            Uuid::from(id),
516            Uuid::from(user_registration.id),
517            &email,
518            created_at,
519        )
520        .traced()
521        .execute(&mut *self.conn)
522        .await?;
523
524        Ok(UserEmailAuthentication {
525            id,
526            user_session_id: None,
527            user_registration_id: Some(user_registration.id),
528            email,
529            created_at,
530            completed_at: None,
531        })
532    }
533
534    #[tracing::instrument(
535        name = "db.user_email.add_authentication_code",
536        skip_all,
537        fields(
538            db.query.text,
539            %user_email_authentication.id,
540            %user_email_authentication.email,
541            user_email_authentication_code.id,
542            user_email_authentication_code.code = code,
543        ),
544        err,
545    )]
546    async fn add_authentication_code(
547        &mut self,
548        rng: &mut (dyn RngCore + Send),
549        clock: &dyn Clock,
550        duration: chrono::Duration,
551        user_email_authentication: &UserEmailAuthentication,
552        code: String,
553    ) -> Result<UserEmailAuthenticationCode, Self::Error> {
554        let created_at = clock.now();
555        let expires_at = created_at + duration;
556        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
557        tracing::Span::current().record(
558            "user_email_authentication_code.id",
559            tracing::field::display(id),
560        );
561
562        sqlx::query!(
563            r#"
564                INSERT INTO user_email_authentication_codes
565                  ( user_email_authentication_code_id
566                  , user_email_authentication_id
567                  , code
568                  , created_at
569                  , expires_at
570                  )
571                VALUES ($1, $2, $3, $4, $5)
572            "#,
573            Uuid::from(id),
574            Uuid::from(user_email_authentication.id),
575            &code,
576            created_at,
577            expires_at,
578        )
579        .traced()
580        .execute(&mut *self.conn)
581        .await?;
582
583        Ok(UserEmailAuthenticationCode {
584            id,
585            user_email_authentication_id: user_email_authentication.id,
586            code,
587            created_at,
588            expires_at,
589        })
590    }
591
592    #[tracing::instrument(
593        name = "db.user_email.lookup_authentication",
594        skip_all,
595        fields(
596            db.query.text,
597            user_email_authentication.id = %id,
598        ),
599        err,
600    )]
601    async fn lookup_authentication(
602        &mut self,
603        id: Ulid,
604    ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
605        let res = sqlx::query_as!(
606            UserEmailAuthenticationLookup,
607            r#"
608                SELECT user_email_authentication_id
609                     , user_session_id
610                     , user_registration_id
611                     , email
612                     , created_at
613                     , completed_at
614                FROM user_email_authentications
615                WHERE user_email_authentication_id = $1
616            "#,
617            Uuid::from(id),
618        )
619        .traced()
620        .fetch_optional(&mut *self.conn)
621        .await?;
622
623        Ok(res.map(UserEmailAuthentication::from))
624    }
625
626    #[tracing::instrument(
627        name = "db.user_email.find_authentication_by_code",
628        skip_all,
629        fields(
630            db.query.text,
631            %authentication.id,
632            user_email_authentication_code.code = code,
633        ),
634        err,
635    )]
636    async fn find_authentication_code(
637        &mut self,
638        authentication: &UserEmailAuthentication,
639        code: &str,
640    ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
641        let res = sqlx::query_as!(
642            UserEmailAuthenticationCodeLookup,
643            r#"
644                SELECT user_email_authentication_code_id
645                     , user_email_authentication_id
646                     , code
647                     , created_at
648                     , expires_at
649                FROM user_email_authentication_codes
650                WHERE user_email_authentication_id = $1
651                  AND code = $2
652            "#,
653            Uuid::from(authentication.id),
654            code,
655        )
656        .traced()
657        .fetch_optional(&mut *self.conn)
658        .await?;
659
660        Ok(res.map(UserEmailAuthenticationCode::from))
661    }
662
663    #[tracing::instrument(
664        name = "db.user_email.complete_email_authentication",
665        skip_all,
666        fields(
667            db.query.text,
668            %user_email_authentication.id,
669            %user_email_authentication.email,
670            %user_email_authentication_code.id,
671            %user_email_authentication_code.code,
672        ),
673        err,
674    )]
675    async fn complete_authentication(
676        &mut self,
677        clock: &dyn Clock,
678        mut user_email_authentication: UserEmailAuthentication,
679        user_email_authentication_code: &UserEmailAuthenticationCode,
680    ) -> Result<UserEmailAuthentication, Self::Error> {
681        // We technically don't use the authentication code here (other than
682        // recording it in the span), but this is to make sure the caller has
683        // fetched one before calling this
684        let completed_at = clock.now();
685
686        // We'll assume the caller has checked that completed_at is None, so in case
687        // they haven't, the update will not affect any rows, which will raise
688        // an error
689        let res = sqlx::query!(
690            r#"
691                UPDATE user_email_authentications
692                SET completed_at = $2
693                WHERE user_email_authentication_id = $1
694                  AND completed_at IS NULL
695            "#,
696            Uuid::from(user_email_authentication.id),
697            completed_at,
698        )
699        .traced()
700        .execute(&mut *self.conn)
701        .await?;
702
703        DatabaseError::ensure_affected_rows(&res, 1)?;
704
705        user_email_authentication.completed_at = Some(completed_at);
706        Ok(user_email_authentication)
707    }
708}