Implement dir type and add Zstd compression

This commit is contained in:
Elara 2020-12-05 16:47:04 -08:00
parent 6e71ba1c1c
commit 5b976e6bac
4 changed files with 110 additions and 11 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"archive/tar"
"encoding/json"
"github.com/pkg/browser"
"github.com/rs/zerolog"
@ -10,6 +11,7 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
)
// Create config type to store action type and data
@ -63,6 +65,42 @@ func (config *Config) CollectFiles(dir string) {
if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") }
// Replace file path in config.ActionData with file name
config.ActionData = filepath.Base(config.ActionData)
} else if config.ActionType == "dir" {
// Create tar archive
tarFile, err := os.Create(dir + "/" + filepath.Base(config.ActionData) + ".tar")
if err != nil { log.Fatal().Err(err).Msg("Error creating file") }
// Close tar file at the end of this function
defer tarFile.Close()
// Create writer for tar archive
tarArchiver := tar.NewWriter(tarFile)
// Close archiver at the end of this function
defer tarArchiver.Close()
// Walk given directory
err = filepath.Walk(config.ActionData, func(path string, info os.FileInfo, err error) error {
// Return if error walking
if err != nil { return err }
// Skip if file is not normal mode
if !info.Mode().IsRegular() { return nil }
// Create tar header for file
header, err := tar.FileInfoHeader(info, info.Name())
if err != nil { return err }
// Change header name to reflect decompressed filepath
header.Name = strings.TrimPrefix(strings.ReplaceAll(path, config.ActionData, ""), string(filepath.Separator))
// Write header to archive
if err := tarArchiver.WriteHeader(header); err != nil { return err }
// Open source file
src, err := os.Open(path)
if err != nil { return err }
// Close source file at the end of this function
defer src.Close()
// Copy source bytes to tar archive
if _, err := io.Copy(tarArchiver, src); err != nil { return err }
// Return at the end of the function
return nil
})
if err != nil { log.Fatal().Err(err).Msg("Error creating tar archive") }
// Set config data to base path for receiver
config.ActionData = filepath.Base(config.ActionData)
}
}
@ -80,6 +118,9 @@ func (config *Config) ReadFile(filePath string) {
// Execute action specified in config
func (config *Config) ExecuteAction(srcDir string) {
// Get user's home directory
homeDir, err := os.UserHomeDir()
if err != nil { log.Fatal().Err(err).Msg("Error getting home directory") }
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// If action is file
@ -89,9 +130,6 @@ func (config *Config) ExecuteAction(srcDir string) {
if err != nil { log.Fatal().Err(err).Msg("Error reading file from config") }
// Close source file at the end of this function
defer src.Close()
// Get user's home directory
homeDir, err := os.UserHomeDir()
if err != nil { log.Fatal().Err(err).Msg("Error getting home directory") }
// Create file in user's Downloads directory
dst, err := os.Create(homeDir + "/Downloads/" + config.ActionData)
if err != nil { log.Fatal().Err(err).Msg("Error creating file") }
@ -105,6 +143,51 @@ func (config *Config) ExecuteAction(srcDir string) {
// Attempt to open URL in browser
err := browser.OpenURL(config.ActionData)
if err != nil { log.Fatal().Err(err).Msg("Error opening browser") }
// If action is dir
} else if config.ActionType == "dir" {
// Set destination directory to ~/Downloads/{dir name}
dstDir := homeDir + "/Downloads/" + config.ActionData
// Try to create destination directory
err := os.Mkdir(dstDir, 0755)
if err != nil { log.Fatal().Err(err).Msg("Error creating directory") }
// Try to open tar archive file
tarFile, err := os.Open(srcDir + "/" + config.ActionData + ".tar")
if err != nil { log.Fatal().Err(err).Msg("Error opening tar archive") }
// Close tar archive file at the end of this function
defer tarFile.Close()
// Create tar reader to unarchive tar archive
tarUnarchiver := tar.NewReader(tarFile)
// Loop to recursively unarchive tar file
unarchiveLoop: for {
// Jump to next header in tar archive
header, err := tarUnarchiver.Next()
switch {
// If EOF
case err == io.EOF:
// break loop
break unarchiveLoop
case err != nil:
log.Fatal().Err(err).Msg("Error unarchiving tar archive")
// If nil header
case header == nil:
// Skip
continue
}
// Set target path to header name in destination dir
targetPath := filepath.Join(dstDir, header.Name)
switch header.Typeflag {
// If regular file
case tar.TypeReg:
// Try to create containing folder ignoring errors
_ = os.MkdirAll(strings.TrimSuffix(targetPath, filepath.Base(targetPath)), 0755)
// Create file with mode contained in header at target path
dstFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode))
if err != nil { log.Fatal().Err(err).Msg("Error creating file during unarchiving") }
// Copy data from tar archive into file
_, err = io.Copy(dstFile, tarUnarchiver)
if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") }
}
}
// Catchall
} else {
// Log unknown action type

View File

@ -1,11 +1,13 @@
package main
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/hex"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"io"
@ -17,12 +19,20 @@ import (
)
// Encrypt given file using the shared key
func EncryptFile(filePath string, newFilePath string, sharedKey string) {
func CompressAndEncryptFile(filePath string, newFilePath string, sharedKey string) {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Read data from file
data, err := ioutil.ReadFile(filePath)
file, err := os.Open(filePath)
if err != nil { log.Fatal().Err(err).Msg("Error opening file") }
compressedBuffer := new(bytes.Buffer)
zstdEncoder, err := zstd.NewWriter(compressedBuffer)
if err != nil { log.Fatal().Err(err).Msg("Error creating Zstd encoder") }
_, err = io.Copy(zstdEncoder, file)
if err != nil { log.Fatal().Err(err).Msg("Error reading file") }
zstdEncoder.Close()
data, err := ioutil.ReadAll(compressedBuffer)
if err != nil { log.Fatal().Err(err).Msg("Error reading compressed buffer") }
// Create md5 hash of password in order to make it the required size
md5Hash := md5.New()
md5Hash.Write([]byte(sharedKey))
@ -54,7 +64,7 @@ func EncryptFile(filePath string, newFilePath string, sharedKey string) {
}
// Decrypt given file using the shared key
func DecryptFile(filePath string, newFilePath string, sharedKey string) {
func DecryptAndDecompressFile(filePath string, newFilePath string, sharedKey string) {
// Use ConsoleWriter logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{})
// Read data from file
@ -76,16 +86,19 @@ func DecryptFile(filePath string, newFilePath string, sharedKey string) {
// Decrypt data
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil { log.Fatal().Err(err).Msg("Error decrypting data") }
zstdDecoder, err := zstd.NewReader(bytes.NewBuffer(plaintext))
if err != nil { log.Fatal().Err(err).Msg("Error creating Zstd decoder") }
// Create new file
newFile, err := os.Create(newFilePath)
if err != nil { log.Fatal().Err(err).Msg("Error creating file") }
// Defer file close
defer newFile.Close()
// Write ciphertext to new file
bytesWritten, err := newFile.Write(plaintext)
// Write plaintext to new file
bytesWritten, err := io.Copy(newFile, zstdDecoder)
if err != nil { log.Fatal().Err(err).Msg("Error writing to file") }
zstdDecoder.Close()
// Log bytes written and to which file
log.Info().Str("file", filepath.Base(newFilePath)).Msg("Wrote " + strconv.Itoa(bytesWritten) + " bytes")
log.Info().Str("file", filepath.Base(newFilePath)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes")
}
// Encrypt files in given directory using shared key
@ -99,7 +112,7 @@ func EncryptFiles(dir string, sharedKey string) {
// If file is not a directory and is not the key
if !info.IsDir() && !strings.Contains(path, "key.aes"){
// Encrypt the file using shared key, appending .enc
EncryptFile(path, path + ".enc", sharedKey)
CompressAndEncryptFile(path, path + ".zst.enc", sharedKey)
// Remove unencrypted file
err := os.Remove(path)
if err != nil { return err }
@ -121,7 +134,7 @@ func DecryptFiles(dir string, sharedKey string) {
// If file is not a directory and is encrypted
if !info.IsDir() && strings.Contains(path, ".enc") {
// Decrypt the file using the shared key, removing .enc
DecryptFile(path, strings.TrimSuffix(path, ".enc"), sharedKey)
DecryptAndDecompressFile(path, strings.TrimSuffix(path, ".zst.enc"), sharedKey)
}
// Return nil if no errors occurred
return nil

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.15
require (
github.com/grandcat/zeroconf v1.0.0
github.com/klauspost/compress v1.11.3
github.com/pkg/browser v0.0.0-20201112035734-206646e67786
github.com/rs/zerolog v1.20.0
)

2
go.sum
View File

@ -3,6 +3,8 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE=
github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs=
github.com/klauspost/compress v1.11.3 h1:dB4Bn0tN3wdCzQxnS8r06kV74qN/TAfaIS0bVE8h3jc=
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM=
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/pkg/browser v0.0.0-20201112035734-206646e67786 h1:4Gk0Dsp90g2YwfsxDOjvkEIgKGh+2R9FlvormRycveA=