From ce56909505b351a755ad0bc294c5ee01ed0ea050 Mon Sep 17 00:00:00 2001 From: Julien Cretel Date: Tue, 11 Mar 2025 18:26:19 +0000 Subject: [PATCH] jws: improve fix for CVE-2025-22868 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fix for CVE-2025-22868 relies on strings.Count, which isn't ideal because it precludes failing fast when the token contains an unexpected number of periods. Moreover, Verify still allocates more than necessary. Eschew strings.Count in favor of strings.Cut. Some benchmark results: goos: darwin goarch: amd64 pkg: golang.org/x/oauth2/jws cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz │ old │ new │ │ sec/op │ sec/op vs base │ Verify/full_of_periods-8 24862.50n ± 1% 57.87n ± 0% -99.77% (p=0.000 n=20) Verify/two_trailing_periods-8 3.485m ± 1% 3.445m ± 1% -1.13% (p=0.003 n=20) geomean 294.3µ 14.12µ -95.20% │ old │ new │ │ B/op │ B/op vs base │ Verify/full_of_periods-8 16.00 ± 0% 16.00 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 2.001Mi ± 0% 1.001Mi ± 0% -49.98% (p=0.000 n=20) geomean 5.658Ki 4.002Ki -29.27% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ Verify/full_of_periods-8 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 12.000 ± 0% 9.000 ± 0% -25.00% (p=0.000 n=20) geomean 3.464 3.000 -13.40% ¹ all samples are equal Also, remove all remaining calls to strings.Split. Updates golang/go#71490 Change-Id: Icac3c7a81562161ab6533d892ba19247d6d5b943 GitHub-Last-Rev: 3a82900f747798f5f36065126385880277c0fce7 GitHub-Pull-Request: golang/oauth2#774 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/655455 Commit-Queue: Neal Patel Reviewed-by: Roland Shoemaker LUCI-TryBot-Result: Go LUCI Reviewed-by: Neal Patel Auto-Submit: Neal Patel --- jws/jws.go | 34 +++++++++++++++++++++-------- jws/jws_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/jws/jws.go b/jws/jws.go index 6f03a49..27ab061 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) { // Decode decodes a claim set from a JWS payload. func Decode(payload string) (*ClaimSet, error) { // decode returned id token to get expiry - s := strings.Split(payload, ".") - if len(s) < 2 { + _, claims, _, ok := parseToken(payload) + if !ok { // TODO(jbd): Provide more context about the error. return nil, errors.New("jws: invalid token received") } - decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + decoded, err := base64.RawURLEncoding.DecodeString(claims) if err != nil { return nil, err } @@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { // Verify tests whether the provided JWT token's signature was produced by the private key // associated with the supplied public key. func Verify(token string, key *rsa.PublicKey) error { - if strings.Count(token, ".") != 2 { + header, claims, sig, ok := parseToken(token) + if !ok { return errors.New("jws: invalid token received, token must have 3 parts") } - - parts := strings.SplitN(token, ".", 3) - signedContent := parts[0] + "." + parts[1] - signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) + signatureString, err := base64.RawURLEncoding.DecodeString(sig) if err != nil { return err } h := sha256.New() - h.Write([]byte(signedContent)) + h.Write([]byte(header + tokenDelim + claims)) return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) } + +func parseToken(s string) (header, claims, sig string, ok bool) { + header, s, ok = strings.Cut(s, tokenDelim) + if !ok { // no period found + return "", "", "", false + } + claims, s, ok = strings.Cut(s, tokenDelim) + if !ok { // only one period found + return "", "", "", false + } + sig, _, ok = strings.Cut(s, tokenDelim) + if ok { // three periods found + return "", "", "", false + } + return header, claims, sig, true +} + +const tokenDelim = "." diff --git a/jws/jws_test.go b/jws/jws_test.go index 39a136a..1776f56 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -7,6 +7,8 @@ package jws import ( "crypto/rand" "crypto/rsa" + "net/http" + "strings" "testing" ) @@ -39,8 +41,57 @@ func TestSignAndVerify(t *testing.T) { } func TestVerifyFailsOnMalformedClaim(t *testing.T) { - err := Verify("abc.def", nil) - if err == nil { - t.Error("got no errors; want improperly formed JWT not to be verified") + cases := []struct { + desc string + token string + }{ + { + desc: "no periods", + token: "aa", + }, { + desc: "only one period", + token: "a.a", + }, { + desc: "more than two periods", + token: "a.a.a.a", + }, + } + for _, tc := range cases { + f := func(t *testing.T) { + err := Verify(tc.token, nil) + if err == nil { + t.Error("got no errors; want improperly formed JWT not to be verified") + } + } + t.Run(tc.desc, f) + } +} + +func BenchmarkVerify(b *testing.B) { + cases := []struct { + desc string + token string + }{ + { + desc: "full of periods", + token: strings.Repeat(".", http.DefaultMaxHeaderBytes), + }, { + desc: "two trailing periods", + token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..", + }, + } + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + b.Fatal(err) + } + for _, bc := range cases { + f := func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for range b.N { + Verify(bc.token, &privateKey.PublicKey) + } + } + b.Run(bc.desc, f) } }