summaryrefslogtreecommitdiff
path: root/bin/bot/src/main.rs
blob: 574b41715e2a5e5be189203d9dcc3f6878d8c725 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use deepphonemizer::Phonemizer;
use inference::DeepPhonemizerInference;
use std::{error::Error, sync::Arc};
use tch::Device;
use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType};
use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
use twilight_http::Client as HttpClient;

mod bot_config;
mod inference;

use db::{save::Save, types::GraphemeString};

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
    // Initialize the tracing subscriber.
    tracing_subscriber::fmt::init();

    let config = config::Config::builder()
        .add_source(config::Environment::with_prefix("GRU"))
        .build()?
        .try_deserialize::<bot_config::Config>()?;
    let token = config.token;

    // Specify intents requesting events about things like new and updated
    // messages in a guild and direct messages.

    // Create a single shard.
    let mut shard = Shard::new(
        ShardId::ONE,
        token.clone(),
        Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT,
    );

    // The http client is separate from the gateway, so startup a new
    // one, also use Arc such that it can be cloned to other threads.
    let http = Arc::new(HttpClient::new(token));

    // Since we only care about messages, make the cache only process messages.
    let cache = DefaultInMemoryCache::builder()
        .resource_types(ResourceType::MESSAGE)
        .build();

    let mut db = std::fs::File::open(config.db_path).unwrap();

    let save: Save =
        bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();

    let phonemizer = Phonemizer::from_checkpoint(
        config.model_path,
        config.config_path,
        Device::cuda_if_available(),
        None,
    )
    .unwrap();

    let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));

    let db = Arc::new(db::inference::Inference::new(inference_service, save));

    // Startup the event loop to process each event in the event stream as they
    // come in.
    while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
        let Ok(event) = item else {
            tracing::warn!(source = ?item.unwrap_err(), "error receiving event");

            continue;
        };
        // Update the cache.
        cache.update(&event);

        // Spawn a new task to handle the event
        tokio::spawn(handle_event(event, Arc::clone(&http), db.clone()));
    }

    Ok(())
}

async fn handle_event(
    event: Event,
    http: Arc<HttpClient>,
    inference: Arc<db::inference::Inference>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
    match event {
        Event::MessageCreate(msg) => {
            let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap();

            println!("{}", infer);
        }
        Event::Ready(_) => {
            println!("Shard is ready");
        }
        _ => {}
    }

    Ok(())
}