mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 23:18:15 -05:00
Add --jwt-id flag (#13218)
* add jwt-id flag * optimize unit test for jwt-id * Add jwt-id to help text * gofmt --------- Co-authored-by: Preston Van Loon <pvanloon@offchainlabs.com>
This commit is contained in:
@@ -24,6 +24,7 @@ const DefaultRPCHTTPTimeout = time.Second * 30
|
||||
type jwtTransport struct {
|
||||
underlyingTransport http.RoundTripper
|
||||
jwtSecret []byte
|
||||
jwtId string
|
||||
}
|
||||
|
||||
// RoundTrip ensures our transport implements http.RoundTripper interface from the
|
||||
@@ -32,12 +33,16 @@ type jwtTransport struct {
|
||||
// an JWT bearer token in the Authorization request header of every outgoing request
|
||||
// our HTTP client makes.
|
||||
func (t *jwtTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
claims := jwt.MapClaims{
|
||||
// Required claim for engine API auth. "iat" stands for issued at
|
||||
// and it must be a unix timestamp that is +/- 5 seconds from the current
|
||||
// timestamp at the moment the server verifies this value.
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
if len(t.jwtId) > 0 {
|
||||
claims["id"] = t.jwtId
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(t.jwtSecret)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not produce signed JWT token")
|
||||
|
||||
@@ -51,3 +51,91 @@ func TestJWTAuthTransport(t *testing.T) {
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJWTWithId(t *testing.T) {
|
||||
secret := bytesutil.PadTo([]byte("foo"), 32)
|
||||
jwtId := "abc"
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: secret,
|
||||
jwtId: jwtId,
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
Transport: authTransport,
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqToken := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(reqToken, "Bearer")
|
||||
// The format should be `Bearer ${token}`.
|
||||
require.Equal(t, 2, len(splitToken))
|
||||
reqToken = strings.TrimSpace(splitToken[1])
|
||||
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// We should be doing HMAC signing.
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
require.Equal(t, true, ok)
|
||||
return secret, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, token.Valid)
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
require.Equal(t, true, ok)
|
||||
item, ok := claims["iat"]
|
||||
require.Equal(t, true, ok)
|
||||
iat, ok := item.(float64)
|
||||
require.Equal(t, true, ok)
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
|
||||
since := time.Since(issuedAt)
|
||||
require.Equal(t, true, since <= time.Second*5)
|
||||
// check jwt claims id
|
||||
id, ok := claims["id"]
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, id, jwtId)
|
||||
}))
|
||||
defer srv.Close()
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJWTWithoutId(t *testing.T) {
|
||||
secret := bytesutil.PadTo([]byte("foo"), 32)
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: secret,
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
Transport: authTransport,
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqToken := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(reqToken, "Bearer")
|
||||
// The format should be `Bearer ${token}`.
|
||||
require.Equal(t, 2, len(splitToken))
|
||||
reqToken = strings.TrimSpace(splitToken[1])
|
||||
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// We should be doing HMAC signing.
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
require.Equal(t, true, ok)
|
||||
return secret, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, token.Valid)
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
require.Equal(t, true, ok)
|
||||
item, ok := claims["iat"]
|
||||
require.Equal(t, true, ok)
|
||||
iat, ok := item.(float64)
|
||||
require.Equal(t, true, ok)
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
|
||||
since := time.Since(issuedAt)
|
||||
require.Equal(t, true, since <= time.Second*5)
|
||||
_, ok = claims["id"]
|
||||
require.Equal(t, false, ok)
|
||||
}))
|
||||
defer srv.Close()
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type Endpoint struct {
|
||||
type AuthorizationData struct {
|
||||
Method authorization.Method
|
||||
Value string
|
||||
JwtId string
|
||||
}
|
||||
|
||||
// Equals compares two endpoints for equality.
|
||||
@@ -37,7 +38,7 @@ func (e Endpoint) HttpClient() *http.Client {
|
||||
if e.Auth.Method != authorization.Bearer {
|
||||
return http.DefaultClient
|
||||
}
|
||||
return NewHttpClientWithSecret(e.Auth.Value)
|
||||
return NewHttpClientWithSecret(e.Auth.Value, e.Auth.JwtId)
|
||||
}
|
||||
|
||||
// Equals compares two authorization data objects for equality.
|
||||
@@ -112,10 +113,11 @@ func Method(auth string) authorization.Method {
|
||||
|
||||
// NewHttpClientWithSecret returns a http client that utilizes
|
||||
// jwt authentication.
|
||||
func NewHttpClientWithSecret(secret string) *http.Client {
|
||||
func NewHttpClientWithSecret(secret, id string) *http.Client {
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: []byte(secret),
|
||||
jwtId: id,
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
|
||||
Reference in New Issue
Block a user