added JSON marshaling support to the Status type (#323)

This commit is contained in:
Dmitry Holodov
2023-02-22 08:25:28 -06:00
committed by GitHub
parent 3d92587e21
commit 8ca306af68
10 changed files with 117 additions and 44 deletions

View File

@@ -811,7 +811,7 @@ func runGetStage(ctx *cli.Context) error {
return err
}
fmt.Printf("Stage=%s: %s\n", resp.Stage, resp.Info)
fmt.Printf("Stage=%s: %s\n", resp.Stage, resp.Description)
return nil
}

View File

@@ -29,7 +29,7 @@ type SubscribeSwapStatusRequest struct {
// SubscribeSwapStatusResponse ...
type SubscribeSwapStatusResponse struct {
Status string `json:"status"`
Status types.Status `json:"status" validate:"required"`
}
// DiscoverRequest ...

View File

@@ -1,16 +1,23 @@
// Package types is for types that are shared by multiple packages
package types
import (
"fmt"
)
// Status represents the stage that a swap is at.
type Status byte
// Status values
const (
// UnknownStatus is a placeholder for unmatched status strings and
// uninitialized variables
UnknownStatus Status = iota
// ExpectingKeys is the status of the taker between taking an offer and
// receiving a response with swap keys from the maker. It is also the
// maker's status after creating an offer up until receiving keys from a
// taker accepting the offer.
ExpectingKeys Status = iota
ExpectingKeys
// KeysExchanged is the status of the maker after a taker accepts his offer.
KeysExchanged
// ETHLocked is the taker status after locking her ETH up until confirming
@@ -29,8 +36,6 @@ const (
// CompletedAbort represents the case where the swap aborts before any funds
// are locked.
CompletedAbort
// UnknownStatus is a placeholder for unmatched status strings.
UnknownStatus
)
const unknownString string = "unknown"
@@ -60,7 +65,7 @@ func NewStatus(str string) Status {
}
}
// String ...
// String returns the status as a text string.
func (s Status) String() string {
switch s {
case ExpectingKeys:
@@ -84,8 +89,27 @@ func (s Status) String() string {
}
}
// Info returns a description of the swap stage.
func (s Status) Info() string {
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (s *Status) UnmarshalText(data []byte) error {
newStatus := NewStatus(string(data))
if newStatus == UnknownStatus {
return fmt.Errorf("unknown status %q", string(data))
}
*s = newStatus
return nil
}
// MarshalText implements the encoding.TextMarshaler interface.
func (s Status) MarshalText() ([]byte, error) {
textStr := s.String()
if textStr == unknownString {
return nil, fmt.Errorf("unknown status %d", s)
}
return []byte(textStr), nil
}
// Description returns a description of the swap stage.
func (s Status) Description() string {
switch s {
case ExpectingKeys:
return "keys have not yet been exchanged"

View File

@@ -0,0 +1,45 @@
package types
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestMarshalStatus(t *testing.T) {
type S struct {
Status Status `json:"status"`
}
const jsonText = `{
"status": "XMRLocked"
}`
s := new(S)
err := json.Unmarshal([]byte(jsonText), s)
require.NoError(t, err)
require.Equal(t, XMRLocked, s.Status)
jsonData, err := json.Marshal(s)
require.NoError(t, err)
require.JSONEq(t, jsonText, string(jsonData))
}
func TestUnmarshalStatus_fail(t *testing.T) {
type S struct {
Status Status `json:"status"`
}
const jsonText = `{
"status": "Garbage"
}`
s := new(S)
err := json.Unmarshal([]byte(jsonText), s)
require.ErrorContains(t, err, `unknown status "Garbage"`)
s.Status = 255 // not a valid value
_, err = json.Marshal(s)
require.ErrorContains(t, err, `unknown status 255`)
}

View File

@@ -28,6 +28,7 @@ func TestDatabase_OfferTable(t *testing.T) {
infoA := &swap.Info{
ID: types.Hash{0x1},
Provides: coins.ProvidesXMR,
Status: types.ExpectingKeys,
}
err = db.PutSwap(infoA)
require.NoError(t, err)
@@ -82,6 +83,7 @@ func TestDatabase_GetAllOffers_InvalidEntry(t *testing.T) {
swapEntry := &swap.Info{
ID: badOfferID,
Provides: coins.ProvidesXMR,
Status: types.KeysExchanged,
}
err = db.PutSwap(swapEntry)
require.NoError(t, err)
@@ -139,6 +141,7 @@ func TestDatabase_SwapTable(t *testing.T) {
ID: types.Hash{0x1},
Version: swap.CurInfoVersion,
Provides: coins.ProvidesXMR,
Status: types.ContractReady,
}
err = db.PutSwap(infoA)
require.NoError(t, err)
@@ -147,6 +150,7 @@ func TestDatabase_SwapTable(t *testing.T) {
ID: types.Hash{0x2},
Version: swap.CurInfoVersion,
Provides: coins.ProvidesXMR,
Status: types.XMRLocked,
}
err = db.PutSwap(infoB)
require.NoError(t, err)
@@ -175,7 +179,7 @@ func TestDatabase_GetAllSwaps_InvalidEntry(t *testing.T) {
ExpectedAmount: coins.StrToDecimal("0.15"),
ExchangeRate: coins.ToExchangeRate(coins.StrToDecimal("0.1")),
EthAsset: types.EthAsset{},
Status: 0,
Status: types.ETHLocked,
MoneroStartHeight: 0,
}
err = db.PutSwap(goodInfo)
@@ -224,14 +228,15 @@ func TestDatabase_SwapTable_Update(t *testing.T) {
infoA := &swap.Info{
ID: id,
Provides: coins.ProvidesXMR,
Status: types.XMRLocked,
}
err = db.PutSwap(infoA)
require.NoError(t, err)
infoB := &swap.Info{
ID: id,
Status: types.CompletedSuccess,
Provides: coins.ProvidesXMR,
Status: types.CompletedSuccess,
}
err = db.PutSwap(infoB)

View File

@@ -38,7 +38,7 @@ func Test_InfoMarshal(t *testing.T) {
"exchangeRate": "0.33",
"ethAsset": "ETH",
"moneroStartHeight": 200,
"status": 5
"status": "Success"
}`
require.JSONEq(t, expectedJSON, string(infoBytes))
}

View File

@@ -187,7 +187,7 @@ func (s *NetService) takeOffer(who peer.ID, offerID types.Hash, providesAmount *
// TakeOfferSyncResponse ...
type TakeOfferSyncResponse struct {
Status string `json:"status"`
Status types.Status `json:"status" validate:"required"`
}
// TakeOfferSync initiates a swap with the given peer by taking an offer they've made.
@@ -216,7 +216,7 @@ func (s *NetService) TakeOfferSync(
continue
}
resp.Status = info.Status.String()
resp.Status = info.Status
break
}

View File

@@ -74,7 +74,7 @@ type GetPastResponse struct {
ProvidedAmount *apd.Decimal `json:"providedAmount"`
ExpectedAmount *apd.Decimal `json:"expectedAmount"`
ExchangeRate *coins.ExchangeRate `json:"exchangeRate"`
Status string `json:"status"`
Status types.Status `json:"status" validate:"required"`
}
// GetPast returns information about a past swap, given its ID.
@@ -93,7 +93,7 @@ func (s *SwapService) GetPast(_ *http.Request, req *GetPastRequest, resp *GetPas
resp.ProvidedAmount = info.ProvidedAmount
resp.ExpectedAmount = info.ExpectedAmount
resp.ExchangeRate = info.ExchangeRate
resp.Status = info.Status.String()
resp.Status = info.Status
return nil
}
@@ -103,7 +103,7 @@ type GetOngoingResponse struct {
ProvidedAmount *apd.Decimal `json:"providedAmount"`
ExpectedAmount *apd.Decimal `json:"expectedAmount"`
ExchangeRate *coins.ExchangeRate `json:"exchangeRate"`
Status string `json:"status"`
Status types.Status `json:"status" validate:"required"`
}
// GetOngoingRequest ...
@@ -127,7 +127,7 @@ func (s *SwapService) GetOngoing(_ *http.Request, req *GetOngoingRequest, resp *
resp.ProvidedAmount = info.ProvidedAmount
resp.ExpectedAmount = info.ExpectedAmount
resp.ExchangeRate = info.ExchangeRate
resp.Status = info.Status.String()
resp.Status = info.Status
return nil
}
@@ -174,8 +174,8 @@ type GetStageRequest struct {
// GetStageResponse ...
type GetStageResponse struct {
Stage string `json:"stage"`
Info string `json:"info"`
Stage types.Status `json:"stage" validate:"required"`
Description string `json:"description" validate:"required"`
}
// GetStage returns the stage of the ongoing swap, if there is one.
@@ -190,8 +190,8 @@ func (s *SwapService) GetStage(_ *http.Request, req *GetStageRequest, resp *GetS
return err
}
resp.Stage = info.Status.String()
resp.Info = info.Status.Info()
resp.Stage = info.Status
resp.Description = info.Status.Description()
return nil
}

View File

@@ -227,7 +227,7 @@ func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
}
resp := &rpctypes.SubscribeSwapStatusResponse{
Status: status.String(),
Status: status,
}
if err := writeResponse(conn, resp); err != nil {
@@ -262,7 +262,7 @@ func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn,
}
resp := &rpctypes.SubscribeSwapStatusResponse{
Status: status.String(),
Status: status,
}
if err := writeResponse(conn, resp); err != nil {
@@ -298,7 +298,7 @@ func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn
}
resp := &rpctypes.SubscribeSwapStatusResponse{
Status: status.String(),
Status: status,
}
if err := writeResponse(conn, resp); err != nil {
@@ -321,7 +321,7 @@ func (s *wsServer) writeSwapExitStatus(conn *websocket.Conn, id types.Hash) erro
}
resp := &rpctypes.SubscribeSwapStatusResponse{
Status: info.Status.String(),
Status: info.Status,
}
if err := writeResponse(conn, resp); err != nil {

View File

@@ -224,15 +224,15 @@ func (c *wsClient) SubscribeSwapStatus(id types.Hash) (<-chan types.Status, erro
}
log.Debugf("received message over websockets: %s", message)
status := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, status); err != nil {
statusResp := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, statusResp); err != nil {
log.Warnf("failed to unmarshal response: %s", err)
break
}
s := types.NewStatus(status.Status)
respCh <- s
if !s.IsOngoing() {
status := statusResp.Status
respCh <- status
if !status.IsOngoing() {
return
}
}
@@ -280,9 +280,8 @@ func (c *wsClient) TakeOfferAndSubscribe(
defer close(respCh)
for {
s := types.NewStatus(status)
respCh <- s
if !s.IsOngoing() {
respCh <- status
if !status.IsOngoing() {
return
}
@@ -297,29 +296,29 @@ func (c *wsClient) TakeOfferAndSubscribe(
return respCh, nil
}
func (c *wsClient) readTakeOfferResponse() (string, error) {
func (c *wsClient) readTakeOfferResponse() (types.Status, error) {
message, err := c.read()
if err != nil {
return "", fmt.Errorf("failed to read websockets message: %s", err)
return 0, fmt.Errorf("failed to read websockets message: %s", err)
}
resp := new(rpctypes.Response)
err = vjson.UnmarshalStruct(message, resp)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response: %w", err)
return 0, fmt.Errorf("failed to unmarshal response: %w", err)
}
if resp.Error != nil {
return "", fmt.Errorf("websocket server returned error: %w", resp.Error)
return 0, fmt.Errorf("websocket server returned error: %w", resp.Error)
}
log.Debugf("received message over websockets: %s", message)
status := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, status); err != nil {
return "", fmt.Errorf("failed to unmarshal swap status response: %w", err)
statusResp := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, statusResp); err != nil {
return 0, fmt.Errorf("failed to unmarshal swap status response: %w", err)
}
return status.Status, nil
return statusResp.Status, nil
}
func (c *wsClient) MakeOfferAndSubscribe(
@@ -402,13 +401,13 @@ func (c *wsClient) MakeOfferAndSubscribe(
}
log.Debugf("received message over websockets: %s", message)
status := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, status); err != nil {
statusResp := new(rpctypes.SubscribeSwapStatusResponse)
if err := vjson.UnmarshalStruct(resp.Result, statusResp); err != nil {
log.Warnf("failed to unmarshal response: %s", err)
break
}
s := types.NewStatus(status.Status)
s := statusResp.Status
respCh <- s
if !s.IsOngoing() {
return