diff options
| author | Matthieu Pignolet <matthieu@puffer.fish> | 2025-05-19 17:05:09 +0400 |
|---|---|---|
| committer | Matthieu Pignolet <matthieu@puffer.fish> | 2025-05-19 17:05:09 +0400 |
| commit | ca310bf1b62988fd0a5bf91531a8147a0f6a3343 (patch) | |
| tree | a292ac0bf2374bbb8ae62bac8cbe80601de100a4 /libs/db | |
| parent | 4debeb335f171eecc034f905617bf985f8763a33 (diff) | |
feat: add db library
Diffstat (limited to 'libs/db')
| -rw-r--r-- | libs/db/Cargo.toml | 17 | ||||
| -rw-r--r-- | libs/db/src/inference.rs | 171 | ||||
| -rw-r--r-- | libs/db/src/lib.rs | 4 | ||||
| -rw-r--r-- | libs/db/src/save.rs | 10 | ||||
| -rw-r--r-- | libs/db/src/trie.rs | 156 | ||||
| -rw-r--r-- | libs/db/src/types.rs | 43 |
6 files changed, 401 insertions, 0 deletions
diff --git a/libs/db/Cargo.toml b/libs/db/Cargo.toml new file mode 100644 index 0000000..c6a7c1c --- /dev/null +++ b/libs/db/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "db" +version = "0.1.0" +edition = "2024" + +[dependencies] +unicode-segmentation = "1.10.0" +rand = "0.8.5" +serde = { version = "1.0.152", features = ["derive"] } +itertools = "0.14.0" +hypher = "0.1.5" +thiserror = "2.0.12" +async-trait = "0.1.88" +log = "0.4.27" +strsim = "0.11.1" +whatlang = "0.16.4" +kondrak-aline = "0.1.3" diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs new file mode 100644 index 0000000..f2cbabf --- /dev/null +++ b/libs/db/src/inference.rs @@ -0,0 +1,171 @@ +use crate::{ + save::Save, + types::{GraphemeString, InferenceService, PhonemeString}, +}; +use hypher::{Lang, hyphenate}; +use itertools::Itertools; +use log::*; +use std::{iter, ops::Add}; +use thiserror::Error; +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Error, Debug)] +pub enum InferenceError { + #[error("failed to fetch a random value from the trie")] + TrieError(), + #[error("the matched value wasn't in the dictionary")] + NotInDictionary(), + #[error("failt to cut the word")] + WordCutError(), +} + +pub struct Inference<'a> { + inference_service: Box<dyn InferenceService<InferenceError>>, + save: Save<'a>, +} + +pub struct TLang(whatlang::Lang); +impl Into<Lang> for TLang { + fn into(self) -> Lang { + match self.0 { + whatlang::Lang::Eng => Lang::English, + whatlang::Lang::Rus => Lang::Russian, + whatlang::Lang::Spa => Lang::Spanish, + whatlang::Lang::Por => Lang::Portuguese, + whatlang::Lang::Ita => Lang::Italian, + whatlang::Lang::Fra => Lang::French, + whatlang::Lang::Deu => Lang::German, + whatlang::Lang::Ukr => Lang::Ukrainian, + whatlang::Lang::Kat => Lang::Georgian, + whatlang::Lang::Pol => Lang::Polish, + whatlang::Lang::Dan => Lang::Danish, + whatlang::Lang::Swe => Lang::Swedish, + whatlang::Lang::Fin => Lang::Finnish, + whatlang::Lang::Tur => Lang::Turkish, + whatlang::Lang::Nld => Lang::Dutch, + whatlang::Lang::Hun => Lang::Hungarian, + whatlang::Lang::Ces => Lang::Czech, + whatlang::Lang::Ell => Lang::Greek, + whatlang::Lang::Bul => Lang::Bulgarian, + whatlang::Lang::Bel => Lang::Belarusian, + whatlang::Lang::Hrv => Lang::Croatian, + whatlang::Lang::Srp => Lang::Serbian, + whatlang::Lang::Lit => Lang::Lithuanian, + whatlang::Lang::Est => Lang::Estonian, + whatlang::Lang::Tuk => Lang::Turkmen, + whatlang::Lang::Afr => Lang::Afrikaans, + whatlang::Lang::Lat => Lang::Latin, + whatlang::Lang::Slk => Lang::Slovak, + whatlang::Lang::Cat => Lang::Catalan, + _ => Lang::English, + } + } +} + +impl Inference<'_> { + fn matches(&self, source: &str, complete: &str) -> bool { + let source_chars: Vec<&str> = source.graphemes(true).collect(); + let complete_chars: Vec<&str> = complete.graphemes(true).collect(); + + let s = usize::min(source_chars.len(), complete_chars.len()); + + for j in 0..s { + let metric = kondrak_aline::delta(source_chars[j], complete_chars[j]); + if metric > 20.0 { + return false; + } + } + + return true; + } + + pub fn new<'a>( + inference_service: Box<dyn InferenceService<InferenceError>>, + save: Save<'a>, + ) -> Inference<'a> { + return Inference { + inference_service, + save, + }; + } + + pub async fn infer(&self, prefix: GraphemeString) -> Result<String, InferenceError> { + debug!("starting inference for {prefix}"); + + let cprefix = prefix.0.clone(); + let lang = TLang( + whatlang::detect_lang(&cprefix) + .or(Some(whatlang::Lang::Eng)) + .unwrap(), + ); + let langt: Lang = lang.into(); + // The first step is to use the inference service in order to find the phoneme corresponding to the word. + let phonemes = + PhonemeString::from_grapheme(prefix.clone(), &self.inference_service).await?; + // we then "cut" the word into his syllables + let cprefix = prefix.0.clone(); + let hyphens: Vec<&str> = hyphenate(&cprefix, langt).collect_vec(); + + let cprefix = prefix.0.clone(); + // we then perform the random walk into the tree to find a random + // word starting with the given phoneme combination + let completion = self + .save + .trie + .random_starting_with(phonemes.clone()) + .ok_or_else(InferenceError::TrieError)?; + + // store the found word + let initial = cprefix.add(&completion.0); + // cut the found word syllabes + let mut completed_hyphens: Vec<&str> = hyphenate(&initial, langt).into_iter().collect_vec(); + + // we search the matched word in the reverse dictionary + let matched = self + .save + .reverse_index + .get(&completion) + .ok_or_else(InferenceError::NotInDictionary)?; + + debug!( + "found matching word '{}' that has phonetic '{}' that start with '{}'", + initial, completion, phonemes + ); + + // we store the index where the word should be cut + let mut highest_index = 0usize; + for (index, completed_hyphen) in completed_hyphens + .splice(0..hyphens.len() - 1, iter::empty()) + .enumerate() + { + let source_hyph = hyphens[index].to_lowercase(); + let complete_hyph = completed_hyphen.to_lowercase(); + + debug!( + "[{}] comparing hyphen src={} dst={}", + index, source_hyph, complete_hyph + ); + + if self.matches(&source_hyph, &complete_hyph) { + highest_index = index + } else { + debug!( + "[{}] found matching hyphen at index {}", + index, highest_index + ); + break; + } + } + + // we finally just need to compute the end of the word which matches the sound + let found = completed_hyphens.drain(highest_index + 1..).join(""); + + if found.len() == 0 { + debug!("failed to find where to cut the word to match"); + Err(InferenceError::WordCutError()) + } else { + debug!("found that {} is equivalent to {}", completion, found); + Ok(format!("{} ({})", found, matched)) + } + } +} diff --git a/libs/db/src/lib.rs b/libs/db/src/lib.rs new file mode 100644 index 0000000..2eecd05 --- /dev/null +++ b/libs/db/src/lib.rs @@ -0,0 +1,4 @@ +pub mod trie; +pub mod save; +pub mod inference; +pub mod types;
\ No newline at end of file diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs new file mode 100644 index 0000000..e1b3577 --- /dev/null +++ b/libs/db/src/save.rs @@ -0,0 +1,10 @@ +use crate::{trie::Trie, types::{GraphemeString, PhonemeString}}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct Save<'a> { + #[serde(borrow = "'a")] + pub trie: Trie<'a>, + pub reverse_index: HashMap<PhonemeString, GraphemeString>, +} diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs new file mode 100644 index 0000000..38cbcd8 --- /dev/null +++ b/libs/db/src/trie.rs @@ -0,0 +1,156 @@ +use crate::types::PhonemeString; +use log::debug; +use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct TrieNode<'a> { + is_final: bool, + #[serde(borrow = "'a")] + child_nodes: HashMap<&'a str, TrieNode<'a>>, + child_count: u64, +} + +impl<'a> TrieNode<'a> { + pub fn new<'b>(is_final: bool) -> TrieNode<'b> { + TrieNode { + is_final, + child_nodes: HashMap::new(), + child_count: 0, + } + } + + pub fn new_root<'b>() -> TrieNode<'b> { + TrieNode { + is_final: false, + child_nodes: HashMap::new(), + child_count: 0, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Trie<'a> { + #[serde(borrow = "'a")] + root_node: Box<TrieNode<'a>>, +} + +impl<'a> Trie<'a> { + // Create a TrieStruct + pub fn new<'b>() -> Trie<'b> { + Trie { + root_node: Box::new(TrieNode::new_root()), + } + } + + // Insert a string + pub fn insert(&mut self, char_list: &'a PhonemeString) { + let mut current_node: &mut TrieNode = self.root_node.as_mut(); + let iterator = char_list.0.graphemes(true); + let mut create = false; + + current_node.child_count += 1; + // Find the minimum existing math + for str in iterator { + if create == false { + if current_node.child_nodes.contains_key(str) { + current_node = current_node.child_nodes.get_mut(str).unwrap(); + // we mark the node as containing our children. + current_node.child_count += 1; + } else { + create = true; + + current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node = current_node.child_nodes.get_mut(str).unwrap(); + current_node.child_count = 1; + } + } else { + current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node = current_node.child_nodes.get_mut(str).unwrap(); + // we will only have one final node + current_node.child_count = 1; + } + } + current_node.is_final = true; + } + + // Find a string + pub fn random_starting_with(&self, prefix: PhonemeString) -> Option<PhonemeString> { + let mut current_node: &TrieNode = self.root_node.as_ref(); + // String for the return value + let mut builder = String::new(); + + // Iterator over each grapheme + let graphemes = prefix.0.graphemes(true).enumerate(); + + // Descend as far as possible into the tree + for (_, str) in graphemes { + // If we can descend further into the tree + if let Some(node) = current_node.child_nodes.get(&str) { + current_node = node; + builder += str; + debug!("going into node {}", builder); + } else { + // couldn't descend fully into the tree + // this basically means nothing exist in the tree + // with this prefix + debug!("no matches for prefix!"); + return None; + } + } + + debug!("Found common root node {}", builder); + builder = String::new(); + let mut rng = thread_rng(); + + let mut length = 0; + while current_node.child_nodes.len() != 0 { + // We need to choose a random child based on weights + let weighted = WeightedIndex::new(current_node.child_nodes.iter().map(|(_, node)| { + node.child_count + / (node + .child_nodes + .iter() + .map(|(_, b)| b.child_count) + .sum::<u64>() + + 1) + + 1 + })) + .expect("distribution creation should be valid"); + + let (key, node) = current_node + .child_nodes + .iter() + .nth(weighted.sample(&mut rng)) + .expect("choosed value did not exist"); + debug!("walking into node {}", key); + + current_node = node; + builder += key; + + // If this node is final and has childrens + if current_node.is_final && current_node.child_count > 0 { + // choose from current node or continue with childrens + let weighted = WeightedIndex::new(&[1, current_node.child_count / (length + 1)]) + .expect("distribution seems impossible"); + + if weighted.sample(&mut rng) == 0 { + // we choosed this node! + // stop adding other characters + break; + } + } + length += 1; + } + + // If we only added + if builder == "" { + return None; + } + + // selected word + Some(PhonemeString(builder)) + } +} diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs new file mode 100644 index 0000000..6062c98 --- /dev/null +++ b/libs/db/src/types.rs @@ -0,0 +1,43 @@ +use std::{error::Error, fmt::Display}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +// This is needed since we need dynamic dispatch +// as any kind of inference service could be used. +#[async_trait] +pub trait InferenceService<E: Error> : Send + Sync { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>; +} + +#[repr(transparent)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct GraphemeString(pub String); + +#[repr(transparent)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct PhonemeString(pub String); + +impl Display for PhonemeString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl Display for GraphemeString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl PhonemeString { + /// Converts a Grapheme representation of a work + /// into a phoneme one. + /// !! This is a one-way operation. + pub async fn from_grapheme<E: Error>( + value: GraphemeString, + inference: &Box<dyn InferenceService<E>>, + ) -> Result<Self, E> { + inference.fetch(value).await + } +} |
