Use raw TCP instead of HTTP to transfer files

This commit is contained in:
Elara 2020-12-07 17:15:38 -08:00
parent 8331e6b543
commit 1f1bb67431
5 changed files with 243 additions and 146 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog/log"
"io"
"io/ioutil"
"net/url"
"os"
"path/filepath"
"strconv"
@ -25,6 +26,24 @@ func NewConfig(actionType string, actionData string) *Config {
return &Config{ActionType: actionType, ActionData: actionData}
}
func (config *Config) Validate() {
// Parse URL in config
urlParser, err := url.Parse(config.ActionData)
// If there was an error parsing
if err != nil {
// Alert user of invalid url
log.Fatal().Err(err).Msg("Invalid URL")
// If scheme is not detected
} else if urlParser.Scheme == "" {
// Alert user of invalid scheme
log.Fatal().Msg("Invalid URL scheme")
// If host is not detected
} else if urlParser.Host == "" {
// Alert user of invalid host
log.Fatal().Msg("Invalid URL host")
}
}
// Create config file
func (config *Config) CreateFile(dir string) {
// Use ConsoleWriter logger
@ -137,8 +156,23 @@ func (config *Config) ExecuteAction(srcDir string, destDir string) {
if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") }
// If action is url
} else if config.ActionType == "url" {
// Parse received URL
urlParser, err := url.Parse(config.ActionData)
// If there was an error parsing
if err != nil {
// Alert user of invalid url
log.Fatal().Err(err).Msg("Invalid URL")
// If scheme is not detected
} else if urlParser.Scheme == "" {
// Alert user of invalid scheme
log.Fatal().Msg("Invalid URL scheme")
// If host is not detected
} else if urlParser.Host == "" {
// Alert user of invalid host
log.Fatal().Msg("Invalid URL host")
}
// Attempt to open URL in browser
err := browser.OpenURL(config.ActionData)
err = browser.OpenURL(config.ActionData)
if err != nil { log.Fatal().Err(err).Msg("Error opening browser") }
// If action is dir
} else if config.ActionType == "dir" {

273
files.go
View File

@ -1,19 +1,20 @@
package main
import (
"context"
"bufio"
"bytes"
"encoding/hex"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
// Save encrypted key to file
@ -34,147 +35,179 @@ func SaveEncryptedKey(encryptedKey []byte, filePath string) {
// Create HTTP server to transmit files
func SendFiles(dir string) {
// Use ConsoleWriter logger
// Use ConsoleWriter logger with normal FatalHook
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Instantiate http.Server struct
srv := &http.Server{}
// Listen on all ipv4 addresses on port 9898
// Create TCP listener on port 9898
listener, err := net.Listen("tcp", ":9898")
if err != nil { log.Fatal().Err(err).Msg("Error starting listener") }
// If client connects to /:filePath
http.HandleFunc("/", func(res http.ResponseWriter, req *http.Request) {
// Set file to first path components of URL, excluding first /
file := req.URL.Path[1:]
// Read file at specified location
fileData, err := ioutil.ReadFile(dir + "/" + file)
// If there was an error reading
if err != nil {
// Warn user of error
log.Warn().Err(err).Msg("Error reading file")
// Otherwise
} else {
// Inform user client has requested a file
log.Info().Str("file", file).Msg("GET File")
}
// Write file to ResponseWriter
_, err = fmt.Fprint(res, string(fileData))
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
})
// If client connects to /index
http.HandleFunc("/index", func(res http.ResponseWriter, req *http.Request) {
// Inform user a client has requested the file index
log.Info().Msg("GET Index")
// Get directory listing
dirListing, err := ioutil.ReadDir(dir)
if err != nil { log.Fatal().Err(err).Msg("Error reading directory") }
// Create new slice to house filenames for index
var indexSlice []string
// For each file in listing
for _, file := range dirListing {
// If the file is not the key
if !strings.Contains(file.Name(), "key.aes") {
// Append the file path to indexSlice
indexSlice = append(indexSlice, file.Name())
// Accept connection on listener
connection, err := listener.Accept()
if err != nil { log.Fatal().Err(err).Msg("Error accepting connection") }
// Close connection at the end of this function
defer connection.Close()
// Create for loop to listen for messages on connection
connectionLoop: for {
// Use ConsoleWriter logger with TCPFatalHook
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(TCPFatalHook{conn: connection})
// Attempt to read new message on connection
data, err := bufio.NewReader(connection).ReadString('\n')
// If no message detected, try again
if err != nil && err.Error() == "EOF" { continue }
// If non-EOF error, fatally log
if err != nil { log.Fatal().Err(err).Msg("Error reading data") }
// Process received data
processedData := strings.Split(strings.TrimSpace(data), ";")
// If processedData is empty, alert the user of invalid data
if len(processedData) < 1 { log.Fatal().Str("data", data).Msg("Received data invalid") }
switch processedData[0] {
case "key":
// Inform user client has requested key
log.Info().Msg("Key requested")
// Read saved key
key, err := ioutil.ReadFile(dir + "/key.aes")
if err != nil { log.Fatal().Err(err).Msg("Error reading key") }
// Write saved key to ResponseWriter
_, err = fmt.Fprintln(connection, "OK;" + hex.EncodeToString(key) + ";")
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
case "index":
// Inform user a client has requested the file index
log.Info().Msg("Index requested")
// Get directory listing
dirListing, err := ioutil.ReadDir(dir)
if err != nil { log.Fatal().Err(err).Msg("Error reading directory") }
// Create new slice to house filenames for index
var indexSlice []string
// For each file in listing
for _, file := range dirListing {
// If the file is not the key
if !strings.Contains(file.Name(), "key.aes") {
// Append the file path to indexSlice
indexSlice = append(indexSlice, file.Name())
}
}
// Join index slice into string
indexStr := strings.Join(indexSlice, "|")
// Write index to ResponseWriter
_, err = fmt.Fprintln(connection, "OK;" + indexStr + ";")
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
case "file":
// If processedData only has one entry
if len(processedData) == 1 {
// Warn user of unexpected end of line
log.Warn().Err(errors.New("unexpected eol")).Msg("Invalid file request")
// Send error to connection
_, _ = fmt.Fprintln(connection, "ERR;")
// Break out of switch
break
}
// Set file to first path components of URL, excluding first /
file := processedData[1]
// Read file at specified location
fileData, err := ioutil.ReadFile(dir + "/" + file)
// If there was an error reading
if err != nil {
// Warn user of error
log.Warn().Err(err).Msg("Error reading file")
// Otherwise
} else {
// Inform user client has requested a file
log.Info().Str("file", file).Msg("File requested")
}
// Write file as hex to connection
_, err = fmt.Fprintln(connection, "OK;" + hex.EncodeToString(fileData) + ";")
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
case "stop":
// Alert user that stop signal has been received
log.Info().Msg("Received stop signal")
// Print ok message to connection
_, _ = fmt.Fprintln(connection, "OK;")
// Break out of connectionLoop
break connectionLoop
}
// Join index slice into string
indexStr := strings.Join(indexSlice, ";")
// Write index to ResponseWriter
_, err = fmt.Fprint(res, indexStr)
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
})
}
}
// If client connects to /key
http.HandleFunc("/key", func(res http.ResponseWriter, req *http.Request) {
// Inform user a client has requested the key
log.Info().Msg("GET Key")
// Read saved key
key, err := ioutil.ReadFile(dir + "/key.aes")
if err != nil { log.Fatal().Err(err).Msg("Error reading key") }
// Write saved key to ResponseWriter
_, err = fmt.Fprint(res, string(key))
if err != nil { log.Fatal().Err(err).Msg("Error writing response") }
})
// If client connects to /stop
http.HandleFunc("/stop", func(res http.ResponseWriter, req *http.Request) {
// Inform user a client has requested server shutdown
log.Info().Msg("GET Stop")
log.Info().Msg("Stop signal received")
// Shutdown server and send to empty context
err := srv.Shutdown(context.Background())
if err != nil { log.Fatal().Err(err).Msg("Error stopping server") }
})
// Start HTTP Server
_ = srv.Serve(listener)
func ConnectToSender(senderAddr string) net.Conn {
// Get server address by getting the IP without the port, and appending :9898
serverAddr := strings.Split(senderAddr, ":")[0] + ":9898"
// Create error variable
var err error
// Create connection variable
var connection net.Conn
// Until break
for {
// Try connecting to sender
connection, err = net.Dial("tcp", serverAddr)
// If connection refused
if err != nil && strings.Contains(err.Error(), "connection refused") {
// Continue loop (retry)
continue
// If error other than connection refused
} else if err != nil {
// Fatally log
log.Fatal().Err(err).Msg("Error connecting to sender")
// If no error
} else {
// Break out of loop
break
}
}
// Returned created connection
return connection
}
// Get files from sender
func RecvFiles(senderAddr string) {
func RecvFiles(connection net.Conn) {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Get server address by getting the IP without the port, prepending http:// and appending :9898
serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898"
var response *http.Response
// GET /index on sender's HTTP server
response, err := http.Get(serverAddr + "/index")
// If error occurred, retry every 500ms
if err != nil {
// Set index failed to true
indexGetFailed := true
for indexGetFailed {
// GET /index on sender's HTTP server
response, err = http.Get(serverAddr + "/index")
// If no error, set index failed to false
if err == nil { indexGetFailed = false }
// Wait 500ms
time.Sleep(500*time.Millisecond)
}
}
// Close response body at the end of this function
defer response.Body.Close()
// Create index slice for storage of file index
var index []string
// If server responded with 200 OK
if response.StatusCode == http.StatusOK {
// Read response body
body, err := ioutil.ReadAll(response.Body)
if err != nil { log.Fatal().Err(err).Msg("Error reading HTTP response") }
// Get string from body
bodyStr := string(body)
// Split string to form index
index = strings.Split(bodyStr, ";")
}
// For each file in the index
// Request index from sender
_, err := fmt.Fprintln(connection, "index;")
if err != nil { log.Fatal().Err(err).Msg("Error sending index request") }
// Read received message
message, err := bufio.NewReader(connection).ReadString('\n')
if err != nil { log.Fatal().Err(err).Msg("Error getting index") }
// Process received message
procMessage := strings.Split(strings.TrimSpace(message), ";")
// If non-ok code returned, fatally log
if procMessage[0] != "OK" { log.Fatal().Err(err).Msg("Sender reported error") }
// Get index from message
index := strings.Split(strings.TrimSpace(procMessage[1]), "|")
for _, file := range index {
// GET current file in index
response, err := http.Get(serverAddr + "/" + filepath.Base(file))
// Get current file in index
_, err = fmt.Fprintln(connection, "file;" + file + ";")
if err != nil { log.Fatal().Err(err).Msg("Error sending file request") }
// Read received message
message, err := bufio.NewReader(connection).ReadString('\n')
if err != nil { log.Fatal().Err(err).Msg("Error getting file") }
// If server responded with 200 OK
if response.StatusCode == http.StatusOK {
// Process received message
procMessage := strings.Split(message, ";")
// If non-ok code returned
if procMessage[0] != "OK" {
// fatally log
log.Fatal().Err(err).Msg("Sender reported error")
// Otherwise
} else {
// Create new file at index filepath
newFile, err := os.Create(opensendDir + "/" + file)
if err != nil { log.Fatal().Err(err).Msg("Error creating file") }
// Decode file data from hex string
fileData, err := hex.DecodeString(strings.TrimSpace(procMessage[1]))
if err != nil { log.Fatal().Err(err).Msg("Error decoding hex") }
// Copy response body to new file
bytesWritten, err := io.Copy(newFile, response.Body)
bytesWritten, err := io.Copy(newFile, bytes.NewBuffer(fileData))
if err != nil { log.Fatal().Err(err).Msg("Error writing to file") }
// Log bytes written
log.Info().Str("file", filepath.Base(file)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes")
// Close new file
newFile.Close()
}
// Close response body
response.Body.Close()
}
}
// Send stop signal to sender's HTTP server
func SendSrvStopSignal(senderAddr string) {
// Get server address by getting the IP without the port, prepending http:// and appending :9898
serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898"
// GET /stop on sender's HTTP servers ignoring any errors
_, _ = http.Get(serverAddr + "/stop")
// Send stop signal to sender
func SendSrvStopSignal(connection net.Conn) {
// Send stop signal to connection
_, _ = fmt.Fprintln(connection, "stop;")
// Close connection
_ = connection.Close()
}

View File

@ -1,13 +1,15 @@
package main
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"io/ioutil"
"net/http"
"net"
"os"
"strings"
)
@ -26,27 +28,28 @@ func GenerateRSAKeypair() (*rsa.PrivateKey, *rsa.PublicKey) {
}
// Get public key from sender
func GetKey(senderAddr string) []byte {
func GetKey(connection net.Conn) []byte {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Get server address by getting the IP without the port, prepending http:// and appending :9898
serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898"
// GET /key on the sender's HTTP server
response, err := http.Get(serverAddr + "/key")
// Send key request to connection
_, err := fmt.Fprintln(connection, "key;")
if err != nil { log.Fatal().Err(err).Msg("Error sending key request") }
// Read received message
message, err := bufio.NewReader(connection).ReadString('\n')
if err != nil { log.Fatal().Err(err).Msg("Error getting key") }
// Close response body at the end of this function
defer response.Body.Close()
// If server responded with 200 OK
if response.StatusCode == http.StatusOK {
// Read response body into key
key, err := ioutil.ReadAll(response.Body)
if err != nil { log.Fatal().Err(err).Msg("Error reading HTTP response") }
// Process received message
procMessage := strings.Split(strings.TrimSpace(message), ";")
// If ok code returned
if procMessage[0] == "OK" {
// Decode received hex string into key
key, err := hex.DecodeString(procMessage[1])
if err != nil { log.Fatal().Err(err).Msg("Error reading key") }
// Return key
return key
// Otherwise
} else {
// Fatally log status code
if err != nil { log.Fatal().Int("code", response.StatusCode).Msg("HTTP Error Response Code Received") }
// Fatally log
if err != nil { log.Fatal().Msg("Server reported error") }
}
// Return nil if all else fails
return nil

View File

@ -1,12 +1,16 @@
package main
import (
"fmt"
"github.com/rs/zerolog"
"net"
"os"
)
// Fatal hook to run in case of Fatal error
type FatalHook struct {}
// Run function on trigger
func (hook FatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) {
// If log event is fatal
if level == zerolog.FatalLevel {
@ -14,3 +18,22 @@ func (hook FatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) {
_ = os.RemoveAll(opensendDir)
}
}
// TCP Fatal hook to run in case of Fatal error with open TCP connection
type TCPFatalHook struct {
conn net.Conn
}
// Run function on trigger
func (hook TCPFatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) {
// If log event is fatal
if level == zerolog.FatalLevel {
// Send error to connection
_, _ = fmt.Fprintln(hook.conn, "ERR;")
// Close connection
_ = hook.conn.Close()
// Attempt removal of opensend directory
_ = os.RemoveAll(opensendDir)
}
}

22
main.go
View File

@ -112,6 +112,14 @@ func main() {
// Get IP of chosen receiver
choiceIP = discoveredIPs[choiceIndex]
}
// Instantiate Config object
config := NewConfig(*actionType, *actionData)
// Validate data in config struct
config.Validate()
// Collect any files that may be required for transaction into opensend directory
config.CollectFiles(opensendDir)
// Create config file in opensend directory
config.CreateFile(opensendDir)
// Notify user of key exchange
log.Info().Msg("Performing key exchange")
// Exchange RSA keys with receiver
@ -122,12 +130,6 @@ func main() {
key := EncryptKey(sharedKey, rawKey)
// Save encrypted key in opensend directory as key.aes
SaveEncryptedKey(key, opensendDir + "/key.aes")
// Instantiate Config object
config := NewConfig(*actionType, *actionData)
// Collect any files that may be required for transaction into opensend directory
config.CollectFiles(opensendDir)
// Create config file in opensend directory
config.CreateFile(opensendDir)
// Notify user file encryption is beginning
log.Info().Msg("Encrypting files")
// Encrypt all files in opensend directory using shared key
@ -157,12 +159,14 @@ func main() {
time.Sleep(300*time.Millisecond)
// Notify user files are being received
log.Info().Msg("Receiving files from server (This may take a while)")
// Connect to sender's TCP socket
connection := ConnectToSender(senderIP)
// Get files from sender and place them into the opensend directory
RecvFiles(senderIP)
RecvFiles(connection)
// Get encrypted shared key from sender
encryptedKey := GetKey(senderIP)
encryptedKey := GetKey(connection)
// Send stop signal to sender's HTTP server
SendSrvStopSignal(senderIP)
SendSrvStopSignal(connection)
// Decrypt shared key
sharedKey := DecryptKey(encryptedKey, privateKey)
// Notify user file decryption is beginning