1use std::borrow::Cow;
12
13use base64ct::{Base64UrlUnpadded, Encoding};
14use mas_iana::oauth::PkceCodeChallengeMethod;
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17use thiserror::Error;
18
19#[derive(Debug, Error, PartialEq, Eq)]
21pub enum CodeChallengeError {
22    #[error("code_verifier should be at least 43 characters long")]
24    TooShort,
25
26    #[error("code_verifier should be at most 128 characters long")]
28    TooLong,
29
30    #[error("code_verifier contains invalid characters")]
32    InvalidCharacters,
33
34    #[error("challenge verification failed")]
36    VerificationFailed,
37
38    #[error("unknown challenge method")]
40    UnknownChallengeMethod,
41}
42
43fn validate_verifier(verifier: &str) -> Result<(), CodeChallengeError> {
44    if verifier.len() < 43 {
45        return Err(CodeChallengeError::TooShort);
46    }
47
48    if verifier.len() > 128 {
49        return Err(CodeChallengeError::TooLong);
50    }
51
52    if !verifier
53        .chars()
54        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
55    {
56        return Err(CodeChallengeError::InvalidCharacters);
57    }
58
59    Ok(())
60}
61
62pub trait CodeChallengeMethodExt {
64    fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError>;
71
72    fn verify(&self, challenge: &str, verifier: &str) -> Result<(), CodeChallengeError>
80    where
81        Self: Sized,
82    {
83        if self.compute_challenge(verifier)? == challenge {
84            Ok(())
85        } else {
86            Err(CodeChallengeError::VerificationFailed)
87        }
88    }
89}
90
91impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
92    fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError> {
93        validate_verifier(verifier)?;
94
95        let challenge = match self {
96            Self::Plain => verifier.into(),
97            Self::S256 => {
98                let mut hasher = Sha256::new();
99                hasher.update(verifier.as_bytes());
100                let hash = hasher.finalize();
101                let verifier = Base64UrlUnpadded::encode_string(&hash);
102                verifier.into()
103            }
104            _ => return Err(CodeChallengeError::UnknownChallengeMethod),
105        };
106
107        Ok(challenge)
108    }
109}
110
111#[derive(Clone, Serialize, Deserialize)]
113pub struct AuthorizationRequest {
114    pub code_challenge_method: PkceCodeChallengeMethod,
116
117    pub code_challenge: String,
119}
120
121#[derive(Clone, Serialize, Deserialize)]
123pub struct TokenRequest {
124    pub code_challenge_verifier: String,
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_pkce_verification() {
134        use PkceCodeChallengeMethod::{Plain, S256};
135        let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
137
138        assert!(
139            S256.verify(challenge, "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
140                .is_ok()
141        );
142
143        assert!(Plain.verify(challenge, challenge).is_ok());
144
145        assert_eq!(
146            S256.verify(challenge, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"),
147            Err(CodeChallengeError::VerificationFailed),
148        );
149
150        assert_eq!(
151            S256.verify(challenge, "tooshort"),
152            Err(CodeChallengeError::TooShort),
153        );
154
155        assert_eq!(
156            S256.verify(challenge, "toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"),
157            Err(CodeChallengeError::TooLong),
158        );
159
160        assert_eq!(
161            S256.verify(
162                challenge,
163                "this is long enough but has invalid characters in it"
164            ),
165            Err(CodeChallengeError::InvalidCharacters),
166        );
167    }
168}