1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use mas_data_model::{Clock, UserEmail, UserRecoverySession, UserRecoveryTicket};
12use mas_storage::user::UserRecoveryRepository;
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20pub struct PgUserRecoveryRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgUserRecoveryRepository<'c> {
26    pub fn new(conn: &'c mut PgConnection) -> Self {
29        Self { conn }
30    }
31}
32
33struct UserRecoverySessionRow {
34    user_recovery_session_id: Uuid,
35    email: String,
36    user_agent: String,
37    ip_address: Option<IpAddr>,
38    locale: String,
39    created_at: DateTime<Utc>,
40    consumed_at: Option<DateTime<Utc>>,
41}
42
43impl From<UserRecoverySessionRow> for UserRecoverySession {
44    fn from(row: UserRecoverySessionRow) -> Self {
45        UserRecoverySession {
46            id: row.user_recovery_session_id.into(),
47            email: row.email,
48            user_agent: row.user_agent,
49            ip_address: row.ip_address,
50            locale: row.locale,
51            created_at: row.created_at,
52            consumed_at: row.consumed_at,
53        }
54    }
55}
56
57struct UserRecoveryTicketRow {
58    user_recovery_ticket_id: Uuid,
59    user_recovery_session_id: Uuid,
60    user_email_id: Uuid,
61    ticket: String,
62    created_at: DateTime<Utc>,
63    expires_at: DateTime<Utc>,
64}
65
66impl From<UserRecoveryTicketRow> for UserRecoveryTicket {
67    fn from(row: UserRecoveryTicketRow) -> Self {
68        Self {
69            id: row.user_recovery_ticket_id.into(),
70            user_recovery_session_id: row.user_recovery_session_id.into(),
71            user_email_id: row.user_email_id.into(),
72            ticket: row.ticket,
73            created_at: row.created_at,
74            expires_at: row.expires_at,
75        }
76    }
77}
78
79#[async_trait]
80impl UserRecoveryRepository for PgUserRecoveryRepository<'_> {
81    type Error = DatabaseError;
82
83    #[tracing::instrument(
84        name = "db.user_recovery.lookup_session",
85        skip_all,
86        fields(
87            db.query.text,
88            user_recovery_session.id = %id,
89        ),
90        err,
91    )]
92    async fn lookup_session(
93        &mut self,
94        id: Ulid,
95    ) -> Result<Option<UserRecoverySession>, Self::Error> {
96        let row = sqlx::query_as!(
97            UserRecoverySessionRow,
98            r#"
99                SELECT
100                      user_recovery_session_id
101                    , email
102                    , user_agent
103                    , ip_address as "ip_address: IpAddr"
104                    , locale
105                    , created_at
106                    , consumed_at
107                FROM user_recovery_sessions
108                WHERE user_recovery_session_id = $1
109            "#,
110            Uuid::from(id),
111        )
112        .traced()
113        .fetch_optional(&mut *self.conn)
114        .await?;
115
116        let Some(row) = row else {
117            return Ok(None);
118        };
119
120        Ok(Some(row.into()))
121    }
122
123    #[tracing::instrument(
124        name = "db.user_recovery.add_session",
125        skip_all,
126        fields(
127            db.query.text,
128            user_recovery_session.id,
129            user_recovery_session.email = email,
130            user_recovery_session.user_agent = user_agent,
131            user_recovery_session.ip_address = ip_address.map(|ip| ip.to_string()),
132        )
133    )]
134    async fn add_session(
135        &mut self,
136        rng: &mut (dyn RngCore + Send),
137        clock: &dyn Clock,
138        email: String,
139        user_agent: String,
140        ip_address: Option<IpAddr>,
141        locale: String,
142    ) -> Result<UserRecoverySession, Self::Error> {
143        let created_at = clock.now();
144        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
145        tracing::Span::current().record("user_recovery_session.id", tracing::field::display(id));
146        sqlx::query!(
147            r#"
148                INSERT INTO user_recovery_sessions (
149                      user_recovery_session_id
150                    , email
151                    , user_agent
152                    , ip_address
153                    , locale
154                    , created_at
155                )
156                VALUES ($1, $2, $3, $4, $5, $6)
157            "#,
158            Uuid::from(id),
159            &email,
160            &*user_agent,
161            ip_address as Option<IpAddr>,
162            &locale,
163            created_at,
164        )
165        .traced()
166        .execute(&mut *self.conn)
167        .await?;
168
169        let user_recovery_session = UserRecoverySession {
170            id,
171            email,
172            user_agent,
173            ip_address,
174            locale,
175            created_at,
176            consumed_at: None,
177        };
178
179        Ok(user_recovery_session)
180    }
181
182    #[tracing::instrument(
183        name = "db.user_recovery.find_ticket",
184        skip_all,
185        fields(
186            db.query.text,
187            user_recovery_ticket.id = ticket,
188        ),
189        err,
190    )]
191    async fn find_ticket(
192        &mut self,
193        ticket: &str,
194    ) -> Result<Option<UserRecoveryTicket>, Self::Error> {
195        let row = sqlx::query_as!(
196            UserRecoveryTicketRow,
197            r#"
198                SELECT
199                      user_recovery_ticket_id
200                    , user_recovery_session_id
201                    , user_email_id
202                    , ticket
203                    , created_at
204                    , expires_at
205                FROM user_recovery_tickets
206                WHERE ticket = $1
207            "#,
208            ticket,
209        )
210        .traced()
211        .fetch_optional(&mut *self.conn)
212        .await?;
213
214        let Some(row) = row else {
215            return Ok(None);
216        };
217
218        Ok(Some(row.into()))
219    }
220
221    #[tracing::instrument(
222        name = "db.user_recovery.add_ticket",
223        skip_all,
224        fields(
225            db.query.text,
226            user_recovery_ticket.id,
227            user_recovery_ticket.id = ticket,
228            %user_recovery_session.id,
229            %user_email.id,
230        )
231    )]
232    async fn add_ticket(
233        &mut self,
234        rng: &mut (dyn RngCore + Send),
235        clock: &dyn Clock,
236        user_recovery_session: &UserRecoverySession,
237        user_email: &UserEmail,
238        ticket: String,
239    ) -> Result<UserRecoveryTicket, Self::Error> {
240        let created_at = clock.now();
241        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
242        tracing::Span::current().record("user_recovery_ticket.id", tracing::field::display(id));
243
244        let expires_at = created_at + Duration::minutes(10);
246
247        sqlx::query!(
248            r#"
249                INSERT INTO user_recovery_tickets (
250                      user_recovery_ticket_id
251                    , user_recovery_session_id
252                    , user_email_id
253                    , ticket
254                    , created_at
255                    , expires_at
256                )
257                VALUES ($1, $2, $3, $4, $5, $6)
258            "#,
259            Uuid::from(id),
260            Uuid::from(user_recovery_session.id),
261            Uuid::from(user_email.id),
262            &ticket,
263            created_at,
264            expires_at,
265        )
266        .traced()
267        .execute(&mut *self.conn)
268        .await?;
269
270        let ticket = UserRecoveryTicket {
271            id,
272            user_recovery_session_id: user_recovery_session.id,
273            user_email_id: user_email.id,
274            ticket,
275            created_at,
276            expires_at,
277        };
278
279        Ok(ticket)
280    }
281
282    #[tracing::instrument(
283        name = "db.user_recovery.consume_ticket",
284        skip_all,
285        fields(
286            db.query.text,
287            %user_recovery_ticket.id,
288            user_email.id = %user_recovery_ticket.user_email_id,
289            %user_recovery_session.id,
290            %user_recovery_session.email,
291        ),
292        err,
293    )]
294    async fn consume_ticket(
295        &mut self,
296        clock: &dyn Clock,
297        user_recovery_ticket: UserRecoveryTicket,
298        mut user_recovery_session: UserRecoverySession,
299    ) -> Result<UserRecoverySession, Self::Error> {
300        let _ = user_recovery_ticket;
302
303        if user_recovery_session.consumed_at.is_some() {
305            return Err(DatabaseError::invalid_operation());
306        }
307
308        let consumed_at = clock.now();
309
310        let res = sqlx::query!(
311            r#"
312                UPDATE user_recovery_sessions
313                SET consumed_at = $1
314                WHERE user_recovery_session_id = $2
315            "#,
316            consumed_at,
317            Uuid::from(user_recovery_session.id),
318        )
319        .traced()
320        .execute(&mut *self.conn)
321        .await?;
322
323        user_recovery_session.consumed_at = Some(consumed_at);
324
325        DatabaseError::ensure_affected_rows(&res, 1)?;
326
327        Ok(user_recovery_session)
328    }
329}