diff options
| author | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
|---|---|---|
| committer | Matthieu Pignolet <matthieu@puffer.fish> | 2025-10-30 16:04:25 +0100 |
| commit | 09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch) | |
| tree | 874fe676977f56c5456e4cead31e1a6c86bc40dd | |
| parent | c951f54be999fb7e508039ba23fa5b1fe7035743 (diff) | |
feat: commit allfeat/twilight-discord
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | Cargo.lock | 748 | ||||
| -rw-r--r-- | Cargo.toml | 3 | ||||
| -rw-r--r-- | Makefile | 20 | ||||
| -rw-r--r-- | bin/bot/Cargo.toml | 25 | ||||
| -rw-r--r-- | bin/bot/src/bot_config.rs | 10 | ||||
| -rw-r--r-- | bin/bot/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/bot/src/main.rs | 97 | ||||
| -rw-r--r-- | bin/cli/Cargo.toml | 13 | ||||
| -rw-r--r-- | bin/cli/src/inference.rs | 23 | ||||
| -rw-r--r-- | bin/cli/src/main.rs | 30 | ||||
| -rw-r--r-- | libs/db/src/inference.rs | 24 | ||||
| -rw-r--r-- | libs/db/src/save.rs | 12 | ||||
| -rw-r--r-- | libs/db/src/trie.rs | 36 | ||||
| -rw-r--r-- | libs/db/src/types.rs | 11 |
15 files changed, 1025 insertions, 52 deletions
@@ -1,3 +1,5 @@ .env target/ data/ +db.bin +data/
\ No newline at end of file @@ -18,6 +18,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] name = "ahash" version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -122,6 +133,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] +name = "base64ct" +version = "1.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" + +[[package]] name = "bincode" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -164,11 +181,15 @@ name = "bot" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "bincode", "config", "db", + "deepphonemizer", "log", "pretty_env_logger", "serde", + "tch", "tokio", "tracing", "tracing-subscriber", @@ -189,17 +210,45 @@ dependencies = [ ] [[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] name = "cc" version = "1.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8691782945451c1c383942c4874dbe63814f61cb57ef773cda2972682b7bb3c0" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -216,6 +265,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + +[[package]] +name = "cli" +version = "0.1.0" +dependencies = [ + "async-trait", + "bincode", + "db", + "deepphonemizer", + "tch", + "tokio", +] + +[[package]] name = "combine" version = "4.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -259,12 +330,18 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] [[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + +[[package]] name = "convert_case" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -382,6 +459,21 @@ dependencies = [ ] [[package]] +name = "deepphonemizer" +version = "1.0.0" +source = "git+https://github.com/MatthieuCoder/deepphonemizer-rs.git#e3bbb7911c57850e326e27fa453fb89a46031872" +dependencies = [ + "anyhow", + "dotenv", + "itertools", + "serde", + "serde-pickle", + "serde_json", + "serde_yaml", + "tch", +] + +[[package]] name = "deranged" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -398,6 +490,18 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -410,6 +514,12 @@ dependencies = [ ] [[package]] +name = "dotenv" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" + +[[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -473,6 +583,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] name = "futures-channel" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -529,7 +648,19 @@ checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] @@ -558,6 +689,16 @@ dependencies = [ ] [[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -592,6 +733,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" [[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] name = "http" version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -702,6 +852,113 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b24ad5637230df201ab1034d593f1d09bf7f2a9274f2e8897638078579f4265" [[package]] +name = "icu_collections" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" + +[[package]] +name = "icu_properties" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" + +[[package]] +name = "icu_provider" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +dependencies = [ + "displaydoc", + "icu_locale_core", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] name = "index" version = "0.1.0" dependencies = [ @@ -722,6 +979,15 @@ dependencies = [ ] [[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + +[[package]] name = "is-terminal" version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -733,6 +999,12 @@ dependencies = [ ] [[package]] +name = "iter-read" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294" + +[[package]] name = "itertools" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -770,6 +1042,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] +name = "jobserver" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + +[[package]] name = "json5" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -816,6 +1098,12 @@ dependencies = [ ] [[package]] +name = "litemap" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" + +[[package]] name = "lock_api" version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -832,6 +1120,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -853,11 +1151,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] [[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] name = "nu-ansi-term" version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -868,12 +1181,40 @@ dependencies = [ ] [[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] name = "num-traits" version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -929,6 +1270,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] name = "parking_lot_core" version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -942,12 +1293,35 @@ dependencies = [ ] [[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + +[[package]] name = "pathdiff" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" [[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + +[[package]] name = "percent-encoding" version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1017,6 +1391,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + +[[package]] name = "powerfmt" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1060,6 +1458,12 @@ dependencies = [ ] [[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + +[[package]] name = "rand" version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1086,10 +1490,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", ] [[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] name = "redox_syscall" version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1135,7 +1545,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -1176,6 +1586,7 @@ version = "0.23.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -1247,6 +1658,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1303,6 +1724,19 @@ dependencies = [ ] [[package]] +name = "serde-pickle" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843" +dependencies = [ + "byteorder", + "iter-read", + "num-bigint", + "num-traits", + "serde", +] + +[[package]] name = "serde-value" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1356,6 +1790,30 @@ dependencies = [ ] [[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] name = "sha1_smol" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1388,6 +1846,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] name = "simdutf8" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1419,6 +1886,12 @@ dependencies = [ ] [[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1442,6 +1915,34 @@ dependencies = [ ] [[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tch" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a760143efe7e4bb5b56e95d01f52ee6773bc315202e7c47db6a6429b0705a1f2" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "safetensors", + "thiserror 1.0.69", + "torch-sys", + "zip", +] + +[[package]] name = "termcolor" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1540,6 +2041,16 @@ dependencies = [ ] [[package]] +name = "tinystr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] name = "tokio" version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1549,7 +2060,9 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -1647,6 +2160,21 @@ dependencies = [ ] [[package]] +name = "torch-sys" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +dependencies = [ + "anyhow", + "cc", + "libc", + "serde", + "serde_json", + "ureq", + "zip", +] + +[[package]] name = "tower-service" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1845,6 +2373,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1857,6 +2391,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" [[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] name = "valuable" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1906,6 +2475,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + +[[package]] name = "webpki-root-certs" version = "0.26.10" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1915,6 +2493,24 @@ dependencies = [ ] [[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.0", +] + +[[package]] +name = "webpki-roots" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb" +dependencies = [ + "rustls-pki-types", +] + +[[package]] name = "whatlang" version = "0.16.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2113,6 +2709,21 @@ dependencies = [ ] [[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags", +] + +[[package]] +name = "writeable" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" + +[[package]] name = "yaml-rust2" version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2124,6 +2735,30 @@ dependencies = [ ] [[package]] +name = "yoke" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2164,7 +2799,110 @@ dependencies = [ ] [[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +dependencies = [ + "cc", + "pkg-config", +] @@ -4,5 +4,6 @@ resolver = "2" members = [ "libs/db", "bin/bot", - "bin/index" + "bin/index", + "bin/cli" ] @@ -6,22 +6,26 @@ LANGS := all data/: mkdir -p data/ +data/all.csv: data/ + 0>data/all.csv + cat data/*.csv > data/all.csv + data/%.csv: ipa-dict/data/%.txt data/ tr '\t' ',' < $< > $@ data/all: $(NAMES) .PHONY: data/all -data/all.csv: data/$(LANGS:=.csv) - 0>data/all.csv - cat data/*.csv > data/all.csv - build-indexer: bin/index libs/ - +db.bin: data/$(LANGS:=.csv) build-indexer + cat data/$(LANGS:=.csv) | cargo run --bin index --release > db.bin + +data/latin_ipa_forward.pt: data/ + wget "https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/latin_ipa_forward.pt" -O data/latin_ipa_forward.pt -data/db.bin: data/all.csv build-indexer - cat data/all.csv | cargo run --bin index --release > data/all.csv +data/model.pt: data/latin_ipa_forward.pt + ./utils/export.py clean: rm -r data/ -.PHONY: clean +.PHONY: clean
\ No newline at end of file diff --git a/bin/bot/Cargo.toml b/bin/bot/Cargo.toml new file mode 100644 index 0000000..708a831 --- /dev/null +++ b/bin/bot/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "bot" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = "1.0.98" +config = "0.15.11" +log = "0.4.27" +pretty_env_logger = "0.5.0" +serde = "1.0.219" +tokio = { version = "1.44.2", features = ["rt-multi-thread"] } +tracing = "0.1.41" +tracing-subscriber = "0.3.19" +twilight-cache-inmemory = "0.16.0" +twilight-gateway = "0.16.0" +twilight-http = "0.16.0" +twilight-model = "0.16.0" +bincode = { version = "2.0.1", features = ["serde"] } + +db = { path = "../../libs/db" } +deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" } +async-trait = "0.1.88" + +tch = {version = "0.20.0", features = ["download-libtorch"]} diff --git a/bin/bot/src/bot_config.rs b/bin/bot/src/bot_config.rs new file mode 100644 index 0000000..af6d714 --- /dev/null +++ b/bin/bot/src/bot_config.rs @@ -0,0 +1,10 @@ +use serde::{Serialize, Deserialize}; +use std::string::String; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Config { + pub token: String, + pub model_path: String, + pub config_path: String, + pub db_path: String, +} diff --git a/bin/bot/src/inference.rs b/bin/bot/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/bot/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ + inference::InferenceError, + types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { + let pronemized = self + .0 + .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + + return Ok(PhonemeString( + PhonemeString(format!("{:?}", pronemized)).to_string(), + )); + } +} diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs new file mode 100644 index 0000000..574b417 --- /dev/null +++ b/bin/bot/src/main.rs @@ -0,0 +1,97 @@ +use deepphonemizer::Phonemizer; +use inference::DeepPhonemizerInference; +use std::{error::Error, sync::Arc}; +use tch::Device; +use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType}; +use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _}; +use twilight_http::Client as HttpClient; + +mod bot_config; +mod inference; + +use db::{save::Save, types::GraphemeString}; + +#[tokio::main] +async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { + // Initialize the tracing subscriber. + tracing_subscriber::fmt::init(); + + let config = config::Config::builder() + .add_source(config::Environment::with_prefix("GRU")) + .build()? + .try_deserialize::<bot_config::Config>()?; + let token = config.token; + + // Specify intents requesting events about things like new and updated + // messages in a guild and direct messages. + + // Create a single shard. + let mut shard = Shard::new( + ShardId::ONE, + token.clone(), + Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT, + ); + + // The http client is separate from the gateway, so startup a new + // one, also use Arc such that it can be cloned to other threads. + let http = Arc::new(HttpClient::new(token)); + + // Since we only care about messages, make the cache only process messages. + let cache = DefaultInMemoryCache::builder() + .resource_types(ResourceType::MESSAGE) + .build(); + + let mut db = std::fs::File::open(config.db_path).unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + config.model_path, + config.config_path, + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + // Startup the event loop to process each event in the event stream as they + // come in. + while let Some(item) = shard.next_event(EventTypeFlags::all()).await { + let Ok(event) = item else { + tracing::warn!(source = ?item.unwrap_err(), "error receiving event"); + + continue; + }; + // Update the cache. + cache.update(&event); + + // Spawn a new task to handle the event + tokio::spawn(handle_event(event, Arc::clone(&http), db.clone())); + } + + Ok(()) +} + +async fn handle_event( + event: Event, + http: Arc<HttpClient>, + inference: Arc<db::inference::Inference>, +) -> Result<(), Box<dyn Error + Send + Sync>> { + match event { + Event::MessageCreate(msg) => { + let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap(); + + println!("{}", infer); + } + Event::Ready(_) => { + println!("Shard is ready"); + } + _ => {} + } + + Ok(()) +} diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml new file mode 100644 index 0000000..cc2b393 --- /dev/null +++ b/bin/cli/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "cli" +version = "0.1.0" +edition = "2024" + +[dependencies] + +tokio = { version = "1.44.2", features = ["full"] } +db = { path = "../../libs/db" } +deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" } +async-trait = "0.1.88" +bincode = { version = "2.0.1", features = ["serde"] } +tch = {version = "0.20.0", features = ["download-libtorch"]} diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs new file mode 100644 index 0000000..f78f381 --- /dev/null +++ b/bin/cli/src/inference.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use db::types::GraphemeString; +use db::{ + inference::InferenceError, + types::{InferenceService, PhonemeString}, +}; +use deepphonemizer::Phonemizer; + +#[derive(Debug)] +pub struct DeepPhonemizerInference(pub Phonemizer); + +#[async_trait] +impl InferenceService<InferenceError> for DeepPhonemizerInference { + async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> { + let pronemized = self + .0 + .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap(); + + return Ok(PhonemeString( + PhonemeString(format!("{:?}", pronemized)).to_string(), + )); + } +} diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs new file mode 100644 index 0000000..736ba8e --- /dev/null +++ b/bin/cli/src/main.rs @@ -0,0 +1,30 @@ +use db::{save::Save, types::GraphemeString}; +use deepphonemizer::Phonemizer; +use tch::Device; +use std::sync::Arc; + +mod inference; + +use crate::inference::DeepPhonemizerInference; + +#[tokio::main] +async fn main() { + let mut db = std::fs::File::open("db.bin").unwrap(); + + let save: Save = + bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap(); + + let phonemizer = Phonemizer::from_checkpoint( + "data/model.pt", + "data/forward_config.yaml", + Device::cuda_if_available(), + None, + ) + .unwrap(); + + let inference_service = Arc::new(DeepPhonemizerInference(phonemizer)); + + let db = Arc::new(db::inference::Inference::new(inference_service, save)); + + println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap()); +} diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs index f2cbabf..25415df 100644 --- a/libs/db/src/inference.rs +++ b/libs/db/src/inference.rs @@ -5,7 +5,7 @@ use crate::{ use hypher::{Lang, hyphenate}; use itertools::Itertools; use log::*; -use std::{iter, ops::Add}; +use std::{iter, ops::Add, sync::Arc}; use thiserror::Error; use unicode_segmentation::UnicodeSegmentation; @@ -15,13 +15,14 @@ pub enum InferenceError { TrieError(), #[error("the matched value wasn't in the dictionary")] NotInDictionary(), - #[error("failt to cut the word")] + #[error("failed to cut the word")] WordCutError(), } -pub struct Inference<'a> { - inference_service: Box<dyn InferenceService<InferenceError>>, - save: Save<'a>, +#[derive(Debug, Clone)] +pub struct Inference { + inference_service: Arc<dyn InferenceService<InferenceError>>, + save: Save, } pub struct TLang(whatlang::Lang); @@ -62,7 +63,7 @@ impl Into<Lang> for TLang { } } -impl Inference<'_> { +impl Inference { fn matches(&self, source: &str, complete: &str) -> bool { let source_chars: Vec<&str> = source.graphemes(true).collect(); let complete_chars: Vec<&str> = complete.graphemes(true).collect(); @@ -79,10 +80,10 @@ impl Inference<'_> { return true; } - pub fn new<'a>( - inference_service: Box<dyn InferenceService<InferenceError>>, - save: Save<'a>, - ) -> Inference<'a> { + pub fn new( + inference_service: Arc<dyn InferenceService<InferenceError>>, + save: Save, + ) -> Inference { return Inference { inference_service, save, @@ -112,8 +113,7 @@ impl Inference<'_> { let completion = self .save .trie - .random_starting_with(phonemes.clone()) - .ok_or_else(InferenceError::TrieError)?; + .random_starting_with(phonemes.clone()).unwrap(); // store the found word let initial = cprefix.add(&completion.0); diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs index e1b3577..87a2736 100644 --- a/libs/db/src/save.rs +++ b/libs/db/src/save.rs @@ -1,10 +1,12 @@ -use crate::{trie::Trie, types::{GraphemeString, PhonemeString}}; +use crate::{ + trie::Trie, + types::{GraphemeString, PhonemeString}, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -#[derive(Debug, Deserialize, Serialize, Default)] -pub struct Save<'a> { - #[serde(borrow = "'a")] - pub trie: Trie<'a>, +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct Save { + pub trie: Trie, pub reverse_index: HashMap<PhonemeString, GraphemeString>, } diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs index 38cbcd8..754d978 100644 --- a/libs/db/src/trie.rs +++ b/libs/db/src/trie.rs @@ -5,16 +5,15 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use unicode_segmentation::UnicodeSegmentation; -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct TrieNode<'a> { +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct TrieNode { is_final: bool, - #[serde(borrow = "'a")] - child_nodes: HashMap<&'a str, TrieNode<'a>>, + child_nodes: HashMap<String, TrieNode>, child_count: u64, } -impl<'a> TrieNode<'a> { - pub fn new<'b>(is_final: bool) -> TrieNode<'b> { +impl TrieNode { + pub fn new<'b>(is_final: bool) -> TrieNode { TrieNode { is_final, child_nodes: HashMap::new(), @@ -22,7 +21,7 @@ impl<'a> TrieNode<'a> { } } - pub fn new_root<'b>() -> TrieNode<'b> { + pub fn new_root<'b>() -> TrieNode { TrieNode { is_final: false, child_nodes: HashMap::new(), @@ -31,15 +30,14 @@ impl<'a> TrieNode<'a> { } } -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Trie<'a> { - #[serde(borrow = "'a")] - root_node: Box<TrieNode<'a>>, +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct Trie { + root_node: Box<TrieNode>, } -impl<'a> Trie<'a> { +impl<'a> Trie { // Create a TrieStruct - pub fn new<'b>() -> Trie<'b> { + pub fn new<'b>() -> Trie { Trie { root_node: Box::new(TrieNode::new_root()), } @@ -62,12 +60,16 @@ impl<'a> Trie<'a> { } else { create = true; - current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node + .child_nodes + .insert(str.to_string(), TrieNode::new(false)); current_node = current_node.child_nodes.get_mut(str).unwrap(); current_node.child_count = 1; } } else { - current_node.child_nodes.insert(str, TrieNode::new(false)); + current_node + .child_nodes + .insert(str.to_string(), TrieNode::new(false)); current_node = current_node.child_nodes.get_mut(str).unwrap(); // we will only have one final node current_node.child_count = 1; @@ -88,7 +90,7 @@ impl<'a> Trie<'a> { // Descend as far as possible into the tree for (_, str) in graphemes { // If we can descend further into the tree - if let Some(node) = current_node.child_nodes.get(&str) { + if let Some(node) = current_node.child_nodes.get(str) { current_node = node; builder += str; debug!("going into node {}", builder); @@ -124,7 +126,7 @@ impl<'a> Trie<'a> { .child_nodes .iter() .nth(weighted.sample(&mut rng)) - .expect("choosed value did not exist"); + .expect("choose value did not exist"); debug!("walking into node {}", key); current_node = node; diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs index 6062c98..636fcb7 100644 --- a/libs/db/src/types.rs +++ b/libs/db/src/types.rs @@ -1,12 +1,15 @@ -use std::{error::Error, fmt::Display}; - use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{Debug, Display}, + sync::Arc, +}; // This is needed since we need dynamic dispatch // as any kind of inference service could be used. #[async_trait] -pub trait InferenceService<E: Error> : Send + Sync { +pub trait InferenceService<E: Error>: Send + Sync + Debug { async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>; } @@ -36,7 +39,7 @@ impl PhonemeString { /// !! This is a one-way operation. pub async fn from_grapheme<E: Error>( value: GraphemeString, - inference: &Box<dyn InferenceService<E>>, + inference: &Arc<dyn InferenceService<E>>, ) -> Result<Self, E> { inference.fetch(value).await } |
