diff options
Diffstat (limited to 'autofeur_db')
| -rw-r--r-- | autofeur_db/src/bin/generate.rs | 11 | ||||
| -rw-r--r-- | autofeur_db/src/bin/server.rs | 37 | ||||
| -rw-r--r-- | autofeur_db/src/french_ipa.rs | 129 | ||||
| -rw-r--r-- | autofeur_db/src/inference.rs | 23 | ||||
| -rw-r--r-- | autofeur_db/src/lib.rs | 1 | ||||
| -rw-r--r-- | autofeur_db/src/save.rs | 9 | ||||
| -rw-r--r-- | autofeur_db/src/trie.rs | 191 | 
7 files changed, 132 insertions, 269 deletions
diff --git a/autofeur_db/src/bin/generate.rs b/autofeur_db/src/bin/generate.rs index 8db8809..7f85f41 100644 --- a/autofeur_db/src/bin/generate.rs +++ b/autofeur_db/src/bin/generate.rs @@ -1,6 +1,5 @@  use std::fs; -use autofeur::french_ipa::parse_word;  use autofeur::save::Save;  use kdam::tqdm; @@ -42,18 +41,10 @@ async fn main() {          phonems.append(&mut pron);      } -    let mut invalid = 0;      for phoneme in tqdm!(phonems.iter()) { -        match parse_word(&phoneme) { -            Some(a) => save.trie.insert(a), -            None => { -                invalid += 1; -            } -        } +        save.trie.insert(&phoneme);      } -    println!("Invalid items count: {}", invalid); -      fs::write("assets/db.bin", bincode::serialize(&save).unwrap()).unwrap();      println!("Generated to assets/db.bin"); diff --git a/autofeur_db/src/bin/server.rs b/autofeur_db/src/bin/server.rs index 376b4e1..8c350f9 100644 --- a/autofeur_db/src/bin/server.rs +++ b/autofeur_db/src/bin/server.rs @@ -17,28 +17,47 @@ fn parse_query(query: &str) -> HashMap<String, String> {          .collect()  } -async fn handler(request: Request<Body>) -> Result<Response<Body>, anyhow::Error> { +fn anyhow_response(err: anyhow::Error) -> Response<Body> { +    Response::builder() +        .status(400) +        .body(Body::from(err.root_cause().to_string())) +        .unwrap() +} + +async fn handler(request: Request<Body>) -> Result<Response<Body>, hyper::Error> {      let save: &Arc<Save> = request.extensions().get().unwrap(); -    let query = request +    let query = match request          .uri()          .query() -        .ok_or_else(|| anyhow!("query does not exists"))?; -    let data = parse_query(query) +        .ok_or_else(|| anyhow_response(anyhow!("query does not exists"))) +    { +        Ok(ok) => ok, +        Err(err) => return Ok(err), +    }; +    let data = match parse_query(query)          .get("grapheme") -        .ok_or_else(|| anyhow!("grapheme argument is not specified"))? -        .clone(); +        .ok_or_else(|| anyhow_response(anyhow!("grapheme argument is not specified"))) +    { +        Ok(ok) => ok.clone(), +        Err(err) => return Ok(err), +    }; -    let infered = save +    let infered = match save          .inference(&data)          .await -        .or_else(|_| Err(anyhow!("cannot find data")))?; +        .or_else(|e| Err(anyhow_response(e.context("inference error")))) +    { +        Ok(ok) => ok, +        Err(e) => return Ok(e), +    };      Ok(Response::builder().body(Body::from(infered)).unwrap())  }  #[tokio::main]  async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { -    let checkpoint: Save = bincode::deserialize(&fs::read("assets/db.bin").unwrap()).unwrap(); +    let data = Box::leak(Box::new(fs::read("assets/db.bin").unwrap())); +    let checkpoint: Save = bincode::deserialize(data).unwrap();      let service = ServiceBuilder::new()          .layer(AddExtensionLayer::new(Arc::new(checkpoint)))          // Wrap a `Service` in our middleware stack diff --git a/autofeur_db/src/french_ipa.rs b/autofeur_db/src/french_ipa.rs deleted file mode 100644 index b758779..0000000 --- a/autofeur_db/src/french_ipa.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::hash::Hash; - -use unicode_segmentation::UnicodeSegmentation; - -macro_rules! ipa_element_to_number { -    (@step $_idx:expr, $ident:ident,) => { -        None -    }; - -    (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => { -        if $ident == $head { -            Some(Self($idx)) -        } -        else { -            ipa_element_to_number!(@step $idx + 1usize, $ident, $($tail,)*) -        } -    }; -} -macro_rules! ipa_number_to_ipa { -    (@step $_idx:expr, $ident:ident,) => { -        "unreachable!()" -    }; - -    (@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => { -        if $ident == $idx { -            $head -        } -        else { -            ipa_number_to_ipa!(@step $idx + 1usize, $ident, $($tail,)*) -        } -    }; -} - -macro_rules! replace_expr { -    ($_t:tt $sub:expr) => { -        $sub -    }; -} - -macro_rules! count_tts { -    ($($tts:tt)*) => {0usize $(+ replace_expr!($tts 1usize))*}; -} - -macro_rules! ipa_map { -    ($name:ident, $($l:literal),*) => { -        use serde::{Deserialize, Serialize}; -        #[derive(Eq, Hash, PartialEq, Debug, Copy, Clone, Serialize, Deserialize)] -        pub struct $name(pub usize); - -        impl $name { -            pub const SIZE: usize = count_tts!($($l,)*); - -            pub fn from_char(ch: &str) -> Option<$name> { -                ipa_element_to_number!(@step 0usize, ch, $($l,)*) -            } -            pub fn to_char(self) -> &'static str { -                let num = self.0; -                ipa_number_to_ipa!(@step 0usize, num, $($l,)*) -            } -        } -    }; -} - -ipa_map!( -    FrenchIPAChar, -    "a", -    "ɑ", -    "ɑ̃", -    "e", -    "ɛ", -    "ɛ̃", -    "ə", -    "i", -    "o", -    "ɔ", -    "ɔ̃", -    "œ", -    "œ̃", -    "ø", -    "u", -    "y", -    "j", -    "ɥ", -    "w", -    "b", -    "d", -    "f", -    "g", -    "k", -    "l", -    "m", -    "n", -    "ɲ", -    "ŋ", -    "p", -    "ʁ", -    "s", -    "ʃ", -    "t", -    "v", -    "z", -    "ʒ", -    "g", -    "ɡ", -    "ɪ", -    "ʊ", -    "x", -    "r" -); - -pub type FrenchIPAWord = Vec<FrenchIPAChar>; - -pub fn parse_word(str: &str) -> Option<FrenchIPAWord> { -    let mut word = FrenchIPAWord::default(); -    let graphemes: Vec<&str> = str.graphemes(true).collect(); -    for (_, grapheme) in graphemes.iter().enumerate() { -        let a = FrenchIPAChar::from_char(grapheme); - -        word.push(match a { -            None => { -                println!("invalid char: {}", grapheme); -                return None; -            } -            Some(a) => a, -        }) -    } - -    Some(word) -} diff --git a/autofeur_db/src/inference.rs b/autofeur_db/src/inference.rs index 8b833a0..f8c1351 100644 --- a/autofeur_db/src/inference.rs +++ b/autofeur_db/src/inference.rs @@ -5,11 +5,11 @@ use itertools::Itertools;  use levenshtein::levenshtein;  use unicode_segmentation::UnicodeSegmentation; -use crate::{french_ipa::parse_word, save::Save}; +use crate::save::Save;  async fn call_inference_service(word: &str) -> anyhow::Result<String> {      let server: Result<String, anyhow::Error> = -        env::var("PHONEMIZER").or_else(|_| Ok("".to_string())); +        env::var("PHONEMIZER").or_else(|_| Ok("http://localhost:8000/".to_string()));      Ok(          reqwest::get(format!("{}?grapheme={}", server.unwrap(), word))              .await? @@ -18,22 +18,23 @@ async fn call_inference_service(word: &str) -> anyhow::Result<String> {      )  } -impl Save { +impl Save<'_> {      pub async fn inference(&self, prefix: &str) -> anyhow::Result<String> {          let phonemes = call_inference_service(prefix).await?; -        let ipa_phonemes = -            parse_word(&phonemes).ok_or_else(|| anyhow!("failed to parse the word"))?;          let completion = self              .trie -            .random_starting_with(ipa_phonemes) +            .random_starting_with(&phonemes)              .ok_or_else(|| anyhow!("no matches"))?; -        let infered = phonemes.add(&completion); -        let word = self -            .reverse_index -            .get(&infered) -            .ok_or_else(|| anyhow!("matched values is not in dictionary"))?; +        let infered = phonemes.clone().add(&completion); +        let word = self.reverse_index.get(&infered).ok_or_else(|| { +            anyhow!( +                "matched value is not in dictionary {} {}", +                infered, +                phonemes +            ) +        })?;          println!("Matching {} by adding {}", word, completion); diff --git a/autofeur_db/src/lib.rs b/autofeur_db/src/lib.rs index 9c8ba4b..6610fc0 100644 --- a/autofeur_db/src/lib.rs +++ b/autofeur_db/src/lib.rs @@ -1,4 +1,3 @@  pub mod trie; -pub mod french_ipa;  pub mod save;  pub mod inference; diff --git a/autofeur_db/src/save.rs b/autofeur_db/src/save.rs index c64a430..a2bfd4c 100644 --- a/autofeur_db/src/save.rs +++ b/autofeur_db/src/save.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize};  use crate::trie::Trie;  #[derive(Debug, Deserialize, Serialize, Default)] -pub struct Save { -    pub trie: Trie, -    pub reverse_index: HashMap<String, String> -}
\ No newline at end of file +pub struct Save<'a> { +    #[serde(borrow = "'a")] +    pub trie: Trie<'a>, +    pub reverse_index: HashMap<String, String>, +} diff --git a/autofeur_db/src/trie.rs b/autofeur_db/src/trie.rs index 43a1be8..8e8decc 100644 --- a/autofeur_db/src/trie.rs +++ b/autofeur_db/src/trie.rs @@ -1,169 +1,150 @@  use std::collections::HashMap; -use rand::{thread_rng, Rng}; +use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};  use serde::{Deserialize, Serialize}; - -use crate::french_ipa::{FrenchIPAChar, FrenchIPAWord}; +use unicode_segmentation::UnicodeSegmentation;  #[derive(Debug, Serialize, Deserialize, Default)] -pub struct TrieNode { -    value: Option<FrenchIPAChar>, +pub struct TrieNode<'a> {      is_final: bool, -    child_nodes: HashMap<FrenchIPAChar, TrieNode>, +    #[serde(borrow = "'a")] +    child_nodes: HashMap<&'a str, TrieNode<'a>>,      child_count: u64,  } -impl TrieNode { +impl<'a> TrieNode<'a> {      // Create new node -    pub fn new(c: FrenchIPAChar, is_final: bool) -> TrieNode { +    pub fn new<'b>(is_final: bool) -> TrieNode<'b> {          TrieNode { -            value: Option::Some(c),              is_final, -            child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE), +            child_nodes: HashMap::new(),              child_count: 0,          }      } -    pub fn new_root() -> TrieNode { +    pub fn new_root<'b>() -> TrieNode<'b> {          TrieNode { -            value: Option::None,              is_final: false, -            child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE), +            child_nodes: HashMap::new(),              child_count: 0,          }      }  }  #[derive(Debug, Serialize, Deserialize, Default)] -pub struct Trie { -    root_node: Box<TrieNode>, +pub struct Trie<'a> { +    #[serde(borrow = "'a")] +    root_node: Box<TrieNode<'a>>,  } -impl Trie { +impl<'a> Trie<'a> {      // Create a TrieStruct -    pub fn new() -> Trie { +    pub fn new<'b>() -> Trie<'b> {          Trie {              root_node: Box::new(TrieNode::new_root()),          }      }      // Insert a string -    pub fn insert(&mut self, char_list: FrenchIPAWord) { +    pub fn insert(&mut self, char_list: &'a str) {          let mut current_node: &mut TrieNode = self.root_node.as_mut(); -        let mut last_match = 0; +        let iterator = char_list.graphemes(true); +        let mut create = false; +        current_node.child_count += 1;          // Find the minimum existing math -        for letter_counter in 0..char_list.len() { -            if current_node -                .child_nodes -                .contains_key(&char_list[letter_counter]) -            { -                current_node = current_node -                    .child_nodes -                    .get_mut(&char_list[letter_counter]) -                    .unwrap(); -                // we mark the node as containing our children. -                current_node.child_count += 1; -            } else { -                last_match = letter_counter; -                break; -            } -            last_match = letter_counter + 1; -        } +        for str in iterator { +            if create == false { +                if current_node.child_nodes.contains_key(str) { +                    current_node = current_node.child_nodes.get_mut(str).unwrap(); +                    // we mark the node as containing our children. +                    current_node.child_count += 1; +                } else { +                    create = true; -        // if we found an already exsting node -        if last_match == char_list.len() { -            current_node.is_final = true; -        } else { -            for new_counter in last_match..char_list.len() { -                let key = char_list[new_counter]; -                current_node -                    .child_nodes -                    .insert(key, TrieNode::new(char_list[new_counter], false)); -                current_node = current_node.child_nodes.get_mut(&key).unwrap(); -                current_node.child_count += 1; +                    current_node.child_nodes.insert(str, TrieNode::new(false)); +                    current_node = current_node.child_nodes.get_mut(str).unwrap(); +                    current_node.child_count = 1; +                } +            } else { +                current_node.child_nodes.insert(str, TrieNode::new(false)); +                current_node = current_node.child_nodes.get_mut(str).unwrap(); +                // we will only have one final node +                current_node.child_count = 1;              } -            current_node.is_final = true;          } +        current_node.is_final = true;      }      // Find a string -    pub fn random_starting_with(&self, prefix: FrenchIPAWord) -> Option<String> { +    pub fn random_starting_with(&self, prefix: &str) -> Option<String> {          let mut current_node: &TrieNode = self.root_node.as_ref(); -        let mut str = String::new(); -        let mut i = prefix.len(); +        // String for the return value +        let mut builder = String::new(); + +        // Iterator over each grapheme +        let graphemes = prefix.graphemes(true).enumerate(); +          // Descend as far as possible into the tree -        for counter in prefix { -            if let Some(node) = current_node.child_nodes.get(&counter) { +        for (_, str) in graphemes { +            // If we can descend further into the tree +            if let Some(node) = current_node.child_nodes.get(&str) {                  current_node = node; -                if let Some(value) = current_node.value { -                    str += value.to_char(); -                    i -= 1; -                } +                builder += str; +                println!("going into node {}", builder);              } else {                  // couldn't descend fully into the tree +                // this basically means nothing exist in the tree +                // with this prefix +                println!("no matches for prefix!");                  return None;              }          } -        println!("Found common root node {}", str); +        println!("Found common root node {}", builder); +        builder = String::new(); +        let mut rng = thread_rng(); -        // Ignore the 0-len matches -        if i == 0 && current_node.child_nodes.len() == 0 { -            println!("removing 0-len match"); -            return None; -        } -        str = String::new(); - -        // now that we have the node we descend by respecting the probabilities -        while current_node.child_nodes.len() != 0 && current_node.child_count > 0 { -            println!("Descending into node {}", str); -            let max = current_node.child_count; -            let random_number = thread_rng().gen_range(0..max); -            let mut increment = 0; - -            let mut did_change = false; -            // find node corresponding to the node -            for (_, node) in ¤t_node.child_nodes { -                if node.child_count + increment >= random_number { -                    println!("changing node"); -                    current_node = node; -                    did_change = true; -                    break; -                } else { -                    println!( -                        "didn't change node: {}<{}", -                        node.child_count + increment, -                        random_number -                    ) -                } -                increment += node.child_count; -            } -            if did_change { -                if let Some(value) = current_node.value { -                    println!("added {}", value.to_char()); -                    str += value.to_char(); -                } -            } else { -                println!( -                    "WARNING: DIDNT CHANGE NODE child_count={}", -                    current_node.child_count -                ) -            } -            // if this node is a final node, we have a probability of using it +        while current_node.child_nodes.len() != 0 { +            // We need to choose a random child based on weights +            let weighted = WeightedIndex::new( +                current_node +                    .child_nodes +                    .iter() +                    .map(|(_, node)| node.child_count), +            ) +            .expect("distribution creation should be valid"); + +            let (key, node) = current_node +                .child_nodes +                .iter() +                .nth(weighted.sample(&mut rng)) +                .expect("choosed value did not exist"); +            println!("waling into node {}", key); + +            current_node = node; +            builder += key; + +            // If this node is final and has childrens              if current_node.is_final && current_node.child_count > 0 { -                let random_number = thread_rng().gen_range(0..current_node.child_count); -                if random_number == 0 { +                // choose from current node or continue with childrens +                let weighted = WeightedIndex::new(&[1, current_node.child_count]) +                    .expect("distribution seems impossible"); + +                if weighted.sample(&mut rng) == 0 { +                    // we choosed this node! +                    // stop adding other characters                      break;                  }              }          } -        if str == "" { +        // If we only added +        if builder == "" {              return None;          }          // selected word -        Some(str) +        Some(builder)      }  }  | 
