diff options
| author | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
|---|---|---|
| committer | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
| commit | 09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch) | |
| tree | 874fe676977f56c5456e4cead31e1a6c86bc40dd /bin/bot/src/main.rs | |
| parent | c951f54be999fb7e508039ba23fa5b1fe7035743 (diff) | |
feat: commit allfeat/twilight-discord
Diffstat (limited to 'bin/bot/src/main.rs')
| -rw-r--r-- | bin/bot/src/main.rs | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs new file mode 100644 index 0000000..574b417 --- /dev/null +++ b/bin/bot/src/main.rs @@ -0,0 +1,97 @@ +use deepphonemizer::Phonemizer; +use inference::DeepPhonemizerInference; +use std::{error::Error, sync::Arc}; +use tch::Device; +use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType}; +use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _}; +use twilight_http::Client as HttpClient; + +mod bot_config; +mod inference; + +use db::{save::Save, types::GraphemeString}; + +#[tokio::main] +async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { + // Initialize the tracing subscriber. + tracing_subscriber::fmt::init(); + + let config = config::Config::builder() + .add_source(config::Environment::with_prefix("GRU")) + .build()? + .try_deserialize::<bot_config::Config>()?; + let token = config.token; + + // Specify intents requesting events about things like new and updated + // messages in a guild and direct messages. + + // Create a single shard. + let mut shard = Shard::new( + ShardId::ONE, + token.clone(), + Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT, + ); + + // The http client is separate from the gateway, so startup a new + // one, also use Arc such that it can be cloned to other threads. + let http = Arc::new(HttpClient::new(token)); + + // Since we only care about messages, make the cache only process messages. + let cache = DefaultInMemoryCache::builder() + .resource_types(ResourceType::MESSAGE) + .build(); + + let mut db = std::fs::File::open(config.db_path).unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + config.model_path, + config.config_path, + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + // Startup the event loop to process each event in the event stream as they + // come in. + while let Some(item) = shard.next_event(EventTypeFlags::all()).await { + let Ok(event) = item else { + tracing::warn!(source = ?item.unwrap_err(), "error receiving event"); + + continue; + }; + // Update the cache. + cache.update(&event); + + // Spawn a new task to handle the event + tokio::spawn(handle_event(event, Arc::clone(&http), db.clone())); + } + + Ok(()) +} + +async fn handle_event( + event: Event, + http: Arc<HttpClient>, + inference: Arc<db::inference::Inference>, +) -> Result<(), Box<dyn Error + Send + Sync>> { + match event { + Event::MessageCreate(msg) => { + let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap(); + + println!("{}", infer); + } + Event::Ready(_) => { + println!("Shard is ready"); + } + _ => {} + } + + Ok(()) +} |
