]> git.puffer.fish Git - matthieu/gru.git/commitdiff
new db and readability
authorMatthieuCoder <matthieu@matthieu-dev.xyz>
Sat, 21 Jan 2023 21:55:20 +0000 (01:55 +0400)
committerMatthieuCoder <matthieu@matthieu-dev.xyz>
Sat, 21 Jan 2023 21:55:20 +0000 (01:55 +0400)
autofeur_db/src/bin/generate.rs
autofeur_db/src/bin/server.rs
autofeur_db/src/french_ipa.rs [deleted file]
autofeur_db/src/inference.rs
autofeur_db/src/lib.rs
autofeur_db/src/save.rs
autofeur_db/src/trie.rs
autofeur_nova/src/index.mts
docker-compose.yaml

index 8db8809ec61ac160701a7fbdfa104aa57633c0a9..7f85f419bd92b0edf152aa2cefcb648945904d46 100644 (file)
@@ -1,6 +1,5 @@
 use std::fs;
 
-use autofeur::french_ipa::parse_word;
 use autofeur::save::Save;
 use kdam::tqdm;
 
@@ -42,18 +41,10 @@ async fn main() {
         phonems.append(&mut pron);
     }
 
-    let mut invalid = 0;
     for phoneme in tqdm!(phonems.iter()) {
-        match parse_word(&phoneme) {
-            Some(a) => save.trie.insert(a),
-            None => {
-                invalid += 1;
-            }
-        }
+        save.trie.insert(&phoneme);
     }
 
-    println!("Invalid items count: {}", invalid);
-
     fs::write("assets/db.bin", bincode::serialize(&save).unwrap()).unwrap();
 
     println!("Generated to assets/db.bin");
index 376b4e180cce021937e5924d64cf67030378d2be..8c350f9205d897d6b138fa0bddd5816f6e0da7a1 100644 (file)
@@ -17,28 +17,47 @@ fn parse_query(query: &str) -> HashMap<String, String> {
         .collect()
 }
 
