diff options
Diffstat (limited to 'bin/cli/src')
| -rw-r--r-- | bin/cli/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/cli/src/main.rs | 30 |
2 files changed, 53 insertions, 0 deletions
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/cli/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ + inference::InferenceError, + types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { + let pronemized = self + .0 + .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + + return Ok(PhonemeString( + PhonemeString(format!("{:?}", pronemized)).to_string(), + )); + } +} diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs new file mode 100644 index 0000000..736ba8e --- /dev/null +++ b/bin/cli/src/main.rs @@ -0,0 +1,30 @@ +use db::{save::Save, types::GraphemeString}; +use deepphonemizer::Phonemizer; +use tch::Device; +use std::sync::Arc; + +mod inference; + +use crate::inference::DeepPhonemizerInference; + +#[tokio::main] +async fn main() { + let mut db = std::fs::File::open("db.bin").unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + "data/model.pt", + "data/forward_config.yaml", + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap()); +} |
