Generalize socket cancellation and update API accordingly

This commit is contained in:
Elara 2021-10-23 18:03:17 -07:00
parent ef4bad94b5
commit e198b769f9
9 changed files with 211 additions and 76 deletions

View File

@ -18,11 +18,11 @@ const DefaultAddr = "/tmp/itd/socket"
type Client struct { type Client struct {
conn net.Conn conn net.Conn
respCh chan types.Response respCh chan types.Response
heartRateCh chan uint8 heartRateCh chan types.Response
battLevelCh chan uint8 battLevelCh chan types.Response
stepCountCh chan uint32 stepCountCh chan types.Response
motionCh chan infinitime.MotionValues motionCh chan types.Response
dfuProgressCh chan DFUProgress dfuProgressCh chan types.Response
} }
// New creates a new client and sets it up // New creates a new client and sets it up
@ -91,27 +91,43 @@ func (c *Client) requestNoRes(req types.Request) error {
func (c *Client) handleResp(res types.Response) error { func (c *Client) handleResp(res types.Response) error {
switch res.Type { switch res.Type {
case types.ResTypeWatchHeartRate: case types.ResTypeWatchHeartRate:
c.heartRateCh <- uint8(res.Value.(float64)) c.heartRateCh <- res
case types.ResTypeWatchBattLevel: case types.ResTypeWatchBattLevel:
c.battLevelCh <- uint8(res.Value.(float64)) c.battLevelCh <- res
case types.ResTypeWatchStepCount: case types.ResTypeWatchStepCount:
c.stepCountCh <- uint32(res.Value.(float64)) c.stepCountCh <- res
case types.ResTypeWatchMotion: case types.ResTypeWatchMotion:
out := infinitime.MotionValues{} c.motionCh <- res
err := mapstructure.Decode(res.Value, &out)
if err != nil {
return err
}
c.motionCh <- out
case types.ResTypeDFUProgress: case types.ResTypeDFUProgress:
out := DFUProgress{} c.dfuProgressCh <- res
err := mapstructure.Decode(res.Value, &out)
if err != nil {
return err
}
c.dfuProgressCh <- out
default: default:
c.respCh <- res c.respCh <- res
} }
return nil return nil
} }
func decodeUint8(val interface{}) uint8 {
return uint8(val.(float64))
}
func decodeUint32(val interface{}) uint32 {
return uint32(val.(float64))
}
func decodeMotion(val interface{}) (infinitime.MotionValues, error) {
out := infinitime.MotionValues{}
err := mapstructure.Decode(val, &out)
if err != nil {
return out, err
}
return out, nil
}
func decodeDFUProgress(val interface{}) (DFUProgress, error) {
out := DFUProgress{}
err := mapstructure.Decode(val, &out)
if err != nil {
return out, err
}
return out, nil
}

View File

@ -1,8 +1,6 @@
package api package api
import ( import (
"reflect"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"go.arsenm.dev/infinitime" "go.arsenm.dev/infinitime"
"go.arsenm.dev/itd/internal/types" "go.arsenm.dev/itd/internal/types"
@ -48,15 +46,27 @@ func (c *Client) BatteryLevel() (uint8, error) {
// new battery level values as they update. Do not use after // new battery level values as they update. Do not use after
// calling cancellation function // calling cancellation function
func (c *Client) WatchBatteryLevel() (<-chan uint8, func(), error) { func (c *Client) WatchBatteryLevel() (<-chan uint8, func(), error) {
c.battLevelCh = make(chan uint8, 2) c.battLevelCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{ err := c.requestNoRes(types.Request{
Type: types.ReqTypeBattLevel, Type: types.ReqTypeBattLevel,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cancel := c.cancelFn(types.ReqTypeCancelBattLevel, c.battLevelCh) res := <-c.battLevelCh
return c.battLevelCh, cancel, nil done, cancel := c.cancelFn(res.ID, c.battLevelCh)
out := make(chan uint8, 2)
go func() {
for res := range c.battLevelCh {
select {
case <-done:
return
default:
out <- decodeUint8(res.Value)
}
}
}()
return out, cancel, nil
} }
// HeartRate gets the heart rate from the connected device // HeartRate gets the heart rate from the connected device
@ -68,33 +78,46 @@ func (c *Client) HeartRate() (uint8, error) {
return 0, err return 0, err
} }
return uint8(res.Value.(float64)), nil return decodeUint8(res.Value), nil
} }
// WatchHeartRate returns a channel which will contain // WatchHeartRate returns a channel which will contain
// new heart rate values as they update. Do not use after // new heart rate values as they update. Do not use after
// calling cancellation function // calling cancellation function
func (c *Client) WatchHeartRate() (<-chan uint8, func(), error) { func (c *Client) WatchHeartRate() (<-chan uint8, func(), error) {
c.heartRateCh = make(chan uint8, 2) c.heartRateCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{ err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchHeartRate, Type: types.ReqTypeWatchHeartRate,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cancel := c.cancelFn(types.ReqTypeCancelHeartRate, c.heartRateCh) res := <-c.heartRateCh
return c.heartRateCh, cancel, nil done, cancel := c.cancelFn(res.ID, c.heartRateCh)
out := make(chan uint8, 2)
go func() {
for res := range c.heartRateCh {
select {
case <-done:
return
default:
out <- decodeUint8(res.Value)
}
}
}()
return out, cancel, nil
} }
// cancelFn generates a cancellation function for the given // cancelFn generates a cancellation function for the given
// request type and channel // request type and channel
func (c *Client) cancelFn(reqType int, ch interface{}) func() { func (c *Client) cancelFn(reqID string, ch chan types.Response) (chan struct{}, func()) {
return func() { done := make(chan struct{}, 1)
reflectCh := reflect.ValueOf(ch) return done, func() {
reflectCh.Close() done <- struct{}{}
reflectCh.Set(reflect.Zero(reflectCh.Type())) close(ch)
c.requestNoRes(types.Request{ c.requestNoRes(types.Request{
Type: reqType, Type: types.ReqTypeCancel,
Data: reqID,
}) })
} }
} }
@ -115,15 +138,27 @@ func (c *Client) StepCount() (uint32, error) {
// new step count values as they update. Do not use after // new step count values as they update. Do not use after
// calling cancellation function // calling cancellation function
func (c *Client) WatchStepCount() (<-chan uint32, func(), error) { func (c *Client) WatchStepCount() (<-chan uint32, func(), error) {
c.stepCountCh = make(chan uint32, 2) c.stepCountCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{ err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchStepCount, Type: types.ReqTypeWatchStepCount,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cancel := c.cancelFn(types.ReqTypeCancelStepCount, c.stepCountCh) res := <-c.stepCountCh
return c.stepCountCh, cancel, nil done, cancel := c.cancelFn(res.ID, c.stepCountCh)
out := make(chan uint32, 2)
go func() {
for res := range c.stepCountCh {
select {
case <-done:
return
default:
out <- decodeUint32(res.Value)
}
}
}()
return out, cancel, nil
} }
// Motion gets the motion values from the connected device // Motion gets the motion values from the connected device
@ -146,13 +181,29 @@ func (c *Client) Motion() (infinitime.MotionValues, error) {
// new motion values as they update. Do not use after // new motion values as they update. Do not use after
// calling cancellation function // calling cancellation function
func (c *Client) WatchMotion() (<-chan infinitime.MotionValues, func(), error) { func (c *Client) WatchMotion() (<-chan infinitime.MotionValues, func(), error) {
c.motionCh = make(chan infinitime.MotionValues, 2) c.motionCh = make(chan types.Response, 5)
err := c.requestNoRes(types.Request{ err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchMotion, Type: types.ReqTypeWatchMotion,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cancel := c.cancelFn(types.ReqTypeCancelMotion, c.motionCh) res := <-c.motionCh
return c.motionCh, cancel, nil done, cancel := c.cancelFn(res.ID, c.motionCh)
out := make(chan infinitime.MotionValues, 5)
go func() {
for res := range c.motionCh {
select {
case <-done:
return
default:
motion, err := decodeMotion(res.Value)
if err != nil {
continue
}
out <- motion
}
}
}()
return out, cancel, nil
} }

View File

@ -31,7 +31,18 @@ func (c *Client) FirmwareUpgrade(upgType UpgradeType, files ...string) (<-chan D
return nil, err return nil, err
} }
c.dfuProgressCh = make(chan DFUProgress, 5) c.dfuProgressCh = make(chan types.Response, 5)
return c.dfuProgressCh, nil out := make(chan DFUProgress, 5)
go func() {
for res := range c.dfuProgressCh {
progress, err := decodeDFUProgress(res.Value)
if err != nil {
continue
}
out <- progress
}
}()
return out, nil
} }

28
cmd/test/main.go Normal file
View File

@ -0,0 +1,28 @@
package main
import (
"fmt"
"time"
"go.arsenm.dev/itd/api"
)
func main() {
itd, _ := api.New(api.DefaultAddr)
defer itd.Close()
fmt.Println(itd.Address())
mCh, cancel, _ := itd.WatchMotion()
go func() {
time.Sleep(10 * time.Second)
cancel()
fmt.Println("canceled")
}()
for m := range mCh {
fmt.Println(m)
}
}

1
go.mod
View File

@ -13,6 +13,7 @@ require (
github.com/go-gl/gl v0.0.0-20210905235341-f7a045908259 // indirect github.com/go-gl/gl v0.0.0-20210905235341-f7a045908259 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20210727001814-0db043d8d5be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20210727001814-0db043d8d5be // indirect
github.com/godbus/dbus/v5 v5.0.5 github.com/godbus/dbus/v5 v5.0.5
github.com/google/uuid v1.1.2
github.com/mattn/go-colorable v0.1.11 // indirect github.com/mattn/go-colorable v0.1.11 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/mitchellh/mapstructure v1.4.2 github.com/mitchellh/mapstructure v1.4.2

1
go.sum
View File

@ -193,6 +193,7 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=

View File

@ -9,15 +9,12 @@ const (
ReqTypeNotify ReqTypeNotify
ReqTypeSetTime ReqTypeSetTime
ReqTypeWatchHeartRate ReqTypeWatchHeartRate
ReqTypeCancelHeartRate
ReqTypeWatchBattLevel ReqTypeWatchBattLevel
ReqTypeCancelBattLevel
ReqTypeMotion ReqTypeMotion
ReqTypeWatchMotion ReqTypeWatchMotion
ReqTypeCancelMotion
ReqTypeStepCount ReqTypeStepCount
ReqTypeWatchStepCount ReqTypeWatchStepCount
ReqTypeCancelStepCount ReqTypeCancel
) )
const ( const (
@ -29,15 +26,12 @@ const (
ResTypeNotify ResTypeNotify
ResTypeSetTime ResTypeSetTime
ResTypeWatchHeartRate ResTypeWatchHeartRate
ResTypeCancelHeartRate
ResTypeWatchBattLevel ResTypeWatchBattLevel
ResTypeCancelBattLevel
ResTypeMotion ResTypeMotion
ResTypeWatchMotion ResTypeWatchMotion
ResTypeCancelMotion
ResTypeStepCount ResTypeStepCount
ResTypeWatchStepCount ResTypeWatchStepCount
ResTypeCancelStepCount ResTypeCancel
) )
const ( const (
@ -54,6 +48,7 @@ type Response struct {
Type int `json:"type"` Type int `json:"type"`
Value interface{} `json:"value,omitempty"` Value interface{} `json:"value,omitempty"`
Message string `json:"msg,omitempty"` Message string `json:"msg,omitempty"`
ID string `json:"id,omitempty"`
Error bool `json:"error"` Error bool `json:"error"`
} }

View File

@ -27,6 +27,7 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/google/uuid"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -35,6 +36,29 @@ import (
"go.arsenm.dev/itd/translit" "go.arsenm.dev/itd/translit"
) )
type DoneMap map[string]chan struct{}
func (dm DoneMap) Exists(key string) bool {
_, ok := dm[key]
return ok
}
func (dm DoneMap) Done(key string) {
ch := dm[key]
ch <- struct{}{}
}
func (dm DoneMap) Create(key string) {
dm[key] = make(chan struct{}, 1)
}
func (dm DoneMap) Remove(key string) {
close(dm[key])
delete(dm, key)
}
var done = DoneMap{}
func startSocket(dev *infinitime.Device) error { func startSocket(dev *infinitime.Device) error {
// Make socket directory if non-existant // Make socket directory if non-existant
err := os.MkdirAll(filepath.Dir(viper.GetString("socket.path")), 0755) err := os.MkdirAll(filepath.Dir(viper.GetString("socket.path")), 0755)
@ -81,11 +105,6 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
return return
} }
heartRateDone := make(chan struct{})
battLevelDone := make(chan struct{})
stepCountDone := make(chan struct{})
motionDone := make(chan struct{})
// Create new scanner on connection // Create new scanner on connection
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(conn)
for scanner.Scan() { for scanner.Scan() {
@ -116,27 +135,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String()
go func() { go func() {
done.Create(reqID)
// For every heart rate value // For every heart rate value
for heartRate := range heartRateCh { for heartRate := range heartRateCh {
select { select {
case <-heartRateDone: case <-done[reqID]:
// Stop notifications if done signal received // Stop notifications if done signal received
cancel() cancel()
done.Remove(reqID)
return return
default: default:
// Encode response to connection if no done signal received // Encode response to connection if no done signal received
json.NewEncoder(conn).Encode(types.Response{ json.NewEncoder(conn).Encode(types.Response{
Type: types.ResTypeWatchHeartRate, Type: types.ResTypeWatchHeartRate,
ID: reqID,
Value: heartRate, Value: heartRate,
}) })
} }
} }
}() }()
case types.ReqTypeCancelHeartRate:
// Stop heart rate notifications
heartRateDone <- struct{}{}
json.NewEncoder(conn).Encode(types.Response{})
case types.ReqTypeBattLevel: case types.ReqTypeBattLevel:
// Get battery level from watch // Get battery level from watch
battLevel, err := dev.BatteryLevel() battLevel, err := dev.BatteryLevel()
@ -155,27 +174,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
connErr(conn, err, "Error getting battery level channel") connErr(conn, err, "Error getting battery level channel")
break break
} }
reqID := uuid.New().String()
go func() { go func() {
done.Create(reqID)
// For every battery level value // For every battery level value
for battLevel := range battLevelCh { for battLevel := range battLevelCh {
select { select {
case <-battLevelDone: case <-done[reqID]:
// Stop notifications if done signal received // Stop notifications if done signal received
cancel() cancel()
done.Remove(reqID)
return return
default: default:
// Encode response to connection if no done signal received // Encode response to connection if no done signal received
json.NewEncoder(conn).Encode(types.Response{ json.NewEncoder(conn).Encode(types.Response{
Type: types.ResTypeWatchBattLevel, Type: types.ResTypeWatchBattLevel,
ID: reqID,
Value: battLevel, Value: battLevel,
}) })
} }
} }
}() }()
case types.ReqTypeCancelBattLevel:
// Stop battery level notifications
battLevelDone <- struct{}{}
json.NewEncoder(conn).Encode(types.Response{})
case types.ReqTypeMotion: case types.ReqTypeMotion:
// Get battery level from watch // Get battery level from watch
motionVals, err := dev.Motion() motionVals, err := dev.Motion()
@ -194,27 +213,28 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String()
go func() { go func() {
done.Create(reqID)
// For every motion event // For every motion event
for motionVals := range motionValCh { for motionVals := range motionValCh {
select { select {
case <-motionDone: case <-done[reqID]:
// Stop notifications if done signal received // Stop notifications if done signal received
cancel() cancel()
done.Remove(reqID)
return return
default: default:
// Encode response to connection if no done signal received // Encode response to connection if no done signal received
json.NewEncoder(conn).Encode(types.Response{ json.NewEncoder(conn).Encode(types.Response{
Type: types.ResTypeWatchMotion, Type: types.ResTypeWatchMotion,
ID: reqID,
Value: motionVals, Value: motionVals,
}) })
} }
} }
}() }()
case types.ReqTypeCancelMotion:
// Stop motion notifications
motionDone <- struct{}{}
json.NewEncoder(conn).Encode(types.Response{})
case types.ReqTypeStepCount: case types.ReqTypeStepCount:
// Get battery level from watch // Get battery level from watch
stepCount, err := dev.StepCount() stepCount, err := dev.StepCount()
@ -233,27 +253,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String()
go func() { go func() {
done.Create(reqID)
// For every step count value // For every step count value
for stepCount := range stepCountCh { for stepCount := range stepCountCh {
select { select {
case <-stepCountDone: case <-done[reqID]:
// Stop notifications if done signal received // Stop notifications if done signal received
cancel() cancel()
done.Remove(reqID)
return return
default: default:
// Encode response to connection if no done signal received // Encode response to connection if no done signal received
json.NewEncoder(conn).Encode(types.Response{ json.NewEncoder(conn).Encode(types.Response{
Type: types.ResTypeWatchStepCount, Type: types.ResTypeWatchStepCount,
ID: reqID,
Value: stepCount, Value: stepCount,
}) })
} }
} }
}() }()
case types.ReqTypeCancelStepCount:
// Stop step count notifications
stepCountDone <- struct{}{}
json.NewEncoder(conn).Encode(types.Response{})
case types.ReqTypeFwVersion: case types.ReqTypeFwVersion:
// Get firmware version from watch // Get firmware version from watch
version, err := dev.Version() version, err := dev.Version()
@ -409,6 +429,18 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
break break
} }
firmwareUpdating = false firmwareUpdating = false
case types.ReqTypeCancel:
if req.Data == nil {
connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.")
continue
}
reqID, ok := req.Data.(string)
if !ok {
connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.")
}
// Stop notifications
done.Done(reqID)
json.NewEncoder(conn).Encode(types.Response{Type: types.ResTypeCancel})
} }
} }
} }

BIN
test Executable file

Binary file not shown.