From 1f1bb67431c5f5ebd72d6759dde3d59fc368f46d Mon Sep 17 00:00:00 2001 From: Arsen Musayelyan Date: Mon, 7 Dec 2020 17:15:38 -0800 Subject: [PATCH] Use raw TCP instead of HTTP to transfer files --- config.go | 36 ++++++- files.go | 273 +++++++++++++++++++++++++++++---------------------- keyCrypto.go | 35 ++++--- logging.go | 23 +++++ main.go | 22 +++-- 5 files changed, 243 insertions(+), 146 deletions(-) diff --git a/config.go b/config.go index a70af89..f0a39df 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog/log" "io" "io/ioutil" + "net/url" "os" "path/filepath" "strconv" @@ -25,6 +26,24 @@ func NewConfig(actionType string, actionData string) *Config { return &Config{ActionType: actionType, ActionData: actionData} } +func (config *Config) Validate() { + // Parse URL in config + urlParser, err := url.Parse(config.ActionData) + // If there was an error parsing + if err != nil { + // Alert user of invalid url + log.Fatal().Err(err).Msg("Invalid URL") + // If scheme is not detected + } else if urlParser.Scheme == "" { + // Alert user of invalid scheme + log.Fatal().Msg("Invalid URL scheme") + // If host is not detected + } else if urlParser.Host == "" { + // Alert user of invalid host + log.Fatal().Msg("Invalid URL host") + } +} + // Create config file func (config *Config) CreateFile(dir string) { // Use ConsoleWriter logger @@ -137,8 +156,23 @@ func (config *Config) ExecuteAction(srcDir string, destDir string) { if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") } // If action is url } else if config.ActionType == "url" { + // Parse received URL + urlParser, err := url.Parse(config.ActionData) + // If there was an error parsing + if err != nil { + // Alert user of invalid url + log.Fatal().Err(err).Msg("Invalid URL") + // If scheme is not detected + } else if urlParser.Scheme == "" { + // Alert user of invalid scheme + log.Fatal().Msg("Invalid URL scheme") + // If host is not detected + } else if urlParser.Host == "" { + // Alert user of invalid host + log.Fatal().Msg("Invalid URL host") + } // Attempt to open URL in browser - err := browser.OpenURL(config.ActionData) + err = browser.OpenURL(config.ActionData) if err != nil { log.Fatal().Err(err).Msg("Error opening browser") } // If action is dir } else if config.ActionType == "dir" { diff --git a/files.go b/files.go index cd23582..9b75981 100644 --- a/files.go +++ b/files.go @@ -1,19 +1,20 @@ package main import ( - "context" + "bufio" + "bytes" + "encoding/hex" + "errors" "fmt" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "io" "io/ioutil" "net" - "net/http" "os" "path/filepath" "strconv" "strings" - "time" ) // Save encrypted key to file @@ -34,147 +35,179 @@ func SaveEncryptedKey(encryptedKey []byte, filePath string) { // Create HTTP server to transmit files func SendFiles(dir string) { - // Use ConsoleWriter logger + // Use ConsoleWriter logger with normal FatalHook log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) - // Instantiate http.Server struct - srv := &http.Server{} - // Listen on all ipv4 addresses on port 9898 + // Create TCP listener on port 9898 listener, err := net.Listen("tcp", ":9898") if err != nil { log.Fatal().Err(err).Msg("Error starting listener") } - - // If client connects to /:filePath - http.HandleFunc("/", func(res http.ResponseWriter, req *http.Request) { - // Set file to first path components of URL, excluding first / - file := req.URL.Path[1:] - // Read file at specified location - fileData, err := ioutil.ReadFile(dir + "/" + file) - // If there was an error reading - if err != nil { - // Warn user of error - log.Warn().Err(err).Msg("Error reading file") - // Otherwise - } else { - // Inform user client has requested a file - log.Info().Str("file", file).Msg("GET File") - } - // Write file to ResponseWriter - _, err = fmt.Fprint(res, string(fileData)) - if err != nil { log.Fatal().Err(err).Msg("Error writing response") } - }) - - // If client connects to /index - http.HandleFunc("/index", func(res http.ResponseWriter, req *http.Request) { - // Inform user a client has requested the file index - log.Info().Msg("GET Index") - // Get directory listing - dirListing, err := ioutil.ReadDir(dir) - if err != nil { log.Fatal().Err(err).Msg("Error reading directory") } - // Create new slice to house filenames for index - var indexSlice []string - // For each file in listing - for _, file := range dirListing { - // If the file is not the key - if !strings.Contains(file.Name(), "key.aes") { - // Append the file path to indexSlice - indexSlice = append(indexSlice, file.Name()) + // Accept connection on listener + connection, err := listener.Accept() + if err != nil { log.Fatal().Err(err).Msg("Error accepting connection") } + // Close connection at the end of this function + defer connection.Close() + // Create for loop to listen for messages on connection + connectionLoop: for { + // Use ConsoleWriter logger with TCPFatalHook + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(TCPFatalHook{conn: connection}) + // Attempt to read new message on connection + data, err := bufio.NewReader(connection).ReadString('\n') + // If no message detected, try again + if err != nil && err.Error() == "EOF" { continue } + // If non-EOF error, fatally log + if err != nil { log.Fatal().Err(err).Msg("Error reading data") } + // Process received data + processedData := strings.Split(strings.TrimSpace(data), ";") + // If processedData is empty, alert the user of invalid data + if len(processedData) < 1 { log.Fatal().Str("data", data).Msg("Received data invalid") } + switch processedData[0] { + case "key": + // Inform user client has requested key + log.Info().Msg("Key requested") + // Read saved key + key, err := ioutil.ReadFile(dir + "/key.aes") + if err != nil { log.Fatal().Err(err).Msg("Error reading key") } + // Write saved key to ResponseWriter + _, err = fmt.Fprintln(connection, "OK;" + hex.EncodeToString(key) + ";") + if err != nil { log.Fatal().Err(err).Msg("Error writing response") } + case "index": + // Inform user a client has requested the file index + log.Info().Msg("Index requested") + // Get directory listing + dirListing, err := ioutil.ReadDir(dir) + if err != nil { log.Fatal().Err(err).Msg("Error reading directory") } + // Create new slice to house filenames for index + var indexSlice []string + // For each file in listing + for _, file := range dirListing { + // If the file is not the key + if !strings.Contains(file.Name(), "key.aes") { + // Append the file path to indexSlice + indexSlice = append(indexSlice, file.Name()) + } } + // Join index slice into string + indexStr := strings.Join(indexSlice, "|") + // Write index to ResponseWriter + _, err = fmt.Fprintln(connection, "OK;" + indexStr + ";") + if err != nil { log.Fatal().Err(err).Msg("Error writing response") } + case "file": + // If processedData only has one entry + if len(processedData) == 1 { + // Warn user of unexpected end of line + log.Warn().Err(errors.New("unexpected eol")).Msg("Invalid file request") + // Send error to connection + _, _ = fmt.Fprintln(connection, "ERR;") + // Break out of switch + break + } + // Set file to first path components of URL, excluding first / + file := processedData[1] + // Read file at specified location + fileData, err := ioutil.ReadFile(dir + "/" + file) + // If there was an error reading + if err != nil { + // Warn user of error + log.Warn().Err(err).Msg("Error reading file") + // Otherwise + } else { + // Inform user client has requested a file + log.Info().Str("file", file).Msg("File requested") + } + // Write file as hex to connection + _, err = fmt.Fprintln(connection, "OK;" + hex.EncodeToString(fileData) + ";") + if err != nil { log.Fatal().Err(err).Msg("Error writing response") } + case "stop": + // Alert user that stop signal has been received + log.Info().Msg("Received stop signal") + // Print ok message to connection + _, _ = fmt.Fprintln(connection, "OK;") + // Break out of connectionLoop + break connectionLoop } - // Join index slice into string - indexStr := strings.Join(indexSlice, ";") - // Write index to ResponseWriter - _, err = fmt.Fprint(res, indexStr) - if err != nil { log.Fatal().Err(err).Msg("Error writing response") } - }) + } +} - // If client connects to /key - http.HandleFunc("/key", func(res http.ResponseWriter, req *http.Request) { - // Inform user a client has requested the key - log.Info().Msg("GET Key") - // Read saved key - key, err := ioutil.ReadFile(dir + "/key.aes") - if err != nil { log.Fatal().Err(err).Msg("Error reading key") } - // Write saved key to ResponseWriter - _, err = fmt.Fprint(res, string(key)) - if err != nil { log.Fatal().Err(err).Msg("Error writing response") } - }) - - // If client connects to /stop - http.HandleFunc("/stop", func(res http.ResponseWriter, req *http.Request) { - // Inform user a client has requested server shutdown - log.Info().Msg("GET Stop") - log.Info().Msg("Stop signal received") - // Shutdown server and send to empty context - err := srv.Shutdown(context.Background()) - if err != nil { log.Fatal().Err(err).Msg("Error stopping server") } - }) - - // Start HTTP Server - _ = srv.Serve(listener) +func ConnectToSender(senderAddr string) net.Conn { + // Get server address by getting the IP without the port, and appending :9898 + serverAddr := strings.Split(senderAddr, ":")[0] + ":9898" + // Create error variable + var err error + // Create connection variable + var connection net.Conn + // Until break + for { + // Try connecting to sender + connection, err = net.Dial("tcp", serverAddr) + // If connection refused + if err != nil && strings.Contains(err.Error(), "connection refused") { + // Continue loop (retry) + continue + // If error other than connection refused + } else if err != nil { + // Fatally log + log.Fatal().Err(err).Msg("Error connecting to sender") + // If no error + } else { + // Break out of loop + break + } + } + // Returned created connection + return connection } // Get files from sender -func RecvFiles(senderAddr string) { +func RecvFiles(connection net.Conn) { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) - // Get server address by getting the IP without the port, prepending http:// and appending :9898 - serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898" - var response *http.Response - // GET /index on sender's HTTP server - response, err := http.Get(serverAddr + "/index") - // If error occurred, retry every 500ms - if err != nil { - // Set index failed to true - indexGetFailed := true - for indexGetFailed { - // GET /index on sender's HTTP server - response, err = http.Get(serverAddr + "/index") - // If no error, set index failed to false - if err == nil { indexGetFailed = false } - // Wait 500ms - time.Sleep(500*time.Millisecond) - } - } - // Close response body at the end of this function - defer response.Body.Close() - // Create index slice for storage of file index - var index []string - // If server responded with 200 OK - if response.StatusCode == http.StatusOK { - // Read response body - body, err := ioutil.ReadAll(response.Body) - if err != nil { log.Fatal().Err(err).Msg("Error reading HTTP response") } - // Get string from body - bodyStr := string(body) - // Split string to form index - index = strings.Split(bodyStr, ";") - } - // For each file in the index + // Request index from sender + _, err := fmt.Fprintln(connection, "index;") + if err != nil { log.Fatal().Err(err).Msg("Error sending index request") } + // Read received message + message, err := bufio.NewReader(connection).ReadString('\n') + if err != nil { log.Fatal().Err(err).Msg("Error getting index") } + // Process received message + procMessage := strings.Split(strings.TrimSpace(message), ";") + // If non-ok code returned, fatally log + if procMessage[0] != "OK" { log.Fatal().Err(err).Msg("Sender reported error") } + // Get index from message + index := strings.Split(strings.TrimSpace(procMessage[1]), "|") for _, file := range index { - // GET current file in index - response, err := http.Get(serverAddr + "/" + filepath.Base(file)) + // Get current file in index + _, err = fmt.Fprintln(connection, "file;" + file + ";") + if err != nil { log.Fatal().Err(err).Msg("Error sending file request") } + // Read received message + message, err := bufio.NewReader(connection).ReadString('\n') if err != nil { log.Fatal().Err(err).Msg("Error getting file") } - // If server responded with 200 OK - if response.StatusCode == http.StatusOK { + // Process received message + procMessage := strings.Split(message, ";") + // If non-ok code returned + if procMessage[0] != "OK" { + // fatally log + log.Fatal().Err(err).Msg("Sender reported error") + // Otherwise + } else { // Create new file at index filepath newFile, err := os.Create(opensendDir + "/" + file) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } + // Decode file data from hex string + fileData, err := hex.DecodeString(strings.TrimSpace(procMessage[1])) + if err != nil { log.Fatal().Err(err).Msg("Error decoding hex") } // Copy response body to new file - bytesWritten, err := io.Copy(newFile, response.Body) + bytesWritten, err := io.Copy(newFile, bytes.NewBuffer(fileData)) if err != nil { log.Fatal().Err(err).Msg("Error writing to file") } // Log bytes written log.Info().Str("file", filepath.Base(file)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes") // Close new file newFile.Close() } - // Close response body - response.Body.Close() } } -// Send stop signal to sender's HTTP server -func SendSrvStopSignal(senderAddr string) { - // Get server address by getting the IP without the port, prepending http:// and appending :9898 - serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898" - // GET /stop on sender's HTTP servers ignoring any errors - _, _ = http.Get(serverAddr + "/stop") +// Send stop signal to sender +func SendSrvStopSignal(connection net.Conn) { + // Send stop signal to connection + _, _ = fmt.Fprintln(connection, "stop;") + // Close connection + _ = connection.Close() } \ No newline at end of file diff --git a/keyCrypto.go b/keyCrypto.go index 51929c3..78074c7 100644 --- a/keyCrypto.go +++ b/keyCrypto.go @@ -1,13 +1,15 @@ package main import ( + "bufio" "crypto/rand" "crypto/rsa" "crypto/sha256" + "encoding/hex" + "fmt" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "io/ioutil" - "net/http" + "net" "os" "strings" ) @@ -26,27 +28,28 @@ func GenerateRSAKeypair() (*rsa.PrivateKey, *rsa.PublicKey) { } // Get public key from sender -func GetKey(senderAddr string) []byte { +func GetKey(connection net.Conn) []byte { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) - // Get server address by getting the IP without the port, prepending http:// and appending :9898 - serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898" - // GET /key on the sender's HTTP server - response, err := http.Get(serverAddr + "/key") + // Send key request to connection + _, err := fmt.Fprintln(connection, "key;") + if err != nil { log.Fatal().Err(err).Msg("Error sending key request") } + // Read received message + message, err := bufio.NewReader(connection).ReadString('\n') if err != nil { log.Fatal().Err(err).Msg("Error getting key") } - // Close response body at the end of this function - defer response.Body.Close() - // If server responded with 200 OK - if response.StatusCode == http.StatusOK { - // Read response body into key - key, err := ioutil.ReadAll(response.Body) - if err != nil { log.Fatal().Err(err).Msg("Error reading HTTP response") } + // Process received message + procMessage := strings.Split(strings.TrimSpace(message), ";") + // If ok code returned + if procMessage[0] == "OK" { + // Decode received hex string into key + key, err := hex.DecodeString(procMessage[1]) + if err != nil { log.Fatal().Err(err).Msg("Error reading key") } // Return key return key // Otherwise } else { - // Fatally log status code - if err != nil { log.Fatal().Int("code", response.StatusCode).Msg("HTTP Error Response Code Received") } + // Fatally log + if err != nil { log.Fatal().Msg("Server reported error") } } // Return nil if all else fails return nil diff --git a/logging.go b/logging.go index 24abb20..9f4030a 100644 --- a/logging.go +++ b/logging.go @@ -1,12 +1,16 @@ package main import ( + "fmt" "github.com/rs/zerolog" + "net" "os" ) +// Fatal hook to run in case of Fatal error type FatalHook struct {} +// Run function on trigger func (hook FatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) { // If log event is fatal if level == zerolog.FatalLevel { @@ -14,3 +18,22 @@ func (hook FatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) { _ = os.RemoveAll(opensendDir) } } + +// TCP Fatal hook to run in case of Fatal error with open TCP connection +type TCPFatalHook struct { + conn net.Conn +} + +// Run function on trigger +func (hook TCPFatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) { + // If log event is fatal + if level == zerolog.FatalLevel { + // Send error to connection + _, _ = fmt.Fprintln(hook.conn, "ERR;") + // Close connection + _ = hook.conn.Close() + // Attempt removal of opensend directory + _ = os.RemoveAll(opensendDir) + } + +} \ No newline at end of file diff --git a/main.go b/main.go index e978779..76b2b80 100644 --- a/main.go +++ b/main.go @@ -112,6 +112,14 @@ func main() { // Get IP of chosen receiver choiceIP = discoveredIPs[choiceIndex] } + // Instantiate Config object + config := NewConfig(*actionType, *actionData) + // Validate data in config struct + config.Validate() + // Collect any files that may be required for transaction into opensend directory + config.CollectFiles(opensendDir) + // Create config file in opensend directory + config.CreateFile(opensendDir) // Notify user of key exchange log.Info().Msg("Performing key exchange") // Exchange RSA keys with receiver @@ -122,12 +130,6 @@ func main() { key := EncryptKey(sharedKey, rawKey) // Save encrypted key in opensend directory as key.aes SaveEncryptedKey(key, opensendDir + "/key.aes") - // Instantiate Config object - config := NewConfig(*actionType, *actionData) - // Collect any files that may be required for transaction into opensend directory - config.CollectFiles(opensendDir) - // Create config file in opensend directory - config.CreateFile(opensendDir) // Notify user file encryption is beginning log.Info().Msg("Encrypting files") // Encrypt all files in opensend directory using shared key @@ -157,12 +159,14 @@ func main() { time.Sleep(300*time.Millisecond) // Notify user files are being received log.Info().Msg("Receiving files from server (This may take a while)") + // Connect to sender's TCP socket + connection := ConnectToSender(senderIP) // Get files from sender and place them into the opensend directory - RecvFiles(senderIP) + RecvFiles(connection) // Get encrypted shared key from sender - encryptedKey := GetKey(senderIP) + encryptedKey := GetKey(connection) // Send stop signal to sender's HTTP server - SendSrvStopSignal(senderIP) + SendSrvStopSignal(connection) // Decrypt shared key sharedKey := DecryptKey(encryptedKey, privateKey) // Notify user file decryption is beginning