summaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
authorMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
committerMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
commit09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch)
tree874fe676977f56c5456e4cead31e1a6c86bc40dd /libs
parentc951f54be999fb7e508039ba23fa5b1fe7035743 (diff)
feat: commit allfeat/twilight-discord
Diffstat (limited to 'libs')
-rw-r--r--libs/db/src/inference.rs24
-rw-r--r--libs/db/src/save.rs12
-rw-r--r--libs/db/src/trie.rs36
-rw-r--r--libs/db/src/types.rs11
4 files changed, 45 insertions, 38 deletions
diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs
index f2cbabf..25415df 100644
--- a/libs/db/src/inference.rs
+++ b/libs/db/src/inference.rs
@@ -5,7 +5,7 @@ use crate::{
use hypher::{Lang, hyphenate};
use itertools::Itertools;
use log::*;
-use std::{iter, ops::Add};
+use std::{iter, ops::Add, sync::Arc};
use thiserror::Error;
use unicode_segmentation::UnicodeSegmentation;
@@ -15,13 +15,14 @@ pub enum InferenceError {
TrieError(),
#[error("the matched value wasn't in the dictionary")]
NotInDictionary(),
- #[error("failt to cut the word")]
+ #[error("failed to cut the word")]
WordCutError(),
}
-pub struct Inference<'a> {
- inference_service: Box<dyn InferenceService<InferenceError>>,
- save: Save<'a>,
+#[derive(Debug, Clone)]
+pub struct Inference {
+ inference_service: Arc<dyn InferenceService<InferenceError>>,
+ save: Save,
}
pub struct TLang(whatlang::Lang);
@@ -62,7 +63,7 @@ impl Into<Lang> for TLang {
}
}
-impl Inference<'_> {
+impl Inference {
fn matches(&self, source: &str, complete: &str) -> bool {
let source_chars: Vec<&str> = source.graphemes(true).collect();
let complete_chars: Vec<&str> = complete.graphemes(true).collect();
@@ -79,10 +80,10 @@ impl Inference<'_> {
return true;
}
- pub fn new<'a>(
- inference_service: Box<dyn InferenceService<InferenceError>>,
- save: Save<'a>,
- ) -> Inference<'a> {
+ pub fn new(
+ inference_service: Arc<dyn InferenceService<InferenceError>>,
+ save: Save,
+ ) -> Inference {
return Inference {
inference_service,
save,
@@ -112,8 +113,7 @@ impl Inference<'_> {
let completion = self
.save
.trie
- .random_starting_with(phonemes.clone())
- .ok_or_else(InferenceError::TrieError)?;
+ .random_starting_with(phonemes.clone()).unwrap();
// store the found word
let initial = cprefix.add(&completion.0);
diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs
index e1b3577..87a2736 100644
--- a/libs/db/src/save.rs
+++ b/libs/db/src/save.rs
@@ -1,10 +1,12 @@
-use crate::{trie::Trie, types::{GraphemeString, PhonemeString}};
+use crate::{
+ trie::Trie,
+ types::{GraphemeString, PhonemeString},
+};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
-#[derive(Debug, Deserialize, Serialize, Default)]
-pub struct Save<'a> {
- #[serde(borrow = "'a")]
- pub trie: Trie<'a>,
+#[derive(Debug, Deserialize, Serialize, Default, Clone)]
+pub struct Save {
+ pub trie: Trie,
pub reverse_index: HashMap<PhonemeString, GraphemeString>,
}
diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs
index 38cbcd8..754d978 100644
--- a/libs/db/src/trie.rs
+++ b/libs/db/src/trie.rs
@@ -5,16 +5,15 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use unicode_segmentation::UnicodeSegmentation;
-#[derive(Debug, Serialize, Deserialize, Default)]
-pub struct TrieNode<'a> {
+#[derive(Debug, Serialize, Deserialize, Default, Clone)]
+pub struct TrieNode {
is_final: bool,
- #[serde(borrow = "'a")]
- child_nodes: HashMap<&'a str, TrieNode<'a>>,
+ child_nodes: HashMap<String, TrieNode>,
child_count: u64,
}
-impl<'a> TrieNode<'a> {
- pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
+impl TrieNode {
+ pub fn new<'b>(is_final: bool) -> TrieNode {
TrieNode {
is_final,
child_nodes: HashMap::new(),
@@ -22,7 +21,7 @@ impl<'a> TrieNode<'a> {
}
}
- pub fn new_root<'b>() -> TrieNode<'b> {
+ pub fn new_root<'b>() -> TrieNode {
TrieNode {
is_final: false,
child_nodes: HashMap::new(),
@@ -31,15 +30,14 @@ impl<'a> TrieNode<'a> {
}
}
-#[derive(Debug, Serialize, Deserialize, Default)]
-pub struct Trie<'a> {
- #[serde(borrow = "'a")]
- root_node: Box<TrieNode<'a>>,
+#[derive(Debug, Serialize, Deserialize, Default, Clone)]
+pub struct Trie {
+ root_node: Box<TrieNode>,
}
-impl<'a> Trie<'a> {
+impl<'a> Trie {
// Create a TrieStruct
- pub fn new<'b>() -> Trie<'b> {
+ pub fn new<'b>() -> Trie {
Trie {
root_node: Box::new(TrieNode::new_root()),
}
@@ -62,12 +60,16 @@ impl<'a> Trie<'a> {
} else {
create = true;
- current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node
+ .child_nodes
+ .insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
current_node.child_count = 1;
}
} else {
- current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node
+ .child_nodes
+ .insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we will only have one final node
current_node.child_count = 1;
@@ -88,7 +90,7 @@ impl<'a> Trie<'a> {
// Descend as far as possible into the tree
for (_, str) in graphemes {
// If we can descend further into the tree
- if let Some(node) = current_node.child_nodes.get(&str) {
+ if let Some(node) = current_node.child_nodes.get(str) {
current_node = node;
builder += str;
debug!("going into node {}", builder);
@@ -124,7 +126,7 @@ impl<'a> Trie<'a> {
.child_nodes
.iter()
.nth(weighted.sample(&mut rng))
- .expect("choosed value did not exist");
+ .expect("choose value did not exist");
debug!("walking into node {}", key);
current_node = node;
diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs
index 6062c98..636fcb7 100644
--- a/libs/db/src/types.rs
+++ b/libs/db/src/types.rs
@@ -1,12 +1,15 @@
-use std::{error::Error, fmt::Display};
-
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
+use std::{
+ error::Error,
+ fmt::{Debug, Display},
+ sync::Arc,
+};
// This is needed since we need dynamic dispatch
// as any kind of inference service could be used.
#[async_trait]
-pub trait InferenceService<E: Error> : Send + Sync {
+pub trait InferenceService<E: Error>: Send + Sync + Debug {
async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>;
}
@@ -36,7 +39,7 @@ impl PhonemeString {
/// !! This is a one-way operation.
pub async fn from_grapheme<E: Error>(
value: GraphemeString,
- inference: &Box<dyn InferenceService<E>>,
+ inference: &Arc<dyn InferenceService<E>>,
) -> Result<Self, E> {
inference.fetch(value).await
}