1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
|
use crate::types::PhonemeString;
use log::debug;
use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use unicode_segmentation::UnicodeSegmentation;
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct TrieNode {
is_final: bool,
child_nodes: HashMap<String, TrieNode>,
child_count: u64,
}
impl TrieNode {
pub fn new<'b>(is_final: bool) -> TrieNode {
TrieNode {
is_final,
child_nodes: HashMap::new(),
child_count: 0,
}
}
pub fn new_root<'b>() -> TrieNode {
TrieNode {
is_final: false,
child_nodes: HashMap::new(),
child_count: 0,
}
}
}
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct Trie {
root_node: Box<TrieNode>,
}
impl<'a> Trie {
// Create a TrieStruct
pub fn new<'b>() -> Trie {
Trie {
root_node: Box::new(TrieNode::new_root()),
}
}
// Insert a string
pub fn insert(&mut self, char_list: &'a PhonemeString) {
let mut current_node: &mut TrieNode = self.root_node.as_mut();
let iterator = char_list.0.graphemes(true);
let mut create = false;
current_node.child_count += 1;
// Find the minimum existing math
for str in iterator {
if create == false {
if current_node.child_nodes.contains_key(str) {
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we mark the node as containing our children.
current_node.child_count += 1;
} else {
create = true;
current_node
.child_nodes
.insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
current_node.child_count = 1;
}
} else {
current_node
.child_nodes
.insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we will only have one final node
current_node.child_count = 1;
}
}
current_node.is_final = true;
}
// Find a string
pub fn random_starting_with(&self, prefix: PhonemeString) -> Option<PhonemeString> {
let mut current_node: &TrieNode = self.root_node.as_ref();
// String for the return value
let mut builder = String::new();
// Iterator over each grapheme
let graphemes = prefix.0.graphemes(true).enumerate();
// Descend as far as possible into the tree
for (_, str) in graphemes {
// If we can descend further into the tree
if let Some(node) = current_node.child_nodes.get(str) {
current_node = node;
builder += str;
debug!("going into node {}", builder);
} else {
// couldn't descend fully into the tree
// this basically means nothing exist in the tree
// with this prefix
debug!("no matches for prefix!");
return None;
}
}
debug!("Found common root node {}", builder);
builder = String::new();
let mut rng = thread_rng();
let mut length = 0;
while current_node.child_nodes.len() != 0 {
// We need to choose a random child based on weights
let weighted = WeightedIndex::new(current_node.child_nodes.iter().map(|(_, node)| {
node.child_count
/ (node
.child_nodes
.iter()
.map(|(_, b)| b.child_count)
.sum::<u64>()
+ 1)
+ 1
}))
.expect("distribution creation should be valid");
let (key, node) = current_node
.child_nodes
.iter()
.nth(weighted.sample(&mut rng))
.expect("choose value did not exist");
debug!("walking into node {}", key);
current_node = node;
builder += key;
// If this node is final and has childrens
if current_node.is_final && current_node.child_count > 0 {
// choose from current node or continue with childrens
let weighted = WeightedIndex::new(&[1, current_node.child_count / (length + 1)])
.expect("distribution seems impossible");
if weighted.sample(&mut rng) == 0 {
// we choosed this node!
// stop adding other characters
break;
}
}
length += 1;
}
// If we only added
if builder == "" {
return None;
}
// selected word
Some(PhonemeString(builder))
}
}
|