diff --git a/api/client.go b/api/client.go index 67aa16b..45846af 100644 --- a/api/client.go +++ b/api/client.go @@ -18,11 +18,11 @@ const DefaultAddr = "/tmp/itd/socket" type Client struct { conn net.Conn respCh chan types.Response - heartRateCh chan uint8 - battLevelCh chan uint8 - stepCountCh chan uint32 - motionCh chan infinitime.MotionValues - dfuProgressCh chan DFUProgress + heartRateCh chan types.Response + battLevelCh chan types.Response + stepCountCh chan types.Response + motionCh chan types.Response + dfuProgressCh chan types.Response } // New creates a new client and sets it up @@ -91,27 +91,43 @@ func (c *Client) requestNoRes(req types.Request) error { func (c *Client) handleResp(res types.Response) error { switch res.Type { case types.ResTypeWatchHeartRate: - c.heartRateCh <- uint8(res.Value.(float64)) + c.heartRateCh <- res case types.ResTypeWatchBattLevel: - c.battLevelCh <- uint8(res.Value.(float64)) + c.battLevelCh <- res case types.ResTypeWatchStepCount: - c.stepCountCh <- uint32(res.Value.(float64)) + c.stepCountCh <- res case types.ResTypeWatchMotion: - out := infinitime.MotionValues{} - err := mapstructure.Decode(res.Value, &out) - if err != nil { - return err - } - c.motionCh <- out + c.motionCh <- res case types.ResTypeDFUProgress: - out := DFUProgress{} - err := mapstructure.Decode(res.Value, &out) - if err != nil { - return err - } - c.dfuProgressCh <- out + c.dfuProgressCh <- res default: c.respCh <- res } return nil } + +func decodeUint8(val interface{}) uint8 { + return uint8(val.(float64)) +} + +func decodeUint32(val interface{}) uint32 { + return uint32(val.(float64)) +} + +func decodeMotion(val interface{}) (infinitime.MotionValues, error) { + out := infinitime.MotionValues{} + err := mapstructure.Decode(val, &out) + if err != nil { + return out, err + } + return out, nil +} + +func decodeDFUProgress(val interface{}) (DFUProgress, error) { + out := DFUProgress{} + err := mapstructure.Decode(val, &out) + if err != nil { + return out, err + } + return out, nil +} diff --git a/api/info.go b/api/info.go index c0b2aef..a422df7 100644 --- a/api/info.go +++ b/api/info.go @@ -1,8 +1,6 @@ package api import ( - "reflect" - "github.com/mitchellh/mapstructure" "go.arsenm.dev/infinitime" "go.arsenm.dev/itd/internal/types" @@ -48,15 +46,27 @@ func (c *Client) BatteryLevel() (uint8, error) { // new battery level values as they update. Do not use after // calling cancellation function func (c *Client) WatchBatteryLevel() (<-chan uint8, func(), error) { - c.battLevelCh = make(chan uint8, 2) + c.battLevelCh = make(chan types.Response, 2) err := c.requestNoRes(types.Request{ Type: types.ReqTypeBattLevel, }) if err != nil { return nil, nil, err } - cancel := c.cancelFn(types.ReqTypeCancelBattLevel, c.battLevelCh) - return c.battLevelCh, cancel, nil + res := <-c.battLevelCh + done, cancel := c.cancelFn(res.ID, c.battLevelCh) + out := make(chan uint8, 2) + go func() { + for res := range c.battLevelCh { + select { + case <-done: + return + default: + out <- decodeUint8(res.Value) + } + } + }() + return out, cancel, nil } // HeartRate gets the heart rate from the connected device @@ -68,33 +78,46 @@ func (c *Client) HeartRate() (uint8, error) { return 0, err } - return uint8(res.Value.(float64)), nil + return decodeUint8(res.Value), nil } // WatchHeartRate returns a channel which will contain // new heart rate values as they update. Do not use after // calling cancellation function func (c *Client) WatchHeartRate() (<-chan uint8, func(), error) { - c.heartRateCh = make(chan uint8, 2) + c.heartRateCh = make(chan types.Response, 2) err := c.requestNoRes(types.Request{ Type: types.ReqTypeWatchHeartRate, }) if err != nil { return nil, nil, err } - cancel := c.cancelFn(types.ReqTypeCancelHeartRate, c.heartRateCh) - return c.heartRateCh, cancel, nil + res := <-c.heartRateCh + done, cancel := c.cancelFn(res.ID, c.heartRateCh) + out := make(chan uint8, 2) + go func() { + for res := range c.heartRateCh { + select { + case <-done: + return + default: + out <- decodeUint8(res.Value) + } + } + }() + return out, cancel, nil } // cancelFn generates a cancellation function for the given // request type and channel -func (c *Client) cancelFn(reqType int, ch interface{}) func() { - return func() { - reflectCh := reflect.ValueOf(ch) - reflectCh.Close() - reflectCh.Set(reflect.Zero(reflectCh.Type())) +func (c *Client) cancelFn(reqID string, ch chan types.Response) (chan struct{}, func()) { + done := make(chan struct{}, 1) + return done, func() { + done <- struct{}{} + close(ch) c.requestNoRes(types.Request{ - Type: reqType, + Type: types.ReqTypeCancel, + Data: reqID, }) } } @@ -115,15 +138,27 @@ func (c *Client) StepCount() (uint32, error) { // new step count values as they update. Do not use after // calling cancellation function func (c *Client) WatchStepCount() (<-chan uint32, func(), error) { - c.stepCountCh = make(chan uint32, 2) + c.stepCountCh = make(chan types.Response, 2) err := c.requestNoRes(types.Request{ Type: types.ReqTypeWatchStepCount, }) if err != nil { return nil, nil, err } - cancel := c.cancelFn(types.ReqTypeCancelStepCount, c.stepCountCh) - return c.stepCountCh, cancel, nil + res := <-c.stepCountCh + done, cancel := c.cancelFn(res.ID, c.stepCountCh) + out := make(chan uint32, 2) + go func() { + for res := range c.stepCountCh { + select { + case <-done: + return + default: + out <- decodeUint32(res.Value) + } + } + }() + return out, cancel, nil } // Motion gets the motion values from the connected device @@ -146,13 +181,29 @@ func (c *Client) Motion() (infinitime.MotionValues, error) { // new motion values as they update. Do not use after // calling cancellation function func (c *Client) WatchMotion() (<-chan infinitime.MotionValues, func(), error) { - c.motionCh = make(chan infinitime.MotionValues, 2) + c.motionCh = make(chan types.Response, 5) err := c.requestNoRes(types.Request{ Type: types.ReqTypeWatchMotion, }) if err != nil { return nil, nil, err } - cancel := c.cancelFn(types.ReqTypeCancelMotion, c.motionCh) - return c.motionCh, cancel, nil + res := <-c.motionCh + done, cancel := c.cancelFn(res.ID, c.motionCh) + out := make(chan infinitime.MotionValues, 5) + go func() { + for res := range c.motionCh { + select { + case <-done: + return + default: + motion, err := decodeMotion(res.Value) + if err != nil { + continue + } + out <- motion + } + } + }() + return out, cancel, nil } diff --git a/api/upgrade.go b/api/upgrade.go index 3311b74..1f1d44e 100644 --- a/api/upgrade.go +++ b/api/upgrade.go @@ -31,7 +31,18 @@ func (c *Client) FirmwareUpgrade(upgType UpgradeType, files ...string) (<-chan D return nil, err } - c.dfuProgressCh = make(chan DFUProgress, 5) + c.dfuProgressCh = make(chan types.Response, 5) - return c.dfuProgressCh, nil + out := make(chan DFUProgress, 5) + go func() { + for res := range c.dfuProgressCh { + progress, err := decodeDFUProgress(res.Value) + if err != nil { + continue + } + out <- progress + } + }() + + return out, nil } diff --git a/cmd/test/main.go b/cmd/test/main.go new file mode 100644 index 0000000..f387c25 --- /dev/null +++ b/cmd/test/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "time" + + "go.arsenm.dev/itd/api" +) + +func main() { + itd, _ := api.New(api.DefaultAddr) + defer itd.Close() + + fmt.Println(itd.Address()) + + mCh, cancel, _ := itd.WatchMotion() + + go func() { + time.Sleep(10 * time.Second) + cancel() + fmt.Println("canceled") + }() + + for m := range mCh { + fmt.Println(m) + } + +} diff --git a/go.mod b/go.mod index 9183c53..a8cb8cd 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-gl/gl v0.0.0-20210905235341-f7a045908259 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20210727001814-0db043d8d5be // indirect github.com/godbus/dbus/v5 v5.0.5 + github.com/google/uuid v1.1.2 github.com/mattn/go-colorable v0.1.11 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mitchellh/mapstructure v1.4.2 diff --git a/go.sum b/go.sum index 57ef322..107892e 100644 --- a/go.sum +++ b/go.sum @@ -193,6 +193,7 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= diff --git a/internal/types/types.go b/internal/types/types.go index a945982..cf6a47f 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -9,15 +9,12 @@ const ( ReqTypeNotify ReqTypeSetTime ReqTypeWatchHeartRate - ReqTypeCancelHeartRate ReqTypeWatchBattLevel - ReqTypeCancelBattLevel ReqTypeMotion ReqTypeWatchMotion - ReqTypeCancelMotion ReqTypeStepCount ReqTypeWatchStepCount - ReqTypeCancelStepCount + ReqTypeCancel ) const ( @@ -29,15 +26,12 @@ const ( ResTypeNotify ResTypeSetTime ResTypeWatchHeartRate - ResTypeCancelHeartRate ResTypeWatchBattLevel - ResTypeCancelBattLevel ResTypeMotion ResTypeWatchMotion - ResTypeCancelMotion ResTypeStepCount ResTypeWatchStepCount - ResTypeCancelStepCount + ResTypeCancel ) const ( @@ -54,6 +48,7 @@ type Response struct { Type int `json:"type"` Value interface{} `json:"value,omitempty"` Message string `json:"msg,omitempty"` + ID string `json:"id,omitempty"` Error bool `json:"error"` } diff --git a/socket.go b/socket.go index 82c9970..356f3be 100644 --- a/socket.go +++ b/socket.go @@ -27,6 +27,7 @@ import ( "path/filepath" "time" + "github.com/google/uuid" "github.com/mitchellh/mapstructure" "github.com/rs/zerolog/log" "github.com/spf13/viper" @@ -35,6 +36,29 @@ import ( "go.arsenm.dev/itd/translit" ) +type DoneMap map[string]chan struct{} + +func (dm DoneMap) Exists(key string) bool { + _, ok := dm[key] + return ok +} + +func (dm DoneMap) Done(key string) { + ch := dm[key] + ch <- struct{}{} +} + +func (dm DoneMap) Create(key string) { + dm[key] = make(chan struct{}, 1) +} + +func (dm DoneMap) Remove(key string) { + close(dm[key]) + delete(dm, key) +} + +var done = DoneMap{} + func startSocket(dev *infinitime.Device) error { // Make socket directory if non-existant err := os.MkdirAll(filepath.Dir(viper.GetString("socket.path")), 0755) @@ -81,11 +105,6 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { return } - heartRateDone := make(chan struct{}) - battLevelDone := make(chan struct{}) - stepCountDone := make(chan struct{}) - motionDone := make(chan struct{}) - // Create new scanner on connection scanner := bufio.NewScanner(conn) for scanner.Scan() { @@ -116,27 +135,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { connErr(conn, err, "Error getting heart rate channel") break } + reqID := uuid.New().String() go func() { + done.Create(reqID) // For every heart rate value for heartRate := range heartRateCh { select { - case <-heartRateDone: + case <-done[reqID]: // Stop notifications if done signal received cancel() + done.Remove(reqID) return default: // Encode response to connection if no done signal received json.NewEncoder(conn).Encode(types.Response{ Type: types.ResTypeWatchHeartRate, + ID: reqID, Value: heartRate, }) } } }() - case types.ReqTypeCancelHeartRate: - // Stop heart rate notifications - heartRateDone <- struct{}{} - json.NewEncoder(conn).Encode(types.Response{}) case types.ReqTypeBattLevel: // Get battery level from watch battLevel, err := dev.BatteryLevel() @@ -155,27 +174,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { connErr(conn, err, "Error getting battery level channel") break } + reqID := uuid.New().String() go func() { + done.Create(reqID) // For every battery level value for battLevel := range battLevelCh { select { - case <-battLevelDone: + case <-done[reqID]: // Stop notifications if done signal received cancel() + done.Remove(reqID) return default: // Encode response to connection if no done signal received json.NewEncoder(conn).Encode(types.Response{ Type: types.ResTypeWatchBattLevel, + ID: reqID, Value: battLevel, }) } } }() - case types.ReqTypeCancelBattLevel: - // Stop battery level notifications - battLevelDone <- struct{}{} - json.NewEncoder(conn).Encode(types.Response{}) case types.ReqTypeMotion: // Get battery level from watch motionVals, err := dev.Motion() @@ -194,27 +213,28 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { connErr(conn, err, "Error getting heart rate channel") break } + reqID := uuid.New().String() go func() { + done.Create(reqID) // For every motion event for motionVals := range motionValCh { select { - case <-motionDone: + case <-done[reqID]: // Stop notifications if done signal received cancel() + done.Remove(reqID) + return default: // Encode response to connection if no done signal received json.NewEncoder(conn).Encode(types.Response{ Type: types.ResTypeWatchMotion, + ID: reqID, Value: motionVals, }) } } }() - case types.ReqTypeCancelMotion: - // Stop motion notifications - motionDone <- struct{}{} - json.NewEncoder(conn).Encode(types.Response{}) case types.ReqTypeStepCount: // Get battery level from watch stepCount, err := dev.StepCount() @@ -233,27 +253,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { connErr(conn, err, "Error getting heart rate channel") break } + reqID := uuid.New().String() go func() { + done.Create(reqID) // For every step count value for stepCount := range stepCountCh { select { - case <-stepCountDone: + case <-done[reqID]: // Stop notifications if done signal received cancel() + done.Remove(reqID) return default: // Encode response to connection if no done signal received json.NewEncoder(conn).Encode(types.Response{ Type: types.ResTypeWatchStepCount, + ID: reqID, Value: stepCount, }) } } }() - case types.ReqTypeCancelStepCount: - // Stop step count notifications - stepCountDone <- struct{}{} - json.NewEncoder(conn).Encode(types.Response{}) case types.ReqTypeFwVersion: // Get firmware version from watch version, err := dev.Version() @@ -409,6 +429,18 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { break } firmwareUpdating = false + case types.ReqTypeCancel: + if req.Data == nil { + connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.") + continue + } + reqID, ok := req.Data.(string) + if !ok { + connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.") + } + // Stop notifications + done.Done(reqID) + json.NewEncoder(conn).Encode(types.Response{Type: types.ResTypeCancel}) } } } diff --git a/test b/test new file mode 100755 index 0000000..831c240 Binary files /dev/null and b/test differ