summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
committerMatthieu Pignolet <matthieu@puffer.fish>2025-10-30 16:04:25 +0100
commit09a0afaff012b88f79afde35a3c2ebaeb010a176 (patch)
tree874fe676977f56c5456e4cead31e1a6c86bc40dd
parentc951f54be999fb7e508039ba23fa5b1fe7035743 (diff)
feat: commit allfeat/twilight-discord
-rw-r--r--.gitignore2
-rw-r--r--Cargo.lock748
-rw-r--r--Cargo.toml3
-rw-r--r--Makefile20
-rw-r--r--bin/bot/Cargo.toml25
-rw-r--r--bin/bot/src/bot_config.rs10
-rw-r--r--bin/bot/src/inference.rs23
-rw-r--r--bin/bot/src/main.rs97
-rw-r--r--bin/cli/Cargo.toml13
-rw-r--r--bin/cli/src/inference.rs23
-rw-r--r--bin/cli/src/main.rs30
-rw-r--r--libs/db/src/inference.rs24
-rw-r--r--libs/db/src/save.rs12
-rw-r--r--libs/db/src/trie.rs36
-rw-r--r--libs/db/src/types.rs11
15 files changed, 1025 insertions, 52 deletions
diff --git a/.gitignore b/.gitignore
index 57c0a60..c882eb8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
.env
target/
data/
+db.bin
+data/ \ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
index 31d9e28..403328c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -18,6 +18,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
+name = "aes"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
+dependencies = [
+ "cfg-if",
+ "cipher",
+ "cpufeatures",
+]
+
+[[package]]
name = "ahash"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -122,6 +133,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
+name = "base64ct"
+version = "1.7.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
+
+[[package]]
name = "bincode"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -164,11 +181,15 @@ name = "bot"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-trait",
+ "bincode",
"config",
"db",
+ "deepphonemizer",
"log",
"pretty_env_logger",
"serde",
+ "tch",
"tokio",
"tracing",
"tracing-subscriber",
@@ -189,17 +210,45 @@ dependencies = [
]
[[package]]
+name = "byteorder"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
+
+[[package]]
name = "bytes"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
+name = "bzip2"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8"
+dependencies = [
+ "bzip2-sys",
+ "libc",
+]
+
+[[package]]
+name = "bzip2-sys"
+version = "0.1.13+1.0.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14"
+dependencies = [
+ "cc",
+ "pkg-config",
+]
+
+[[package]]
name = "cc"
version = "1.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8691782945451c1c383942c4874dbe63814f61cb57ef773cda2972682b7bb3c0"
dependencies = [
+ "jobserver",
+ "libc",
"shlex",
]
@@ -216,6 +265,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
+name = "cipher"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
+dependencies = [
+ "crypto-common",
+ "inout",
+]
+
+[[package]]
+name = "cli"
+version = "0.1.0"
+dependencies = [
+ "async-trait",
+ "bincode",
+ "db",
+ "deepphonemizer",
+ "tch",
+ "tokio",
+]
+
+[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -259,12 +330,18 @@ version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e"
dependencies = [
- "getrandom",
+ "getrandom 0.2.16",
"once_cell",
"tiny-keccak",
]
[[package]]
+name = "constant_time_eq"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
+
+[[package]]
name = "convert_case"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -382,6 +459,21 @@ dependencies = [
]
[[package]]
+name = "deepphonemizer"
+version = "1.0.0"
+source = "git+https://github.com/MatthieuCoder/deepphonemizer-rs.git#e3bbb7911c57850e326e27fa453fb89a46031872"
+dependencies = [
+ "anyhow",
+ "dotenv",
+ "itertools",
+ "serde",
+ "serde-pickle",
+ "serde_json",
+ "serde_yaml",
+ "tch",
+]
+
+[[package]]
name = "deranged"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -398,6 +490,18 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
+ "subtle",
+]
+
+[[package]]
+name = "displaydoc"
+version = "0.2.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
]
[[package]]
@@ -410,6 +514,12 @@ dependencies = [
]
[[package]]
+name = "dotenv"
+version = "0.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
+
+[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -473,6 +583,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
+name = "form_urlencoded"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
+dependencies = [
+ "percent-encoding",
+]
+
+[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -529,7 +648,19 @@ checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
+]
+
+[[package]]
+name = "getrandom"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "r-efi",
+ "wasi 0.14.2+wasi-0.2.4",
]
[[package]]
@@ -558,6 +689,16 @@ dependencies = [
]
[[package]]
+name = "half"
+version = "2.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
+dependencies = [
+ "cfg-if",
+ "crunchy",
+]
+
+[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -592,6 +733,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e"
[[package]]
+name = "hmac"
+version = "0.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
+dependencies = [
+ "digest",
+]
+
+[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -702,6 +852,113 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b24ad5637230df201ab1034d593f1d09bf7f2a9274f2e8897638078579f4265"
[[package]]
+name = "icu_collections"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47"
+dependencies = [
+ "displaydoc",
+ "potential_utf",
+ "yoke",
+ "zerofrom",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_locale_core"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a"
+dependencies = [
+ "displaydoc",
+ "litemap",
+ "tinystr",
+ "writeable",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_normalizer"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979"
+dependencies = [
+ "displaydoc",
+ "icu_collections",
+ "icu_normalizer_data",
+ "icu_properties",
+ "icu_provider",
+ "smallvec",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_normalizer_data"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3"
+
+[[package]]
+name = "icu_properties"
+version = "2.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b"
+dependencies = [
+ "displaydoc",
+ "icu_collections",
+ "icu_locale_core",
+ "icu_properties_data",
+ "icu_provider",
+ "potential_utf",
+ "zerotrie",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_properties_data"
+version = "2.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632"
+
+[[package]]
+name = "icu_provider"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af"
+dependencies = [
+ "displaydoc",
+ "icu_locale_core",
+ "stable_deref_trait",
+ "tinystr",
+ "writeable",
+ "yoke",
+ "zerofrom",
+ "zerotrie",
+ "zerovec",
+]
+
+[[package]]
+name = "idna"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e"
+dependencies = [
+ "idna_adapter",
+ "smallvec",
+ "utf8_iter",
+]
+
+[[package]]
+name = "idna_adapter"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
+dependencies = [
+ "icu_normalizer",
+ "icu_properties",
+]
+
+[[package]]
name = "index"
version = "0.1.0"
dependencies = [
@@ -722,6 +979,15 @@ dependencies = [
]
[[package]]
+name = "inout"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
+dependencies = [
+ "generic-array",
+]
+
+[[package]]
name = "is-terminal"
version = "0.4.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -733,6 +999,12 @@ dependencies = [
]
[[package]]
+name = "iter-read"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
+
+[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -770,6 +1042,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
+name = "jobserver"
+version = "0.1.33"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a"
+dependencies = [
+ "getrandom 0.3.3",
+ "libc",
+]
+
+[[package]]
name = "json5"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -816,6 +1098,12 @@ dependencies = [
]
[[package]]
+name = "litemap"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
+
+[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -832,6 +1120,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
+name = "matrixmultiply"
+version = "0.3.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
+dependencies = [
+ "autocfg",
+ "rawpointer",
+]
+
+[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -853,11 +1151,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
[[package]]
+name = "ndarray"
+version = "0.16.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "portable-atomic",
+ "portable-atomic-util",
+ "rawpointer",
+]
+
+[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -868,12 +1181,40 @@ dependencies = [
]
[[package]]
+name = "num-bigint"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
+dependencies = [
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
+name = "num-complex"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
+name = "num-integer"
+version = "0.1.46"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -929,6 +1270,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
+name = "parking_lot"
+version = "0.12.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
+dependencies = [
+ "lock_api",
+ "parking_lot_core",
+]
+
+[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -942,12 +1293,35 @@ dependencies = [
]
[[package]]
+name = "password-hash"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
+dependencies = [
+ "base64ct",
+ "rand_core",
+ "subtle",
+]
+
+[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
+name = "pbkdf2"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
+dependencies = [
+ "digest",
+ "hmac",
+ "password-hash",
+ "sha2",
+]
+
+[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1017,6 +1391,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
+name = "portable-atomic"
+version = "1.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
+
+[[package]]
+name = "portable-atomic-util"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
+dependencies = [
+ "portable-atomic",
+]
+
+[[package]]
+name = "potential_utf"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585"
+dependencies = [
+ "zerovec",
+]
+
+[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1060,6 +1458,12 @@ dependencies = [
]
[[package]]
+name = "r-efi"
+version = "5.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
+
+[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1086,10 +1490,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
- "getrandom",
+ "getrandom 0.2.16",
]
[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
name = "redox_syscall"
version = "0.5.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1135,7 +1545,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
- "getrandom",
+ "getrandom 0.2.16",
"libc",
"untrusted",
"windows-sys 0.52.0",
@@ -1176,6 +1586,7 @@ version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [
+ "log",
"once_cell",
"ring",
"rustls-pki-types",
@@ -1247,6 +1658,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
+name = "safetensors"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
+[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1303,6 +1724,19 @@ dependencies = [
]
[[package]]
+name = "serde-pickle"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
+dependencies = [
+ "byteorder",
+ "iter-read",
+ "num-bigint",
+ "num-traits",
+ "serde",
+]
+
+[[package]]
name = "serde-value"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1356,6 +1790,30 @@ dependencies = [
]
[[package]]
+name = "serde_yaml"
+version = "0.9.34+deprecated"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
+dependencies = [
+ "indexmap",
+ "itoa",
+ "ryu",
+ "serde",
+ "unsafe-libyaml",
+]
+
+[[package]]
+name = "sha1"
+version = "0.10.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest",
+]
+
+[[package]]
name = "sha1_smol"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1388,6 +1846,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
+name = "signal-hook-registry"
+version = "1.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
+dependencies = [
+ "libc",
+]
+
+[[package]]
name = "simdutf8"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1419,6 +1886,12 @@ dependencies = [
]
[[package]]
+name = "stable_deref_trait"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
+
+[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1442,6 +1915,34 @@ dependencies = [
]
[[package]]
+name = "synstructure"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "tch"
+version = "0.20.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a760143efe7e4bb5b56e95d01f52ee6773bc315202e7c47db6a6429b0705a1f2"
+dependencies = [
+ "half",
+ "lazy_static",
+ "libc",
+ "ndarray",
+ "rand",
+ "safetensors",
+ "thiserror 1.0.69",
+ "torch-sys",
+ "zip",
+]
+
+[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1540,6 +2041,16 @@ dependencies = [
]
[[package]]
+name = "tinystr"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b"
+dependencies = [
+ "displaydoc",
+ "zerovec",
+]
+
+[[package]]
name = "tokio"
version = "1.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1549,7 +2060,9 @@ dependencies = [
"bytes",
"libc",
"mio",
+ "parking_lot",
"pin-project-lite",
+ "signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.52.0",
@@ -1647,6 +2160,21 @@ dependencies = [
]
[[package]]
+name = "torch-sys"
+version = "0.20.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac"
+dependencies = [
+ "anyhow",
+ "cc",
+ "libc",
+ "serde",
+ "serde_json",
+ "ureq",
+ "zip",
+]
+
+[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1845,6 +2373,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
+name = "unsafe-libyaml"
+version = "0.2.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
+
+[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1857,6 +2391,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
[[package]]
+name = "ureq"
+version = "2.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
+dependencies = [
+ "base64 0.22.1",
+ "flate2",
+ "log",
+ "once_cell",
+ "rustls",
+ "rustls-pki-types",
+ "serde",
+ "serde_json",
+ "url",
+ "webpki-roots 0.26.11",
+]
+
+[[package]]
+name = "url"
+version = "2.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60"
+dependencies = [
+ "form_urlencoded",
+ "idna",
+ "percent-encoding",
+]
+
+[[package]]
+name = "utf8_iter"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
+
+[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1906,6 +2475,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
+name = "wasi"
+version = "0.14.2+wasi-0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
+dependencies = [
+ "wit-bindgen-rt",
+]
+
+[[package]]
name = "webpki-root-certs"
version = "0.26.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1915,6 +2493,24 @@ dependencies = [
]
[[package]]
+name = "webpki-roots"
+version = "0.26.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
+dependencies = [
+ "webpki-roots 1.0.0",
+]
+
+[[package]]
+name = "webpki-roots"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb"
+dependencies = [
+ "rustls-pki-types",
+]
+
+[[package]]
name = "whatlang"
version = "0.16.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2113,6 +2709,21 @@ dependencies = [
]
[[package]]
+name = "wit-bindgen-rt"
+version = "0.39.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
+dependencies = [
+ "bitflags",
+]
+
+[[package]]
+name = "writeable"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb"
+
+[[package]]
name = "yaml-rust2"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2124,6 +2735,30 @@ dependencies = [
]
[[package]]
+name = "yoke"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc"
+dependencies = [
+ "serde",
+ "stable_deref_trait",
+ "yoke-derive",
+ "zerofrom",
+]
+
+[[package]]
+name = "yoke-derive"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "synstructure",
+]
+
+[[package]]
name = "zerocopy"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2164,7 +2799,110 @@ dependencies = [
]
[[package]]
+name = "zerofrom"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
+dependencies = [
+ "zerofrom-derive",
+]
+
+[[package]]
+name = "zerofrom-derive"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "synstructure",
+]
+
+[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
+
+[[package]]
+name = "zerotrie"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595"
+dependencies = [
+ "displaydoc",
+ "yoke",
+ "zerofrom",
+]
+
+[[package]]
+name = "zerovec"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428"
+dependencies = [
+ "yoke",
+ "zerofrom",
+ "zerovec-derive",
+]
+
+[[package]]
+name = "zerovec-derive"
+version = "0.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "zip"
+version = "0.6.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
+dependencies = [
+ "aes",
+ "byteorder",
+ "bzip2",
+ "constant_time_eq",
+ "crc32fast",
+ "crossbeam-utils",
+ "flate2",
+ "hmac",
+ "pbkdf2",
+ "sha1",
+ "time",
+ "zstd",
+]
+
+[[package]]
+name = "zstd"
+version = "0.11.2+zstd.1.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
+dependencies = [
+ "zstd-safe",
+]
+
+[[package]]
+name = "zstd-safe"
+version = "5.0.2+zstd.1.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
+dependencies = [
+ "libc",
+ "zstd-sys",
+]
+
+[[package]]
+name = "zstd-sys"
+version = "2.0.15+zstd.1.5.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237"
+dependencies = [
+ "cc",
+ "pkg-config",
+]
diff --git a/Cargo.toml b/Cargo.toml
index 1d94d9d..ae9edfb 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,5 +4,6 @@ resolver = "2"
members = [
"libs/db",
"bin/bot",
- "bin/index"
+ "bin/index",
+ "bin/cli"
]
diff --git a/Makefile b/Makefile
index d9599f3..7812149 100644
--- a/Makefile
+++ b/Makefile
@@ -6,22 +6,26 @@ LANGS := all
data/:
mkdir -p data/
+data/all.csv: data/
+ 0>data/all.csv
+ cat data/*.csv > data/all.csv
+
data/%.csv: ipa-dict/data/%.txt data/
tr '\t' ',' < $< > $@
data/all: $(NAMES)
.PHONY: data/all
-data/all.csv: data/$(LANGS:=.csv)
- 0>data/all.csv
- cat data/*.csv > data/all.csv
-
build-indexer: bin/index libs/
-
+db.bin: data/$(LANGS:=.csv) build-indexer
+ cat data/$(LANGS:=.csv) | cargo run --bin index --release > db.bin
+
+data/latin_ipa_forward.pt: data/
+ wget "https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/latin_ipa_forward.pt" -O data/latin_ipa_forward.pt
-data/db.bin: data/all.csv build-indexer
- cat data/all.csv | cargo run --bin index --release > data/all.csv
+data/model.pt: data/latin_ipa_forward.pt
+ ./utils/export.py
clean:
rm -r data/
-.PHONY: clean
+.PHONY: clean \ No newline at end of file
diff --git a/bin/bot/Cargo.toml b/bin/bot/Cargo.toml
new file mode 100644
index 0000000..708a831
--- /dev/null
+++ b/bin/bot/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "bot"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+anyhow = "1.0.98"
+config = "0.15.11"
+log = "0.4.27"
+pretty_env_logger = "0.5.0"
+serde = "1.0.219"
+tokio = { version = "1.44.2", features = ["rt-multi-thread"] }
+tracing = "0.1.41"
+tracing-subscriber = "0.3.19"
+twilight-cache-inmemory = "0.16.0"
+twilight-gateway = "0.16.0"
+twilight-http = "0.16.0"
+twilight-model = "0.16.0"
+bincode = { version = "2.0.1", features = ["serde"] }
+
+db = { path = "../../libs/db" }
+deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" }
+async-trait = "0.1.88"
+
+tch = {version = "0.20.0", features = ["download-libtorch"]}
diff --git a/bin/bot/src/bot_config.rs b/bin/bot/src/bot_config.rs
new file mode 100644
index 0000000..af6d714
--- /dev/null
+++ b/bin/bot/src/bot_config.rs
@@ -0,0 +1,10 @@
+use serde::{Serialize, Deserialize};
+use std::string::String;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Config {
+ pub token: String,
+ pub model_path: String,
+ pub config_path: String,
+ pub db_path: String,
+}
diff --git a/bin/bot/src/inference.rs b/bin/bot/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/bot/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/bot/src/main.rs b/bin/bot/src/main.rs
new file mode 100644
index 0000000..574b417
--- /dev/null
+++ b/bin/bot/src/main.rs
@@ -0,0 +1,97 @@
+use deepphonemizer::Phonemizer;
+use inference::DeepPhonemizerInference;
+use std::{error::Error, sync::Arc};
+use tch::Device;
+use twilight_cache_inmemory::{DefaultInMemoryCache, ResourceType};
+use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
+use twilight_http::Client as HttpClient;
+
+mod bot_config;
+mod inference;
+
+use db::{save::Save, types::GraphemeString};
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
+ // Initialize the tracing subscriber.
+ tracing_subscriber::fmt::init();
+
+ let config = config::Config::builder()
+ .add_source(config::Environment::with_prefix("GRU"))
+ .build()?
+ .try_deserialize::<bot_config::Config>()?;
+ let token = config.token;
+
+ // Specify intents requesting events about things like new and updated
+ // messages in a guild and direct messages.
+
+ // Create a single shard.
+ let mut shard = Shard::new(
+ ShardId::ONE,
+ token.clone(),
+ Intents::GUILD_MESSAGES | Intents::DIRECT_MESSAGES | Intents::MESSAGE_CONTENT,
+ );
+
+ // The http client is separate from the gateway, so startup a new
+ // one, also use Arc such that it can be cloned to other threads.
+ let http = Arc::new(HttpClient::new(token));
+
+ // Since we only care about messages, make the cache only process messages.
+ let cache = DefaultInMemoryCache::builder()
+ .resource_types(ResourceType::MESSAGE)
+ .build();
+
+ let mut db = std::fs::File::open(config.db_path).unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ config.model_path,
+ config.config_path,
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ // Startup the event loop to process each event in the event stream as they
+ // come in.
+ while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
+ let Ok(event) = item else {
+ tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
+
+ continue;
+ };
+ // Update the cache.
+ cache.update(&event);
+
+ // Spawn a new task to handle the event
+ tokio::spawn(handle_event(event, Arc::clone(&http), db.clone()));
+ }
+
+ Ok(())
+}
+
+async fn handle_event(
+ event: Event,
+ http: Arc<HttpClient>,
+ inference: Arc<db::inference::Inference>,
+) -> Result<(), Box<dyn Error + Send + Sync>> {
+ match event {
+ Event::MessageCreate(msg) => {
+ let infer = inference.infer(GraphemeString(msg.content.clone())).await.unwrap();
+
+ println!("{}", infer);
+ }
+ Event::Ready(_) => {
+ println!("Shard is ready");
+ }
+ _ => {}
+ }
+
+ Ok(())
+}
diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml
new file mode 100644
index 0000000..cc2b393
--- /dev/null
+++ b/bin/cli/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "cli"
+version = "0.1.0"
+edition = "2024"
+
+[dependencies]
+
+tokio = { version = "1.44.2", features = ["full"] }
+db = { path = "../../libs/db" }
+deepphonemizer = { git = "https://github.com/MatthieuCoder/deepphonemizer-rs.git" }
+async-trait = "0.1.88"
+bincode = { version = "2.0.1", features = ["serde"] }
+tch = {version = "0.20.0", features = ["download-libtorch"]}
diff --git a/bin/cli/src/inference.rs b/bin/cli/src/inference.rs
new file mode 100644
index 0000000..f78f381
--- /dev/null
+++ b/bin/cli/src/inference.rs
@@ -0,0 +1,23 @@
+use async_trait::async_trait;
+use db::types::GraphemeString;
+use db::{
+ inference::InferenceError,
+ types::{InferenceService, PhonemeString},
+};
+use deepphonemizer::Phonemizer;
+
+#[derive(Debug)]
+pub struct DeepPhonemizerInference(pub Phonemizer);
+
+#[async_trait]
+impl InferenceService<InferenceError> for DeepPhonemizerInference {
+ async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, InferenceError> {
+ let pronemized = self
+ .0
+ .phonemize(prefix.0, "fr_fr".to_string(), "", true, 1).unwrap();
+
+ return Ok(PhonemeString(
+ PhonemeString(format!("{:?}", pronemized)).to_string(),
+ ));
+ }
+}
diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs
new file mode 100644
index 0000000..736ba8e
--- /dev/null
+++ b/bin/cli/src/main.rs
@@ -0,0 +1,30 @@
+use db::{save::Save, types::GraphemeString};
+use deepphonemizer::Phonemizer;
+use tch::Device;
+use std::sync::Arc;
+
+mod inference;
+
+use crate::inference::DeepPhonemizerInference;
+
+#[tokio::main]
+async fn main() {
+ let mut db = std::fs::File::open("db.bin").unwrap();
+
+ let save: Save =
+ bincode::serde::decode_from_std_read(&mut db, bincode::config::standard()).unwrap();
+
+ let phonemizer = Phonemizer::from_checkpoint(
+ "data/model.pt",
+ "data/forward_config.yaml",
+ Device::cuda_if_available(),
+ None,
+ )
+ .unwrap();
+
+ let inference_service = Arc::new(DeepPhonemizerInference(phonemizer));
+
+ let db = Arc::new(db::inference::Inference::new(inference_service, save));
+
+ println!("{}", db.infer(GraphemeString("bon".to_string())).await.unwrap());
+}
diff --git a/libs/db/src/inference.rs b/libs/db/src/inference.rs
index f2cbabf..25415df 100644
--- a/libs/db/src/inference.rs
+++ b/libs/db/src/inference.rs
@@ -5,7 +5,7 @@ use crate::{
use hypher::{Lang, hyphenate};
use itertools::Itertools;
use log::*;
-use std::{iter, ops::Add};
+use std::{iter, ops::Add, sync::Arc};
use thiserror::Error;
use unicode_segmentation::UnicodeSegmentation;
@@ -15,13 +15,14 @@ pub enum InferenceError {
TrieError(),
#[error("the matched value wasn't in the dictionary")]
NotInDictionary(),
- #[error("failt to cut the word")]
+ #[error("failed to cut the word")]
WordCutError(),
}
-pub struct Inference<'a> {
- inference_service: Box<dyn InferenceService<InferenceError>>,
- save: Save<'a>,
+#[derive(Debug, Clone)]
+pub struct Inference {
+ inference_service: Arc<dyn InferenceService<InferenceError>>,
+ save: Save,
}
pub struct TLang(whatlang::Lang);
@@ -62,7 +63,7 @@ impl Into<Lang> for TLang {
}
}
-impl Inference<'_> {
+impl Inference {
fn matches(&self, source: &str, complete: &str) -> bool {
let source_chars: Vec<&str> = source.graphemes(true).collect();
let complete_chars: Vec<&str> = complete.graphemes(true).collect();
@@ -79,10 +80,10 @@ impl Inference<'_> {
return true;
}
- pub fn new<'a>(
- inference_service: Box<dyn InferenceService<InferenceError>>,
- save: Save<'a>,
- ) -> Inference<'a> {
+ pub fn new(
+ inference_service: Arc<dyn InferenceService<InferenceError>>,
+ save: Save,
+ ) -> Inference {
return Inference {
inference_service,
save,
@@ -112,8 +113,7 @@ impl Inference<'_> {
let completion = self
.save
.trie
- .random_starting_with(phonemes.clone())
- .ok_or_else(InferenceError::TrieError)?;
+ .random_starting_with(phonemes.clone()).unwrap();
// store the found word
let initial = cprefix.add(&completion.0);
diff --git a/libs/db/src/save.rs b/libs/db/src/save.rs
index e1b3577..87a2736 100644
--- a/libs/db/src/save.rs
+++ b/libs/db/src/save.rs
@@ -1,10 +1,12 @@
-use crate::{trie::Trie, types::{GraphemeString, PhonemeString}};
+use crate::{
+ trie::Trie,
+ types::{GraphemeString, PhonemeString},
+};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
-#[derive(Debug, Deserialize, Serialize, Default)]
-pub struct Save<'a> {
- #[serde(borrow = "'a")]
- pub trie: Trie<'a>,
+#[derive(Debug, Deserialize, Serialize, Default, Clone)]
+pub struct Save {
+ pub trie: Trie,
pub reverse_index: HashMap<PhonemeString, GraphemeString>,
}
diff --git a/libs/db/src/trie.rs b/libs/db/src/trie.rs
index 38cbcd8..754d978 100644
--- a/libs/db/src/trie.rs
+++ b/libs/db/src/trie.rs
@@ -5,16 +5,15 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use unicode_segmentation::UnicodeSegmentation;
-#[derive(Debug, Serialize, Deserialize, Default)]
-pub struct TrieNode<'a> {
+#[derive(Debug, Serialize, Deserialize, Default, Clone)]
+pub struct TrieNode {
is_final: bool,
- #[serde(borrow = "'a")]
- child_nodes: HashMap<&'a str, TrieNode<'a>>,
+ child_nodes: HashMap<String, TrieNode>,
child_count: u64,
}
-impl<'a> TrieNode<'a> {
- pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
+impl TrieNode {
+ pub fn new<'b>(is_final: bool) -> TrieNode {
TrieNode {
is_final,
child_nodes: HashMap::new(),
@@ -22,7 +21,7 @@ impl<'a> TrieNode<'a> {
}
}
- pub fn new_root<'b>() -> TrieNode<'b> {
+ pub fn new_root<'b>() -> TrieNode {
TrieNode {
is_final: false,
child_nodes: HashMap::new(),
@@ -31,15 +30,14 @@ impl<'a> TrieNode<'a> {
}
}
-#[derive(Debug, Serialize, Deserialize, Default)]
-pub struct Trie<'a> {
- #[serde(borrow = "'a")]
- root_node: Box<TrieNode<'a>>,
+#[derive(Debug, Serialize, Deserialize, Default, Clone)]
+pub struct Trie {
+ root_node: Box<TrieNode>,
}
-impl<'a> Trie<'a> {
+impl<'a> Trie {
// Create a TrieStruct
- pub fn new<'b>() -> Trie<'b> {
+ pub fn new<'b>() -> Trie {
Trie {
root_node: Box::new(TrieNode::new_root()),
}
@@ -62,12 +60,16 @@ impl<'a> Trie<'a> {
} else {
create = true;
- current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node
+ .child_nodes
+ .insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
current_node.child_count = 1;
}
} else {
- current_node.child_nodes.insert(str, TrieNode::new(false));
+ current_node
+ .child_nodes
+ .insert(str.to_string(), TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we will only have one final node
current_node.child_count = 1;
@@ -88,7 +90,7 @@ impl<'a> Trie<'a> {
// Descend as far as possible into the tree
for (_, str) in graphemes {
// If we can descend further into the tree
- if let Some(node) = current_node.child_nodes.get(&str) {
+ if let Some(node) = current_node.child_nodes.get(str) {
current_node = node;
builder += str;
debug!("going into node {}", builder);
@@ -124,7 +126,7 @@ impl<'a> Trie<'a> {
.child_nodes
.iter()
.nth(weighted.sample(&mut rng))
- .expect("choosed value did not exist");
+ .expect("choose value did not exist");
debug!("walking into node {}", key);
current_node = node;
diff --git a/libs/db/src/types.rs b/libs/db/src/types.rs
index 6062c98..636fcb7 100644
--- a/libs/db/src/types.rs
+++ b/libs/db/src/types.rs
@@ -1,12 +1,15 @@
-use std::{error::Error, fmt::Display};
-
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
+use std::{
+ error::Error,
+ fmt::{Debug, Display},
+ sync::Arc,
+};
// This is needed since we need dynamic dispatch
// as any kind of inference service could be used.
#[async_trait]
-pub trait InferenceService<E: Error> : Send + Sync {
+pub trait InferenceService<E: Error>: Send + Sync + Debug {
async fn fetch(&self, prefix: GraphemeString) -> Result<PhonemeString, E>;
}
@@ -36,7 +39,7 @@ impl PhonemeString {
/// !! This is a one-way operation.
pub async fn from_grapheme<E: Error>(
value: GraphemeString,
- inference: &Box<dyn InferenceService<E>>,
+ inference: &Arc<dyn InferenceService<E>>,
) -> Result<Self, E> {
inference.fetch(value).await
}