-async fn handler(request: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
+fn anyhow_response(err: anyhow::Error) -> Response<Body> {
+    Response::builder()
+        .status(400)
+        .body(Body::from(err.root_cause().to_string()))
+        .unwrap()
+}
+
+async fn handler(request: Request<Body>) -> Result<Response<Body>, hyper::Error> {
     let save: &Arc<Save> = request.extensions().get().unwrap();
-    let query = request
+    let query = match request
         .uri()
         .query()
-        .ok_or_else(|| anyhow!("query does not exists"))?;
-    let data = parse_query(query)
+        .ok_or_else(|| anyhow_response(anyhow!("query does not exists")))
+    {
+        Ok(ok) => ok,
+        Err(err) => return Ok(err),
+    };
+    let data = match parse_query(query)
         .get("grapheme")
-        .ok_or_else(|| anyhow!("grapheme argument is not specified"))?
-        .clone();
+        .ok_or_else(|| anyhow_response(anyhow!("grapheme argument is not specified")))
+    {
+        Ok(ok) => ok.clone(),
+        Err(err) => return Ok(err),
+    };
 
-    let infered = save
+    let infered = match save
         .inference(&data)
         .await
-        .or_else(|_| Err(anyhow!("cannot find data")))?;
+        .or_else(|e| Err(anyhow_response(e.context("inference error"))))
+    {
+        Ok(ok) => ok,
+        Err(e) => return Ok(e),
+    };
 
     Ok(Response::builder().body(Body::from(infered)).unwrap())
 }
 
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
-    let checkpoint: Save = bincode::deserialize(&fs::read("assets/db.bin").unwrap()).unwrap();
+    let data = Box::leak(Box::new(fs::read("assets/db.bin").unwrap()));
+    let checkpoint: Save = bincode::deserialize(data).unwrap();
     let service = ServiceBuilder::new()
         .layer(AddExtensionLayer::new(Arc::new(checkpoint)))
         // Wrap a `Service` in our middleware stack
diff --git a/autofeur_db/src/french_ipa.rs b/autofeur_db/src/french_ipa.rs
deleted file mode 100644 (file)
index b758779..0000000
+++ /dev/null
@@ -1,129 +0,0 @@
-use std::hash::Hash;
-
-use unicode_segmentation::UnicodeSegmentation;
-
-macro_rules! ipa_element_to_number {
-    (@step $_idx:expr, $ident:ident,) => {
-        None
-    };
-
-    (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
-        if $ident == $head {
-            Some(Self($idx))
-        }
-        else {
-            ipa_element_to_number!(@step $idx + 1usize, $ident, $($tail,)*)
-        }
-    };
-}
-macro_rules! ipa_number_to_ipa {
-    (@step $_idx:expr, $ident:ident,) => {
-        "unreachable!()"
-    };
-
-    (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
-        if $ident == $idx {
-            $head
-        }
-        else {
-            ipa_number_to_ipa!(@step $idx + 1usize, $ident, $($tail,)*)
-        }
-    };
-}
-
-macro_rules! replace_expr {
-    ($_t:tt $sub:expr) => {
-        $sub
-    };
-}
-
-macro_rules! count_tts {
-    ($($tts:tt)*) => {0usize $(+ replace_expr!($tts 1usize))*};
-}
-
-macro_rules! ipa_map {
-    ($name:ident, $($l:literal),*) => {
-        use serde::{Deserialize, Serialize};
-        #[derive(Eq, Hash, PartialEq, Debug, Copy, Clone, Serialize, Deserialize)]
-        pub struct $name(pub usize);
-
-        impl $name {
-            pub const SIZE: usize = count_tts!($($l,)*);
-
-            pub fn from_char(ch: &str) -> Option<$name> {
-                ipa_element_to_number!(@step 0usize, ch, $($l,)*)
-            }
-            pub fn to_char(self) -> &'static str {
-                let num = self.0;
-                ipa_number_to_ipa!(@step 0usize, num, $($l,)*)
-            }
-        }
-    };
-}
-
-ipa_map!(
-    FrenchIPAChar,
-    "a",
-    "ɑ",
-    "ɑ̃",
-    "e",
-    "ɛ",
-    "ɛ̃",
-    "ə",
-    "i",
-    "o",
-    "ɔ",
-    "ɔ̃",
-    "œ",
-    "œ̃",
-    "ø",
-    "u",
-    "y",
-    "j",
-    "ɥ",
-    "w",
-    "b",
-    "d",
-    "f",
-    "g",
-    "k",
-    "l",
-    "m",
-    "n",
-    "ɲ",
-    "ŋ",
-    "p",
-    "ʁ",
-    "s",
-    "ʃ",
-    "t",
-    "v",
-    "z",
-    "ʒ",
-    "g",
-    "ɡ",
-    "ɪ",
-    "ʊ",
-    "x",
-    "r"
-);
-
-pub type FrenchIPAWord = Vec<FrenchIPAChar>;
-
-pub fn parse_word(str: &str) -> Option<FrenchIPAWord> {
-    let mut word = FrenchIPAWord::default();
-    let graphemes: Vec<&str> = str.graphemes(true).collect();
-    for (_, grapheme) in graphemes.iter().enumerate() {
-        let a = FrenchIPAChar::from_char(grapheme);
-
-        word.push(match a {
-            None => {
-                println!("invalid char: {}", grapheme);
-                return None;
-            }
-            Some(a) => a,
-        })
-    }
-
-    Some(word)
-}
index 8b833a00ae3bf369dda826b2a0930204df74eb5d..f8c1351eb5bb6fc4b4593411177ebad2b27a0d5d 100644 (file)
@@ -5,11 +5,11 @@ use itertools::Itertools;
 use levenshtein::levenshtein;
 use unicode_segmentation::UnicodeSegmentation;
 
