summaryrefslogtreecommitdiff
path: root/libs/db/src/inference.rs
blob: f2cbabf04386750fd2cfa866accdd01db9187287 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
use crate::{
    save::Save,
    types::{GraphemeString, InferenceService, PhonemeString},
};
use hypher::{Lang, hyphenate};
use itertools::Itertools;
use log::*;
use std::{iter, ops::Add};
use thiserror::Error;
use unicode_segmentation::UnicodeSegmentation;

#[derive(Error, Debug)]
pub enum InferenceError {
    #[error("failed to fetch a random value from the trie")]
    TrieError(),
    #[error("the matched value wasn't in the dictionary")]
    NotInDictionary(),
    #[error("failt to cut the word")]
    WordCutError(),
}

pub struct Inference<'a> {
    inference_service: Box<dyn InferenceService<InferenceError>>,
    save: Save<'a>,
}

pub struct TLang(whatlang::Lang);
impl Into<Lang> for TLang {
    fn into(self) -> Lang {
        match self.0 {
            whatlang::Lang::Eng => Lang::English,
            whatlang::Lang::Rus => Lang::Russian,
            whatlang::Lang::Spa => Lang::Spanish,
            whatlang::Lang::Por => Lang::Portuguese,
            whatlang::Lang::Ita => Lang::Italian,
            whatlang::Lang::Fra => Lang::French,
            whatlang::Lang::Deu => Lang::German,
            whatlang::Lang::Ukr => Lang::Ukrainian,
            whatlang::Lang::Kat => Lang::Georgian,
            whatlang::Lang::Pol => Lang::Polish,
            whatlang::Lang::Dan => Lang::Danish,
            whatlang::Lang::Swe => Lang::Swedish,
            whatlang::Lang::Fin => Lang::Finnish,
            whatlang::Lang::Tur => Lang::Turkish,
            whatlang::Lang::Nld => Lang::Dutch,
            whatlang::Lang::Hun => Lang::Hungarian,
            whatlang::Lang::Ces => Lang::Czech,
            whatlang::Lang::Ell => Lang::Greek,
            whatlang::Lang::Bul => Lang::Bulgarian,
            whatlang::Lang::Bel => Lang::Belarusian,
            whatlang::Lang::Hrv => Lang::Croatian,
            whatlang::Lang::Srp => Lang::Serbian,
            whatlang::Lang::Lit => Lang::Lithuanian,
            whatlang::Lang::Est => Lang::Estonian,
            whatlang::Lang::Tuk => Lang::Turkmen,
            whatlang::Lang::Afr => Lang::Afrikaans,
            whatlang::Lang::Lat => Lang::Latin,
            whatlang::Lang::Slk => Lang::Slovak,
            whatlang::Lang::Cat => Lang::Catalan,
            _ => Lang::English,
        }
    }
}

impl Inference<'_> {
    fn matches(&self, source: &str, complete: &str) -> bool {
        let source_chars: Vec<&str> = source.graphemes(true).collect();
        let complete_chars: Vec<&str> = complete.graphemes(true).collect();

        let s = usize::min(source_chars.len(), complete_chars.len());

        for j in 0..s {
            let metric = kondrak_aline::delta(source_chars[j], complete_chars[j]);
            if metric > 20.0 {
                return false;
            }
        }

        return true;
    }

    pub fn new<'a>(
        inference_service: Box<dyn InferenceService<InferenceError>>,
        save: Save<'a>,
    ) -> Inference<'a> {
        return Inference {
            inference_service,
            save,
        };
    }

    pub async fn infer(&self, prefix: GraphemeString) -> Result<String, InferenceError> {
        debug!("starting inference for {prefix}");

        let cprefix = prefix.0.clone();
        let lang = TLang(
            whatlang::detect_lang(&cprefix)
                .or(Some(whatlang::Lang::Eng))
                .unwrap(),
        );
        let langt: Lang = lang.into();
        // The first step is to use the inference service in order to find the phoneme corresponding to the word.
        let phonemes =
            PhonemeString::from_grapheme(prefix.clone(), &self.inference_service).await?;
        // we then "cut" the word into his syllables
        let cprefix = prefix.0.clone();
        let hyphens: Vec<&str> = hyphenate(&cprefix, langt).collect_vec();

        let cprefix = prefix.0.clone();
        // we then perform the random walk into the tree to find a random
        // word starting with the given phoneme combination
        let completion = self
            .save
            .trie
            .random_starting_with(phonemes.clone())
            .ok_or_else(InferenceError::TrieError)?;

        // store the found word
        let initial = cprefix.add(&completion.0);
        // cut the found word syllabes
        let mut completed_hyphens: Vec<&str> = hyphenate(&initial, langt).into_iter().collect_vec();

        // we search the matched word in the reverse dictionary
        let matched = self
            .save
            .reverse_index
            .get(&completion)
            .ok_or_else(InferenceError::NotInDictionary)?;

        debug!(
            "found matching word '{}' that has phonetic '{}' that start with '{}'",
            initial, completion, phonemes
        );

        // we store the index where the word should be cut
        let mut highest_index = 0usize;
        for (index, completed_hyphen) in completed_hyphens
            .splice(0..hyphens.len() - 1, iter::empty())
            .enumerate()
        {
            let source_hyph = hyphens[index].to_lowercase();
            let complete_hyph = completed_hyphen.to_lowercase();

            debug!(
                "[{}] comparing hyphen src={} dst={}",
                index, source_hyph, complete_hyph
            );

            if self.matches(&source_hyph, &complete_hyph) {
                highest_index = index
            } else {
                debug!(
                    "[{}] found matching hyphen at index {}",
                    index, highest_index
                );
                break;
            }
        }

        // we finally just need to compute the end of the word which matches the sound
        let found = completed_hyphens.drain(highest_index + 1..).join("");

        if found.len() == 0 {
            debug!("failed to find where to cut the word to match");
            Err(InferenceError::WordCutError())
        } else {
            debug!("found that {} is equivalent to {}", completion, found);
            Ok(format!("{} ({})", found, matched))
        }
    }
}