diff --git a/socket.go b/socket.go index 97ae24b..261fbc2 100644 --- a/socket.go +++ b/socket.go @@ -19,10 +19,13 @@ package main import ( + "bytes" "context" "errors" "io" "net" + "net/http" + "net/url" "os" "path/filepath" "strings" @@ -31,6 +34,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/smallnest/rpcx/server" + "github.com/smallnest/rpcx/share" "github.com/vmihailenco/msgpack/v5" "go.arsenm.dev/infinitime" "go.arsenm.dev/infinitime/blefs" @@ -46,7 +50,7 @@ var ( ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type") ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") ErrDFUInvalidUpgType = errors.New("invalid upgrade type") - ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway") + ErrRPCXNoReturnURL = errors.New("bidirectional requests over gateway require a returnURL field in the metadata") ) type DoneMap map[string]chan struct{} @@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error { } func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, i.srv) // If user is using gateway, the client connection will not be available if !ok { - return ErrRPCXUsingGateway + return ErrRPCXNoReturnURL } heartRateCh, cancel, err := i.dev.WatchHeartRate() @@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { } // Send response to connection if no done signal received - i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data) + msgSender.SendMessage(id, "HeartRateSample", nil, data) } } }() @@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error { } func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, i.srv) // If user is using gateway, the client connection will not be available if !ok { - return ErrRPCXUsingGateway + return ErrRPCXNoReturnURL } battLevelCh, cancel, err := i.dev.WatchBatteryLevel() @@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error } // Send response to connection if no done signal received - i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data) + msgSender.SendMessage(id, "BatteryLevelSample", nil, data) } } }() @@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er } func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, i.srv) // If user is using gateway, the client connection will not be available if !ok { - return ErrRPCXUsingGateway + return ErrRPCXNoReturnURL } motionValsCh, cancel, err := i.dev.WatchMotion() @@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { } // Send response to connection if no done signal received - i.srv.SendMessage(clientConn, id, "MotionSample", nil, data) + msgSender.SendMessage(id, "MotionSample", nil, data) } } }() @@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error { } func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, i.srv) // If user is using gateway, the client connection will not be available if !ok { - return ErrRPCXUsingGateway + return ErrRPCXNoReturnURL } stepCountCh, cancel, err := i.dev.WatchStepCount() @@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { } // Send response to connection if no done signal received - i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data) + msgSender.SendMessage(id, "StepCountSample", nil, data) } } }() @@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou id := uuid.New().String() *out = id - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, i.srv) // If user is using gateway, the client connection will not be available if ok { go func() { @@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou continue } - i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) + msgSender.SendMessage(id, "DFUProgress", nil, data) } firmwareUpdating = false - i.srv.SendMessage(clientConn, id, "Done", nil, nil) + msgSender.SendMessage(id, "Done", nil, nil) }() } @@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { id := uuid.New().String() *out = id - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client message sender + msgSender, ok := getMsgSender(ctx, fs.srv) // If user is using gateway, the client connection will not be available if ok { go func() { @@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { continue } - fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) + msgSender.SendMessage(id, "FSProgress", nil, data) } - fs.srv.SendMessage(clientConn, id, "Done", nil, nil) + msgSender.SendMessage(id, "Done", nil, nil) }() } @@ -540,22 +544,22 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error { fs.updateFS() - + localFile, err := os.Create(paths[0]) if err != nil { return err } - + remoteFile, err := fs.fs.Open(paths[1]) if err != nil { return err } - + id := uuid.New().String() *out = id - - // Get client's connection - clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + + // Get client message sender + msgSender, ok := getMsgSender(ctx, fs.srv) // If user is using gateway, the client connection will not be available if ok { go func() { @@ -569,11 +573,11 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") continue } - - fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) + + msgSender.SendMessage(id, "FSProgress", nil, data) } - - fs.srv.SendMessage(clientConn, id, "Done", nil, nil) + + msgSender.SendMessage(id, "Done", nil, nil) localFile.Close() remoteFile.Close() }() @@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string { } return paths } + +func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) { + // Get client message sender + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If the connection exists, use rpcMsgSender + if ok { + return &rpcMsgSender{srv, clientConn}, true + } else { + // Get metadata if it exists + metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string) + if !ok { + return nil, false + } + // Get returnURL field from metadata if it exists + returnURL, ok := metadata["returnURL"] + if !ok { + return nil, false + } + // Use httpMsgSender + return &httpMsgSender{returnURL}, true + } + +} + +type MessageSender interface { + SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error +} + +type rpcMsgSender struct { + srv *server.Server + conn net.Conn +} + +func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { + + return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data) +} + +type httpMsgSender struct { + url string +} + +func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { + req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data)) + if err != nil { + return err + } + + req.Header.Set("X-RPCX-ServicePath", servicePath) + req.Header.Set("X-RPCX-ServiceMethod", serviceMethod) + + query := url.Values{} + for k, v := range metadata { + query.Set(k, v) + } + req.Header.Set("X-RPCX-Meta", query.Encode()) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + return res.Body.Close() +}