use crate::{ save::Save, types::{GraphemeString, InferenceService, PhonemeString}, }; use hypher::{Lang, hyphenate}; use itertools::Itertools; use log::*; use std::{iter, ops::Add, sync::Arc}; use thiserror::Error; use unicode_segmentation::UnicodeSegmentation; #[derive(Error, Debug)] pub enum InferenceError { #[error("failed to fetch a random value from the trie")] TrieError(), #[error("the matched value wasn't in the dictionary")] NotInDictionary(), #[error("failed to cut the word")] WordCutError(), } #[derive(Debug, Clone)] pub struct Inference { inference_service: Arc>, save: Save, } pub struct TLang(whatlang::Lang); impl Into for TLang { fn into(self) -> Lang { match self.0 { whatlang::Lang::Eng => Lang::English, whatlang::Lang::Rus => Lang::Russian, whatlang::Lang::Spa => Lang::Spanish, whatlang::Lang::Por => Lang::Portuguese, whatlang::Lang::Ita => Lang::Italian, whatlang::Lang::Fra => Lang::French, whatlang::Lang::Deu => Lang::German, whatlang::Lang::Ukr => Lang::Ukrainian, whatlang::Lang::Kat => Lang::Georgian, whatlang::Lang::Pol => Lang::Polish, whatlang::Lang::Dan => Lang::Danish, whatlang::Lang::Swe => Lang::Swedish, whatlang::Lang::Fin => Lang::Finnish, whatlang::Lang::Tur => Lang::Turkish, whatlang::Lang::Nld => Lang::Dutch, whatlang::Lang::Hun => Lang::Hungarian, whatlang::Lang::Ces => Lang::Czech, whatlang::Lang::Ell => Lang::Greek, whatlang::Lang::Bul => Lang::Bulgarian, whatlang::Lang::Bel => Lang::Belarusian, whatlang::Lang::Hrv => Lang::Croatian, whatlang::Lang::Srp => Lang::Serbian, whatlang::Lang::Lit => Lang::Lithuanian, whatlang::Lang::Est => Lang::Estonian, whatlang::Lang::Tuk => Lang::Turkmen, whatlang::Lang::Afr => Lang::Afrikaans, whatlang::Lang::Lat => Lang::Latin, whatlang::Lang::Slk => Lang::Slovak, whatlang::Lang::Cat => Lang::Catalan, _ => Lang::English, } } } impl Inference { fn matches(&self, source: &str, complete: &str) -> bool { let source_chars: Vec<&str> = source.graphemes(true).collect(); let complete_chars: Vec<&str> = complete.graphemes(true).collect(); let s = usize::min(source_chars.len(), complete_chars.len()); for j in 0..s { let metric = kondrak_aline::delta(source_chars[j], complete_chars[j]); if metric > 20.0 { return false; } } return true; } pub fn new( inference_service: Arc>, save: Save, ) -> Inference { return Inference { inference_service, save, }; } pub async fn infer(&self, prefix: GraphemeString) -> Result { debug!("starting inference for {prefix}"); let cprefix = prefix.0.clone(); let lang = TLang( whatlang::detect_lang(&cprefix) .or(Some(whatlang::Lang::Eng)) .unwrap(), ); let langt: Lang = lang.into(); // The first step is to use the inference service in order to find the phoneme corresponding to the word. let phonemes = PhonemeString::from_grapheme(prefix.clone(), &self.inference_service).await?; // we then "cut" the word into his syllables let cprefix = prefix.0.clone(); let hyphens: Vec<&str> = hyphenate(&cprefix, langt).collect_vec(); let cprefix = prefix.0.clone(); // we then perform the random walk into the tree to find a random // word starting with the given phoneme combination let completion = self .save .trie .random_starting_with(phonemes.clone()).unwrap(); // store the found word let initial = cprefix.add(&completion.0); // cut the found word syllabes let mut completed_hyphens: Vec<&str> = hyphenate(&initial, langt).into_iter().collect_vec(); // we search the matched word in the reverse dictionary let matched = self .save .reverse_index .get(&completion) .ok_or_else(InferenceError::NotInDictionary)?; debug!( "found matching word '{}' that has phonetic '{}' that start with '{}'", initial, completion, phonemes ); // we store the index where the word should be cut let mut highest_index = 0usize; for (index, completed_hyphen) in completed_hyphens .splice(0..hyphens.len() - 1, iter::empty()) .enumerate() { let source_hyph = hyphens[index].to_lowercase(); let complete_hyph = completed_hyphen.to_lowercase(); debug!( "[{}] comparing hyphen src={} dst={}", index, source_hyph, complete_hyph ); if self.matches(&source_hyph, &complete_hyph) { highest_index = index } else { debug!( "[{}] found matching hyphen at index {}", index, highest_index ); break; } } // we finally just need to compute the end of the word which matches the sound let found = completed_hyphens.drain(highest_index + 1..).join(""); if found.len() == 0 { debug!("failed to find where to cut the word to match"); Err(InferenceError::WordCutError()) } else { debug!("found that {} is equivalent to {}", completion, found); Ok(format!("{} ({})", found, matched)) } } }