summaryrefslogtreecommitdiff
path: root/bin/cli/src
diff options
context:
space:
mode:
Diffstat (limited to 'bin/cli/src')
-rw-r--r--bin/cli/src/inference.rs23
-rw-r--r--bin/cli/src/main.rs30
2 files changed, 53 insertions, 0 deletions
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/cli/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs
new file mode 100644
index 0000000..736ba8e
--- /dev/null
+++ b/bin/cli/src/main.rs
@@ -0,0 +1,30 @@
+use db::{save::Save, types::GraphemeString};
+use deepphonemizer::Phonemizer;
+use tch::Device;
+use std::sync::Arc;
+
+mod inference;
+
+use crate::inference::DeepPhonemizerInference;
+
+#[tokio::main]
+async fn main() {
+ let mut db = std::fs::File::open("db.bin").unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ "data/model.pt",
+ "data/forward_config.yaml",
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap());
+}