From cf7dbd0b9c92a3b225742c7c36bad6166f2f3267 Mon Sep 17 00:00:00 2001 From: Elara Musayelyan Date: Sun, 24 Oct 2021 01:09:27 -0700 Subject: [PATCH] Use request type for error response type --- socket.go | 74 +++++++++++++++++++++++++++---------------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/socket.go b/socket.go index 70dd20f..15d7600 100644 --- a/socket.go +++ b/socket.go @@ -22,7 +22,6 @@ import ( "bufio" "encoding/json" "fmt" - "math" "net" "os" "path/filepath" @@ -100,11 +99,6 @@ func startSocket(dev *infinitime.Device) error { func handleConnection(conn net.Conn, dev *infinitime.Device) { defer conn.Close() - // If firmware is updating, return error - if firmwareUpdating { - connErr(conn, nil, "Firmware update in progress") - return - } // Create new scanner on connection scanner := bufio.NewScanner(conn) @@ -113,16 +107,22 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Decode scanned message into types.Request err := json.Unmarshal(scanner.Bytes(), &req) if err != nil { - connErr(conn, err, "Error decoding JSON input") + connErr(conn, req.Type, err, "Error decoding JSON input") continue } + // If firmware is updating, return error + if firmwareUpdating { + connErr(conn, req.Type, nil, "Firmware update in progress") + return + } + switch req.Type { case types.ReqTypeHeartRate: // Get heart rate from watch heartRate, err := dev.HeartRate() if err != nil { - connErr(conn, err, "Error getting heart rate") + connErr(conn, req.Type, err, "Error getting heart rate") break } // Encode heart rate to connection @@ -133,7 +133,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeWatchHeartRate: heartRateCh, cancel, err := dev.WatchHeartRate() if err != nil { - connErr(conn, err, "Error getting heart rate channel") + connErr(conn, req.Type, err, "Error getting heart rate channel") break } reqID := uuid.New().String() @@ -161,7 +161,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Get battery level from watch battLevel, err := dev.BatteryLevel() if err != nil { - connErr(conn, err, "Error getting battery level") + connErr(conn, req.Type, err, "Error getting battery level") break } // Encode battery level to connection @@ -172,7 +172,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeWatchBattLevel: battLevelCh, cancel, err := dev.WatchBatteryLevel() if err != nil { - connErr(conn, err, "Error getting battery level channel") + connErr(conn, req.Type, err, "Error getting battery level channel") break } reqID := uuid.New().String() @@ -200,7 +200,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Get battery level from watch motionVals, err := dev.Motion() if err != nil { - connErr(conn, err, "Error getting motion values") + connErr(conn, req.Type, err, "Error getting motion values") break } // Encode battery level to connection @@ -211,7 +211,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeWatchMotion: motionValCh, cancel, err := dev.WatchMotion() if err != nil { - connErr(conn, err, "Error getting heart rate channel") + connErr(conn, req.Type, err, "Error getting heart rate channel") break } reqID := uuid.New().String() @@ -240,7 +240,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Get battery level from watch stepCount, err := dev.StepCount() if err != nil { - connErr(conn, err, "Error getting step count") + connErr(conn, req.Type, err, "Error getting step count") break } // Encode battery level to connection @@ -251,7 +251,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeWatchStepCount: stepCountCh, cancel, err := dev.WatchStepCount() if err != nil { - connErr(conn, err, "Error getting heart rate channel") + connErr(conn, req.Type, err, "Error getting heart rate channel") break } reqID := uuid.New().String() @@ -279,7 +279,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Get firmware version from watch version, err := dev.Version() if err != nil { - connErr(conn, err, "Error getting firmware version") + connErr(conn, req.Type, err, "Error getting firmware version") break } // Encode version to connection @@ -296,14 +296,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeNotify: // If no data, return error if req.Data == nil { - connErr(conn, nil, "Data required for notify request") + connErr(conn, req.Type, nil, "Data required for notify request") break } var reqData types.ReqDataNotify // Decode data map to notify request data err = mapstructure.Decode(req.Data, &reqData) if err != nil { - connErr(conn, err, "Error decoding request data") + connErr(conn, req.Type, err, "Error decoding request data") break } maps := viper.GetStringSlice("notifs.translit.use") @@ -313,7 +313,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Send notification to watch err = dev.Notify(title, body) if err != nil { - connErr(conn, err, "Error sending notification") + connErr(conn, req.Type, err, "Error sending notification") break } // Encode empty types.Response to connection @@ -321,13 +321,13 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeSetTime: // If no data, return error if req.Data == nil { - connErr(conn, nil, "Data required for settime request") + connErr(conn, req.Type, nil, "Data required for settime request") break } // Get string from data or return error reqTimeStr, ok := req.Data.(string) if !ok { - connErr(conn, nil, "Data for settime request must be RFC3339 formatted time string") + connErr(conn, req.Type, nil, "Data for settime request must be RFC3339 formatted time string") break } @@ -338,14 +338,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Parse time as RFC3339/ISO8601 reqTime, err = time.Parse(time.RFC3339, reqTimeStr) if err != nil { - connErr(conn, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`") + connErr(conn, req.Type, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`") break } } // Set time on watch err = dev.SetTime(reqTime) if err != nil { - connErr(conn, err, "Error setting device time") + connErr(conn, req.Type, err, "Error setting device time") break } // Encode empty types.Response to connection @@ -353,14 +353,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.ReqTypeFwUpgrade: // If no data, return error if req.Data == nil { - connErr(conn, nil, "Data required for firmware upgrade request") + connErr(conn, req.Type, nil, "Data required for firmware upgrade request") break } var reqData types.ReqDataFwUpgrade // Decode data map to firmware upgrade request data err = mapstructure.Decode(req.Data, &reqData) if err != nil { - connErr(conn, err, "Error decoding request data") + connErr(conn, req.Type, err, "Error decoding request data") break } // Reset DFU to prepare for next update @@ -369,40 +369,40 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { case types.UpgradeTypeArchive: // If less than one file, return error if len(reqData.Files) < 1 { - connErr(conn, nil, "Archive upgrade requires one file with .zip extension") + connErr(conn, req.Type, nil, "Archive upgrade requires one file with .zip extension") break } // If file is not zip archive, return error if filepath.Ext(reqData.Files[0]) != ".zip" { - connErr(conn, nil, "Archive upgrade file must be a zip archive") + connErr(conn, req.Type, nil, "Archive upgrade file must be a zip archive") break } // Load DFU archive err := dev.DFU.LoadArchive(reqData.Files[0]) if err != nil { - connErr(conn, err, "Error loading archive file") + connErr(conn, req.Type, err, "Error loading archive file") break } case types.UpgradeTypeFiles: // If less than two files, return error if len(reqData.Files) < 2 { - connErr(conn, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.") + connErr(conn, req.Type, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.") break } // If first file is not init packet, return error if filepath.Ext(reqData.Files[0]) != ".dat" { - connErr(conn, nil, "First file must be a .dat file") + connErr(conn, req.Type, nil, "First file must be a .dat file") break } // If second file is not firmware image, return error if filepath.Ext(reqData.Files[1]) != ".bin" { - connErr(conn, nil, "Second file must be a .bin file") + connErr(conn, req.Type, nil, "Second file must be a .bin file") break } // Load individual DFU files err := dev.DFU.LoadFiles(reqData.Files[0], reqData.Files[1]) if err != nil { - connErr(conn, err, "Error loading firmware files") + connErr(conn, req.Type, err, "Error loading firmware files") break } } @@ -426,19 +426,19 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { // Start DFU err = dev.DFU.Start() if err != nil { - connErr(conn, err, "Error performing upgrade") + connErr(conn, req.Type, err, "Error performing upgrade") firmwareUpdating = false break } firmwareUpdating = false case types.ReqTypeCancel: if req.Data == nil { - connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.") + connErr(conn, req.Type, nil, "No data provided. Cancel request requires request ID string as data.") continue } reqID, ok := req.Data.(string) if !ok { - connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.") + connErr(conn, req.Type, nil, "Invalid data. Cancel request required request ID string as data.") } // Stop notifications done.Done(reqID) @@ -447,7 +447,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { } } -func connErr(conn net.Conn, err error, msg string) { +func connErr(conn net.Conn, resType int, err error, msg string) { var res types.Response // If error exists, add to types.Response, otherwise don't if err != nil { @@ -455,7 +455,7 @@ func connErr(conn net.Conn, err error, msg string) { res = types.Response{Message: fmt.Sprintf("%s: %s", msg, err)} } else { log.Error().Msg(msg) - res = types.Response{Message: msg, Type: math.MaxInt} + res = types.Response{Message: msg, Type: resType} } res.Error = true