1use chrono::{DateTime, Utc};
8use mas_iana::oauth::PkceCodeChallengeMethod;
9use oauth2_types::{
10 pkce::{CodeChallengeError, CodeChallengeMethodExt},
11 requests::ResponseMode,
12 scope::{OPENID, PROFILE, Scope},
13};
14use rand::{
15 RngCore,
16 distributions::{Alphanumeric, DistString},
17};
18use ruma_common::UserId;
19use serde::Serialize;
20use ulid::Ulid;
21use url::Url;
22
23use super::session::Session;
24use crate::InvalidTransitionError;
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
27pub struct Pkce {
28 pub challenge_method: PkceCodeChallengeMethod,
29 pub challenge: String,
30}
31
32impl Pkce {
33 #[must_use]
35 pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
36 Pkce {
37 challenge_method,
38 challenge,
39 }
40 }
41
42 pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
48 self.challenge_method.verify(&self.challenge, verifier)
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
53pub struct AuthorizationCode {
54 pub code: String,
55 pub pkce: Option<Pkce>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
59#[serde(tag = "stage", rename_all = "lowercase")]
60pub enum AuthorizationGrantStage {
61 #[default]
62 Pending,
63 Fulfilled {
64 session_id: Ulid,
65 fulfilled_at: DateTime<Utc>,
66 },
67 Exchanged {
68 session_id: Ulid,
69 fulfilled_at: DateTime<Utc>,
70 exchanged_at: DateTime<Utc>,
71 },
72 Cancelled {
73 cancelled_at: DateTime<Utc>,
74 },
75}
76
77impl AuthorizationGrantStage {
78 #[must_use]
79 pub fn new() -> Self {
80 Self::Pending
81 }
82
83 fn fulfill(
84 self,
85 fulfilled_at: DateTime<Utc>,
86 session: &Session,
87 ) -> Result<Self, InvalidTransitionError> {
88 match self {
89 Self::Pending => Ok(Self::Fulfilled {
90 fulfilled_at,
91 session_id: session.id,
92 }),
93 _ => Err(InvalidTransitionError),
94 }
95 }
96
97 fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
98 match self {
99 Self::Fulfilled {
100 fulfilled_at,
101 session_id,
102 } => Ok(Self::Exchanged {
103 fulfilled_at,
104 exchanged_at,
105 session_id,
106 }),
107 _ => Err(InvalidTransitionError),
108 }
109 }
110
111 fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
112 match self {
113 Self::Pending => Ok(Self::Cancelled { cancelled_at }),
114 _ => Err(InvalidTransitionError),
115 }
116 }
117
118 #[must_use]
122 pub fn is_pending(&self) -> bool {
123 matches!(self, Self::Pending)
124 }
125
126 #[must_use]
130 pub fn is_fulfilled(&self) -> bool {
131 matches!(self, Self::Fulfilled { .. })
132 }
133
134 #[must_use]
138 pub fn is_exchanged(&self) -> bool {
139 matches!(self, Self::Exchanged { .. })
140 }
141}
142
143pub enum LoginHint<'a> {
144 MXID(&'a UserId),
145 None,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
149pub struct AuthorizationGrant {
150 pub id: Ulid,
151 #[serde(flatten)]
152 pub stage: AuthorizationGrantStage,
153 pub code: Option<AuthorizationCode>,
154 pub client_id: Ulid,
155 pub redirect_uri: Url,
156 pub scope: Scope,
157 pub state: Option<String>,
158 pub nonce: Option<String>,
159 pub response_mode: ResponseMode,
160 pub response_type_id_token: bool,
161 pub created_at: DateTime<Utc>,
162 pub login_hint: Option<String>,
163 pub locale: Option<String>,
164}
165
166impl std::ops::Deref for AuthorizationGrant {
167 type Target = AuthorizationGrantStage;
168
169 fn deref(&self) -> &Self::Target {
170 &self.stage
171 }
172}
173
174impl AuthorizationGrant {
175 #[must_use]
176 pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
177 let Some(login_hint) = &self.login_hint else {
178 return LoginHint::None;
179 };
180
181 let Some((prefix, value)) = login_hint.split_once(':') else {
183 return LoginHint::None;
184 };
185
186 match prefix {
187 "mxid" => {
188 let Ok(mxid) = <&UserId>::try_from(value) else {
190 return LoginHint::None;
191 };
192
193 if mxid.server_name() != homeserver {
195 return LoginHint::None;
196 }
197
198 LoginHint::MXID(mxid)
199 }
200 _ => LoginHint::None,
202 }
203 }
204
205 pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
213 self.stage = self.stage.exchange(exchanged_at)?;
214 Ok(self)
215 }
216
217 pub fn fulfill(
225 mut self,
226 fulfilled_at: DateTime<Utc>,
227 session: &Session,
228 ) -> Result<Self, InvalidTransitionError> {
229 self.stage = self.stage.fulfill(fulfilled_at, session)?;
230 Ok(self)
231 }
232
233 pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
245 self.stage = self.stage.cancel(canceld_at)?;
246 Ok(self)
247 }
248
249 #[doc(hidden)]
250 pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
251 Self {
252 id: Ulid::from_datetime_with_source(now.into(), rng),
253 stage: AuthorizationGrantStage::Pending,
254 code: Some(AuthorizationCode {
255 code: Alphanumeric.sample_string(rng, 10),
256 pkce: None,
257 }),
258 client_id: Ulid::from_datetime_with_source(now.into(), rng),
259 redirect_uri: Url::parse("http://localhost:8080").unwrap(),
260 scope: Scope::from_iter([OPENID, PROFILE]),
261 state: Some(Alphanumeric.sample_string(rng, 10)),
262 nonce: Some(Alphanumeric.sample_string(rng, 10)),
263 response_mode: ResponseMode::Query,
264 response_type_id_token: false,
265 created_at: now,
266 login_hint: Some(String::from("mxid:@example-user:example.com")),
267 locale: Some(String::from("fr")),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use rand::thread_rng;
275
276 use super::*;
277
278 #[test]
279 fn no_login_hint() {
280 #[allow(clippy::disallowed_methods)]
281 let mut rng = thread_rng();
282
283 #[allow(clippy::disallowed_methods)]
284 let now = Utc::now();
285
286 let grant = AuthorizationGrant {
287 login_hint: None,
288 ..AuthorizationGrant::sample(now, &mut rng)
289 };
290
291 let hint = grant.parse_login_hint("example.com");
292
293 assert!(matches!(hint, LoginHint::None));
294 }
295
296 #[test]
297 fn valid_login_hint() {
298 #[allow(clippy::disallowed_methods)]
299 let mut rng = thread_rng();
300
301 #[allow(clippy::disallowed_methods)]
302 let now = Utc::now();
303
304 let grant = AuthorizationGrant {
305 login_hint: Some(String::from("mxid:@example-user:example.com")),
306 ..AuthorizationGrant::sample(now, &mut rng)
307 };
308
309 let hint = grant.parse_login_hint("example.com");
310
311 assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
312 }
313
314 #[test]
315 fn invalid_login_hint() {
316 #[allow(clippy::disallowed_methods)]
317 let mut rng = thread_rng();
318
319 #[allow(clippy::disallowed_methods)]
320 let now = Utc::now();
321
322 let grant = AuthorizationGrant {
323 login_hint: Some(String::from("example-user")),
324 ..AuthorizationGrant::sample(now, &mut rng)
325 };
326
327 let hint = grant.parse_login_hint("example.com");
328
329 assert!(matches!(hint, LoginHint::None));
330 }
331
332 #[test]
333 fn valid_login_hint_for_wrong_homeserver() {
334 #[allow(clippy::disallowed_methods)]
335 let mut rng = thread_rng();
336
337 #[allow(clippy::disallowed_methods)]
338 let now = Utc::now();
339
340 let grant = AuthorizationGrant {
341 login_hint: Some(String::from("mxid:@example-user:matrix.org")),
342 ..AuthorizationGrant::sample(now, &mut rng)
343 };
344
345 let hint = grant.parse_login_hint("example.com");
346
347 assert!(matches!(hint, LoginHint::None));
348 }
349
350 #[test]
351 fn unknown_login_hint_type() {
352 #[allow(clippy::disallowed_methods)]
353 let mut rng = thread_rng();
354
355 #[allow(clippy::disallowed_methods)]
356 let now = Utc::now();
357
358 let grant = AuthorizationGrant {
359 login_hint: Some(String::from("something:anything")),
360 ..AuthorizationGrant::sample(now, &mut rng)
361 };
362
363 let hint = grant.parse_login_hint("example.com");
364
365 assert!(matches!(hint, LoginHint::None));
366 }
367}