This repository has been archived on 2021-07-08. You can view files and clone it, but cannot push or open issues or pull requests.
opensend/files.go

291 lines
9.0 KiB
Go

/*
Copyright © 2021 Arsen Musayelyan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"bufio"
"bytes"
"ekyu.moe/base91"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"strconv"
"strings"
)
// Encode byte slice to base91 encoded, escaped string for transmission
func EncodeToSafeString(data []byte) string {
// Encode data to base91 string
base91Data := base91.EncodeToString(data)
// Replace all semicolons with "-" as semicolon is used as a separator
escapedBase91Data := strings.ReplaceAll(base91Data, ";", "-")
// Return escaped base91 string
return escapedBase91Data
}
// Decode base91 encoded, escaped string to byte slice
func DecodeSafeString(base91Str string) ([]byte, error) {
// Replace "-" with semicolon to reverse escape
base91Data := strings.ReplaceAll(base91Str, "-", ";")
// Decode unescaped base91 string
data := base91.DecodeString(base91Data)
// Return byte slice
return data, nil
}
// Save encrypted key to file
func SaveEncryptedKey(encryptedKey []byte, filePath string) {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Create file at given file path
keyFile, err := os.Create(filePath)
if err != nil {
log.Fatal().Err(err).Msg("Error creating file")
}
// Close file at the end of this function
defer keyFile.Close()
// Write encrypted key to file
bytesWritten, err := keyFile.Write(encryptedKey)
if err != nil {
log.Fatal().Err(err).Msg("Error writing key to file")
}
// Log bytes written
log.Info().Str("file", filepath.Base(filePath)).Msg("Wrote " + strconv.Itoa(bytesWritten) + " bytes")
}
// Create HTTP server to transmit files
func SendFiles(dir string) {
// Use ConsoleWriter logger with normal FatalHook
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Create TCP listener on port 9898
listener, err := net.Listen("tcp", ":9898")
if err != nil {
log.Fatal().Err(err).Msg("Error starting listener")
}
// Accept connection on listener
connection, err := listener.Accept()
if err != nil {
log.Fatal().Err(err).Msg("Error accepting connection")
}
// Close connection at the end of this function
defer connection.Close()
// Create for loop to listen for messages on connection
connectionLoop:
for {
// Use ConsoleWriter logger with TCPFatalHook
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(TCPFatalHook{conn: connection})
// Attempt to read new message on connection
data, err := bufio.NewReader(connection).ReadString('\n')
// If no message detected, try again
if err != nil && err.Error() == "EOF" {
continue
}
// If non-EOF error, fatally log
if err != nil {
log.Fatal().Err(err).Msg("Error reading data")
}
// Process received data
processedData := strings.Split(strings.TrimSpace(data), ";")
// If processedData is empty, alert the user of invalid data
if len(processedData) < 1 {
log.Fatal().Str("data", data).Msg("Received data invalid")
}
switch processedData[0] {
case "key":
// Inform user client has requested key
log.Info().Msg("Key requested")
// Read saved key
key, err := ioutil.ReadFile(dir + "/key.aes")
if err != nil {
log.Fatal().Err(err).Msg("Error reading key")
}
// Write saved key to ResponseWriter
_, err = fmt.Fprintln(connection, "OK;"+EncodeToSafeString(key)+";")
if err != nil {
log.Fatal().Err(err).Msg("Error writing response")
}
case "index":
// Inform user a client has requested the file index
log.Info().Msg("Index requested")
// Get directory listing
dirListing, err := ioutil.ReadDir(dir)
if err != nil {
log.Fatal().Err(err).Msg("Error reading directory")
}
// Create new slice to house filenames for index
var indexSlice []string
// For each file in listing
for _, file := range dirListing {
// If the file is not the key
if !strings.Contains(file.Name(), "key.aes") {
// Append the file path to indexSlice
indexSlice = append(indexSlice, file.Name())
}
}
// Join index slice into string
indexStr := strings.Join(indexSlice, "|")
// Write index to ResponseWriter
_, err = fmt.Fprintln(connection, "OK;"+indexStr+";")
if err != nil {
log.Fatal().Err(err).Msg("Error writing response")
}
case "file":
// If processedData only has one entry
if len(processedData) == 1 {
// Warn user of unexpected end of line
log.Warn().Err(errors.New("unexpected eol")).Msg("Invalid file request")
// Send error to connection
_, _ = fmt.Fprintln(connection, "ERR;")
// Break out of switch
break
}
// Set file to first path components of URL, excluding first /
file := processedData[1]
// Read file at specified location
fileData, err := ioutil.ReadFile(dir + "/" + file)
// If there was an error reading
if err != nil {
// Warn user of error
log.Warn().Err(err).Msg("Error reading file")
// Otherwise
} else {
// Inform user client has requested a file
log.Info().Str("file", file).Msg("File requested")
}
// Write file as base91 to connection
_, err = fmt.Fprintln(connection, "OK;"+EncodeToSafeString(fileData)+";")
if err != nil {
log.Fatal().Err(err).Msg("Error writing response")
}
case "stop":
// Alert user that stop signal has been received
log.Info().Msg("Received stop signal")
// Print ok message to connection
_, _ = fmt.Fprintln(connection, "OK;")
// Break out of connectionLoop
break connectionLoop
}
}
}
func ConnectToSender(senderAddr string) net.Conn {
// Get server address by getting the IP without the port, and appending :9898
serverAddr := strings.Split(senderAddr, ":")[0] + ":9898"
// Create error variable
var err error
// Create connection variable
var connection net.Conn
// Until break
for {
// Try connecting to sender
connection, err = net.Dial("tcp", serverAddr)
// If connection refused
if err != nil && strings.Contains(err.Error(), "connection refused") {
// Continue loop (retry)
continue
// If error other than connection refused
} else if err != nil {
// Fatally log
log.Fatal().Err(err).Msg("Error connecting to sender")
// If no error
} else {
// Break out of loop
break
}
}
// Returned created connection
return connection
}
// Get files from sender
func RecvFiles(connection net.Conn) {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Request index from sender
_, err := fmt.Fprintln(connection, "index;")
if err != nil {
log.Fatal().Err(err).Msg("Error sending index request")
}
// Read received message
message, err := bufio.NewReader(connection).ReadString('\n')
if err != nil {
log.Fatal().Err(err).Msg("Error getting index")
}
// Process received message
procMessage := strings.Split(strings.TrimSpace(message), ";")
// If non-ok code returned, fatally log
if procMessage[0] != "OK" {
log.Fatal().Err(err).Msg("Sender reported error")
}
// Get index from message
index := strings.Split(strings.TrimSpace(procMessage[1]), "|")
for _, file := range index {
// Get current file in index
_, err = fmt.Fprintln(connection, "file;"+file+";")
if err != nil {
log.Fatal().Err(err).Msg("Error sending file request")
}
// Read received message
message, err := bufio.NewReader(connection).ReadString('\n')
if err != nil {
log.Fatal().Err(err).Msg("Error getting file")
}
// Process received message
procMessage := strings.Split(message, ";")
// If non-ok code returned
if procMessage[0] != "OK" {
// fatally log
log.Fatal().Err(err).Msg("Sender reported error")
// Otherwise
} else {
// Create new file at index filepath
newFile, err := os.Create(*workDir + "/" + file)
if err != nil {
log.Fatal().Err(err).Msg("Error creating file")
}
// Decode file data from base91 string
fileData, err := DecodeSafeString(strings.TrimSpace(procMessage[1]))
if err != nil {
log.Fatal().Err(err).Msg("Error decoding hex")
}
// Copy response body to new file
bytesWritten, err := io.Copy(newFile, bytes.NewBuffer(fileData))
if err != nil {
log.Fatal().Err(err).Msg("Error writing to file")
}
// Log bytes written
log.Info().Str("file", filepath.Base(file)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes")
// Close new file
newFile.Close()
}
}
}
// Send stop signal to sender
func SendSrvStopSignal(connection net.Conn) {
// Send stop signal to connection
_, _ = fmt.Fprintln(connection, "stop;")
// Close connection
_ = connection.Close()
}