summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthieu Pignolet <matthieu@puffer.fish>2025-05-19 17:05:09 +0400
committerMatthieu Pignolet <matthieu@puffer.fish>2025-05-19 17:05:09 +0400
commitca310bf1b62988fd0a5bf91531a8147a0f6a3343 (patch)
treea292ac0bf2374bbb8ae62bac8cbe80601de100a4
parent4debeb335f171eecc034f905617bf985f8763a33 (diff)
feat: add db library
-rw-r--r--libs/db/Cargo.toml17
-rw-r--r--libs/db/src/inference.rs171
-rw-r--r--libs/db/src/lib.rs4
-rw-r--r--libs/db/src/save.rs10
-rw-r--r--libs/db/src/trie.rs156
-rw-r--r--libs/db/src/types.rs43
6 files changed, 401 insertions, 0 deletions
diff --git a/libs/db/Cargo.toml b/libs/db/Cargo.toml
new file mode 100644
index 0000000..c6a7c1c
--- /dev/null
+++ b/libs/db/Cargo.toml
@@ -0,0 +1,17 @@
+[package]
+name = "db"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+unicode-segmentation = "1.10.0"
+rand = "0.8.5"
+serde = { version = "1.0.152", features = ["derive"] }
+itertools = "0.14.0"
+hypher = "0.1.5"
+thiserror = "2.0.12"
+async-trait = "0.1.88"
+log = "0.4.27"
+strsim = "0.11.1"
+whatlang = "0.16.4"
+kondrak-aline = "0.1.3"
diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs
new file mode 100644
index 0000000..f2cbabf
--- /dev/null
+++ b/libs/db/src/inference.rs
@@ -0,0 +1,171 @@
+use crate::{
+ save::Save,
+ types::{GraphemeString, InferenceService, PhonemeString},
+};
+use hypher::{Lang, hyphenate};
+use itertools::Itertools;
+use log::*;
+use std::{iter, ops::Add};
+use thiserror::Error;
+use unicode_segmentation::UnicodeSegmentation;
+
+#[derive(Error, Debug)]
+pub enum InferenceError {
+ #[error("failed to fetch a random value from the trie")]
+ TrieError(),
+ #[error("the matched value wasn't in the dictionary")]
+ NotInDictionary(),
+ #[error("failt to cut the word")]
+ WordCutError(),
+}
+
+pub struct Inference<'a> {
+ inference_service: Box<dyn InferenceService<InferenceError>>,
+ save: Save<'a>,
+}
+
+pub struct TLang(whatlang::Lang);
+impl Into<Lang> for TLang {
+ fn into(self) -> Lang {
+ match self.0 {
+ whatlang::Lang::Eng => Lang::English,
+ whatlang::Lang::Rus => Lang::Russian,
+ whatlang::Lang::Spa => Lang::Spanish,
+ whatlang::Lang::Por => Lang::Portuguese,
+ whatlang::Lang::Ita => Lang::Italian,
+ whatlang::Lang::Fra => Lang::French,
+ whatlang::Lang::Deu => Lang::German,
+ whatlang::Lang::Ukr => Lang::Ukrainian,
+ whatlang::Lang::Kat => Lang::Georgian,
+ whatlang::Lang::Pol => Lang::Polish,
+ whatlang::Lang::Dan => Lang::Danish,
+ whatlang::Lang::Swe => Lang::Swedish,
+ whatlang::Lang::Fin => Lang::Finnish,
+ whatlang::Lang::Tur => Lang::Turkish,
+ whatlang::Lang::Nld => Lang::Dutch,
+ whatlang::Lang::Hun => Lang::Hungarian,
+ whatlang::Lang::Ces => Lang::Czech,
+ whatlang::Lang::Ell => Lang::Greek,
+ whatlang::Lang::Bul => Lang::Bulgarian,
+ whatlang::Lang::Bel => Lang::Belarusian,
+ whatlang::Lang::Hrv => Lang::Croatian,
+ whatlang::Lang::Srp => Lang::Serbian,
+ whatlang::Lang::Lit => Lang::Lithuanian,
+ whatlang::Lang::Est => Lang::Estonian,
+ whatlang::Lang::Tuk => Lang::Turkmen,
+ whatlang::Lang::Afr => Lang::Afrikaans,
+ whatlang::Lang::Lat => Lang::Latin,
+ whatlang::Lang::Slk => Lang::Slovak,
+ whatlang::Lang::Cat => Lang::Catalan,
+ _ => Lang::English,
+ }
+ }
+}
+
+impl Inference<'_> {
+ fn matches(&self, source: &str, complete: &str) -> bool {
+ let source_chars: Vec<&str> = source.graphemes(true).collect();
+ let complete_chars: Vec<&str> = complete.graphemes(true).collect();
+
+ let s = usize::min(source_chars.len(), complete_chars.len());
+
+ for j in 0..s {
+ let metric = kondrak_aline::delta(source_chars[j], complete_chars[j]);
+ if metric > 20.0 {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ pub fn new<'a>(
+ inference_service: Box<dyn InferenceService<InferenceError>>,
+ save: Save<'a>,
+ ) -> Inference<'a> {
+ return Inference {
+ inference_service,
+ save,
+ };
+ }
+
+ pub async fn infer(&self, prefix: GraphemeString) -> Result<String, InferenceError> {
+ debug!("starting inference for {prefix}");
+
+ let cprefix = prefix.0.clone();
+ let lang = TLang(
+ whatlang::detect_lang(&cprefix)
+ .or(Some(whatlang::Lang::Eng))
+ .unwrap(),
+ );
+ let langt: Lang = lang.into();
+ // The first step is to use the inference service in order to find the phoneme corresponding to the word.
+ let phonemes =
+ PhonemeString::from_grapheme(prefix.clone(), &self.inference_service).await?;
+ // we then "cut" the word into his syllables
+ let cprefix = prefix.0.clone();
+ let hyphens: Vec<&str> = hyphenate(&cprefix, langt).collect_vec();
+
+ let cprefix = prefix.0.clone();
+ // we then perform the random walk into the tree to find a random
+ // word starting with the given phoneme combination
+ let completion = self
+ .save
+ .trie
+ .random_starting_with(phonemes.clone())
+ .ok_or_else(InferenceError::TrieError)?;
+
+ // store the found word
+ let initial = cprefix.add(&completion.0);
+ // cut the found word syllabes
+ let mut completed_hyphens: Vec<&str> = hyphenate(&initial, langt).into_iter().collect_vec();
+
+ // we search the matched word in the reverse dictionary
+ let matched = self
+ .save
+ .reverse_index
+ .get(&completion)
+ .ok_or_else(InferenceError::NotInDictionary)?;
+
+ debug!(
+ "found matching word '{}' that has phonetic '{}' that start with '{}'",
+ initial, completion, phonemes
+ );
+
+ // we store the index where the word should be cut
+ let mut highest_index = 0usize;
+ for (index, completed_hyphen) in completed_hyphens
+ .splice(0..hyphens.len() - 1, iter::empty())
+ .enumerate()
+ {
+ let source_hyph = hyphens[index].to_lowercase();
+ let complete_hyph = completed_hyphen.to_lowercase();
+
+ debug!(
+ "[{}] comparing hyphen src={} dst={}",
+ index, source_hyph, complete_hyph
+ );
+
+ if self.matches(&source_hyph, &complete_hyph) {
+ highest_index = index
+ } else {
+ debug!(
+ "[{}] found matching hyphen at index {}",
+ index, highest_index
+ );
+ break;
+ }
+ }
+
+ // we finally just need to compute the end of the word which matches the sound
+ let found = completed_hyphens.drain(highest_index + 1..).join("");
+
+ if found.len() == 0 {
+ debug!("failed to find where to cut the word to match");
+ Err(InferenceError::WordCutError())
+ } else {
+ debug!("found that {} is equivalent to {}", completion, found);
+ Ok(format!("{} ({})", found, matched))
+ }
+ }
+}
diff --git a/libs/db/src/lib.rs b/libs/db/src/lib.rs
new file mode 100644
index 0000000..2eecd05
--- /dev/null
+++ b/libs/db/src/lib.rs
@@ -0,0 +1,4 @@
+pub mod trie;
+pub mod save;
+pub mod inference;
+pub mod types; \ No newline at end of file
diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs
new file mode 100644
index 0000000..e1b3577
--- /dev/null
+++ b/libs/db/src/save.rs
@@ -0,0 +1,10 @@
+use crate::{trie::Trie, types::{GraphemeString, PhonemeString}};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+
+#[derive(Debug, Deserialize, Serialize, Default)]
+pub struct Save<'a> {
+ #[serde(borrow = "'a")]
+ pub trie: Trie<'a>,
+ pub reverse_index: HashMap<PhonemeString, GraphemeString>,
+}
diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs
new file mode 100644
index 0000000..38cbcd8
--- /dev/null
+++ b/libs/db/src/trie.rs
@@ -0,0 +1,156 @@
+use crate::types::PhonemeString;
+use log::debug;
+use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use unicode_segmentation::UnicodeSegmentation;
+
+#[derive(Debug, Serialize, Deserialize, Default)]
+pub struct TrieNode<'a> {
+ is_final: bool,
+ #[serde(borrow = "'a")]
+ child_nodes: HashMap<&'a str, TrieNode<'a>>,
+ child_count: u64,
+}
+
+impl<'a> TrieNode<'a> {
+ pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
+ TrieNode {
+ is_final,
+ child_nodes: HashMap::new(),
+ child_count: 0,
+ }
+ }
+
+ pub fn new_root<'b>() -> TrieNode<'b> {
+ TrieNode {
+ is_final: false,
+ child_nodes: HashMap::new(),
+ child_count: 0,
+ }
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize, Default)]
+pub struct Trie<'a> {
+ #[serde(borrow = "'a")]
+ root_node: Box<TrieNode<'a>>,
+}
+
+impl<'a> Trie<'a> {
+ // Create a TrieStruct
+ pub fn new<'b>() -> Trie<'b> {
+ Trie {
+ root_node: Box::new(TrieNode::new_root()),
+ }
+ }
+
+ // Insert a string
+ pub fn insert(&mut self, char_list: &'a PhonemeString) {
+ let mut current_node: &mut TrieNode = self.root_node.as_mut();
+ let iterator = char_list.0.graphemes(true);
+ let mut create = false;
+
+ current_node.child_count += 1;
+ // Find the minimum existing math
+ for str in iterator {
+ if create == false {
+ if current_node.child_nodes.contains_key(str) {
+ current_node = current_node.child_nodes.get_mut(str).unwrap();
+ // we mark the node as containing our children.
+ current_node.child_count += 1;
+ } else {
+ create = true;
+
+ current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node = current_node.child_nodes.get_mut(str).unwrap();
+ current_node.child_count = 1;
+ }
+ } else {
+ current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node = current_node.child_nodes.get_mut(str).unwrap();
+ // we will only have one final node
+ current_node.child_count = 1;
+ }
+ }
+ current_node.is_final = true;
+ }
+
+ // Find a string
+ pub fn random_starting_with(&self, prefix: PhonemeString) -> Option<PhonemeString> {
+ let mut current_node: &TrieNode = self.root_node.as_ref();
+ // String for the return value
+ let mut builder = String::new();
+
+ // Iterator over each grapheme
+ let graphemes = prefix.0.graphemes(true).enumerate();
+
+ // Descend as far as possible into the tree
+ for (_, str) in graphemes {
+ // If we can descend further into the tree
+ if let Some(node) = current_node.child_nodes.get(&str) {
+ current_node = node;
+ builder += str;
+ debug!("going into node {}", builder);
+ } else {
+ // couldn't descend fully into the tree
+ // this basically means nothing exist in the tree
+ // with this prefix
+ debug!("no matches for prefix!");
+ return None;
+ }
+ }
+
+ debug!("Found common root node {}", builder);
+ builder = String::new();
+ let mut rng = thread_rng();
+
+ let mut length = 0;
+ while current_node.child_nodes.len() != 0 {
+ // We need to choose a random child based on weights
+ let weighted = WeightedIndex::new(current_node.child_nodes.iter().map(|(_, node)| {
+ node.child_count
+ / (node
+ .child_nodes
+ .iter()
+ .map(|(_, b)| b.child_count)
+ .sum::<u64>()
+ + 1)
+ + 1
+ }))
+ .expect("distribution creation should be valid");
+
+ let (key, node) = current_node
+ .child_nodes
+ .iter()
+ .nth(weighted.sample(&mut rng))
+ .expect("choosed value did not exist");
+ debug!("walking into node {}", key);
+
+ current_node = node;
+ builder += key;
+
+ // If this node is final and has childrens
+ if current_node.is_final && current_node.child_count > 0 {
+ // choose from current node or continue with childrens
+ let weighted = WeightedIndex::new(&[1, current_node.child_count / (length + 1)])
+ .expect("distribution seems impossible");
+
+ if weighted.sample(&mut rng) == 0 {
+ // we choosed this node!
+ // stop adding other characters
+ break;
+ }
+ }
+ length += 1;
+ }
+
+ // If we only added
+ if builder == "" {
+ return None;
+ }
+
+ // selected word
+ Some(PhonemeString(builder))
+ }
+}
diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs
new file mode 100644
index 0000000..6062c98
--- /dev/null
+++ b/libs/db/src/types.rs
@@ -0,0 +1,43 @@
+use std::{error::Error, fmt::Display};
+
+use async_trait::async_trait;
+use serde::{Deserialize, Serialize};
+
+// This is needed since we need dynamic dispatch
+// as any kind of inference service could be used.
+#[async_trait]
+pub trait InferenceService<E: Error> : Send + Sync {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>;
+}
+
+#[repr(transparent)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
+pub struct GraphemeString(pub String);
+
+#[repr(transparent)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
+pub struct PhonemeString(pub String);
+
+impl Display for PhonemeString {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str(&self.0)
+ }
+}
+
+impl Display for GraphemeString {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str(&self.0)
+ }
+}
+
+impl PhonemeString {
+ /// Converts a Grapheme representation of a work
+ /// into a phoneme one.
+ /// !! This is a one-way operation.
+ pub async fn from_grapheme<E: Error>(
+ value: GraphemeString,
+ inference: &Box<dyn InferenceService<E>>,
+ ) -> Result<Self, E> {
+ inference.fetch(value).await
+ }
+}