summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
Diffstat (limited to 'bin')
-rw-r--r--bin/bot/Cargo.toml25
-rw-r--r--bin/bot/src/bot_config.rs10
-rw-r--r--bin/bot/src/inference.rs23
-rw-r--r--bin/bot/src/main.rs97
-rw-r--r--bin/cli/Cargo.toml13
-rw-r--r--bin/cli/src/inference.rs23
-rw-r--r--bin/cli/src/main.rs30
7 files changed, 221 insertions, 0 deletions
diff --git a/bin/bot/Cargo.toml b/bin/bot/Cargo.toml
new file mode 100644
index 0000000..708a831
--- /dev/null
+++ b/bin/bot/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "bot"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+anyhow = "1.0.98"
+config = "0.15.11"
+log = "0.4.27"
+pretty_env_logger = "0.5.0"
+serde = "1.0.219"
+tokio = { version = "1.44.2", features = ["rt-multi-thread"] }
+tracing = "0.1.41"
+tracing-subscriber = "0.3.19"
+twilight-cache-inmemory = "0.16.0"
+twilight-gateway = "0.16.0"
+twilight-http = "0.16.0"
+twilight-model = "0.16.0"
+bincode = { version = "2.0.1", features = ["serde"] }
+
+db = { path = "../../libs/db" }
+deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" }
+async-trait = "0.1.88"
+
+tch = {version = "0.20.0", features = ["download-libtorch"]}
diff --git a/bin/bot/src/bot_config.rs b/bin/bot/src/bot_config.rs
new file mode 100644
index 0000000..af6d714
--- /dev/null
+++ b/bin/bot/src/bot_config.rs
@@ -0,0 +1,10 @@
+use serde::{Serialize, Deserialize};
+use std::string::String;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Config {
+ pub token: String,
+ pub model_path: String,
+ pub config_path: String,
+ pub db_path: String,
+}
diff --git a/bin/bot/src/inference.rs b/bin/bot/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/bot/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs
new file mode 100644
index 0000000..574b417
--- /dev/null
+++ b/bin/bot/src/main.rs
@@ -0,0 +1,97 @@
+use deepphonemizer::Phonemizer;
+use inference::DeepPhonemizerInference;
+use std::{error::Error, sync::Arc};
+use tch::Device;
+use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType};
+use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
+use twilight_http::Client as HttpClient;
+
+mod bot_config;
+mod inference;
+
+use db::{save::Save, types::GraphemeString};
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
+ // Initialize the tracing subscriber.
+ tracing_subscriber::fmt::init();
+
+ let config = config::Config::builder()
+ .add_source(config::Environment::with_prefix("GRU"))
+ .build()?
+ .try_deserialize::<bot_config::Config>()?;
+ let token = config.token;
+
+ // Specify intents requesting events about things like new and updated
+ // messages in a guild and direct messages.
+
+ // Create a single shard.
+ let mut shard = Shard::new(
+ ShardId::ONE,
+ token.clone(),
+ Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT,
+ );
+
+ // The http client is separate from the gateway, so startup a new
+ // one, also use Arc such that it can be cloned to other threads.
+ let http = Arc::new(HttpClient::new(token));
+
+ // Since we only care about messages, make the cache only process messages.
+ let cache = DefaultInMemoryCache::builder()
+ .resource_types(ResourceType::MESSAGE)
+ .build();
+
+ let mut db = std::fs::File::open(config.db_path).unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ config.model_path,
+ config.config_path,
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ // Startup the event loop to process each event in the event stream as they
+ // come in.
+ while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
+ let Ok(event) = item else {
+ tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
+
+ continue;
+ };
+ // Update the cache.
+ cache.update(&event);
+
+ // Spawn a new task to handle the event
+ tokio::spawn(handle_event(event, Arc::clone(&http), db.clone()));
+ }
+
+ Ok(())
+}
+
+async fn handle_event(
+ event: Event,
+ http: Arc<HttpClient>,
+ inference: Arc<db::inference::Inference>,
+) -> Result<(), Box<dyn Error + Send + Sync>> {
+ match event {
+ Event::MessageCreate(msg) => {
+ let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap();
+
+ println!("{}", infer);
+ }
+ Event::Ready(_) => {
+ println!("Shard is ready");
+ }
+ _ => {}
+ }
+
+ Ok(())
+}
diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml
new file mode 100644
index 0000000..cc2b393
--- /dev/null
+++ b/bin/cli/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "cli"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+
+tokio = { version = "1.44.2", features = ["full"] }
+db = { path = "../../libs/db" }
+deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" }
+async-trait = "0.1.88"
+bincode = { version = "2.0.1", features = ["serde"] }
+tch = {version = "0.20.0", features = ["download-libtorch"]}
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/cli/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs
new file mode 100644
index 0000000..736ba8e
--- /dev/null
+++ b/bin/cli/src/main.rs
@@ -0,0 +1,30 @@
+use db::{save::Save, types::GraphemeString};
+use deepphonemizer::Phonemizer;
+use tch::Device;
+use std::sync::Arc;
+
+mod inference;
+
+use crate::inference::DeepPhonemizerInference;
+
+#[tokio::main]
+async fn main() {
+ let mut db = std::fs::File::open("db.bin").unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ "data/model.pt",
+ "data/forward_config.yaml",
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap());
+}