trident/main.go

330 lines
9.7 KiB
Go
Raw Normal View History

2021-04-22 02:29:14 +00:00
/*
* Copyright (C) 2021 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"bufio"
"bytes"
ds "github.com/asticode/go-astideepspeech"
"github.com/gen2brain/malgo"
flag "github.com/spf13/pflag"
"net"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
)
var verbose *bool
var execDir string
var configDir string
func main() {
// Configure environment (paths to resources)
var gopath, confPath string
gopath, configDir, execDir, confPath = configEnv()
// Define and parse command line flags
tfLogLevel := flag.Int("tf-log-level", 2, "Log level for TensorFlow")
verbose = flag.BoolP("verbose", "v", false, "Log more events")
showDecode := flag.BoolP("show-decode", "d", false, "Show text to speech decodes")
2021-04-22 02:29:14 +00:00
configPath := flag.StringP("config", "c", confPath, "Location of trident TOML config")
modelPath := flag.StringP("model", "m", filepath.Join(execDir, "deepspeech.pbmm"), "Path to DeepSpeech model")
scorerPath := flag.StringP("scorer", "s", filepath.Join(execDir, "deepspeech.scorer"), "Path to DeepSpeech scorer")
socketPath := flag.StringP("socket", "S", filepath.Join(configDir, "trident.sock"), "Path to UNIX socket for IPC")
GOPATH := flag.String("gopath", gopath, "GOPATH for use with plugins")
flag.Parse()
// Set TensorFlow log level to specified level (default 2)
_ = os.Setenv("TF_CPP_MIN_LOG_LEVEL", strconv.Itoa(*tfLogLevel))
// Get and parse TOML config
config, err := getConfig(*configPath)
if err != nil {
log.Fatal().Err(err).Msg("Error getting TOML config")
}
// Create new channel storing os.Signal
sigChannel := make(chan os.Signal, 1)
// Notify channel upon reception of specified signals
signal.Notify(sigChannel,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGHUP,
syscall.SIGQUIT,
)
// Create new goroutine to handle signals gracefully
go func() {
// Wait for signal
sig := <-sigChannel
// Log reception of signal
log.Info().Str("signal", sig.String()).Msg("Received signal, shutting down")
// If IPC is enabled in the config, remove the UNIX socket
if config.IPCEnabled {
_ = os.RemoveAll(*socketPath)
}
// Exit with code 0
os.Exit(0)
}()
// Create new DeepSpeech model
model, err := ds.New(*modelPath)
if err != nil {
log.Fatal().Err(err).Msg("Error opening DeepSpeech model")
}
// Initialize available plugins
plugins := initPlugins(*GOPATH)
// If IPC is enabled in config
if config.IPCEnabled {
// Remove UNIX socket ignoring error
_ = os.RemoveAll(*socketPath)
// Listen on UNIX socket
ln, err := net.Listen("unix", *socketPath)
if err != nil {
log.Fatal().Err(err).Msg("Error listening on UNIX socket")
}
go func() {
for {
// Accept any connection when it arrives
conn, err := ln.Accept()
if err != nil {
log.Fatal().Err(err).Msg("Error accepting connection")
}
go func(conn net.Conn) {
// Close connection at end of function
defer conn.Close()
// Create new scanner for connection (default is ScanLines)
scanner := bufio.NewScanner(conn)
// Scan until EOF
for scanner.Scan() {
// If error encountered, return from function
if scanner.Err() != nil {
return
}
// Get text from scanner
input := scanner.Text()
// Attempt to match text to action and return action
action, ok := getAction(config, &input)
// If match founc
if ok {
// Log performing action
log.Info().Str("action", action.Name).Str("source", "socket").Msg("Performing action")
// Perform returned action
done, err := performAction(action, &input, plugins)
if err != nil {
log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action")
}
// If action complete, close connection and return
if done {
conn.Close()
return
}
}
}
}(conn)
}
}()
}
// Initialize audio context
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) {
log.Warn().Msg(message)
})
if err != nil {
log.Fatal().Err(err).Msg("Error initializing malgo context")
}
// Uninitialize and free at end of function
defer func() {
_ = ctx.Uninit()
ctx.Free()
}()
// Set device configuration options
deviceConfig := malgo.DefaultDeviceConfig(malgo.Capture)
deviceConfig.Capture.Format = malgo.FormatS16
deviceConfig.Capture.Channels = 1
deviceConfig.Playback.Format = malgo.FormatS16
deviceConfig.Playback.Channels = 1
deviceConfig.SampleRate = uint32(model.SampleRate())
deviceConfig.Alsa.NoMMap = 1
// Create new buffer to store audio samples
captured := &bytes.Buffer{}
onRecvFrames := func(_, sample []byte, _ uint32) {
// Upon receipt of sample, write to buffer
captured.Write(sample)
}
log.Info().Msg("Listening to audio events")
// Initialize audio device using configuration options
device, err := malgo.InitDevice(ctx.Context, deviceConfig, malgo.DeviceCallbacks{
Data: onRecvFrames,
})
if err != nil {
log.Fatal().Err(err).Msg("Error initializing audio device")
}
// Uninitialize at end of function
defer device.Uninit()
// Start capture device (begin recording)
err = device.Start()
if err != nil {
log.Fatal().Err(err).Msg("Error starting capture device")
}
// Set DeepSpeech scorer
err = model.EnableExternalScorer(*scorerPath)
if err != nil {
log.Fatal().Err(err).Msg("Error opening DeepSpeech scorer")
}
// Create new stream for DeepSpeech model
sttStream, err := model.NewStream()
if err != nil {
log.Fatal().Err(err).Msg("Error creating DeepSpeech stream")
}
// Create a safe stream using sync.Mutex
safeStream := &SafeStream{Stream: sttStream}
// Create goroutine to clean stream every minute
go func() {
for {
time.Sleep(20*time.Second)
2021-04-22 02:29:14 +00:00
// Lock mutex of stream
safeStream.Lock()
// Reset stream and buffer
resetStream(safeStream, model, captured)
if *verbose {
log.Debug().Msg("1m passed; cleaning stream")
}
// Unlock mutex of stream
safeStream.Unlock()
}
}()
var tts string
listenForActivation := true
for {
time.Sleep(200*time.Millisecond)
2021-04-22 02:29:14 +00:00
// Convert captured raw audio to slice of int16
slice, err := convToInt16Slice(captured)
if err != nil {
log.Fatal().Err(err).Msg("Error converting captured audio feed")
}
// Reset buffer
captured.Reset()
// Lock mutex of stream
safeStream.Lock()
// Feed converted audio to stream
safeStream.FeedAudioContent(slice)
// Decode stream without destroying
tts, err = safeStream.IntermediateDecode()
if err != nil {
log.Fatal().Err(err).Msg("Error intermediate decoding stream")
}
if *showDecode {
log.Debug().Msg("TTS Decode: " + tts)
}
2021-04-22 02:29:14 +00:00
// If decoded string contains activation phrase and listenForActivation is true
if strings.Contains(tts, config.ActivationPhrase) && listenForActivation {
// Play activation tone
err = playActivationTone(ctx)
if err != nil {
log.Fatal().Err(err).Msg("Error playing activation tone")
}
// Log detection of activation phrase
log.Info().Msg("Activation phrase detected")
// Reset stream and buffer
resetStream(safeStream, model, captured)
// Create new goroutine to listen for commands
go func() {
// Disable activation
listenForActivation = false
// Enable activation at end of function
defer func() {
listenForActivation = true
}()
// Create timeout channel to trigger after configured time
timeout := time.After(config.ActivationTime)
activationLoop:
for {
time.Sleep(100 * time.Millisecond)
select {
// If timeout has elapsed
case <-timeout:
log.Warn().Msg("Unknown command")
break activationLoop
// If timeout has not elapsed
default:
// Attempt to match decoded string to action
action, ok := getAction(config, &tts)
// If match found
if ok {
// Keep listening if user is talking
for {
// Get length of text to speech string
ttsLen := len(tts)
time.Sleep(time.Second)
// If length has not changed
if ttsLen == len(tts) {
// Break out of for loop
break
}
}
// Log performing action
log.Info().Str("action", action.Name).Str("source", "voice").Msg("Performing action")
// Perform action matched by getAction()
done, err := performAction(action, &tts, plugins)
if err != nil {
log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action")
}
// If action is complete
if done {
// Lock mutex of stream
safeStream.Lock()
// Reset stream and buffer
resetStream(safeStream, model, captured)
// Unlock mutex of stream
safeStream.Unlock()
// Return from goroutine
return
}
}
}
}
}()
}
// Unlock mutex of stream
safeStream.Unlock()
}
}
// Function to reset stream and buffer
func resetStream(s *SafeStream, model *ds.Model, captured *bytes.Buffer) {
// Reset buffer
captured.Reset()
// Discard stream (workaround for lack of Clear function)
s.Discard()
// Create new stream, setting it to same location as old
s.Stream, _ = model.NewStream()
}