1use crate::account::account::{Account, AuthenticationKey};
4use crate::crypto::{Ed25519PrivateKey, Ed25519PublicKey, KEYLESS_SCHEME};
5use crate::error::{MovementError, MovementResult};
6use crate::types::AccountAddress;
7use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
8use rand::RngCore;
9use serde::{Deserialize, Serialize};
10use sha3::{Digest, Sha3_256};
11use std::fmt;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use url::Url;
14
15pub use jsonwebtoken::jwk::JwkSet;
17
18#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20pub struct KeylessSignature {
21 pub ephemeral_public_key: Vec<u8>,
23 pub ephemeral_signature: Vec<u8>,
25 pub proof: Vec<u8>,
27}
28
29impl KeylessSignature {
30 pub fn to_bcs(&self) -> MovementResult<Vec<u8>> {
36 aptos_bcs::to_bytes(self).map_err(MovementError::bcs)
37 }
38}
39
40#[derive(Clone)]
42pub struct EphemeralKeyPair {
43 private_key: Ed25519PrivateKey,
44 public_key: Ed25519PublicKey,
45 expiry: SystemTime,
46 nonce: String,
47}
48
49impl EphemeralKeyPair {
50 pub fn generate(expiry_secs: u64) -> Self {
52 let private_key = Ed25519PrivateKey::generate();
53 let public_key = private_key.public_key();
54 let nonce = {
55 let mut bytes = [0u8; 16];
56 rand::rngs::OsRng.fill_bytes(&mut bytes);
57 const_hex::encode(bytes)
58 };
59 Self {
60 private_key,
61 public_key,
62 expiry: SystemTime::now() + Duration::from_secs(expiry_secs),
63 nonce,
64 }
65 }
66
67 pub fn is_expired(&self) -> bool {
69 SystemTime::now() >= self.expiry
70 }
71
72 pub fn nonce(&self) -> &str {
74 &self.nonce
75 }
76
77 pub fn public_key(&self) -> &Ed25519PublicKey {
79 &self.public_key
80 }
81}
82
83impl fmt::Debug for EphemeralKeyPair {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("EphemeralKeyPair")
86 .field("public_key", &self.public_key)
87 .field("expiry", &self.expiry)
88 .field("nonce", &self.nonce)
89 .finish_non_exhaustive()
90 }
91}
92
93#[derive(Clone, Debug, PartialEq, Eq)]
95pub enum OidcProvider {
96 Google,
98 Apple,
100 Microsoft,
102 Custom {
104 issuer: String,
106 jwks_url: String,
108 },
109}
110
111impl OidcProvider {
112 pub fn issuer(&self) -> &str {
114 match self {
115 OidcProvider::Google => "https://accounts.google.com",
116 OidcProvider::Apple => "https://appleid.apple.com",
117 OidcProvider::Microsoft => "https://login.microsoftonline.com/common/v2.0",
118 OidcProvider::Custom { issuer, .. } => issuer,
119 }
120 }
121
122 pub fn jwks_url(&self) -> &str {
124 match self {
125 OidcProvider::Google => "https://www.googleapis.com/oauth2/v3/certs",
126 OidcProvider::Apple => "https://appleid.apple.com/auth/keys",
127 OidcProvider::Microsoft => {
128 "https://login.microsoftonline.com/common/discovery/v2.0/keys"
129 }
130 OidcProvider::Custom { jwks_url, .. } => jwks_url,
131 }
132 }
133
134 pub fn from_issuer(issuer: &str) -> Self {
145 match issuer {
146 "https://accounts.google.com" => OidcProvider::Google,
147 "https://appleid.apple.com" => OidcProvider::Apple,
148 "https://login.microsoftonline.com/common/v2.0" => OidcProvider::Microsoft,
149 _ => {
150 let jwks_url = if issuer.starts_with("https://") {
155 format!("{issuer}/.well-known/jwks.json")
156 } else {
157 String::new()
161 };
162 OidcProvider::Custom {
163 issuer: issuer.to_string(),
164 jwks_url,
165 }
166 }
167 }
168 }
169}
170
171#[derive(Clone, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
179pub struct Pepper(Vec<u8>);
180
181impl std::fmt::Debug for Pepper {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(f, "Pepper(REDACTED)")
184 }
185}
186
187impl Pepper {
188 pub fn new(bytes: Vec<u8>) -> Self {
190 Self(bytes)
191 }
192
193 pub fn as_bytes(&self) -> &[u8] {
195 &self.0
196 }
197
198 pub fn from_hex(hex_str: &str) -> MovementResult<Self> {
204 Ok(Self(const_hex::decode(hex_str)?))
205 }
206
207 pub fn to_hex(&self) -> String {
209 const_hex::encode_prefixed(&self.0)
210 }
211}
212
213#[derive(Clone, Debug, PartialEq, Eq)]
215pub struct ZkProof(Vec<u8>);
216
217impl ZkProof {
218 pub fn new(bytes: Vec<u8>) -> Self {
220 Self(bytes)
221 }
222
223 pub fn as_bytes(&self) -> &[u8] {
225 &self.0
226 }
227
228 pub fn from_hex(hex_str: &str) -> MovementResult<Self> {
234 Ok(Self(const_hex::decode(hex_str)?))
235 }
236
237 pub fn to_hex(&self) -> String {
239 const_hex::encode_prefixed(&self.0)
240 }
241}
242
243pub trait PepperService: Send + Sync {
245 fn get_pepper(
247 &self,
248 jwt: &str,
249 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>;
250}
251
252pub trait ProverService: Send + Sync {
254 fn generate_proof<'a>(
256 &'a self,
257 jwt: &'a str,
258 ephemeral_key: &'a EphemeralKeyPair,
259 pepper: &'a Pepper,
260 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>;
261}
262
263#[derive(Clone, Debug)]
265pub struct HttpPepperService {
266 url: Url,
267 client: reqwest::Client,
268}
269
270impl HttpPepperService {
271 pub fn new(url: Url) -> Self {
273 Self {
274 url,
275 client: reqwest::Client::new(),
276 }
277 }
278}
279
280#[derive(Serialize)]
281struct PepperRequest<'a> {
282 jwt: &'a str,
283}
284
285#[derive(Deserialize)]
286struct PepperResponse {
287 pepper: String,
288}
289
290impl PepperService for HttpPepperService {
291 fn get_pepper(
292 &self,
293 jwt: &str,
294 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>
295 {
296 let jwt = jwt.to_owned();
297 Box::pin(async move {
298 let response = self
299 .client
300 .post(self.url.clone())
301 .json(&PepperRequest { jwt: &jwt })
302 .send()
303 .await?
304 .error_for_status()?;
305
306 let bytes =
308 crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
309 let payload: PepperResponse = serde_json::from_slice(&bytes).map_err(|e| {
310 MovementError::InvalidJwt(format!("failed to parse pepper response: {e}"))
311 })?;
312 Pepper::from_hex(&payload.pepper)
313 })
314 }
315}
316
317#[derive(Clone, Debug)]
319pub struct HttpProverService {
320 url: Url,
321 client: reqwest::Client,
322}
323
324impl HttpProverService {
325 pub fn new(url: Url) -> Self {
327 Self {
328 url,
329 client: reqwest::Client::new(),
330 }
331 }
332}
333
334#[derive(Serialize)]
335struct ProverRequest<'a> {
336 jwt: &'a str,
337 ephemeral_public_key: String,
338 nonce: &'a str,
339 pepper: String,
340}
341
342#[derive(Deserialize)]
343struct ProverResponse {
344 proof: String,
345}
346
347impl ProverService for HttpProverService {
348 fn generate_proof<'a>(
349 &'a self,
350 jwt: &'a str,
351 ephemeral_key: &'a EphemeralKeyPair,
352 pepper: &'a Pepper,
353 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>
354 {
355 Box::pin(async move {
356 let request = ProverRequest {
357 jwt,
358 ephemeral_public_key: const_hex::encode_prefixed(
359 ephemeral_key.public_key.to_bytes(),
360 ),
361 nonce: ephemeral_key.nonce(),
362 pepper: pepper.to_hex(),
363 };
364
365 let response = self
366 .client
367 .post(self.url.clone())
368 .json(&request)
369 .send()
370 .await?
371 .error_for_status()?;
372
373 let bytes =
375 crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
376 let payload: ProverResponse = serde_json::from_slice(&bytes).map_err(|e| {
377 MovementError::InvalidJwt(format!("failed to parse prover response: {e}"))
378 })?;
379 ZkProof::from_hex(&payload.proof)
380 })
381 }
382}
383
384pub struct KeylessAccount {
386 ephemeral_key: EphemeralKeyPair,
387 provider: OidcProvider,
388 issuer: String,
389 audience: String,
390 user_id: String,
391 pepper: Pepper,
392 proof: ZkProof,
393 address: AccountAddress,
394 auth_key: AuthenticationKey,
395 jwt_expiration: Option<SystemTime>,
396}
397
398impl KeylessAccount {
399 pub async fn from_jwt(
426 jwt: &str,
427 ephemeral_key: EphemeralKeyPair,
428 pepper_service: &dyn PepperService,
429 prover_service: &dyn ProverService,
430 ) -> MovementResult<Self> {
431 let unverified_claims = decode_claims_unverified(jwt)?;
433 let issuer = unverified_claims
434 .iss
435 .as_ref()
436 .ok_or_else(|| MovementError::InvalidJwt("missing iss claim".into()))?;
437
438 let provider = OidcProvider::from_issuer(issuer);
440 let client = reqwest::Client::builder()
441 .timeout(JWKS_FETCH_TIMEOUT)
442 .build()
443 .map_err(|e| MovementError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
444 let jwks = fetch_jwks(&client, provider.jwks_url()).await?;
445
446 let claims = decode_and_verify_jwt(jwt, &jwks)?;
448 let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
449
450 if nonce != ephemeral_key.nonce() {
451 return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
452 }
453
454 let pepper = pepper_service.get_pepper(jwt).await?;
455 let proof = prover_service
456 .generate_proof(jwt, &ephemeral_key, &pepper)
457 .await?;
458
459 let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
460 let auth_key = AuthenticationKey::new(address.to_bytes());
461
462 Ok(Self {
463 provider: OidcProvider::from_issuer(&issuer),
464 issuer,
465 audience,
466 user_id,
467 pepper,
468 proof,
469 address,
470 auth_key,
471 jwt_expiration: exp,
472 ephemeral_key,
473 })
474 }
475
476 pub async fn from_jwt_with_jwks(
493 jwt: &str,
494 jwks: &JwkSet,
495 ephemeral_key: EphemeralKeyPair,
496 pepper_service: &dyn PepperService,
497 prover_service: &dyn ProverService,
498 ) -> MovementResult<Self> {
499 let claims = decode_and_verify_jwt(jwt, jwks)?;
501 let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
502
503 if nonce != ephemeral_key.nonce() {
504 return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
505 }
506
507 let pepper = pepper_service.get_pepper(jwt).await?;
508 let proof = prover_service
509 .generate_proof(jwt, &ephemeral_key, &pepper)
510 .await?;
511
512 let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
513 let auth_key = AuthenticationKey::new(address.to_bytes());
514
515 Ok(Self {
516 provider: OidcProvider::from_issuer(&issuer),
517 issuer,
518 audience,
519 user_id,
520 pepper,
521 proof,
522 address,
523 auth_key,
524 jwt_expiration: exp,
525 ephemeral_key,
526 })
527 }
528
529 pub fn provider(&self) -> &OidcProvider {
531 &self.provider
532 }
533
534 pub fn issuer(&self) -> &str {
536 &self.issuer
537 }
538
539 pub fn audience(&self) -> &str {
541 &self.audience
542 }
543
544 pub fn user_id(&self) -> &str {
546 &self.user_id
547 }
548
549 pub fn proof(&self) -> &ZkProof {
551 &self.proof
552 }
553
554 pub fn is_valid(&self) -> bool {
556 if self.ephemeral_key.is_expired() {
557 return false;
558 }
559
560 match self.jwt_expiration {
561 Some(exp) => SystemTime::now() < exp,
562 None => true,
563 }
564 }
565
566 pub async fn refresh_proof(
585 &mut self,
586 jwt: &str,
587 prover_service: &dyn ProverService,
588 ) -> MovementResult<()> {
589 let client = reqwest::Client::builder()
591 .timeout(JWKS_FETCH_TIMEOUT)
592 .build()
593 .map_err(|e| MovementError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
594 let jwks = fetch_jwks(&client, self.provider.jwks_url()).await?;
595 self.refresh_proof_with_jwks(jwt, &jwks, prover_service)
596 .await
597 }
598
599 pub async fn refresh_proof_with_jwks(
612 &mut self,
613 jwt: &str,
614 jwks: &JwkSet,
615 prover_service: &dyn ProverService,
616 ) -> MovementResult<()> {
617 let claims = decode_and_verify_jwt(jwt, jwks)?;
618 let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
619
620 if nonce != self.ephemeral_key.nonce() {
621 return Err(MovementError::InvalidJwt("JWT nonce mismatch".into()));
622 }
623
624 if issuer != self.issuer || audience != self.audience || user_id != self.user_id {
625 return Err(MovementError::InvalidJwt(
626 "JWT identity does not match account".into(),
627 ));
628 }
629
630 let proof = prover_service
631 .generate_proof(jwt, &self.ephemeral_key, &self.pepper)
632 .await?;
633 self.proof = proof;
634 self.jwt_expiration = exp;
635 Ok(())
636 }
637
638 pub fn sign_keyless(&self, message: &[u8]) -> KeylessSignature {
640 let signature = self.ephemeral_key.private_key.sign(message).to_bytes();
641 KeylessSignature {
642 ephemeral_public_key: self.ephemeral_key.public_key.to_bytes().to_vec(),
643 ephemeral_signature: signature.to_vec(),
644 proof: self.proof.as_bytes().to_vec(),
645 }
646 }
647
648 #[doc(hidden)]
660 #[allow(clippy::too_many_arguments)]
661 pub async fn from_verified_claims(
662 issuer: String,
663 audience: String,
664 user_id: String,
665 nonce: String,
666 exp: Option<SystemTime>,
667 ephemeral_key: EphemeralKeyPair,
668 pepper_service: &dyn PepperService,
669 prover_service: &dyn ProverService,
670 jwt_for_services: &str,
671 ) -> MovementResult<Self> {
672 if nonce != ephemeral_key.nonce() {
673 return Err(MovementError::InvalidJwt("nonce mismatch".into()));
674 }
675
676 let pepper = pepper_service.get_pepper(jwt_for_services).await?;
677 let proof = prover_service
678 .generate_proof(jwt_for_services, &ephemeral_key, &pepper)
679 .await?;
680
681 let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
682 let auth_key = AuthenticationKey::new(address.to_bytes());
683
684 Ok(Self {
685 provider: OidcProvider::from_issuer(&issuer),
686 issuer,
687 audience,
688 user_id,
689 pepper,
690 proof,
691 address,
692 auth_key,
693 jwt_expiration: exp,
694 ephemeral_key,
695 })
696 }
697}
698
699impl Account for KeylessAccount {
700 fn address(&self) -> AccountAddress {
701 self.address
702 }
703
704 fn authentication_key(&self) -> AuthenticationKey {
705 self.auth_key
706 }
707
708 fn sign(&self, message: &[u8]) -> crate::error::MovementResult<Vec<u8>> {
709 let signature = self.sign_keyless(message);
710 signature
711 .to_bcs()
712 .map_err(|e| crate::error::MovementError::Bcs(e.to_string()))
713 }
714
715 fn public_key_bytes(&self) -> Vec<u8> {
716 self.ephemeral_key.public_key.to_bytes().to_vec()
717 }
718
719 fn signature_scheme(&self) -> u8 {
720 KEYLESS_SCHEME
721 }
722}
723
724impl fmt::Debug for KeylessAccount {
725 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
726 f.debug_struct("KeylessAccount")
727 .field("address", &self.address)
728 .field("provider", &self.provider)
729 .field("issuer", &self.issuer)
730 .field("audience", &self.audience)
731 .field("user_id", &self.user_id)
732 .finish_non_exhaustive()
733 }
734}
735
736#[derive(Debug, Deserialize)]
737struct JwtClaims {
738 iss: Option<String>,
739 aud: Option<AudClaim>,
740 sub: Option<String>,
741 exp: Option<u64>,
742 nonce: Option<String>,
743}
744
745#[derive(Debug, Deserialize)]
746#[serde(untagged)]
747enum AudClaim {
748 Single(String),
749 Multiple(Vec<String>),
750}
751
752impl AudClaim {
753 fn first(&self) -> Option<&str> {
754 match self {
755 AudClaim::Single(value) => Some(value.as_str()),
756 AudClaim::Multiple(values) => values.first().map(std::string::String::as_str),
757 }
758 }
759}
760
761const JWKS_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
763
764const MAX_JWKS_RESPONSE_SIZE: usize = 1024 * 1024;
766
767async fn fetch_jwks(client: &reqwest::Client, jwks_url: &str) -> MovementResult<JwkSet> {
777 let parsed_url = Url::parse(jwks_url)
781 .map_err(|e| MovementError::InvalidJwt(format!("invalid JWKS URL: {e}")))?;
782 if parsed_url.scheme() != "https" {
783 return Err(MovementError::InvalidJwt(format!(
784 "JWKS URL must use HTTPS scheme, got: {}",
785 parsed_url.scheme()
786 )));
787 }
788
789 let response = client.get(jwks_url).send().await?;
791
792 if !response.status().is_success() {
793 return Err(MovementError::InvalidJwt(format!(
794 "JWKS endpoint returned status: {}",
795 response.status()
796 )));
797 }
798
799 let bytes = crate::config::read_response_bounded(response, MAX_JWKS_RESPONSE_SIZE).await?;
802 let jwks: JwkSet = serde_json::from_slice(&bytes)
803 .map_err(|e| MovementError::InvalidJwt(format!("failed to parse JWKS: {e}")))?;
804 Ok(jwks)
805}
806
807fn decode_and_verify_jwt(jwt: &str, jwks: &JwkSet) -> MovementResult<JwtClaims> {
822 let header = decode_header(jwt)
824 .map_err(|e| MovementError::InvalidJwt(format!("failed to decode JWT header: {e}")))?;
825
826 let kid = header
827 .kid
828 .as_ref()
829 .ok_or_else(|| MovementError::InvalidJwt("JWT header missing 'kid' field".into()))?;
830
831 let signing_key = jwks.find(kid).ok_or_else(|| {
833 MovementError::InvalidJwt("no matching key found for provided key identifier".into())
834 })?;
835
836 let decoding_key = DecodingKey::from_jwk(signing_key)
838 .map_err(|e| MovementError::InvalidJwt(format!("failed to create decoding key: {e}")))?;
839
840 let jwk_alg = signing_key.common.key_algorithm.ok_or_else(|| {
842 MovementError::InvalidJwt("JWK missing 'alg' (key_algorithm) field".into())
843 })?;
844
845 let algorithm = match jwk_alg {
846 jsonwebtoken::jwk::KeyAlgorithm::RS256 => Algorithm::RS256,
848 jsonwebtoken::jwk::KeyAlgorithm::RS384 => Algorithm::RS384,
849 jsonwebtoken::jwk::KeyAlgorithm::RS512 => Algorithm::RS512,
850 jsonwebtoken::jwk::KeyAlgorithm::PS256 => Algorithm::PS256,
852 jsonwebtoken::jwk::KeyAlgorithm::PS384 => Algorithm::PS384,
853 jsonwebtoken::jwk::KeyAlgorithm::PS512 => Algorithm::PS512,
854 jsonwebtoken::jwk::KeyAlgorithm::ES256 => Algorithm::ES256,
856 jsonwebtoken::jwk::KeyAlgorithm::ES384 => Algorithm::ES384,
857 jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Algorithm::EdDSA,
859 _ => {
860 return Err(MovementError::InvalidJwt(format!(
861 "unsupported JWK algorithm: {jwk_alg:?}"
862 )));
863 }
864 };
865
866 if header.alg != algorithm {
868 return Err(MovementError::InvalidJwt(format!(
869 "JWT header algorithm ({:?}) does not match JWK algorithm ({:?})",
870 header.alg, algorithm
871 )));
872 }
873
874 let mut validation = Validation::new(algorithm);
876 validation.validate_exp = false;
877 validation.validate_aud = false; validation.set_required_spec_claims::<String>(&[]);
879
880 let data = decode::<JwtClaims>(jwt, &decoding_key, &validation)
881 .map_err(|e| MovementError::InvalidJwt(format!("JWT verification failed: {e}")))?;
882
883 Ok(data.claims)
884}
885
886fn decode_claims_unverified(jwt: &str) -> MovementResult<JwtClaims> {
894 let data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(jwt)
900 .map_err(|e| MovementError::InvalidJwt(format!("failed to decode JWT claims: {e}")))?;
901 Ok(data.claims)
902}
903
904fn extract_claims(
905 claims: &JwtClaims,
906) -> MovementResult<(String, String, String, Option<SystemTime>, String)> {
907 let issuer = claims
908 .iss
909 .clone()
910 .ok_or_else(|| MovementError::InvalidJwt("missing iss claim".into()))?;
911 let audience = claims
912 .aud
913 .as_ref()
914 .and_then(|aud| aud.first())
915 .map(std::string::ToString::to_string)
916 .ok_or_else(|| MovementError::InvalidJwt("missing aud claim".into()))?;
917 let user_id = claims
918 .sub
919 .clone()
920 .ok_or_else(|| MovementError::InvalidJwt("missing sub claim".into()))?;
921 let nonce = claims
922 .nonce
923 .clone()
924 .ok_or_else(|| MovementError::InvalidJwt("missing nonce claim".into()))?;
925
926 let exp_time = claims.exp.map(|exp| UNIX_EPOCH + Duration::from_secs(exp));
927 if let Some(exp) = exp_time
928 && SystemTime::now() >= exp
929 {
930 let exp_secs = claims.exp.unwrap_or(0);
931 return Err(MovementError::InvalidJwt(format!(
932 "JWT is expired (exp: {exp_secs} seconds since UNIX_EPOCH)"
933 )));
934 }
935
936 Ok((issuer, audience, user_id, exp_time, nonce))
937}
938
939fn derive_keyless_address(
940 issuer: &str,
941 audience: &str,
942 user_id: &str,
943 pepper: &Pepper,
944) -> AccountAddress {
945 let issuer_hash = sha3_256_bytes(issuer.as_bytes());
946 let audience_hash = sha3_256_bytes(audience.as_bytes());
947 let user_hash = sha3_256_bytes(user_id.as_bytes());
948
949 let mut hasher = Sha3_256::new();
950 hasher.update(issuer_hash);
951 hasher.update(audience_hash);
952 hasher.update(user_hash);
953 hasher.update(pepper.as_bytes());
954 hasher.update([KEYLESS_SCHEME]);
955 let result = hasher.finalize();
956
957 let mut address = [0u8; 32];
958 address.copy_from_slice(&result);
959 AccountAddress::new(address)
960}
961
962fn sha3_256_bytes(data: &[u8]) -> [u8; 32] {
963 let mut hasher = Sha3_256::new();
964 hasher.update(data);
965 let result = hasher.finalize();
966 let mut output = [0u8; 32];
967 output.copy_from_slice(&result);
968 output
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974 use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
975
976 struct StaticPepperService {
977 pepper: Pepper,
978 }
979
980 impl PepperService for StaticPepperService {
981 fn get_pepper(
982 &self,
983 _jwt: &str,
984 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<Pepper>> + Send + '_>>
985 {
986 Box::pin(async move { Ok(self.pepper.clone()) })
987 }
988 }
989
990 struct StaticProverService {
991 proof: ZkProof,
992 }
993
994 impl ProverService for StaticProverService {
995 fn generate_proof<'a>(
996 &'a self,
997 _jwt: &'a str,
998 _ephemeral_key: &'a EphemeralKeyPair,
999 _pepper: &'a Pepper,
1000 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MovementResult<ZkProof>> + Send + 'a>>
1001 {
1002 Box::pin(async move { Ok(self.proof.clone()) })
1003 }
1004 }
1005
1006 #[derive(Serialize, Deserialize)]
1007 struct TestClaims {
1008 iss: String,
1009 aud: String,
1010 sub: String,
1011 exp: u64,
1012 nonce: String,
1013 }
1014
1015 #[tokio::test]
1016 async fn test_keyless_account_creation() {
1017 let ephemeral = EphemeralKeyPair::generate(3600);
1018 let now = SystemTime::now()
1019 .duration_since(UNIX_EPOCH)
1020 .expect("time went backwards")
1021 .as_secs();
1022
1023 let claims = TestClaims {
1025 iss: "https://accounts.google.com".to_string(),
1026 aud: "client-id".to_string(),
1027 sub: "user-123".to_string(),
1028 exp: now + 3600,
1029 nonce: ephemeral.nonce().to_string(),
1030 };
1031
1032 let jwt = encode(
1033 &Header::new(Algorithm::HS256),
1034 &claims,
1035 &EncodingKey::from_secret(b"secret"),
1036 )
1037 .unwrap();
1038
1039 let pepper_service = StaticPepperService {
1040 pepper: Pepper::new(vec![1, 2, 3, 4]),
1041 };
1042 let prover_service = StaticProverService {
1043 proof: ZkProof::new(vec![9, 9, 9]),
1044 };
1045
1046 let exp_time = UNIX_EPOCH + std::time::Duration::from_secs(now + 3600);
1048 let account = KeylessAccount::from_verified_claims(
1049 "https://accounts.google.com".to_string(),
1050 "client-id".to_string(),
1051 "user-123".to_string(),
1052 ephemeral.nonce().to_string(),
1053 Some(exp_time),
1054 ephemeral,
1055 &pepper_service,
1056 &prover_service,
1057 &jwt,
1058 )
1059 .await
1060 .unwrap();
1061
1062 assert_eq!(account.issuer(), "https://accounts.google.com");
1063 assert_eq!(account.audience(), "client-id");
1064 assert_eq!(account.user_id(), "user-123");
1065 assert!(account.is_valid());
1066 assert!(!account.address().is_zero());
1067 }
1068
1069 #[tokio::test]
1070 async fn test_keyless_account_nonce_mismatch() {
1071 let ephemeral = EphemeralKeyPair::generate(3600);
1072 let now = SystemTime::now()
1073 .duration_since(UNIX_EPOCH)
1074 .expect("time went backwards")
1075 .as_secs();
1076
1077 let claims = TestClaims {
1078 iss: "https://accounts.google.com".to_string(),
1079 aud: "client-id".to_string(),
1080 sub: "user-123".to_string(),
1081 exp: now + 3600,
1082 nonce: ephemeral.nonce().to_string(),
1083 };
1084
1085 let jwt = encode(
1086 &Header::new(Algorithm::HS256),
1087 &claims,
1088 &EncodingKey::from_secret(b"secret"),
1089 )
1090 .unwrap();
1091
1092 let pepper_service = StaticPepperService {
1093 pepper: Pepper::new(vec![1, 2, 3, 4]),
1094 };
1095 let prover_service = StaticProverService {
1096 proof: ZkProof::new(vec![9, 9, 9]),
1097 };
1098
1099 let result = KeylessAccount::from_verified_claims(
1101 "https://accounts.google.com".to_string(),
1102 "client-id".to_string(),
1103 "user-123".to_string(),
1104 "wrong-nonce".to_string(), None,
1106 ephemeral,
1107 &pepper_service,
1108 &prover_service,
1109 &jwt,
1110 )
1111 .await;
1112
1113 assert!(result.is_err());
1114 assert!(matches!(result, Err(MovementError::InvalidJwt(_))));
1115 }
1116
1117 #[test]
1118 fn test_decode_claims_unverified() {
1119 let now = SystemTime::now()
1120 .duration_since(UNIX_EPOCH)
1121 .expect("time went backwards")
1122 .as_secs();
1123
1124 let claims = TestClaims {
1125 iss: "https://accounts.google.com".to_string(),
1126 aud: "test-aud".to_string(),
1127 sub: "test-sub".to_string(),
1128 exp: now + 3600,
1129 nonce: "test-nonce".to_string(),
1130 };
1131
1132 let jwt = encode(
1133 &Header::new(Algorithm::HS256),
1134 &claims,
1135 &EncodingKey::from_secret(b"secret"),
1136 )
1137 .unwrap();
1138
1139 let decoded = decode_claims_unverified(&jwt).unwrap();
1140 assert_eq!(decoded.iss.unwrap(), "https://accounts.google.com");
1141 assert_eq!(decoded.sub.unwrap(), "test-sub");
1142 assert_eq!(decoded.nonce.unwrap(), "test-nonce");
1143 }
1144
1145 #[test]
1146 fn test_oidc_provider_detection() {
1147 assert!(matches!(
1148 OidcProvider::from_issuer("https://accounts.google.com"),
1149 OidcProvider::Google
1150 ));
1151 assert!(matches!(
1152 OidcProvider::from_issuer("https://appleid.apple.com"),
1153 OidcProvider::Apple
1154 ));
1155 assert!(matches!(
1156 OidcProvider::from_issuer("https://unknown.example.com"),
1157 OidcProvider::Custom { .. }
1158 ));
1159 }
1160
1161 #[test]
1162 fn test_decode_and_verify_jwt_missing_kid() {
1163 let now = SystemTime::now()
1165 .duration_since(UNIX_EPOCH)
1166 .expect("time went backwards")
1167 .as_secs();
1168
1169 let claims = TestClaims {
1170 iss: "https://accounts.google.com".to_string(),
1171 aud: "test-aud".to_string(),
1172 sub: "test-sub".to_string(),
1173 exp: now + 3600,
1174 nonce: "test-nonce".to_string(),
1175 };
1176
1177 let jwt = encode(
1179 &Header::new(Algorithm::HS256),
1180 &claims,
1181 &EncodingKey::from_secret(b"secret"),
1182 )
1183 .unwrap();
1184
1185 let jwks = JwkSet { keys: vec![] };
1187
1188 let result = decode_and_verify_jwt(&jwt, &jwks);
1189 assert!(result.is_err());
1190 let err = result.unwrap_err();
1191 assert!(
1192 matches!(&err, MovementError::InvalidJwt(msg) if msg.contains("kid")),
1193 "Expected error about missing kid, got: {err:?}"
1194 );
1195 }
1196
1197 #[test]
1198 fn test_decode_and_verify_jwt_no_matching_key() {
1199 let now = SystemTime::now()
1200 .duration_since(UNIX_EPOCH)
1201 .expect("time went backwards")
1202 .as_secs();
1203
1204 let claims = TestClaims {
1205 iss: "https://accounts.google.com".to_string(),
1206 aud: "test-aud".to_string(),
1207 sub: "test-sub".to_string(),
1208 exp: now + 3600,
1209 nonce: "test-nonce".to_string(),
1210 };
1211
1212 let mut header = Header::new(Algorithm::HS256);
1214 header.kid = Some("test-kid-123".to_string());
1215
1216 let jwt = encode(&header, &claims, &EncodingKey::from_secret(b"secret")).unwrap();
1217
1218 let jwks = JwkSet { keys: vec![] };
1220
1221 let result = decode_and_verify_jwt(&jwt, &jwks);
1222 assert!(result.is_err());
1223 let err = result.unwrap_err();
1224 assert!(
1225 matches!(&err, MovementError::InvalidJwt(msg) if msg.contains("no matching key")),
1226 "Expected error about no matching key, got: {err:?}"
1227 );
1228 }
1229
1230 #[test]
1231 fn test_decode_and_verify_jwt_invalid_jwt_format() {
1232 let jwks = JwkSet { keys: vec![] };
1233
1234 let result = decode_and_verify_jwt("not-a-valid-jwt", &jwks);
1236 assert!(result.is_err());
1237
1238 let result = decode_and_verify_jwt("aaa.bbb.ccc", &jwks);
1240 assert!(result.is_err());
1241 }
1242}