Add --jwt-id flag (#13218)

* add jwt-id flag

* optimize unit test for jwt-id

* Add jwt-id to help text

* gofmt

---------

Co-authored-by: Preston Van Loon <pvanloon@offchainlabs.com>
This commit is contained in:
Brandon Liu
2023-12-06 03:02:25 +08:00
committed by GitHub
parent 705e98e3c3
commit c78d698d89
11 changed files with 117 additions and 6 deletions

View File

@@ -24,6 +24,7 @@ const DefaultRPCHTTPTimeout = time.Second * 30
type jwtTransport struct {
underlyingTransport http.RoundTripper
jwtSecret []byte
jwtId string
}
// RoundTrip ensures our transport implements http.RoundTripper interface from the
@@ -32,12 +33,16 @@ type jwtTransport struct {
// an JWT bearer token in the Authorization request header of every outgoing request
// our HTTP client makes.
func (t *jwtTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
claims := jwt.MapClaims{
// Required claim for engine API auth. "iat" stands for issued at
// and it must be a unix timestamp that is +/- 5 seconds from the current
// timestamp at the moment the server verifies this value.
"iat": time.Now().Unix(),
})
}
if len(t.jwtId) > 0 {
claims["id"] = t.jwtId
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(t.jwtSecret)
if err != nil {
return nil, errors.Wrap(err, "could not produce signed JWT token")

View File

@@ -51,3 +51,91 @@ func TestJWTAuthTransport(t *testing.T) {
_, err := client.Get(srv.URL)
require.NoError(t, err)
}
func TestJWTWithId(t *testing.T) {
secret := bytesutil.PadTo([]byte("foo"), 32)
jwtId := "abc"
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: secret,
jwtId: jwtId,
}
client := &http.Client{
Timeout: DefaultRPCHTTPTimeout,
Transport: authTransport,
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqToken := r.Header.Get("Authorization")
splitToken := strings.Split(reqToken, "Bearer")
// The format should be `Bearer ${token}`.
require.Equal(t, 2, len(splitToken))
reqToken = strings.TrimSpace(splitToken[1])
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
// We should be doing HMAC signing.
_, ok := token.Method.(*jwt.SigningMethodHMAC)
require.Equal(t, true, ok)
return secret, nil
})
require.NoError(t, err)
require.Equal(t, true, token.Valid)
claims, ok := token.Claims.(jwt.MapClaims)
require.Equal(t, true, ok)
item, ok := claims["iat"]
require.Equal(t, true, ok)
iat, ok := item.(float64)
require.Equal(t, true, ok)
issuedAt := time.Unix(int64(iat), 0)
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
since := time.Since(issuedAt)
require.Equal(t, true, since <= time.Second*5)
// check jwt claims id
id, ok := claims["id"]
require.Equal(t, true, ok)
require.Equal(t, id, jwtId)
}))
defer srv.Close()
_, err := client.Get(srv.URL)
require.NoError(t, err)
}
func TestJWTWithoutId(t *testing.T) {
secret := bytesutil.PadTo([]byte("foo"), 32)
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: secret,
}
client := &http.Client{
Timeout: DefaultRPCHTTPTimeout,
Transport: authTransport,
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqToken := r.Header.Get("Authorization")
splitToken := strings.Split(reqToken, "Bearer")
// The format should be `Bearer ${token}`.
require.Equal(t, 2, len(splitToken))
reqToken = strings.TrimSpace(splitToken[1])
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
// We should be doing HMAC signing.
_, ok := token.Method.(*jwt.SigningMethodHMAC)
require.Equal(t, true, ok)
return secret, nil
})
require.NoError(t, err)
require.Equal(t, true, token.Valid)
claims, ok := token.Claims.(jwt.MapClaims)
require.Equal(t, true, ok)
item, ok := claims["iat"]
require.Equal(t, true, ok)
iat, ok := item.(float64)
require.Equal(t, true, ok)
issuedAt := time.Unix(int64(iat), 0)
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
since := time.Since(issuedAt)
require.Equal(t, true, since <= time.Second*5)
_, ok = claims["id"]
require.Equal(t, false, ok)
}))
defer srv.Close()
_, err := client.Get(srv.URL)
require.NoError(t, err)
}

View File

@@ -24,6 +24,7 @@ type Endpoint struct {
type AuthorizationData struct {
Method authorization.Method
Value string
JwtId string
}
// Equals compares two endpoints for equality.
@@ -37,7 +38,7 @@ func (e Endpoint) HttpClient() *http.Client {
if e.Auth.Method != authorization.Bearer {
return http.DefaultClient
}
return NewHttpClientWithSecret(e.Auth.Value)
return NewHttpClientWithSecret(e.Auth.Value, e.Auth.JwtId)
}
// Equals compares two authorization data objects for equality.
@@ -112,10 +113,11 @@ func Method(auth string) authorization.Method {
// NewHttpClientWithSecret returns a http client that utilizes
// jwt authentication.
func NewHttpClientWithSecret(secret string) *http.Client {
func NewHttpClientWithSecret(secret, id string) *http.Client {
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: []byte(secret),
jwtId: id,
}
return &http.Client{
Timeout: DefaultRPCHTTPTimeout,