summaryrefslogtreecommitdiff
path: root/libs/db/src/trie.rs
blob: 38cbcd880bea296b81649ee978d44a8c4caff0e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use crate::types::PhonemeString;
use log::debug;
use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use unicode_segmentation::UnicodeSegmentation;

#[derive(Debug, Serialize, Deserialize, Default)]
pub struct TrieNode<'a> {
    is_final: bool,
    #[serde(borrow = "'a")]
    child_nodes: HashMap<&'a str, TrieNode<'a>>,
    child_count: u64,
}

impl<'a> TrieNode<'a> {
    pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
        TrieNode {
            is_final,
            child_nodes: HashMap::new(),
            child_count: 0,
        }
    }

    pub fn new_root<'b>() -> TrieNode<'b> {
        TrieNode {
            is_final: false,
            child_nodes: HashMap::new(),
            child_count: 0,
        }
    }
}

#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Trie<'a> {
    #[serde(borrow = "'a")]
    root_node: Box<TrieNode<'a>>,
}

impl<'a> Trie<'a> {
    // Create a TrieStruct
    pub fn new<'b>() -> Trie<'b> {
        Trie {
            root_node: Box::new(TrieNode::new_root()),
        }
    }

    // Insert a string
    pub fn insert(&mut self, char_list: &'a PhonemeString) {
        let mut current_node: &mut TrieNode = self.root_node.as_mut();
        let iterator = char_list.0.graphemes(true);
        let mut create = false;

        current_node.child_count += 1;
        // Find the minimum existing math
        for str in iterator {
            if create == false {
                if current_node.child_nodes.contains_key(str) {
                    current_node = current_node.child_nodes.get_mut(str).unwrap();
                    // we mark the node as containing our children.
                    current_node.child_count += 1;
                } else {
                    create = true;

                    current_node.child_nodes.insert(str, TrieNode::new(false));
                    current_node = current_node.child_nodes.get_mut(str).unwrap();
                    current_node.child_count = 1;
                }
            } else {
                current_node.child_nodes.insert(str, TrieNode::new(false));
                current_node = current_node.child_nodes.get_mut(str).unwrap();
                // we will only have one final node
                current_node.child_count = 1;
            }
        }
        current_node.is_final = true;
    }

    // Find a string
    pub fn random_starting_with(&self, prefix: PhonemeString) -> Option<PhonemeString> {
        let mut current_node: &TrieNode = self.root_node.as_ref();
        // String for the return value
        let mut builder = String::new();

        // Iterator over each grapheme
        let graphemes = prefix.0.graphemes(true).enumerate();

        // Descend as far as possible into the tree
        for (_, str) in graphemes {
            // If we can descend further into the tree
            if let Some(node) = current_node.child_nodes.get(&str) {
                current_node = node;
                builder += str;
                debug!("going into node {}", builder);
            } else {
                // couldn't descend fully into the tree
                // this basically means nothing exist in the tree
                // with this prefix
                debug!("no matches for prefix!");
                return None;
            }
        }

        debug!("Found common root node {}", builder);
        builder = String::new();
        let mut rng = thread_rng();

        let mut length = 0;
        while current_node.child_nodes.len() != 0 {
            // We need to choose a random child based on weights
            let weighted = WeightedIndex::new(current_node.child_nodes.iter().map(|(_, node)| {
                node.child_count
                    / (node
                        .child_nodes
                        .iter()
                        .map(|(_, b)| b.child_count)
                        .sum::<u64>()
                        + 1)
                    + 1
            }))
            .expect("distribution creation should be valid");

            let (key, node) = current_node
                .child_nodes
                .iter()
                .nth(weighted.sample(&mut rng))
                .expect("choosed value did not exist");
            debug!("walking into node {}", key);

            current_node = node;
            builder += key;

            // If this node is final and has childrens
            if current_node.is_final && current_node.child_count > 0 {
                // choose from current node or continue with childrens
                let weighted = WeightedIndex::new(&[1, current_node.child_count / (length + 1)])
                    .expect("distribution seems impossible");

                if weighted.sample(&mut rng) == 0 {
                    // we choosed this node!
                    // stop adding other characters
                    break;
                }
            }
            length += 1;
        }

        // If we only added
        if builder == "" {
            return None;
        }

        // selected word
        Some(PhonemeString(builder))
    }
}