/* * Copyright (C) 2021 Arsen Musayelyan * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package main import ( "bufio" "bytes" ds "github.com/asticode/go-astideepspeech" "github.com/gen2brain/malgo" flag "github.com/spf13/pflag" "net" "os" "os/signal" "path/filepath" "strconv" "strings" "syscall" "time" ) var verbose *bool var execDir string var configDir string func main() { // Configure environment (paths to resources) var gopath, confPath string gopath, configDir, execDir, confPath = configEnv() // Define and parse command line flags tfLogLevel := flag.Int("tf-log-level", 2, "Log level for TensorFlow") verbose = flag.BoolP("verbose", "v", false, "Log more events") showDecode := flag.BoolP("show-decode", "d", false, "Show text to speech decodes") configPath := flag.StringP("config", "c", confPath, "Location of trident TOML config") modelPath := flag.StringP("model", "m", filepath.Join(execDir, "deepspeech.pbmm"), "Path to DeepSpeech model") scorerPath := flag.StringP("scorer", "s", filepath.Join(execDir, "deepspeech.scorer"), "Path to DeepSpeech scorer") socketPath := flag.StringP("socket", "S", filepath.Join(configDir, "trident.sock"), "Path to UNIX socket for IPC") GOPATH := flag.String("gopath", gopath, "GOPATH for use with plugins") flag.Parse() // Set TensorFlow log level to specified level (default 2) _ = os.Setenv("TF_CPP_MIN_LOG_LEVEL", strconv.Itoa(*tfLogLevel)) // Get and parse TOML config config, err := getConfig(*configPath) if err != nil { log.Fatal().Err(err).Msg("Error getting TOML config") } // Create new channel storing os.Signal sigChannel := make(chan os.Signal, 1) // Notify channel upon reception of specified signals signal.Notify(sigChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, ) // Create new goroutine to handle signals gracefully go func() { // Wait for signal sig := <-sigChannel // Log reception of signal log.Info().Str("signal", sig.String()).Msg("Received signal, shutting down") // If IPC is enabled in the config, remove the UNIX socket if config.IPCEnabled { _ = os.RemoveAll(*socketPath) } // Exit with code 0 os.Exit(0) }() // Create new DeepSpeech model model, err := ds.New(*modelPath) if err != nil { log.Fatal().Err(err).Msg("Error opening DeepSpeech model") } // Initialize available plugins plugins := initPlugins(*GOPATH) // If IPC is enabled in config if config.IPCEnabled { // Remove UNIX socket ignoring error _ = os.RemoveAll(*socketPath) // Listen on UNIX socket ln, err := net.Listen("unix", *socketPath) if err != nil { log.Fatal().Err(err).Msg("Error listening on UNIX socket") } go func() { for { // Accept any connection when it arrives conn, err := ln.Accept() if err != nil { log.Fatal().Err(err).Msg("Error accepting connection") } go func(conn net.Conn) { // Close connection at end of function defer conn.Close() // Create new scanner for connection (default is ScanLines) scanner := bufio.NewScanner(conn) // Scan until EOF for scanner.Scan() { // If error encountered, return from function if scanner.Err() != nil { return } // Get text from scanner input := scanner.Text() // Attempt to match text to action and return action action, ok := getAction(config, &input) // If match founc if ok { // Log performing action log.Info().Str("action", action.Name).Str("source", "socket").Msg("Performing action") // Perform returned action done, err := performAction(action, &input, plugins) if err != nil { log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action") } // If action complete, close connection and return if done { conn.Close() return } } } }(conn) } }() } // Initialize audio context ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) { log.Warn().Msg(message) }) if err != nil { log.Fatal().Err(err).Msg("Error initializing malgo context") } // Uninitialize and free at end of function defer func() { _ = ctx.Uninit() ctx.Free() }() // Set device configuration options deviceConfig := malgo.DefaultDeviceConfig(malgo.Capture) deviceConfig.Capture.Format = malgo.FormatS16 deviceConfig.Capture.Channels = 1 deviceConfig.Playback.Format = malgo.FormatS16 deviceConfig.Playback.Channels = 1 deviceConfig.SampleRate = uint32(model.SampleRate()) deviceConfig.Alsa.NoMMap = 1 // Create new buffer to store audio samples captured := &bytes.Buffer{} onRecvFrames := func(_, sample []byte, _ uint32) { // Upon receipt of sample, write to buffer captured.Write(sample) } log.Info().Msg("Listening to audio events") // Initialize audio device using configuration options device, err := malgo.InitDevice(ctx.Context, deviceConfig, malgo.DeviceCallbacks{ Data: onRecvFrames, }) if err != nil { log.Fatal().Err(err).Msg("Error initializing audio device") } // Uninitialize at end of function defer device.Uninit() // Start capture device (begin recording) err = device.Start() if err != nil { log.Fatal().Err(err).Msg("Error starting capture device") } // Set DeepSpeech scorer err = model.EnableExternalScorer(*scorerPath) if err != nil { log.Fatal().Err(err).Msg("Error opening DeepSpeech scorer") } // Create new stream for DeepSpeech model sttStream, err := model.NewStream() if err != nil { log.Fatal().Err(err).Msg("Error creating DeepSpeech stream") } // Create a safe stream using sync.Mutex safeStream := &SafeStream{Stream: sttStream} // Create goroutine to clean stream every minute go func() { for { time.Sleep(20 * time.Second) // Lock mutex of stream safeStream.Lock() // Reset stream and buffer resetStream(safeStream, model, captured) if *verbose { log.Debug().Msg("1m passed; cleaning stream") } // Unlock mutex of stream safeStream.Unlock() } }() var tts string listenForActivation := true for { time.Sleep(200 * time.Millisecond) // Convert captured raw audio to slice of int16 slice, err := convToInt16Slice(captured) if err != nil { log.Fatal().Err(err).Msg("Error converting captured audio feed") } // Reset buffer captured.Reset() // Lock mutex of stream safeStream.Lock() // Feed converted audio to stream safeStream.FeedAudioContent(slice) // Decode stream without destroying tts, err = safeStream.IntermediateDecode() if err != nil { log.Fatal().Err(err).Msg("Error intermediate decoding stream") } if *showDecode { log.Debug().Msg("TTS Decode: " + tts) } // If decoded string contains activation phrase and listenForActivation is true if strings.Contains(tts, config.ActivationPhrase) && listenForActivation { // Play activation tone err = playActivationTone(ctx) if err != nil { log.Fatal().Err(err).Msg("Error playing activation tone") } // Log detection of activation phrase log.Info().Msg("Activation phrase detected") // Reset stream and buffer resetStream(safeStream, model, captured) // Create new goroutine to listen for commands go func() { // Disable activation listenForActivation = false // Enable activation at end of function defer func() { listenForActivation = true }() // Create timeout channel to trigger after configured time timeout := time.After(config.ActivationTime) activationLoop: for { time.Sleep(100 * time.Millisecond) select { // If timeout has elapsed case <-timeout: log.Warn().Msg("Unknown command") break activationLoop // If timeout has not elapsed default: // Attempt to match decoded string to action action, ok := getAction(config, &tts) // If match found if ok { // Keep listening if user is talking for { // Get length of text to speech string ttsLen := len(tts) time.Sleep(time.Second) // If length has not changed if ttsLen == len(tts) { // Break out of for loop break } } // Log performing action log.Info().Str("action", action.Name).Str("source", "voice").Msg("Performing action") // Perform action matched by getAction() done, err := performAction(action, &tts, plugins) if err != nil { log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action") } // If action is complete if done { // Lock mutex of stream safeStream.Lock() // Reset stream and buffer resetStream(safeStream, model, captured) // Unlock mutex of stream safeStream.Unlock() // Return from goroutine return } } } } }() } // Unlock mutex of stream safeStream.Unlock() } } // Function to reset stream and buffer func resetStream(s *SafeStream, model *ds.Model, captured *bytes.Buffer) { // Reset buffer captured.Reset() // Discard stream (workaround for lack of Clear function) s.Discard() // Create new stream, setting it to same location as old s.Stream, _ = model.NewStream() }