use deepphonemizer::Phonemizer; use inference::DeepPhonemizerInference; use std::{error::Error, sync::Arc}; use tch::Device; use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType}; use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _}; use twilight_http::Client as HttpClient; mod bot_config; mod inference; use db::{save::Save, types::GraphemeString}; #[tokio::main] async fn main() -> Result<(), Box> { // Initialize the tracing subscriber. tracing_subscriber::fmt::init(); let config = config::Config::builder() .add_source(config::Environment::with_prefix("GRU")) .build()? .try_deserialize::()?; let token = config.token; // Specify intents requesting events about things like new and updated // messages in a guild and direct messages. // Create a single shard. let mut shard = Shard::new( ShardId::ONE, token.clone(), Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT, ); // The http client is separate from the gateway, so startup a new // one, also use Arc such that it can be cloned to other threads. let http = Arc::new(HttpClient::new(token)); // Since we only care about messages, make the cache only process messages. let cache = DefaultInMemoryCache::builder() .resource_types(ResourceType::MESSAGE) .build(); let mut db = std::fs::File::open(config.db_path).unwrap(); let save: Save = bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); let phonemizer = Phonemizer::from_checkpoint( config.model_path, config.config_path, Device::cuda_if_available(), None, ) .unwrap(); let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); let db = Arc::new(db::inference::Inference::new(inference_service, save)); // Startup the event loop to process each event in the event stream as they // come in. while let Some(item) = shard.next_event(EventTypeFlags::all()).await { let Ok(event) = item else { tracing::warn!(source = ?item.unwrap_err(), "error receiving event"); continue; }; // Update the cache. cache.update(&event); // Spawn a new task to handle the event tokio::spawn(handle_event(event, Arc::clone(&http), db.clone())); } Ok(()) } async fn handle_event( event: Event, http: Arc, inference: Arc, ) -> Result<(), Box> { match event { Event::MessageCreate(msg) => { let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap(); println!("{}", infer); } Event::Ready(_) => { println!("Shard is ready"); } _ => {} } Ok(()) }