summaryrefslogtreecommitdiff
path: root/bin/cli/src/main.rs
blob: 736ba8e7ceb27ea1811d71a1398acb8db2e1b839 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
use db::{save::Save, types::GraphemeString};
use deepphonemizer::Phonemizer;
use tch::Device;
use std::sync::Arc;

mod inference;

use crate::inference::DeepPhonemizerInference;

#[tokio::main]
async fn main() {
    let mut db = std::fs::File::open("db.bin").unwrap();

    let save: Save =
        bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();

    let phonemizer = Phonemizer::from_checkpoint(
        "data/model.pt",
        "data/forward_config.yaml",
        Device::cuda_if_available(),
        None,
    )
    .unwrap();

    let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));

    let db = Arc::new(db::inference::Inference::new(inference_service, save));

    println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap());
}