1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
use deepphonemizer::Phonemizer;
use inference::DeepPhonemizerInference;
use std::{error::Error, sync::Arc};
use tch::Device;
use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType};
use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
use twilight_http::Client as HttpClient;
mod bot_config;
mod inference;
use db::{save::Save, types::GraphemeString};
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// Initialize the tracing subscriber.
tracing_subscriber::fmt::init();
let config = config::Config::builder()
.add_source(config::Environment::with_prefix("GRU"))
.build()?
.try_deserialize::<bot_config::Config>()?;
let token = config.token;
// Specify intents requesting events about things like new and updated
// messages in a guild and direct messages.
// Create a single shard.
let mut shard = Shard::new(
ShardId::ONE,
token.clone(),
Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT,
);
// The http client is separate from the gateway, so startup a new
// one, also use Arc such that it can be cloned to other threads.
let http = Arc::new(HttpClient::new(token));
// Since we only care about messages, make the cache only process messages.
let cache = DefaultInMemoryCache::builder()
.resource_types(ResourceType::MESSAGE)
.build();
let mut db = std::fs::File::open(config.db_path).unwrap();
let save: Save =
bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
let phonemizer = Phonemizer::from_checkpoint(
config.model_path,
config.config_path,
Device::cuda_if_available(),
None,
)
.unwrap();
let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
let db = Arc::new(db::inference::Inference::new(inference_service, save));
// Startup the event loop to process each event in the event stream as they
// come in.
while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
let Ok(event) = item else {
tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
continue;
};
// Update the cache.
cache.update(&event);
// Spawn a new task to handle the event
tokio::spawn(handle_event(event, Arc::clone(&http), db.clone()));
}
Ok(())
}
async fn handle_event(
event: Event,
http: Arc<HttpClient>,
inference: Arc<db::inference::Inference>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
match event {
Event::MessageCreate(msg) => {
let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap();
println!("{}", infer);
}
Event::Ready(_) => {
println!("Shard is ready");
}
_ => {}
}
Ok(())
}
|