From 9f798c57f8525fb3e83fa2427fbdb645ebdf65eb Mon Sep 17 00:00:00 2001 From: MatthieuCoder Date: Sun, 22 Jan 2023 01:55:20 +0400 Subject: [PATCH] new db and readability --- autofeur_db/src/bin/generate.rs | 11 +- autofeur_db/src/bin/server.rs | 37 +++++-- autofeur_db/src/french_ipa.rs | 129 --------------------- autofeur_db/src/inference.rs | 23 ++-- autofeur_db/src/lib.rs | 1 - autofeur_db/src/save.rs | 9 +- autofeur_db/src/trie.rs | 191 ++++++++++++++------------------ autofeur_nova/src/index.mts | 16 ++- docker-compose.yaml | 2 + 9 files changed, 144 insertions(+), 275 deletions(-) delete mode 100644 autofeur_db/src/french_ipa.rs diff --git a/autofeur_db/src/bin/generate.rs b/autofeur_db/src/bin/generate.rs index 8db8809..7f85f41 100644 --- a/autofeur_db/src/bin/generate.rs +++ b/autofeur_db/src/bin/generate.rs @@ -1,6 +1,5 @@ use std::fs; -use autofeur::french_ipa::parse_word; use autofeur::save::Save; use kdam::tqdm; @@ -42,18 +41,10 @@ async fn main() { phonems.append(&mut pron); } - let mut invalid = 0; for phoneme in tqdm!(phonems.iter()) { - match parse_word(&phoneme) { - Some(a) => save.trie.insert(a), - None => { - invalid += 1; - } - } + save.trie.insert(&phoneme); } - println!("Invalid items count: {}", invalid); - fs::write("assets/db.bin", bincode::serialize(&save).unwrap()).unwrap(); println!("Generated to assets/db.bin"); diff --git a/autofeur_db/src/bin/server.rs b/autofeur_db/src/bin/server.rs index 376b4e1..8c350f9 100644 --- a/autofeur_db/src/bin/server.rs +++ b/autofeur_db/src/bin/server.rs @@ -17,28 +17,47 @@ fn parse_query(query: &str) -> HashMap { .collect() } -async fn handler(request: Request) -> Result, anyhow::Error> { +fn anyhow_response(err: anyhow::Error) -> Response { + Response::builder() + .status(400) + .body(Body::from(err.root_cause().to_string())) + .unwrap() +} + +async fn handler(request: Request) -> Result, hyper::Error> { let save: &Arc = request.extensions().get().unwrap(); - let query = request + let query = match request .uri() .query() - .ok_or_else(|| anyhow!("query does not exists"))?; - let data = parse_query(query) + .ok_or_else(|| anyhow_response(anyhow!("query does not exists"))) + { + Ok(ok) => ok, + Err(err) => return Ok(err), + }; + let data = match parse_query(query) .get("grapheme") - .ok_or_else(|| anyhow!("grapheme argument is not specified"))? - .clone(); + .ok_or_else(|| anyhow_response(anyhow!("grapheme argument is not specified"))) + { + Ok(ok) => ok.clone(), + Err(err) => return Ok(err), + }; - let infered = save + let infered = match save .inference(&data) .await - .or_else(|_| Err(anyhow!("cannot find data")))?; + .or_else(|e| Err(anyhow_response(e.context("inference error")))) + { + Ok(ok) => ok, + Err(e) => return Ok(e), + }; Ok(Response::builder().body(Body::from(infered)).unwrap()) } #[tokio::main] async fn main() -> Result<(), Box> { - let checkpoint: Save = bincode::deserialize(&fs::read("assets/db.bin").unwrap()).unwrap(); + let data = Box::leak(Box::new(fs::read("assets/db.bin").unwrap())); + let checkpoint: Save = bincode::deserialize(data).unwrap(); let service = ServiceBuilder::new() .layer(AddExtensionLayer::new(Arc::new(checkpoint))) // Wrap a `Service` in our middleware stack diff --git a/autofeur_db/src/french_ipa.rs b/autofeur_db/src/french_ipa.rs deleted file mode 100644 index b758779..0000000 --- a/autofeur_db/src/french_ipa.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::hash::Hash; - -use unicode_segmentation::UnicodeSegmentation; - -macro_rules! ipa_element_to_number { - (@step $_idx:expr, $ident:ident,) => { - None - }; - - (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => { - if $ident == $head { - Some(Self($idx)) - } - else { - ipa_element_to_number!(@step $idx + 1usize, $ident, $($tail,)*) - } - }; -} -macro_rules! ipa_number_to_ipa { - (@step $_idx:expr, $ident:ident,) => { - "unreachable!()" - }; - - (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => { - if $ident == $idx { - $head - } - else { - ipa_number_to_ipa!(@step $idx + 1usize, $ident, $($tail,)*) - } - }; -} - -macro_rules! replace_expr { - ($_t:tt $sub:expr) => { - $sub - }; -} - -macro_rules! count_tts { - ($($tts:tt)*) => {0usize $(+ replace_expr!($tts 1usize))*}; -} - -macro_rules! ipa_map { - ($name:ident, $($l:literal),*) => { - use serde::{Deserialize, Serialize}; - #[derive(Eq, Hash, PartialEq, Debug, Copy, Clone, Serialize, Deserialize)] - pub struct $name(pub usize); - - impl $name { - pub const SIZE: usize = count_tts!($($l,)*); - - pub fn from_char(ch: &str) -> Option<$name> { - ipa_element_to_number!(@step 0usize, ch, $($l,)*) - } - pub fn to_char(self) -> &'static str { - let num = self.0; - ipa_number_to_ipa!(@step 0usize, num, $($l,)*) - } - } - }; -} - -ipa_map!( - FrenchIPAChar, - "a", - "ɑ", - "ɑ̃", - "e", - "ɛ", - "ɛ̃", - "ə", - "i", - "o", - "ɔ", - "ɔ̃", - "œ", - "œ̃", - "ø", - "u", - "y", - "j", - "É¥", - "w", - "b", - "d", - "f", - "g", - "k", - "l", - "m", - "n", - "ɲ", - "ŋ", - "p", - "ʁ", - "s", - "ʃ", - "t", - "v", - "z", - "ʒ", - "g", - "É¡", - "ɪ", - "ʊ", - "x", - "r" -); - -pub type FrenchIPAWord = Vec; - -pub fn parse_word(str: &str) -> Option { - let mut word = FrenchIPAWord::default(); - let graphemes: Vec<&str> = str.graphemes(true).collect(); - for (_, grapheme) in graphemes.iter().enumerate() { - let a = FrenchIPAChar::from_char(grapheme); - - word.push(match a { - None => { - println!("invalid char: {}", grapheme); - return None; - } - Some(a) => a, - }) - } - - Some(word) -} diff --git a/autofeur_db/src/inference.rs b/autofeur_db/src/inference.rs index 8b833a0..f8c1351 100644 --- a/autofeur_db/src/inference.rs +++ b/autofeur_db/src/inference.rs @@ -5,11 +5,11 @@ use itertools::Itertools; use levenshtein::levenshtein; use unicode_segmentation::UnicodeSegmentation; -use crate::{french_ipa::parse_word, save::Save}; +use crate::save::Save; async fn call_inference_service(word: &str) -> anyhow::Result { let server: Result = - env::var("PHONEMIZER").or_else(|_| Ok("".to_string())); + env::var("PHONEMIZER").or_else(|_| Ok("http://localhost:8000/".to_string())); Ok( reqwest::get(format!("{}?grapheme={}", server.unwrap(), word)) .await? @@ -18,22 +18,23 @@ async fn call_inference_service(word: &str) -> anyhow::Result { ) } -impl Save { +impl Save<'_> { pub async fn inference(&self, prefix: &str) -> anyhow::Result { let phonemes = call_inference_service(prefix).await?; - let ipa_phonemes = - parse_word(&phonemes).ok_or_else(|| anyhow!("failed to parse the word"))?; let completion = self .trie - .random_starting_with(ipa_phonemes) + .random_starting_with(&phonemes) .ok_or_else(|| anyhow!("no matches"))?; - let infered = phonemes.add(&completion); - let word = self - .reverse_index - .get(&infered) - .ok_or_else(|| anyhow!("matched values is not in dictionary"))?; + let infered = phonemes.clone().add(&completion); + let word = self.reverse_index.get(&infered).ok_or_else(|| { + anyhow!( + "matched value is not in dictionary {} {}", + infered, + phonemes + ) + })?; println!("Matching {} by adding {}", word, completion); diff --git a/autofeur_db/src/lib.rs b/autofeur_db/src/lib.rs index 9c8ba4b..6610fc0 100644 --- a/autofeur_db/src/lib.rs +++ b/autofeur_db/src/lib.rs @@ -1,4 +1,3 @@ pub mod trie; -pub mod french_ipa; pub mod save; pub mod inference; diff --git a/autofeur_db/src/save.rs b/autofeur_db/src/save.rs index c64a430..a2bfd4c 100644 --- a/autofeur_db/src/save.rs +++ b/autofeur_db/src/save.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize}; use crate::trie::Trie; #[derive(Debug, Deserialize, Serialize, Default)] -pub struct Save { - pub trie: Trie, - pub reverse_index: HashMap -} \ No newline at end of file +pub struct Save<'a> { + #[serde(borrow = "'a")] + pub trie: Trie<'a>, + pub reverse_index: HashMap, +} diff --git a/autofeur_db/src/trie.rs b/autofeur_db/src/trie.rs index 43a1be8..8e8decc 100644 --- a/autofeur_db/src/trie.rs +++ b/autofeur_db/src/trie.rs @@ -1,169 +1,150 @@ use std::collections::HashMap; -use rand::{thread_rng, Rng}; +use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng}; use serde::{Deserialize, Serialize}; - -use crate::french_ipa::{FrenchIPAChar, FrenchIPAWord}; +use unicode_segmentation::UnicodeSegmentation; #[derive(Debug, Serialize, Deserialize, Default)] -pub struct TrieNode { - value: Option, +pub struct TrieNode<'a> { is_final: bool, - child_nodes: HashMap, + #[serde(borrow = "'a")] + child_nodes: HashMap<&'a str, TrieNode<'a>>, child_count: u64, } -impl TrieNode { +impl<'a> TrieNode<'a> { // Create new node - pub fn new(c: FrenchIPAChar, is_final: bool) -> TrieNode { + pub fn new<'b>(is_final: bool) -> TrieNode<'b> { TrieNode { - value: Option::Some(c), is_final, - child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE), + child_nodes: HashMap::new(), child_count: 0, } } - pub fn new_root() -> TrieNode { + pub fn new_root<'b>() -> TrieNode<'b> { TrieNode { - value: Option::None, is_final: false, - child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE), + child_nodes: HashMap::new(), child_count: 0, } } } #[derive(Debug, Serialize, Deserialize, Default)] -pub struct Trie { - root_node: Box, +pub struct Trie<'a> { + #[serde(borrow = "'a")] + root_node: Box>, } -impl Trie { +impl<'a> Trie<'a> { // Create a TrieStruct - pub fn new() -> Trie { + pub fn new<'b>() -> Trie<'b> { Trie { root_node: Box::new(TrieNode::new_root()), } } // Insert a string - pub fn insert(&mut self, char_list: FrenchIPAWord) { + pub fn insert(&mut self, char_list: &'a str) { let mut current_node: &mut TrieNode = self.root_node.as_mut(); - let mut last_match = 0; + let iterator = char_list.graphemes(true); + let mut create = false; + current_node.child_count += 1; // Find the minimum existing math - for letter_counter in 0..char_list.len() { - if current_node - .child_nodes - .contains_key(&char_list[letter_counter]) - { - current_node = current_node - .child_nodes - .get_mut(&char_list[letter_counter]) - .unwrap(); - // we mark the node as containing our children. - current_node.child_count += 1; - } else { - last_match = letter_counter; - break; - } - last_match = letter_counter + 1; - } + for str in iterator { + if create == false { + if current_node.child_nodes.contains_key(str) { + current_node = current_node.child_nodes.get_mut(str).unwrap(); + // we mark the node as containing our children. + current_node.child_count += 1; + } else { + create = true; - // if we found an already exsting node - if last_match == char_list.len() { - current_node.is_final = true; - } else { - for new_counter in last_match..char_list.len() { - let key = char_list[new_counter]; - current_node - .child_nodes - .insert(key, TrieNode::new(char_list[new_counter], false)); - current_node = current_node.child_nodes.get_mut(&key).unwrap(); - current_node.child_count += 1; + current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node = current_node.child_nodes.get_mut(str).unwrap(); + current_node.child_count = 1; + } + } else { + current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node = current_node.child_nodes.get_mut(str).unwrap(); + // we will only have one final node + current_node.child_count = 1; } - current_node.is_final = true; } + current_node.is_final = true; } // Find a string - pub fn random_starting_with(&self, prefix: FrenchIPAWord) -> Option { + pub fn random_starting_with(&self, prefix: &str) -> Option { let mut current_node: &TrieNode = self.root_node.as_ref(); - let mut str = String::new(); - let mut i = prefix.len(); + // String for the return value + let mut builder = String::new(); + + // Iterator over each grapheme + let graphemes = prefix.graphemes(true).enumerate(); + // Descend as far as possible into the tree - for counter in prefix { - if let Some(node) = current_node.child_nodes.get(&counter) { + for (_, str) in graphemes { + // If we can descend further into the tree + if let Some(node) = current_node.child_nodes.get(&str) { current_node = node; - if let Some(value) = current_node.value { - str += value.to_char(); - i -= 1; - } + builder += str; + println!("going into node {}", builder); } else { // couldn't descend fully into the tree + // this basically means nothing exist in the tree + // with this prefix + println!("no matches for prefix!"); return None; } } - println!("Found common root node {}", str); + println!("Found common root node {}", builder); + builder = String::new(); + let mut rng = thread_rng(); - // Ignore the 0-len matches - if i == 0 && current_node.child_nodes.len() == 0 { - println!("removing 0-len match"); - return None; - } - str = String::new(); - - // now that we have the node we descend by respecting the probabilities - while current_node.child_nodes.len() != 0 && current_node.child_count > 0 { - println!("Descending into node {}", str); - let max = current_node.child_count; - let random_number = thread_rng().gen_range(0..max); - let mut increment = 0; - - let mut did_change = false; - // find node corresponding to the node - for (_, node) in ¤t_node.child_nodes { - if node.child_count + increment >= random_number { - println!("changing node"); - current_node = node; - did_change = true; - break; - } else { - println!( - "didn't change node: {}<{}", - node.child_count + increment, - random_number - ) - } - increment += node.child_count; - } - if did_change { - if let Some(value) = current_node.value { - println!("added {}", value.to_char()); - str += value.to_char(); - } - } else { - println!( - "WARNING: DIDNT CHANGE NODE child_count={}", - current_node.child_count - ) - } - // if this node is a final node, we have a probability of using it + while current_node.child_nodes.len() != 0 { + // We need to choose a random child based on weights + let weighted = WeightedIndex::new( + current_node + .child_nodes + .iter() + .map(|(_, node)| node.child_count), + ) + .expect("distribution creation should be valid"); + + let (key, node) = current_node + .child_nodes + .iter() + .nth(weighted.sample(&mut rng)) + .expect("choosed value did not exist"); + println!("waling into node {}", key); + + current_node = node; + builder += key; + + // If this node is final and has childrens if current_node.is_final && current_node.child_count > 0 { - let random_number = thread_rng().gen_range(0..current_node.child_count); - if random_number == 0 { + // choose from current node or continue with childrens + let weighted = WeightedIndex::new(&[1, current_node.child_count]) + .expect("distribution seems impossible"); + + if weighted.sample(&mut rng) == 0 { + // we choosed this node! + // stop adding other characters break; } } } - if str == "" { + // If we only added + if builder == "" { return None; } // selected word - Some(str) + Some(builder) } } diff --git a/autofeur_nova/src/index.mts b/autofeur_nova/src/index.mts index 1c0dfbd..a5a2efa 100644 --- a/autofeur_nova/src/index.mts +++ b/autofeur_nova/src/index.mts @@ -10,19 +10,23 @@ import { request } from "undici"; // `autofeur_db` service export const DB = process.env.DB || "http://localhost:3000"; // nats broker for connecting to nova -export const NATS = process.env.NATS || "localhost:4222"; +export const NATS = process.env.NATS || "192.168.0.17:4222"; // rest endpoint for connecting to nova -export const REST = process.env.REST || "http://localhost:8090/api"; +export const REST = process.env.REST || "http://192.168.0.17:8090/api"; /** * Completes a grapheme using the `autofeur_db` service. * @param grapheme Grapheme to complete * @returns Completed grapheme */ -export const completeWord = (grapheme: string) => - request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`).then((x) => - x.body.text() - ); +export const completeWord = (grapheme: string) => { + let resp = request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`); + return resp.then((x) => { + if (x.statusCode === 200) { + return x.body.text(); + } + }); +}; /** * Cleans a sentence for usage with this program, strips unwanted chars diff --git a/docker-compose.yaml b/docker-compose.yaml index f0f8d41..9b2c703 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -23,6 +23,8 @@ services: deep_phonemizer: build: deep_phonemizer restart: always + ports: + - 8000:8000 volumes: - ./deep_phonemizer/assets/model.pt:/app/assets/model.pt nats: -- 2.39.5