mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
jws: improve fix for CVE-2025-22868
The fix for CVE-2025-22868 relies on strings.Count, which isn't ideal because it precludes failing fast when the token contains an unexpected number of periods. Moreover, Verify still allocates more than necessary. Eschew strings.Count in favor of strings.Cut. Some benchmark results: goos: darwin goarch: amd64 pkg: golang.org/x/oauth2/jws cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz │ old │ new │ │ sec/op │ sec/op vs base │ Verify/full_of_periods-8 24862.50n ± 1% 57.87n ± 0% -99.77% (p=0.000 n=20) Verify/two_trailing_periods-8 3.485m ± 1% 3.445m ± 1% -1.13% (p=0.003 n=20) geomean 294.3µ 14.12µ -95.20% │ old │ new │ │ B/op │ B/op vs base │ Verify/full_of_periods-8 16.00 ± 0% 16.00 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 2.001Mi ± 0% 1.001Mi ± 0% -49.98% (p=0.000 n=20) geomean 5.658Ki 4.002Ki -29.27% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ Verify/full_of_periods-8 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 12.000 ± 0% 9.000 ± 0% -25.00% (p=0.000 n=20) geomean 3.464 3.000 -13.40% ¹ all samples are equal Also, remove all remaining calls to strings.Split. Updates golang/go#71490 Change-Id: Icac3c7a81562161ab6533d892ba19247d6d5b943 GitHub-Last-Rev: 3a82900f747798f5f36065126385880277c0fce7 GitHub-Pull-Request: golang/oauth2#774 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/655455 Commit-Queue: Neal Patel <nealpatel@google.com> Reviewed-by: Roland Shoemaker <roland@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Neal Patel <nealpatel@google.com> Auto-Submit: Neal Patel <nealpatel@google.com>
This commit is contained in:
parent
0042180b24
commit
ce56909505
34
jws/jws.go
34
jws/jws.go
@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
|
|||||||
// Decode decodes a claim set from a JWS payload.
|
// Decode decodes a claim set from a JWS payload.
|
||||||
func Decode(payload string) (*ClaimSet, error) {
|
func Decode(payload string) (*ClaimSet, error) {
|
||||||
// decode returned id token to get expiry
|
// decode returned id token to get expiry
|
||||||
s := strings.Split(payload, ".")
|
_, claims, _, ok := parseToken(payload)
|
||||||
if len(s) < 2 {
|
if !ok {
|
||||||
// TODO(jbd): Provide more context about the error.
|
// TODO(jbd): Provide more context about the error.
|
||||||
return nil, errors.New("jws: invalid token received")
|
return nil, errors.New("jws: invalid token received")
|
||||||
}
|
}
|
||||||
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
|
decoded, err := base64.RawURLEncoding.DecodeString(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
|
|||||||
// Verify tests whether the provided JWT token's signature was produced by the private key
|
// Verify tests whether the provided JWT token's signature was produced by the private key
|
||||||
// associated with the supplied public key.
|
// associated with the supplied public key.
|
||||||
func Verify(token string, key *rsa.PublicKey) error {
|
func Verify(token string, key *rsa.PublicKey) error {
|
||||||
if strings.Count(token, ".") != 2 {
|
header, claims, sig, ok := parseToken(token)
|
||||||
|
if !ok {
|
||||||
return errors.New("jws: invalid token received, token must have 3 parts")
|
return errors.New("jws: invalid token received, token must have 3 parts")
|
||||||
}
|
}
|
||||||
|
signatureString, err := base64.RawURLEncoding.DecodeString(sig)
|
||||||
parts := strings.SplitN(token, ".", 3)
|
|
||||||
signedContent := parts[0] + "." + parts[1]
|
|
||||||
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write([]byte(signedContent))
|
h.Write([]byte(header + tokenDelim + claims))
|
||||||
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
|
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseToken(s string) (header, claims, sig string, ok bool) {
|
||||||
|
header, s, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if !ok { // no period found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
claims, s, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if !ok { // only one period found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
sig, _, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if ok { // three periods found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
return header, claims, sig, true
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenDelim = "."
|
||||||
|
@ -7,6 +7,8 @@ package jws
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,8 +41,57 @@ func TestSignAndVerify(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
|
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
|
||||||
err := Verify("abc.def", nil)
|
cases := []struct {
|
||||||
if err == nil {
|
desc string
|
||||||
t.Error("got no errors; want improperly formed JWT not to be verified")
|
token string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "no periods",
|
||||||
|
token: "aa",
|
||||||
|
}, {
|
||||||
|
desc: "only one period",
|
||||||
|
token: "a.a",
|
||||||
|
}, {
|
||||||
|
desc: "more than two periods",
|
||||||
|
token: "a.a.a.a",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
f := func(t *testing.T) {
|
||||||
|
err := Verify(tc.token, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("got no errors; want improperly formed JWT not to be verified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Run(tc.desc, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkVerify(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
token string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "full of periods",
|
||||||
|
token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
|
||||||
|
}, {
|
||||||
|
desc: "two trailing periods",
|
||||||
|
token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, bc := range cases {
|
||||||
|
f := func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
Verify(bc.token, &privateKey.PublicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Run(bc.desc, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user