Support bidirectional requests over gateway

This commit is contained in:
Elara 2022-04-24 00:54:04 -07:00
parent 9034ef7c6b
commit 4b6f7d408e
1 changed files with 102 additions and 35 deletions

137
socket.go
View File

@ -19,10 +19,13 @@
package main
import (
"bytes"
"context"
"errors"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@ -31,6 +34,7 @@ import (
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/smallnest/rpcx/server"
"github.com/smallnest/rpcx/share"
"github.com/vmihailenco/msgpack/v5"
"go.arsenm.dev/infinitime"
"go.arsenm.dev/infinitime/blefs"
@ -46,7 +50,7 @@ var (
ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type")
ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway")
ErrRPCXNoReturnURL = errors.New("bidirectional requests over gateway require a returnURL field in the metadata")
)
type DoneMap map[string]chan struct{}
@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error {
}
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available
if !ok {
return ErrRPCXUsingGateway
return ErrRPCXNoReturnURL
}
heartRateCh, cancel, err := i.dev.WatchHeartRate()
@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
}
// Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data)
msgSender.SendMessage(id, "HeartRateSample", nil, data)
}
}
}()
@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error {
}
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available
if !ok {
return ErrRPCXUsingGateway
return ErrRPCXNoReturnURL
}
battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error
}
// Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data)
msgSender.SendMessage(id, "BatteryLevelSample", nil, data)
}
}
}()
@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er
}
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available
if !ok {
return ErrRPCXUsingGateway
return ErrRPCXNoReturnURL
}
motionValsCh, cancel, err := i.dev.WatchMotion()
@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
}
// Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "MotionSample", nil, data)
msgSender.SendMessage(id, "MotionSample", nil, data)
}
}
}()
@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error {
}
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available
if !ok {
return ErrRPCXUsingGateway
return ErrRPCXNoReturnURL
}
stepCountCh, cancel, err := i.dev.WatchStepCount()
@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
}
// Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data)
msgSender.SendMessage(id, "StepCountSample", nil, data)
}
}
}()
@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
id := uuid.New().String()
*out = id
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available
if ok {
go func() {
@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
continue
}
i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data)
msgSender.SendMessage(id, "DFUProgress", nil, data)
}
firmwareUpdating = false
i.srv.SendMessage(clientConn, id, "Done", nil, nil)
msgSender.SendMessage(id, "Done", nil, nil)
}()
}
@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
id := uuid.New().String()
*out = id
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, fs.srv)
// If user is using gateway, the client connection will not be available
if ok {
go func() {
@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
continue
}
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
msgSender.SendMessage(id, "FSProgress", nil, data)
}
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
msgSender.SendMessage(id, "Done", nil, nil)
}()
}
@ -540,22 +544,22 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error {
fs.updateFS()
localFile, err := os.Create(paths[0])
if err != nil {
return err
}
remoteFile, err := fs.fs.Open(paths[1])
if err != nil {
return err
}
id := uuid.New().String()
*out = id
// Get client's connection
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// Get client message sender
msgSender, ok := getMsgSender(ctx, fs.srv)
// If user is using gateway, the client connection will not be available
if ok {
go func() {
@ -569,11 +573,11 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
log.Error().Err(err).Msg("Error encoding filesystem transfer progress event")
continue
}
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
msgSender.SendMessage(id, "FSProgress", nil, data)
}
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
msgSender.SendMessage(id, "Done", nil, nil)
localFile.Close()
remoteFile.Close()
}()
@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string {
}
return paths
}
func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) {
// Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// If the connection exists, use rpcMsgSender
if ok {
return &rpcMsgSender{srv, clientConn}, true
} else {
// Get metadata if it exists
metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string)
if !ok {
return nil, false
}
// Get returnURL field from metadata if it exists
returnURL, ok := metadata["returnURL"]
if !ok {
return nil, false
}
// Use httpMsgSender
return &httpMsgSender{returnURL}, true
}
}
type MessageSender interface {
SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error
}
type rpcMsgSender struct {
srv *server.Server
conn net.Conn
}
func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data)
}
type httpMsgSender struct {
url string
}
func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Set("X-RPCX-ServicePath", servicePath)
req.Header.Set("X-RPCX-ServiceMethod", serviceMethod)
query := url.Values{}
for k, v := range metadata {
query.Set(k, v)
}
req.Header.Set("X-RPCX-Meta", query.Encode())
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
return res.Body.Close()
}