diff --git a/jws/jws.go b/jws/jws.go index dd22043..f6565db 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -12,10 +12,8 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" - "crypto/x509" "encoding/base64" "encoding/json" - "encoding/pem" "errors" "fmt" "strings" @@ -123,7 +121,7 @@ func Decode(payload string) (c *ClaimSet, err error) { } // Encode encodes a signed JWS with provided header and claim set. -func Encode(header *Header, c *ClaimSet, signature []byte) (payload string, err error) { +func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (payload string, err error) { var encodedHeader, encodedClaimSet string encodedHeader, err = header.encode() if err != nil { @@ -135,14 +133,9 @@ func Encode(header *Header, c *ClaimSet, signature []byte) (payload string, err } ss := fmt.Sprintf("%s.%s", encodedHeader, encodedClaimSet) - parsed, err := parsePrivateKey(signature) - if err != nil { - return - } - h := sha256.New() h.Write([]byte(ss)) - b, err := rsa.SignPKCS1v15(rand.Reader, parsed, crypto.SHA256, h.Sum(nil)) + b, err := rsa.SignPKCS1v15(rand.Reader, signature, crypto.SHA256, h.Sum(nil)) if err != nil { return } @@ -168,26 +161,3 @@ func base64Decode(s string) ([]byte, error) { } return base64.URLEncoding.DecodeString(s) } - -// parsePrivateKey parses the key to extract the private key. -// It returns an error if private key is not provided or the -// provided key is invalid. -func parsePrivateKey(key []byte) (*rsa.PrivateKey, error) { - invalidPrivateKeyErr := errors.New("Private key is invalid.") - block, _ := pem.Decode(key) - if block == nil { - return nil, invalidPrivateKeyErr - } - parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, err - } - } - parsed, ok := parsedKey.(*rsa.PrivateKey) - if !ok { - return nil, invalidPrivateKeyErr - } - return parsed, nil -} diff --git a/jwt.go b/jwt.go index 1dace43..d1ea222 100644 --- a/jwt.go +++ b/jwt.go @@ -5,7 +5,10 @@ package oauth2 import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "errors" "io/ioutil" "net/http" @@ -28,7 +31,13 @@ type JWTOptions struct { // the configured OAuth provider. Email string `json:"email"` - // The path to the pem file. If you have a p12 file instead, you + // Private key to sign JWS payloads. + PrivateKey *rsa.PrivateKey `json:"-"` + + // The path to a pem container that includes your private key. + // If PrivateKey is set, this field is ignored. + // + // If you have a p12 file instead, you // can use `openssl` to export the private key into a pem file. // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // Pem file should contain your private key. @@ -38,8 +47,6 @@ type JWTOptions struct { Scopes []string `json:"scopes"` } -// TODO(jbd): Add p12 support. - // NewJWTConfig creates a new configuration with the specified options // and OAuth2 provider endpoint. func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) { @@ -47,19 +54,26 @@ func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) { if err != nil { return nil, err } + if opts.PrivateKey != nil { + return &JWTConfig{opts: opts, aud: audURL, key: opts.PrivateKey}, nil + } contents, err := ioutil.ReadFile(opts.PemFilename) if err != nil { return nil, err } - return &JWTConfig{opts: opts, aud: audURL, signature: contents}, nil + parsedKey, err := parsePemKey(contents) + if err != nil { + return nil, err + } + return &JWTConfig{opts: opts, aud: audURL, key: parsedKey}, nil } // JWTConfig represents an OAuth 2.0 provider and client options to // provide authorized transports with a Bearer JWT token. type JWTConfig struct { - opts *JWTOptions - aud *url.URL - signature []byte + opts *JWTOptions + aud *url.URL + key *rsa.PrivateKey } // NewTransport creates a transport that is authorize with the @@ -94,7 +108,7 @@ func (c *JWTConfig) FetchToken(existing *Token) (token *Token, err error) { claimSet.Prn = existing.Subject } - payload, err := jws.Encode(defaultHeader, claimSet, c.signature) + payload, err := jws.Encode(defaultHeader, claimSet, c.key) if err != nil { return } @@ -140,3 +154,26 @@ func (c *JWTConfig) FetchToken(existing *Token) (token *Token, err error) { token.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second) return } + +// parsePemKey parses the pem file to extract the private key. +// It returns an error if private key is not provided or the +// provided key is invalid. +func parsePemKey(key []byte) (*rsa.PrivateKey, error) { + invalidPrivateKeyErr := errors.New("oauth2: private key is invalid") + block, _ := pem.Decode(key) + if block == nil { + return nil, invalidPrivateKeyErr + } + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + } + parsed, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, invalidPrivateKeyErr + } + return parsed, nil +}