-use crate::{french_ipa::parse_word, save::Save};
+use crate::save::Save;
 
 async fn call_inference_service(word: &str) -> anyhow::Result<String> {
     let server: Result<String, anyhow::Error> =
-        env::var("PHONEMIZER").or_else(|_| Ok("".to_string()));
+        env::var("PHONEMIZER").or_else(|_| Ok("http://localhost:8000/".to_string()));
     Ok(
         reqwest::get(format!("{}?grapheme={}", server.unwrap(), word))
             .await?
@@ -18,22 +18,23 @@ async fn call_inference_service(word: &str) -> anyhow::Result<String> {
     )
 }
 
-impl Save {
+impl Save<'_> {
     pub async fn inference(&self, prefix: &str) -> anyhow::Result<String> {
         let phonemes = call_inference_service(prefix).await?;
-        let ipa_phonemes =
-            parse_word(&phonemes).ok_or_else(|| anyhow!("failed to parse the word"))?;
 
         let completion = self
             .trie
-            .random_starting_with(ipa_phonemes)
+            .random_starting_with(&phonemes)
             .ok_or_else(|| anyhow!("no matches"))?;
 
-        let infered = phonemes.add(&completion);
-        let word = self
-            .reverse_index
-            .get(&infered)
-            .ok_or_else(|| anyhow!("matched values is not in dictionary"))?;
+        let infered = phonemes.clone().add(&completion);
+        let word = self.reverse_index.get(&infered).ok_or_else(|| {
+            anyhow!(
+                "matched value is not in dictionary {} {}",
+                infered,
+                phonemes
+            )
+        })?;
 
         println!("Matching {} by adding {}", word, completion);
 
index 9c8ba4b4892dfba8303bc204714409901822bb69..6610fc0609b491360450c9f01d17fba8b0a52e0a 100644 (file)
@@ -1,4 +1,3 @@
 pub mod trie;
-pub mod french_ipa;
 pub mod save;
 pub mod inference;
index c64a4304da30d2bdaf1319133ce35813bc7ee05c..a2bfd4cc5007083368528ecf1f6c560371b28871 100644 (file)
@@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize};
 use crate::trie::Trie;
 
 #[derive(Debug, Deserialize, Serialize, Default)]
-pub struct Save {
-    pub trie: Trie,
-    pub reverse_index: HashMap<String, String>
-}
\ No newline at end of file
+pub struct Save<'a> {
+    #[serde(borrow = "'a")]
+    pub trie: Trie<'a>,
+    pub reverse_index: HashMap<String, String>,
+}
index 43a1be8ec6850da4b072cff26274e4728e6ae697..8e8decc4109c314cb2e14d2fc0b1bc7224c799c4 100644 (file)
 use std::collections::HashMap;
 
-use rand::{thread_rng, Rng};
+use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
 use serde::{Deserialize, Serialize};
-
-use crate::french_ipa::{FrenchIPAChar, FrenchIPAWord};
+use unicode_segmentation::UnicodeSegmentation;
 
 #[derive(Debug, Serialize, Deserialize, Default)]
-pub struct TrieNode {
-    value: Option<FrenchIPAChar>,
+pub struct TrieNode<'a> {
     is_final: bool,
-    child_nodes: HashMap<FrenchIPAChar, TrieNode>,
+    #[serde(borrow = "'a")]
+    child_nodes: HashMap<&'a str, TrieNode<'a>>,
     child_count: u64,
 }
 
-impl TrieNode {
+impl<'a> TrieNode<'a> {
     // Create new node
-    pub fn new(c: FrenchIPAChar, is_final: bool) -> TrieNode {
+    pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
         TrieNode {
-            value: Option::Some(c),
             is_final,
-            child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
+            child_nodes: HashMap::new(),
             child_count: 0,
         }
     }
 
-    pub fn new_root() -> TrieNode {
+    pub fn new_root<'b>() -> TrieNode<'b> {
         TrieNode {
-            value: Option::None,
             is_final: false,
-            child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
+            child_nodes: HashMap::new(),
             child_count: 0,
         }
     }
 }
 
 #[derive(Debug, Serialize, Deserialize, Default)]
-pub struct Trie {
-    root_node: Box<TrieNode>,
+pub struct Trie<'a> {
+    #[serde(borrow = "'a")]
+    root_node: Box<TrieNode<'a>>,
 }
 
-impl Trie {
+impl<'a> Trie<'a> {
     // Create a TrieStruct
-    pub fn new() -> Trie {
+    pub fn new<'b>() -> Trie<'b> {
         Trie {
             root_node: Box::new(TrieNode::new_root()),
         }
     }
 
     // Insert a string
-    pub fn insert(&mut self, char_list: FrenchIPAWord) {
+    pub fn insert(&mut self, char_list: &'a str) {
         let mut current_node: &mut TrieNode = self.root_node.as_mut();
-        let mut last_match = 0;
+        let iterator = char_list.graphemes(true);
+        let mut create = false;
 
+        current_node.child_count += 1;
         // Find the minimum existing math
-        for letter_counter in 0..char_list.len() {
-            if current_node
-                .child_nodes
-                .contains_key(&char_list[letter_counter])
-            {
-                current_node = current_node
-                    .child_nodes
-                    .get_mut(&char_list[letter_counter])
-                    .unwrap();
-                // we mark the node as containing our children.
-                current_node.child_count += 1;
-            } else {
-                last_match = letter_counter;
-                break;
-            }
-            last_match = letter_counter + 1;
-        }
+        for str in iterator {
+            if create == false {
+                if current_node.child_nodes.contains_key(str) {
+                    current_node = current_node.child_nodes.get_mut(str).unwrap();
+                    // we mark the node as containing our children.
+                    current_node.child_count += 1;
+                } else {
+                    create = true;
 
-        // if we found an already exsting node
-        if last_match == char_list.len() {
-            current_node.is_final = true;
-        } else {
-            for new_counter in last_match..char_list.len() {
-                let key = char_list[new_counter];
-                current_node
-                    .child_nodes
-                    .insert(key, TrieNode::new(char_list[new_counter], false));
-                current_node = current_node.child_nodes.get_mut(&key).unwrap();
-                current_node.child_count += 1;
+                    current_node.child_nodes.insert(str, TrieNode::new(false));
+                    current_node = current_node.child_nodes.get_mut(str).unwrap();
+                    current_node.child_count = 1;
+                }
+            } else {
+                current_node.child_nodes.insert(str, TrieNode::new(false));
+                current_node = current_node.child_nodes.get_mut(str).unwrap();
+                // we will only have one final node
+                current_node.child_count = 1;
             }
-            current_node.is_final = true;
         }
+        current_node.is_final = true;
     }
 
     // Find a string
-    pub fn random_starting_with(&self, prefix: FrenchIPAWord) -> Option<String> {
+    pub fn random_starting_with(&self, prefix: &str) -> Option<String> {
         let mut current_node: &TrieNode = self.root_node.as_ref();
-        let mut str = String::new();
-        let mut i = prefix.len();
+        // String for the return value
+        let mut builder = String::new();
+
+        // Iterator over each grapheme
+        let graphemes = prefix.graphemes(true).enumerate();
+
         // Descend as far as possible into the tree
-        for counter in prefix {
-            if let Some(node) = current_node.child_nodes.get(&counter) {
+        for (_, str) in graphemes {
+            // If we can descend further into the tree
+            if let Some(node) = current_node.child_nodes.get(&str) {
                 current_node = node;
-                if let Some(value) = current_node.value {
-                    str += value.to_char();
-                    i -= 1;
-                }
+                builder += str;
+                println!("going into node {}", builder);
             } else {
                 // couldn't descend fully into the tree
+                // this basically means nothing exist in the tree
+                // with this prefix
+                println!("no matches for prefix!");
                 return None;
             }
         }
 
-        println!("Found common root node {}", str);
+        println!("Found common root node {}", builder);
+        builder = String::new();
+        let mut rng = thread_rng();
 
-        // Ignore the 0-len matches
-        if i == 0 && current_node.child_nodes.len() == 0 {
-            println!("removing 0-len match");
-            return None;
-        }
-        str = String::new();
-
-        // now that we have the node we descend by respecting the probabilities
-        while current_node.child_nodes.len() != 0 && current_node.child_count > 0 {
-            println!("Descending into node {}", str);
-            let max = current_node.child_count;
-            let random_number = thread_rng().gen_range(0..max);
-            let mut increment = 0;
-
-            let mut did_change = false;
-            // find node corresponding to the node
-            for (_, node) in &current_node.child_nodes {
-                if node.child_count + increment >= random_number {
-                    println!("changing node");
-                    current_node = node;
-                    did_change = true;
-                    break;
-                } else {
-                    println!(
-                        "didn't change node: {}<{}",
-                        node.child_count + increment,
-                        random_number
-                    )
-                }
-                increment += node.child_count;
-            }
-            if did_change {
-                if let Some(value) = current_node.value {
-                    println!("added {}", value.to_char());
-                    str += value.to_char();
-                }
-            } else {
-                println!(
-                    "WARNING: DIDNT CHANGE NODE child_count={}",
-                    current_node.child_count
-                )
-            }
-            // if this node is a final node, we have a probability of using it
+        while current_node.child_nodes.len() != 0 {
+            // We need to choose a random child based on weights
+            let weighted = WeightedIndex::new(
+                current_node
+                    .child_nodes
+                    .iter()
+                    .map(|(_, node)| node.child_count),
+            )
+            .expect("distribution creation should be valid");
+
+            let (key, node) = current_node
+                .child_nodes
+                .iter()
+                .nth(weighted.sample(&mut rng))
+                .expect("choosed value did not exist");
+            println!("waling into node {}", key);
+
+            current_node = node;
+            builder += key;
+
+            // If this node is final and has childrens
             if current_node.is_final && current_node.child_count > 0 {
-                let random_number = thread_rng().gen_range(0..current_node.child_count);
-                if random_number == 0 {
+                // choose from current node or continue with childrens
+                let weighted = WeightedIndex::new(&[1, current_node.child_count])
+                    .expect("distribution seems impossible");
+
+                if weighted.sample(&mut rng) == 0 {
+                    // we choosed this node!
+                    // stop adding other characters
                     break;
                 }
             }
         }
 
-        if str == "" {
+        // If we only added
+        if builder == "" {
             return None;
         }
 
         // selected word
-        Some(str)
+        Some(builder)
     }
 }
index 1c0dfbdf47f60778ca38380b9342dfdf88bbf1d4..a5a2efae37d31858ea7278720eda38e4107511f2 100644 (file)
@@ -10,19 +10,23 @@ import { request } from "undici";
 // `autofeur_db` service
 export const DB = process.env.DB || "http://localhost:3000";
 // nats broker for connecting to nova
-export const NATS = process.env.NATS || "localhost:4222";
+export const NATS = process.env.NATS || "192.168.0.17:4222";
 // rest endpoint for connecting to nova
-export const REST = process.env.REST || "http://localhost:8090/api";
+export const REST = process.env.REST || "http://192.168.0.17:8090/api";
 
 /**
  * Completes a grapheme using the `autofeur_db` service.
  * @param grapheme Grapheme to complete
  * @returns Completed grapheme
  */
-export const completeWord = (grapheme: string) =>
-  request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`).then((x) =>
-    x.body.text()
-  );
+export const completeWord = (grapheme: string) => {
+  let resp = request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`);
+  return resp.then((x) => {
+    if (x.statusCode === 200) {
+      return x.body.text();
+    }
+  });
+};
 
 /**
  * Cleans a sentence for usage with this program, strips unwanted chars
index f0f8d41f24ca7ac0c9ce3c3b39f595e8f606c92b..9b2c703c12da56e57bed4ed6f3011d68a9db05ce 100644 (file)
@@ -23,6 +23,8 @@ services:
   deep_phonemizer:
     build: deep_phonemizer
     restart: always
+    ports:
+      - 8000:8000
     volumes:
       - ./deep_phonemizer/assets/model.pt:/app/assets/model.pt
   nats: