1use async_trait::async_trait;
10use mas_jose::{
11    claims::ClaimError,
12    jwa::InvalidAlgorithm,
13    jwt::{JwtDecodeError, JwtSignatureError, NoKeyWorked},
14};
15use oauth2_types::{oidc::ProviderMetadataVerificationError, pkce::CodeChallengeError};
16use serde::Deserialize;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
21#[error(transparent)]
22pub enum Error {
23    Discovery(#[from] DiscoveryError),
25
26    Jwks(#[from] JwksError),
28
29    Authorization(#[from] AuthorizationError),
31
32    TokenAuthorizationCode(#[from] TokenAuthorizationCodeError),
34
35    TokenClientCredentials(#[from] TokenRequestError),
37
38    TokenRefresh(#[from] TokenRefreshError),
40
41    UserInfo(#[from] UserInfoError),
43}
44
45#[derive(Debug, Error)]
47#[error("Fetching provider metadata failed")]
48pub enum DiscoveryError {
49    IntoUrl(#[from] url::ParseError),
51
52    Http(#[from] reqwest::Error),
54
55    Validation(#[from] ProviderMetadataVerificationError),
57
58    #[error("Provider doesn't have an issuer set")]
61    MissingIssuer,
62
63    #[error("Discovery is disabled for this provider")]
65    Disabled,
66}
67
68#[derive(Debug, Error)]
70#[error("Building the authorization URL failed")]
71pub enum AuthorizationError {
72    Pkce(#[from] CodeChallengeError),
74
75    UrlEncoded(#[from] serde_urlencoded::ser::Error),
77}
78
79#[derive(Debug, Error)]
81#[error("Request to the token endpoint failed")]
82pub enum TokenRequestError {
83    Http(#[from] reqwest::Error),
85
86    OAuth2(#[from] OAuth2Error),
88
89    Credentials(#[from] CredentialsError),
91}
92
93#[derive(Debug, Error)]
95pub enum TokenAuthorizationCodeError {
96    #[error(transparent)]
98    Token(#[from] TokenRequestError),
99
100    #[error("Verifying the 'id_token' returned by the provider failed")]
102    IdToken(#[from] IdTokenError),
103}
104
105#[derive(Debug, Error)]
107pub enum TokenRefreshError {
108    #[error(transparent)]
110    Token(#[from] TokenRequestError),
111
112    #[error("Verifying the 'id_token' returned by the provider failed")]
114    IdToken(#[from] IdTokenError),
115}
116
117#[derive(Debug, Error)]
119pub enum UserInfoError {
120    #[error("missing response content-type")]
122    MissingResponseContentType,
123
124    #[error("invalid response content-type")]
126    InvalidResponseContentTypeValue,
127
128    #[error("unexpected response content-type {got:?}, expected {expected:?}")]
130    UnexpectedResponseContentType {
131        expected: String,
133        got: String,
135    },
136
137    #[error("Verifying the 'id_token' returned by the provider failed")]
139    IdToken(#[from] IdTokenError),
140
141    #[error(transparent)]
143    Http(#[from] reqwest::Error),
144
145    #[error(transparent)]
147    OAuth2(#[from] OAuth2Error),
148}
149
150#[derive(Debug, Error)]
152#[error("Failed to fetch JWKS")]
153pub enum JwksError {
154    Http(#[from] reqwest::Error),
156}
157
158#[derive(Debug, Error)]
160pub enum JwtVerificationError {
161    #[error(transparent)]
163    JwtDecode(#[from] JwtDecodeError),
164
165    #[error(transparent)]
167    JwtSignature(#[from] NoKeyWorked),
168
169    #[error(transparent)]
171    Claim(#[from] ClaimError),
172
173    #[error("wrong signature alg")]
176    WrongSignatureAlg,
177}
178
179#[derive(Debug, Error)]
181pub enum IdTokenError {
182    #[error("ID token is missing")]
184    MissingIdToken,
185
186    #[error("Authorization ID token is missing")]
189    MissingAuthIdToken,
190
191    #[error(transparent)]
192    Jwt(#[from] JwtVerificationError),
194
195    #[error(transparent)]
196    Claim(#[from] ClaimError),
198
199    #[error("wrong subject identifier")]
202    WrongSubjectIdentifier,
203
204    #[error("wrong authentication time")]
207    WrongAuthTime,
208}
209
210#[derive(Debug, Error)]
212pub enum CredentialsError {
213    #[error("unsupported authentication method")]
215    UnsupportedMethod,
216
217    #[error("no private key was found for the given algorithm")]
220    NoPrivateKeyFound,
221
222    #[error("invalid algorithm: {0}")]
224    InvalidSigningAlgorithm(#[from] InvalidAlgorithm),
225
226    #[error(transparent)]
228    JwtClaims(#[from] ClaimError),
229
230    #[error("Wrong algorithm for key")]
232    JwtWrongAlgorithm,
233
234    #[error(transparent)]
236    JwtSignature(#[from] JwtSignatureError),
237}
238
239#[derive(Debug, Deserialize)]
240struct OAuth2ErrorResponse {
241    error: String,
242    error_description: Option<String>,
243    error_uri: Option<String>,
244}
245
246impl std::fmt::Display for OAuth2ErrorResponse {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(f, "{:?}", self.error)?;
249
250        if let Some(error_uri) = &self.error_uri {
251            write!(f, " (See {error_uri})")?;
252        }
253
254        if let Some(error_description) = &self.error_description {
255            write!(f, ": {error_description}")?;
256        }
257
258        Ok(())
259    }
260}
261
262#[derive(Debug, Error)]
264pub struct OAuth2Error {
265    error: Option<OAuth2ErrorResponse>,
266
267    #[source]
268    inner: reqwest::Error,
269}
270
271impl std::fmt::Display for OAuth2Error {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        if let Some(error) = &self.error {
274            write!(
275                f,
276                "Request to the provider failed with the following error: {error}"
277            )
278        } else {
279            write!(f, "Request to the provider failed")
280        }
281    }
282}
283
284impl From<reqwest::Error> for OAuth2Error {
285    fn from(inner: reqwest::Error) -> Self {
286        Self { error: None, inner }
287    }
288}
289
290#[async_trait]
292pub(crate) trait ResponseExt {
293    async fn error_from_oauth2_error_response(self) -> Result<Self, OAuth2Error>
294    where
295        Self: Sized;
296}
297
298#[async_trait]
299impl ResponseExt for reqwest::Response {
300    async fn error_from_oauth2_error_response(self) -> Result<Self, OAuth2Error> {
301        let Err(inner) = self.error_for_status_ref() else {
302            return Ok(self);
303        };
304
305        let error: OAuth2ErrorResponse = self.json().await?;
306
307        Err(OAuth2Error {
308            error: Some(error),
309            inner,
310        })
311    }
312}