diff options
| author | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
|---|---|---|
| committer | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
| commit | 09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch) | |
| tree | 874fe676977f56c5456e4cead31e1a6c86bc40dd /bin/cli/src | |
| parent | c951f54be999fb7e508039ba23fa5b1fe7035743 (diff) | |
feat: commit allfeat/twilight-discord
Diffstat (limited to 'bin/cli/src')
| -rw-r--r-- | bin/cli/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/cli/src/main.rs | 30 |
2 files changed, 53 insertions, 0 deletions
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/cli/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ + inference::InferenceError, + types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { + let pronemized = self + .0 + .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + + return Ok(PhonemeString( + PhonemeString(format!("{:?}", pronemized)).to_string(), + )); + } +} diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs new file mode 100644 index 0000000..736ba8e --- /dev/null +++ b/bin/cli/src/main.rs @@ -0,0 +1,30 @@ +use db::{save::Save, types::GraphemeString}; +use deepphonemizer::Phonemizer; +use tch::Device; +use std::sync::Arc; + +mod inference; + +use crate::inference::DeepPhonemizerInference; + +#[tokio::main] +async fn main() { + let mut db = std::fs::File::open("db.bin").unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + "data/model.pt", + "data/forward_config.yaml", + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap()); +} |
