mas_storage_pg/compat/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing PostgreSQL implementation of repositories for the
8//! compatibility layer
9
10mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16    access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17    session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22    use chrono::Duration;
23    use mas_data_model::Device;
24    use mas_storage::{
25        Clock, Pagination, RepositoryAccess,
26        clock::MockClock,
27        compat::{
28            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
29            CompatSessionRepository, CompatSsoLoginFilter,
30        },
31        user::UserRepository,
32    };
33    use rand::SeedableRng;
34    use rand_chacha::ChaChaRng;
35    use sqlx::PgPool;
36    use ulid::Ulid;
37
38    use crate::PgRepository;
39
40    #[sqlx::test(migrator = "crate::MIGRATOR")]
41    async fn test_session_repository(pool: PgPool) {
42        let mut rng = ChaChaRng::seed_from_u64(42);
43        let clock = MockClock::default();
44        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
45
46        // Create a user
47        let user = repo
48            .user()
49            .add(&mut rng, &clock, "john".to_owned())
50            .await
51            .unwrap();
52
53        let all = CompatSessionFilter::new().for_user(&user);
54        let active = all.active_only();
55        let finished = all.finished_only();
56        let pagination = Pagination::first(10);
57
58        assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
59        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
60        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
61
62        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
63        assert!(full_list.edges.is_empty());
64        let active_list = repo
65            .compat_session()
66            .list(active, pagination)
67            .await
68            .unwrap();
69        assert!(active_list.edges.is_empty());
70        let finished_list = repo
71            .compat_session()
72            .list(finished, pagination)
73            .await
74            .unwrap();
75        assert!(finished_list.edges.is_empty());
76
77        // Start a compat session for that user
78        let device = Device::generate(&mut rng);
79        let device_str = device.as_str().to_owned();
80        let session = repo
81            .compat_session()
82            .add(&mut rng, &clock, &user, device.clone(), None, false, None)
83            .await
84            .unwrap();
85        assert_eq!(session.user_id, user.id);
86        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
87        assert!(session.is_valid());
88        assert!(!session.is_finished());
89
90        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
91        assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
92        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
93
94        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
95        assert_eq!(full_list.edges.len(), 1);
96        assert_eq!(full_list.edges[0].0.id, session.id);
97        let active_list = repo
98            .compat_session()
99            .list(active, pagination)
100            .await
101            .unwrap();
102        assert_eq!(active_list.edges.len(), 1);
103        assert_eq!(active_list.edges[0].0.id, session.id);
104        let finished_list = repo
105            .compat_session()
106            .list(finished, pagination)
107            .await
108            .unwrap();
109        assert!(finished_list.edges.is_empty());
110
111        // Lookup the session and check it didn't change
112        let session_lookup = repo
113            .compat_session()
114            .lookup(session.id)
115            .await
116            .unwrap()
117            .expect("compat session not found");
118        assert_eq!(session_lookup.id, session.id);
119        assert_eq!(session_lookup.user_id, user.id);
120        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
121        assert!(session_lookup.is_valid());
122        assert!(!session_lookup.is_finished());
123
124        // Record a user-agent for the session
125        assert!(session_lookup.user_agent.is_none());
126        let session = repo
127            .compat_session()
128            .record_user_agent(session_lookup, "Mozilla/5.0".to_owned())
129            .await
130            .unwrap();
131        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
132
133        // Reload the session and check again
134        let session_lookup = repo
135            .compat_session()
136            .lookup(session.id)
137            .await
138            .unwrap()
139            .expect("compat session not found");
140        assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
141
142        // Look up the session by device
143        let list = repo
144            .compat_session()
145            .list(
146                CompatSessionFilter::new()
147                    .for_user(&user)
148                    .for_device(&device),
149                pagination,
150            )
151            .await
152            .unwrap();
153        assert_eq!(list.edges.len(), 1);
154        let session_lookup = &list.edges[0].0;
155        assert_eq!(session_lookup.id, session.id);
156        assert_eq!(session_lookup.user_id, user.id);
157        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
158        assert!(session_lookup.is_valid());
159        assert!(!session_lookup.is_finished());
160
161        // Finish the session
162        let session = repo.compat_session().finish(&clock, session).await.unwrap();
163        assert!(!session.is_valid());
164        assert!(session.is_finished());
165
166        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
167        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
168        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
169
170        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
171        assert_eq!(full_list.edges.len(), 1);
172        assert_eq!(full_list.edges[0].0.id, session.id);
173        let active_list = repo
174            .compat_session()
175            .list(active, pagination)
176            .await
177            .unwrap();
178        assert!(active_list.edges.is_empty());
179        let finished_list = repo
180            .compat_session()
181            .list(finished, pagination)
182            .await
183            .unwrap();
184        assert_eq!(finished_list.edges.len(), 1);
185        assert_eq!(finished_list.edges[0].0.id, session.id);
186
187        // Reload the session and check again
188        let session_lookup = repo
189            .compat_session()
190            .lookup(session.id)
191            .await
192            .unwrap()
193            .expect("compat session not found");
194        assert!(!session_lookup.is_valid());
195        assert!(session_lookup.is_finished());
196
197        // Now add another session, with an SSO login this time
198        let unknown_session = session;
199        // Start a new SSO login
200        let login = repo
201            .compat_sso_login()
202            .add(
203                &mut rng,
204                &clock,
205                "login-token".to_owned(),
206                "https://example.com/callback".parse().unwrap(),
207            )
208            .await
209            .unwrap();
210        assert!(login.is_pending());
211
212        // Start a browser session for the user
213        let browser_session = repo
214            .browser_session()
215            .add(&mut rng, &clock, &user, None)
216            .await
217            .unwrap();
218
219        // Start a compat session for that user
220        let device = Device::generate(&mut rng);
221        let sso_login_session = repo
222            .compat_session()
223            .add(
224                &mut rng,
225                &clock,
226                &user,
227                device,
228                Some(&browser_session),
229                false,
230                None,
231            )
232            .await
233            .unwrap();
234
235        // Associate the login with the session
236        let login = repo
237            .compat_sso_login()
238            .fulfill(&clock, login, &browser_session)
239            .await
240            .unwrap();
241        assert!(login.is_fulfilled());
242        let login = repo
243            .compat_sso_login()
244            .exchange(&clock, login, &sso_login_session)
245            .await
246            .unwrap();
247        assert!(login.is_exchanged());
248
249        // Now query the session list with both the unknown and SSO login session type
250        // filter
251        let all = CompatSessionFilter::new().for_user(&user);
252        let sso_login = all.sso_login_only();
253        let unknown = all.unknown_only();
254        assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
255        assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
256        assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
257
258        let list = repo
259            .compat_session()
260            .list(sso_login, pagination)
261            .await
262            .unwrap();
263        assert_eq!(list.edges.len(), 1);
264        assert_eq!(list.edges[0].0.id, sso_login_session.id);
265        let list = repo
266            .compat_session()
267            .list(unknown, pagination)
268            .await
269            .unwrap();
270        assert_eq!(list.edges.len(), 1);
271        assert_eq!(list.edges[0].0.id, unknown_session.id);
272
273        // Check that combining the two filters works
274        // At this point, there is one active SSO login session and one finished unknown
275        // session
276        assert_eq!(
277            repo.compat_session()
278                .count(all.sso_login_only().active_only())
279                .await
280                .unwrap(),
281            1
282        );
283        assert_eq!(
284            repo.compat_session()
285                .count(all.sso_login_only().finished_only())
286                .await
287                .unwrap(),
288            0
289        );
290        assert_eq!(
291            repo.compat_session()
292                .count(all.unknown_only().active_only())
293                .await
294                .unwrap(),
295            0
296        );
297        assert_eq!(
298            repo.compat_session()
299                .count(all.unknown_only().finished_only())
300                .await
301                .unwrap(),
302            1
303        );
304
305        // Check that we can batch finish sessions
306        let affected = repo
307            .compat_session()
308            .finish_bulk(&clock, all.sso_login_only().active_only())
309            .await
310            .unwrap();
311        assert_eq!(affected, 1);
312        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
313        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
314    }
315
316    #[sqlx::test(migrator = "crate::MIGRATOR")]
317    async fn test_access_token_repository(pool: PgPool) {
318        const FIRST_TOKEN: &str = "first_access_token";
319        const SECOND_TOKEN: &str = "second_access_token";
320        let mut rng = ChaChaRng::seed_from_u64(42);
321        let clock = MockClock::default();
322        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
323
324        // Create a user
325        let user = repo
326            .user()
327            .add(&mut rng, &clock, "john".to_owned())
328            .await
329            .unwrap();
330
331        // Start a compat session for that user
332        let device = Device::generate(&mut rng);
333        let session = repo
334            .compat_session()
335            .add(&mut rng, &clock, &user, device, None, false, None)
336            .await
337            .unwrap();
338
339        // Add an access token to that session
340        let token = repo
341            .compat_access_token()
342            .add(
343                &mut rng,
344                &clock,
345                &session,
346                FIRST_TOKEN.to_owned(),
347                Some(Duration::try_minutes(1).unwrap()),
348            )
349            .await
350            .unwrap();
351        assert_eq!(token.session_id, session.id);
352        assert_eq!(token.token, FIRST_TOKEN);
353
354        // Commit the txn and grab a new transaction, to test a conflict
355        repo.save().await.unwrap();
356
357        {
358            let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
359            // Adding the same token a second time should conflict
360            assert!(
361                repo.compat_access_token()
362                    .add(
363                        &mut rng,
364                        &clock,
365                        &session,
366                        FIRST_TOKEN.to_owned(),
367                        Some(Duration::try_minutes(1).unwrap()),
368                    )
369                    .await
370                    .is_err()
371            );
372            repo.cancel().await.unwrap();
373        }
374
375        // Grab a new repo
376        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
377
378        // Looking up via ID works
379        let token_lookup = repo
380            .compat_access_token()
381            .lookup(token.id)
382            .await
383            .unwrap()
384            .expect("compat access token not found");
385        assert_eq!(token.id, token_lookup.id);
386        assert_eq!(token_lookup.session_id, session.id);
387
388        // Looking up via the token value works
389        let token_lookup = repo
390            .compat_access_token()
391            .find_by_token(FIRST_TOKEN)
392            .await
393            .unwrap()
394            .expect("compat access token not found");
395        assert_eq!(token.id, token_lookup.id);
396        assert_eq!(token_lookup.session_id, session.id);
397
398        // Token is currently valid
399        assert!(token.is_valid(clock.now()));
400
401        clock.advance(Duration::try_minutes(1).unwrap());
402        // Token should have expired
403        assert!(!token.is_valid(clock.now()));
404
405        // Add a second access token, this time without expiration
406        let token = repo
407            .compat_access_token()
408            .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
409            .await
410            .unwrap();
411        assert_eq!(token.session_id, session.id);
412        assert_eq!(token.token, SECOND_TOKEN);
413
414        // Token is currently valid
415        assert!(token.is_valid(clock.now()));
416
417        // Make it expire
418        repo.compat_access_token()
419            .expire(&clock, token)
420            .await
421            .unwrap();
422
423        // Reload it
424        let token = repo
425            .compat_access_token()
426            .find_by_token(SECOND_TOKEN)
427            .await
428            .unwrap()
429            .expect("compat access token not found");
430
431        // Token is not valid anymore
432        assert!(!token.is_valid(clock.now()));
433
434        repo.save().await.unwrap();
435    }
436
437    #[sqlx::test(migrator = "crate::MIGRATOR")]
438    async fn test_refresh_token_repository(pool: PgPool) {
439        const ACCESS_TOKEN: &str = "access_token";
440        const REFRESH_TOKEN: &str = "refresh_token";
441        let mut rng = ChaChaRng::seed_from_u64(42);
442        let clock = MockClock::default();
443        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
444
445        // Create a user
446        let user = repo
447            .user()
448            .add(&mut rng, &clock, "john".to_owned())
449            .await
450            .unwrap();
451
452        // Start a compat session for that user
453        let device = Device::generate(&mut rng);
454        let session = repo
455            .compat_session()
456            .add(&mut rng, &clock, &user, device, None, false, None)
457            .await
458            .unwrap();
459
460        // Add an access token to that session
461        let access_token = repo
462            .compat_access_token()
463            .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
464            .await
465            .unwrap();
466
467        let refresh_token = repo
468            .compat_refresh_token()
469            .add(
470                &mut rng,
471                &clock,
472                &session,
473                &access_token,
474                REFRESH_TOKEN.to_owned(),
475            )
476            .await
477            .unwrap();
478        assert_eq!(refresh_token.session_id, session.id);
479        assert_eq!(refresh_token.access_token_id, access_token.id);
480        assert_eq!(refresh_token.token, REFRESH_TOKEN);
481        assert!(refresh_token.is_valid());
482        assert!(!refresh_token.is_consumed());
483
484        // Look it up by ID and check everything matches
485        let refresh_token_lookup = repo
486            .compat_refresh_token()
487            .lookup(refresh_token.id)
488            .await
489            .unwrap()
490            .expect("refresh token not found");
491        assert_eq!(refresh_token_lookup.id, refresh_token.id);
492        assert_eq!(refresh_token_lookup.session_id, session.id);
493        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
494        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
495        assert!(refresh_token_lookup.is_valid());
496        assert!(!refresh_token_lookup.is_consumed());
497
498        // Look it up by token and check everything matches
499        let refresh_token_lookup = repo
500            .compat_refresh_token()
501            .find_by_token(REFRESH_TOKEN)
502            .await
503            .unwrap()
504            .expect("refresh token not found");
505        assert_eq!(refresh_token_lookup.id, refresh_token.id);
506        assert_eq!(refresh_token_lookup.session_id, session.id);
507        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
508        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
509        assert!(refresh_token_lookup.is_valid());
510        assert!(!refresh_token_lookup.is_consumed());
511
512        // Consume it
513        let refresh_token = repo
514            .compat_refresh_token()
515            .consume(&clock, refresh_token)
516            .await
517            .unwrap();
518        assert!(!refresh_token.is_valid());
519        assert!(refresh_token.is_consumed());
520
521        // Reload it and check again
522        let refresh_token_lookup = repo
523            .compat_refresh_token()
524            .find_by_token(REFRESH_TOKEN)
525            .await
526            .unwrap()
527            .expect("refresh token not found");
528        assert!(!refresh_token_lookup.is_valid());
529        assert!(refresh_token_lookup.is_consumed());
530
531        // Consuming it again should not work
532        assert!(
533            repo.compat_refresh_token()
534                .consume(&clock, refresh_token)
535                .await
536                .is_err()
537        );
538
539        repo.save().await.unwrap();
540    }
541
542    #[sqlx::test(migrator = "crate::MIGRATOR")]
543    async fn test_compat_sso_login_repository(pool: PgPool) {
544        let mut rng = ChaChaRng::seed_from_u64(42);
545        let clock = MockClock::default();
546        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
547
548        // Create a user
549        let user = repo
550            .user()
551            .add(&mut rng, &clock, "john".to_owned())
552            .await
553            .unwrap();
554
555        // Lookup an unknown SSO login
556        let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
557        assert_eq!(login, None);
558
559        let all = CompatSsoLoginFilter::new();
560        let for_user = all.for_user(&user);
561        let pending = all.pending_only();
562        let fulfilled = all.fulfilled_only();
563        let exchanged = all.exchanged_only();
564
565        // Check the initial counts
566        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
567        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
568        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
569        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
570        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
571
572        // Lookup an unknown login token
573        let login = repo
574            .compat_sso_login()
575            .find_by_token("login-token")
576            .await
577            .unwrap();
578        assert_eq!(login, None);
579
580        // Start a new SSO login
581        let login = repo
582            .compat_sso_login()
583            .add(
584                &mut rng,
585                &clock,
586                "login-token".to_owned(),
587                "https://example.com/callback".parse().unwrap(),
588            )
589            .await
590            .unwrap();
591        assert!(login.is_pending());
592
593        // Check the counts
594        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
595        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
596        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
597        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
598        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
599
600        // Lookup the login by ID
601        let login_lookup = repo
602            .compat_sso_login()
603            .lookup(login.id)
604            .await
605            .unwrap()
606            .expect("login not found");
607        assert_eq!(login_lookup, login);
608
609        // Find the login by token
610        let login_lookup = repo
611            .compat_sso_login()
612            .find_by_token("login-token")
613            .await
614            .unwrap()
615            .expect("login not found");
616        assert_eq!(login_lookup, login);
617
618        // Start a compat session for that user
619        let device = Device::generate(&mut rng);
620        let compat_session = repo
621            .compat_session()
622            .add(&mut rng, &clock, &user, device, None, false, None)
623            .await
624            .unwrap();
625
626        // Exchanging before fulfilling should not work
627        // Note: It should also not poison the SQL transaction
628        let res = repo
629            .compat_sso_login()
630            .exchange(&clock, login.clone(), &compat_session)
631            .await;
632        assert!(res.is_err());
633
634        // Start a browser session for that user
635        let browser_session = repo
636            .browser_session()
637            .add(&mut rng, &clock, &user, None)
638            .await
639            .unwrap();
640
641        // Associate the login with the session
642        let login = repo
643            .compat_sso_login()
644            .fulfill(&clock, login, &browser_session)
645            .await
646            .unwrap();
647        assert!(login.is_fulfilled());
648
649        // Check the counts
650        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
651        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
652        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
653        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
654        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
655
656        // Fulfilling again should not work
657        // Note: It should also not poison the SQL transaction
658        let res = repo
659            .compat_sso_login()
660            .fulfill(&clock, login.clone(), &browser_session)
661            .await;
662        assert!(res.is_err());
663
664        // Exchange that login
665        let login = repo
666            .compat_sso_login()
667            .exchange(&clock, login, &compat_session)
668            .await
669            .unwrap();
670        assert!(login.is_exchanged());
671
672        // Check the counts
673        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
674        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
675        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
676        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
677        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
678
679        // Exchange again should not work
680        // Note: It should also not poison the SQL transaction
681        let res = repo
682            .compat_sso_login()
683            .exchange(&clock, login.clone(), &compat_session)
684            .await;
685        assert!(res.is_err());
686
687        // Fulfilling after exchanging should not work
688        // Note: It should also not poison the SQL transaction
689        let res = repo
690            .compat_sso_login()
691            .fulfill(&clock, login.clone(), &browser_session)
692            .await;
693        assert!(res.is_err());
694
695        let pagination = Pagination::first(10);
696
697        // List all logins
698        let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
699        assert!(!logins.has_next_page);
700        assert_eq!(logins.edges, &[login.clone()]);
701
702        // List the logins for the user
703        let logins = repo
704            .compat_sso_login()
705            .list(for_user, pagination)
706            .await
707            .unwrap();
708        assert!(!logins.has_next_page);
709        assert_eq!(logins.edges, &[login.clone()]);
710
711        // List only the pending logins for the user
712        let logins = repo
713            .compat_sso_login()
714            .list(for_user.pending_only(), pagination)
715            .await
716            .unwrap();
717        assert!(!logins.has_next_page);
718        assert!(logins.edges.is_empty());
719
720        // List only the fulfilled logins for the user
721        let logins = repo
722            .compat_sso_login()
723            .list(for_user.fulfilled_only(), pagination)
724            .await
725            .unwrap();
726        assert!(!logins.has_next_page);
727        assert!(logins.edges.is_empty());
728
729        // List only the exchanged logins for the user
730        let logins = repo
731            .compat_sso_login()
732            .list(for_user.exchanged_only(), pagination)
733            .await
734            .unwrap();
735        assert!(!logins.has_next_page);
736        assert_eq!(logins.edges, &[login]);
737    }
738}