diff options
Diffstat (limited to 'libs/db')
| -rw-r--r-- | libs/db/Cargo.toml | 17 | ||||
| -rw-r--r-- | libs/db/src/inference.rs | 171 | ||||
| -rw-r--r-- | libs/db/src/lib.rs | 4 | ||||
| -rw-r--r-- | libs/db/src/save.rs | 10 | ||||
| -rw-r--r-- | libs/db/src/trie.rs | 156 | ||||
| -rw-r--r-- | libs/db/src/types.rs | 43 | 
6 files changed, 401 insertions, 0 deletions
diff --git a/libs/db/Cargo.toml b/libs/db/Cargo.toml new file mode 100644 index 0000000..c6a7c1c --- /dev/null +++ b/libs/db/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "db" +version = "0.1.0" +edition = "2024" + +[dependencies] +unicode-segmentation = "1.10.0" +rand = "0.8.5" +serde = { version = "1.0.152", features = ["derive"] } +itertools = "0.14.0" +hypher = "0.1.5" +thiserror = "2.0.12" +async-trait = "0.1.88" +log = "0.4.27" +strsim = "0.11.1" +whatlang = "0.16.4" +kondrak-aline = "0.1.3" diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs new file mode 100644 index 0000000..f2cbabf --- /dev/null +++ b/libs/db/src/inference.rs @@ -0,0 +1,171 @@ +use crate::{ +    save::Save, +    types::{GraphemeString, InferenceService, PhonemeString}, +}; +use hypher::{Lang, hyphenate}; +use itertools::Itertools; +use log::*; +use std::{iter, ops::Add}; +use thiserror::Error; +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Error, Debug)] +pub enum InferenceError { +    #[error("failed to fetch a random value from the trie")] +    TrieError(), +    #[error("the matched value wasn't in the dictionary")] +    NotInDictionary(), +    #[error("failt to cut the word")] +    WordCutError(), +} + +pub struct Inference<'a> { +    inference_service: Box<dyn InferenceService<InferenceError>>, +    save: Save<'a>, +} + +pub struct TLang(whatlang::Lang); +impl Into<Lang> for TLang { +    fn into(self) -> Lang { +        match self.0 { +            whatlang::Lang::Eng => Lang::English, +            whatlang::Lang::Rus => Lang::Russian, +            whatlang::Lang::Spa => Lang::Spanish, +            whatlang::Lang::Por => Lang::Portuguese, +            whatlang::Lang::Ita => Lang::Italian, +            whatlang::Lang::Fra => Lang::French, +            whatlang::Lang::Deu => Lang::German, +            whatlang::Lang::Ukr => Lang::Ukrainian, +            whatlang::Lang::Kat => Lang::Georgian, +            whatlang::Lang::Pol => Lang::Polish, +            whatlang::Lang::Dan => Lang::Danish, +            whatlang::Lang::Swe => Lang::Swedish, +            whatlang::Lang::Fin => Lang::Finnish, +            whatlang::Lang::Tur => Lang::Turkish, +            whatlang::Lang::Nld => Lang::Dutch, +            whatlang::Lang::Hun => Lang::Hungarian, +            whatlang::Lang::Ces => Lang::Czech, +            whatlang::Lang::Ell => Lang::Greek, +            whatlang::Lang::Bul => Lang::Bulgarian, +            whatlang::Lang::Bel => Lang::Belarusian, +            whatlang::Lang::Hrv => Lang::Croatian, +            whatlang::Lang::Srp => Lang::Serbian, +            whatlang::Lang::Lit => Lang::Lithuanian, +            whatlang::Lang::Est => Lang::Estonian, +            whatlang::Lang::Tuk => Lang::Turkmen, +            whatlang::Lang::Afr => Lang::Afrikaans, +            whatlang::Lang::Lat => Lang::Latin, +            whatlang::Lang::Slk => Lang::Slovak, +            whatlang::Lang::Cat => Lang::Catalan, +            _ => Lang::English, +        } +    } +} + +impl Inference<'_> { +    fn matches(&self, source: &str, complete: &str) -> bool { +        let source_chars: Vec<&str> = source.graphemes(true).collect(); +        let complete_chars: Vec<&str> = complete.graphemes(true).collect(); + +        let s = usize::min(source_chars.len(), complete_chars.len()); + +        for j in 0..s { +            let metric = kondrak_aline::delta(source_chars[j], complete_chars[j]); +            if metric > 20.0 { +                return false; +            } +        } + +        return true; +    } + +    pub fn new<'a>( +        inference_service: Box<dyn InferenceService<InferenceError>>, +        save: Save<'a>, +    ) -> Inference<'a> { +        return Inference { +            inference_service, +            save, +        }; +    } + +    pub async fn infer(&self, prefix: GraphemeString) -> Result<String, InferenceError> { +        debug!("starting inference for {prefix}"); + +        let cprefix = prefix.0.clone(); +        let lang = TLang( +            whatlang::detect_lang(&cprefix) +                .or(Some(whatlang::Lang::Eng)) +                .unwrap(), +        ); +        let langt: Lang = lang.into(); +        // The first step is to use the inference service in order to find the phoneme corresponding to the word. +        let phonemes = +            PhonemeString::from_grapheme(prefix.clone(), &self.inference_service).await?; +        // we then "cut" the word into his syllables +        let cprefix = prefix.0.clone(); +        let hyphens: Vec<&str> = hyphenate(&cprefix, langt).collect_vec(); + +        let cprefix = prefix.0.clone(); +        // we then perform the random walk into the tree to find a random +        // word starting with the given phoneme combination +        let completion = self +            .save +            .trie +            .random_starting_with(phonemes.clone()) +            .ok_or_else(InferenceError::TrieError)?; + +        // store the found word +        let initial = cprefix.add(&completion.0); +        // cut the found word syllabes +        let mut completed_hyphens: Vec<&str> = hyphenate(&initial, langt).into_iter().collect_vec(); + +        // we search the matched word in the reverse dictionary +        let matched = self +            .save +            .reverse_index +            .get(&completion) +            .ok_or_else(InferenceError::NotInDictionary)?; + +        debug!( +            "found matching word '{}' that has phonetic '{}' that start with '{}'", +            initial, completion, phonemes +        ); + +        // we store the index where the word should be cut +        let mut highest_index = 0usize; +        for (index, completed_hyphen) in completed_hyphens +            .splice(0..hyphens.len() - 1, iter::empty()) +            .enumerate() +        { +            let source_hyph = hyphens[index].to_lowercase(); +            let complete_hyph = completed_hyphen.to_lowercase(); + +            debug!( +                "[{}] comparing hyphen src={} dst={}", +                index, source_hyph, complete_hyph +            ); + +            if self.matches(&source_hyph, &complete_hyph) { +                highest_index = index +            } else { +                debug!( +                    "[{}] found matching hyphen at index {}", +                    index, highest_index +                ); +                break; +            } +        } + +        // we finally just need to compute the end of the word which matches the sound +        let found = completed_hyphens.drain(highest_index + 1..).join(""); + +        if found.len() == 0 { +            debug!("failed to find where to cut the word to match"); +            Err(InferenceError::WordCutError()) +        } else { +            debug!("found that {} is equivalent to {}", completion, found); +            Ok(format!("{} ({})", found, matched)) +        } +    } +} diff --git a/libs/db/src/lib.rs b/libs/db/src/lib.rs new file mode 100644 index 0000000..2eecd05 --- /dev/null +++ b/libs/db/src/lib.rs @@ -0,0 +1,4 @@ +pub mod trie; +pub mod save; +pub mod inference; +pub mod types;
\ No newline at end of file diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs new file mode 100644 index 0000000..e1b3577 --- /dev/null +++ b/libs/db/src/save.rs @@ -0,0 +1,10 @@ +use crate::{trie::Trie, types::{GraphemeString, PhonemeString}}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct Save<'a> { +    #[serde(borrow = "'a")] +    pub trie: Trie<'a>, +    pub reverse_index: HashMap<PhonemeString, GraphemeString>, +} diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs new file mode 100644 index 0000000..38cbcd8 --- /dev/null +++ b/libs/db/src/trie.rs @@ -0,0 +1,156 @@ +use crate::types::PhonemeString; +use log::debug; +use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct TrieNode<'a> { +    is_final: bool, +    #[serde(borrow = "'a")] +    child_nodes: HashMap<&'a str, TrieNode<'a>>, +    child_count: u64, +} + +impl<'a> TrieNode<'a> { +    pub fn new<'b>(is_final: bool) -> TrieNode<'b> { +        TrieNode { +            is_final, +            child_nodes: HashMap::new(), +            child_count: 0, +        } +    } + +    pub fn new_root<'b>() -> TrieNode<'b> { +        TrieNode { +            is_final: false, +            child_nodes: HashMap::new(), +            child_count: 0, +        } +    } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Trie<'a> { +    #[serde(borrow = "'a")] +    root_node: Box<TrieNode<'a>>, +} + +impl<'a> Trie<'a> { +    // Create a TrieStruct +    pub fn new<'b>() -> Trie<'b> { +        Trie { +            root_node: Box::new(TrieNode::new_root()), +        } +    } + +    // Insert a string +    pub fn insert(&mut self, char_list: &'a PhonemeString) { +        let mut current_node: &mut TrieNode = self.root_node.as_mut(); +        let iterator = char_list.0.graphemes(true); +        let mut create = false; + +        current_node.child_count += 1; +        // Find the minimum existing math +        for str in iterator { +            if create == false { +                if current_node.child_nodes.contains_key(str) { +                    current_node = current_node.child_nodes.get_mut(str).unwrap(); +                    // we mark the node as containing our children. +                    current_node.child_count += 1; +                } else { +                    create = true; + +                    current_node.child_nodes.insert(str, TrieNode::new(false)); +                    current_node = current_node.child_nodes.get_mut(str).unwrap(); +                    current_node.child_count = 1; +                } +            } else { +                current_node.child_nodes.insert(str, TrieNode::new(false)); +                current_node = current_node.child_nodes.get_mut(str).unwrap(); +                // we will only have one final node +                current_node.child_count = 1; +            } +        } +        current_node.is_final = true; +    } + +    // Find a string +    pub fn random_starting_with(&self, prefix: PhonemeString) -> Option<PhonemeString> { +        let mut current_node: &TrieNode = self.root_node.as_ref(); +        // String for the return value +        let mut builder = String::new(); + +        // Iterator over each grapheme +        let graphemes = prefix.0.graphemes(true).enumerate(); + +        // Descend as far as possible into the tree +        for (_, str) in graphemes { +            // If we can descend further into the tree +            if let Some(node) = current_node.child_nodes.get(&str) { +                current_node = node; +                builder += str; +                debug!("going into node {}", builder); +            } else { +                // couldn't descend fully into the tree +                // this basically means nothing exist in the tree +                // with this prefix +                debug!("no matches for prefix!"); +                return None; +            } +        } + +        debug!("Found common root node {}", builder); +        builder = String::new(); +        let mut rng = thread_rng(); + +        let mut length = 0; +        while current_node.child_nodes.len() != 0 { +            // We need to choose a random child based on weights +            let weighted = WeightedIndex::new(current_node.child_nodes.iter().map(|(_, node)| { +                node.child_count +                    / (node +                        .child_nodes +                        .iter() +                        .map(|(_, b)| b.child_count) +                        .sum::<u64>() +                        + 1) +                    + 1 +            })) +            .expect("distribution creation should be valid"); + +            let (key, node) = current_node +                .child_nodes +                .iter() +                .nth(weighted.sample(&mut rng)) +                .expect("choosed value did not exist"); +            debug!("walking into node {}", key); + +            current_node = node; +            builder += key; + +            // If this node is final and has childrens +            if current_node.is_final && current_node.child_count > 0 { +                // choose from current node or continue with childrens +                let weighted = WeightedIndex::new(&[1, current_node.child_count / (length + 1)]) +                    .expect("distribution seems impossible"); + +                if weighted.sample(&mut rng) == 0 { +                    // we choosed this node! +                    // stop adding other characters +                    break; +                } +            } +            length += 1; +        } + +        // If we only added +        if builder == "" { +            return None; +        } + +        // selected word +        Some(PhonemeString(builder)) +    } +} diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs new file mode 100644 index 0000000..6062c98 --- /dev/null +++ b/libs/db/src/types.rs @@ -0,0 +1,43 @@ +use std::{error::Error, fmt::Display}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +// This is needed since we need dynamic dispatch +// as any kind of inference service could be used. +#[async_trait] +pub trait InferenceService<E: Error> : Send + Sync { +    async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>; +} + +#[repr(transparent)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct GraphemeString(pub String); + +#[repr(transparent)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct PhonemeString(pub String); + +impl Display for PhonemeString { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        f.write_str(&self.0) +    } +} + +impl Display for GraphemeString { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        f.write_str(&self.0) +    } +} + +impl PhonemeString { +    /// Converts a Grapheme representation of a work +    /// into a phoneme one. +    /// !! This is a one-way operation. +    pub async fn from_grapheme<E: Error>( +        value: GraphemeString, +        inference: &Box<dyn InferenceService<E>>, +    ) -> Result<Self, E> { +        inference.fetch(value).await +    } +}  | 
