summaryrefslogtreecommitdiff
path: root/bin/cli
diff options
context:
space:
mode:
Diffstat (limited to 'bin/cli')
-rw-r--r--bin/cli/Cargo.toml13
-rw-r--r--bin/cli/src/inference.rs23
-rw-r--r--bin/cli/src/main.rs30
3 files changed, 66 insertions, 0 deletions
diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml
new file mode 100644
index 0000000..cc2b393
--- /dev/null
+++ b/bin/cli/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "cli"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+
+tokio = { version = "1.44.2", features = ["full"] }
+db = { path = "../../libs/db" }
+deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" }
+async-trait = "0.1.88"
+bincode = { version = "2.0.1", features = ["serde"] }
+tch = {version = "0.20.0", features = ["download-libtorch"]}
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/cli/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs
new file mode 100644
index 0000000..736ba8e
--- /dev/null
+++ b/bin/cli/src/main.rs
@@ -0,0 +1,30 @@
+use db::{save::Save, types::GraphemeString};
+use deepphonemizer::Phonemizer;
+use tch::Device;
+use std::sync::Arc;
+
+mod inference;
+
+use crate::inference::DeepPhonemizerInference;
+
+#[tokio::main]
+async fn main() {
+ let mut db = std::fs::File::open("db.bin").unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ "data/model.pt",
+ "data/forward_config.yaml",
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap());
+}