mas_storage_pg/personal/
session.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, User,
12    personal::session::{PersonalSession, PersonalSessionOwner, SessionState},
13};
14use mas_storage::{
15    Page, Pagination,
16    pagination::Node,
17    personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
18};
19use oauth2_types::scope::Scope;
20use rand::RngCore;
21use sea_query::{
22    Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
23    extension::postgres::PgExpr as _,
24};
25use sea_query_binder::SqlxBinder as _;
26use sqlx::PgConnection;
27use ulid::Ulid;
28use uuid::Uuid;
29
30use crate::{
31    DatabaseError,
32    errors::DatabaseInconsistencyError,
33    filter::{Filter, StatementExt as _},
34    iden::PersonalSessions,
35    pagination::QueryBuilderExt as _,
36    tracing::ExecuteExt as _,
37};
38
39/// An implementation of [`PersonalSessionRepository`] for a PostgreSQL
40/// connection
41pub struct PgPersonalSessionRepository<'c> {
42    conn: &'c mut PgConnection,
43}
44
45impl<'c> PgPersonalSessionRepository<'c> {
46    /// Create a new [`PgPersonalSessionRepository`] from an active PostgreSQL
47    /// connection
48    pub fn new(conn: &'c mut PgConnection) -> Self {
49        Self { conn }
50    }
51}
52
53#[derive(sqlx::FromRow)]
54#[enum_def]
55struct PersonalSessionLookup {
56    personal_session_id: Uuid,
57    owner_user_id: Option<Uuid>,
58    owner_oauth2_client_id: Option<Uuid>,
59    actor_user_id: Uuid,
60    human_name: String,
61    scope_list: Vec<String>,
62    created_at: DateTime<Utc>,
63    revoked_at: Option<DateTime<Utc>>,
64    last_active_at: Option<DateTime<Utc>>,
65    last_active_ip: Option<IpAddr>,
66}
67
68impl Node<Ulid> for PersonalSessionLookup {
69    fn cursor(&self) -> Ulid {
70        self.personal_session_id.into()
71    }
72}
73
74impl TryFrom<PersonalSessionLookup> for PersonalSession {
75    type Error = DatabaseInconsistencyError;
76
77    fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
78        let id = Ulid::from(value.personal_session_id);
79        let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
80        let scope = scope.map_err(|e| {
81            DatabaseInconsistencyError::on("personal_sessions")
82                .column("scope")
83                .row(id)
84                .source(e)
85        })?;
86
87        let state = match value.revoked_at {
88            None => SessionState::Valid,
89            Some(revoked_at) => SessionState::Revoked { revoked_at },
90        };
91
92        let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
93            (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
94            (None, Some(owner_oauth2_client_id)) => {
95                PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
96            }
97            _ => {
98                // should be impossible (CHECK constraint in Postgres prevents it)
99                return Err(DatabaseInconsistencyError::on("personal_sessions")
100                    .column("owner_user_id, owner_oauth2_client_id")
101                    .row(id));
102            }
103        };
104
105        Ok(PersonalSession {
106            id,
107            state,
108            owner,
109            actor_user_id: Ulid::from(value.actor_user_id),
110            human_name: value.human_name,
111            scope,
112            created_at: value.created_at,
113            last_active_at: value.last_active_at,
114            last_active_ip: value.last_active_ip,
115        })
116    }
117}
118
119#[async_trait]
120impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
121    type Error = DatabaseError;
122
123    #[tracing::instrument(
124        name = "db.personal_session.lookup",
125        skip_all,
126        fields(
127            db.query.text,
128            session.id = %id,
129        ),
130        err,
131    )]
132    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
133        let res = sqlx::query_as!(
134            PersonalSessionLookup,
135            r#"
136                SELECT personal_session_id
137                     , owner_user_id
138                     , owner_oauth2_client_id
139                     , actor_user_id
140                     , scope_list
141                     , created_at
142                     , revoked_at
143                     , human_name
144                     , last_active_at
145                     , last_active_ip as "last_active_ip: IpAddr"
146                FROM personal_sessions
147
148                WHERE personal_session_id = $1
149            "#,
150            Uuid::from(id),
151        )
152        .traced()
153        .fetch_optional(&mut *self.conn)
154        .await?;
155
156        let Some(session) = res else { return Ok(None) };
157
158        Ok(Some(session.try_into()?))
159    }
160
161    #[tracing::instrument(
162        name = "db.personal_session.add",
163        skip_all,
164        fields(
165            db.query.text,
166            session.id,
167            session.scope = %scope,
168        ),
169        err,
170    )]
171    async fn add(
172        &mut self,
173        rng: &mut (dyn RngCore + Send),
174        clock: &dyn Clock,
175        owner: PersonalSessionOwner,
176        actor_user: &User,
177        human_name: String,
178        scope: Scope,
179    ) -> Result<PersonalSession, Self::Error> {
180        let created_at = clock.now();
181        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
182        tracing::Span::current().record("session.id", tracing::field::display(id));
183
184        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
185
186        let (owner_user_id, owner_oauth2_client_id) = match owner {
187            PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
188            PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
189        };
190
191        sqlx::query!(
192            r#"
193                INSERT INTO personal_sessions
194                    ( personal_session_id
195                    , owner_user_id
196                    , owner_oauth2_client_id
197                    , actor_user_id
198                    , human_name
199                    , scope_list
200                    , created_at
201                    )
202                VALUES ($1, $2, $3, $4, $5, $6, $7)
203            "#,
204            Uuid::from(id),
205            owner_user_id,
206            owner_oauth2_client_id,
207            Uuid::from(actor_user.id),
208            &human_name,
209            &scope_list,
210            created_at,
211        )
212        .traced()
213        .execute(&mut *self.conn)
214        .await?;
215
216        Ok(PersonalSession {
217            id,
218            state: SessionState::Valid,
219            owner,
220            actor_user_id: actor_user.id,
221            human_name,
222            scope,
223            created_at,
224            last_active_at: None,
225            last_active_ip: None,
226        })
227    }
228
229    #[tracing::instrument(
230        name = "db.personal_session.revoke",
231        skip_all,
232        fields(
233            db.query.text,
234            %session.id,
235            %session.scope,
236        ),
237        err,
238    )]
239    async fn revoke(
240        &mut self,
241        clock: &dyn Clock,
242        session: PersonalSession,
243    ) -> Result<PersonalSession, Self::Error> {
244        let finished_at = clock.now();
245        let res = sqlx::query!(
246            r#"
247                UPDATE personal_sessions
248                SET revoked_at = $2
249                WHERE personal_session_id = $1
250            "#,
251            Uuid::from(session.id),
252            finished_at,
253        )
254        .traced()
255        .execute(&mut *self.conn)
256        .await?;
257
258        DatabaseError::ensure_affected_rows(&res, 1)?;
259
260        session
261            .finish(finished_at)
262            .map_err(DatabaseError::to_invalid_operation)
263    }
264
265    #[tracing::instrument(
266        name = "db.personal_session.list",
267        skip_all,
268        fields(
269            db.query.text,
270        ),
271        err,
272    )]
273    async fn list(
274        &mut self,
275        filter: PersonalSessionFilter<'_>,
276        pagination: Pagination,
277    ) -> Result<Page<PersonalSession>, Self::Error> {
278        let (sql, arguments) = Query::select()
279            .expr_as(
280                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
281                PersonalSessionLookupIden::PersonalSessionId,
282            )
283            .expr_as(
284                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
285                PersonalSessionLookupIden::OwnerUserId,
286            )
287            .expr_as(
288                Expr::col((
289                    PersonalSessions::Table,
290                    PersonalSessions::OwnerOAuth2ClientId,
291                )),
292                PersonalSessionLookupIden::OwnerOauth2ClientId,
293            )
294            .expr_as(
295                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
296                PersonalSessionLookupIden::ActorUserId,
297            )
298            .expr_as(
299                Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
300                PersonalSessionLookupIden::HumanName,
301            )
302            .expr_as(
303                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
304                PersonalSessionLookupIden::ScopeList,
305            )
306            .expr_as(
307                Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
308                PersonalSessionLookupIden::CreatedAt,
309            )
310            .expr_as(
311                Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
312                PersonalSessionLookupIden::RevokedAt,
313            )
314            .expr_as(
315                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
316                PersonalSessionLookupIden::LastActiveAt,
317            )
318            .expr_as(
319                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
320                PersonalSessionLookupIden::LastActiveIp,
321            )
322            .from(PersonalSessions::Table)
323            .apply_filter(filter)
324            .generate_pagination(
325                (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
326                pagination,
327            )
328            .build_sqlx(PostgresQueryBuilder);
329
330        let edges: Vec<PersonalSessionLookup> = sqlx::query_as_with(&sql, arguments)
331            .traced()
332            .fetch_all(&mut *self.conn)
333            .await?;
334
335        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
336
337        Ok(page)
338    }
339
340    #[tracing::instrument(
341        name = "db.personal_session.count",
342        skip_all,
343        fields(
344            db.query.text,
345        ),
346        err,
347    )]
348    async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
349        let (sql, arguments) = Query::select()
350            .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
351            .from(PersonalSessions::Table)
352            .apply_filter(filter)
353            .build_sqlx(PostgresQueryBuilder);
354
355        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
356            .traced()
357            .fetch_one(&mut *self.conn)
358            .await?;
359
360        count
361            .try_into()
362            .map_err(DatabaseError::to_invalid_operation)
363    }
364
365    #[tracing::instrument(
366        name = "db.personal_session.record_batch_activity",
367        skip_all,
368        fields(
369            db.query.text,
370        ),
371        err,
372    )]
373    async fn record_batch_activity(
374        &mut self,
375        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
376    ) -> Result<(), Self::Error> {
377        // Sort the activity by ID, so that when batching the updates, Postgres
378        // locks the rows in a stable order, preventing deadlocks
379        activities.sort_unstable();
380        let mut ids = Vec::with_capacity(activities.len());
381        let mut last_activities = Vec::with_capacity(activities.len());
382        let mut ips = Vec::with_capacity(activities.len());
383
384        for (id, last_activity, ip) in activities {
385            ids.push(Uuid::from(id));
386            last_activities.push(last_activity);
387            ips.push(ip);
388        }
389
390        let res = sqlx::query!(
391            r#"
392                UPDATE personal_sessions
393                SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)
394                  , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)
395                FROM (
396                    SELECT *
397                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
398                        AS t(personal_session_id, last_active_at, last_active_ip)
399                ) AS t
400                WHERE personal_sessions.personal_session_id = t.personal_session_id
401            "#,
402            &ids,
403            &last_activities,
404            &ips as &[Option<IpAddr>],
405        )
406        .traced()
407        .execute(&mut *self.conn)
408        .await?;
409
410        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
411
412        Ok(())
413    }
414}
415
416impl Filter for PersonalSessionFilter<'_> {
417    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
418        sea_query::Condition::all()
419            .add_option(self.owner_user().map(|user| {
420                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
421                    .eq(Uuid::from(user.id))
422            }))
423            .add_option(self.owner_oauth2_client().map(|client| {
424                Expr::col((
425                    PersonalSessions::Table,
426                    PersonalSessions::OwnerOAuth2ClientId,
427                ))
428                .eq(Uuid::from(client.id))
429            }))
430            .add_option(self.actor_user().map(|user| {
431                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
432                    .eq(Uuid::from(user.id))
433            }))
434            .add_option(self.device().map(|device| -> SimpleExpr {
435                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
436                    Condition::any()
437                        .add(
438                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
439                                PersonalSessions::Table,
440                                PersonalSessions::ScopeList,
441                            )))),
442                        )
443                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
444                            Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
445                        )))
446                        .into()
447                } else {
448                    // If the device ID can't be encoded as a scope token, match no rows
449                    Expr::val(false).into()
450                }
451            }))
452            .add_option(self.state().map(|state| match state {
453                PersonalSessionState::Active => {
454                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
455                }
456                PersonalSessionState::Revoked => {
457                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
458                }
459            }))
460            .add_option(self.scope().map(|scope| {
461                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
462                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
463            }))
464            .add_option(self.last_active_before().map(|last_active_before| {
465                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
466                    .lt(last_active_before)
467            }))
468            .add_option(self.last_active_after().map(|last_active_after| {
469                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
470                    .gt(last_active_after)
471            }))
472    }
473}