diff options
| author | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
|---|---|---|
| committer | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
| commit | 09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch) | |
| tree | 874fe676977f56c5456e4cead31e1a6c86bc40dd /libs | |
| parent | c951f54be999fb7e508039ba23fa5b1fe7035743 (diff) | |
feat: commit allfeat/twilight-discord
Diffstat (limited to 'libs')
| -rw-r--r-- | libs/db/src/inference.rs | 24 | ||||
| -rw-r--r-- | libs/db/src/save.rs | 12 | ||||
| -rw-r--r-- | libs/db/src/trie.rs | 36 | ||||
| -rw-r--r-- | libs/db/src/types.rs | 11 |
4 files changed, 45 insertions, 38 deletions
diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs index f2cbabf..25415df 100644 --- a/libs/db/src/inference.rs +++ b/libs/db/src/inference.rs @@ -5,7 +5,7 @@ use crate::{ use hypher::{Lang, hyphenate}; use itertools::Itertools; use log::*; -use std::{iter, ops::Add}; +use std::{iter, ops::Add, sync::Arc}; use thiserror::Error; use unicode_segmentation::UnicodeSegmentation; @@ -15,13 +15,14 @@ pub enum InferenceError { TrieError(), #[error("the matched value wasn't in the dictionary")] NotInDictionary(), - #[error("failt to cut the word")] + #[error("failed to cut the word")] WordCutError(), } -pub struct Inference<'a> { - inference_service: Box<dyn InferenceService<InferenceError>>, - save: Save<'a>, +#[derive(Debug, Clone)] +pub struct Inference { + inference_service: Arc<dyn InferenceService<InferenceError>>, + save: Save, } pub struct TLang(whatlang::Lang); @@ -62,7 +63,7 @@ impl Into<Lang> for TLang { } } -impl Inference<'_> { +impl Inference { fn matches(&self, source: &str, complete: &str) -> bool { let source_chars: Vec<&str> = source.graphemes(true).collect(); let complete_chars: Vec<&str> = complete.graphemes(true).collect(); @@ -79,10 +80,10 @@ impl Inference<'_> { return true; } - pub fn new<'a>( - inference_service: Box<dyn InferenceService<InferenceError>>, - save: Save<'a>, - ) -> Inference<'a> { + pub fn new( + inference_service: Arc<dyn InferenceService<InferenceError>>, + save: Save, + ) -> Inference { return Inference { inference_service, save, @@ -112,8 +113,7 @@ impl Inference<'_> { let completion = self .save .trie - .random_starting_with(phonemes.clone()) - .ok_or_else(InferenceError::TrieError)?; + .random_starting_with(phonemes.clone()).unwrap(); // store the found word let initial = cprefix.add(&completion.0); diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs index e1b3577..87a2736 100644 --- a/libs/db/src/save.rs +++ b/libs/db/src/save.rs @@ -1,10 +1,12 @@ -use crate::{trie::Trie, types::{GraphemeString, PhonemeString}}; +use crate::{ + trie::Trie, + types::{GraphemeString, PhonemeString}, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -#[derive(Debug, Deserialize, Serialize, Default)] -pub struct Save<'a> { - #[serde(borrow = "'a")] - pub trie: Trie<'a>, +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct Save { + pub trie: Trie, pub reverse_index: HashMap<PhonemeString, GraphemeString>, } diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs index 38cbcd8..754d978 100644 --- a/libs/db/src/trie.rs +++ b/libs/db/src/trie.rs @@ -5,16 +5,15 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use unicode_segmentation::UnicodeSegmentation; -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct TrieNode<'a> { +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct TrieNode { is_final: bool, - #[serde(borrow = "'a")] - child_nodes: HashMap<&'a str, TrieNode<'a>>, + child_nodes: HashMap<String, TrieNode>, child_count: u64, } -impl<'a> TrieNode<'a> { - pub fn new<'b>(is_final: bool) -> TrieNode<'b> { +impl TrieNode { + pub fn new<'b>(is_final: bool) -> TrieNode { TrieNode { is_final, child_nodes: HashMap::new(), @@ -22,7 +21,7 @@ impl<'a> TrieNode<'a> { } } - pub fn new_root<'b>() -> TrieNode<'b> { + pub fn new_root<'b>() -> TrieNode { TrieNode { is_final: false, child_nodes: HashMap::new(), @@ -31,15 +30,14 @@ impl<'a> TrieNode<'a> { } } -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Trie<'a> { - #[serde(borrow = "'a")] - root_node: Box<TrieNode<'a>>, +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct Trie { + root_node: Box<TrieNode>, } -impl<'a> Trie<'a> { +impl<'a> Trie { // Create a TrieStruct - pub fn new<'b>() -> Trie<'b> { + pub fn new<'b>() -> Trie { Trie { root_node: Box::new(TrieNode::new_root()), } @@ -62,12 +60,16 @@ impl<'a> Trie<'a> { } else { create = true; - current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node + .child_nodes + .insert(str.to_string(), TrieNode::new(false)); current_node = current_node.child_nodes.get_mut(str).unwrap(); current_node.child_count = 1; } } else { - current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node + .child_nodes + .insert(str.to_string(), TrieNode::new(false)); current_node = current_node.child_nodes.get_mut(str).unwrap(); // we will only have one final node current_node.child_count = 1; @@ -88,7 +90,7 @@ impl<'a> Trie<'a> { // Descend as far as possible into the tree for (_, str) in graphemes { // If we can descend further into the tree - if let Some(node) = current_node.child_nodes.get(&str) { + if let Some(node) = current_node.child_nodes.get(str) { current_node = node; builder += str; debug!("going into node {}", builder); @@ -124,7 +126,7 @@ impl<'a> Trie<'a> { .child_nodes .iter() .nth(weighted.sample(&mut rng)) - .expect("choosed value did not exist"); + .expect("choose value did not exist"); debug!("walking into node {}", key); current_node = node; diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs index 6062c98..636fcb7 100644 --- a/libs/db/src/types.rs +++ b/libs/db/src/types.rs @@ -1,12 +1,15 @@ -use std::{error::Error, fmt::Display}; - use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{Debug, Display}, + sync::Arc, +}; // This is needed since we need dynamic dispatch // as any kind of inference service could be used. #[async_trait] -pub trait InferenceService<E: Error> : Send + Sync { +pub trait InferenceService<E: Error>: Send + Sync + Debug { async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>; } @@ -36,7 +39,7 @@ impl PhonemeString { /// !! This is a one-way operation. pub async fn from_grapheme<E: Error>( value: GraphemeString, - inference: &Box<dyn InferenceService<E>>, + inference: &Arc<dyn InferenceService<E>>, ) -> Result<Self, E> { inference.fetch(value).await } |
