1use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Utc};
9use crc::{CRC_32_ISO_HDLC, Crc};
10use mas_iana::oauth::OAuthTokenTypeHint;
11use rand::{Rng, RngCore, distributions::Alphanumeric};
12use thiserror::Error;
13use ulid::Ulid;
14
15use crate::InvalidTransitionError;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub enum AccessTokenState {
19    #[default]
20    Valid,
21    Revoked {
22        revoked_at: DateTime<Utc>,
23    },
24}
25
26impl AccessTokenState {
27    fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
28        match self {
29            Self::Valid => Ok(Self::Revoked { revoked_at }),
30            Self::Revoked { .. } => Err(InvalidTransitionError),
31        }
32    }
33
34    #[must_use]
38    pub fn is_valid(&self) -> bool {
39        matches!(self, Self::Valid)
40    }
41
42    #[must_use]
46    pub fn is_revoked(&self) -> bool {
47        matches!(self, Self::Revoked { .. })
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct AccessToken {
53    pub id: Ulid,
54    pub state: AccessTokenState,
55    pub session_id: Ulid,
56    pub access_token: String,
57    pub created_at: DateTime<Utc>,
58    pub expires_at: Option<DateTime<Utc>>,
59    pub first_used_at: Option<DateTime<Utc>>,
60}
61
62impl AccessToken {
63    #[must_use]
64    pub fn jti(&self) -> String {
65        self.id.to_string()
66    }
67
68    #[must_use]
74    pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
75        self.state.is_valid() && !self.is_expired(now)
76    }
77
78    #[must_use]
86    pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
87        match self.expires_at {
88            Some(expires_at) => expires_at < now,
89            None => false,
90        }
91    }
92
93    #[must_use]
95    pub fn is_used(&self) -> bool {
96        self.first_used_at.is_some()
97    }
98
99    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
109        self.state = self.state.revoke(revoked_at)?;
110        Ok(self)
111    }
112}
113
114#[derive(Debug, Clone, Default, PartialEq, Eq)]
115pub enum RefreshTokenState {
116    #[default]
117    Valid,
118    Consumed {
119        consumed_at: DateTime<Utc>,
120        next_refresh_token_id: Option<Ulid>,
121    },
122    Revoked {
123        revoked_at: DateTime<Utc>,
124    },
125}
126
127impl RefreshTokenState {
128    fn consume(
134        self,
135        consumed_at: DateTime<Utc>,
136        replaced_by: &RefreshToken,
137    ) -> Result<Self, InvalidTransitionError> {
138        match self {
139            Self::Valid | Self::Consumed { .. } => Ok(Self::Consumed {
140                consumed_at,
141                next_refresh_token_id: Some(replaced_by.id),
142            }),
143            Self::Revoked { .. } => Err(InvalidTransitionError),
144        }
145    }
146
147    pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
153        match self {
154            Self::Valid => Ok(Self::Revoked { revoked_at }),
155            Self::Consumed { .. } | Self::Revoked { .. } => Err(InvalidTransitionError),
156        }
157    }
158
159    #[must_use]
163    pub fn is_valid(&self) -> bool {
164        matches!(self, Self::Valid)
165    }
166
167    #[must_use]
169    pub fn next_refresh_token_id(&self) -> Option<Ulid> {
170        match self {
171            Self::Valid | Self::Revoked { .. } => None,
172            Self::Consumed {
173                next_refresh_token_id,
174                ..
175            } => *next_refresh_token_id,
176        }
177    }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct RefreshToken {
182    pub id: Ulid,
183    pub state: RefreshTokenState,
184    pub refresh_token: String,
185    pub session_id: Ulid,
186    pub created_at: DateTime<Utc>,
187    pub access_token_id: Option<Ulid>,
188}
189
190impl std::ops::Deref for RefreshToken {
191    type Target = RefreshTokenState;
192
193    fn deref(&self) -> &Self::Target {
194        &self.state
195    }
196}
197
198impl RefreshToken {
199    #[must_use]
200    pub fn jti(&self) -> String {
201        self.id.to_string()
202    }
203
204    pub fn consume(
210        mut self,
211        consumed_at: DateTime<Utc>,
212        replaced_by: &Self,
213    ) -> Result<Self, InvalidTransitionError> {
214        self.state = self.state.consume(consumed_at, replaced_by)?;
215        Ok(self)
216    }
217
218    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
224        self.state = self.state.revoke(revoked_at)?;
225        Ok(self)
226    }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum TokenType {
232    AccessToken,
234
235    RefreshToken,
237
238    CompatAccessToken,
240
241    CompatRefreshToken,
243}
244
245impl std::fmt::Display for TokenType {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        match self {
248            TokenType::AccessToken => write!(f, "access token"),
249            TokenType::RefreshToken => write!(f, "refresh token"),
250            TokenType::CompatAccessToken => write!(f, "compat access token"),
251            TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
252        }
253    }
254}
255
256impl TokenType {
257    fn prefix(self) -> &'static str {
258        match self {
259            TokenType::AccessToken => "mat",
260            TokenType::RefreshToken => "mar",
261            TokenType::CompatAccessToken => "mct",
262            TokenType::CompatRefreshToken => "mcr",
263        }
264    }
265
266    fn match_prefix(prefix: &str) -> Option<Self> {
267        match prefix {
268            "mat" => Some(TokenType::AccessToken),
269            "mar" => Some(TokenType::RefreshToken),
270            "mct" | "syt" => Some(TokenType::CompatAccessToken),
271            "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
272            _ => None,
273        }
274    }
275
276    pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
278        let random_part: String = rng
279            .sample_iter(&Alphanumeric)
280            .take(30)
281            .map(char::from)
282            .collect();
283
284        let base = format!("{prefix}_{random_part}", prefix = self.prefix());
285        let crc = CRC.checksum(base.as_bytes());
286        let crc = base62_encode(crc);
287        format!("{base}_{crc}")
288    }
289
290    pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
296        if token.starts_with("syt_") || is_likely_synapse_macaroon(token) {
299            return Ok(TokenType::CompatAccessToken);
300        }
301        if token.starts_with("syr_") {
302            return Ok(TokenType::CompatRefreshToken);
303        }
304
305        let split: Vec<&str> = token.split('_').collect();
306        let [prefix, random_part, crc]: [&str; 3] = split
307            .try_into()
308            .map_err(|_| TokenFormatError::InvalidFormat)?;
309
310        if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
311            return Err(TokenFormatError::InvalidFormat);
312        }
313
314        let token_type =
315            TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
316                prefix: prefix.to_owned(),
317            })?;
318
319        let base = format!("{prefix}_{random_part}", prefix = token_type.prefix());
320        let expected_crc = CRC.checksum(base.as_bytes());
321        let expected_crc = base62_encode(expected_crc);
322        if crc != expected_crc {
323            return Err(TokenFormatError::InvalidCrc {
324                expected: expected_crc,
325                got: crc.to_owned(),
326            });
327        }
328
329        Ok(token_type)
330    }
331}
332
333impl PartialEq<OAuthTokenTypeHint> for TokenType {
334    fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
335        matches!(
336            (self, other),
337            (
338                TokenType::AccessToken | TokenType::CompatAccessToken,
339                OAuthTokenTypeHint::AccessToken
340            ) | (
341                TokenType::RefreshToken | TokenType::CompatRefreshToken,
342                OAuthTokenTypeHint::RefreshToken
343            )
344        )
345    }
346}
347
348fn is_likely_synapse_macaroon(token: &str) -> bool {
356    let Ok(decoded) = Base64UrlUnpadded::decode_vec(token) else {
357        return false;
358    };
359    decoded.get(4..13) == Some(b"location ")
360}
361
362const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
363
364fn base62_encode(mut num: u32) -> String {
365    let mut res = String::with_capacity(6);
366    while num > 0 {
367        res.push(NUM[(num % 62) as usize] as char);
368        num /= 62;
369    }
370
371    format!("{res:0>6}")
372}
373
374const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
375
376#[derive(Debug, Error, PartialEq, Eq)]
378pub enum TokenFormatError {
379    #[error("invalid token format")]
381    InvalidFormat,
382
383    #[error("unknown token prefix {prefix:?}")]
385    UnknownPrefix {
386        prefix: String,
388    },
389
390    #[error("invalid crc {got:?}, expected {expected:?}")]
392    InvalidCrc {
393        expected: String,
395        got: String,
397    },
398}
399
400#[cfg(test)]
401mod tests {
402    use std::collections::HashSet;
403
404    use rand::thread_rng;
405
406    use super::*;
407
408    #[test]
409    fn test_prefix_match() {
410        use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
411        assert_eq!(TokenType::match_prefix("syt"), Some(CompatAccessToken));
412        assert_eq!(TokenType::match_prefix("syr"), Some(CompatRefreshToken));
413        assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
414        assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
415        assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
416        assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
417        assert_eq!(TokenType::match_prefix("matt"), None);
418        assert_eq!(TokenType::match_prefix("marr"), None);
419        assert_eq!(TokenType::match_prefix("ma"), None);
420        assert_eq!(
421            TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
422            Some(TokenType::CompatAccessToken)
423        );
424        assert_eq!(
425            TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
426            Some(TokenType::CompatRefreshToken)
427        );
428        assert_eq!(
429            TokenType::match_prefix(TokenType::AccessToken.prefix()),
430            Some(TokenType::AccessToken)
431        );
432        assert_eq!(
433            TokenType::match_prefix(TokenType::RefreshToken.prefix()),
434            Some(TokenType::RefreshToken)
435        );
436    }
437
438    #[test]
439    fn test_is_likely_synapse_macaroon() {
440        assert!(is_likely_synapse_macaroon(
443            "MDAxYmxvY2F0aW9uIGxpYnJlcHVzaC5uZXQKMDAx"
444        ));
445
446        assert!(is_likely_synapse_macaroon(
448            "MDAxY2xvY2F0aW9uIGh0dHA6Ly9teWJhbmsvCjAwMjZpZGVudGlmaWVyIHdlIHVzZWQgb3VyIHNlY3JldCBrZXkKMDAyZnNpZ25hdHVyZSDj2eApCFJsTAA5rhURQRXZf91ovyujebNCqvD2F9BVLwo"
449        ));
450
451        assert!(!is_likely_synapse_macaroon(
453            "eyJARTOhearotnaeisahtoarsnhiasra.arsohenaor.oarnsteao"
454        ));
455        assert!(!is_likely_synapse_macaroon("...."));
456        assert!(!is_likely_synapse_macaroon("aaa"));
457    }
458
459    #[test]
460    fn test_generate_and_check() {
461        const COUNT: usize = 500; #[allow(clippy::disallowed_methods)]
464        let mut rng = thread_rng();
465
466        for t in [
467            TokenType::CompatAccessToken,
468            TokenType::CompatRefreshToken,
469            TokenType::AccessToken,
470            TokenType::RefreshToken,
471        ] {
472            let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
474
475            assert_eq!(tokens.len(), COUNT, "All tokens are unique");
477
478            for token in tokens {
480                assert_eq!(TokenType::check(&token).unwrap(), t);
481            }
482        }
483    }
484}