Skip to main content

movement_sdk/account/
keyless.rs

1//! Keyless (OIDC-based) account support.
2
3use crate::account::account::{Account, AuthenticationKey};
4use crate::crypto::{Ed25519PrivateKey, Ed25519PublicKey, KEYLESS_SCHEME};
5use crate::error::{MovementError, MovementResult};
6use crate::types::AccountAddress;
7use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
8use rand::RngCore;
9use serde::{Deserialize, Serialize};
10use sha3::{Digest, Sha3_256};
11use std::fmt;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use url::Url;
14
15// Re-export JwkSet for use with from_jwt_with_jwks and refresh_proof_with_jwks
16pub use jsonwebtoken::jwk::JwkSet;
17
18/// Keyless signature payload for transaction authentication.
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20pub struct KeylessSignature {
21    /// Ephemeral public key bytes.
22    pub ephemeral_public_key: Vec<u8>,
23    /// Signature produced by the ephemeral key.
24    pub ephemeral_signature: Vec<u8>,
25    /// Zero-knowledge proof bytes.
26    pub proof: Vec<u8>,
27}
28
29impl KeylessSignature {
30    /// Serializes the signature using BCS.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if BCS serialization fails.
35    pub fn to_bcs(&self) -> MovementResult<Vec<u8>> {
36        aptos_bcs::to_bytes(self).map_err(MovementError::bcs)
37    }
38}
39
40/// Short-lived key pair used for keyless signing.
41#[derive(Clone)]
42pub struct EphemeralKeyPair {
43    private_key: Ed25519PrivateKey,
44    public_key: Ed25519PublicKey,
45    expiry: SystemTime,
46    nonce: String,
47}
48
49impl EphemeralKeyPair {
50    /// Generates a new ephemeral key pair with the given expiry (in seconds).
51    pub fn generate(expiry_secs: u64) -> Self {
52        let private_key = Ed25519PrivateKey::generate();
53        let public_key = private_key.public_key();
54        let nonce = {
55            let mut bytes = [0u8; 16];
56            rand::rngs::OsRng.fill_bytes(&mut bytes);
57            const_hex::encode(bytes)
58        };
59        Self {
60            private_key,
61            public_key,
62            expiry: SystemTime::now() + Duration::from_secs(expiry_secs),
63            nonce,
64        }
65    }
66
67    /// Returns true if the key pair has expired.
68    pub fn is_expired(&self) -> bool {
69        SystemTime::now() >= self.expiry
70    }
71
72    /// Returns the nonce associated with this key pair.
73    pub fn nonce(&self) -> &str {
74        &self.nonce
75    }
76
77    /// Returns the public key.
78    pub fn public_key(&self) -> &Ed25519PublicKey {
79        &self.public_key
80    }
81}
82
83impl fmt::Debug for EphemeralKeyPair {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("EphemeralKeyPair")
86            .field("public_key", &self.public_key)
87            .field("expiry", &self.expiry)
88            .field("nonce", &self.nonce)
89            .finish_non_exhaustive()
90    }
91}
92
93/// Supported OIDC providers.
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub enum OidcProvider {
96    /// Google identity provider.
97    Google,
98    /// Apple identity provider.
99    Apple,
100    /// Microsoft identity provider.
101    Microsoft,
102    /// Custom OIDC provider.
103    Custom {
104        /// Issuer URL.
105        issuer: String,
106        /// JWKS URL.
107        jwks_url: String,
108    },
109}
110
111impl OidcProvider {
112    /// Returns the issuer URL.
113    pub fn issuer(&self) -> &str {
114        match self {
115            OidcProvider::Google => "https://accounts.google.com",
116            OidcProvider::Apple => "https://appleid.apple.com",
117            OidcProvider::Microsoft => "https://login.microsoftonline.com/common/v2.0",
118            OidcProvider::Custom { issuer, .. } => issuer,
119        }
120    }
121
122    /// Returns the JWKS URL.
123    pub fn jwks_url(&self) -> &str {
124        match self {
125            OidcProvider::Google => "https://www.googleapis.com/oauth2/v3/certs",
126            OidcProvider::Apple => "https://appleid.apple.com/auth/keys",
127            OidcProvider::Microsoft => {
128                "https://login.microsoftonline.com/common/discovery/v2.0/keys"
129            }
130            OidcProvider::Custom { jwks_url, .. } => jwks_url,
131        }
132    }
133
134    /// Infers a provider from an issuer URL.
135    ///
136    /// # Security
137    ///
138    /// For unknown issuers, the JWKS URL is constructed as `{issuer}/.well-known/jwks.json`.
139    /// Non-HTTPS issuers are accepted at construction time but will produce an empty
140    /// JWKS URL, causing a clear error at JWKS fetch time. This prevents SSRF via
141    /// `http://`, `file://`, or other dangerous URL schemes without changing the
142    /// function signature. Callers controlling issuer input should additionally
143    /// validate the host (e.g., block private IP ranges) if SSRF is a concern.
144    pub fn from_issuer(issuer: &str) -> Self {
145        match issuer {
146            "https://accounts.google.com" => OidcProvider::Google,
147            "https://appleid.apple.com" => OidcProvider::Apple,
148            "https://login.microsoftonline.com/common/v2.0" => OidcProvider::Microsoft,
149            _ => {
150                // SECURITY: Only accept HTTPS issuers to prevent SSRF attacks.
151                // A malicious JWT could set `iss` to an internal URL (e.g.,
152                // http://169.254.169.254/) causing the SDK to make requests to
153                // attacker-chosen endpoints when fetching JWKS.
154                let jwks_url = if issuer.starts_with("https://") {
155                    format!("{issuer}/.well-known/jwks.json")
156                } else {
157                    // Non-HTTPS issuers get an invalid JWKS URL that will fail
158                    // at fetch time with a clear error rather than making requests
159                    // to potentially dangerous endpoints.
160                    String::new()
161                };
162                OidcProvider::Custom {
163                    issuer: issuer.to_string(),
164                    jwks_url,
165                }
166            }
167        }
168    }
169}
170
171/// Pepper bytes used in keyless address derivation.
172///
173/// # Security
174///
175/// The pepper is secret material used to derive keyless account addresses.
176/// It is automatically zeroized when dropped to prevent key material from
177/// lingering in memory.
178#[derive(Clone, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
179pub struct Pepper(Vec<u8>);
180
181impl std::fmt::Debug for Pepper {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "Pepper(REDACTED)")
184    }
185}
186
187impl Pepper {
188    /// Creates a new pepper from raw bytes.
189    pub fn new(bytes: Vec<u8>) -> Self {
190        Self(bytes)
191    }
192
193    /// Returns the pepper as bytes.
194    pub fn as_bytes(&self) -> &[u8] {
195        &self.0
196    }
197
198    /// Creates a pepper from hex.
199    ///
200    /// # Errors
201    ///
202    /// Returns an error if the hex string is invalid or cannot be decoded.
203    pub fn from_hex(hex_str: &str) -> MovementResult<Self> {
204        Ok(Self(const_hex::decode(hex_str)?))
205    }
206
207    /// Returns the pepper as hex.
208    pub fn to_hex(&self) -> String {
209        const_hex::encode_prefixed(&self.0)
210    }
211}
212
213/// Zero-knowledge proof bytes.
214#[derive(Clone, Debug, PartialEq, Eq)]
215pub struct ZkProof(Vec<u8>);
216
217impl ZkProof {
218    /// Creates a new proof from raw bytes.
219    pub fn new(bytes: Vec<u8>) -> Self {
220        Self(bytes)
221    }
222
223    /// Returns the proof as bytes.
224    pub fn as_bytes(&self) -> &[u8] {
225        &self.0
226    }
227
228    /// Creates a proof from hex.
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the hex string is invalid or cannot be decoded.
233    pub fn from_hex(hex_str: &str) -> MovementResult<Self> {
234        Ok(Self(const_hex::decode(hex_str)?))
235    }
236
237    /// Returns the proof as hex.
238    pub fn to_hex(&self) -> String {
239        const_hex::encode_prefixed(&self.0)
240    }
241}
242
243/// Service for obtaining pepper values.
244pub trait PepperService: Send + Sync {
245    /// Fetches the pepper for a JWT.
246    fn get_pepper(
247        &self,
248        jwt: &str,
249    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>;
250}
251
252/// Service for generating zero-knowledge proofs.
253pub trait ProverService: Send + Sync {
254    /// Generates the proof for keyless authentication.
255    fn generate_proof<'a>(
256        &'a self,
257        jwt: &'a str,
258        ephemeral_key: &'a EphemeralKeyPair,
259        pepper: &'a Pepper,
260    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>;
261}
262
263/// HTTP pepper service client.
264#[derive(Clone, Debug)]
265pub struct HttpPepperService {
266    url: Url,
267    client: reqwest::Client,
268}
269
270impl HttpPepperService {
271    /// Creates a new HTTP pepper service client.
272    pub fn new(url: Url) -> Self {
273        Self {
274            url,
275            client: reqwest::Client::new(),
276        }
277    }
278}
279
280#[derive(Serialize)]
281struct PepperRequest<'a> {
282    jwt: &'a str,
283}
284
285#[derive(Deserialize)]
286struct PepperResponse {
287    pepper: String,
288}
289
290impl PepperService for HttpPepperService {
291    fn get_pepper(
292        &self,
293        jwt: &str,
294    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>
295    {
296        let jwt = jwt.to_owned();
297        Box::pin(async move {
298            let response = self
299                .client
300                .post(self.url.clone())
301                .json(&PepperRequest { jwt: &jwt })
302                .send()
303                .await?
304                .error_for_status()?;
305
306            // SECURITY: Stream body with size limit to prevent OOM
307            let bytes =
308                crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
309            let payload: PepperResponse = serde_json::from_slice(&bytes).map_err(|e| {
310                MovementError::InvalidJwt(format!("failed to parse pepper response: {e}"))
311            })?;
312            Pepper::from_hex(&payload.pepper)
313        })
314    }
315}
316
317/// HTTP prover service client.
318#[derive(Clone, Debug)]
319pub struct HttpProverService {
320    url: Url,
321    client: reqwest::Client,
322}
323
324impl HttpProverService {
325    /// Creates a new HTTP prover service client.
326    pub fn new(url: Url) -> Self {
327        Self {
328            url,
329            client: reqwest::Client::new(),
330        }
331    }
332}
333
334#[derive(Serialize)]
335struct ProverRequest<'a> {
336    jwt: &'a str,
337    ephemeral_public_key: String,
338    nonce: &'a str,
339    pepper: String,
340}
341
342#[derive(Deserialize)]
343struct ProverResponse {
344    proof: String,
345}
346
347impl ProverService for HttpProverService {
348    fn generate_proof<'a>(
349        &'a self,
350        jwt: &'a str,
351        ephemeral_key: &'a EphemeralKeyPair,
352        pepper: &'a Pepper,
353    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>
354    {
355        Box::pin(async move {
356            let request = ProverRequest {
357                jwt,
358                ephemeral_public_key: const_hex::encode_prefixed(
359                    ephemeral_key.public_key.to_bytes(),
360                ),
361                nonce: ephemeral_key.nonce(),
362                pepper: pepper.to_hex(),
363            };
364
365            let response = self
366                .client
367                .post(self.url.clone())
368                .json(&request)
369                .send()
370                .await?
371                .error_for_status()?;
372
373            // SECURITY: Stream body with size limit to prevent OOM
374            let bytes =
375                crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
376            let payload: ProverResponse = serde_json::from_slice(&bytes).map_err(|e| {
377                MovementError::InvalidJwt(format!("failed to parse prover response: {e}"))
378            })?;
379            ZkProof::from_hex(&payload.proof)
380        })
381    }
382}
383
384/// Account authenticated via OIDC.
385pub struct KeylessAccount {
386    ephemeral_key: EphemeralKeyPair,
387    provider: OidcProvider,
388    issuer: String,
389    audience: String,
390    user_id: String,
391    pepper: Pepper,
392    proof: ZkProof,
393    address: AccountAddress,
394    auth_key: AuthenticationKey,
395    jwt_expiration: Option<SystemTime>,
396}
397
398impl KeylessAccount {
399    /// Creates a keyless account from an OIDC JWT token.
400    ///
401    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint
402    /// before extracting claims and creating the account.
403    ///
404    /// # Network Requests
405    ///
406    /// This method makes HTTP requests to:
407    /// - The OIDC provider's JWKS endpoint to fetch signing keys
408    /// - The pepper service to obtain the pepper
409    /// - The prover service to generate a ZK proof
410    ///
411    /// For more control over network calls and caching, use [`Self::from_jwt_with_jwks`]
412    /// with pre-fetched JWKS.
413    ///
414    /// # Errors
415    ///
416    /// This function will return an error if:
417    /// - The JWT signature verification fails
418    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
419    /// - The JWT nonce doesn't match the ephemeral key's nonce
420    /// - The JWT is expired
421    /// - The JWKS cannot be fetched from the provider (network timeout, DNS failure,
422    ///   connection errors, HTTP errors, or invalid JWKS response)
423    /// - The pepper service fails to return a pepper
424    /// - The prover service fails to generate a proof
425    pub async fn from_jwt(
426        jwt: &str,
427        ephemeral_key: EphemeralKeyPair,
428        pepper_service: &dyn PepperService,
429        prover_service: &dyn ProverService,
430    ) -> MovementResult<Self> {
431        // First, decode without verification to get the issuer for JWKS lookup
432        let unverified_claims = decode_claims_unverified(jwt)?;
433        let issuer = unverified_claims
434            .iss
435            .as_ref()
436            .ok_or_else(|| MovementError::InvalidJwt("missing iss claim".into()))?;
437
438        // Determine provider and fetch JWKS
439        let provider = OidcProvider::from_issuer(issuer);
440        let client = reqwest::Client::builder()
441            .timeout(JWKS_FETCH_TIMEOUT)
442            .build()
443            .map_err(|e| MovementError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
444        let jwks = fetch_jwks(&client, provider.jwks_url()).await?;
445
446        // Now verify and decode the JWT properly
447        let claims = decode_and_verify_jwt(jwt, &jwks)?;
448        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
449
450        if nonce != ephemeral_key.nonce() {
451            return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
452        }
453
454        let pepper = pepper_service.get_pepper(jwt).await?;
455        let proof = prover_service
456            .generate_proof(jwt, &ephemeral_key, &pepper)
457            .await?;
458
459        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
460        let auth_key = AuthenticationKey::new(address.to_bytes());
461
462        Ok(Self {
463            provider: OidcProvider::from_issuer(&issuer),
464            issuer,
465            audience,
466            user_id,
467            pepper,
468            proof,
469            address,
470            auth_key,
471            jwt_expiration: exp,
472            ephemeral_key,
473        })
474    }
475
476    /// Creates a keyless account from a JWT with pre-fetched JWKS.
477    ///
478    /// This method is useful when you want to:
479    /// - Cache the JWKS to avoid repeated network requests
480    /// - Have more control over HTTP client configuration
481    /// - Implement custom caching strategies based on HTTP cache headers
482    ///
483    /// # Errors
484    ///
485    /// This function will return an error if:
486    /// - The JWT signature verification fails
487    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
488    /// - The JWT nonce doesn't match the ephemeral key's nonce
489    /// - The JWT is expired
490    /// - The pepper service fails to return a pepper
491    /// - The prover service fails to generate a proof
492    pub async fn from_jwt_with_jwks(
493        jwt: &str,
494        jwks: &JwkSet,
495        ephemeral_key: EphemeralKeyPair,
496        pepper_service: &dyn PepperService,
497        prover_service: &dyn ProverService,
498    ) -> MovementResult<Self> {
499        // Verify and decode the JWT using the provided JWKS
500        let claims = decode_and_verify_jwt(jwt, jwks)?;
501        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
502
503        if nonce != ephemeral_key.nonce() {
504            return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
505        }
506
507        let pepper = pepper_service.get_pepper(jwt).await?;
508        let proof = prover_service
509            .generate_proof(jwt, &ephemeral_key, &pepper)
510            .await?;
511
512        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
513        let auth_key = AuthenticationKey::new(address.to_bytes());
514
515        Ok(Self {
516            provider: OidcProvider::from_issuer(&issuer),
517            issuer,
518            audience,
519            user_id,
520            pepper,
521            proof,
522            address,
523            auth_key,
524            jwt_expiration: exp,
525            ephemeral_key,
526        })
527    }
528
529    /// Returns the OIDC provider.
530    pub fn provider(&self) -> &OidcProvider {
531        &self.provider
532    }
533
534    /// Returns the issuer.
535    pub fn issuer(&self) -> &str {
536        &self.issuer
537    }
538
539    /// Returns the audience.
540    pub fn audience(&self) -> &str {
541        &self.audience
542    }
543
544    /// Returns the user identifier (sub claim).
545    pub fn user_id(&self) -> &str {
546        &self.user_id
547    }
548
549    /// Returns the proof.
550    pub fn proof(&self) -> &ZkProof {
551        &self.proof
552    }
553
554    /// Returns true if the account is still valid.
555    pub fn is_valid(&self) -> bool {
556        if self.ephemeral_key.is_expired() {
557            return false;
558        }
559
560        match self.jwt_expiration {
561            Some(exp) => SystemTime::now() < exp,
562            None => true,
563        }
564    }
565
566    /// Refreshes the proof using a new JWT.
567    ///
568    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint.
569    ///
570    /// # Network Requests
571    ///
572    /// This method makes HTTP requests to fetch the JWKS from the OIDC provider.
573    /// For more control over network calls and caching, use [`Self::refresh_proof_with_jwks`].
574    ///
575    /// # Errors
576    ///
577    /// Returns an error if:
578    /// - The JWKS cannot be fetched (network timeout, DNS failure, connection errors)
579    /// - The JWT signature verification fails
580    /// - The JWT cannot be decoded
581    /// - The JWT nonce does not match the ephemeral key
582    /// - The JWT identity does not match the account
583    /// - The prover service fails to generate a new proof
584    pub async fn refresh_proof(
585        &mut self,
586        jwt: &str,
587        prover_service: &dyn ProverService,
588    ) -> MovementResult<()> {
589        // Fetch JWKS and verify JWT
590        let client = reqwest::Client::builder()
591            .timeout(JWKS_FETCH_TIMEOUT)
592            .build()
593            .map_err(|e| MovementError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
594        let jwks = fetch_jwks(&client, self.provider.jwks_url()).await?;
595        self.refresh_proof_with_jwks(jwt, &jwks, prover_service)
596            .await
597    }
598
599    /// Refreshes the proof using a new JWT with pre-fetched JWKS.
600    ///
601    /// This method is useful for caching the JWKS or using a custom HTTP client.
602    ///
603    /// # Errors
604    ///
605    /// Returns an error if:
606    /// - The JWT signature verification fails
607    /// - The JWT cannot be decoded
608    /// - The JWT nonce does not match the ephemeral key
609    /// - The JWT identity does not match the account
610    /// - The prover service fails to generate a new proof
611    pub async fn refresh_proof_with_jwks(
612        &mut self,
613        jwt: &str,
614        jwks: &JwkSet,
615        prover_service: &dyn ProverService,
616    ) -> MovementResult<()> {
617        let claims = decode_and_verify_jwt(jwt, jwks)?;
618        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
619
620        if nonce != self.ephemeral_key.nonce() {
621            return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
622        }
623
624        if issuer != self.issuer || audience != self.audience || user_id != self.user_id {
625            return Err(MovementError::InvalidJwt(
626                "JWT identity does not match account".into(),
627            ));
628        }
629
630        let proof = prover_service
631            .generate_proof(jwt, &self.ephemeral_key, &self.pepper)
632            .await?;
633        self.proof = proof;
634        self.jwt_expiration = exp;
635        Ok(())
636    }
637
638    /// Signs a message and returns the structured keyless signature.
639    pub fn sign_keyless(&self, message: &[u8]) -> KeylessSignature {
640        let signature = self.ephemeral_key.private_key.sign(message).to_bytes();
641        KeylessSignature {
642            ephemeral_public_key: self.ephemeral_key.public_key.to_bytes().to_vec(),
643            ephemeral_signature: signature.to_vec(),
644            proof: self.proof.as_bytes().to_vec(),
645        }
646    }
647
648    /// Creates a keyless account from pre-verified JWT claims.
649    ///
650    /// This is useful for testing or when JWT verification is handled externally.
651    /// The caller is responsible for ensuring the JWT was properly verified.
652    ///
653    /// # Errors
654    ///
655    /// This function will return an error if:
656    /// - The nonce doesn't match the ephemeral key's nonce
657    /// - The pepper service fails to return a pepper
658    /// - The prover service fails to generate a proof
659    #[doc(hidden)]
660    #[allow(clippy::too_many_arguments)]
661    pub async fn from_verified_claims(
662        issuer: String,
663        audience: String,
664        user_id: String,
665        nonce: String,
666        exp: Option<SystemTime>,
667        ephemeral_key: EphemeralKeyPair,
668        pepper_service: &dyn PepperService,
669        prover_service: &dyn ProverService,
670        jwt_for_services: &str,
671    ) -> MovementResult<Self> {
672        if nonce != ephemeral_key.nonce() {
673            return Err(MovementError::InvalidJwt("nonce mismatch".into()));
674        }
675
676        let pepper = pepper_service.get_pepper(jwt_for_services).await?;
677        let proof = prover_service
678            .generate_proof(jwt_for_services, &ephemeral_key, &pepper)
679            .await?;
680
681        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
682        let auth_key = AuthenticationKey::new(address.to_bytes());
683
684        Ok(Self {
685            provider: OidcProvider::from_issuer(&issuer),
686            issuer,
687            audience,
688            user_id,
689            pepper,
690            proof,
691            address,
692            auth_key,
693            jwt_expiration: exp,
694            ephemeral_key,
695        })
696    }
697}
698
699impl Account for KeylessAccount {
700    fn address(&self) -> AccountAddress {
701        self.address
702    }
703
704    fn authentication_key(&self) -> AuthenticationKey {
705        self.auth_key
706    }
707
708    fn sign(&self, message: &[u8]) -> crate::error::MovementResult<Vec<u8>> {
709        let signature = self.sign_keyless(message);
710        signature
711            .to_bcs()
712            .map_err(|e| crate::error::MovementError::Bcs(e.to_string()))
713    }
714
715    fn public_key_bytes(&self) -> Vec<u8> {
716        self.ephemeral_key.public_key.to_bytes().to_vec()
717    }
718
719    fn signature_scheme(&self) -> u8 {
720        KEYLESS_SCHEME
721    }
722}
723
724impl fmt::Debug for KeylessAccount {
725    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
726        f.debug_struct("KeylessAccount")
727            .field("address", &self.address)
728            .field("provider", &self.provider)
729            .field("issuer", &self.issuer)
730            .field("audience", &self.audience)
731            .field("user_id", &self.user_id)
732            .finish_non_exhaustive()
733    }
734}
735
736#[derive(Debug, Deserialize)]
737struct JwtClaims {
738    iss: Option<String>,
739    aud: Option<AudClaim>,
740    sub: Option<String>,
741    exp: Option<u64>,
742    nonce: Option<String>,
743}
744
745#[derive(Debug, Deserialize)]
746#[serde(untagged)]
747enum AudClaim {
748    Single(String),
749    Multiple(Vec<String>),
750}
751
752impl AudClaim {
753    fn first(&self) -> Option<&str> {
754        match self {
755            AudClaim::Single(value) => Some(value.as_str()),
756            AudClaim::Multiple(values) => values.first().map(std::string::String::as_str),
757        }
758    }
759}
760
761/// Default timeout for JWKS fetch requests (10 seconds).
762const JWKS_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
763
764/// Maximum JWKS response size: 1 MB (JWKS payloads are typically under 10 KB).
765const MAX_JWKS_RESPONSE_SIZE: usize = 1024 * 1024;
766
767/// Fetches the JWKS (JSON Web Key Set) from an OIDC provider.
768///
769/// # Errors
770///
771/// Returns an error if:
772/// - The JWKS cannot be fetched (network timeouts, DNS resolution failures,
773///   TLS/connection errors, or HTTP errors)
774/// - The JWKS endpoint returns a non-success status code
775/// - The response cannot be parsed as valid JWKS JSON
776async fn fetch_jwks(client: &reqwest::Client, jwks_url: &str) -> MovementResult<JwkSet> {
777    // SECURITY: Validate the JWKS URL scheme to prevent SSRF.
778    // The issuer comes from an untrusted JWT, so the derived JWKS URL could
779    // point to internal services (e.g., cloud metadata endpoints).
780    let parsed_url = Url::parse(jwks_url)
781        .map_err(|e| MovementError::InvalidJwt(format!("invalid JWKS URL: {e}")))?;
782    if parsed_url.scheme() != "https" {
783        return Err(MovementError::InvalidJwt(format!(
784            "JWKS URL must use HTTPS scheme, got: {}",
785            parsed_url.scheme()
786        )));
787    }
788
789    // Note: timeout is configured on the client, not per-request
790    let response = client.get(jwks_url).send().await?;
791
792    if !response.status().is_success() {
793        return Err(MovementError::InvalidJwt(format!(
794            "JWKS endpoint returned status: {}",
795            response.status()
796        )));
797    }
798
799    // SECURITY: Stream body with size limit to prevent OOM from a
800    // compromised or malicious JWKS endpoint (including chunked encoding).
801    let bytes = crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
802    let jwks: JwkSet = serde_json::from_slice(&bytes)
803        .map_err(|e| MovementError::InvalidJwt(format!("failed to parse JWKS: {e}")))?;
804    Ok(jwks)
805}
806
807/// Decodes and verifies a JWT using the provided JWKS.
808///
809/// This function:
810/// 1. Extracts the `kid` (key ID) from the JWT header
811/// 2. Finds the matching key in the JWKS
812/// 3. Verifies the signature and decodes the claims
813///
814/// # Errors
815///
816/// Returns an error if:
817/// - The JWT header cannot be decoded
818/// - No matching key is found in the JWKS
819/// - The signature verification fails
820/// - The claims cannot be decoded
821fn decode_and_verify_jwt(jwt: &str, jwks: &JwkSet) -> MovementResult<JwtClaims> {
822    // Decode header to get the key ID
823    let header = decode_header(jwt)
824        .map_err(|e| MovementError::InvalidJwt(format!("failed to decode JWT header: {e}")))?;
825
826    let kid = header
827        .kid
828        .as_ref()
829        .ok_or_else(|| MovementError::InvalidJwt("JWT header missing 'kid' field".into()))?;
830
831    // Find the matching key in the JWKS
832    let signing_key = jwks.find(kid).ok_or_else(|| {
833        MovementError::InvalidJwt("no matching key found for provided key identifier".into())
834    })?;
835
836    // Create decoding key from JWK
837    let decoding_key = DecodingKey::from_jwk(signing_key)
838        .map_err(|e| MovementError::InvalidJwt(format!("failed to create decoding key: {e}")))?;
839
840    // Determine the algorithm strictly from the JWK to prevent algorithm substitution attacks
841    let jwk_alg = signing_key.common.key_algorithm.ok_or_else(|| {
842        MovementError::InvalidJwt("JWK missing 'alg' (key_algorithm) field".into())
843    })?;
844
845    let algorithm = match jwk_alg {
846        // RSA algorithms
847        jsonwebtoken::jwk::KeyAlgorithm::RS256 => Algorithm::RS256,
848        jsonwebtoken::jwk::KeyAlgorithm::RS384 => Algorithm::RS384,
849        jsonwebtoken::jwk::KeyAlgorithm::RS512 => Algorithm::RS512,
850        // RSA-PSS algorithms
851        jsonwebtoken::jwk::KeyAlgorithm::PS256 => Algorithm::PS256,
852        jsonwebtoken::jwk::KeyAlgorithm::PS384 => Algorithm::PS384,
853        jsonwebtoken::jwk::KeyAlgorithm::PS512 => Algorithm::PS512,
854        // ECDSA algorithms
855        jsonwebtoken::jwk::KeyAlgorithm::ES256 => Algorithm::ES256,
856        jsonwebtoken::jwk::KeyAlgorithm::ES384 => Algorithm::ES384,
857        // EdDSA algorithm
858        jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Algorithm::EdDSA,
859        _ => {
860            return Err(MovementError::InvalidJwt(format!(
861                "unsupported JWK algorithm: {jwk_alg:?}"
862            )));
863        }
864    };
865
866    // Ensure the JWT header algorithm matches the JWK algorithm to prevent substitution
867    if header.alg != algorithm {
868        return Err(MovementError::InvalidJwt(format!(
869            "JWT header algorithm ({:?}) does not match JWK algorithm ({:?})",
870            header.alg, algorithm
871        )));
872    }
873
874    // Configure validation - we'll validate exp ourselves with more detailed errors
875    let mut validation = Validation::new(algorithm);
876    validation.validate_exp = false;
877    validation.validate_aud = false; // We'll check aud after decoding
878    validation.set_required_spec_claims::<String>(&[]);
879
880    let data = decode::<JwtClaims>(jwt, &decoding_key, &validation)
881        .map_err(|e| MovementError::InvalidJwt(format!("JWT verification failed: {e}")))?;
882
883    Ok(data.claims)
884}
885
886/// Decodes JWT claims without signature verification.
887///
888/// This is used only to extract the issuer (and other metadata) before we know
889/// which JWKS endpoint to fetch. This is safe because:
890/// 1. The extracted issuer is only used to determine which JWKS endpoint to fetch.
891/// 2. The JWT is fully verified immediately afterwards using `decode_and_verify_jwt`.
892/// 3. No security decisions are made based on these unverified claims.
893fn decode_claims_unverified(jwt: &str) -> MovementResult<JwtClaims> {
894    // Use dangerous decode only for initial issuer extraction to select the JWKS.
895    // The JWT is not trusted at this point: no authorization decisions are made
896    // based on these unverified claims, and the token is fully verified (including
897    // signature and claims validation) in `decode_and_verify_jwt` right after the
898    // appropriate JWKS has been fetched.
899    let data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(jwt)
900        .map_err(|e| MovementError::InvalidJwt(format!("failed to decode JWT claims: {e}")))?;
901    Ok(data.claims)
902}
903
904fn extract_claims(
905    claims: &JwtClaims,
906) -> MovementResult<(String, String, String, Option<SystemTime>, String)> {
907    let issuer = claims
908        .iss
909        .clone()
910        .ok_or_else(|| MovementError::InvalidJwt("missing iss claim".into()))?;
911    let audience = claims
912        .aud
913        .as_ref()
914        .and_then(|aud| aud.first())
915        .map(std::string::ToString::to_string)
916        .ok_or_else(|| MovementError::InvalidJwt("missing aud claim".into()))?;
917    let user_id = claims
918        .sub
919        .clone()
920        .ok_or_else(|| MovementError::InvalidJwt("missing sub claim".into()))?;
921    let nonce = claims
922        .nonce
923        .clone()
924        .ok_or_else(|| MovementError::InvalidJwt("missing nonce claim".into()))?;
925
926    let exp_time = claims.exp.map(|exp| UNIX_EPOCH + Duration::from_secs(exp));
927    if let Some(exp) = exp_time
928        && SystemTime::now() >= exp
929    {
930        let exp_secs = claims.exp.unwrap_or(0);
931        return Err(MovementError::InvalidJwt(format!(
932            "JWT is expired (exp: {exp_secs} seconds since UNIX_EPOCH)"
933        )));
934    }
935
936    Ok((issuer, audience, user_id, exp_time, nonce))
937}
938
939fn derive_keyless_address(
940    issuer: &str,
941    audience: &str,
942    user_id: &str,
943    pepper: &Pepper,
944) -> AccountAddress {
945    let issuer_hash = sha3_256_bytes(issuer.as_bytes());
946    let audience_hash = sha3_256_bytes(audience.as_bytes());
947    let user_hash = sha3_256_bytes(user_id.as_bytes());
948
949    let mut hasher = Sha3_256::new();
950    hasher.update(issuer_hash);
951    hasher.update(audience_hash);
952    hasher.update(user_hash);
953    hasher.update(pepper.as_bytes());
954    hasher.update([KEYLESS_SCHEME]);
955    let result = hasher.finalize();
956
957    let mut address = [0u8; 32];
958    address.copy_from_slice(&result);
959    AccountAddress::new(address)
960}
961
962fn sha3_256_bytes(data: &[u8]) -> [u8; 32] {
963    let mut hasher = Sha3_256::new();
964    hasher.update(data);
965    let result = hasher.finalize();
966    let mut output = [0u8; 32];
967    output.copy_from_slice(&result);
968    output
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974    use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
975
976    struct StaticPepperService {
977        pepper: Pepper,
978    }
979
980    impl PepperService for StaticPepperService {
981        fn get_pepper(
982            &self,
983            _jwt: &str,
984        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>
985        {
986            Box::pin(async move { Ok(self.pepper.clone()) })
987        }
988    }
989
990    struct StaticProverService {
991        proof: ZkProof,
992    }
993
994    impl ProverService for StaticProverService {
995        fn generate_proof<'a>(
996            &'a self,
997            _jwt: &'a str,
998            _ephemeral_key: &'a EphemeralKeyPair,
999            _pepper: &'a Pepper,
1000        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>
1001        {
1002            Box::pin(async move { Ok(self.proof.clone()) })
1003        }
1004    }
1005
1006    #[derive(Serialize, Deserialize)]
1007    struct TestClaims {
1008        iss: String,
1009        aud: String,
1010        sub: String,
1011        exp: u64,
1012        nonce: String,
1013    }
1014
1015    #[tokio::test]
1016    async fn test_keyless_account_creation() {
1017        let ephemeral = EphemeralKeyPair::generate(3600);
1018        let now = SystemTime::now()
1019            .duration_since(UNIX_EPOCH)
1020            .expect("time went backwards")
1021            .as_secs();
1022
1023        // Create a test JWT for the services (they don't validate it)
1024        let claims = TestClaims {
1025            iss: "https://accounts.google.com".to_string(),
1026            aud: "client-id".to_string(),
1027            sub: "user-123".to_string(),
1028            exp: now + 3600,
1029            nonce: ephemeral.nonce().to_string(),
1030        };
1031
1032        let jwt = encode(
1033            &Header::new(Algorithm::HS256),
1034            &claims,
1035            &EncodingKey::from_secret(b"secret"),
1036        )
1037        .unwrap();
1038
1039        let pepper_service = StaticPepperService {
1040            pepper: Pepper::new(vec![1, 2, 3, 4]),
1041        };
1042        let prover_service = StaticProverService {
1043            proof: ZkProof::new(vec![9, 9, 9]),
1044        };
1045
1046        // Use from_verified_claims for unit testing since we can't mock JWKS
1047        let exp_time = UNIX_EPOCH + std::time::Duration::from_secs(now + 3600);
1048        let account = KeylessAccount::from_verified_claims(
1049            "https://accounts.google.com".to_string(),
1050            "client-id".to_string(),
1051            "user-123".to_string(),
1052            ephemeral.nonce().to_string(),
1053            Some(exp_time),
1054            ephemeral,
1055            &pepper_service,
1056            &prover_service,
1057            &jwt,
1058        )
1059        .await
1060        .unwrap();
1061
1062        assert_eq!(account.issuer(), "https://accounts.google.com");
1063        assert_eq!(account.audience(), "client-id");
1064        assert_eq!(account.user_id(), "user-123");
1065        assert!(account.is_valid());
1066        assert!(!account.address().is_zero());
1067    }
1068
1069    #[tokio::test]
1070    async fn test_keyless_account_nonce_mismatch() {
1071        let ephemeral = EphemeralKeyPair::generate(3600);
1072        let now = SystemTime::now()
1073            .duration_since(UNIX_EPOCH)
1074            .expect("time went backwards")
1075            .as_secs();
1076
1077        let claims = TestClaims {
1078            iss: "https://accounts.google.com".to_string(),
1079            aud: "client-id".to_string(),
1080            sub: "user-123".to_string(),
1081            exp: now + 3600,
1082            nonce: ephemeral.nonce().to_string(),
1083        };
1084
1085        let jwt = encode(
1086            &Header::new(Algorithm::HS256),
1087            &claims,
1088            &EncodingKey::from_secret(b"secret"),
1089        )
1090        .unwrap();
1091
1092        let pepper_service = StaticPepperService {
1093            pepper: Pepper::new(vec![1, 2, 3, 4]),
1094        };
1095        let prover_service = StaticProverService {
1096            proof: ZkProof::new(vec![9, 9, 9]),
1097        };
1098
1099        // Use a different nonce to trigger mismatch
1100        let result = KeylessAccount::from_verified_claims(
1101            "https://accounts.google.com".to_string(),
1102            "client-id".to_string(),
1103            "user-123".to_string(),
1104            "wrong-nonce".to_string(), // This doesn't match ephemeral.nonce()
1105            None,
1106            ephemeral,
1107            &pepper_service,
1108            &prover_service,
1109            &jwt,
1110        )
1111        .await;
1112
1113        assert!(result.is_err());
1114        assert!(matches!(result, Err(MovementError::InvalidJwt(_))));
1115    }
1116
1117    #[test]
1118    fn test_decode_claims_unverified() {
1119        let now = SystemTime::now()
1120            .duration_since(UNIX_EPOCH)
1121            .expect("time went backwards")
1122            .as_secs();
1123
1124        let claims = TestClaims {
1125            iss: "https://accounts.google.com".to_string(),
1126            aud: "test-aud".to_string(),
1127            sub: "test-sub".to_string(),
1128            exp: now + 3600,
1129            nonce: "test-nonce".to_string(),
1130        };
1131
1132        let jwt = encode(
1133            &Header::new(Algorithm::HS256),
1134            &claims,
1135            &EncodingKey::from_secret(b"secret"),
1136        )
1137        .unwrap();
1138
1139        let decoded = decode_claims_unverified(&jwt).unwrap();
1140        assert_eq!(decoded.iss.unwrap(), "https://accounts.google.com");
1141        assert_eq!(decoded.sub.unwrap(), "test-sub");
1142        assert_eq!(decoded.nonce.unwrap(), "test-nonce");
1143    }
1144
1145    #[test]
1146    fn test_oidc_provider_detection() {
1147        assert!(matches!(
1148            OidcProvider::from_issuer("https://accounts.google.com"),
1149            OidcProvider::Google
1150        ));
1151        assert!(matches!(
1152            OidcProvider::from_issuer("https://appleid.apple.com"),
1153            OidcProvider::Apple
1154        ));
1155        assert!(matches!(
1156            OidcProvider::from_issuer("https://unknown.example.com"),
1157            OidcProvider::Custom { .. }
1158        ));
1159    }
1160
1161    #[test]
1162    fn test_decode_and_verify_jwt_missing_kid() {
1163        // Create a JWT without a kid in the header
1164        let now = SystemTime::now()
1165            .duration_since(UNIX_EPOCH)
1166            .expect("time went backwards")
1167            .as_secs();
1168
1169        let claims = TestClaims {
1170            iss: "https://accounts.google.com".to_string(),
1171            aud: "test-aud".to_string(),
1172            sub: "test-sub".to_string(),
1173            exp: now + 3600,
1174            nonce: "test-nonce".to_string(),
1175        };
1176
1177        // HS256 JWT without kid
1178        let jwt = encode(
1179            &Header::new(Algorithm::HS256),
1180            &claims,
1181            &EncodingKey::from_secret(b"secret"),
1182        )
1183        .unwrap();
1184
1185        // Empty JWKS
1186        let jwks = JwkSet { keys: vec![] };
1187
1188        let result = decode_and_verify_jwt(&jwt, &jwks);
1189        assert!(result.is_err());
1190        let err = result.unwrap_err();
1191        assert!(
1192            matches!(&err, MovementError::InvalidJwt(msg) if msg.contains("kid")),
1193            "Expected error about missing kid, got: {err:?}"
1194        );
1195    }
1196
1197    #[test]
1198    fn test_decode_and_verify_jwt_no_matching_key() {
1199        let now = SystemTime::now()
1200            .duration_since(UNIX_EPOCH)
1201            .expect("time went backwards")
1202            .as_secs();
1203
1204        let claims = TestClaims {
1205            iss: "https://accounts.google.com".to_string(),
1206            aud: "test-aud".to_string(),
1207            sub: "test-sub".to_string(),
1208            exp: now + 3600,
1209            nonce: "test-nonce".to_string(),
1210        };
1211
1212        // Create JWT with a kid in header (using HS256 for encoding)
1213        let mut header = Header::new(Algorithm::HS256);
1214        header.kid = Some("test-kid-123".to_string());
1215
1216        let jwt = encode(&header, &claims, &EncodingKey::from_secret(b"secret")).unwrap();
1217
1218        // Empty JWKS - no matching key
1219        let jwks = JwkSet { keys: vec![] };
1220
1221        let result = decode_and_verify_jwt(&jwt, &jwks);
1222        assert!(result.is_err());
1223        let err = result.unwrap_err();
1224        assert!(
1225            matches!(&err, MovementError::InvalidJwt(msg) if msg.contains("no matching key")),
1226            "Expected error about no matching key, got: {err:?}"
1227        );
1228    }
1229
1230    #[test]
1231    fn test_decode_and_verify_jwt_invalid_jwt_format() {
1232        let jwks = JwkSet { keys: vec![] };
1233
1234        // Completely invalid JWT
1235        let result = decode_and_verify_jwt("not-a-valid-jwt", &jwks);
1236        assert!(result.is_err());
1237
1238        // JWT with invalid base64
1239        let result = decode_and_verify_jwt("aaa.bbb.ccc", &jwks);
1240        assert!(result.is_err());
1241    }
1242}