diff options
Diffstat (limited to 'bin')
| -rw-r--r-- | bin/bot/Cargo.toml | 25 | ||||
| -rw-r--r-- | bin/bot/src/bot_config.rs | 10 | ||||
| -rw-r--r-- | bin/bot/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/bot/src/main.rs | 97 | ||||
| -rw-r--r-- | bin/cli/Cargo.toml | 13 | ||||
| -rw-r--r-- | bin/cli/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/cli/src/main.rs | 30 | 
7 files changed, 221 insertions, 0 deletions
diff --git a/bin/bot/Cargo.toml b/bin/bot/Cargo.toml new file mode 100644 index 0000000..708a831 --- /dev/null +++ b/bin/bot/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "bot" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = "1.0.98" +config = "0.15.11" +log = "0.4.27" +pretty_env_logger = "0.5.0" +serde = "1.0.219" +tokio = { version = "1.44.2", features = ["rt-multi-thread"] } +tracing = "0.1.41" +tracing-subscriber = "0.3.19" +twilight-cache-inmemory = "0.16.0" +twilight-gateway = "0.16.0" +twilight-http = "0.16.0" +twilight-model = "0.16.0" +bincode = { version = "2.0.1", features = ["serde"] } + +db = { path = "../../libs/db" } +deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" } +async-trait = "0.1.88" + +tch = {version = "0.20.0", features = ["download-libtorch"]} diff --git a/bin/bot/src/bot_config.rs b/bin/bot/src/bot_config.rs new file mode 100644 index 0000000..af6d714 --- /dev/null +++ b/bin/bot/src/bot_config.rs @@ -0,0 +1,10 @@ +use serde::{Serialize, Deserialize}; +use std::string::String; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Config { +    pub token: String, +    pub model_path: String, +    pub config_path: String, +    pub db_path: String, +} diff --git a/bin/bot/src/inference.rs b/bin/bot/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/bot/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ +    inference::InferenceError, +    types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { +    async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { +        let pronemized = self +            .0 +            .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + +        return Ok(PhonemeString( +            PhonemeString(format!("{:?}", pronemized)).to_string(), +        )); +    } +} diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs new file mode 100644 index 0000000..574b417 --- /dev/null +++ b/bin/bot/src/main.rs @@ -0,0 +1,97 @@ +use deepphonemizer::Phonemizer; +use inference::DeepPhonemizerInference; +use std::{error::Error, sync::Arc}; +use tch::Device; +use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType}; +use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _}; +use twilight_http::Client as HttpClient; + +mod bot_config; +mod inference; + +use db::{save::Save, types::GraphemeString}; + +#[tokio::main] +async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { +    // Initialize the tracing subscriber. +    tracing_subscriber::fmt::init(); + +    let config = config::Config::builder() +        .add_source(config::Environment::with_prefix("GRU")) +        .build()? +        .try_deserialize::<bot_config::Config>()?; +    let token = config.token; + +    // Specify intents requesting events about things like new and updated +    // messages in a guild and direct messages. + +    // Create a single shard. +    let mut shard = Shard::new( +        ShardId::ONE, +        token.clone(), +        Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT, +    ); + +    // The http client is separate from the gateway, so startup a new +    // one, also use Arc such that it can be cloned to other threads. +    let http = Arc::new(HttpClient::new(token)); + +    // Since we only care about messages, make the cache only process messages. +    let cache = DefaultInMemoryCache::builder() +        .resource_types(ResourceType::MESSAGE) +        .build(); + +    let mut db = std::fs::File::open(config.db_path).unwrap(); + +    let save: Save = +        bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + +    let phonemizer = Phonemizer::from_checkpoint( +        config.model_path, +        config.config_path, +        Device::cuda_if_available(), +        None, +    ) +    .unwrap(); + +    let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + +    let db = Arc::new(db::inference::Inference::new(inference_service, save)); + +    // Startup the event loop to process each event in the event stream as they +    // come in. +    while let Some(item) = shard.next_event(EventTypeFlags::all()).await { +        let Ok(event) = item else { +            tracing::warn!(source = ?item.unwrap_err(), "error receiving event"); + +            continue; +        }; +        // Update the cache. +        cache.update(&event); + +        // Spawn a new task to handle the event +        tokio::spawn(handle_event(event, Arc::clone(&http), db.clone())); +    } + +    Ok(()) +} + +async fn handle_event( +    event: Event, +    http: Arc<HttpClient>, +    inference: Arc<db::inference::Inference>, +) -> Result<(), Box<dyn Error + Send + Sync>> { +    match event { +        Event::MessageCreate(msg) => { +            let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap(); + +            println!("{}", infer); +        } +        Event::Ready(_) => { +            println!("Shard is ready"); +        } +        _ => {} +    } + +    Ok(()) +} diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml new file mode 100644 index 0000000..cc2b393 --- /dev/null +++ b/bin/cli/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "cli" +version = "0.1.0" +edition = "2024" + +[dependencies] + +tokio = { version = "1.44.2", features = ["full"] } +db = { path = "../../libs/db" } +deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" } +async-trait = "0.1.88" +bincode = { version = "2.0.1", features = ["serde"] } +tch = {version = "0.20.0", features = ["download-libtorch"]} diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/cli/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ +    inference::InferenceError, +    types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { +    async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { +        let pronemized = self +            .0 +            .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + +        return Ok(PhonemeString( +            PhonemeString(format!("{:?}", pronemized)).to_string(), +        )); +    } +} diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs new file mode 100644 index 0000000..736ba8e --- /dev/null +++ b/bin/cli/src/main.rs @@ -0,0 +1,30 @@ +use db::{save::Save, types::GraphemeString}; +use deepphonemizer::Phonemizer; +use tch::Device; +use std::sync::Arc; + +mod inference; + +use crate::inference::DeepPhonemizerInference; + +#[tokio::main] +async fn main() { +    let mut db = std::fs::File::open("db.bin").unwrap(); + +    let save: Save = +        bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + +    let phonemizer = Phonemizer::from_checkpoint( +        "data/model.pt", +        "data/forward_config.yaml", +        Device::cuda_if_available(), +        None, +    ) +    .unwrap(); + +    let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + +    let db = Arc::new(db::inference::Inference::new(inference_service, save)); + +    println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap()); +}  | 
