1use chrono::{DateTime, Utc};
8use mas_iana::oauth::PkceCodeChallengeMethod;
9use oauth2_types::{
10    pkce::{CodeChallengeError, CodeChallengeMethodExt},
11    requests::ResponseMode,
12    scope::{OPENID, PROFILE, Scope},
13};
14use rand::{
15    RngCore,
16    distributions::{Alphanumeric, DistString},
17};
18use ruma_common::UserId;
19use serde::Serialize;
20use ulid::Ulid;
21use url::Url;
22
23use super::session::Session;
24use crate::InvalidTransitionError;
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
27pub struct Pkce {
28    pub challenge_method: PkceCodeChallengeMethod,
29    pub challenge: String,
30}
31
32impl Pkce {
33    #[must_use]
35    pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
36        Pkce {
37            challenge_method,
38            challenge,
39        }
40    }
41
42    pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
48        self.challenge_method.verify(&self.challenge, verifier)
49    }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
53pub struct AuthorizationCode {
54    pub code: String,
55    pub pkce: Option<Pkce>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
59#[serde(tag = "stage", rename_all = "lowercase")]
60pub enum AuthorizationGrantStage {
61    #[default]
62    Pending,
63    Fulfilled {
64        session_id: Ulid,
65        fulfilled_at: DateTime<Utc>,
66    },
67    Exchanged {
68        session_id: Ulid,
69        fulfilled_at: DateTime<Utc>,
70        exchanged_at: DateTime<Utc>,
71    },
72    Cancelled {
73        cancelled_at: DateTime<Utc>,
74    },
75}
76
77impl AuthorizationGrantStage {
78    #[must_use]
79    pub fn new() -> Self {
80        Self::Pending
81    }
82
83    fn fulfill(
84        self,
85        fulfilled_at: DateTime<Utc>,
86        session: &Session,
87    ) -> Result<Self, InvalidTransitionError> {
88        match self {
89            Self::Pending => Ok(Self::Fulfilled {
90                fulfilled_at,
91                session_id: session.id,
92            }),
93            _ => Err(InvalidTransitionError),
94        }
95    }
96
97    fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
98        match self {
99            Self::Fulfilled {
100                fulfilled_at,
101                session_id,
102            } => Ok(Self::Exchanged {
103                fulfilled_at,
104                exchanged_at,
105                session_id,
106            }),
107            _ => Err(InvalidTransitionError),
108        }
109    }
110
111    fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
112        match self {
113            Self::Pending => Ok(Self::Cancelled { cancelled_at }),
114            _ => Err(InvalidTransitionError),
115        }
116    }
117
118    #[must_use]
122    pub fn is_pending(&self) -> bool {
123        matches!(self, Self::Pending)
124    }
125
126    #[must_use]
130    pub fn is_fulfilled(&self) -> bool {
131        matches!(self, Self::Fulfilled { .. })
132    }
133
134    #[must_use]
138    pub fn is_exchanged(&self) -> bool {
139        matches!(self, Self::Exchanged { .. })
140    }
141}
142
143pub enum LoginHint<'a> {
144    MXID(&'a UserId),
145    None,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
149pub struct AuthorizationGrant {
150    pub id: Ulid,
151    #[serde(flatten)]
152    pub stage: AuthorizationGrantStage,
153    pub code: Option<AuthorizationCode>,
154    pub client_id: Ulid,
155    pub redirect_uri: Url,
156    pub scope: Scope,
157    pub state: Option<String>,
158    pub nonce: Option<String>,
159    pub response_mode: ResponseMode,
160    pub response_type_id_token: bool,
161    pub created_at: DateTime<Utc>,
162    pub login_hint: Option<String>,
163}
164
165impl std::ops::Deref for AuthorizationGrant {
166    type Target = AuthorizationGrantStage;
167
168    fn deref(&self) -> &Self::Target {
169        &self.stage
170    }
171}
172
173impl AuthorizationGrant {
174    #[must_use]
175    pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
176        let Some(login_hint) = &self.login_hint else {
177            return LoginHint::None;
178        };
179
180        let Some((prefix, value)) = login_hint.split_once(':') else {
182            return LoginHint::None;
183        };
184
185        match prefix {
186            "mxid" => {
187                let Ok(mxid) = <&UserId>::try_from(value) else {
189                    return LoginHint::None;
190                };
191
192                if mxid.server_name() != homeserver {
194                    return LoginHint::None;
195                }
196
197                LoginHint::MXID(mxid)
198            }
199            _ => LoginHint::None,
201        }
202    }
203
204    pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
212        self.stage = self.stage.exchange(exchanged_at)?;
213        Ok(self)
214    }
215
216    pub fn fulfill(
224        mut self,
225        fulfilled_at: DateTime<Utc>,
226        session: &Session,
227    ) -> Result<Self, InvalidTransitionError> {
228        self.stage = self.stage.fulfill(fulfilled_at, session)?;
229        Ok(self)
230    }
231
232    pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
244        self.stage = self.stage.cancel(canceld_at)?;
245        Ok(self)
246    }
247
248    #[doc(hidden)]
249    pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
250        Self {
251            id: Ulid::from_datetime_with_source(now.into(), rng),
252            stage: AuthorizationGrantStage::Pending,
253            code: Some(AuthorizationCode {
254                code: Alphanumeric.sample_string(rng, 10),
255                pkce: None,
256            }),
257            client_id: Ulid::from_datetime_with_source(now.into(), rng),
258            redirect_uri: Url::parse("http://localhost:8080").unwrap(),
259            scope: Scope::from_iter([OPENID, PROFILE]),
260            state: Some(Alphanumeric.sample_string(rng, 10)),
261            nonce: Some(Alphanumeric.sample_string(rng, 10)),
262            response_mode: ResponseMode::Query,
263            response_type_id_token: false,
264            created_at: now,
265            login_hint: Some(String::from("mxid:@example-user:example.com")),
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use rand::thread_rng;
273
274    use super::*;
275
276    #[test]
277    fn no_login_hint() {
278        #[allow(clippy::disallowed_methods)]
279        let mut rng = thread_rng();
280
281        #[allow(clippy::disallowed_methods)]
282        let now = Utc::now();
283
284        let grant = AuthorizationGrant {
285            login_hint: None,
286            ..AuthorizationGrant::sample(now, &mut rng)
287        };
288
289        let hint = grant.parse_login_hint("example.com");
290
291        assert!(matches!(hint, LoginHint::None));
292    }
293
294    #[test]
295    fn valid_login_hint() {
296        #[allow(clippy::disallowed_methods)]
297        let mut rng = thread_rng();
298
299        #[allow(clippy::disallowed_methods)]
300        let now = Utc::now();
301
302        let grant = AuthorizationGrant {
303            login_hint: Some(String::from("mxid:@example-user:example.com")),
304            ..AuthorizationGrant::sample(now, &mut rng)
305        };
306
307        let hint = grant.parse_login_hint("example.com");
308
309        assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
310    }
311
312    #[test]
313    fn invalid_login_hint() {
314        #[allow(clippy::disallowed_methods)]
315        let mut rng = thread_rng();
316
317        #[allow(clippy::disallowed_methods)]
318        let now = Utc::now();
319
320        let grant = AuthorizationGrant {
321            login_hint: Some(String::from("example-user")),
322            ..AuthorizationGrant::sample(now, &mut rng)
323        };
324
325        let hint = grant.parse_login_hint("example.com");
326
327        assert!(matches!(hint, LoginHint::None));
328    }
329
330    #[test]
331    fn valid_login_hint_for_wrong_homeserver() {
332        #[allow(clippy::disallowed_methods)]
333        let mut rng = thread_rng();
334
335        #[allow(clippy::disallowed_methods)]
336        let now = Utc::now();
337
338        let grant = AuthorizationGrant {
339            login_hint: Some(String::from("mxid:@example-user:matrix.org")),
340            ..AuthorizationGrant::sample(now, &mut rng)
341        };
342
343        let hint = grant.parse_login_hint("example.com");
344
345        assert!(matches!(hint, LoginHint::None));
346    }
347
348    #[test]
349    fn unknown_login_hint_type() {
350        #[allow(clippy::disallowed_methods)]
351        let mut rng = thread_rng();
352
353        #[allow(clippy::disallowed_methods)]
354        let now = Utc::now();
355
356        let grant = AuthorizationGrant {
357            login_hint: Some(String::from("something:anything")),
358            ..AuthorizationGrant::sample(now, &mut rng)
359        };
360
361        let hint = grant.parse_login_hint("example.com");
362
363        assert!(matches!(hint, LoginHint::None));
364    }
365}