diff --git a/cmd/booster-http/gateway_handler.go b/cmd/booster-http/gateway_handler.go index e0a998d3f..18a2a8306 100644 --- a/cmd/booster-http/gateway_handler.go +++ b/cmd/booster-http/gateway_handler.go @@ -69,7 +69,13 @@ func (h *gatewayHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // TODO: Allow this to be configurable expectedPaymentAmount := big.NewInt(10) - if hasPaid := checkPaymentChannelBalance(h.nitroRpcClient, chId, expectedPaymentAmount); !hasPaid { + hasPaid, err := checkPaymentChannelBalance(h.nitroRpcClient, chId, expectedPaymentAmount) + if err != nil { + webError(w, err, http.StatusPaymentRequired) + return + } + + if !hasPaid { webError(w, fmt.Errorf("payment of %d required", expectedPaymentAmount.Uint64()), http.StatusPaymentRequired) return } diff --git a/cmd/booster-http/server.go b/cmd/booster-http/server.go index 207daaf4e..429af4448 100644 --- a/cmd/booster-http/server.go +++ b/cmd/booster-http/server.go @@ -203,8 +203,14 @@ func (s *HttpServer) handleByPieceCid(w http.ResponseWriter, r *http.Request) { // TODO: Allow this to be configurable expectedPaymentAmount := big.NewInt(10) - if hasPaid := checkPaymentChannelBalance(s.nitroRpcClient, chId, expectedPaymentAmount); !hasPaid { - writeError(w, r, http.StatusPaymentRequired, "payment required") + hasPaid, err := checkPaymentChannelBalance(s.nitroRpcClient, chId, expectedPaymentAmount) + if err != nil { + writeError(w, r, http.StatusInternalServerError, err.Error()) + return + } + + if !hasPaid { + writeError(w, r, http.StatusPaymentRequired, fmt.Sprintf("payment of %d required", expectedPaymentAmount.Uint64())) return } } diff --git a/cmd/booster-http/util.go b/cmd/booster-http/util.go index 4041530bd..7560a26d2 100644 --- a/cmd/booster-http/util.go +++ b/cmd/booster-http/util.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "errors" "math/big" "net/http" @@ -18,13 +19,12 @@ func addCommas(count uint64) string { } // checkPaymentChannelBalance checks a payment channel balance and returns true if the AmountPaid is greater than the expected amount -func checkPaymentChannelBalance(rpcClient *rpc.RpcClient, paymentChannelId types.Destination, expectedAmount *big.Int) bool { +func checkPaymentChannelBalance(rpcClient *rpc.RpcClient, paymentChannelId types.Destination, expectedAmount *big.Int) (bool, error) { if rpcClient == nil { - panic("the rpcClient is nil") + return false, errors.New("the rpcClient is nil") } payCh := rpcClient.GetVirtualChannel(paymentChannelId) - return payCh.Balance.PaidSoFar.ToInt().Cmp(expectedAmount) > 0 - + return (payCh.Balance.PaidSoFar.ToInt().Cmp(expectedAmount) > 0), nil } type corsHandler struct {