summaryrefslogtreecommitdiff
path: root/bin/bot/src/main.rs
diff options
context:
space:
mode:
authorMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
committerMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
commit09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch)
tree874fe676977f56c5456e4cead31e1a6c86bc40dd /bin/bot/src/main.rs
parentc951f54be999fb7e508039ba23fa5b1fe7035743 (diff)
feat: commit allfeat/twilight-discord
Diffstat (limited to 'bin/bot/src/main.rs')
-rw-r--r--bin/bot/src/main.rs97
1 files changed, 97 insertions, 0 deletions
diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs
new file mode 100644
index 0000000..574b417
--- /dev/null
+++ b/bin/bot/src/main.rs
@@ -0,0 +1,97 @@
+use deepphonemizer::Phonemizer;
+use inference::DeepPhonemizerInference;
+use std::{error::Error, sync::Arc};
+use tch::Device;
+use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType};
+use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
+use twilight_http::Client as HttpClient;
+
+mod bot_config;
+mod inference;
+
+use db::{save::Save, types::GraphemeString};
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
+ // Initialize the tracing subscriber.
+ tracing_subscriber::fmt::init();
+
+ let config = config::Config::builder()
+ .add_source(config::Environment::with_prefix("GRU"))
+ .build()?
+ .try_deserialize::<bot_config::Config>()?;
+ let token = config.token;
+
+ // Specify intents requesting events about things like new and updated
+ // messages in a guild and direct messages.
+
+ // Create a single shard.
+ let mut shard = Shard::new(
+ ShardId::ONE,
+ token.clone(),
+ Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT,
+ );
+
+ // The http client is separate from the gateway, so startup a new
+ // one, also use Arc such that it can be cloned to other threads.
+ let http = Arc::new(HttpClient::new(token));
+
+ // Since we only care about messages, make the cache only process messages.
+ let cache = DefaultInMemoryCache::builder()
+ .resource_types(ResourceType::MESSAGE)
+ .build();
+
+ let mut db = std::fs::File::open(config.db_path).unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ config.model_path,
+ config.config_path,
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ // Startup the event loop to process each event in the event stream as they
+ // come in.
+ while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
+ let Ok(event) = item else {
+ tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
+
+ continue;
+ };
+ // Update the cache.
+ cache.update(&event);
+
+ // Spawn a new task to handle the event
+ tokio::spawn(handle_event(event, Arc::clone(&http), db.clone()));
+ }
+
+ Ok(())
+}
+
+async fn handle_event(
+ event: Event,
+ http: Arc<HttpClient>,
+ inference: Arc<db::inference::Inference>,
+) -> Result<(), Box<dyn Error + Send + Sync>> {
+ match event {
+ Event::MessageCreate(msg) => {
+ let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap();
+
+ println!("{}", infer);
+ }
+ Event::Ready(_) => {
+ println!("Shard is ready");
+ }
+ _ => {}
+ }
+
+ Ok(())
+}