diff options
Diffstat (limited to 'bin/cli')
| -rw-r--r-- | bin/cli/Cargo.toml | 13 | ||||
| -rw-r--r-- | bin/cli/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/cli/src/main.rs | 30 |
3 files changed, 66 insertions, 0 deletions
diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml new file mode 100644 index 0000000..cc2b393 --- /dev/null +++ b/bin/cli/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "cli" +version = "0.1.0" +edition = "2024" + +[dependencies] + +tokio = { version = "1.44.2", features = ["full"] } +db = { path = "../../libs/db" } +deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" } +async-trait = "0.1.88" +bincode = { version = "2.0.1", features = ["serde"] } +tch = {version = "0.20.0", features = ["download-libtorch"]} diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/cli/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ + inference::InferenceError, + types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { + let pronemized = self + .0 + .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + + return Ok(PhonemeString( + PhonemeString(format!("{:?}", pronemized)).to_string(), + )); + } +} diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs new file mode 100644 index 0000000..736ba8e --- /dev/null +++ b/bin/cli/src/main.rs @@ -0,0 +1,30 @@ +use db::{save::Save, types::GraphemeString}; +use deepphonemizer::Phonemizer; +use tch::Device; +use std::sync::Arc; + +mod inference; + +use crate::inference::DeepPhonemizerInference; + +#[tokio::main] +async fn main() { + let mut db = std::fs::File::open("db.bin").unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + "data/model.pt", + "data/forward_config.yaml", + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap()); +} |
