diff --git a/socket.go b/socket.go index 385cdb1..97ae24b 100644 --- a/socket.go +++ b/socket.go @@ -46,6 +46,7 @@ var ( ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type") ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") ErrDFUInvalidUpgType = errors.New("invalid upgrade type") + ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway") ) type DoneMap map[string]chan struct{} @@ -116,7 +117,7 @@ func startSocket(dev *infinitime.Device) error { return err } - go srv.ServeListener("unix", ln) + go srv.ServeListener("tcp", ln) // Log socket start log.Info().Str("path", k.String("socket.path")).Msg("Started control socket") @@ -136,7 +137,12 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error { } func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if !ok { + return ErrRPCXUsingGateway + } heartRateCh, cancel, err := i.dev.WatchHeartRate() if err != nil { @@ -178,7 +184,12 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error { } func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if !ok { + return ErrRPCXUsingGateway + } battLevelCh, cancel, err := i.dev.WatchBatteryLevel() if err != nil { @@ -220,7 +231,12 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er } func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if !ok { + return ErrRPCXUsingGateway + } motionValsCh, cancel, err := i.dev.WatchMotion() if err != nil { @@ -262,7 +278,12 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error { } func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if !ok { + return ErrRPCXUsingGateway + } stepCountCh, cancel, err := i.dev.WatchStepCount() if err != nil { @@ -365,23 +386,26 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou id := uuid.New().String() *out = id - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if ok { + go func() { + // For every progress event + for event := range i.dev.DFU.Progress() { + data, err := msgpack.Marshal(event) + if err != nil { + log.Error().Err(err).Msg("Error encoding DFU progress event") + continue + } - go func() { - // For every progress event - for event := range i.dev.DFU.Progress() { - data, err := msgpack.Marshal(event) - if err != nil { - log.Error().Err(err).Msg("Error encoding DFU progress event") - continue + i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) } - i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) - } - - firmwareUpdating = false - i.srv.SendMessage(clientConn, id, "Done", nil, nil) - }() + firmwareUpdating = false + i.srv.SendMessage(clientConn, id, "Done", nil, nil) + }() + } // Set firmwareUpdating firmwareUpdating = true @@ -463,7 +487,6 @@ func (fs *FS) ReadDir(_ context.Context, dir string, out *[]api.FileInfo) error func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { fs.updateFS() - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) localFile, err := os.Open(paths[1]) if err != nil { @@ -483,23 +506,28 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { id := uuid.New().String() *out = id - go func() { - // For every progress event - for sent := range remoteFile.Progress() { - data, err := msgpack.Marshal(api.FSTransferProgress{ - Total: remoteFile.Size(), - Sent: sent, - }) - if err != nil { - log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") - continue + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if ok { + go func() { + // For every progress event + for sent := range remoteFile.Progress() { + data, err := msgpack.Marshal(api.FSTransferProgress{ + Total: remoteFile.Size(), + Sent: sent, + }) + if err != nil { + log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") + continue + } + + fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) } - fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) - } - - fs.srv.SendMessage(clientConn, id, "Done", nil, nil) - }() + fs.srv.SendMessage(clientConn, id, "Done", nil, nil) + }() + } go func() { io.Copy(remoteFile, localFile) @@ -512,40 +540,44 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error { fs.updateFS() - clientConn := ctx.Value(server.RemoteConnContextKey).(net.Conn) - + localFile, err := os.Create(paths[0]) if err != nil { return err } - + remoteFile, err := fs.fs.Open(paths[1]) if err != nil { return err } - + id := uuid.New().String() *out = id - - go func() { - // For every progress event - for rcvd := range remoteFile.Progress() { - data, err := msgpack.Marshal(api.FSTransferProgress{ - Total: remoteFile.Size(), - Sent: rcvd, - }) - if err != nil { - log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") - continue + + // Get client's connection + clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) + // If user is using gateway, the client connection will not be available + if ok { + go func() { + // For every progress event + for rcvd := range remoteFile.Progress() { + data, err := msgpack.Marshal(api.FSTransferProgress{ + Total: remoteFile.Size(), + Sent: rcvd, + }) + if err != nil { + log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") + continue + } + + fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) } - - fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) - } - - fs.srv.SendMessage(clientConn, id, "Done", nil, nil) - localFile.Close() - remoteFile.Close() - }() + + fs.srv.SendMessage(clientConn, id, "Done", nil, nil) + localFile.Close() + remoteFile.Close() + }() + } go io.Copy(localFile, remoteFile)