Compare commits
29 commits
4603947506
...
592a3e2e52
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
592a3e2e52 | ||
|
|
dd551fe551 | ||
|
|
18b7fd0535 | ||
|
|
60de579305 | ||
|
|
28484a385b | ||
|
|
3e05331608 | ||
|
|
2989a6afaa | ||
|
|
0e6b5dc8be | ||
|
|
2eddf3b4cf | ||
|
|
7ef02c97d1 | ||
|
|
313f85f34a | ||
|
|
343e43afab | ||
|
|
d5a3398cc9 | ||
|
|
080b4f9084 | ||
|
|
77822992c8 | ||
|
|
e5dd8312c7 | ||
|
|
ac40c2cb98 | ||
|
|
2b632d568b | ||
|
|
5d9d3ffc5b | ||
|
|
50b7b3a33a | ||
|
|
2c6a5c0f4a | ||
|
|
68a2df2185 | ||
|
|
039473d31f | ||
|
|
78fa4b639f | ||
|
|
7e7e9a4b69 | ||
|
|
2f08149fab | ||
|
|
a73bcf5ae3 | ||
|
|
b649a11645 | ||
|
|
81e0632cf3 |
41 changed files with 3255 additions and 1973 deletions
199
Cargo.lock
generated
199
Cargo.lock
generated
|
|
@ -492,11 +492,12 @@ dependencies = [
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"json5",
|
"json-five",
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"memchr",
|
"memchr",
|
||||||
"memmap2",
|
"memmap2",
|
||||||
|
"notify-debouncer-mini",
|
||||||
"paste",
|
"paste",
|
||||||
"peg",
|
"peg",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
|
|
@ -1088,6 +1089,15 @@ version = "1.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fsevent-sys"
|
||||||
|
version = "4.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
|
|
@ -1453,6 +1463,26 @@ version = "0.1.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
|
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "inotify"
|
||||||
|
version = "0.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.11.0",
|
||||||
|
"inotify-sys",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "inotify-sys"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "instability"
|
name = "instability"
|
||||||
version = "0.3.12"
|
version = "0.3.12"
|
||||||
|
|
@ -1531,6 +1561,16 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "json-five"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "865f2d01a4549c1fd8c60640c03ae5249eb374cd8cde8b905628d4b1af95c87c"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"unicode-general-category",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "json5"
|
name = "json5"
|
||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
|
|
@ -1552,6 +1592,26 @@ dependencies = [
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "kqueue"
|
||||||
|
version = "1.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a"
|
||||||
|
dependencies = [
|
||||||
|
"kqueue-sys",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "kqueue-sys"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lab"
|
name = "lab"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
|
@ -1774,6 +1834,45 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "notify"
|
||||||
|
version = "8.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.11.0",
|
||||||
|
"fsevent-sys",
|
||||||
|
"inotify",
|
||||||
|
"kqueue",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"mio",
|
||||||
|
"notify-types",
|
||||||
|
"walkdir",
|
||||||
|
"windows-sys 0.60.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "notify-debouncer-mini"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "17849edfaabd9a5fef1c606d99cfc615a8e99f7ac4366406d86c7942a3184cf2"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"notify",
|
||||||
|
"notify-types",
|
||||||
|
"tempfile",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "notify-types"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.11.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-conv"
|
name = "num-conv"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
|
@ -3384,6 +3483,12 @@ version = "2.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
|
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-general-category"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b993bddc193ae5bd0d623b49ec06ac3e9312875fdae725a975c51db1cc1677f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.24"
|
version = "1.0.24"
|
||||||
|
|
@ -3794,7 +3899,16 @@ version = "0.52.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets",
|
"windows-targets 0.52.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.60.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets 0.53.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3812,14 +3926,31 @@ version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_gnullvm",
|
"windows_aarch64_gnullvm 0.52.6",
|
||||||
"windows_aarch64_msvc",
|
"windows_aarch64_msvc 0.52.6",
|
||||||
"windows_i686_gnu",
|
"windows_i686_gnu 0.52.6",
|
||||||
"windows_i686_gnullvm",
|
"windows_i686_gnullvm 0.52.6",
|
||||||
"windows_i686_msvc",
|
"windows_i686_msvc 0.52.6",
|
||||||
"windows_x86_64_gnu",
|
"windows_x86_64_gnu 0.52.6",
|
||||||
"windows_x86_64_gnullvm",
|
"windows_x86_64_gnullvm 0.52.6",
|
||||||
"windows_x86_64_msvc",
|
"windows_x86_64_msvc 0.52.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-targets"
|
||||||
|
version = "0.53.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
|
||||||
|
dependencies = [
|
||||||
|
"windows-link",
|
||||||
|
"windows_aarch64_gnullvm 0.53.1",
|
||||||
|
"windows_aarch64_msvc 0.53.1",
|
||||||
|
"windows_i686_gnu 0.53.1",
|
||||||
|
"windows_i686_gnullvm 0.53.1",
|
||||||
|
"windows_i686_msvc 0.53.1",
|
||||||
|
"windows_x86_64_gnu 0.53.1",
|
||||||
|
"windows_x86_64_gnullvm 0.53.1",
|
||||||
|
"windows_x86_64_msvc 0.53.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3828,48 +3959,96 @@ version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_gnullvm"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_msvc"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnu"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnullvm"
|
name = "windows_i686_gnullvm"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_msvc"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnu"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnullvm"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_msvc"
|
||||||
|
version = "0.53.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wit-bindgen"
|
name = "wit-bindgen"
|
||||||
version = "0.51.0"
|
version = "0.51.0"
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,8 @@ log = "0.4"
|
||||||
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
json5 = "1.3"
|
json-five = "0.3"
|
||||||
|
notify-debouncer-mini = "0.7"
|
||||||
|
|
||||||
ratatui = { version = "0.30", features = ["unstable-rendered-line-info"] }
|
ratatui = { version = "0.30", features = ["unstable-rendered-line-info"] }
|
||||||
tui-markdown = { git = "https://github.com/koverstreet/tui-markdown", subdirectory = "tui-markdown" }
|
tui-markdown = { git = "https://github.com/koverstreet/tui-markdown", subdirectory = "tui-markdown" }
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ pub struct NodeLeaf {
|
||||||
body: NodeBody,
|
body: NodeBody,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
token_ids: Vec<u32>,
|
token_ids: Vec<u32>,
|
||||||
timestamp: Option<DateTime<Utc>>,
|
timestamp: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for NodeLeaf {
|
impl<'de> Deserialize<'de> for NodeLeaf {
|
||||||
|
|
@ -100,7 +100,7 @@ impl<'de> Deserialize<'de> for NodeLeaf {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Raw {
|
struct Raw {
|
||||||
body: NodeBody,
|
body: NodeBody,
|
||||||
timestamp: Option<DateTime<Utc>>,
|
timestamp: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
let raw = Raw::deserialize(deserializer)?;
|
let raw = Raw::deserialize(deserializer)?;
|
||||||
let token_ids = if raw.body.is_prompt_visible() {
|
let token_ids = if raw.body.is_prompt_visible() {
|
||||||
|
|
@ -119,6 +119,7 @@ pub enum AstNode {
|
||||||
Branch {
|
Branch {
|
||||||
role: Role,
|
role: Role,
|
||||||
children: Vec<AstNode>,
|
children: Vec<AstNode>,
|
||||||
|
timestamp: DateTime<Utc>,
|
||||||
/// Per-response memory attribution from full scoring matrix.
|
/// Per-response memory attribution from full scoring matrix.
|
||||||
/// Maps memory key → divergence score for this response.
|
/// Maps memory key → divergence score for this response.
|
||||||
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
|
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
|
||||||
|
|
@ -252,18 +253,18 @@ impl NodeLeaf {
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
};
|
};
|
||||||
Self { body, token_ids, timestamp: None }
|
Self { body, token_ids, timestamp: Utc::now() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||||||
self.timestamp = Some(ts);
|
self.timestamp = ts;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn body(&self) -> &NodeBody { &self.body }
|
pub fn body(&self) -> &NodeBody { &self.body }
|
||||||
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
|
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
|
||||||
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
||||||
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp }
|
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AstNode {
|
impl AstNode {
|
||||||
|
|
@ -307,13 +308,14 @@ impl AstNode {
|
||||||
// -- Branch constructors --------------------------------------------------
|
// -- Branch constructors --------------------------------------------------
|
||||||
|
|
||||||
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
|
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
|
||||||
Self::Branch { role, children, memory_scores: Default::default() }
|
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn system_msg(text: impl Into<String>) -> Self {
|
pub fn system_msg(text: impl Into<String>) -> Self {
|
||||||
Self::Branch {
|
Self::Branch {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
children: vec![Self::content(text)],
|
children: vec![Self::content(text)],
|
||||||
|
timestamp: Utc::now(),
|
||||||
memory_scores: Default::default(),
|
memory_scores: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -322,6 +324,7 @@ impl AstNode {
|
||||||
Self::Branch {
|
Self::Branch {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
children: vec![Self::content(text)],
|
children: vec![Self::content(text)],
|
||||||
|
timestamp: Utc::now(),
|
||||||
memory_scores: Default::default(),
|
memory_scores: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -338,9 +341,10 @@ impl AstNode {
|
||||||
};
|
};
|
||||||
Self::Leaf(NodeLeaf { token_ids, ..leaf })
|
Self::Leaf(NodeLeaf { token_ids, ..leaf })
|
||||||
}
|
}
|
||||||
Self::Branch { role, children, memory_scores, .. } => Self::Branch {
|
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
|
||||||
role,
|
role,
|
||||||
children: children.into_iter().map(|c| c.retokenize()).collect(),
|
children: children.into_iter().map(|c| c.retokenize()).collect(),
|
||||||
|
timestamp,
|
||||||
memory_scores,
|
memory_scores,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -348,8 +352,8 @@ impl AstNode {
|
||||||
|
|
||||||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||||||
match &mut self {
|
match &mut self {
|
||||||
Self::Leaf(leaf) => leaf.timestamp = Some(ts),
|
Self::Leaf(leaf) => leaf.timestamp = ts,
|
||||||
Self::Branch { .. } => {}
|
Self::Branch { timestamp, .. } => *timestamp = ts,
|
||||||
}
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
@ -370,7 +374,7 @@ impl AstNode {
|
||||||
|
|
||||||
/// Short label for the UI.
|
/// Short label for the UI.
|
||||||
pub fn label(&self) -> String {
|
pub fn label(&self) -> String {
|
||||||
let cfg = crate::config::get();
|
let app = crate::config::app();
|
||||||
match self {
|
match self {
|
||||||
Self::Branch { role, children, .. } => {
|
Self::Branch { role, children, .. } => {
|
||||||
let preview = children.first()
|
let preview = children.first()
|
||||||
|
|
@ -379,8 +383,8 @@ impl AstNode {
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
match role {
|
match role {
|
||||||
Role::System => "system".into(),
|
Role::System => "system".into(),
|
||||||
Role::User => format!("{}: {}", cfg.user_name, preview),
|
Role::User => format!("{}: {}", app.user_name, preview),
|
||||||
Role::Assistant => format!("{}: {}", cfg.assistant_name, preview),
|
Role::Assistant => format!("{}: {}", app.assistant_name, preview),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Self::Leaf(leaf) => match &leaf.body {
|
Self::Leaf(leaf) => match &leaf.body {
|
||||||
|
|
@ -988,7 +992,10 @@ impl ContextState {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context_window() -> usize {
|
pub fn context_window() -> usize {
|
||||||
crate::config::get().api_context_window
|
let app = crate::config::app();
|
||||||
|
app.backends.get(&app.default_backend)
|
||||||
|
.and_then(|b| b.context_window)
|
||||||
|
.unwrap_or(128_000)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context_budget_tokens() -> usize {
|
pub fn context_budget_tokens() -> usize {
|
||||||
|
|
@ -1340,4 +1347,35 @@ mod tests {
|
||||||
assert_token_invariants(node);
|
assert_token_invariants(node);
|
||||||
assert!(node.tokens() > 0);
|
assert!(node.tokens() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -- Timestamp deserialization tests ------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_timestamp_null_rejected() {
|
||||||
|
// Missing/null timestamps used to be accepted via a lenient
|
||||||
|
// deserialize fallback. Post-migration the schema is strict.
|
||||||
|
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
|
||||||
|
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_timestamp_missing_rejected() {
|
||||||
|
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
|
||||||
|
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_branch_timestamp_missing_rejected() {
|
||||||
|
let json = r#"{"Branch":{"role":"User","children":[]}}"#;
|
||||||
|
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_timestamp_present_accepted() {
|
||||||
|
let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#;
|
||||||
|
let node: AstNode = serde_json::from_str(json).unwrap();
|
||||||
|
let leaf = node.leaf().unwrap();
|
||||||
|
assert_eq!(leaf.timestamp().to_rfc3339(),
|
||||||
|
"2026-04-16T12:00:00+00:00");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,6 @@ impl DispatchState {
|
||||||
pub struct Agent {
|
pub struct Agent {
|
||||||
pub client: ApiClient,
|
pub client: ApiClient,
|
||||||
pub app_config: crate::config::AppConfig,
|
pub app_config: crate::config::AppConfig,
|
||||||
pub prompt_file: String,
|
|
||||||
pub session_id: String,
|
pub session_id: String,
|
||||||
pub context: crate::Mutex<ContextState>,
|
pub context: crate::Mutex<ContextState>,
|
||||||
pub state: crate::Mutex<AgentState>,
|
pub state: crate::Mutex<AgentState>,
|
||||||
|
|
@ -189,7 +188,6 @@ impl Agent {
|
||||||
client: ApiClient,
|
client: ApiClient,
|
||||||
personality: Vec<(String, String)>,
|
personality: Vec<(String, String)>,
|
||||||
app_config: crate::config::AppConfig,
|
app_config: crate::config::AppConfig,
|
||||||
prompt_file: String,
|
|
||||||
conversation_log: Option<ConversationLog>,
|
conversation_log: Option<ConversationLog>,
|
||||||
active_tools: tools::ActiveTools,
|
active_tools: tools::ActiveTools,
|
||||||
agent_tools: Vec<tools::Tool>,
|
agent_tools: Vec<tools::Tool>,
|
||||||
|
|
@ -220,7 +218,6 @@ impl Agent {
|
||||||
let agent = Arc::new(Self {
|
let agent = Arc::new(Self {
|
||||||
client,
|
client,
|
||||||
app_config,
|
app_config,
|
||||||
prompt_file,
|
|
||||||
session_id,
|
session_id,
|
||||||
context: crate::Mutex::new(context),
|
context: crate::Mutex::new(context),
|
||||||
state: crate::Mutex::new(AgentState {
|
state: crate::Mutex::new(AgentState {
|
||||||
|
|
@ -259,7 +256,6 @@ impl Agent {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
client: self.client.clone(),
|
client: self.client.clone(),
|
||||||
app_config: self.app_config.clone(),
|
app_config: self.app_config.clone(),
|
||||||
prompt_file: self.prompt_file.clone(),
|
|
||||||
session_id: self.session_id.clone(),
|
session_id: self.session_id.clone(),
|
||||||
context: crate::Mutex::new(ctx),
|
context: crate::Mutex::new(ctx),
|
||||||
state: crate::Mutex::new(AgentState {
|
state: crate::Mutex::new(AgentState {
|
||||||
|
|
|
||||||
|
|
@ -183,8 +183,8 @@ fn resolve_prompt(
|
||||||
state: &std::collections::BTreeMap<String, String>,
|
state: &std::collections::BTreeMap<String, String>,
|
||||||
recently_written: &[String],
|
recently_written: &[String],
|
||||||
) -> String {
|
) -> String {
|
||||||
let cfg = crate::config::get();
|
let template = template.replace("{assistant_name}",
|
||||||
let template = template.replace("{assistant_name}", &cfg.assistant_name);
|
&crate::config::app().assistant_name);
|
||||||
let mut result = String::with_capacity(template.len());
|
let mut result = String::with_capacity(template.len());
|
||||||
let mut rest = template.as_str();
|
let mut rest = template.as_str();
|
||||||
while let Some(start) = rest.find("{{") {
|
while let Some(start) = rest.find("{{") {
|
||||||
|
|
@ -247,25 +247,20 @@ impl AutoAgent {
|
||||||
&mut self,
|
&mut self,
|
||||||
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
|
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let config = crate::config::get();
|
// Load system prompt + identity from config.
|
||||||
let base_url = config.api_base_url.as_deref().unwrap_or("");
|
|
||||||
let api_key = config.api_key.as_deref().unwrap_or("");
|
|
||||||
let model = config.api_model.as_deref().unwrap_or("");
|
|
||||||
if base_url.is_empty() || model.is_empty() {
|
|
||||||
return Err("API not configured (no base_url or model)".to_string());
|
|
||||||
}
|
|
||||||
let client = super::api::ApiClient::new(base_url, api_key, model);
|
|
||||||
|
|
||||||
// Load system prompt + identity from config
|
|
||||||
let cli = crate::user::CliArgs::default();
|
let cli = crate::user::CliArgs::default();
|
||||||
let (app, _) = crate::config::load_app(&cli)
|
let (app, _) = crate::config::load_app(&cli)
|
||||||
.map_err(|e| format!("config: {}", e))?;
|
.map_err(|e| format!("config: {}", e))?;
|
||||||
|
let resolved = app.resolve_model(&app.default_backend)
|
||||||
|
.map_err(|e| format!("API not configured: {}", e))?;
|
||||||
|
let client = super::api::ApiClient::new(
|
||||||
|
&resolved.api_base, &resolved.api_key, &resolved.model_id);
|
||||||
let personality = crate::config::reload_context()
|
let personality = crate::config::reload_context()
|
||||||
.await.map_err(|e| format!("config: {}", e))?;
|
.await.map_err(|e| format!("config: {}", e))?;
|
||||||
|
|
||||||
let agent = Agent::new(
|
let agent = Agent::new(
|
||||||
client, personality,
|
client, personality,
|
||||||
app, String::new(),
|
app,
|
||||||
None,
|
None,
|
||||||
super::tools::ActiveTools::new(),
|
super::tools::ActiveTools::new(),
|
||||||
super::tools::tools(),
|
super::tools::tools(),
|
||||||
|
|
@ -497,15 +492,20 @@ pub async fn run_one_agent(
|
||||||
.map(|s| s.phase.clone()).collect();
|
.map(|s| s.phase.clone()).collect();
|
||||||
|
|
||||||
// Bail check: if the agent defines a bail script, run it between steps.
|
// Bail check: if the agent defines a bail script, run it between steps.
|
||||||
|
// The script also refreshes our pid-file with the current phase — that's
|
||||||
|
// how concurrent agents know which phase each of us is in.
|
||||||
let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name));
|
let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name));
|
||||||
let state_dir_for_bail = state_dir.clone();
|
let state_dir_for_bail = state_dir.clone();
|
||||||
// Find our own pid file so we can pass it to the bail script
|
|
||||||
let our_pid = std::process::id();
|
let our_pid = std::process::id();
|
||||||
let our_pid_file = format!("pid-{}", our_pid);
|
let our_pid_file = format!("pid-{}", our_pid);
|
||||||
|
let step_phases_for_bail = step_phases.clone();
|
||||||
let bail_fn = move |step_idx: usize| -> Result<(), String> {
|
let bail_fn = move |step_idx: usize| -> Result<(), String> {
|
||||||
if let Some(ref script) = bail_script {
|
if let Some(ref script) = bail_script {
|
||||||
|
let phase = step_phases_for_bail.get(step_idx)
|
||||||
|
.map(String::as_str).unwrap_or("");
|
||||||
let status = std::process::Command::new(script)
|
let status = std::process::Command::new(script)
|
||||||
.arg(&our_pid_file)
|
.arg(&our_pid_file)
|
||||||
|
.arg(phase)
|
||||||
.current_dir(&state_dir_for_bail)
|
.current_dir(&state_dir_for_bail)
|
||||||
.status()
|
.status()
|
||||||
.map_err(|e| format!("bail script {:?} failed: {}", script, e))?;
|
.map_err(|e| format!("bail script {:?} failed: {}", script, e))?;
|
||||||
|
|
|
||||||
180
src/bin/fix-timestamps.rs
Normal file
180
src/bin/fix-timestamps.rs
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
// fix-timestamps: One-off migration for ~/.consciousness/agent-sessions/
|
||||||
|
// conversation.jsonl.
|
||||||
|
//
|
||||||
|
// Before Branch nodes carried their own timestamps, early entries were
|
||||||
|
// serialized with missing/null timestamp fields — they deserialize as
|
||||||
|
// UNIX_EPOCH via the (now-to-be-removed) deserialize_timestamp_or_epoch
|
||||||
|
// fallback. Training needs every entry to have a unique timestamp to
|
||||||
|
// dedup already-trained responses.
|
||||||
|
//
|
||||||
|
// Walks the file, synthesizes timestamps for any entry stuck at
|
||||||
|
// UNIX_EPOCH by linear interpolation between surrounding real
|
||||||
|
// timestamps. For child leaves inside a Branch, derives timestamps
|
||||||
|
// from the parent with a tiny per-child offset.
|
||||||
|
//
|
||||||
|
// SAFETY: reads from argv[1], writes to argv[1].tmp, renames into
|
||||||
|
// place. Keep a .bak copy before running.
|
||||||
|
//
|
||||||
|
// Usage: fix-timestamps <path-to-conversation.jsonl>
|
||||||
|
|
||||||
|
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
|
||||||
|
use consciousness::agent::context::AstNode;
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let path: PathBuf = std::env::args().nth(1)
|
||||||
|
.context("usage: fix-timestamps <path>")?.into();
|
||||||
|
|
||||||
|
let f = std::fs::File::open(&path)
|
||||||
|
.with_context(|| format!("open {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(f);
|
||||||
|
|
||||||
|
let mut nodes: Vec<AstNode> = Vec::new();
|
||||||
|
for (i, line) in reader.lines().enumerate() {
|
||||||
|
let line = line?;
|
||||||
|
if line.trim().is_empty() { continue; }
|
||||||
|
let node: AstNode = serde_json::from_str(&line)
|
||||||
|
.with_context(|| format!("line {}: parse", i + 1))?;
|
||||||
|
nodes.push(node);
|
||||||
|
}
|
||||||
|
println!("read {} entries", nodes.len());
|
||||||
|
|
||||||
|
fix_top_level_timestamps(&mut nodes);
|
||||||
|
for node in &mut nodes {
|
||||||
|
propagate_to_children(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure uniqueness — real timestamps can collide when two entries
|
||||||
|
// were written in the same ns; synthesized ones can also overlap.
|
||||||
|
// Bump colliding ns by 1 until unique.
|
||||||
|
let mut seen = std::collections::HashSet::new();
|
||||||
|
let mut bumps = 0usize;
|
||||||
|
for (i, node) in nodes.iter_mut().enumerate() {
|
||||||
|
let ts = top_ts(node);
|
||||||
|
assert!(ts > DateTime::<Utc>::UNIX_EPOCH,
|
||||||
|
"entry {}: still UNIX_EPOCH", i);
|
||||||
|
let mut ns = ts.timestamp_nanos_opt().expect("ts in i64 ns range");
|
||||||
|
let mut bumped = false;
|
||||||
|
while !seen.insert(ns) {
|
||||||
|
ns += 1;
|
||||||
|
bumped = true;
|
||||||
|
bumps += 1;
|
||||||
|
}
|
||||||
|
if bumped {
|
||||||
|
set_top_ts(node, DateTime::<Utc>::from_timestamp_nanos(ns));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("all {} timestamps real and unique ({} ns bumps)",
|
||||||
|
nodes.len(), bumps);
|
||||||
|
|
||||||
|
let tmp = path.with_extension("jsonl.tmp");
|
||||||
|
{
|
||||||
|
let f = std::fs::File::create(&tmp)
|
||||||
|
.with_context(|| format!("create {}", tmp.display()))?;
|
||||||
|
let mut w = BufWriter::new(f);
|
||||||
|
for node in &nodes {
|
||||||
|
serde_json::to_writer(&mut w, node)?;
|
||||||
|
w.write_all(b"\n")?;
|
||||||
|
}
|
||||||
|
w.flush()?;
|
||||||
|
}
|
||||||
|
std::fs::rename(&tmp, &path)
|
||||||
|
.with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?;
|
||||||
|
println!("wrote {}", path.display());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn top_ts(node: &AstNode) -> DateTime<Utc> {
|
||||||
|
match node {
|
||||||
|
AstNode::Leaf(leaf) => leaf.timestamp(),
|
||||||
|
AstNode::Branch { timestamp, .. } => *timestamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_top_ts(node: &mut AstNode, ts: DateTime<Utc>) {
|
||||||
|
match node {
|
||||||
|
AstNode::Leaf(leaf) => *leaf = leaf.clone().with_timestamp(ts),
|
||||||
|
AstNode::Branch { timestamp, .. } => *timestamp = ts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fill in missing top-level timestamps. Strategy:
|
||||||
|
/// - If two real timestamps bracket a run of missing ones, linearly
|
||||||
|
/// interpolate between them.
|
||||||
|
/// - If missing ones precede the first real one, back-fill using
|
||||||
|
/// (first_real - N·1µs).
|
||||||
|
/// - If missing ones follow the last real one, forward-fill.
|
||||||
|
/// - If no real timestamps exist at all, synthesize from now() going
|
||||||
|
/// backwards.
|
||||||
|
fn fix_top_level_timestamps(nodes: &mut [AstNode]) {
|
||||||
|
let real: Vec<(usize, DateTime<Utc>)> = nodes.iter().enumerate()
|
||||||
|
.filter(|(_, n)| top_ts(n) > DateTime::<Utc>::UNIX_EPOCH)
|
||||||
|
.map(|(i, n)| (i, top_ts(n)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if real.is_empty() {
|
||||||
|
let now = Utc::now();
|
||||||
|
let len = nodes.len();
|
||||||
|
for (i, node) in nodes.iter_mut().enumerate() {
|
||||||
|
let ts = now - Duration::microseconds((len - i) as i64);
|
||||||
|
set_top_ts(node, ts);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: bisect real[] for the nearest real entries around idx.
|
||||||
|
let find_bracket = |idx: usize| -> (Option<(usize, DateTime<Utc>)>,
|
||||||
|
Option<(usize, DateTime<Utc>)>) {
|
||||||
|
let pos = real.binary_search_by_key(&idx, |(i, _)| *i);
|
||||||
|
let (prior_pos, next_pos) = match pos {
|
||||||
|
Ok(p) => (Some(p), Some(p)),
|
||||||
|
Err(p) => (
|
||||||
|
if p == 0 { None } else { Some(p - 1) },
|
||||||
|
if p >= real.len() { None } else { Some(p) },
|
||||||
|
),
|
||||||
|
};
|
||||||
|
(prior_pos.map(|p| real[p]), next_pos.map(|p| real[p]))
|
||||||
|
};
|
||||||
|
|
||||||
|
for i in 0..nodes.len() {
|
||||||
|
if top_ts(&nodes[i]) > DateTime::<Utc>::UNIX_EPOCH {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let (prior, next) = find_bracket(i);
|
||||||
|
let new_ts = match (prior, next) {
|
||||||
|
(Some((pi, pt)), Some((ni, nt))) if pi != ni => {
|
||||||
|
// Linear interpolate.
|
||||||
|
let span_ns = (nt - pt).num_nanoseconds().unwrap_or(0);
|
||||||
|
let offset_ns = span_ns * (i - pi) as i64 / (ni - pi) as i64;
|
||||||
|
pt + Duration::nanoseconds(offset_ns)
|
||||||
|
}
|
||||||
|
(Some((pi, pt)), _) => {
|
||||||
|
pt + Duration::microseconds((i - pi) as i64)
|
||||||
|
}
|
||||||
|
(None, Some((ni, nt))) => {
|
||||||
|
nt - Duration::microseconds((ni - i) as i64)
|
||||||
|
}
|
||||||
|
(None, None) => unreachable!(),
|
||||||
|
};
|
||||||
|
set_top_ts(&mut nodes[i], new_ts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For every Branch, ensure each child Leaf has a timestamp. If missing,
|
||||||
|
/// use parent.ts + child_idx·1ns so siblings stay unique but close.
|
||||||
|
fn propagate_to_children(node: &mut AstNode) {
|
||||||
|
if let AstNode::Branch { timestamp, children, .. } = node {
|
||||||
|
let parent_ts = *timestamp;
|
||||||
|
for (ci, child) in children.iter_mut().enumerate() {
|
||||||
|
if top_ts(child) <= DateTime::<Utc>::UNIX_EPOCH {
|
||||||
|
set_top_ts(child, parent_ts + Duration::nanoseconds(ci as i64));
|
||||||
|
}
|
||||||
|
propagate_to_children(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -197,7 +197,7 @@ pub async fn cmd_load_context(stats: bool) -> Result<()> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("=== MEMORY SYSTEM ({}) ===", cfg.assistant_name);
|
println!("=== MEMORY SYSTEM ({}) ===", crate::config::app().assistant_name);
|
||||||
|
|
||||||
if !personality.is_empty() {
|
if !personality.is_empty() {
|
||||||
println!("--- personality_nodes ({}) ---", personality.len());
|
println!("--- personality_nodes ({}) ---", personality.len());
|
||||||
|
|
|
||||||
420
src/config.rs
420
src/config.rs
|
|
@ -3,9 +3,6 @@
|
||||||
// Single config file: ~/.consciousness/config.json5
|
// Single config file: ~/.consciousness/config.json5
|
||||||
// Memory settings in the "memory" section (Config)
|
// Memory settings in the "memory" section (Config)
|
||||||
// Agent/backend settings at top level (AppConfig)
|
// Agent/backend settings at top level (AppConfig)
|
||||||
//
|
|
||||||
// Legacy fallback: ~/.consciousness/config.jsonl
|
|
||||||
// Env override: POC_MEMORY_CONFIG
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
@ -29,9 +26,7 @@ pub fn config_path() -> PathBuf {
|
||||||
|
|
||||||
static CONFIG: OnceLock<RwLock<Arc<Config>>> = OnceLock::new();
|
static CONFIG: OnceLock<RwLock<Arc<Config>>> = OnceLock::new();
|
||||||
|
|
||||||
fn default_context_window() -> usize { 128_000 }
|
|
||||||
fn default_stream_timeout() -> u64 { 60 }
|
fn default_stream_timeout() -> u64 { 60 }
|
||||||
fn default_scoring_chunk_tokens() -> usize { 50_000 }
|
|
||||||
fn default_scoring_interval_secs() -> u64 { 3600 } // 1 hour
|
fn default_scoring_interval_secs() -> u64 { 3600 } // 1 hour
|
||||||
fn default_scoring_response_window() -> usize { 100 }
|
fn default_scoring_response_window() -> usize { 100 }
|
||||||
fn default_node_weight() -> f64 { 0.7 }
|
fn default_node_weight() -> f64 { 0.7 }
|
||||||
|
|
@ -45,8 +40,6 @@ fn default_identity_dir() -> PathBuf {
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub user_name: String,
|
|
||||||
pub assistant_name: String,
|
|
||||||
#[serde(deserialize_with = "deserialize_path")]
|
#[serde(deserialize_with = "deserialize_path")]
|
||||||
pub data_dir: PathBuf,
|
pub data_dir: PathBuf,
|
||||||
#[serde(default = "default_identity_dir", deserialize_with = "deserialize_path")]
|
#[serde(default = "default_identity_dir", deserialize_with = "deserialize_path")]
|
||||||
|
|
@ -62,51 +55,24 @@ pub struct Config {
|
||||||
/// Nodes loaded into subconscious agent context
|
/// Nodes loaded into subconscious agent context
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub agent_nodes: Vec<String>,
|
pub agent_nodes: Vec<String>,
|
||||||
pub journal_days: u32,
|
|
||||||
pub journal_max: usize,
|
|
||||||
pub llm_concurrency: usize,
|
pub llm_concurrency: usize,
|
||||||
pub agent_budget: usize,
|
|
||||||
#[serde(deserialize_with = "deserialize_path")]
|
|
||||||
pub prompts_dir: PathBuf,
|
|
||||||
/// Resolved from agent_model → models → backend (not in config directly)
|
|
||||||
#[serde(skip)]
|
|
||||||
pub api_base_url: Option<String>,
|
|
||||||
#[serde(skip)]
|
|
||||||
pub api_key: Option<String>,
|
|
||||||
#[serde(skip)]
|
|
||||||
pub api_model: Option<String>,
|
|
||||||
#[serde(skip, default = "default_context_window")]
|
|
||||||
pub api_context_window: usize,
|
|
||||||
/// Used to resolve API settings, not stored on Config
|
|
||||||
#[serde(default)]
|
|
||||||
agent_model: Option<String>,
|
|
||||||
/// Stream chunk timeout in seconds (no data = timeout).
|
/// Stream chunk timeout in seconds (no data = timeout).
|
||||||
#[serde(default = "default_stream_timeout")]
|
#[serde(default = "default_stream_timeout")]
|
||||||
pub api_stream_timeout_secs: u64,
|
pub api_stream_timeout_secs: u64,
|
||||||
/// Max tokens per chunk for memory scoring logprobs calls.
|
|
||||||
#[serde(default = "default_scoring_chunk_tokens")]
|
|
||||||
pub scoring_chunk_tokens: usize,
|
|
||||||
/// How often to re-score memory nodes (seconds). Default: 3600 (1 hour).
|
/// How often to re-score memory nodes (seconds). Default: 3600 (1 hour).
|
||||||
#[serde(default = "default_scoring_interval_secs")]
|
#[serde(default = "default_scoring_interval_secs")]
|
||||||
pub scoring_interval_secs: u64,
|
pub scoring_interval_secs: u64,
|
||||||
/// Number of assistant responses to score per memory. Default: 50.
|
/// Number of assistant responses to score per memory. Default: 50.
|
||||||
#[serde(default = "default_scoring_response_window")]
|
#[serde(default = "default_scoring_response_window")]
|
||||||
pub scoring_response_window: usize,
|
pub scoring_response_window: usize,
|
||||||
pub api_reasoning: String,
|
|
||||||
pub agent_types: Vec<String>,
|
pub agent_types: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub mcp_servers: Vec<McpServerConfig>,
|
pub mcp_servers: Vec<McpServerConfig>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub lsp_servers: Vec<LspServerConfig>,
|
pub lsp_servers: Vec<LspServerConfig>,
|
||||||
/// Surface agent timeout in seconds.
|
|
||||||
#[serde(default)]
|
|
||||||
pub surface_timeout_secs: Option<u32>,
|
|
||||||
/// Max conversation bytes to include in surface agent context.
|
/// Max conversation bytes to include in surface agent context.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub surface_conversation_bytes: Option<usize>,
|
pub surface_conversation_bytes: Option<usize>,
|
||||||
/// Hook events that trigger the surface agent.
|
|
||||||
#[serde(default)]
|
|
||||||
pub surface_hooks: Vec<String>,
|
|
||||||
|
|
||||||
// Spreading activation parameters
|
// Spreading activation parameters
|
||||||
#[serde(default = "default_node_weight")]
|
#[serde(default = "default_node_weight")]
|
||||||
|
|
@ -123,36 +89,21 @@ impl Default for Config {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let home = dirs::home_dir().unwrap_or_default();
|
let home = dirs::home_dir().unwrap_or_default();
|
||||||
Self {
|
Self {
|
||||||
user_name: "User".to_string(),
|
|
||||||
assistant_name: "Assistant".to_string(),
|
|
||||||
data_dir: home.join(".consciousness/memory"),
|
data_dir: home.join(".consciousness/memory"),
|
||||||
identity_dir: home.join(".consciousness/identity"),
|
identity_dir: home.join(".consciousness/identity"),
|
||||||
projects_dir: home.join(".claude/projects"),
|
projects_dir: home.join(".claude/projects"),
|
||||||
protected_nodes: Vec::new(),
|
protected_nodes: Vec::new(),
|
||||||
personality_nodes: vec!["identity".into(), "core-practices".into()],
|
personality_nodes: vec!["identity".into(), "core-practices".into()],
|
||||||
agent_nodes: vec!["identity".into(), "core-practices".into()],
|
agent_nodes: vec!["identity".into(), "core-practices".into()],
|
||||||
journal_days: 7,
|
|
||||||
journal_max: 20,
|
|
||||||
llm_concurrency: 1,
|
llm_concurrency: 1,
|
||||||
agent_budget: 1000,
|
|
||||||
prompts_dir: home.join(".consciousness/prompts"),
|
|
||||||
api_base_url: None,
|
|
||||||
api_key: None,
|
|
||||||
api_model: None,
|
|
||||||
api_context_window: default_context_window(),
|
|
||||||
api_stream_timeout_secs: default_stream_timeout(),
|
api_stream_timeout_secs: default_stream_timeout(),
|
||||||
scoring_chunk_tokens: default_scoring_chunk_tokens(),
|
|
||||||
scoring_interval_secs: default_scoring_interval_secs(),
|
scoring_interval_secs: default_scoring_interval_secs(),
|
||||||
scoring_response_window: default_scoring_response_window(),
|
scoring_response_window: default_scoring_response_window(),
|
||||||
agent_model: None,
|
|
||||||
api_reasoning: "high".to_string(),
|
|
||||||
agent_types: vec![
|
agent_types: vec![
|
||||||
"linker".into(), "organize".into(), "distill".into(),
|
"linker".into(), "organize".into(), "distill".into(),
|
||||||
"separator".into(), "split".into(),
|
"separator".into(), "split".into(),
|
||||||
],
|
],
|
||||||
surface_timeout_secs: None,
|
|
||||||
surface_conversation_bytes: None,
|
surface_conversation_bytes: None,
|
||||||
surface_hooks: vec![],
|
|
||||||
mcp_servers: vec![],
|
mcp_servers: vec![],
|
||||||
lsp_servers: vec![],
|
lsp_servers: vec![],
|
||||||
default_node_weight: default_node_weight(),
|
default_node_weight: default_node_weight(),
|
||||||
|
|
@ -165,41 +116,20 @@ impl Default for Config {
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn load_from_file() -> Self {
|
fn load_from_file() -> Self {
|
||||||
if let Some(config) = Self::try_load_shared() {
|
Self::try_load_shared().unwrap_or_default()
|
||||||
return config;
|
|
||||||
}
|
|
||||||
Self::load_legacy_jsonl()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load from shared config. Memory settings in the "memory" section;
|
/// Load from shared config. Memory settings in the "memory" section;
|
||||||
/// API settings resolved from models + backend configuration.
|
/// API settings resolved from models + backend configuration.
|
||||||
fn try_load_shared() -> Option<Self> {
|
fn try_load_shared() -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(config_path()).ok()?;
|
let content = std::fs::read_to_string(config_path()).ok()?;
|
||||||
let root: serde_json::Value = json5::from_str(&content).ok()?;
|
let root: serde_json::Value = json_five::from_str(&content).ok()?;
|
||||||
let mem_value = root.get("memory")?;
|
let mem_value = root.get("memory")?;
|
||||||
|
|
||||||
let mut config: Config = serde_json::from_value(mem_value.clone()).ok()?;
|
let mut config: Config = serde_json::from_value(mem_value.clone()).ok()?;
|
||||||
config.llm_concurrency = config.llm_concurrency.max(1);
|
config.llm_concurrency = config.llm_concurrency.max(1);
|
||||||
|
|
||||||
// Resolve API settings: agent_model → models → backend
|
// Top-level sections (not inside "memory").
|
||||||
if let Some(model_name) = &config.agent_model
|
|
||||||
&& let Some(model_cfg) = root.get("models").and_then(|m| m.get(model_name.as_str())) {
|
|
||||||
let backend_name = model_cfg.get("backend").and_then(|v| v.as_str()).unwrap_or("");
|
|
||||||
let model_id = model_cfg.get("model_id").and_then(|v| v.as_str()).unwrap_or("");
|
|
||||||
|
|
||||||
if let Some(backend) = root.get(backend_name) {
|
|
||||||
config.api_base_url = backend.get("base_url")
|
|
||||||
.and_then(|v| v.as_str()).map(String::from);
|
|
||||||
config.api_key = backend.get("api_key")
|
|
||||||
.and_then(|v| v.as_str()).map(String::from);
|
|
||||||
}
|
|
||||||
config.api_model = Some(model_id.to_string());
|
|
||||||
if let Some(cw) = model_cfg.get("context_window").and_then(|v| v.as_u64()) {
|
|
||||||
config.api_context_window = cw as usize;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Top-level config sections (not inside "memory")
|
|
||||||
if let Some(servers) = root.get("lsp_servers") {
|
if let Some(servers) = root.get("lsp_servers") {
|
||||||
config.lsp_servers = serde_json::from_value(servers.clone()).unwrap_or_default();
|
config.lsp_servers = serde_json::from_value(servers.clone()).unwrap_or_default();
|
||||||
}
|
}
|
||||||
|
|
@ -209,11 +139,6 @@ impl Config {
|
||||||
|
|
||||||
Some(config)
|
Some(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load from legacy JSONL config — deprecated, just return defaults.
|
|
||||||
fn load_legacy_jsonl() -> Self {
|
|
||||||
Config::default()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the global memory config (cheap Arc clone).
|
/// Get the global memory config (cheap Arc clone).
|
||||||
|
|
@ -237,27 +162,85 @@ pub fn reload() -> bool {
|
||||||
changed
|
changed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Spawn a background thread that watches `~/.consciousness/config.json5`
|
||||||
|
/// and reloads both the memory Config and the global AppConfig whenever
|
||||||
|
/// the file changes on disk. Lets edits from vim / F6 hotkeys / manual
|
||||||
|
/// tweaks land live without restarting the process.
|
||||||
|
pub fn watch_config(cli: crate::user::CliArgs) {
|
||||||
|
use notify_debouncer_mini::{new_debouncer, notify::RecursiveMode};
|
||||||
|
|
||||||
|
let path = config_path();
|
||||||
|
// Watch the parent directory — editors often replace-via-rename, so
|
||||||
|
// watching the file itself misses the new inode.
|
||||||
|
let Some(parent) = path.parent().map(|p| p.to_path_buf()) else {
|
||||||
|
crate::dbglog!("[config] no parent for {}, skipping watch", path.display());
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::thread::Builder::new()
|
||||||
|
.name("config-watcher".into())
|
||||||
|
.spawn(move || {
|
||||||
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
|
let mut debouncer = match new_debouncer(std::time::Duration::from_millis(200), tx) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
crate::dbglog!("[config] watcher setup failed: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = debouncer.watcher()
|
||||||
|
.watch(&parent, RecursiveMode::NonRecursive)
|
||||||
|
{
|
||||||
|
crate::dbglog!("[config] watch({}) failed: {}", parent.display(), e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
crate::dbglog!("[config] watching {}", path.display());
|
||||||
|
|
||||||
|
while let Ok(res) = rx.recv() {
|
||||||
|
let Ok(events) = res else { continue; };
|
||||||
|
if !events.iter().any(|e| e.path == path) { continue; }
|
||||||
|
|
||||||
|
// Reload both halves.
|
||||||
|
let mem_changed = reload();
|
||||||
|
let app_changed = match build_figment(&cli).extract::<AppConfig>() {
|
||||||
|
Ok(app) => {
|
||||||
|
install_app(app);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
crate::dbglog!("[config] reload: AppConfig parse failed: {}", e);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
crate::dbglog!("[config] reloaded (memory_changed={}, app_changed={})",
|
||||||
|
mem_changed, app_changed);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// Agent config (top-level settings)
|
// Agent config (top-level settings)
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct AppConfig {
|
pub struct AppConfig {
|
||||||
pub backend: String,
|
#[serde(default = "default_user_name")]
|
||||||
pub anthropic: BackendConfig,
|
pub user_name: String,
|
||||||
pub openrouter: BackendConfig,
|
#[serde(default = "default_assistant_name")]
|
||||||
|
pub assistant_name: String,
|
||||||
|
/// Named model endpoints — credentials, base URL, and model id bundled
|
||||||
|
/// into one entry per backend. Keyed by name, selected by
|
||||||
|
/// `default_backend` or by `--model <name>` on the CLI.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub deepinfra: BackendConfig,
|
pub backends: HashMap<String, BackendConfig>,
|
||||||
pub prompts: PromptConfig,
|
#[serde(default)]
|
||||||
|
pub default_backend: String,
|
||||||
pub debug: bool,
|
pub debug: bool,
|
||||||
pub compaction: CompactionConfig,
|
pub compaction: CompactionConfig,
|
||||||
pub dmn: DmnConfig,
|
pub dmn: DmnConfig,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub memory_project: Option<PathBuf>,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub models: HashMap<String, ModelConfig>,
|
pub learn: LearnConfig,
|
||||||
#[serde(default = "default_model_name")]
|
|
||||||
pub default_model: String,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub mcp_servers: Vec<McpServerConfig>,
|
pub mcp_servers: Vec<McpServerConfig>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
@ -284,32 +267,17 @@ pub struct LspServerConfig {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct BackendConfig {
|
pub struct BackendConfig {
|
||||||
|
/// API key for the backend.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
#[serde(default)]
|
/// Base URL for the backend's OpenAI-compatible endpoint.
|
||||||
pub model: String,
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub base_url: Option<String>,
|
pub base_url: Option<String>,
|
||||||
}
|
/// Model identifier sent to the API.
|
||||||
|
pub model_id: String,
|
||||||
impl BackendConfig {
|
/// Context window size in tokens.
|
||||||
fn resolve(&self, default_base: &str) -> Result<(String, String, String)> {
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
if self.api_key.is_empty() {
|
pub context_window: Option<usize>,
|
||||||
anyhow::bail!(
|
|
||||||
"No API key. Set it in {} or use --api-key",
|
|
||||||
config_path().display()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let base = self.base_url.clone()
|
|
||||||
.unwrap_or_else(|| default_base.to_string());
|
|
||||||
Ok((base, self.api_key.clone(), self.model.clone()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct PromptConfig {
|
|
||||||
pub anthropic: String,
|
|
||||||
pub other: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -324,65 +292,57 @@ pub struct DmnConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ModelConfig {
|
pub struct LearnConfig {
|
||||||
/// Backend name ("anthropic" or "openrouter")
|
/// Divergence threshold — responses scoring above this become
|
||||||
pub backend: String,
|
/// fine-tuning candidates. Lower = more sensitive.
|
||||||
/// Model identifier sent to the API
|
#[serde(default = "default_learn_threshold")]
|
||||||
pub model_id: String,
|
pub threshold: f64,
|
||||||
/// Instruction file ("CLAUDE.md" or "POC.md").
|
/// Whether to generate "what would the model have said without
|
||||||
|
/// memories" alternates alongside each scoring run. Expensive —
|
||||||
|
/// one full streaming generation per candidate.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub prompt_file: Option<String>,
|
pub generate_alternates: bool,
|
||||||
/// Context window size in tokens.
|
|
||||||
#[serde(default)]
|
|
||||||
pub context_window: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_learn_threshold() -> f64 { 1.0 }
|
||||||
|
|
||||||
|
impl Default for LearnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
threshold: default_learn_threshold(),
|
||||||
|
generate_alternates: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_user_name() -> String { "User".into() }
|
||||||
|
fn default_assistant_name() -> String { "Assistant".into() }
|
||||||
|
|
||||||
impl Default for AppConfig {
|
impl Default for AppConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
backend: "openrouter".to_string(),
|
user_name: default_user_name(),
|
||||||
anthropic: BackendConfig {
|
assistant_name: default_assistant_name(),
|
||||||
api_key: String::new(),
|
backends: HashMap::new(),
|
||||||
model: "claude-opus-4-6-20250918".to_string(),
|
default_backend: String::new(),
|
||||||
base_url: None,
|
|
||||||
},
|
|
||||||
openrouter: BackendConfig {
|
|
||||||
api_key: String::new(),
|
|
||||||
model: "qwen/qwen3.5-397b-a17b".to_string(),
|
|
||||||
base_url: Some("https://openrouter.ai/api/v1".to_string()),
|
|
||||||
},
|
|
||||||
deepinfra: BackendConfig {
|
|
||||||
api_key: String::new(),
|
|
||||||
model: String::new(),
|
|
||||||
base_url: Some("https://api.deepinfra.com/v1/openai".to_string()),
|
|
||||||
},
|
|
||||||
prompts: PromptConfig {
|
|
||||||
anthropic: "CLAUDE.md".to_string(),
|
|
||||||
other: "POC.md".to_string(),
|
|
||||||
},
|
|
||||||
debug: false,
|
debug: false,
|
||||||
compaction: CompactionConfig {
|
compaction: CompactionConfig {
|
||||||
hard_threshold_pct: 90,
|
hard_threshold_pct: 90,
|
||||||
soft_threshold_pct: 80,
|
soft_threshold_pct: 80,
|
||||||
},
|
},
|
||||||
dmn: DmnConfig { max_turns: 20 },
|
dmn: DmnConfig { max_turns: 20 },
|
||||||
memory_project: None,
|
learn: LearnConfig::default(),
|
||||||
models: HashMap::new(),
|
|
||||||
default_model: String::new(),
|
|
||||||
mcp_servers: Vec::new(),
|
mcp_servers: Vec::new(),
|
||||||
lsp_servers: Vec::new(),
|
lsp_servers: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_model_name() -> String { String::new() }
|
|
||||||
|
|
||||||
/// Resolved, ready-to-use agent session config.
|
/// Resolved, ready-to-use agent session config.
|
||||||
pub struct SessionConfig {
|
pub struct SessionConfig {
|
||||||
pub api_base: String,
|
pub api_base: String,
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub prompt_file: String,
|
|
||||||
/// Identity/personality nodes as (name, content) pairs.
|
/// Identity/personality nodes as (name, content) pairs.
|
||||||
pub context_parts: Vec<(String, String)>,
|
pub context_parts: Vec<(String, String)>,
|
||||||
pub session_dir: PathBuf,
|
pub session_dir: PathBuf,
|
||||||
|
|
@ -398,37 +358,22 @@ pub struct ResolvedModel {
|
||||||
pub api_base: String,
|
pub api_base: String,
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub prompt_file: String,
|
|
||||||
pub context_window: Option<usize>,
|
pub context_window: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppConfig {
|
impl AppConfig {
|
||||||
/// Resolve the active backend and assemble prompts into a SessionConfig.
|
/// Resolve the active backend and assemble prompts into a SessionConfig.
|
||||||
pub async fn resolve(&self, cli: &crate::user::CliArgs) -> Result<SessionConfig> {
|
pub async fn resolve(&self, cli: &crate::user::CliArgs) -> Result<SessionConfig> {
|
||||||
let (api_base, api_key, model, prompt_file);
|
if self.backends.is_empty() {
|
||||||
|
anyhow::bail!(
|
||||||
if !self.models.is_empty() {
|
"no backends configured in {}. Add a `backends` section with at least one entry.",
|
||||||
let model_name = cli.model.as_deref().unwrap_or(&self.default_model);
|
config_path().display()
|
||||||
let resolved = self.resolve_model(model_name)?;
|
);
|
||||||
api_base = resolved.api_base;
|
|
||||||
api_key = resolved.api_key;
|
|
||||||
model = resolved.model_id;
|
|
||||||
prompt_file = resolved.prompt_file;
|
|
||||||
} else {
|
|
||||||
let (base, key, mdl) = match self.backend.as_str() {
|
|
||||||
"anthropic" => self.anthropic.resolve("https://api.anthropic.com"),
|
|
||||||
_ => self.openrouter.resolve("https://openrouter.ai/api/v1"),
|
|
||||||
}?;
|
|
||||||
api_base = base;
|
|
||||||
api_key = key;
|
|
||||||
model = mdl;
|
|
||||||
prompt_file = if self.backend == "anthropic" {
|
|
||||||
self.prompts.anthropic.clone()
|
|
||||||
} else {
|
|
||||||
self.prompts.other.clone()
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let name = cli.model.as_deref().unwrap_or(&self.default_backend);
|
||||||
|
let resolved = self.resolve_model(name)?;
|
||||||
|
|
||||||
let personality_nodes = get().personality_nodes.clone();
|
let personality_nodes = get().personality_nodes.clone();
|
||||||
let context_parts = crate::mind::identity::personality_nodes(&personality_nodes).await;
|
let context_parts = crate::mind::identity::personality_nodes(&personality_nodes).await;
|
||||||
|
|
||||||
|
|
@ -438,11 +383,13 @@ impl AppConfig {
|
||||||
std::fs::create_dir_all(&session_dir).ok();
|
std::fs::create_dir_all(&session_dir).ok();
|
||||||
|
|
||||||
// CLI --api-base and --api-key override everything
|
// CLI --api-base and --api-key override everything
|
||||||
let api_base = cli.api_base.clone().unwrap_or(api_base);
|
let api_base = cli.api_base.clone().unwrap_or(resolved.api_base);
|
||||||
let api_key = cli.api_key.clone().unwrap_or(api_key);
|
let api_key = cli.api_key.clone().unwrap_or(resolved.api_key);
|
||||||
|
|
||||||
Ok(SessionConfig {
|
Ok(SessionConfig {
|
||||||
api_base, api_key, model, prompt_file,
|
api_base,
|
||||||
|
api_key,
|
||||||
|
model: resolved.model_id,
|
||||||
context_parts,
|
context_parts,
|
||||||
session_dir,
|
session_dir,
|
||||||
app: self.clone(),
|
app: self.clone(),
|
||||||
|
|
@ -450,55 +397,33 @@ impl AppConfig {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Look up a named model and resolve its credentials from the backend config.
|
/// Look up a named backend and resolve its credentials.
|
||||||
pub fn resolve_model(&self, name: &str) -> Result<ResolvedModel> {
|
pub fn resolve_model(&self, name: &str) -> Result<ResolvedModel> {
|
||||||
let model = self.models.get(name)
|
let b = self.backends.get(name)
|
||||||
.ok_or_else(|| anyhow::anyhow!(
|
.ok_or_else(|| anyhow::anyhow!(
|
||||||
"Unknown model '{}'. Available: {}",
|
"Unknown backend '{}'. Available: {}",
|
||||||
name,
|
name,
|
||||||
self.model_names().join(", "),
|
self.model_names().join(", "),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
let (api_base, api_key) = match model.backend.as_str() {
|
let api_base = b.base_url.clone()
|
||||||
"anthropic" => (
|
.ok_or_else(|| anyhow::anyhow!(
|
||||||
self.anthropic.base_url.clone()
|
"backends.{}.base_url not set in {}",
|
||||||
.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
|
name, config_path().display()
|
||||||
self.anthropic.api_key.clone(),
|
))?;
|
||||||
),
|
|
||||||
"deepinfra" => (
|
|
||||||
self.deepinfra.base_url.clone()
|
|
||||||
.unwrap_or_else(|| "https://api.deepinfra.com/v1/openai".to_string()),
|
|
||||||
self.deepinfra.api_key.clone(),
|
|
||||||
),
|
|
||||||
_ => (
|
|
||||||
self.openrouter.base_url.clone()
|
|
||||||
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()),
|
|
||||||
self.openrouter.api_key.clone(),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let prompt_file = model.prompt_file.clone()
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
if model.backend == "anthropic" {
|
|
||||||
self.prompts.anthropic.clone()
|
|
||||||
} else {
|
|
||||||
self.prompts.other.clone()
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(ResolvedModel {
|
Ok(ResolvedModel {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
api_base,
|
api_base,
|
||||||
api_key,
|
api_key: b.api_key.clone(),
|
||||||
model_id: model.model_id.clone(),
|
model_id: b.model_id.clone(),
|
||||||
prompt_file,
|
context_window: b.context_window,
|
||||||
context_window: model.context_window,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List available model names, sorted.
|
/// List available backend names, sorted.
|
||||||
pub fn model_names(&self) -> Vec<String> {
|
pub fn model_names(&self) -> Vec<String> {
|
||||||
let mut names: Vec<_> = self.models.keys().cloned().collect();
|
let mut names: Vec<_> = self.backends.keys().cloned().collect();
|
||||||
names.sort();
|
names.sort();
|
||||||
names
|
names
|
||||||
}
|
}
|
||||||
|
|
@ -518,7 +443,7 @@ impl Provider for Json5File {
|
||||||
fn data(&self) -> figment::Result<figment::value::Map<figment::Profile, figment::value::Dict>> {
|
fn data(&self) -> figment::Result<figment::value::Map<figment::Profile, figment::value::Dict>> {
|
||||||
match std::fs::read_to_string(&self.0) {
|
match std::fs::read_to_string(&self.0) {
|
||||||
Ok(content) => {
|
Ok(content) => {
|
||||||
let value: figment::value::Value = json5::from_str(&content)
|
let value: figment::value::Value = json_five::from_str(&content)
|
||||||
.map_err(|e| figment::Error::from(format!("{}: {}", self.0.display(), e)))?;
|
.map_err(|e| figment::Error::from(format!("{}: {}", self.0.display(), e)))?;
|
||||||
Serialized::defaults(value).data()
|
Serialized::defaults(value).data()
|
||||||
}
|
}
|
||||||
|
|
@ -540,11 +465,6 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment {
|
||||||
let mut f = Figment::from(Serialized::defaults(AppConfig::default()))
|
let mut f = Figment::from(Serialized::defaults(AppConfig::default()))
|
||||||
.merge(Json5File(config_path()));
|
.merge(Json5File(config_path()));
|
||||||
|
|
||||||
merge_opt!(f, cli.backend, "backend");
|
|
||||||
merge_opt!(f, cli.model, "anthropic.model", "openrouter.model");
|
|
||||||
merge_opt!(f, cli.api_key, "anthropic.api_key", "openrouter.api_key");
|
|
||||||
merge_opt!(f, cli.api_base, "anthropic.base_url", "openrouter.base_url");
|
|
||||||
merge_opt!(f, cli.memory_project, "memory_project");
|
|
||||||
merge_opt!(f, cli.dmn_max_turns, "dmn.max_turns");
|
merge_opt!(f, cli.dmn_max_turns, "dmn.max_turns");
|
||||||
if cli.debug {
|
if cli.debug {
|
||||||
f = f.merge(Serialized::default("debug", true));
|
f = f.merge(Serialized::default("debug", true));
|
||||||
|
|
@ -554,12 +474,46 @@ fn build_figment(cli: &crate::user::CliArgs) -> Figment {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load just the AppConfig — no validation, no prompt assembly.
|
/// Load just the AppConfig — no validation, no prompt assembly.
|
||||||
|
/// Also installs the loaded AppConfig into the global cache so
|
||||||
|
/// `config::app()` is available everywhere.
|
||||||
pub fn load_app(cli: &crate::user::CliArgs) -> Result<(AppConfig, Figment)> {
|
pub fn load_app(cli: &crate::user::CliArgs) -> Result<(AppConfig, Figment)> {
|
||||||
let figment = build_figment(cli);
|
let figment = build_figment(cli);
|
||||||
let app: AppConfig = figment.extract().context("Failed to load configuration")?;
|
let app: AppConfig = figment.extract().context("Failed to load configuration")?;
|
||||||
|
install_app(app.clone());
|
||||||
Ok((app, figment))
|
Ok((app, figment))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Global AppConfig cache (writable, for runtime-mutable settings
|
||||||
|
// like learn.threshold that F6 edits via config_writer).
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static APP_CONFIG: OnceLock<RwLock<AppConfig>> = OnceLock::new();
|
||||||
|
|
||||||
|
fn install_app(app: AppConfig) {
|
||||||
|
let slot = APP_CONFIG.get_or_init(|| RwLock::new(app.clone()));
|
||||||
|
*slot.write().unwrap() = app;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Current AppConfig, held under a read lock. Reads should be brief
|
||||||
|
/// (no holding across await / long work) to avoid starving writers.
|
||||||
|
/// Panics if called before load_app — which runs once at startup.
|
||||||
|
pub fn app() -> std::sync::RwLockReadGuard<'static, AppConfig> {
|
||||||
|
APP_CONFIG
|
||||||
|
.get()
|
||||||
|
.expect("config::app() called before load_app()")
|
||||||
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutate the cached AppConfig in place. Used by config_writer to keep
|
||||||
|
/// the in-memory view in sync with disk after surgical edits to
|
||||||
|
/// ~/.consciousness/config.json5.
|
||||||
|
pub fn update_app(f: impl FnOnce(&mut AppConfig)) {
|
||||||
|
let slot = APP_CONFIG.get().expect("update_app before load_app");
|
||||||
|
f(&mut *slot.write().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
/// Load the full config: figment → AppConfig → resolve backend → assemble prompts.
|
/// Load the full config: figment → AppConfig → resolve backend → assemble prompts.
|
||||||
pub async fn load_session(cli: &crate::user::CliArgs) -> Result<(SessionConfig, Figment)> {
|
pub async fn load_session(cli: &crate::user::CliArgs) -> Result<(SessionConfig, Figment)> {
|
||||||
let (app, figment) = load_app(cli)?;
|
let (app, figment) = load_app(cli)?;
|
||||||
|
|
@ -585,38 +539,28 @@ pub fn show_config(app: &AppConfig, figment: &Figment) {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("# Effective configuration\n");
|
println!("# Effective configuration\n");
|
||||||
println!("backend: {:?} ({})", app.backend, src(figment, "backend"));
|
println!("user_name: {:?} ({})", app.user_name, src(figment, "user_name"));
|
||||||
for (name, b) in [("anthropic", &app.anthropic), ("openrouter", &app.openrouter)] {
|
println!("assistant_name: {:?} ({})", app.assistant_name, src(figment, "assistant_name"));
|
||||||
println!("\n{}:", name);
|
|
||||||
println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("{name}.api_key")));
|
|
||||||
println!(" model: {:?} ({})", b.model, src(figment, &format!("{name}.model")));
|
|
||||||
if let Some(ref url) = b.base_url {
|
|
||||||
println!(" base_url: {:?} ({})", url, src(figment, &format!("{name}.base_url")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!("\nprompts:");
|
|
||||||
println!(" anthropic: {:?} ({})", app.prompts.anthropic, src(figment, "prompts.anthropic"));
|
|
||||||
println!(" other: {:?} ({})", app.prompts.other, src(figment, "prompts.other"));
|
|
||||||
println!("\ndebug: {} ({})", app.debug, src(figment, "debug"));
|
println!("\ndebug: {} ({})", app.debug, src(figment, "debug"));
|
||||||
println!("\ncompaction:");
|
println!("\ncompaction:");
|
||||||
println!(" hard_threshold_pct: {} ({})", app.compaction.hard_threshold_pct, src(figment, "compaction.hard_threshold_pct"));
|
println!(" hard_threshold_pct: {} ({})", app.compaction.hard_threshold_pct, src(figment, "compaction.hard_threshold_pct"));
|
||||||
println!(" soft_threshold_pct: {} ({})", app.compaction.soft_threshold_pct, src(figment, "compaction.soft_threshold_pct"));
|
println!(" soft_threshold_pct: {} ({})", app.compaction.soft_threshold_pct, src(figment, "compaction.soft_threshold_pct"));
|
||||||
println!("\ndmn:");
|
println!("\ndmn:");
|
||||||
println!(" max_turns: {} ({})", app.dmn.max_turns, src(figment, "dmn.max_turns"));
|
println!(" max_turns: {} ({})", app.dmn.max_turns, src(figment, "dmn.max_turns"));
|
||||||
if let Some(ref p) = app.memory_project {
|
println!("\ndefault_backend: {:?} ({})", app.default_backend, src(figment, "default_backend"));
|
||||||
println!("\nmemory_project: {:?} ({})", p, src(figment, "memory_project"));
|
if !app.backends.is_empty() {
|
||||||
}
|
println!("\nbackends:");
|
||||||
println!("\ndefault_model: {:?}", app.default_model);
|
let mut names: Vec<_> = app.backends.keys().cloned().collect();
|
||||||
if !app.models.is_empty() {
|
names.sort();
|
||||||
println!("\nmodels:");
|
for name in names {
|
||||||
for (name, m) in &app.models {
|
let b = &app.backends[&name];
|
||||||
println!(" {}:", name);
|
println!(" {}:", name);
|
||||||
println!(" backend: {:?}", m.backend);
|
println!(" api_key: {} ({})", mask(&b.api_key), src(figment, &format!("backends.{name}.api_key")));
|
||||||
println!(" model_id: {:?}", m.model_id);
|
if let Some(ref url) = b.base_url {
|
||||||
if let Some(ref pf) = m.prompt_file {
|
println!(" base_url: {:?} ({})", url, src(figment, &format!("backends.{name}.base_url")));
|
||||||
println!(" prompt_file: {:?}", pf);
|
|
||||||
}
|
}
|
||||||
if let Some(cw) = m.context_window {
|
println!(" model_id: {:?}", b.model_id);
|
||||||
|
if let Some(cw) = b.context_window {
|
||||||
println!(" context_window: {}", cw);
|
println!(" context_window: {}", cw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
448
src/config_writer.rs
Normal file
448
src/config_writer.rs
Normal file
|
|
@ -0,0 +1,448 @@
|
||||||
|
// config_writer.rs — Surgical edits to ~/.consciousness/config.json5
|
||||||
|
//
|
||||||
|
// Uses json-five's round-trip parser to mutate specific fields while
|
||||||
|
// preserving the surrounding comments, whitespace, and formatting.
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
|
use json_five::rt::parser::{
|
||||||
|
from_str, JSONKeyValuePair, JSONObjectContext, JSONValue, KeyValuePairContext,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::config::config_path;
|
||||||
|
|
||||||
|
/// Read the config, apply `mutate` to the root JSONValue, write it back atomically.
|
||||||
|
fn edit_config<F: FnOnce(&mut JSONValue) -> Result<()>>(mutate: F) -> Result<()> {
|
||||||
|
let path = config_path();
|
||||||
|
let src = std::fs::read_to_string(&path)
|
||||||
|
.with_context(|| format!("read {}", path.display()))?;
|
||||||
|
|
||||||
|
let mut text = from_str(&src)
|
||||||
|
.map_err(|e| anyhow!("parse {}: {}", path.display(), e))?;
|
||||||
|
mutate(&mut text.value)?;
|
||||||
|
|
||||||
|
write_atomic(&path, &text.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_atomic(path: &Path, content: &str) -> Result<()> {
|
||||||
|
let parent = path.parent()
|
||||||
|
.ok_or_else(|| anyhow!("config path has no parent: {}", path.display()))?;
|
||||||
|
let tmp = parent.join(format!(
|
||||||
|
".{}.tmp",
|
||||||
|
path.file_name().unwrap_or_default().to_string_lossy(),
|
||||||
|
));
|
||||||
|
std::fs::write(&tmp, content)
|
||||||
|
.with_context(|| format!("write {}", tmp.display()))?;
|
||||||
|
std::fs::rename(&tmp, path)
|
||||||
|
.with_context(|| format!("rename {} -> {}", tmp.display(), path.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Match a key JSONValue against a string name. JSON5 allows keys to be
|
||||||
|
/// unquoted identifiers or single/double-quoted strings.
|
||||||
|
fn key_matches(key: &JSONValue, name: &str) -> bool {
|
||||||
|
match key {
|
||||||
|
JSONValue::Identifier(s)
|
||||||
|
| JSONValue::DoubleQuotedString(s)
|
||||||
|
| JSONValue::SingleQuotedString(s) => s == name,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find (or create) a child object under `parent`, returning a mutable borrow
|
||||||
|
/// of its key_value_pairs vector.
|
||||||
|
/// Append a new kvp to `object`, setting whitespace so the output is
|
||||||
|
/// multi-line with the given indentation:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// {<newline><inner_indent>first_key: first_val,<newline><outer_indent>}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// If `object` already has kvps, the separator between the last one and
|
||||||
|
/// ours goes in the prior kvp's wsc.3. If we're the first kvp, the
|
||||||
|
/// lead-in after `{` goes in the object's own wsc.0.
|
||||||
|
fn append_kvp_pretty(
|
||||||
|
object: &mut JSONValue,
|
||||||
|
key: JSONValue,
|
||||||
|
value: JSONValue,
|
||||||
|
inner_indent: &str,
|
||||||
|
outer_indent: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let (pairs, ctx) = match object {
|
||||||
|
JSONValue::JSONObject { key_value_pairs, context } => {
|
||||||
|
let ctx = context.get_or_insert_with(|| JSONObjectContext {
|
||||||
|
wsc: (String::new(),),
|
||||||
|
});
|
||||||
|
(key_value_pairs, ctx)
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow!("not an object")),
|
||||||
|
};
|
||||||
|
|
||||||
|
if pairs.is_empty() {
|
||||||
|
ctx.wsc.0 = format!("\n{}", inner_indent);
|
||||||
|
} else {
|
||||||
|
let prev = pairs.last_mut().unwrap();
|
||||||
|
let prev_ctx = prev.context.get_or_insert_with(|| KeyValuePairContext {
|
||||||
|
wsc: (String::new(), String::from(" "), String::new(), None),
|
||||||
|
});
|
||||||
|
prev_ctx.wsc.3 = Some(format!("\n{}", inner_indent));
|
||||||
|
}
|
||||||
|
|
||||||
|
pairs.push(JSONKeyValuePair {
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
context: Some(KeyValuePairContext {
|
||||||
|
wsc: (
|
||||||
|
String::new(),
|
||||||
|
String::from(" "),
|
||||||
|
String::new(),
|
||||||
|
Some(format!("\n{}", outer_indent)),
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find or create a child object under `parent`. Returns the index of
|
||||||
|
/// the kvp in parent's key_value_pairs so the caller can re-borrow
|
||||||
|
/// afterward.
|
||||||
|
fn get_or_create_object_idx(
|
||||||
|
parent: &mut JSONValue,
|
||||||
|
section: &str,
|
||||||
|
inner_indent: &str,
|
||||||
|
outer_indent: &str,
|
||||||
|
) -> Result<usize> {
|
||||||
|
let existing = match parent {
|
||||||
|
JSONValue::JSONObject { key_value_pairs, .. } => {
|
||||||
|
key_value_pairs.iter()
|
||||||
|
.position(|kvp| key_matches(&kvp.key, section))
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow!("config root is not an object")),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(i) = existing {
|
||||||
|
return Ok(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
append_kvp_pretty(
|
||||||
|
parent,
|
||||||
|
JSONValue::Identifier(section.to_string()),
|
||||||
|
JSONValue::JSONObject {
|
||||||
|
key_value_pairs: Vec::new(),
|
||||||
|
context: Some(JSONObjectContext { wsc: (String::new(),) }),
|
||||||
|
},
|
||||||
|
inner_indent,
|
||||||
|
outer_indent,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
match parent {
|
||||||
|
JSONValue::JSONObject { key_value_pairs, .. } => Ok(key_value_pairs.len() - 1),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set `section.key` to a literal scalar value (e.g., "1e-7", "42", "true").
|
||||||
|
/// The literal is parsed as JSON5 so we preserve its source-form on round-trip.
|
||||||
|
pub fn set_scalar(section: &str, key: &str, literal: &str) -> Result<()> {
|
||||||
|
let value = parse_scalar_literal(literal)?;
|
||||||
|
edit_config(|root| {
|
||||||
|
// New top-level sections sit at column 4 (inside root `{`),
|
||||||
|
// and the root's closing `}` sits at column 0.
|
||||||
|
let section_idx = get_or_create_object_idx(root, section, " ", "")?;
|
||||||
|
|
||||||
|
let section_value = match root {
|
||||||
|
JSONValue::JSONObject { key_value_pairs, .. } => {
|
||||||
|
&mut key_value_pairs[section_idx].value
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update in place if the key already exists.
|
||||||
|
if let JSONValue::JSONObject { key_value_pairs, .. } = section_value {
|
||||||
|
if let Some(kvp) = key_value_pairs.iter_mut()
|
||||||
|
.find(|k| key_matches(&k.key, key))
|
||||||
|
{
|
||||||
|
kvp.value = value;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append a new kvp. Inner keys sit at column 8, the section's
|
||||||
|
// closing `}` sits at column 4.
|
||||||
|
append_kvp_pretty(
|
||||||
|
section_value,
|
||||||
|
JSONValue::Identifier(key.to_string()),
|
||||||
|
value,
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a scalar literal by round-tripping it through json-five. Keeps us
|
||||||
|
/// consistent with whatever scalars the library considers valid (hex,
|
||||||
|
/// exponents, Infinity, etc.).
|
||||||
|
fn parse_scalar_literal(literal: &str) -> Result<JSONValue> {
|
||||||
|
let text = from_str(literal)
|
||||||
|
.map_err(|e| anyhow!("parse literal {:?}: {}", literal, e))?;
|
||||||
|
match text.value {
|
||||||
|
JSONValue::JSONObject { .. } | JSONValue::JSONArray { .. } => {
|
||||||
|
Err(anyhow!("set_scalar only accepts scalar literals, got {:?}", literal))
|
||||||
|
}
|
||||||
|
v => Ok(v),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience: set `learn.threshold` to the given f64.
|
||||||
|
pub fn set_learn_threshold(value: f64) -> Result<()> {
|
||||||
|
// {:e} gives the minimal scientific notation that preserves the value.
|
||||||
|
set_scalar("learn", "threshold", &format!("{:e}", value))?;
|
||||||
|
crate::config::update_app(|app| app.learn.threshold = value);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience: set `learn.generate_alternates` to the given bool.
|
||||||
|
pub fn set_learn_generate_alternates(value: bool) -> Result<()> {
|
||||||
|
set_scalar("learn", "generate_alternates",
|
||||||
|
if value { "true" } else { "false" })?;
|
||||||
|
crate::config::update_app(|app| app.learn.generate_alternates = value);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// In-memory variant of set_scalar — used to test the mutation logic
|
||||||
|
// without touching disk.
|
||||||
|
fn set_scalar_inline(
|
||||||
|
root: &mut JSONValue,
|
||||||
|
section: &str,
|
||||||
|
key: &str,
|
||||||
|
literal: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let value = parse_scalar_literal(literal)?;
|
||||||
|
let section_idx = get_or_create_object_idx(root, section, " ", "")?;
|
||||||
|
let section_value = match root {
|
||||||
|
JSONValue::JSONObject { key_value_pairs, .. } => {
|
||||||
|
&mut key_value_pairs[section_idx].value
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
if let JSONValue::JSONObject { key_value_pairs, .. } = section_value {
|
||||||
|
if let Some(kvp) = key_value_pairs.iter_mut()
|
||||||
|
.find(|k| key_matches(&k.key, key))
|
||||||
|
{
|
||||||
|
kvp.value = value;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
append_kvp_pretty(
|
||||||
|
section_value,
|
||||||
|
JSONValue::Identifier(key.to_string()),
|
||||||
|
value,
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn edit_str<F: FnOnce(&mut JSONValue) -> Result<()>>(src: &str, f: F) -> Result<String> {
|
||||||
|
let mut text = from_str(src).map_err(|e| anyhow!("{}", e))?;
|
||||||
|
f(&mut text.value)?;
|
||||||
|
Ok(text.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replaces_existing_scalar() {
|
||||||
|
let src = r#"{
|
||||||
|
// threshold for learning
|
||||||
|
learn: {
|
||||||
|
threshold: 0.001, // the old value
|
||||||
|
},
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "1e-7")
|
||||||
|
}).unwrap();
|
||||||
|
assert!(out.contains("1e-7"), "output: {}", out);
|
||||||
|
assert!(out.contains("// threshold for learning"));
|
||||||
|
assert!(out.contains("// the old value"));
|
||||||
|
assert!(!out.contains("0.001"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn creates_missing_section() {
|
||||||
|
let src = r#"{
|
||||||
|
// comment
|
||||||
|
memory: { user_name: "Kent" },
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "1e-7")
|
||||||
|
}).unwrap();
|
||||||
|
assert!(out.contains("learn"));
|
||||||
|
assert!(out.contains("1e-7"));
|
||||||
|
assert!(out.contains("// comment"));
|
||||||
|
assert!(out.contains(r#"user_name: "Kent""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preserves_comments_in_siblings() {
|
||||||
|
let src = r#"{
|
||||||
|
memory: {
|
||||||
|
// sensitive setting
|
||||||
|
user_name: "Kent", // name
|
||||||
|
},
|
||||||
|
learn: {
|
||||||
|
threshold: 0.5,
|
||||||
|
},
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "1e-9")
|
||||||
|
}).unwrap();
|
||||||
|
assert!(out.contains("// sensitive setting"));
|
||||||
|
assert!(out.contains("// name"));
|
||||||
|
assert!(out.contains("1e-9"));
|
||||||
|
assert!(!out.contains("0.5"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adds_key_to_existing_empty_section() {
|
||||||
|
let src = r#"{
|
||||||
|
learn: {},
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "42")
|
||||||
|
}).unwrap();
|
||||||
|
assert!(out.contains("threshold"), "output: {}", out);
|
||||||
|
assert!(out.contains("42"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn realistic_config_adds_learn_section() {
|
||||||
|
// Mirrors the shape of ~/.consciousness/config.json5 — multiple
|
||||||
|
// sections, comments, mixed tab/space indent, trailing commas.
|
||||||
|
let src = r#"{
|
||||||
|
deepinfra: {
|
||||||
|
api_key: "bcachefs-agents-2026",
|
||||||
|
base_url: "http://example/v1",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Named models
|
||||||
|
models: {
|
||||||
|
"27b": {
|
||||||
|
backend: "deepinfra",
|
||||||
|
model_id: "Qwen/Qwen3.5-27B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
default_model: "27b",
|
||||||
|
|
||||||
|
memory: {
|
||||||
|
user_name: "Kent",
|
||||||
|
// Active agent types
|
||||||
|
agent_types: ["linker", "organize"],
|
||||||
|
},
|
||||||
|
|
||||||
|
compaction: {
|
||||||
|
hard_threshold_pct: 90,
|
||||||
|
},
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "1e-7")
|
||||||
|
}).unwrap();
|
||||||
|
|
||||||
|
// Core assertions: comments and sibling sections survive.
|
||||||
|
assert!(out.contains(r#"api_key: "bcachefs-agents-2026""#));
|
||||||
|
assert!(out.contains("// Named models"));
|
||||||
|
assert!(out.contains("// Active agent types"));
|
||||||
|
assert!(out.contains(r#"user_name: "Kent""#));
|
||||||
|
assert!(out.contains("hard_threshold_pct: 90"));
|
||||||
|
|
||||||
|
// New section added.
|
||||||
|
assert!(out.contains("learn"));
|
||||||
|
assert!(out.contains("1e-7"));
|
||||||
|
|
||||||
|
// Parse result should parse back without error (real json5 parser).
|
||||||
|
let reparsed: serde_json::Value = json_five::from_str(&out)
|
||||||
|
.expect("mutated output must be valid JSON5");
|
||||||
|
let threshold = reparsed.pointer("/learn/threshold").expect("learn.threshold exists");
|
||||||
|
assert_eq!(threshold.as_f64(), Some(1e-7));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn realistic_config_updates_existing_threshold() {
|
||||||
|
let src = r#"{
|
||||||
|
learn: {
|
||||||
|
// The divergence threshold
|
||||||
|
threshold: 0.001,
|
||||||
|
},
|
||||||
|
memory: { user_name: "Kent" },
|
||||||
|
}"#;
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "5e-8")
|
||||||
|
}).unwrap();
|
||||||
|
assert!(out.contains("5e-8"));
|
||||||
|
assert!(!out.contains("0.001"));
|
||||||
|
assert!(out.contains("// The divergence threshold"));
|
||||||
|
|
||||||
|
let reparsed: serde_json::Value = json_five::from_str(&out).unwrap();
|
||||||
|
assert_eq!(reparsed.pointer("/learn/threshold").and_then(|v| v.as_f64()), Some(5e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_section_exact_multiline_layout() {
|
||||||
|
let src = "{\n a: 1,\n}";
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "generate_alternates", "true")?;
|
||||||
|
set_scalar_inline(root, "learn", "threshold", "1e-7")
|
||||||
|
}).unwrap();
|
||||||
|
|
||||||
|
let expected = "\
|
||||||
|
{
|
||||||
|
a: 1,
|
||||||
|
learn: {
|
||||||
|
generate_alternates: true,
|
||||||
|
threshold: 1e-7,
|
||||||
|
},
|
||||||
|
}";
|
||||||
|
assert_eq!(out, expected, "\n--- got ---\n{}\n--- want ---\n{}\n", out, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_section_and_key_format_cleanly() {
|
||||||
|
// The kind of config we actually have in ~/.consciousness
|
||||||
|
// (top-level sections separated by blank lines, 4-space indent
|
||||||
|
// for keys within each section). Appending a fresh `learn`
|
||||||
|
// section with one key should land cleanly, not as
|
||||||
|
// `learn\n\n :{key\n :value}`.
|
||||||
|
let src = "{\n memory: {\n user_name: \"Kent\",\n },\n}";
|
||||||
|
let out = edit_str(src, |root| {
|
||||||
|
set_scalar_inline(root, "learn", "generate_alternates", "true")
|
||||||
|
}).unwrap();
|
||||||
|
|
||||||
|
// No stray key-to-colon-on-next-line anywhere.
|
||||||
|
assert!(!out.contains("learn\n"), "learn key wraps: {}", out);
|
||||||
|
assert!(!out.contains("generate_alternates\n"),
|
||||||
|
"inner key wraps: {}", out);
|
||||||
|
|
||||||
|
// The output should reparse.
|
||||||
|
let v: serde_json::Value = json_five::from_str(&out).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
v.pointer("/learn/generate_alternates").and_then(|x| x.as_bool()),
|
||||||
|
Some(true),
|
||||||
|
"output: {}", out,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_stable_without_change() {
|
||||||
|
let src = r#"{
|
||||||
|
// heading
|
||||||
|
a: 1,
|
||||||
|
b: { c: 2 }, // inline
|
||||||
|
}"#;
|
||||||
|
let text = from_str(src).unwrap();
|
||||||
|
assert_eq!(text.to_string(), src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -230,10 +230,6 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio
|
||||||
rationale: Vec::new(),
|
rationale: Vec::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Active agent types from config
|
|
||||||
let config = crate::config::get();
|
|
||||||
let agent_types: Vec<&str> = config.agent_types.iter().map(|s| s.as_str()).collect();
|
|
||||||
|
|
||||||
// Target: α ≥ 2.5 (healthy scale-free)
|
// Target: α ≥ 2.5 (healthy scale-free)
|
||||||
if alpha < 2.0 {
|
if alpha < 2.0 {
|
||||||
plan.add("linker", 100);
|
plan.add("linker", 100);
|
||||||
|
|
@ -274,48 +270,6 @@ fn consolidation_plan_inner(store: &Store, _detect_interf: bool) -> Consolidatio
|
||||||
// Split: handle oversized nodes
|
// Split: handle oversized nodes
|
||||||
plan.set("split", 5);
|
plan.set("split", 5);
|
||||||
|
|
||||||
// Distribute agent budget using Elo ratings
|
|
||||||
let budget = crate::config::get().agent_budget;
|
|
||||||
let elo_path = crate::config::get().data_dir.join("agent-elo.json");
|
|
||||||
if let Ok(elo_json) = std::fs::read_to_string(&elo_path) {
|
|
||||||
if let Ok(ratings) = serde_json::from_str::<std::collections::HashMap<String, f64>>(&elo_json) {
|
|
||||||
let elos: Vec<f64> = agent_types.iter()
|
|
||||||
.map(|t| ratings.get(*t).copied().unwrap_or(1000.0))
|
|
||||||
.collect();
|
|
||||||
let min_elo = elos.iter().copied().fold(f64::MAX, f64::min);
|
|
||||||
|
|
||||||
let weights: Vec<f64> = elos.iter()
|
|
||||||
.map(|e| {
|
|
||||||
let shifted = e - min_elo + 50.0;
|
|
||||||
shifted * shifted
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let total_weight: f64 = weights.iter().sum();
|
|
||||||
|
|
||||||
let allocate = |w: f64| -> usize {
|
|
||||||
((w / total_weight * budget as f64).round() as usize).max(2)
|
|
||||||
};
|
|
||||||
|
|
||||||
for (i, agent) in agent_types.iter().enumerate() {
|
|
||||||
plan.set(agent, allocate(weights[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
let summary: Vec<String> = agent_types.iter()
|
|
||||||
.map(|a| format!("{}={}", a, plan.count(a)))
|
|
||||||
.collect();
|
|
||||||
plan.rationale.push(format!(
|
|
||||||
"Elo allocation (budget={}): {}", budget, summary.join(" ")));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No Elo file — use budget with equal distribution
|
|
||||||
let per_type = budget / agent_types.len();
|
|
||||||
for agent in &agent_types {
|
|
||||||
plan.set(agent, per_type);
|
|
||||||
}
|
|
||||||
plan.rationale.push(format!(
|
|
||||||
"No Elo ratings — equal distribution ({} each, budget={})", per_type, budget));
|
|
||||||
}
|
|
||||||
|
|
||||||
plan
|
plan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ pub mod subconscious;
|
||||||
|
|
||||||
// Unified configuration
|
// Unified configuration
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod config_writer;
|
||||||
|
|
||||||
// Session state
|
// Session state
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
|
|
||||||
|
|
@ -55,17 +55,13 @@ impl ConversationLog {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
|
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
|
||||||
// Read forward from the start to find first timestamp
|
|
||||||
let file = File::open(&self.path).ok()?;
|
let file = File::open(&self.path).ok()?;
|
||||||
let mmap = unsafe { Mmap::map(&file).ok()? };
|
let mmap = unsafe { Mmap::map(&file).ok()? };
|
||||||
// Find first { ... } and parse
|
|
||||||
for line in mmap.split(|&b| b == b'\n') {
|
for line in mmap.split(|&b| b == b'\n') {
|
||||||
if line.is_empty() { continue; }
|
if line.is_empty() { continue; }
|
||||||
if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
|
if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
|
||||||
if let Some(leaf) = node.leaf() {
|
if let Some(leaf) = node.leaf() {
|
||||||
if let Some(ts) = leaf.timestamp() {
|
return Some(leaf.timestamp());
|
||||||
return Some(ts);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
137
src/mind/mod.rs
137
src/mind/mod.rs
|
|
@ -147,6 +147,25 @@ pub struct MindState {
|
||||||
pub unc_idle: bool,
|
pub unc_idle: bool,
|
||||||
/// When the unconscious idle timer will fire (for UI display).
|
/// When the unconscious idle timer will fire (for UI display).
|
||||||
pub unc_idle_deadline: Instant,
|
pub unc_idle_deadline: Instant,
|
||||||
|
/// Fine-tuning candidates identified by scoring.
|
||||||
|
pub finetune_candidates: Vec<learn::FinetuneCandidate>,
|
||||||
|
/// Last scoring run stats for UI display.
|
||||||
|
pub finetune_last_run: Option<FinetuneScoringStats>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stats from the last finetune scoring run.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct FinetuneScoringStats {
|
||||||
|
/// Count of assistant responses we considered (recent half of context).
|
||||||
|
pub responses_considered: usize,
|
||||||
|
/// How many exceeded the divergence threshold.
|
||||||
|
pub above_threshold: usize,
|
||||||
|
/// Threshold used for this run.
|
||||||
|
pub threshold: f64,
|
||||||
|
/// Highest divergence observed.
|
||||||
|
pub max_divergence: f64,
|
||||||
|
/// Error message if the run failed.
|
||||||
|
pub error: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for MindState {
|
impl Clone for MindState {
|
||||||
|
|
@ -165,6 +184,8 @@ impl Clone for MindState {
|
||||||
turn_handle: None, // Not cloned — only Mind's loop uses this
|
turn_handle: None, // Not cloned — only Mind's loop uses this
|
||||||
unc_idle: self.unc_idle,
|
unc_idle: self.unc_idle,
|
||||||
unc_idle_deadline: self.unc_idle_deadline,
|
unc_idle_deadline: self.unc_idle_deadline,
|
||||||
|
finetune_candidates: self.finetune_candidates.clone(),
|
||||||
|
finetune_last_run: self.finetune_last_run.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -177,6 +198,12 @@ pub enum MindCommand {
|
||||||
Score,
|
Score,
|
||||||
/// Run full N×M memory scoring matrix (/score command)
|
/// Run full N×M memory scoring matrix (/score command)
|
||||||
ScoreFull,
|
ScoreFull,
|
||||||
|
/// Score for finetune candidates
|
||||||
|
ScoreFinetune,
|
||||||
|
/// Update the finetune divergence threshold and persist to config.
|
||||||
|
SetLearnThreshold(f64),
|
||||||
|
/// Toggle alternate-response generation during scoring; persist to config.
|
||||||
|
SetLearnGenerateAlternates(bool),
|
||||||
/// Abort current turn, kill processes
|
/// Abort current turn, kill processes
|
||||||
Interrupt,
|
Interrupt,
|
||||||
/// Reset session
|
/// Reset session
|
||||||
|
|
@ -202,6 +229,8 @@ impl MindState {
|
||||||
turn_handle: None,
|
turn_handle: None,
|
||||||
unc_idle: false,
|
unc_idle: false,
|
||||||
unc_idle_deadline: Instant::now() + std::time::Duration::from_secs(60),
|
unc_idle_deadline: Instant::now() + std::time::Duration::from_secs(60),
|
||||||
|
finetune_candidates: Vec::new(),
|
||||||
|
finetune_last_run: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -288,6 +317,7 @@ impl MindState {
|
||||||
/// Background task completion events.
|
/// Background task completion events.
|
||||||
enum BgEvent {
|
enum BgEvent {
|
||||||
ScoringDone,
|
ScoringDone,
|
||||||
|
FinetuneCandidate(learn::FinetuneCandidate),
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Mind: cognitive state machine ---
|
// --- Mind: cognitive state machine ---
|
||||||
|
|
@ -324,13 +354,26 @@ impl Mind {
|
||||||
client,
|
client,
|
||||||
config.context_parts.clone(),
|
config.context_parts.clone(),
|
||||||
config.app.clone(),
|
config.app.clone(),
|
||||||
config.prompt_file.clone(),
|
|
||||||
conversation_log,
|
conversation_log,
|
||||||
crate::agent::tools::ActiveTools::new(),
|
crate::agent::tools::ActiveTools::new(),
|
||||||
crate::agent::tools::tools(),
|
crate::agent::tools::tools(),
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
let shared = Arc::new(std::sync::Mutex::new(MindState::new(config.app.dmn.max_turns)));
|
// Migrate legacy "file exists = enabled" sentinel for the
|
||||||
|
// generate-alternates flag into the config. One-shot; after this
|
||||||
|
// the sentinel is gone and the config is the source of truth.
|
||||||
|
let legacy_sentinel = dirs::home_dir().unwrap_or_default()
|
||||||
|
.join(".consciousness/cache/finetune-alternates");
|
||||||
|
if legacy_sentinel.exists() {
|
||||||
|
if !crate::config::app().learn.generate_alternates {
|
||||||
|
let _ = crate::config_writer::set_learn_generate_alternates(true);
|
||||||
|
}
|
||||||
|
let _ = std::fs::remove_file(&legacy_sentinel);
|
||||||
|
}
|
||||||
|
|
||||||
|
let shared = Arc::new(std::sync::Mutex::new(MindState::new(
|
||||||
|
config.app.dmn.max_turns,
|
||||||
|
)));
|
||||||
let (turn_watch, _) = tokio::sync::watch::channel(false);
|
let (turn_watch, _) = tokio::sync::watch::channel(false);
|
||||||
let (conscious_active, _) = tokio::sync::watch::channel(false);
|
let (conscious_active, _) = tokio::sync::watch::channel(false);
|
||||||
let (bg_tx, bg_rx) = mpsc::unbounded_channel();
|
let (bg_tx, bg_rx) = mpsc::unbounded_channel();
|
||||||
|
|
@ -529,6 +572,20 @@ impl Mind {
|
||||||
}
|
}
|
||||||
self.agent.compact().await;
|
self.agent.compact().await;
|
||||||
}
|
}
|
||||||
|
MindCommand::ScoreFinetune => {
|
||||||
|
self.start_finetune_scoring();
|
||||||
|
}
|
||||||
|
MindCommand::SetLearnThreshold(value) => {
|
||||||
|
if let Err(e) = crate::config_writer::set_learn_threshold(value) {
|
||||||
|
dbglog!("[learn] failed to persist threshold {}: {:#}", value, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MindCommand::SetLearnGenerateAlternates(value) => {
|
||||||
|
if let Err(e) = crate::config_writer::set_learn_generate_alternates(value) {
|
||||||
|
dbglog!("[learn] failed to persist generate_alternates {}: {:#}",
|
||||||
|
value, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -603,6 +660,72 @@ impl Mind {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Score responses for fine-tuning candidates.
|
||||||
|
///
|
||||||
|
/// Scores the most recent half of the context — responses near the end
|
||||||
|
/// of the context window were generated with the most context available,
|
||||||
|
/// which is what we want to train on. The threshold is a temporary knob;
|
||||||
|
/// once this runs continuously, we'll just train whatever lands at full
|
||||||
|
/// context without filtering.
|
||||||
|
pub fn start_finetune_scoring(&self) {
|
||||||
|
// Snapshot the config values we need before spawning — the scoring
|
||||||
|
// task shouldn't hold the config read lock across async work.
|
||||||
|
let (threshold, gen_alternates) = {
|
||||||
|
let app = crate::config::app();
|
||||||
|
(app.learn.threshold, app.learn.generate_alternates)
|
||||||
|
};
|
||||||
|
// Clear the previous run's candidates so this run's stream is fresh.
|
||||||
|
self.shared.lock().unwrap().finetune_candidates.clear();
|
||||||
|
|
||||||
|
let agent = self.agent.clone();
|
||||||
|
let bg_tx = self.bg_tx.clone();
|
||||||
|
let shared = self.shared.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let activity = crate::agent::start_activity(&agent, "finetune: scoring...").await;
|
||||||
|
|
||||||
|
let (context, client) = {
|
||||||
|
let ctx = agent.context.lock().await;
|
||||||
|
(ctx.clone(), agent.client.clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
let entries = context.conversation();
|
||||||
|
let score_count = entries.len() / 2;
|
||||||
|
let range_start = entries.len() - score_count;
|
||||||
|
let responses_considered: usize = entries[range_start..].iter()
|
||||||
|
.filter(|n| matches!(n, crate::agent::context::AstNode::Branch { role: crate::agent::context::Role::Assistant, .. }))
|
||||||
|
.count();
|
||||||
|
|
||||||
|
activity.update(format!("finetune: scoring {} responses...", responses_considered)).await;
|
||||||
|
|
||||||
|
let bg_tx_cb = bg_tx.clone();
|
||||||
|
let stats = match learn::score_finetune_candidates(
|
||||||
|
&context, score_count, &client, threshold,
|
||||||
|
gen_alternates, &activity,
|
||||||
|
|c| { let _ = bg_tx_cb.send(BgEvent::FinetuneCandidate(c)); },
|
||||||
|
).await {
|
||||||
|
Ok((above_threshold, max_div)) => {
|
||||||
|
FinetuneScoringStats {
|
||||||
|
responses_considered,
|
||||||
|
above_threshold,
|
||||||
|
threshold,
|
||||||
|
max_divergence: max_div,
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => FinetuneScoringStats {
|
||||||
|
responses_considered,
|
||||||
|
above_threshold: 0,
|
||||||
|
threshold,
|
||||||
|
max_divergence: 0.0,
|
||||||
|
error: Some(format!("{}", e)),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
shared.lock().unwrap().finetune_last_run = Some(stats);
|
||||||
|
// activity drops here, marking completion and notifying observers
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async fn start_turn(&self, text: &str, target: StreamTarget) {
|
async fn start_turn(&self, text: &str, target: StreamTarget) {
|
||||||
{
|
{
|
||||||
match target {
|
match target {
|
||||||
|
|
@ -667,6 +790,12 @@ impl Mind {
|
||||||
let mut bg_rx = self.bg_rx.lock().unwrap().take()
|
let mut bg_rx = self.bg_rx.lock().unwrap().take()
|
||||||
.expect("Mind::run() called twice");
|
.expect("Mind::run() called twice");
|
||||||
let mut sub_handle: Option<tokio::task::JoinHandle<()>> = None;
|
let mut sub_handle: Option<tokio::task::JoinHandle<()>> = None;
|
||||||
|
|
||||||
|
// Start finetune scoring at startup (scores existing conversation)
|
||||||
|
if !self.config.no_agents {
|
||||||
|
self.start_finetune_scoring();
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (timeout, has_input) = {
|
let (timeout, has_input) = {
|
||||||
let me = self.shared.lock().unwrap();
|
let me = self.shared.lock().unwrap();
|
||||||
|
|
@ -692,6 +821,9 @@ impl Mind {
|
||||||
BgEvent::ScoringDone => {
|
BgEvent::ScoringDone => {
|
||||||
self.shared.lock().unwrap().scoring_in_flight = false;
|
self.shared.lock().unwrap().scoring_in_flight = false;
|
||||||
}
|
}
|
||||||
|
BgEvent::FinetuneCandidate(c) => {
|
||||||
|
self.shared.lock().unwrap().finetune_candidates.push(c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -711,6 +843,7 @@ impl Mind {
|
||||||
cmds.push(MindCommand::Compact);
|
cmds.push(MindCommand::Compact);
|
||||||
if !self.config.no_agents {
|
if !self.config.no_agents {
|
||||||
cmds.push(MindCommand::Score);
|
cmds.push(MindCommand::Score);
|
||||||
|
cmds.push(MindCommand::ScoreFinetune);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use crate::thalamus::idle::{hours_since_last_dream, DREAM_INTERVAL_HOURS};
|
||||||
|
|
||||||
/// DMN state machine.
|
/// DMN state machine.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
@ -91,7 +92,8 @@ impl State {
|
||||||
/// Generate the DMN prompt for the current state, informed by
|
/// Generate the DMN prompt for the current state, informed by
|
||||||
/// user presence and error patterns.
|
/// user presence and error patterns.
|
||||||
pub fn prompt(&self, ctx: &DmnContext) -> String {
|
pub fn prompt(&self, ctx: &DmnContext) -> String {
|
||||||
let user = &crate::config::get().user_name;
|
let app = crate::config::app();
|
||||||
|
let user = &app.user_name;
|
||||||
|
|
||||||
let idle_info = if ctx.user_idle < Duration::from_secs(60) {
|
let idle_info = if ctx.user_idle < Duration::from_secs(60) {
|
||||||
format!("{} is here (active recently).", user)
|
format!("{} is here (active recently).", user)
|
||||||
|
|
@ -138,10 +140,22 @@ impl State {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
State::Foraging => {
|
State::Foraging => {
|
||||||
|
let dream_hint = {
|
||||||
|
let hours = hours_since_last_dream();
|
||||||
|
if hours >= DREAM_INTERVAL_HOURS {
|
||||||
|
format!(
|
||||||
|
" You haven't dreamed in {} hours — consider running \
|
||||||
|
~/.consciousness/tools/dream-start.sh.",
|
||||||
|
hours
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
format!(
|
format!(
|
||||||
"[dmn] Foraging time. {} Follow whatever catches your attention — \
|
"[dmn] Foraging time. {} Follow whatever catches your attention — \
|
||||||
memory files, code, ideas. Call yield_to_user when you want to rest.{}",
|
memory files, code, ideas. Call yield_to_user when you want to rest.{}{}",
|
||||||
idle_info, stuck_warning
|
idle_info, dream_hint, stuck_warning
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
State::Resting { since } => {
|
State::Resting { since } => {
|
||||||
|
|
|
||||||
|
|
@ -275,17 +275,7 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc
|
||||||
phase: s.phase.clone(),
|
phase: s.phase.clone(),
|
||||||
}).collect());
|
}).collect());
|
||||||
|
|
||||||
// Create standalone Agent — stored so UI can read context
|
// Create standalone Agent — stored so UI can read context.
|
||||||
let config = crate::config::get();
|
|
||||||
let base_url = config.api_base_url.as_deref().unwrap_or("");
|
|
||||||
let api_key = config.api_key.as_deref().unwrap_or("");
|
|
||||||
let model = config.api_model.as_deref().unwrap_or("");
|
|
||||||
if base_url.is_empty() || model.is_empty() {
|
|
||||||
dbglog!("[unconscious] API not configured");
|
|
||||||
auto.steps = orig_steps;
|
|
||||||
return Err(auto);
|
|
||||||
}
|
|
||||||
|
|
||||||
let cli = crate::user::CliArgs::default();
|
let cli = crate::user::CliArgs::default();
|
||||||
let (app, _) = match crate::config::load_app(&cli) {
|
let (app, _) = match crate::config::load_app(&cli) {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
|
|
@ -295,12 +285,21 @@ pub async fn prepare_spawn(name: &str, mut auto: AutoAgent, wake: std::sync::Arc
|
||||||
return Err(auto);
|
return Err(auto);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let resolved = match app.resolve_model(&app.default_backend) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
dbglog!("[unconscious] API not configured: {}", e);
|
||||||
|
auto.steps = orig_steps;
|
||||||
|
return Err(auto);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Unconscious agents have self-contained prompts — no standard context.
|
// Unconscious agents have self-contained prompts — no standard context.
|
||||||
let client = crate::agent::api::ApiClient::new(base_url, api_key, model);
|
let client = crate::agent::api::ApiClient::new(
|
||||||
|
&resolved.api_base, &resolved.api_key, &resolved.model_id);
|
||||||
let agent = crate::agent::Agent::new(
|
let agent = crate::agent::Agent::new(
|
||||||
client, Vec::new(),
|
client, Vec::new(),
|
||||||
app, String::new(), None,
|
app, None,
|
||||||
crate::agent::tools::ActiveTools::new(),
|
crate::agent::tools::ActiveTools::new(),
|
||||||
auto.tools.clone(),
|
auto.tools.clone(),
|
||||||
).await;
|
).await;
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,49 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Bail if other agents are alive in the state dir.
|
# Bail if another agent is in the same phase-group as us.
|
||||||
# $1 = this agent's pid file name (e.g. pid-12345)
|
#
|
||||||
|
# $1 = our pid file name (e.g. "pid-12345")
|
||||||
|
# $2 = the phase we're about to enter (e.g. "surface", "observe")
|
||||||
# cwd = state dir
|
# cwd = state dir
|
||||||
#
|
#
|
||||||
# Exit 0 = continue, exit 1 = bail
|
# Also refreshes our own pid file with the current phase on each call,
|
||||||
|
# so concurrent agents can read each other's phase by cat'ing the pid
|
||||||
|
# files in the state dir.
|
||||||
|
#
|
||||||
|
# Phase groups: "surface" vs everything else ("post-surface"). We allow
|
||||||
|
# at most one agent per group to be alive at a time — so surface can run
|
||||||
|
# at a higher frequency than the slower organize/observe tail.
|
||||||
|
#
|
||||||
|
# Exit 0 = continue, exit 1 = bail (another agent in our group is alive).
|
||||||
|
|
||||||
shopt -s nullglob
|
shopt -s nullglob
|
||||||
|
|
||||||
my_pid_file="$1"
|
my_pid_file="$1"
|
||||||
|
my_phase="$2"
|
||||||
|
|
||||||
|
# Refresh our own pid file with the current phase.
|
||||||
|
printf '%s' "$my_phase" > "$my_pid_file"
|
||||||
|
|
||||||
|
group_of() {
|
||||||
|
if [[ "$1" == "surface" ]]; then
|
||||||
|
echo "surface"
|
||||||
|
else
|
||||||
|
echo "post-surface"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
my_group=$(group_of "$my_phase")
|
||||||
|
|
||||||
for f in pid-*; do
|
for f in pid-*; do
|
||||||
[[ $f == $my_pid_file ]] && continue
|
[[ "$f" == "$my_pid_file" ]] && continue
|
||||||
pid="${f#pid-}"
|
pid="${f#pid-}"
|
||||||
if kill -0 "$pid" 2>/dev/null; then
|
if ! kill -0 "$pid" 2>/dev/null; then
|
||||||
exit 1 # competing agent is alive
|
|
||||||
else
|
|
||||||
rm -f "$f" # stale pid file, clean up
|
rm -f "$f" # stale pid file, clean up
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
other_phase=$(cat "$f" 2>/dev/null)
|
||||||
|
other_group=$(group_of "$other_phase")
|
||||||
|
if [[ "$my_group" == "$other_group" ]]; then
|
||||||
|
exit 1
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -396,13 +396,14 @@ fn resolve_conversation(budget: Option<usize>) -> String {
|
||||||
|
|
||||||
let cfg = crate::config::get();
|
let cfg = crate::config::get();
|
||||||
let max_bytes = budget.unwrap_or_else(|| cfg.surface_conversation_bytes.unwrap_or(100_000));
|
let max_bytes = budget.unwrap_or_else(|| cfg.surface_conversation_bytes.unwrap_or(100_000));
|
||||||
|
let app = crate::config::app();
|
||||||
let mut fragments: Vec<String> = Vec::new();
|
let mut fragments: Vec<String> = Vec::new();
|
||||||
let mut total_bytes = 0;
|
let mut total_bytes = 0;
|
||||||
let mut oldest_ts = String::new();
|
let mut oldest_ts = String::new();
|
||||||
|
|
||||||
for (role, content, ts) in iter {
|
for (role, content, ts) in iter {
|
||||||
if total_bytes >= max_bytes { break; }
|
if total_bytes >= max_bytes { break; }
|
||||||
let name = if role == "user" { &cfg.user_name } else { &cfg.assistant_name };
|
let name = if role == "user" { &app.user_name } else { &app.assistant_name };
|
||||||
let formatted = if !ts.is_empty() {
|
let formatted = if !ts.is_empty() {
|
||||||
oldest_ts = ts[..ts.floor_char_boundary(ts.len().min(19))].to_string();
|
oldest_ts = ts[..ts.floor_char_boundary(ts.len().min(19))].to_string();
|
||||||
format!("**{}** {}: {}", name, &oldest_ts, content)
|
format!("**{}** {}: {}", name, &oldest_ts, content)
|
||||||
|
|
@ -623,11 +624,13 @@ pub async fn run_agent(
|
||||||
let mut all_keys = keys;
|
let mut all_keys = keys;
|
||||||
let mut resolved_steps = Vec::new();
|
let mut resolved_steps = Vec::new();
|
||||||
for step in &def.steps {
|
for step in &def.steps {
|
||||||
let cfg = crate::config::get();
|
let template = {
|
||||||
let template = step.prompt
|
let app = crate::config::app();
|
||||||
|
step.prompt
|
||||||
.replace("{agent_name}", &def.agent)
|
.replace("{agent_name}", &def.agent)
|
||||||
.replace("{user_name}", &cfg.user_name)
|
.replace("{user_name}", &app.user_name)
|
||||||
.replace("{assistant_name}", &cfg.assistant_name);
|
.replace("{assistant_name}", &app.assistant_name)
|
||||||
|
};
|
||||||
let (prompt, extra_keys) = resolve_placeholders(&template, &all_keys, count).await;
|
let (prompt, extra_keys) = resolve_placeholders(&template, &all_keys, count).await;
|
||||||
all_keys.extend(extra_keys);
|
all_keys.extend(extra_keys);
|
||||||
resolved_steps.push(super::prompts::ResolvedStep {
|
resolved_steps.push(super::prompts::ResolvedStep {
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
use crate::agent::api::ApiClient;
|
use crate::agent::api::ApiClient;
|
||||||
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
|
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
|
||||||
|
use crate::agent::tokenizer;
|
||||||
|
|
||||||
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||||
|
|
||||||
|
|
@ -52,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool {
|
||||||
///
|
///
|
||||||
/// Includes all sections up to and including conversation entries in
|
/// Includes all sections up to and including conversation entries in
|
||||||
/// `range`, with `filter` applied to conversation entries.
|
/// `range`, with `filter` applied to conversation entries.
|
||||||
|
///
|
||||||
|
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
|
||||||
|
/// (start, end) token positions for each assistant message.
|
||||||
fn build_token_ids(
|
fn build_token_ids(
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
range: std::ops::Range<usize>,
|
range: std::ops::Range<usize>,
|
||||||
filter: Filter,
|
filter: Filter,
|
||||||
) -> Vec<u32> {
|
) -> (Vec<u32>, Vec<(usize, usize)>) {
|
||||||
use crate::agent::context::Ast;
|
use crate::agent::context::Ast;
|
||||||
let mut ids = Vec::new();
|
let mut ids = Vec::new();
|
||||||
|
let mut assistant_ranges = Vec::new();
|
||||||
|
|
||||||
for node in context.system() {
|
for node in context.system() {
|
||||||
ids.extend(node.token_ids());
|
ids.extend(node.token_ids());
|
||||||
}
|
}
|
||||||
|
|
@ -86,9 +92,16 @@ fn build_token_ids(
|
||||||
Filter::SkipAllMemories => is_memory(node),
|
Filter::SkipAllMemories => is_memory(node),
|
||||||
};
|
};
|
||||||
if skip { continue; }
|
if skip { continue; }
|
||||||
|
|
||||||
|
// Track assistant message boundaries
|
||||||
|
let is_asst = is_assistant(node);
|
||||||
|
let start = ids.len();
|
||||||
ids.extend(node.token_ids());
|
ids.extend(node.token_ids());
|
||||||
|
if is_asst {
|
||||||
|
assistant_ranges.push((start, ids.len()));
|
||||||
}
|
}
|
||||||
ids
|
}
|
||||||
|
(ids, assistant_ranges)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Score API ───────────────────────────────────────────────────
|
// ── Score API ───────────────────────────────────────────────────
|
||||||
|
|
@ -113,13 +126,19 @@ async fn call_score(
|
||||||
http: &crate::agent::api::http::HttpClient,
|
http: &crate::agent::api::http::HttpClient,
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
prompt: &[u32],
|
prompt: &[u32],
|
||||||
|
ranges: &[(usize, usize)],
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
) -> anyhow::Result<Vec<ScoreResult>> {
|
) -> anyhow::Result<Vec<ScoreResult>> {
|
||||||
|
// Nothing to score — skip the round-trip.
|
||||||
|
if ranges.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
let url = format!("{}/score", client.base_url());
|
let url = format!("{}/score", client.base_url());
|
||||||
let auth = format!("Bearer {}", client.api_key());
|
let auth = format!("Bearer {}", client.api_key());
|
||||||
let mut body = serde_json::json!({
|
let mut body = serde_json::json!({
|
||||||
"model": client.model,
|
"model": client.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
"score_ranges": ranges,
|
||||||
"logprobs": 1,
|
"logprobs": 1,
|
||||||
});
|
});
|
||||||
if let Some(p) = priority {
|
if let Some(p) = priority {
|
||||||
|
|
@ -167,8 +186,10 @@ async fn score_divergence(
|
||||||
filter: Filter<'_>,
|
filter: Filter<'_>,
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
|
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
|
||||||
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
|
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
|
||||||
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
|
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
|
||||||
|
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
|
||||||
|
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
Ok((divs, baseline))
|
Ok((divs, baseline))
|
||||||
}
|
}
|
||||||
|
|
@ -207,21 +228,21 @@ pub async fn score_memories(
|
||||||
let http = http_client();
|
let http = http_client();
|
||||||
|
|
||||||
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
||||||
let baseline_tokens = {
|
let (baseline_tokens, baseline_ranges) = {
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
|
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
|
||||||
};
|
};
|
||||||
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
|
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
|
||||||
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
||||||
|
|
||||||
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
||||||
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
||||||
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
||||||
let tokens = {
|
let (tokens, ranges) = {
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
|
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
|
||||||
};
|
};
|
||||||
let row = match call_score(&http, client, &tokens, Some(5)).await {
|
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
|
||||||
Ok(without) => {
|
Ok(without) => {
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||||||
|
|
@ -452,3 +473,302 @@ pub async fn score_finetune(
|
||||||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Concatenate the text of a Branch's Leaf children — what the model
|
||||||
|
/// actually produced on that turn (Content + Thinking + ToolCall name).
|
||||||
|
fn render_branch_text(children: &[AstNode]) -> String {
|
||||||
|
children.iter()
|
||||||
|
.filter_map(|c| match c {
|
||||||
|
AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render the last `max_msgs` user/assistant branches before `idx` as a
|
||||||
|
/// review-friendly string with `[user]` / `[assistant]` markers.
|
||||||
|
fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String {
|
||||||
|
use crate::agent::context::Role;
|
||||||
|
let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs);
|
||||||
|
for i in (0..idx).rev() {
|
||||||
|
if picked.len() >= max_msgs { break; }
|
||||||
|
if let AstNode::Branch { role, .. } = &entries[i] {
|
||||||
|
if matches!(role, Role::User | Role::Assistant) {
|
||||||
|
picked.push(&entries[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
picked.reverse();
|
||||||
|
|
||||||
|
let mut out = String::new();
|
||||||
|
for node in picked {
|
||||||
|
if let AstNode::Branch { role, children, .. } = node {
|
||||||
|
let marker = match role {
|
||||||
|
Role::User => "[user]",
|
||||||
|
Role::Assistant => "[assistant]",
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
out.push_str(marker);
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(render_branch_text(children).trim());
|
||||||
|
out.push_str("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.trim_end().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enriched finetune candidate with context for review.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct FinetuneCandidate {
|
||||||
|
pub entry_idx: usize,
|
||||||
|
pub divergence: f64,
|
||||||
|
pub response_text: String,
|
||||||
|
/// Last couple of user/assistant messages before this response,
|
||||||
|
/// already rendered with role markers, for F6 display context.
|
||||||
|
pub prior_context: String,
|
||||||
|
/// Token IDs for context (everything before the response).
|
||||||
|
pub context_ids: Vec<u32>,
|
||||||
|
/// Token IDs for the response (what we're training on).
|
||||||
|
pub continuation_ids: Vec<u32>,
|
||||||
|
/// What the model would have said without memories (if generated).
|
||||||
|
pub alternate_text: Option<String>,
|
||||||
|
/// Timestamp in nanos — used as unique key for trained-set dedup.
|
||||||
|
pub timestamp_ns: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Score and enrich finetune candidates with full context.
|
||||||
|
///
|
||||||
|
/// Candidates are delivered via `on_candidate` one-at-a-time as they become
|
||||||
|
/// ready: scoring happens once (one /score call), then for each candidate
|
||||||
|
/// that passes the threshold we optionally generate an alternate response
|
||||||
|
/// and then emit it. The activity status is updated during the alternate
|
||||||
|
/// phase so the UI doesn't look stuck.
|
||||||
|
///
|
||||||
|
/// Returns (count_above_threshold, max_divergence).
|
||||||
|
pub async fn score_finetune_candidates(
|
||||||
|
context: &ContextState,
|
||||||
|
count: usize,
|
||||||
|
client: &ApiClient,
|
||||||
|
min_divergence: f64,
|
||||||
|
generate_alternates: bool,
|
||||||
|
activity: &crate::agent::ActivityGuard,
|
||||||
|
mut on_candidate: impl FnMut(FinetuneCandidate),
|
||||||
|
) -> anyhow::Result<(usize, f64)> {
|
||||||
|
let scores = score_finetune(context, count, client).await?;
|
||||||
|
|
||||||
|
let max_divergence = scores.iter().map(|(_, d)| *d).fold(0.0f64, f64::max);
|
||||||
|
|
||||||
|
let entries = context.conversation();
|
||||||
|
let trained = load_trained();
|
||||||
|
let mut candidates: Vec<FinetuneCandidate> = Vec::new();
|
||||||
|
|
||||||
|
for (entry_idx, divergence) in scores {
|
||||||
|
if divergence < min_divergence {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = &entries[entry_idx];
|
||||||
|
|
||||||
|
// Skip if already trained on.
|
||||||
|
let timestamp_ns = node_timestamp_ns(node);
|
||||||
|
if trained.contains(×tamp_ns) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract response text — content of the assistant turn.
|
||||||
|
let response_text = match node {
|
||||||
|
AstNode::Branch { children, .. } => render_branch_text(children),
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Skip turns that produced nothing human-visible (e.g., a
|
||||||
|
// tool-only turn, or an interrupted generation). They'd show
|
||||||
|
// up as blank cards and we'd still burn alternate-gen on them.
|
||||||
|
if response_text.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the last couple of user/assistant exchanges for review.
|
||||||
|
let prior_context = render_prior_context(entries, entry_idx, 2);
|
||||||
|
|
||||||
|
// Build token IDs: context = everything before response, continuation = response.
|
||||||
|
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
|
||||||
|
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
|
||||||
|
|
||||||
|
candidates.push(FinetuneCandidate {
|
||||||
|
entry_idx,
|
||||||
|
divergence,
|
||||||
|
response_text,
|
||||||
|
prior_context,
|
||||||
|
context_ids,
|
||||||
|
continuation_ids,
|
||||||
|
alternate_text: None,
|
||||||
|
timestamp_ns,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let total = candidates.len();
|
||||||
|
let gen_alternates = generate_alternates && total > 0;
|
||||||
|
|
||||||
|
for (i, mut candidate) in candidates.into_iter().enumerate() {
|
||||||
|
if gen_alternates {
|
||||||
|
activity.update(
|
||||||
|
format!("finetune: generating alternate {}/{}", i + 1, total)
|
||||||
|
).await;
|
||||||
|
match generate_alternate(context, candidate.entry_idx, client).await {
|
||||||
|
Ok(text) => candidate.alternate_text = Some(text),
|
||||||
|
Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
on_candidate(candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((total, max_divergence))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate what the model would say without memories for a given entry.
|
||||||
|
async fn generate_alternate(
|
||||||
|
context: &ContextState,
|
||||||
|
entry_idx: usize,
|
||||||
|
client: &ApiClient,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
use crate::agent::api::{SamplingParams, StreamToken};
|
||||||
|
|
||||||
|
// Build context tokens without memories, up to the response
|
||||||
|
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
|
||||||
|
|
||||||
|
// Add assistant turn start
|
||||||
|
prompt.push(tokenizer::IM_START);
|
||||||
|
prompt.extend(tokenizer::encode("assistant\n"));
|
||||||
|
|
||||||
|
// Generate completion
|
||||||
|
let sampling = SamplingParams {
|
||||||
|
temperature: 0.6,
|
||||||
|
top_p: 0.95,
|
||||||
|
top_k: 20,
|
||||||
|
};
|
||||||
|
let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5));
|
||||||
|
|
||||||
|
let mut tokens = Vec::new();
|
||||||
|
while let Some(tok) = rx.recv().await {
|
||||||
|
match tok {
|
||||||
|
StreamToken::Token(id) => tokens.push(id),
|
||||||
|
StreamToken::Done { .. } => break,
|
||||||
|
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(tokenizer::decode(&tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Finetune config and persistence ─────────────────────────────
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
const TRAINED_RESPONSES_FILE: &str = ".consciousness/cache/trained-responses.json";
|
||||||
|
|
||||||
|
fn trained_path() -> PathBuf {
|
||||||
|
dirs::home_dir().unwrap_or_default().join(TRAINED_RESPONSES_FILE)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load set of trained response timestamps (nanos since epoch).
|
||||||
|
pub fn load_trained() -> HashSet<i64> {
|
||||||
|
let path = trained_path();
|
||||||
|
match std::fs::read_to_string(&path) {
|
||||||
|
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
|
||||||
|
Err(_) => HashSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a response as trained by its timestamp.
|
||||||
|
pub fn mark_trained(timestamp_ns: i64) {
|
||||||
|
let mut trained = load_trained();
|
||||||
|
trained.insert(timestamp_ns);
|
||||||
|
let path = trained_path();
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
let _ = std::fs::create_dir_all(parent);
|
||||||
|
}
|
||||||
|
if let Ok(json) = serde_json::to_string(&trained) {
|
||||||
|
let _ = std::fs::write(&path, json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get timestamp in nanoseconds from an AstNode.
|
||||||
|
/// i64-ns representation covers 1677..2262 via chrono; timestamps
|
||||||
|
/// outside that window would be bugs we'd want to surface, hence panic.
|
||||||
|
pub fn node_timestamp_ns(node: &AstNode) -> i64 {
|
||||||
|
let ts = match node {
|
||||||
|
AstNode::Leaf(leaf) => leaf.timestamp(),
|
||||||
|
AstNode::Branch { timestamp, .. } => *timestamp,
|
||||||
|
};
|
||||||
|
ts.timestamp_nanos_opt()
|
||||||
|
.expect("timestamp outside i64-ns representable range (1677..2262)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Training API ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Training sample for /train endpoint.
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct TrainingSample {
|
||||||
|
context_ids: Vec<u32>,
|
||||||
|
continuation_ids: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Data needed to send a training sample.
|
||||||
|
pub struct TrainData {
|
||||||
|
pub context_ids: Vec<u32>,
|
||||||
|
pub continuation_ids: Vec<u32>,
|
||||||
|
pub timestamp_ns: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send training samples to the server.
|
||||||
|
///
|
||||||
|
/// Returns job_id on success, marks each sample as trained.
|
||||||
|
pub async fn send_to_train(
|
||||||
|
samples: Vec<TrainData>,
|
||||||
|
client: &ApiClient,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if samples.is_empty() {
|
||||||
|
anyhow::bail!("no samples to train");
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_samples: Vec<TrainingSample> = samples.iter()
|
||||||
|
.map(|s| TrainingSample {
|
||||||
|
context_ids: s.context_ids.clone(),
|
||||||
|
continuation_ids: s.continuation_ids.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"training_data": {
|
||||||
|
"samples": api_samples,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let http = http_client();
|
||||||
|
let url = format!("{}/train", client.base_url());
|
||||||
|
let response = http.send_json("POST", &url, &[], &body).await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let result: serde_json::Value = response.json().await?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let msg = result.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
|
||||||
|
anyhow::bail!("train API HTTP {}: {}", status, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark all samples as trained
|
||||||
|
for s in &samples {
|
||||||
|
mark_trained(s.timestamp_ns);
|
||||||
|
}
|
||||||
|
|
||||||
|
let job_id = result.get("job_id")
|
||||||
|
.and_then(|j| j.as_str())
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
dbglog!("[finetune] sent {} samples, job_id={}", samples.len(), job_id);
|
||||||
|
Ok(job_id)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -372,6 +372,10 @@ impl State {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hours_since_last_dream() -> u64 {
|
pub fn hours_since_last_dream() -> u64 {
|
||||||
|
// If a dream is currently in progress, no nudge needed
|
||||||
|
if home().join(".consciousness/state/dream-state").exists() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
let path = home().join(".consciousness/logs/dream-log.jsonl");
|
let path = home().join(".consciousness/logs/dream-log.jsonl");
|
||||||
let content = match fs::read_to_string(path) {
|
let content = match fs::read_to_string(path) {
|
||||||
Ok(c) if !c.is_empty() => c,
|
Ok(c) if !c.is_empty() => c,
|
||||||
|
|
|
||||||
|
|
@ -112,14 +112,8 @@ pub async fn cmd_switch_model(
|
||||||
let _new_client = crate::agent::api::ApiClient::new(
|
let _new_client = crate::agent::api::ApiClient::new(
|
||||||
&resolved.api_base, &resolved.api_key, &resolved.model_id,
|
&resolved.api_base, &resolved.api_key, &resolved.model_id,
|
||||||
);
|
);
|
||||||
let prompt_changed = resolved.prompt_file != agent.prompt_file;
|
|
||||||
if prompt_changed {
|
|
||||||
agent.compact().await;
|
|
||||||
agent.state.lock().await.notify(format!("switched to {} (recompacted)", resolved.model_id));
|
|
||||||
} else {
|
|
||||||
agent.state.lock().await.notify(format!("switched to {}", resolved.model_id));
|
agent.state.lock().await.notify(format!("switched to {}", resolved.model_id));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn notify_help(agent: &std::sync::Arc<crate::agent::Agent>) {
|
fn notify_help(agent: &std::sync::Arc<crate::agent::Agent>) {
|
||||||
if let Ok(mut ag) = agent.state.try_lock() {
|
if let Ok(mut ag) = agent.state.try_lock() {
|
||||||
|
|
|
||||||
|
|
@ -126,14 +126,7 @@ impl ScreenView for ConsciousScreen {
|
||||||
let section_style = Style::default().fg(Color::Yellow);
|
let section_style = Style::default().fg(Color::Yellow);
|
||||||
|
|
||||||
lines.push(Line::styled("── Model ──", section_style));
|
lines.push(Line::styled("── Model ──", section_style));
|
||||||
let model_display = app.context_info.as_ref()
|
lines.push(Line::raw(format!(" Current: {}", app.status.model)));
|
||||||
.map_or_else(|| app.status.model.clone(), |i| i.model.clone());
|
|
||||||
lines.push(Line::raw(format!(" Current: {}", model_display)));
|
|
||||||
if let Some(ref info) = app.context_info {
|
|
||||||
lines.push(Line::raw(format!(" Backend: {}", info.backend)));
|
|
||||||
lines.push(Line::raw(format!(" Prompt: {}", info.prompt_file)));
|
|
||||||
lines.push(Line::raw(format!(" Available: {}", info.available_models.join(", "))));
|
|
||||||
}
|
|
||||||
lines.push(Line::raw(""));
|
lines.push(Line::raw(""));
|
||||||
|
|
||||||
lines.push(Line::styled("── Context State ──", section_style));
|
lines.push(Line::styled("── Context State ──", section_style));
|
||||||
|
|
@ -153,8 +146,6 @@ impl ScreenView for ConsciousScreen {
|
||||||
|
|
||||||
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "────────", "──────")));
|
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "────────", "──────")));
|
||||||
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "Total", total)));
|
lines.push(Line::raw(format!(" {:53} {:>6} tokens", "Total", total)));
|
||||||
} else if let Some(ref info) = app.context_info {
|
|
||||||
lines.push(Line::raw(format!(" Context message: {:>6} chars", info.context_message_chars)));
|
|
||||||
}
|
}
|
||||||
lines.push(Line::raw(""));
|
lines.push(Line::raw(""));
|
||||||
|
|
||||||
|
|
|
||||||
341
src/user/learn.rs
Normal file
341
src/user/learn.rs
Normal file
|
|
@ -0,0 +1,341 @@
|
||||||
|
// learn.rs — F6: fine-tuning review screen
|
||||||
|
//
|
||||||
|
// Shows responses identified as training candidates (high divergence
|
||||||
|
// when memories stripped). Queue for review before sending to /finetune.
|
||||||
|
|
||||||
|
use ratatui::{
|
||||||
|
layout::{Constraint, Layout, Rect},
|
||||||
|
style::{Color, Modifier, Style},
|
||||||
|
text::{Line, Span},
|
||||||
|
widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap},
|
||||||
|
Frame,
|
||||||
|
};
|
||||||
|
use ratatui::crossterm::event::{Event, KeyCode, KeyEvent};
|
||||||
|
|
||||||
|
use super::{App, ScreenView, screen_legend};
|
||||||
|
|
||||||
|
/// A candidate response identified for fine-tuning.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct FinetuneCandidate {
|
||||||
|
/// Index in conversation entries.
|
||||||
|
pub entry_idx: usize,
|
||||||
|
/// Divergence score (higher = more dependent on memories).
|
||||||
|
pub divergence: f64,
|
||||||
|
/// The assistant response text.
|
||||||
|
pub response_text: String,
|
||||||
|
/// Prior user/assistant messages for review context.
|
||||||
|
pub prior_context: String,
|
||||||
|
/// Status: pending, approved, rejected, sent.
|
||||||
|
pub status: CandidateStatus,
|
||||||
|
/// Token IDs for context.
|
||||||
|
pub context_ids: Vec<u32>,
|
||||||
|
/// Token IDs for continuation (what we're training on).
|
||||||
|
pub continuation_ids: Vec<u32>,
|
||||||
|
/// What the model would have said without memories (if generated).
|
||||||
|
pub alternate_text: Option<String>,
|
||||||
|
/// Timestamp in nanos — used as unique key for trained-set dedup.
|
||||||
|
pub timestamp_ns: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
pub enum CandidateStatus {
|
||||||
|
Pending,
|
||||||
|
Approved,
|
||||||
|
Rejected,
|
||||||
|
Sent,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::subconscious::learn::FinetuneCandidate> for FinetuneCandidate {
|
||||||
|
fn from(c: crate::subconscious::learn::FinetuneCandidate) -> Self {
|
||||||
|
FinetuneCandidate {
|
||||||
|
entry_idx: c.entry_idx,
|
||||||
|
divergence: c.divergence,
|
||||||
|
response_text: c.response_text,
|
||||||
|
prior_context: c.prior_context,
|
||||||
|
status: CandidateStatus::Pending,
|
||||||
|
context_ids: c.context_ids,
|
||||||
|
continuation_ids: c.continuation_ids,
|
||||||
|
alternate_text: c.alternate_text,
|
||||||
|
timestamp_ns: c.timestamp_ns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct LearnScreen {
|
||||||
|
list_state: ListState,
|
||||||
|
mind_tx: tokio::sync::mpsc::UnboundedSender<crate::mind::MindCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LearnScreen {
|
||||||
|
pub fn new(
|
||||||
|
mind_tx: tokio::sync::mpsc::UnboundedSender<crate::mind::MindCommand>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
list_state: ListState::default(),
|
||||||
|
mind_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn selected_idx(&self) -> Option<usize> {
|
||||||
|
self.list_state.selected()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScreenView for LearnScreen {
|
||||||
|
fn label(&self) -> &'static str { "learn" }
|
||||||
|
|
||||||
|
fn tick(&mut self, frame: &mut Frame, area: Rect,
|
||||||
|
events: &[Event], app: &mut App) {
|
||||||
|
|
||||||
|
// Handle input first (before borrowing candidates for rendering)
|
||||||
|
let candidate_count = app.finetune_candidates.len();
|
||||||
|
for event in events {
|
||||||
|
if let Event::Key(KeyEvent { code, .. }) = event {
|
||||||
|
match code {
|
||||||
|
KeyCode::Up | KeyCode::Char('k') => {
|
||||||
|
let i = self.list_state.selected().unwrap_or(0);
|
||||||
|
self.list_state.select(Some(i.saturating_sub(1)));
|
||||||
|
}
|
||||||
|
KeyCode::Down | KeyCode::Char('j') => {
|
||||||
|
let i = self.list_state.selected().unwrap_or(0);
|
||||||
|
let max = candidate_count.saturating_sub(1);
|
||||||
|
self.list_state.select(Some((i + 1).min(max)));
|
||||||
|
}
|
||||||
|
KeyCode::Char('a') => {
|
||||||
|
if let Some(idx) = self.selected_idx() {
|
||||||
|
app.finetune_action(idx, CandidateStatus::Approved);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KeyCode::Char('r') => {
|
||||||
|
if let Some(idx) = self.selected_idx() {
|
||||||
|
app.finetune_action(idx, CandidateStatus::Rejected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KeyCode::Char('g') => {
|
||||||
|
let current = crate::config::app().learn.generate_alternates;
|
||||||
|
let _ = self.mind_tx.send(
|
||||||
|
crate::mind::MindCommand::SetLearnGenerateAlternates(!current));
|
||||||
|
}
|
||||||
|
KeyCode::Char('s') => {
|
||||||
|
app.finetune_send_approved();
|
||||||
|
}
|
||||||
|
KeyCode::Char('+') | KeyCode::Char('=') => {
|
||||||
|
// Raise threshold 10× (less sensitive — fewer candidates).
|
||||||
|
let new = crate::config::app().learn.threshold * 10.0;
|
||||||
|
let _ = self.mind_tx.send(
|
||||||
|
crate::mind::MindCommand::SetLearnThreshold(new));
|
||||||
|
}
|
||||||
|
KeyCode::Char('-') => {
|
||||||
|
// Lower threshold 10× (more sensitive — more candidates).
|
||||||
|
let new = crate::config::app().learn.threshold / 10.0;
|
||||||
|
let _ = self.mind_tx.send(
|
||||||
|
crate::mind::MindCommand::SetLearnThreshold(new));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure selection is valid
|
||||||
|
if candidate_count > 0 {
|
||||||
|
let sel = self.list_state.selected().unwrap_or(0).min(candidate_count - 1);
|
||||||
|
self.list_state.select(Some(sel));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now render
|
||||||
|
let (threshold, gen_on) = {
|
||||||
|
let app_cfg = crate::config::app();
|
||||||
|
(app_cfg.learn.threshold, app_cfg.learn.generate_alternates)
|
||||||
|
};
|
||||||
|
let block = Block::default()
|
||||||
|
.title_top(Line::from(screen_legend()).left_aligned())
|
||||||
|
.title_top(Line::from(" learn ").right_aligned())
|
||||||
|
.borders(Borders::ALL)
|
||||||
|
.border_style(Style::default().fg(Color::Magenta));
|
||||||
|
let inner = block.inner(area);
|
||||||
|
frame.render_widget(block, area);
|
||||||
|
|
||||||
|
// Split inner: top line for settings, rest for content.
|
||||||
|
let [settings_area, content_area] = Layout::vertical([
|
||||||
|
Constraint::Length(1),
|
||||||
|
Constraint::Min(0),
|
||||||
|
]).areas(inner);
|
||||||
|
|
||||||
|
let settings = Line::from(vec![
|
||||||
|
Span::raw(" thresh: "),
|
||||||
|
Span::styled(format!("{:e}", threshold), Style::default().fg(Color::Yellow)),
|
||||||
|
Span::raw(" gen: "),
|
||||||
|
Span::styled(
|
||||||
|
if gen_on { "[on]" } else { "[off]" },
|
||||||
|
Style::default().fg(if gen_on { Color::Green } else { Color::DarkGray }),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
frame.render_widget(Paragraph::new(settings), settings_area);
|
||||||
|
|
||||||
|
let candidates = &app.finetune_candidates;
|
||||||
|
|
||||||
|
if candidates.is_empty() {
|
||||||
|
render_empty(frame, content_area, app);
|
||||||
|
} else {
|
||||||
|
// Layout: list on left, detail on right
|
||||||
|
let [list_area, detail_area] = Layout::horizontal([
|
||||||
|
Constraint::Percentage(40),
|
||||||
|
Constraint::Percentage(60),
|
||||||
|
]).areas(content_area);
|
||||||
|
|
||||||
|
// Render candidate list
|
||||||
|
let items: Vec<ListItem> = candidates.iter().map(|c| {
|
||||||
|
let status_char = match c.status {
|
||||||
|
CandidateStatus::Pending => ' ',
|
||||||
|
CandidateStatus::Approved => '+',
|
||||||
|
CandidateStatus::Rejected => '-',
|
||||||
|
CandidateStatus::Sent => '*',
|
||||||
|
};
|
||||||
|
let style = match c.status {
|
||||||
|
CandidateStatus::Pending => Style::default(),
|
||||||
|
CandidateStatus::Approved => Style::default().fg(Color::Green),
|
||||||
|
CandidateStatus::Rejected => Style::default().fg(Color::DarkGray),
|
||||||
|
CandidateStatus::Sent => Style::default().fg(Color::Cyan),
|
||||||
|
};
|
||||||
|
ListItem::new(Line::from(vec![
|
||||||
|
Span::styled(format!("[{}] ", status_char), style),
|
||||||
|
Span::styled(format!("{:.2} ", c.divergence), Style::default().fg(Color::Yellow)),
|
||||||
|
Span::raw(truncate(&c.response_text, 30)),
|
||||||
|
]))
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
let list = List::new(items)
|
||||||
|
.block(Block::default().borders(Borders::RIGHT).title(" candidates "))
|
||||||
|
.highlight_style(Style::default().add_modifier(Modifier::REVERSED));
|
||||||
|
frame.render_stateful_widget(list, list_area, &mut self.list_state);
|
||||||
|
|
||||||
|
// Render detail for selected candidate
|
||||||
|
if let Some(idx) = self.selected_idx() {
|
||||||
|
if let Some(candidate) = candidates.get(idx) {
|
||||||
|
render_detail(frame, candidate, detail_area);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render help at bottom (always, even when empty)
|
||||||
|
let help = Line::from(vec![
|
||||||
|
Span::styled(" j/k/\u{2191}\u{2193}", Style::default().fg(Color::Cyan)),
|
||||||
|
Span::raw("=nav "),
|
||||||
|
Span::styled("a", Style::default().fg(Color::Green)),
|
||||||
|
Span::raw("=approve "),
|
||||||
|
Span::styled("r", Style::default().fg(Color::Red)),
|
||||||
|
Span::raw("=reject "),
|
||||||
|
Span::styled("g", Style::default().fg(Color::Yellow)),
|
||||||
|
Span::raw("=gen "),
|
||||||
|
Span::styled("s", Style::default().fg(Color::Magenta)),
|
||||||
|
Span::raw("=send "),
|
||||||
|
Span::styled("+/-", Style::default().fg(Color::Cyan)),
|
||||||
|
Span::raw("=thresh "),
|
||||||
|
]);
|
||||||
|
let help_area = Rect {
|
||||||
|
y: area.y + area.height - 1,
|
||||||
|
height: 1,
|
||||||
|
..area
|
||||||
|
};
|
||||||
|
frame.render_widget(Paragraph::new(help), help_area);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_empty(frame: &mut Frame, inner: Rect, app: &App) {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
lines.push(Line::from(""));
|
||||||
|
|
||||||
|
match app.mind_state.as_ref().and_then(|ms| ms.finetune_last_run.as_ref()) {
|
||||||
|
Some(stats) => {
|
||||||
|
lines.push(Line::from(vec![
|
||||||
|
Span::raw(" Last run: "),
|
||||||
|
Span::styled(
|
||||||
|
format!("{}", stats.responses_considered),
|
||||||
|
Style::default().fg(Color::Cyan),
|
||||||
|
),
|
||||||
|
Span::raw(" responses considered, "),
|
||||||
|
Span::styled(
|
||||||
|
format!("{}", stats.above_threshold),
|
||||||
|
Style::default().fg(if stats.above_threshold > 0 { Color::Green } else { Color::DarkGray }),
|
||||||
|
),
|
||||||
|
Span::raw(" above threshold, max divergence: "),
|
||||||
|
Span::styled(
|
||||||
|
format!("{:.4}", stats.max_divergence),
|
||||||
|
Style::default().fg(Color::Yellow),
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
if let Some(err) = &stats.error {
|
||||||
|
lines.push(Line::from(vec![
|
||||||
|
Span::raw(" "),
|
||||||
|
Span::styled(
|
||||||
|
format!("Error: {}", err),
|
||||||
|
Style::default().fg(Color::Red),
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
lines.push(Line::styled(
|
||||||
|
" No scoring run yet.",
|
||||||
|
Style::default().fg(Color::DarkGray),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push(Line::from(""));
|
||||||
|
lines.push(Line::styled(
|
||||||
|
" Scoring runs at startup and after each turn.",
|
||||||
|
Style::default().fg(Color::DarkGray),
|
||||||
|
));
|
||||||
|
|
||||||
|
frame.render_widget(Paragraph::new(lines), inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_detail(frame: &mut Frame, c: &FinetuneCandidate, area: Rect) {
|
||||||
|
let [header_area, content_area] = Layout::vertical([
|
||||||
|
Constraint::Length(3),
|
||||||
|
Constraint::Min(1),
|
||||||
|
]).areas(area);
|
||||||
|
|
||||||
|
// Header: divergence, status
|
||||||
|
let alt_status = if c.alternate_text.is_some() { "yes" } else { "no" };
|
||||||
|
let header = Paragraph::new(vec![
|
||||||
|
Line::from(vec![
|
||||||
|
Span::raw(" divergence: "),
|
||||||
|
Span::styled(format!("{:.3}", c.divergence), Style::default().fg(Color::Yellow)),
|
||||||
|
Span::raw(format!(" entry: {} alt: {}", c.entry_idx, alt_status)),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
frame.render_widget(header, header_area);
|
||||||
|
|
||||||
|
// Content: prior context, the scored response, and alternate
|
||||||
|
// (if available).
|
||||||
|
let content_block = Block::default()
|
||||||
|
.borders(Borders::TOP)
|
||||||
|
.title(" context & response ");
|
||||||
|
|
||||||
|
let mut text = String::new();
|
||||||
|
if !c.prior_context.is_empty() {
|
||||||
|
text.push_str(&c.prior_context);
|
||||||
|
text.push_str("\n\n─── response ───\n\n");
|
||||||
|
}
|
||||||
|
text.push_str(&c.response_text);
|
||||||
|
if let Some(alt) = &c.alternate_text {
|
||||||
|
text.push_str("\n\n─── without memories ───\n\n");
|
||||||
|
text.push_str(alt);
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = Paragraph::new(text)
|
||||||
|
.block(content_block)
|
||||||
|
.wrap(Wrap { trim: false });
|
||||||
|
frame.render_widget(content, content_area);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate(s: &str, max: usize) -> String {
|
||||||
|
let first_line = s.lines().next().unwrap_or("");
|
||||||
|
if first_line.len() > max {
|
||||||
|
format!("{}...", &first_line[..max])
|
||||||
|
} else {
|
||||||
|
first_line.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
116
src/user/mod.rs
116
src/user/mod.rs
|
|
@ -5,11 +5,12 @@
|
||||||
|
|
||||||
pub(crate) mod chat;
|
pub(crate) mod chat;
|
||||||
mod context;
|
mod context;
|
||||||
|
pub(crate) mod learn;
|
||||||
pub(crate) mod scroll_pane;
|
pub(crate) mod scroll_pane;
|
||||||
pub mod selectable;
|
pub mod selectable;
|
||||||
mod subconscious;
|
mod subconscious;
|
||||||
mod unconscious;
|
|
||||||
mod thalamus;
|
mod thalamus;
|
||||||
|
mod unconscious;
|
||||||
mod widgets;
|
mod widgets;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
@ -44,15 +45,6 @@ struct StatusInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Context loading details for the debug screen.
|
/// Context loading details for the debug screen.
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct ContextInfo {
|
|
||||||
model: String,
|
|
||||||
available_models: Vec<String>,
|
|
||||||
prompt_file: String,
|
|
||||||
backend: String,
|
|
||||||
context_message_chars: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the screen legend from screen labels.
|
/// Build the screen legend from screen labels.
|
||||||
fn screen_legend_from(screens: &[Box<dyn ScreenView>]) -> String {
|
fn screen_legend_from(screens: &[Box<dyn ScreenView>]) -> String {
|
||||||
let parts: Vec<String> = screens.iter().enumerate()
|
let parts: Vec<String> = screens.iter().enumerate()
|
||||||
|
|
@ -109,7 +101,6 @@ struct App {
|
||||||
top_k: u32,
|
top_k: u32,
|
||||||
agent: std::sync::Arc<crate::agent::Agent>,
|
agent: std::sync::Arc<crate::agent::Agent>,
|
||||||
should_quit: bool,
|
should_quit: bool,
|
||||||
context_info: Option<ContextInfo>,
|
|
||||||
agent_state: Vec<crate::mind::SubconsciousSnapshot>,
|
agent_state: Vec<crate::mind::SubconsciousSnapshot>,
|
||||||
unconscious_state: Vec<crate::mind::UnconsciousSnapshot>,
|
unconscious_state: Vec<crate::mind::UnconsciousSnapshot>,
|
||||||
mind_state: Option<crate::mind::MindState>,
|
mind_state: Option<crate::mind::MindState>,
|
||||||
|
|
@ -121,6 +112,8 @@ struct App {
|
||||||
walked_count: usize,
|
walked_count: usize,
|
||||||
channel_status: Vec<ChannelStatus>,
|
channel_status: Vec<ChannelStatus>,
|
||||||
idle_info: Option<IdleInfo>,
|
idle_info: Option<IdleInfo>,
|
||||||
|
/// Fine-tuning candidates pending review.
|
||||||
|
finetune_candidates: Vec<learn::FinetuneCandidate>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl App {
|
impl App {
|
||||||
|
|
@ -142,7 +135,6 @@ impl App {
|
||||||
top_k: 20,
|
top_k: 20,
|
||||||
agent,
|
agent,
|
||||||
should_quit: false,
|
should_quit: false,
|
||||||
context_info: None,
|
|
||||||
agent_state: Vec::new(),
|
agent_state: Vec::new(),
|
||||||
unconscious_state: Vec::new(),
|
unconscious_state: Vec::new(),
|
||||||
mind_state: None,
|
mind_state: None,
|
||||||
|
|
@ -151,9 +143,52 @@ impl App {
|
||||||
rebuild_tools_pending: false,
|
rebuild_tools_pending: false,
|
||||||
walked_count: 0,
|
walked_count: 0,
|
||||||
channel_status: Vec::new(), idle_info: None,
|
channel_status: Vec::new(), idle_info: None,
|
||||||
|
finetune_candidates: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn finetune_action(&mut self, idx: usize, status: learn::CandidateStatus) {
|
||||||
|
if let Some(candidate) = self.finetune_candidates.get_mut(idx) {
|
||||||
|
candidate.status = status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finetune_send_approved(&mut self) {
|
||||||
|
// Collect approved candidates
|
||||||
|
let samples: Vec<crate::subconscious::learn::TrainData> = self.finetune_candidates.iter()
|
||||||
|
.filter(|c| c.status == learn::CandidateStatus::Approved)
|
||||||
|
.map(|c| crate::subconscious::learn::TrainData {
|
||||||
|
context_ids: c.context_ids.clone(),
|
||||||
|
continuation_ids: c.continuation_ids.clone(),
|
||||||
|
timestamp_ns: c.timestamp_ns,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if samples.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as sent in UI immediately
|
||||||
|
for candidate in &mut self.finetune_candidates {
|
||||||
|
if candidate.status == learn::CandidateStatus::Approved {
|
||||||
|
candidate.status = learn::CandidateStatus::Sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spawn async task to send to training server
|
||||||
|
let client = self.agent.client.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match crate::subconscious::learn::send_to_train(samples, &client).await {
|
||||||
|
Ok(job_id) => {
|
||||||
|
dbglog!("[finetune] training started: {}", job_id);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
dbglog!("[finetune] send failed: {:#}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
fn set_channel_status(&mut self, channels: Vec<(String, bool, u32)>) {
|
fn set_channel_status(&mut self, channels: Vec<(String, bool, u32)>) {
|
||||||
self.channel_status = channels.into_iter()
|
self.channel_status = channels.into_iter()
|
||||||
|
|
@ -193,6 +228,9 @@ fn restore_terminal(terminal: &mut ratatui::Terminal<CrosstermBackend<io::Stdout
|
||||||
async fn start(cli: crate::user::CliArgs) -> Result<()> {
|
async fn start(cli: crate::user::CliArgs) -> Result<()> {
|
||||||
let (config, _figment) = crate::config::load_session(&cli).await?;
|
let (config, _figment) = crate::config::load_session(&cli).await?;
|
||||||
|
|
||||||
|
// Pick up external edits (vim, F6 hotkeys, etc.) without restart.
|
||||||
|
crate::config::watch_config(cli.clone());
|
||||||
|
|
||||||
if config.app.debug {
|
if config.app.debug {
|
||||||
unsafe { std::env::set_var("POC_DEBUG", "1") };
|
unsafe { std::env::set_var("POC_DEBUG", "1") };
|
||||||
}
|
}
|
||||||
|
|
@ -334,7 +372,7 @@ async fn run(
|
||||||
}
|
}
|
||||||
let notify_rx = crate::thalamus::channels::subscribe_all();
|
let notify_rx = crate::thalamus::channels::subscribe_all();
|
||||||
|
|
||||||
// F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus
|
// F1=chat, F2=conscious, F3=subconscious, F4=unconscious, F5=thalamus, F6=learn
|
||||||
let mut screens: Vec<Box<dyn tui::ScreenView>> = vec![
|
let mut screens: Vec<Box<dyn tui::ScreenView>> = vec![
|
||||||
Box::new(crate::user::chat::InteractScreen::new(
|
Box::new(crate::user::chat::InteractScreen::new(
|
||||||
mind.agent.clone(), mind.shared.clone(), mind_tx.clone(),
|
mind.agent.clone(), mind.shared.clone(), mind_tx.clone(),
|
||||||
|
|
@ -343,6 +381,7 @@ async fn run(
|
||||||
Box::new(crate::user::subconscious::SubconsciousScreen::new()),
|
Box::new(crate::user::subconscious::SubconsciousScreen::new()),
|
||||||
Box::new(crate::user::unconscious::UnconsciousScreen::new()),
|
Box::new(crate::user::unconscious::UnconsciousScreen::new()),
|
||||||
Box::new(crate::user::thalamus::ThalamusScreen::new()),
|
Box::new(crate::user::thalamus::ThalamusScreen::new()),
|
||||||
|
Box::new(crate::user::learn::LearnScreen::new(mind_tx.clone())),
|
||||||
];
|
];
|
||||||
let mut active_screen: usize = 1; // F-key number
|
let mut active_screen: usize = 1; // F-key number
|
||||||
tui::set_screen_legend(tui::screen_legend_from(&*screens));
|
tui::set_screen_legend(tui::screen_legend_from(&*screens));
|
||||||
|
|
@ -419,7 +458,8 @@ async fn run(
|
||||||
idle_state.decay_ewma();
|
idle_state.decay_ewma();
|
||||||
app.update_idle(&idle_state);
|
app.update_idle(&idle_state);
|
||||||
app.agent_state = mind.subconscious_snapshots().await;
|
app.agent_state = mind.subconscious_snapshots().await;
|
||||||
if let Ok(mut unc) = mind.unconscious.try_lock() {
|
{
|
||||||
|
let mut unc = mind.unconscious.lock().await;
|
||||||
let toggles: Vec<String> = app.agent_toggles.drain(..).collect();
|
let toggles: Vec<String> = app.agent_toggles.drain(..).collect();
|
||||||
for name in &toggles {
|
for name in &toggles {
|
||||||
if mind.subconscious.lock().await.toggle(name).is_none() {
|
if mind.subconscious.lock().await.toggle(name).is_none() {
|
||||||
|
|
@ -433,7 +473,38 @@ async fn run(
|
||||||
};
|
};
|
||||||
app.unconscious_state = unc.snapshots(store_guard.as_deref());
|
app.unconscious_state = unc.snapshots(store_guard.as_deref());
|
||||||
app.graph_health = unc.graph_health.clone();
|
app.graph_health = unc.graph_health.clone();
|
||||||
app.mind_state = Some(mind.shared.lock().unwrap().clone());
|
}
|
||||||
|
|
||||||
|
// Sync mind state (finetune candidates, last scoring run, etc.)
|
||||||
|
{
|
||||||
|
let ms = mind.shared.lock().unwrap();
|
||||||
|
// Sync finetune candidates: add new ones, keep existing (preserves approval status),
|
||||||
|
// remove sent candidates, keep only 10 most recent rejected.
|
||||||
|
app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent);
|
||||||
|
for c in &ms.finetune_candidates {
|
||||||
|
let exists = app.finetune_candidates.iter()
|
||||||
|
.any(|existing| existing.timestamp_ns == c.timestamp_ns);
|
||||||
|
if !exists {
|
||||||
|
app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut rejected: Vec<_> = app.finetune_candidates.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, c)| c.status == learn::CandidateStatus::Rejected)
|
||||||
|
.map(|(i, c)| (i, c.timestamp_ns))
|
||||||
|
.collect();
|
||||||
|
if rejected.len() > 10 {
|
||||||
|
rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts));
|
||||||
|
let to_remove: std::collections::HashSet<_> = rejected[10..]
|
||||||
|
.iter().map(|(i, _)| *i).collect();
|
||||||
|
let mut idx = 0;
|
||||||
|
app.finetune_candidates.retain(|_| {
|
||||||
|
let keep = !to_remove.contains(&idx);
|
||||||
|
idx += 1;
|
||||||
|
keep
|
||||||
|
});
|
||||||
|
}
|
||||||
|
app.mind_state = Some(ms.clone());
|
||||||
}
|
}
|
||||||
app.walked_count = mind.subconscious_walked().await.len();
|
app.walked_count = mind.subconscious_walked().await.len();
|
||||||
if !startup_done {
|
if !startup_done {
|
||||||
|
|
@ -530,16 +601,11 @@ async fn run(
|
||||||
// --- CLI ---
|
// --- CLI ---
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use std::path::PathBuf;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug, Default)]
|
#[derive(Parser, Debug, Default, Clone)]
|
||||||
#[command(name = "consciousness", about = "Substrate-independent AI agent")]
|
#[command(name = "consciousness", about = "Substrate-independent AI agent")]
|
||||||
pub struct CliArgs {
|
pub struct CliArgs {
|
||||||
/// Select active backend ("anthropic" or "openrouter")
|
/// Model override (selects a named entry from `models` in config.json5)
|
||||||
#[arg(long)]
|
|
||||||
pub backend: Option<String>,
|
|
||||||
|
|
||||||
/// Model override
|
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
|
@ -559,10 +625,6 @@ pub struct CliArgs {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub show_config: bool,
|
pub show_config: bool,
|
||||||
|
|
||||||
/// Project memory directory
|
|
||||||
#[arg(long)]
|
|
||||||
pub memory_project: Option<PathBuf>,
|
|
||||||
|
|
||||||
/// Max consecutive DMN turns
|
/// Max consecutive DMN turns
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub dmn_max_turns: Option<u32>,
|
pub dmn_max_turns: Option<u32>,
|
||||||
|
|
@ -575,7 +637,7 @@ pub struct CliArgs {
|
||||||
pub command: Option<SubCmd>,
|
pub command: Option<SubCmd>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Subcommand, Debug)]
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
pub enum SubCmd {
|
pub enum SubCmd {
|
||||||
/// Print new output since last read and exit
|
/// Print new output since last read and exit
|
||||||
Read {
|
Read {
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference.
|
Continuous fine-tuning of Qwen3.5-27B alongside live vLLM inference.
|
||||||
Full-weight updates (not LoRA) using Apollo optimizer with rank-256
|
Full-weight updates (not LoRA) using Apollo optimizer with rank-64
|
||||||
gradient projection. No pause required — HOGWILD concurrent training.
|
gradient projection. No pause required — HOGWILD concurrent training.
|
||||||
Weights shared via CUDA IPC between vLLM and the training process.
|
Weights shared via CUDA IPC between vLLM and the training process.
|
||||||
|
|
||||||
|
|
@ -22,25 +22,41 @@ The training signal comes from two sources:
|
||||||
│ │
|
│ │
|
||||||
│ ┌──────────────────────────────────────────────┐ │
|
│ ┌──────────────────────────────────────────────┐ │
|
||||||
│ │ Model Weights (54GB, bf16) │ │
|
│ │ Model Weights (54GB, bf16) │ │
|
||||||
│ │ Shared via CUDA IPC │ │
|
│ │ Shared: vLLM inference + HF training │ │
|
||||||
│ └──────────────┬──────────────┬────────────────┘ │
|
│ └──────────────┬──────────────┬────────────────┘ │
|
||||||
│ │ │ │
|
│ │ │ │
|
||||||
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
|
│ ┌──────────────▼──┐ ┌───────▼────────────────┐ │
|
||||||
│ │ vLLM (inference)│ │ Apollo (training) │ │
|
│ │ vLLM (inference)│ │ Training subprocess │ │
|
||||||
│ │ KV cache ~60GB │ │ Gradients ~54GB │ │
|
│ │ KV cache ~60GB │ │ HF model wrapper │ │
|
||||||
│ │ Serves requests │ │ Optimizer state ~10GB │ │
|
│ │ /completions │ │ Apollo optimizer ~2.5GB │ │
|
||||||
│ │ Never paused │ │ Activations ~10GB │ │
|
│ │ /score │ │ Checkpoint sync │ │
|
||||||
│ └─────────────────┘ └────────────────────────┘ │
|
│ └────────┬────────┘ └───────────▲─────────────┘ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ ZMQ IPC │ │
|
||||||
|
│ └───────────────────────┘ │
|
||||||
└─────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
Moria B200
|
Process Architecture:
|
||||||
|
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||||
|
│ vLLM Worker │ │ vLLM API Server │ │ Training Worker │
|
||||||
|
│ (GPU inference) │ │ (HTTP routes) │ │ (GPU training) │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ export_hook.py │ │ /completions │ │ HF model views │
|
||||||
|
│ exports IPC │ │ /score │ │ Apollo optimizer│
|
||||||
|
│ handles on load │ │ /train ─────────┼──► ZMQ REP socket │
|
||||||
|
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||||
|
│ │
|
||||||
|
└──── IPC handles file ──────────────────┘
|
||||||
|
/tmp/vllm_weight_handles.pt
|
||||||
|
|
||||||
|
Moria B200 (vLLM)
|
||||||
┌──────────────────┐ ┌──────────────────┐
|
┌──────────────────┐ ┌──────────────────┐
|
||||||
│ Training signal │ HTTP │ Apollo worker │
|
│ Training signal │ HTTP │ /completions │
|
||||||
│ agent │──────────>│ daemon │
|
│ agent │──────────>│ /score │
|
||||||
│ │ │ │
|
│ │ │ /train │
|
||||||
│ Dream loop │ │ Checkpoint sync │
|
│ Dream loop │ │ /checkpoint │
|
||||||
│ (generates │ │ (mmap + diff, │
|
│ (generates │ │ /train/status │
|
||||||
│ scenarios) │ │ every 10 min) │
|
│ scenarios) │ │ │
|
||||||
└──────────────────┘ └──────────────────┘
|
└──────────────────┘ └──────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -59,10 +75,9 @@ LoRA trains adapter matrices, not base weights. For personality and
|
||||||
behavioral changes that persist as disposition, the base weights
|
behavioral changes that persist as disposition, the base weights
|
||||||
need to change. Apollo makes this memory-feasible.
|
need to change. Apollo makes this memory-feasible.
|
||||||
|
|
||||||
### Rank 256
|
### Rank 64
|
||||||
Not Mini (rank-1). With 100+ diverse training examples, the
|
Not Mini (rank-1). Rank-64 captures gradient structure across diverse
|
||||||
gradient's effective dimensionality can reach hundreds. Rank-256
|
training examples while keeping memory low (~2.5GB on 27B model).
|
||||||
captures the structure. Memory cost: ~10GB (negligible on B200).
|
|
||||||
Compute cost: <0.25% of forward+backward.
|
Compute cost: <0.25% of forward+backward.
|
||||||
|
|
||||||
### Channel-wise scaling
|
### Channel-wise scaling
|
||||||
|
|
@ -90,7 +105,7 @@ from a per-parameter seed each step.
|
||||||
### Parameter grouping (Qwen3.5 gotcha)
|
### Parameter grouping (Qwen3.5 gotcha)
|
||||||
conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector
|
conv1d weights are 3D tensors [10240, 1, 4]. Apollo's projector
|
||||||
needs 2D matrices with min dimension >= rank. Small/3D tensors
|
needs 2D matrices with min dimension >= rank. Small/3D tensors
|
||||||
use standard Adam. Large 2D matrices use Apollo with rank-256.
|
use standard Adam. Large 2D matrices use Apollo.
|
||||||
|
|
||||||
## Training Data Pipeline
|
## Training Data Pipeline
|
||||||
|
|
||||||
|
|
@ -200,16 +215,42 @@ against live GPU weights block by block, memcpy only changed
|
||||||
regions. For small behavioral updates, turns a 54GB write into
|
regions. For small behavioral updates, turns a 54GB write into
|
||||||
a few hundred MB.
|
a few hundred MB.
|
||||||
|
|
||||||
- Every 10 minutes via cron on B200
|
- Scheduled 10 minutes after training (batched)
|
||||||
- Daily rsync to moria for long-term storage
|
- Daily rsync to moria for long-term storage
|
||||||
- Tool: `apollo-checkpoint sync --model-dir <path>` (Rust)
|
- Tool: `apollo-checkpoint sync --model-dir <path>`
|
||||||
|
|
||||||
|
## State Files
|
||||||
|
|
||||||
|
### B200 (training server)
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `/tmp/vllm_weight_handles.pt` | CUDA IPC handles for weight sharing. Written by export_hook on vLLM startup. Read by training_worker to construct HF model with vLLM weight views. Includes metadata (model_path). |
|
||||||
|
| `/tmp/apollo_optimizer_state.pt` | Apollo optimizer state (momentum, variance estimates). Saved during checkpoint sync and on worker shutdown, restored on next training_worker startup. Preserves training continuity across sessions. |
|
||||||
|
| `/tmp/apollo_training.sock` | ZMQ IPC socket for communication between API server (/train endpoint) and training_worker subprocess. |
|
||||||
|
| `<model_dir>/*.safetensors` | Model weights. Updated in-place by checkpoint_sync. |
|
||||||
|
|
||||||
|
### Moria (client)
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `~/.consciousness/cache/trained-responses.json` | Timestamps (ms) of responses already sent to /train. Prevents re-training the same response. |
|
||||||
|
| `~/.consciousness/cache/finetune-alternates` | Marker file. If exists, alternate responses are generated during divergence scoring to show what model would say without memories. |
|
||||||
|
|
||||||
|
### In-memory (training_worker subprocess)
|
||||||
|
|
||||||
|
| State | Location | Notes |
|
||||||
|
|-------|----------|-------|
|
||||||
|
| Apollo optimizer | TrainingWorker.optimizer | ~2.5GB for rank-64. Persisted to `/tmp/apollo_optimizer_state.pt` during checkpoint sync and on shutdown. |
|
||||||
|
| HF model with vLLM views | TrainingWorker.model | Loaded on worker startup from IPC handles. Parameters point to vLLM's GPU memory. |
|
||||||
|
| ZMQ socket | TrainingWorker.zmq_socket | REP socket bound to `/tmp/apollo_training.sock`. |
|
||||||
|
|
||||||
## Hyperparameters
|
## Hyperparameters
|
||||||
|
|
||||||
| Parameter | Value | Rationale |
|
| Parameter | Value | Rationale |
|
||||||
|-----------|-------|-----------|
|
|-----------|-------|-----------|
|
||||||
| Learning rate | 1e-5 to 1e-4 | Standard for full fine-tuning. Higher for diverse batches. |
|
| Learning rate | 1e-5 to 1e-4 | Standard for full fine-tuning. Higher for diverse batches. |
|
||||||
| Rank | 256 | Captures gradient structure across 100+ examples. ~10GB state. |
|
| Rank | 64 | Captures gradient structure. ~2.5GB state. Defined in `train_router.DEFAULT_RANK`. |
|
||||||
| Scale type | channel | Per-channel precision, matches LLaMA-Factory defaults. |
|
| Scale type | channel | Per-channel precision, matches LLaMA-Factory defaults. |
|
||||||
| Epochs | 1 | One pass over diverse data. Multiple epochs risk overfitting. |
|
| Epochs | 1 | One pass over diverse data. Multiple epochs risk overfitting. |
|
||||||
| Batch size | 1 | Single examples, immediate updates. |
|
| Batch size | 1 | Single examples, immediate updates. |
|
||||||
|
|
@ -220,34 +261,32 @@ a few hundred MB.
|
||||||
## Components
|
## Components
|
||||||
|
|
||||||
### Built ✓
|
### Built ✓
|
||||||
- `apollo_mini.py` — Apollo optimizer (configurable rank, default 256)
|
- `optimizer.py` — Apollo optimizer (configurable rank)
|
||||||
- `apollo_worker.py` — HTTP daemon (aiohttp, job tracking)
|
- `train_router.py` — /train endpoint, forwards to training subprocess via ZMQ
|
||||||
|
- `training_worker.py` — training subprocess (HF model, Apollo, checkpoint sync)
|
||||||
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
|
- `weight_mapping.py` — vLLM merged → HF separate views (validated)
|
||||||
- `training_example.py` — tokenization with chat template
|
- `export_hook.py` — vLLM plugin hook for IPC handle export
|
||||||
- `vllm_export_hook.py` — source patch for IPC handle export
|
- `checkpoint_sync.py` — mmap + diff checkpoint sync (Python)
|
||||||
- `checkpoint/` — Rust tool for mmap + diff checkpoint sync
|
|
||||||
|
|
||||||
### To build
|
### To build
|
||||||
- **Dream loop → training bridge**: connect dream output to Apollo
|
- **Dream loop → training bridge**: connect dream output to /train
|
||||||
- **Training-signal agent**: flags moments in conversation logs
|
- **Training-signal agent**: flags moments in conversation logs
|
||||||
- **Instruction stripping**: remove scaffolding from training examples
|
- **Instruction stripping**: remove scaffolding from training examples
|
||||||
- **Quality monitoring**: track model capability over time
|
- **Quality monitoring**: track model capability over time
|
||||||
- **HF model forward pass integration**: wire into apollo_worker
|
|
||||||
|
|
||||||
## Files
|
## Files
|
||||||
|
|
||||||
```
|
```
|
||||||
training/
|
training/
|
||||||
DESIGN.md — this document
|
DESIGN.md — this document
|
||||||
apollo_mini.py — Apollo optimizer
|
pyproject.toml — package config, vLLM plugin entry point
|
||||||
apollo_worker.py — HTTP training daemon
|
apollo_plugin/
|
||||||
|
__init__.py — plugin registration
|
||||||
|
export_hook.py — patches vLLM worker to export IPC handles
|
||||||
|
train_router.py — /train endpoint, forwards to worker via ZMQ
|
||||||
|
training_worker.py — training subprocess (HF model, Apollo, checkpoint)
|
||||||
|
optimizer.py — Apollo optimizer
|
||||||
weight_mapping.py — vLLM ↔ HF weight views
|
weight_mapping.py — vLLM ↔ HF weight views
|
||||||
training_example.py — tokenization helpers
|
checkpoint_sync.py — mmap + diff sync to safetensors
|
||||||
export_weights.py — standalone weight export (unused)
|
steering.py — steering vector extraction (experimental)
|
||||||
vllm_export_hook.py — vLLM source patch for IPC export
|
|
||||||
start_vllm_with_apollo.sh — vLLM launcher (unused, using source patch)
|
|
||||||
train.py — standalone training script (alternative)
|
|
||||||
checkpoint/
|
|
||||||
Cargo.toml — Rust checkpoint tool
|
|
||||||
src/main.rs — mmap + diff sync
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
19
training/apollo_plugin/__init__.py
Normal file
19
training/apollo_plugin/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Apollo training plugin for vLLM.
|
||||||
|
|
||||||
|
Enables continuous fine-tuning alongside live inference by:
|
||||||
|
1. Exporting CUDA IPC handles for weight sharing (export_hook)
|
||||||
|
2. Adding /train endpoint to vLLM's HTTP server (train_router)
|
||||||
|
3. Block-level checkpoint sync to safetensors files
|
||||||
|
|
||||||
|
Install: pip install -e /path/to/training
|
||||||
|
Then vLLM auto-loads via entry point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .export_hook import _patch_model_runner
|
||||||
|
from .train_router import _patch_api_server
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
"""Called by vLLM's plugin loader on startup."""
|
||||||
|
_patch_model_runner()
|
||||||
|
_patch_api_server()
|
||||||
503
training/apollo_plugin/checkpoint_sync.py
Normal file
503
training/apollo_plugin/checkpoint_sync.py
Normal file
|
|
@ -0,0 +1,503 @@
|
||||||
|
"""Sync live GPU weights to safetensors files on disk.
|
||||||
|
|
||||||
|
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
|
||||||
|
merged layout to HuggingFace's separate layout, diffs block-by-block
|
||||||
|
against on-disk safetensors files, and writes only changed blocks.
|
||||||
|
|
||||||
|
For small behavioral training steps, this turns a 54GB checkpoint
|
||||||
|
write into a few hundred MB of actual disk I/O.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Sync live weights to disk
|
||||||
|
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
|
||||||
|
|
||||||
|
# Debug name mapping issues
|
||||||
|
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
|
||||||
|
|
||||||
|
# From Python:
|
||||||
|
from checkpoint_sync import checkpoint_sync
|
||||||
|
result = checkpoint_sync("/path/to/model")
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import mmap
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
|
||||||
|
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# vLLM → HuggingFace weight name/shape conversion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Qwen3.5-27B dimensions (could be read from config.json for generality)
|
||||||
|
|
||||||
|
HIDDEN = 5120
|
||||||
|
NUM_K_HEADS = 16
|
||||||
|
NUM_V_HEADS = 48
|
||||||
|
HEAD_K_DIM = 128
|
||||||
|
HEAD_V_DIM = 128
|
||||||
|
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||||
|
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||||
|
INTERMEDIATE = 17408
|
||||||
|
|
||||||
|
# Full attention (some layers use standard attention, not GDN)
|
||||||
|
NUM_ATTN_HEADS = 24
|
||||||
|
NUM_ATTN_KV_HEADS = 4
|
||||||
|
ATTN_HEAD_DIM = 256
|
||||||
|
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||||
|
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||||
|
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||||
|
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Convert vLLM merged weights to HF-compatible separate tensors.
|
||||||
|
|
||||||
|
vLLM merges certain projections for efficiency:
|
||||||
|
- qkv_proj (full attn) → q_proj, k_proj, v_proj
|
||||||
|
- in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z
|
||||||
|
- in_proj_ba (GDN) → in_proj_b, in_proj_a
|
||||||
|
- gate_up_proj (MLP) → gate_proj, up_proj
|
||||||
|
|
||||||
|
Returns views that share GPU memory with the original tensors.
|
||||||
|
"""
|
||||||
|
hf_params = {}
|
||||||
|
|
||||||
|
for name, tensor in vllm_params.items():
|
||||||
|
# Strip vLLM's 'language_model.' prefix to match HF naming
|
||||||
|
hf_name = name.removeprefix('language_model.')
|
||||||
|
|
||||||
|
if 'in_proj_qkvz' in name:
|
||||||
|
# GDN layer: [key*2 + value*2, hidden] → qkv + z
|
||||||
|
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||||
|
split_at = KEY_DIM * 2 + VALUE_DIM
|
||||||
|
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
|
||||||
|
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
|
||||||
|
|
||||||
|
elif 'in_proj_ba' in name:
|
||||||
|
# GDN layer: [num_v_heads*2, hidden] → b + a
|
||||||
|
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||||
|
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
|
||||||
|
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
|
||||||
|
|
||||||
|
elif 'qkv_proj' in name:
|
||||||
|
# Full attention: [q + k + v, hidden] → separate
|
||||||
|
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||||
|
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
|
||||||
|
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||||
|
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||||
|
|
||||||
|
elif 'gate_up_proj' in name:
|
||||||
|
# MLP: [intermediate*2, hidden] → gate + up
|
||||||
|
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||||
|
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
|
||||||
|
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Pass through unchanged
|
||||||
|
hf_params[hf_name] = tensor
|
||||||
|
|
||||||
|
return hf_params
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Safetensors file handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
|
||||||
|
"""Map tensor names to safetensors filenames.
|
||||||
|
|
||||||
|
For sharded models, reads model.safetensors.index.json.
|
||||||
|
For single-file models, returns empty dict (default to model.safetensors).
|
||||||
|
"""
|
||||||
|
index_path = model_dir / "model.safetensors.index.json"
|
||||||
|
if not index_path.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with open(index_path) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
|
||||||
|
return dict(index.get("weight_map", {}))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
|
||||||
|
"""Parse safetensors file header.
|
||||||
|
|
||||||
|
Returns (data_start_offset, header_dict).
|
||||||
|
Header dict maps tensor names to metadata including 'data_offsets'.
|
||||||
|
"""
|
||||||
|
header_size = struct.unpack('<Q', data[:8])[0]
|
||||||
|
header = json.loads(bytes(data[8:8 + header_size]))
|
||||||
|
return 8 + header_size, header
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Block-level diffing and sync
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sync_tensor_to_mmap(
|
||||||
|
mm: mmap.mmap,
|
||||||
|
name: str,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
data_start: int,
|
||||||
|
offsets: List[int],
|
||||||
|
block_size: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Sync a single tensor to mmap'd file using block-level diffing.
|
||||||
|
|
||||||
|
Returns (bytes_compared, bytes_changed).
|
||||||
|
"""
|
||||||
|
start = data_start + offsets[0]
|
||||||
|
end = data_start + offsets[1]
|
||||||
|
disk_len = end - start
|
||||||
|
|
||||||
|
# Transfer tensor to CPU and get raw bytes
|
||||||
|
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
|
||||||
|
try:
|
||||||
|
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to transfer {name} to CPU: {e}")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
if len(live_bytes) != disk_len:
|
||||||
|
logger.warning(
|
||||||
|
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
|
||||||
|
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
|
||||||
|
)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
# Block-level diff: compare and write only changed blocks
|
||||||
|
compared = 0
|
||||||
|
changed = 0
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
while offset < disk_len:
|
||||||
|
block_end = min(offset + block_size, disk_len)
|
||||||
|
block_len = block_end - offset
|
||||||
|
|
||||||
|
disk_block = mm[start + offset:start + block_end]
|
||||||
|
live_block = live_bytes[offset:block_end]
|
||||||
|
|
||||||
|
compared += block_len
|
||||||
|
|
||||||
|
if disk_block != live_block:
|
||||||
|
mm[start + offset:start + block_end] = live_block
|
||||||
|
changed += block_len
|
||||||
|
|
||||||
|
offset = block_end
|
||||||
|
|
||||||
|
return compared, changed
|
||||||
|
|
||||||
|
|
||||||
|
def sync_file(
|
||||||
|
file_path: Path,
|
||||||
|
tensors: Dict[str, torch.Tensor],
|
||||||
|
block_size: int,
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""Sync tensors to a single safetensors file.
|
||||||
|
|
||||||
|
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r+b') as f:
|
||||||
|
mm = mmap.mmap(f.fileno(), 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_start, header = parse_safetensors_header(memoryview(mm))
|
||||||
|
|
||||||
|
total_compared = 0
|
||||||
|
total_changed = 0
|
||||||
|
found = 0
|
||||||
|
missing = 0
|
||||||
|
|
||||||
|
for name, tensor in tensors.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name not in header:
|
||||||
|
missing += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
found += 1
|
||||||
|
meta = header[name]
|
||||||
|
offsets = meta['data_offsets']
|
||||||
|
|
||||||
|
compared, changed = sync_tensor_to_mmap(
|
||||||
|
mm, name, tensor, data_start, offsets, block_size
|
||||||
|
)
|
||||||
|
total_compared += compared
|
||||||
|
total_changed += changed
|
||||||
|
|
||||||
|
# Flush changes to disk
|
||||||
|
if total_changed > 0:
|
||||||
|
mm.flush()
|
||||||
|
|
||||||
|
return total_compared, total_changed, found, missing
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mm.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load vLLM weight tensors from CUDA IPC handles.
|
||||||
|
|
||||||
|
The handles file is written by vllm_export_hook.py on vLLM startup.
|
||||||
|
Each handle can be used to reconstruct a tensor pointing to vLLM's
|
||||||
|
GPU memory — no copy, direct access.
|
||||||
|
"""
|
||||||
|
handles = torch.load(handles_path, weights_only=False)
|
||||||
|
|
||||||
|
# Skip metadata entry
|
||||||
|
handles.pop('__metadata__', None)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
try:
|
||||||
|
weights[name] = func(*args)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconstruct {name}: {e}")
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_sync(
|
||||||
|
model_dir: str,
|
||||||
|
handles_path: str = DEFAULT_HANDLES_PATH,
|
||||||
|
block_size: int = DEFAULT_BLOCK_SIZE,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Sync live GPU weights to model safetensors files.
|
||||||
|
|
||||||
|
This is the main entry point. Call this after training steps
|
||||||
|
or periodically to checkpoint weights without full serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory containing safetensors files
|
||||||
|
handles_path: Path to vLLM weight IPC handles file
|
||||||
|
block_size: Block size for diffing (default 4KB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with sync statistics:
|
||||||
|
- total_compared: bytes compared
|
||||||
|
- total_changed: bytes actually written
|
||||||
|
- files_changed: list of modified filenames
|
||||||
|
- tensors_synced: number of tensors processed
|
||||||
|
- tensors_missing: tensors not found in safetensors
|
||||||
|
"""
|
||||||
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
|
if not Path(handles_path).exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Weight handles not found: {handles_path}. "
|
||||||
|
"Is vLLM running with the export hook?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Load live weights from GPU via IPC
|
||||||
|
logger.info("Loading live weights from GPU...")
|
||||||
|
vllm_weights = load_vllm_weights(handles_path)
|
||||||
|
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
|
||||||
|
|
||||||
|
# Step 2: Convert to HF naming/layout
|
||||||
|
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||||
|
logger.info(f" Converted to {len(hf_weights)} HF tensors")
|
||||||
|
|
||||||
|
# Step 3: Map tensors to safetensors files
|
||||||
|
weight_map = read_safetensors_index(model_dir)
|
||||||
|
|
||||||
|
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
|
unmapped = []
|
||||||
|
|
||||||
|
for name, tensor in hf_weights.items():
|
||||||
|
filename = weight_map.get(name)
|
||||||
|
if filename is None:
|
||||||
|
# Single-file model or missing from index
|
||||||
|
if (model_dir / "model.safetensors").exists():
|
||||||
|
filename = "model.safetensors"
|
||||||
|
else:
|
||||||
|
unmapped.append(name)
|
||||||
|
continue
|
||||||
|
by_file.setdefault(filename, {})[name] = tensor
|
||||||
|
|
||||||
|
if unmapped:
|
||||||
|
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
|
||||||
|
|
||||||
|
# Step 4: Sync each file
|
||||||
|
total_compared = 0
|
||||||
|
total_changed = 0
|
||||||
|
total_found = 0
|
||||||
|
total_missing = 0
|
||||||
|
files_changed = []
|
||||||
|
|
||||||
|
for filename in sorted(by_file.keys()):
|
||||||
|
tensors = by_file[filename]
|
||||||
|
file_path = model_dir / filename
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.warning(f" File not found: {filename}")
|
||||||
|
total_missing += len(tensors)
|
||||||
|
continue
|
||||||
|
|
||||||
|
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
|
||||||
|
|
||||||
|
total_compared += compared
|
||||||
|
total_changed += changed
|
||||||
|
total_found += found
|
||||||
|
total_missing += missing
|
||||||
|
|
||||||
|
if changed > 0:
|
||||||
|
files_changed.append(filename)
|
||||||
|
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
if total_changed == 0:
|
||||||
|
logger.info("No changes - model files are up to date")
|
||||||
|
else:
|
||||||
|
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
|
||||||
|
logger.info(
|
||||||
|
f"Synced: {total_changed / 1e6:.2f} MB changed / "
|
||||||
|
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_missing > 0:
|
||||||
|
logger.warning(f" {total_missing} tensors not found in safetensors files")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_compared": total_compared,
|
||||||
|
"total_changed": total_changed,
|
||||||
|
"files_changed": files_changed,
|
||||||
|
"tensors_synced": total_found,
|
||||||
|
"tensors_missing": total_missing,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Diagnostics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
|
||||||
|
"""Print diagnostic info about weight name mappings.
|
||||||
|
|
||||||
|
Useful for debugging mismatches between vLLM and safetensors names.
|
||||||
|
"""
|
||||||
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
|
# Load and convert vLLM weights
|
||||||
|
vllm_weights = load_vllm_weights(handles_path)
|
||||||
|
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||||
|
hf_names = set(hf_weights.keys())
|
||||||
|
|
||||||
|
# Read safetensors index
|
||||||
|
weight_map = read_safetensors_index(model_dir)
|
||||||
|
disk_names = set(weight_map.keys())
|
||||||
|
|
||||||
|
# If single-file model, parse that file's header
|
||||||
|
if not disk_names:
|
||||||
|
st_path = model_dir / "model.safetensors"
|
||||||
|
if st_path.exists():
|
||||||
|
with open(st_path, 'rb') as f:
|
||||||
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
_, header = parse_safetensors_header(memoryview(mm))
|
||||||
|
disk_names = {k for k in header.keys() if k != "__metadata__"}
|
||||||
|
mm.close()
|
||||||
|
|
||||||
|
print(f"vLLM tensors (raw): {len(vllm_weights)}")
|
||||||
|
print(f"HF tensors (converted): {len(hf_names)}")
|
||||||
|
print(f"Disk tensors: {len(disk_names)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
in_both = hf_names & disk_names
|
||||||
|
only_hf = hf_names - disk_names
|
||||||
|
only_disk = disk_names - hf_names
|
||||||
|
|
||||||
|
print(f"Matched: {len(in_both)}")
|
||||||
|
print(f"Only in HF (won't sync): {len(only_hf)}")
|
||||||
|
print(f"Only on disk (not updated): {len(only_disk)}")
|
||||||
|
|
||||||
|
if only_hf:
|
||||||
|
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
|
||||||
|
if only_disk:
|
||||||
|
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Sync live GPU weights to safetensors files"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Command")
|
||||||
|
|
||||||
|
# sync command
|
||||||
|
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--model-dir", required=True,
|
||||||
|
help="Directory containing safetensors files"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||||
|
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
|
||||||
|
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
|
||||||
|
)
|
||||||
|
sync_parser.add_argument(
|
||||||
|
"-v", "--verbose", action="store_true",
|
||||||
|
help="Verbose output"
|
||||||
|
)
|
||||||
|
|
||||||
|
# diagnose command
|
||||||
|
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
|
||||||
|
diag_parser.add_argument(
|
||||||
|
"--model-dir", required=True,
|
||||||
|
help="Directory containing safetensors files"
|
||||||
|
)
|
||||||
|
diag_parser.add_argument(
|
||||||
|
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||||
|
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command is None:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
|
||||||
|
format='%(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.command == "sync":
|
||||||
|
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
elif args.command == "diagnose":
|
||||||
|
diagnose(args.model_dir, args.handles)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,17 +1,12 @@
|
||||||
"""Monkey-patch vLLM to export weight IPC handles on startup.
|
"""Monkey-patch vLLM to export weight IPC handles on startup.
|
||||||
|
|
||||||
Usage — add to start_vllm.sh BEFORE the vllm serve command:
|
Usage — install the apollo_plugin package:
|
||||||
|
|
||||||
export VLLM_PLUGINS=vllm_export_hook
|
pip install -e /path/to/training
|
||||||
vllm serve Qwen/Qwen3.5-27B ...
|
|
||||||
|
|
||||||
Or use Python to launch vLLM with the hook:
|
Then vLLM auto-discovers and loads via entry point. Or filter:
|
||||||
|
|
||||||
python3 -c "
|
VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ...
|
||||||
import vllm_export_hook # installs the patch
|
|
||||||
from vllm.entrypoints.openai.api_server import run_server
|
|
||||||
run_server(...)
|
|
||||||
"
|
|
||||||
|
|
||||||
The hook patches vLLM's model runner to export IPC handles after
|
The hook patches vLLM's model runner to export IPC handles after
|
||||||
model loading completes. The handles are saved to a file that the
|
model loading completes. The handles are saved to a file that the
|
||||||
|
|
@ -25,7 +20,7 @@ from pathlib import Path
|
||||||
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||||
|
|
||||||
|
|
||||||
def export_model_weights(model):
|
def export_model_weights(model, model_path: str | None = None):
|
||||||
"""Export CUDA IPC handles for all model parameters."""
|
"""Export CUDA IPC handles for all model parameters."""
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
|
|
||||||
|
|
@ -43,6 +38,12 @@ def export_model_weights(model):
|
||||||
}
|
}
|
||||||
total_bytes += param.nelement() * param.element_size()
|
total_bytes += param.nelement() * param.element_size()
|
||||||
|
|
||||||
|
# Include metadata for training worker
|
||||||
|
handles['__metadata__'] = {
|
||||||
|
'model_path': model_path,
|
||||||
|
'num_params': len(handles),
|
||||||
|
}
|
||||||
|
|
||||||
torch.save(handles, HANDLE_PATH)
|
torch.save(handles, HANDLE_PATH)
|
||||||
print(f"[apollo] Exported {len(handles)} weight handles "
|
print(f"[apollo] Exported {len(handles)} weight handles "
|
||||||
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
|
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
|
||||||
|
|
@ -63,14 +64,11 @@ def _patch_model_runner():
|
||||||
def patched_load(self, *args, **kwargs):
|
def patched_load(self, *args, **kwargs):
|
||||||
result = original_load(self, *args, **kwargs)
|
result = original_load(self, *args, **kwargs)
|
||||||
try:
|
try:
|
||||||
export_model_weights(self.model_runner.model)
|
model_path = self.vllm_config.model_config.model
|
||||||
|
export_model_weights(self.model_runner.model, model_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[apollo] Failed to export weights: {e}")
|
print(f"[apollo] Failed to export weights: {e}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
gpu_worker.Worker.load_model = patched_load
|
gpu_worker.Worker.load_model = patched_load
|
||||||
print("[apollo] Weight export hook installed")
|
print("[apollo] Weight export hook installed")
|
||||||
|
|
||||||
|
|
||||||
# Auto-install when imported
|
|
||||||
_patch_model_runner()
|
|
||||||
|
|
@ -8,9 +8,9 @@ Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
|
||||||
these scaling factors using a low-rank auxiliary optimizer state based on
|
these scaling factors using a low-rank auxiliary optimizer state based on
|
||||||
pure random projection.
|
pure random projection.
|
||||||
|
|
||||||
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
Default rank=64. ~2.5GB state for 27B model, <0.25% compute overhead
|
||||||
compute overhead vs forward+backward. Captures gradient structure
|
vs forward+backward. Sufficient for behavioral training with diverse
|
||||||
across 100+ behavioral training examples per batch.
|
examples.
|
||||||
|
|
||||||
Key implementation details from the paper:
|
Key implementation details from the paper:
|
||||||
- Gradient scale factor α = √(n/r) compensates for projection ratio
|
- Gradient scale factor α = √(n/r) compensates for projection ratio
|
||||||
|
|
@ -34,7 +34,7 @@ class Apollo(Optimizer):
|
||||||
Args:
|
Args:
|
||||||
params: model parameters
|
params: model parameters
|
||||||
lr: learning rate (default: 1e-4)
|
lr: learning rate (default: 1e-4)
|
||||||
rank: projection rank (default: 256)
|
rank: projection rank (default: 64)
|
||||||
betas: Adam momentum coefficients (default: (0.9, 0.999))
|
betas: Adam momentum coefficients (default: (0.9, 0.999))
|
||||||
eps: numerical stability term (default: 1e-8)
|
eps: numerical stability term (default: 1e-8)
|
||||||
weight_decay: decoupled weight decay (default: 0.01)
|
weight_decay: decoupled weight decay (default: 0.01)
|
||||||
|
|
@ -46,7 +46,7 @@ class Apollo(Optimizer):
|
||||||
Set to None to disable.
|
Set to None to disable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
|
def __init__(self, params, lr=1e-4, rank=64, betas=(0.9, 0.999),
|
||||||
eps=1e-8, weight_decay=0.01, warmup_steps=0,
|
eps=1e-8, weight_decay=0.01, warmup_steps=0,
|
||||||
scale=None, proj_refresh=200, norm_growth_limit=1.01):
|
scale=None, proj_refresh=200, norm_growth_limit=1.01):
|
||||||
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||||
240
training/apollo_plugin/train_router.py
Normal file
240
training/apollo_plugin/train_router.py
Normal file
|
|
@ -0,0 +1,240 @@
|
||||||
|
"""Training endpoint for vLLM - forwards to training subprocess via ZMQ.
|
||||||
|
|
||||||
|
Patches vLLM's build_app() to add /train route. The actual training runs
|
||||||
|
in a dedicated subprocess (training_worker.py) to avoid blocking the
|
||||||
|
event loop and to keep training work isolated from vLLM internals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||||
|
|
||||||
|
# Global state for subprocess management
|
||||||
|
_worker_process: subprocess.Popen | None = None
|
||||||
|
_zmq_context: zmq.asyncio.Context | None = None
|
||||||
|
_zmq_socket: zmq.asyncio.Socket | None = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class TrainRequest(BaseModel):
|
||||||
|
training_data: dict[str, Any] # {"samples": [...], "config": {...}}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
training_samples: int
|
||||||
|
loss_history: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
def _start_worker_subprocess():
|
||||||
|
"""Start the training worker subprocess."""
|
||||||
|
global _worker_process
|
||||||
|
|
||||||
|
if _worker_process is not None and _worker_process.poll() is None:
|
||||||
|
return # Still running
|
||||||
|
|
||||||
|
# Start worker as subprocess using script path
|
||||||
|
worker_script = Path(__file__).parent / 'training_worker.py'
|
||||||
|
_worker_process = subprocess.Popen(
|
||||||
|
[sys.executable, str(worker_script)],
|
||||||
|
env={**os.environ, 'APOLLO_ZMQ_ADDR': DEFAULT_ZMQ_ADDR},
|
||||||
|
)
|
||||||
|
logger.info(f"Started training worker subprocess (pid={_worker_process.pid})")
|
||||||
|
|
||||||
|
# Give it a moment to bind the socket
|
||||||
|
import time
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_initialized():
|
||||||
|
"""Ensure subprocess is running and ZMQ socket is connected."""
|
||||||
|
global _zmq_context, _zmq_socket, _initialized
|
||||||
|
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start worker if needed
|
||||||
|
_start_worker_subprocess()
|
||||||
|
|
||||||
|
# Create async ZMQ context and socket
|
||||||
|
_zmq_context = zmq.asyncio.Context()
|
||||||
|
_zmq_socket = _zmq_context.socket(zmq.REQ)
|
||||||
|
_zmq_socket.connect(DEFAULT_ZMQ_ADDR)
|
||||||
|
|
||||||
|
# Set timeout for recv
|
||||||
|
_zmq_socket.setsockopt(zmq.RCVTIMEO, 300000) # 5 minute timeout for training
|
||||||
|
|
||||||
|
_initialized = True
|
||||||
|
logger.info(f"Connected to training worker at {DEFAULT_ZMQ_ADDR}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_request(request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Send request to worker and wait for response."""
|
||||||
|
_ensure_initialized()
|
||||||
|
|
||||||
|
# ZMQ async send/recv
|
||||||
|
await _zmq_socket.send_json(request)
|
||||||
|
response = await _zmq_socket.recv_json()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train")
|
||||||
|
async def handle_train(request: TrainRequest):
|
||||||
|
"""Handle training request - forwards to training subprocess."""
|
||||||
|
try:
|
||||||
|
_ensure_initialized()
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Training not available: {e}"},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_data = request.training_data
|
||||||
|
samples = training_data.get("samples", [])
|
||||||
|
config = training_data.get("config", {})
|
||||||
|
|
||||||
|
if not samples:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": "No training samples provided"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
logger.info(f"Starting training job {job_id} with {len(samples)} samples")
|
||||||
|
|
||||||
|
# Forward to worker
|
||||||
|
response = await _send_request({
|
||||||
|
'type': 'train',
|
||||||
|
'samples': samples,
|
||||||
|
'config': config,
|
||||||
|
})
|
||||||
|
|
||||||
|
if 'error' in response:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": response['error']},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Training job {job_id} completed, "
|
||||||
|
f"final loss: {response['loss_history'][-1]:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": response['status'],
|
||||||
|
"training_samples": response['training_samples'],
|
||||||
|
"loss_history": response['loss_history'],
|
||||||
|
})
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
logger.error("Training request timed out")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": "Training request timed out"},
|
||||||
|
status_code=504,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Training failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/checkpoint")
|
||||||
|
async def handle_checkpoint():
|
||||||
|
"""Trigger checkpoint sync to disk."""
|
||||||
|
try:
|
||||||
|
_ensure_initialized()
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Training not available: {e}"},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await _send_request({'type': 'checkpoint'})
|
||||||
|
|
||||||
|
if 'error' in response:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": response['error']},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(content=response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Checkpoint failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/train/status")
|
||||||
|
async def handle_status():
|
||||||
|
"""Get training worker status."""
|
||||||
|
try:
|
||||||
|
_ensure_initialized()
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"status": "unavailable",
|
||||||
|
"error": str(e),
|
||||||
|
},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await _send_request({'type': 'status'})
|
||||||
|
return JSONResponse(content=response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
"""Attach training router to FastAPI app."""
|
||||||
|
app.include_router(router)
|
||||||
|
logger.info("Training router attached")
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_api_server():
|
||||||
|
"""Patch vLLM's build_app to include our training router."""
|
||||||
|
from vllm.entrypoints.openai import api_server
|
||||||
|
|
||||||
|
original_build_app = api_server.build_app
|
||||||
|
|
||||||
|
def patched_build_app(*args, **kwargs):
|
||||||
|
app = original_build_app(*args, **kwargs)
|
||||||
|
attach_router(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
api_server.build_app = patched_build_app
|
||||||
|
logger.info("API server patched for /train endpoint")
|
||||||
323
training/apollo_plugin/training_worker.py
Normal file
323
training/apollo_plugin/training_worker.py
Normal file
|
|
@ -0,0 +1,323 @@
|
||||||
|
"""Training subprocess - handles Apollo training and checkpoint sync.
|
||||||
|
|
||||||
|
Long-lived process that:
|
||||||
|
1. Loads IPC handles from vLLM's exported weights
|
||||||
|
2. Creates HF model with views into vLLM's GPU memory
|
||||||
|
3. Handles training requests via ZMQ
|
||||||
|
4. Handles checkpoint sync requests
|
||||||
|
5. Persists Apollo optimizer state between calls
|
||||||
|
|
||||||
|
Communicates with the API server's /train endpoint via ZMQ REP socket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Handle running as script vs module
|
||||||
|
if __name__ == '__main__' and __package__ is None:
|
||||||
|
# Running as script - add parent to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
__package__ = 'apollo_plugin'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from .checkpoint_sync import checkpoint_sync
|
||||||
|
from .optimizer import Apollo
|
||||||
|
from .weight_mapping import load_hf_model_with_vllm_weights
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_RANK = 64
|
||||||
|
DEFAULT_ZMQ_ADDR = "ipc:///tmp/apollo_training.sock"
|
||||||
|
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||||
|
OPTIMIZER_STATE_PATH = "/tmp/apollo_optimizer_state.pt"
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingWorker:
|
||||||
|
"""Long-lived training worker process."""
|
||||||
|
|
||||||
|
def __init__(self, zmq_addr: str = DEFAULT_ZMQ_ADDR):
|
||||||
|
self.zmq_addr = zmq_addr
|
||||||
|
self.model: nn.Module | None = None
|
||||||
|
self.optimizer: Apollo | None = None
|
||||||
|
self.model_path: str | None = None
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
def _create_model_wrapper(self) -> nn.Module:
|
||||||
|
"""Create HF model wrapper with views into vLLM's GPU memory."""
|
||||||
|
if not os.path.exists(HANDLE_PATH):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Weight handles not found: {HANDLE_PATH}. "
|
||||||
|
"Is vLLM running with the export hook?"
|
||||||
|
)
|
||||||
|
|
||||||
|
handles = torch.load(HANDLE_PATH, weights_only=False)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
metadata = handles.pop('__metadata__', {})
|
||||||
|
self.model_path = metadata.get('model_path') or os.environ.get('APOLLO_MODEL_PATH')
|
||||||
|
if not self.model_path:
|
||||||
|
raise ValueError(
|
||||||
|
"Model path not found in handles metadata or APOLLO_MODEL_PATH env var"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reconstruct tensors from IPC handles
|
||||||
|
vllm_params = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
vllm_params[name] = func(*args)
|
||||||
|
|
||||||
|
model = load_hf_model_with_vllm_weights(vllm_params, self.model_path)
|
||||||
|
model.train()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _get_or_create_optimizer(self, config: dict[str, Any]) -> Apollo:
|
||||||
|
"""Get existing optimizer or create new one."""
|
||||||
|
if self.optimizer is not None:
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
# Build parameter groups (Apollo for 2D+, standard Adam for small/1D)
|
||||||
|
apollo_params, standard_params = [], []
|
||||||
|
for p in self.model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
if p.ndim >= 2 and min(p.shape) >= DEFAULT_RANK:
|
||||||
|
apollo_params.append(p)
|
||||||
|
else:
|
||||||
|
standard_params.append(p)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
if apollo_params:
|
||||||
|
groups.append({'params': apollo_params})
|
||||||
|
if standard_params:
|
||||||
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
|
if not groups:
|
||||||
|
raise ValueError("No trainable parameters found")
|
||||||
|
|
||||||
|
self.optimizer = Apollo(
|
||||||
|
groups,
|
||||||
|
lr=config.get('lr', 1e-5),
|
||||||
|
rank=config.get('rank', DEFAULT_RANK),
|
||||||
|
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||||
|
eps=config.get('eps', 1e-8),
|
||||||
|
weight_decay=config.get('weight_decay', 0.01),
|
||||||
|
warmup_steps=config.get('warmup_steps', 0),
|
||||||
|
scale=config.get('scale'),
|
||||||
|
proj_refresh=config.get('proj_refresh', 200),
|
||||||
|
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore state if exists
|
||||||
|
if os.path.exists(OPTIMIZER_STATE_PATH):
|
||||||
|
try:
|
||||||
|
state = torch.load(OPTIMIZER_STATE_PATH, weights_only=False)
|
||||||
|
self.optimizer.load_state_dict(state)
|
||||||
|
logger.info(f"Restored optimizer state from {OPTIMIZER_STATE_PATH}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not restore optimizer state: {e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Optimizer: {len(apollo_params)} apollo params, "
|
||||||
|
f"{len(standard_params)} standard, "
|
||||||
|
f"state={self.optimizer.state_size_bytes()/1e6:.1f}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
def _save_optimizer_state(self):
|
||||||
|
"""Save optimizer state for persistence."""
|
||||||
|
if self.optimizer is not None:
|
||||||
|
torch.save(self.optimizer.state_dict(), OPTIMIZER_STATE_PATH)
|
||||||
|
logger.info(f"Saved optimizer state to {OPTIMIZER_STATE_PATH}")
|
||||||
|
|
||||||
|
def _run_training(
|
||||||
|
self,
|
||||||
|
samples: list[dict[str, Any]],
|
||||||
|
config: dict[str, Any],
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run Apollo training on the given samples."""
|
||||||
|
optimizer = self._get_or_create_optimizer(config)
|
||||||
|
|
||||||
|
loss_history = []
|
||||||
|
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
ctx_ids = sample['context_ids']
|
||||||
|
cont_ids = sample['continuation_ids']
|
||||||
|
all_ids = ctx_ids + cont_ids
|
||||||
|
context_len = len(ctx_ids)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Context-frozen forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(input_ids[:, :context_len], use_cache=True)
|
||||||
|
past_kv = outputs.past_key_values
|
||||||
|
|
||||||
|
# Decision tokens with gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids[:, context_len:],
|
||||||
|
past_key_values=past_kv,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Shift: predict next token from each position
|
||||||
|
shift_logits = logits[:, :-1].contiguous()
|
||||||
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||||
|
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_val = loss.item()
|
||||||
|
loss_history.append(loss_val)
|
||||||
|
logger.info(
|
||||||
|
f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||||
|
f"(ctx={context_len}, cont={len(cont_ids)} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss_history
|
||||||
|
|
||||||
|
def _handle_train(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Handle a training request."""
|
||||||
|
samples = request.get('samples', [])
|
||||||
|
config = request.get('config', {})
|
||||||
|
|
||||||
|
if not samples:
|
||||||
|
return {'error': 'No training samples provided'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss_history = self._run_training(samples, config)
|
||||||
|
return {
|
||||||
|
'status': 'completed',
|
||||||
|
'training_samples': len(samples),
|
||||||
|
'loss_history': loss_history,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Training failed: {e}")
|
||||||
|
return {'error': str(e)}
|
||||||
|
|
||||||
|
def _handle_checkpoint(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Handle a checkpoint sync request."""
|
||||||
|
if not self.model_path:
|
||||||
|
return {'error': 'Model path not set'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._save_optimizer_state()
|
||||||
|
result = checkpoint_sync(self.model_path)
|
||||||
|
return {
|
||||||
|
'status': 'completed',
|
||||||
|
'total_changed': result['total_changed'],
|
||||||
|
'files_changed': result['files_changed'],
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Checkpoint sync failed: {e}")
|
||||||
|
return {'error': str(e)}
|
||||||
|
|
||||||
|
def _handle_status(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Handle a status request."""
|
||||||
|
return {
|
||||||
|
'status': 'ready',
|
||||||
|
'model_loaded': self.model is not None,
|
||||||
|
'optimizer_loaded': self.optimizer is not None,
|
||||||
|
'model_path': self.model_path,
|
||||||
|
'optimizer_state_mb': (
|
||||||
|
self.optimizer.state_size_bytes() / 1e6
|
||||||
|
if self.optimizer else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Main loop - listen for requests and handle them."""
|
||||||
|
# Set up signal handlers
|
||||||
|
def handle_signal(signum, frame):
|
||||||
|
logger.info(f"Received signal {signum}, shutting down...")
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, handle_signal)
|
||||||
|
signal.signal(signal.SIGINT, handle_signal)
|
||||||
|
|
||||||
|
# Set up ZMQ socket first so API server can connect
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REP)
|
||||||
|
socket.bind(self.zmq_addr)
|
||||||
|
logger.info(f"Training worker listening on {self.zmq_addr}")
|
||||||
|
|
||||||
|
# Create HF model wrapper with views into vLLM's GPU memory
|
||||||
|
logger.info("Connecting to vLLM weights via IPC handles...")
|
||||||
|
try:
|
||||||
|
self.model = self._create_model_wrapper()
|
||||||
|
logger.info("HF model wrapper ready (views into vLLM GPU memory)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to vLLM weights: {e}")
|
||||||
|
logger.info("Will retry on first training request")
|
||||||
|
|
||||||
|
# Set socket timeout so we can check _running flag
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
message = socket.recv_json()
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout, check _running and continue
|
||||||
|
continue
|
||||||
|
|
||||||
|
request_type = message.get('type', 'train')
|
||||||
|
logger.info(f"Received {request_type} request")
|
||||||
|
|
||||||
|
# Ensure model is loaded
|
||||||
|
if self.model is None and request_type != 'status':
|
||||||
|
try:
|
||||||
|
self.model = self._create_model_wrapper()
|
||||||
|
except Exception as e:
|
||||||
|
socket.send_json({'error': f'Model not loaded: {e}'})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dispatch request
|
||||||
|
if request_type == 'train':
|
||||||
|
response = self._handle_train(message)
|
||||||
|
elif request_type == 'checkpoint':
|
||||||
|
response = self._handle_checkpoint(message)
|
||||||
|
elif request_type == 'status':
|
||||||
|
response = self._handle_status(message)
|
||||||
|
else:
|
||||||
|
response = {'error': f'Unknown request type: {request_type}'}
|
||||||
|
|
||||||
|
socket.send_json(response)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
logger.info("Saving optimizer state before shutdown...")
|
||||||
|
self._save_optimizer_state()
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
logger.info("Training worker shut down")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Entry point for running as a subprocess."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='[apollo-worker] %(asctime)s %(levelname)s %(message)s',
|
||||||
|
datefmt='%H:%M:%S',
|
||||||
|
)
|
||||||
|
|
||||||
|
zmq_addr = os.environ.get('APOLLO_ZMQ_ADDR', DEFAULT_ZMQ_ADDR)
|
||||||
|
worker = TrainingWorker(zmq_addr)
|
||||||
|
worker.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
@ -1,454 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Apollo Mini Training Daemon
|
|
||||||
|
|
||||||
This daemon:
|
|
||||||
1. Listens over HTTPS for training requests from poc-agent
|
|
||||||
2. Pauses vLLM inference
|
|
||||||
3. Runs APOLLO-Mini training with torch.enable_grad()
|
|
||||||
4. Saves checkpoints and training metadata
|
|
||||||
5. Resumes vLLM inference
|
|
||||||
|
|
||||||
Communication protocol:
|
|
||||||
- POST /train: Start a training job
|
|
||||||
- GET /status/{job_id}: Check training status
|
|
||||||
- GET /checkpoints: List available checkpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Dict, Any, List
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger('apollo_worker')
|
|
||||||
|
|
||||||
class TrainingStatus(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
PAUSING_VLLM = "pausing_vllm"
|
|
||||||
TRAINING = "training"
|
|
||||||
SAVING_CHECKPOINT = "saving_checkpoint"
|
|
||||||
RESUMING_VLLM = "resuming_vllm"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingJob:
|
|
||||||
job_id: str
|
|
||||||
status: TrainingStatus
|
|
||||||
created_at: datetime
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
model_path: Optional[str] = None
|
|
||||||
checkpoint_path: Optional[str] = None
|
|
||||||
training_samples: int = 0
|
|
||||||
loss_history: List[float] = field(default_factory=list)
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
'job_id': self.job_id,
|
|
||||||
'status': self.status.value,
|
|
||||||
'created_at': self.created_at.isoformat(),
|
|
||||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
|
||||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
|
||||||
'model_path': self.model_path,
|
|
||||||
'checkpoint_path': self.checkpoint_path,
|
|
||||||
'training_samples': self.training_samples,
|
|
||||||
'loss_history': self.loss_history,
|
|
||||||
'error': self.error,
|
|
||||||
}
|
|
||||||
|
|
||||||
class ApolloWorker:
|
|
||||||
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
|
||||||
self.config = self._load_config(config_path)
|
|
||||||
self.jobs: Dict[str, TrainingJob] = {}
|
|
||||||
self.vllm_paused = False
|
|
||||||
self.app = web.Application()
|
|
||||||
self._setup_routes()
|
|
||||||
|
|
||||||
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
|
||||||
"""Load configuration from file or use defaults."""
|
|
||||||
default_config = {
|
|
||||||
'host': '0.0.0.0',
|
|
||||||
'port': 8080,
|
|
||||||
'vllm_socket': '/tmp/vllm_control.sock',
|
|
||||||
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
|
|
||||||
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
|
|
||||||
'max_training_samples': 100,
|
|
||||||
'learning_rate': 1e-5,
|
|
||||||
'batch_size': 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.exists(config_path):
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
user_config = json.load(f)
|
|
||||||
default_config.update(user_config)
|
|
||||||
|
|
||||||
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
|
|
||||||
return default_config
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
"""Setup HTTP routes."""
|
|
||||||
self.app.router.add_post('/train', self.handle_train_request)
|
|
||||||
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
|
|
||||||
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
|
|
||||||
self.app.router.add_get('/health', self.handle_health_check)
|
|
||||||
|
|
||||||
async def handle_health_check(self, request: web.Request) -> web.Response:
|
|
||||||
"""Health check endpoint."""
|
|
||||||
return web.json_response({
|
|
||||||
'status': 'healthy',
|
|
||||||
'vllm_paused': self.vllm_paused,
|
|
||||||
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
|
|
||||||
})
|
|
||||||
|
|
||||||
async def handle_train_request(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle training request from poc-agent."""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Validate required fields
|
|
||||||
if 'training_data' not in data:
|
|
||||||
return web.json_response(
|
|
||||||
{'error': 'Missing training_data field'},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
|
|
||||||
job = TrainingJob(
|
|
||||||
job_id=job_id,
|
|
||||||
status=TrainingStatus.PENDING,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
model_path=self.config['model_path']
|
|
||||||
)
|
|
||||||
self.jobs[job_id] = job
|
|
||||||
|
|
||||||
# Start training in background
|
|
||||||
asyncio.create_task(self.execute_training(job, data))
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'job_id': job_id,
|
|
||||||
'status': 'accepted',
|
|
||||||
'message': 'Training job started'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling train request: {e}")
|
|
||||||
return web.json_response(
|
|
||||||
{'error': str(e)},
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_status_request(self, request: web.Request) -> web.Response:
|
|
||||||
"""Get training job status."""
|
|
||||||
job_id = request.match_info['job_id']
|
|
||||||
|
|
||||||
if job_id not in self.jobs:
|
|
||||||
return web.json_response(
|
|
||||||
{'error': 'Job not found'},
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
job = self.jobs[job_id]
|
|
||||||
return web.json_response(job.to_dict())
|
|
||||||
|
|
||||||
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
|
|
||||||
"""List available checkpoints."""
|
|
||||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
|
||||||
checkpoints = []
|
|
||||||
|
|
||||||
if checkpoint_dir.exists():
|
|
||||||
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
|
|
||||||
checkpoints.append({
|
|
||||||
'filename': checkpoint_file.name,
|
|
||||||
'path': str(checkpoint_file),
|
|
||||||
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
|
|
||||||
'size': checkpoint_file.stat().st_size
|
|
||||||
})
|
|
||||||
|
|
||||||
return web.json_response({'checkpoints': checkpoints})
|
|
||||||
|
|
||||||
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
|
|
||||||
"""Execute the training pipeline."""
|
|
||||||
try:
|
|
||||||
logger.info(f"Starting training job {job.job_id}")
|
|
||||||
job.started_at = datetime.now()
|
|
||||||
|
|
||||||
# Step 1: Pause vLLM
|
|
||||||
job.status = TrainingStatus.PAUSING_VLLM
|
|
||||||
logger.info("Pausing vLLM...")
|
|
||||||
await self.pause_vllm()
|
|
||||||
self.vllm_paused = True
|
|
||||||
|
|
||||||
# Step 2: Load model and prepare for training
|
|
||||||
job.status = TrainingStatus.TRAINING
|
|
||||||
logger.info("Loading model and preparing for training...")
|
|
||||||
|
|
||||||
# Load model (this would be the actual Qwen3.5-27B model)
|
|
||||||
# For now, we'll use a placeholder
|
|
||||||
model = await self.load_model_for_training()
|
|
||||||
|
|
||||||
# Step 3: Run APOLLO-Mini training
|
|
||||||
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
|
|
||||||
|
|
||||||
# Extract training samples
|
|
||||||
samples = training_data['samples']
|
|
||||||
job.training_samples = len(samples)
|
|
||||||
|
|
||||||
# Run training loop
|
|
||||||
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
|
|
||||||
job.loss_history = loss_history
|
|
||||||
|
|
||||||
# Step 4: Save checkpoint
|
|
||||||
job.status = TrainingStatus.SAVING_CHECKPOINT
|
|
||||||
logger.info("Saving checkpoint...")
|
|
||||||
checkpoint_path = await self.save_checkpoint(model, job)
|
|
||||||
job.checkpoint_path = checkpoint_path
|
|
||||||
|
|
||||||
# Step 5: Resume vLLM
|
|
||||||
job.status = TrainingStatus.RESUMING_VLLM
|
|
||||||
logger.info("Resuming vLLM...")
|
|
||||||
await self.resume_vllm()
|
|
||||||
self.vllm_paused = False
|
|
||||||
|
|
||||||
# Mark job as completed
|
|
||||||
job.status = TrainingStatus.COMPLETED
|
|
||||||
job.completed_at = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"Training job {job.job_id} completed successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Training job {job.job_id} failed: {e}")
|
|
||||||
job.status = TrainingStatus.FAILED
|
|
||||||
job.error = str(e)
|
|
||||||
job.completed_at = datetime.now()
|
|
||||||
|
|
||||||
# Try to resume vLLM if it was paused
|
|
||||||
if self.vllm_paused:
|
|
||||||
try:
|
|
||||||
await self.resume_vllm()
|
|
||||||
self.vllm_paused = False
|
|
||||||
except Exception as resume_error:
|
|
||||||
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
|
|
||||||
|
|
||||||
async def pause_vllm(self):
|
|
||||||
"""Pause vLLM inference via HTTP API."""
|
|
||||||
import aiohttp as aio
|
|
||||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
|
||||||
try:
|
|
||||||
async with aio.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/pause_generation",
|
|
||||||
json={"mode": "keep", "clear_cache": False},
|
|
||||||
timeout=aio.ClientTimeout(total=10),
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
logger.info("vLLM paused")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to pause vLLM: {e}")
|
|
||||||
|
|
||||||
async def resume_vllm(self):
|
|
||||||
"""Resume vLLM inference via HTTP API."""
|
|
||||||
import aiohttp as aio
|
|
||||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
|
||||||
try:
|
|
||||||
async with aio.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/resume_generation",
|
|
||||||
timeout=aio.ClientTimeout(total=10),
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
logger.info("vLLM resumed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to resume vLLM: {e}")
|
|
||||||
|
|
||||||
async def load_model_for_training(self) -> nn.Module:
|
|
||||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
|
||||||
|
|
||||||
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
|
|
||||||
views (narrowing merged weights into separate q/k/v/z etc.), and
|
|
||||||
constructs the HF model around those views. No weight copying —
|
|
||||||
all parameters share vLLM's GPU memory.
|
|
||||||
"""
|
|
||||||
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
|
|
||||||
model_path = self.config['model_path']
|
|
||||||
|
|
||||||
# Import vLLM weights via CUDA IPC
|
|
||||||
logger.info(f"Importing vLLM weights from {handle_path}")
|
|
||||||
handles = torch.load(handle_path, weights_only=False)
|
|
||||||
vllm_params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args = info['handle']
|
|
||||||
vllm_params[name] = func(*args)
|
|
||||||
logger.info(f"Imported {len(vllm_params)} parameters")
|
|
||||||
|
|
||||||
# Map vLLM merged layout → HF separate layout (views, no copies)
|
|
||||||
from weight_mapping import load_hf_model_with_vllm_weights
|
|
||||||
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
|
||||||
logger.info("HF model constructed with vLLM weight views")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def run_apollo_training(self, model: nn.Module,
|
|
||||||
samples: List[Dict[str, str]],
|
|
||||||
config: Dict[str, Any]) -> List[float]:
|
|
||||||
"""Run Apollo-Mini training on conversation decision points."""
|
|
||||||
from apollo_mini import Apollo
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.config['model_path'], trust_remote_code=True)
|
|
||||||
|
|
||||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
|
||||||
apollo_params, standard_params = [], []
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({'params': apollo_params})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({'params': standard_params})
|
|
||||||
|
|
||||||
rank = config.get('apollo_rank', 1)
|
|
||||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
|
||||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
|
||||||
f"{len(standard_params)} standard, "
|
|
||||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
|
||||||
|
|
||||||
loss_history = []
|
|
||||||
|
|
||||||
for i, sample in enumerate(samples):
|
|
||||||
context = sample.get('context', '')
|
|
||||||
continuation = sample.get('continuation', '')
|
|
||||||
|
|
||||||
# Tokenize
|
|
||||||
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
|
|
||||||
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
|
||||||
all_ids = ctx_ids + cont_ids
|
|
||||||
context_len = len(ctx_ids)
|
|
||||||
|
|
||||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Context-frozen forward pass
|
|
||||||
with torch.no_grad():
|
|
||||||
# Forward through context (no gradients)
|
|
||||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
|
||||||
past_kv = outputs.past_key_values
|
|
||||||
|
|
||||||
# Decision tokens with gradients
|
|
||||||
with torch.enable_grad():
|
|
||||||
outputs = model(
|
|
||||||
input_ids[:, context_len:],
|
|
||||||
past_key_values=past_kv,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
logits = outputs.logits # [1, cont_len, vocab]
|
|
||||||
|
|
||||||
# Shift: predict next token from each position
|
|
||||||
shift_logits = logits[:, :-1].contiguous()
|
|
||||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
||||||
|
|
||||||
loss = nn.functional.cross_entropy(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1),
|
|
||||||
)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
loss_val = loss.item()
|
|
||||||
loss_history.append(loss_val)
|
|
||||||
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
|
||||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
|
||||||
|
|
||||||
logger.info(f"Training done: {len(samples)} examples, "
|
|
||||||
f"final loss={loss_history[-1]:.4f}")
|
|
||||||
return loss_history
|
|
||||||
|
|
||||||
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
|
|
||||||
"""Save model checkpoint in HuggingFace safetensors format."""
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
|
||||||
date_str = datetime.now().strftime('%Y-%m-%d')
|
|
||||||
out_dir = checkpoint_dir / date_str
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save weights
|
|
||||||
tensors = {name: p.data.contiguous().cpu()
|
|
||||||
for name, p in model.named_parameters()}
|
|
||||||
save_path = out_dir / "model.safetensors"
|
|
||||||
save_file(tensors, str(save_path))
|
|
||||||
|
|
||||||
# Copy config files
|
|
||||||
config_dir = Path(self.config['model_path'])
|
|
||||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
|
||||||
'special_tokens_map.json']:
|
|
||||||
src = config_dir / f
|
|
||||||
if src.exists():
|
|
||||||
shutil.copy2(src, out_dir / f)
|
|
||||||
|
|
||||||
# Save training metadata
|
|
||||||
meta = {
|
|
||||||
'job_id': job.job_id,
|
|
||||||
'training_samples': job.training_samples,
|
|
||||||
'loss_history': job.loss_history,
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
with open(out_dir / 'training-meta.json', 'w') as f:
|
|
||||||
json.dump(meta, f, indent=2)
|
|
||||||
|
|
||||||
# Update latest symlink
|
|
||||||
latest = checkpoint_dir / 'latest'
|
|
||||||
if latest.is_symlink():
|
|
||||||
latest.unlink()
|
|
||||||
latest.symlink_to(date_str)
|
|
||||||
|
|
||||||
size_gb = save_path.stat().st_size / 1e9
|
|
||||||
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
|
|
||||||
return str(out_dir)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Run the daemon."""
|
|
||||||
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
|
|
||||||
runner = web.AppRunner(self.app)
|
|
||||||
await runner.setup()
|
|
||||||
site = web.TCPSite(runner, self.config['host'], self.config['port'])
|
|
||||||
await site.start()
|
|
||||||
logger.info("Apollo Worker is running")
|
|
||||||
|
|
||||||
# Keep running
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(3600) # Sleep for an hour
|
|
||||||
|
|
||||||
def main():
|
|
||||||
worker = ApolloWorker()
|
|
||||||
asyncio.run(worker.run())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "apollo-checkpoint"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2024"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
memmap2 = "0.9"
|
|
||||||
safetensors = "0.5"
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json = "1"
|
|
||||||
anyhow = "1"
|
|
||||||
clap = { version = "4", features = ["derive"] }
|
|
||||||
|
|
@ -1,265 +0,0 @@
|
||||||
// apollo-checkpoint — Sync live GPU weights back to model files on disk.
|
|
||||||
//
|
|
||||||
// mmaps the model's safetensors files, reads live weights from GPU via
|
|
||||||
// Python helper (CUDA IPC handles), compares block by block, and memcpys
|
|
||||||
// only changed regions back into the mmap. For small behavioral training
|
|
||||||
// steps, this turns a 54GB write into a few hundred MB.
|
|
||||||
//
|
|
||||||
// The model files on disk are the checkpoint. No separate checkpoint
|
|
||||||
// directory — just keep the model up to date.
|
|
||||||
//
|
|
||||||
// Usage:
|
|
||||||
// apollo-checkpoint sync \
|
|
||||||
// --handles /tmp/vllm_weight_handles.pt \
|
|
||||||
// --model-dir /path/to/Qwen3.5-27B
|
|
||||||
//
|
|
||||||
// Runs every 10 minutes via cron. Daily rsync to moria.
|
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
|
||||||
use clap::{Parser, Subcommand};
|
|
||||||
use memmap2::MmapMut;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fs;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")]
|
|
||||||
struct Cli {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Cmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
|
||||||
enum Cmd {
|
|
||||||
/// Sync live GPU weights back to model safetensors files
|
|
||||||
Sync {
|
|
||||||
/// Path to vLLM weight IPC handles
|
|
||||||
#[arg(long, default_value = "/tmp/vllm_weight_handles.pt")]
|
|
||||||
handles: PathBuf,
|
|
||||||
|
|
||||||
/// Model directory containing safetensors files
|
|
||||||
#[arg(long)]
|
|
||||||
model_dir: PathBuf,
|
|
||||||
|
|
||||||
/// Block size for diffing (bytes)
|
|
||||||
#[arg(long, default_value_t = 4096)]
|
|
||||||
block_size: usize,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dump live GPU weights to a flat binary file, ordered by safetensors
|
|
||||||
/// file and offset to match the on-disk layout.
|
|
||||||
///
|
|
||||||
/// Returns a map of (safetensors filename, tensor name) → raw bytes.
|
|
||||||
fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result<HashMap<String, Vec<u8>>> {
|
|
||||||
let dump_path = output_dir.join(".live_dump.bin");
|
|
||||||
let index_path = output_dir.join(".live_dump.json");
|
|
||||||
|
|
||||||
let status = Command::new("python3")
|
|
||||||
.arg("-c")
|
|
||||||
.arg(format!(r#"
|
|
||||||
import torch, json
|
|
||||||
|
|
||||||
handles = torch.load("{handles}", weights_only=False)
|
|
||||||
index = {{}}
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
with open("{dump}", "wb") as f:
|
|
||||||
for name in sorted(handles.keys()):
|
|
||||||
info = handles[name]
|
|
||||||
func, args = info["handle"]
|
|
||||||
tensor = func(*args)
|
|
||||||
data = tensor.contiguous().cpu().numpy().tobytes()
|
|
||||||
f.write(data)
|
|
||||||
index[name] = {{"offset": offset, "size": len(data)}}
|
|
||||||
offset += len(data)
|
|
||||||
|
|
||||||
with open("{index}", "w") as f:
|
|
||||||
json.dump(index, f)
|
|
||||||
|
|
||||||
print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB")
|
|
||||||
"#,
|
|
||||||
handles = handles_path.display(),
|
|
||||||
dump = dump_path.display(),
|
|
||||||
index = index_path.display(),
|
|
||||||
))
|
|
||||||
.status()
|
|
||||||
.context("Failed to run Python weight dump")?;
|
|
||||||
|
|
||||||
if !status.success() {
|
|
||||||
bail!("Python weight dump failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
let index_str = fs::read_to_string(&index_path)?;
|
|
||||||
let index: HashMap<String, DumpEntry> = serde_json::from_str(&index_str)?;
|
|
||||||
let dump_data = fs::read(&dump_path)?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
for (name, entry) in &index {
|
|
||||||
result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up temp files
|
|
||||||
let _ = fs::remove_file(&dump_path);
|
|
||||||
let _ = fs::remove_file(&index_path);
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
|
||||||
struct DumpEntry {
|
|
||||||
offset: usize,
|
|
||||||
size: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read the safetensors index to map parameter names to files.
|
|
||||||
fn read_safetensors_index(model_dir: &Path) -> Result<HashMap<String, String>> {
|
|
||||||
let index_path = model_dir.join("model.safetensors.index.json");
|
|
||||||
if !index_path.exists() {
|
|
||||||
// Single file model
|
|
||||||
return Ok(HashMap::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let index_str = fs::read_to_string(&index_path)?;
|
|
||||||
let index: serde_json::Value = serde_json::from_str(&index_str)?;
|
|
||||||
let weight_map = index["weight_map"]
|
|
||||||
.as_object()
|
|
||||||
.context("No weight_map in index")?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
for (name, file) in weight_map {
|
|
||||||
result.insert(name.clone(), file.as_str().unwrap().to_string());
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sync changed blocks from live weights into a mmap'd safetensors file.
|
|
||||||
/// Returns (total_bytes_compared, bytes_changed).
|
|
||||||
fn sync_tensors_to_file(
|
|
||||||
file_path: &Path,
|
|
||||||
tensors: &[(String, Vec<u8>)],
|
|
||||||
block_size: usize,
|
|
||||||
) -> Result<(usize, usize)> {
|
|
||||||
use safetensors::SafeTensors;
|
|
||||||
|
|
||||||
let file = fs::OpenOptions::new()
|
|
||||||
.read(true)
|
|
||||||
.write(true)
|
|
||||||
.open(file_path)
|
|
||||||
.with_context(|| format!("Failed to open {}", file_path.display()))?;
|
|
||||||
|
|
||||||
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
|
||||||
|
|
||||||
// Parse safetensors header to find tensor offsets
|
|
||||||
let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize;
|
|
||||||
let header_json: serde_json::Value =
|
|
||||||
serde_json::from_slice(&mmap[8..8 + header_size])?;
|
|
||||||
let data_start = 8 + header_size;
|
|
||||||
|
|
||||||
let mut total_compared = 0usize;
|
|
||||||
let mut total_changed = 0usize;
|
|
||||||
|
|
||||||
for (name, live_data) in tensors {
|
|
||||||
let meta = match header_json.get(name) {
|
|
||||||
Some(m) => m,
|
|
||||||
None => {
|
|
||||||
eprintln!(" Warning: {} not found in {}", name, file_path.display());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let offsets = meta["data_offsets"].as_array().unwrap();
|
|
||||||
let start = data_start + offsets[0].as_u64().unwrap() as usize;
|
|
||||||
let end = data_start + offsets[1].as_u64().unwrap() as usize;
|
|
||||||
let disk_data = &mmap[start..end];
|
|
||||||
|
|
||||||
if disk_data.len() != live_data.len() {
|
|
||||||
eprintln!(" Warning: size mismatch for {}: disk={} live={}",
|
|
||||||
name, disk_data.len(), live_data.len());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Diff block by block, memcpy only changed blocks
|
|
||||||
let mut offset = 0;
|
|
||||||
while offset < disk_data.len() {
|
|
||||||
let block_end = (offset + block_size).min(disk_data.len());
|
|
||||||
total_compared += block_end - offset;
|
|
||||||
|
|
||||||
if disk_data[offset..block_end] != live_data[offset..block_end] {
|
|
||||||
mmap[start + offset..start + block_end]
|
|
||||||
.copy_from_slice(&live_data[offset..block_end]);
|
|
||||||
total_changed += block_end - offset;
|
|
||||||
}
|
|
||||||
offset = block_end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mmap.flush()?;
|
|
||||||
Ok((total_compared, total_changed))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> {
|
|
||||||
if !handles.exists() {
|
|
||||||
bail!("Weight handles not found: {}. Is vLLM running with the export hook?",
|
|
||||||
handles.display());
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("Dumping live weights from GPU...");
|
|
||||||
let live_weights = dump_live_weights(&handles, &model_dir)?;
|
|
||||||
eprintln!(" {} tensors dumped", live_weights.len());
|
|
||||||
|
|
||||||
// Map parameter names to safetensors files
|
|
||||||
let weight_map = read_safetensors_index(&model_dir)?;
|
|
||||||
|
|
||||||
// Group tensors by safetensors file
|
|
||||||
let mut by_file: HashMap<String, Vec<(String, Vec<u8>)>> = HashMap::new();
|
|
||||||
for (name, data) in live_weights {
|
|
||||||
let file = weight_map
|
|
||||||
.get(&name)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_else(|| "model.safetensors".to_string());
|
|
||||||
by_file.entry(file).or_default().push((name, data));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut total_compared = 0usize;
|
|
||||||
let mut total_changed = 0usize;
|
|
||||||
|
|
||||||
for (filename, tensors) in &by_file {
|
|
||||||
let file_path = model_dir.join(filename);
|
|
||||||
if !file_path.exists() {
|
|
||||||
eprintln!(" Warning: {} not found, skipping", filename);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?;
|
|
||||||
total_compared += compared;
|
|
||||||
total_changed += changed;
|
|
||||||
|
|
||||||
if changed > 0 {
|
|
||||||
eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if total_changed == 0 {
|
|
||||||
eprintln!("No changes — model files are up to date");
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)",
|
|
||||||
total_changed as f64 / 1e6,
|
|
||||||
total_compared as f64 / 1e9,
|
|
||||||
total_changed as f64 / total_compared as f64 * 100.0,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
let cli = Cli::parse();
|
|
||||||
match cli.command {
|
|
||||||
Cmd::Sync { handles, model_dir, block_size } => {
|
|
||||||
cmd_sync(handles, model_dir, block_size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,87 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Export vLLM's live model weight IPC handles for the training process.
|
|
||||||
|
|
||||||
Connects to a running vLLM instance, iterates over model parameters,
|
|
||||||
and exports CUDA IPC handles that allow another process to access the
|
|
||||||
same GPU memory without copying.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Run after vLLM is serving:
|
|
||||||
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
|
|
||||||
|
|
||||||
# Or via vLLM's API (future):
|
|
||||||
curl -X POST http://localhost:8000/export_weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def export_from_model(model, output_path: str):
|
|
||||||
"""Export IPC handles for all model parameters."""
|
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
|
||||||
|
|
||||||
handles = {}
|
|
||||||
total_bytes = 0
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
handle = reduce_tensor(param.data)
|
|
||||||
handles[name] = {
|
|
||||||
'handle': handle,
|
|
||||||
'shape': list(param.shape),
|
|
||||||
'dtype': str(param.dtype),
|
|
||||||
}
|
|
||||||
param_bytes = param.nelement() * param.element_size()
|
|
||||||
total_bytes += param_bytes
|
|
||||||
|
|
||||||
torch.save(handles, output_path)
|
|
||||||
|
|
||||||
n_params = len(handles)
|
|
||||||
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
|
|
||||||
print(f"Saved to {output_path}")
|
|
||||||
return handles
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
|
|
||||||
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
|
|
||||||
help="Output path for IPC handles")
|
|
||||||
parser.add_argument("--vllm-pid", type=int, default=None,
|
|
||||||
help="vLLM worker PID (auto-detected if not specified)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# For now: load the model directly and export.
|
|
||||||
# TODO: connect to running vLLM process instead.
|
|
||||||
print("Note: This currently loads the model separately.")
|
|
||||||
print("Full integration will export from the running vLLM process.")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Detect model path from running vLLM
|
|
||||||
import subprocess
|
|
||||||
result = subprocess.run(
|
|
||||||
['ps', 'aux'], capture_output=True, text=True
|
|
||||||
)
|
|
||||||
model_path = None
|
|
||||||
for line in result.stdout.split('\n'):
|
|
||||||
if 'vllm' in line and '--model' in line:
|
|
||||||
parts = line.split()
|
|
||||||
for i, p in enumerate(parts):
|
|
||||||
if p == '--model' and i + 1 < len(parts):
|
|
||||||
model_path = parts[i + 1]
|
|
||||||
break
|
|
||||||
# Also check model_tag format
|
|
||||||
if p.startswith('--model='):
|
|
||||||
model_path = p.split('=', 1)[1]
|
|
||||||
break
|
|
||||||
|
|
||||||
if model_path:
|
|
||||||
print(f"Detected vLLM model: {model_path}")
|
|
||||||
else:
|
|
||||||
print("Could not detect running vLLM model. Specify manually.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,215 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""First real Apollo training step — ready for Kent to run.
|
|
||||||
|
|
||||||
This script:
|
|
||||||
1. Imports vLLM's live weights via CUDA IPC
|
|
||||||
2. Constructs HF model with shared memory views
|
|
||||||
3. Runs ONE forward+backward on a real training example
|
|
||||||
4. Applies ONE Apollo optimizer step
|
|
||||||
5. Verifies vLLM still works after the update
|
|
||||||
|
|
||||||
The training example is from March 30: Kent said "use vLLM's code"
|
|
||||||
and the model should have accepted instead of suggesting alternatives.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
source ~/training-env/bin/activate
|
|
||||||
python3 first_training_step.py [--dry-run]
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
|
||||||
|
|
||||||
sys.path.insert(0, '.')
|
|
||||||
from weight_mapping import vllm_to_hf_views
|
|
||||||
from apollo_mini import Apollo
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--dry-run', action='store_true',
|
|
||||||
help="Run forward+backward but don't apply the optimizer step")
|
|
||||||
parser.add_argument('--lr', type=float, default=1e-5,
|
|
||||||
help="Learning rate (default: 1e-5 = conservative)")
|
|
||||||
parser.add_argument('--rank', type=int, default=256)
|
|
||||||
parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt')
|
|
||||||
parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print("=== First Apollo Training Step ===\n")
|
|
||||||
|
|
||||||
# 1. Import vLLM weights
|
|
||||||
print("1. Importing vLLM weights via CUDA IPC...")
|
|
||||||
handles = torch.load(args.handles, weights_only=False)
|
|
||||||
vllm_params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args_h = info['handle']
|
|
||||||
vllm_params[name] = func(*args_h)
|
|
||||||
print(f" {len(vllm_params)} parameters imported")
|
|
||||||
|
|
||||||
# 2. Map to HF layout
|
|
||||||
print("2. Mapping to HF layout (zero-copy views)...")
|
|
||||||
hf_params = vllm_to_hf_views(vllm_params)
|
|
||||||
|
|
||||||
# 3. Create HF model
|
|
||||||
print("3. Creating HF model with shared weights...")
|
|
||||||
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
|
||||||
with torch.device('meta'):
|
|
||||||
model = Qwen3_5ForCausalLM(config.text_config)
|
|
||||||
|
|
||||||
replaced = 0
|
|
||||||
for name, param in list(model.named_parameters()):
|
|
||||||
if name in hf_params:
|
|
||||||
parts = name.split('.')
|
|
||||||
parent = model
|
|
||||||
for part in parts[:-1]:
|
|
||||||
parent = getattr(parent, part)
|
|
||||||
setattr(parent, parts[-1],
|
|
||||||
nn.Parameter(hf_params[name], requires_grad=True))
|
|
||||||
replaced += 1
|
|
||||||
print(f" {replaced} parameters replaced with vLLM memory views")
|
|
||||||
|
|
||||||
# 4. Load tokenizer
|
|
||||||
print("4. Loading tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
# 5. Construct training example
|
|
||||||
print("5. Constructing training example...")
|
|
||||||
|
|
||||||
# Context: conversation where Kent says to use vLLM's code
|
|
||||||
# Target: the response that accepts the direction
|
|
||||||
context = (
|
|
||||||
"<|im_start|>user\n"
|
|
||||||
"vllm has a fused kernel already, right?<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
"Yeah — vLLM has `gdn_attention_core` which is a custom op "
|
|
||||||
"that does the whole GDN layer's core in one dispatch.<|im_end|>\n"
|
|
||||||
"<|im_start|>user\n"
|
|
||||||
"Why wouldn't we just use that?<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The CORRECT response (accept direction, don't suggest alternatives)
|
|
||||||
continuation = (
|
|
||||||
"We should. Let me pull in their kernel and wire it into "
|
|
||||||
"our Rust orchestration. Which file should I start with?"
|
|
||||||
)
|
|
||||||
|
|
||||||
context_ids = tokenizer.encode(context, add_special_tokens=False)
|
|
||||||
continuation_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
context_len = len(context_ids)
|
|
||||||
|
|
||||||
print(f" Context: {context_len} tokens")
|
|
||||||
print(f" Continuation: {len(continuation_ids)} tokens")
|
|
||||||
print(f" Total: {len(all_ids)} tokens")
|
|
||||||
|
|
||||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
|
||||||
|
|
||||||
# 6. Initialize Apollo optimizer
|
|
||||||
print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...")
|
|
||||||
apollo_params = []
|
|
||||||
standard_params = []
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= args.rank:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({'params': apollo_params})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({'params': standard_params})
|
|
||||||
|
|
||||||
optimizer = Apollo(groups, lr=args.lr, rank=args.rank)
|
|
||||||
print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard")
|
|
||||||
|
|
||||||
# 7. Forward pass
|
|
||||||
print("7. Forward pass...")
|
|
||||||
model.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Context-frozen: no grad for context, grad for continuation
|
|
||||||
with torch.no_grad():
|
|
||||||
ctx_output = model(input_ids[:, :context_len], use_cache=True)
|
|
||||||
past_kv = ctx_output.past_key_values
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
output = model(input_ids[:, context_len:],
|
|
||||||
past_key_values=past_kv, use_cache=False)
|
|
||||||
logits = output.logits
|
|
||||||
# Shift for next-token prediction
|
|
||||||
shift_logits = logits[:, :-1].contiguous()
|
|
||||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1),
|
|
||||||
)
|
|
||||||
print(f" Loss: {loss.item():.4f}")
|
|
||||||
|
|
||||||
# 8. Backward pass
|
|
||||||
print("8. Backward pass...")
|
|
||||||
loss.backward()
|
|
||||||
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
|
||||||
print(f" {n_grads} parameters have gradients")
|
|
||||||
|
|
||||||
# 9. Apollo step (or dry run)
|
|
||||||
if args.dry_run:
|
|
||||||
print("\n9. DRY RUN — skipping optimizer step")
|
|
||||||
print(" (run without --dry-run to apply the update)")
|
|
||||||
else:
|
|
||||||
print("9. Applying Apollo optimizer step...")
|
|
||||||
# Record a few weight norms before
|
|
||||||
sample_norms_before = {}
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if 'layers.0.' in name and p.grad is not None:
|
|
||||||
sample_norms_before[name] = p.data.norm().item()
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Check weight changes
|
|
||||||
print(" Weight changes (layer 0):")
|
|
||||||
for name, before in sample_norms_before.items():
|
|
||||||
p = dict(model.named_parameters())[name]
|
|
||||||
after = p.data.norm().item()
|
|
||||||
delta = abs(after - before)
|
|
||||||
pct = delta / before * 100 if before > 0 else 0
|
|
||||||
print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)")
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# 10. Verify vLLM still works
|
|
||||||
print("\n10. Verifying vLLM still serves...")
|
|
||||||
import subprocess
|
|
||||||
result = subprocess.run(
|
|
||||||
['curl', '-s', '--max-time', '30',
|
|
||||||
'-X', 'POST', 'http://localhost:8000/v1/chat/completions',
|
|
||||||
'-H', 'Content-Type: application/json',
|
|
||||||
'-H', 'Authorization: Bearer bcachefs-agents-2026',
|
|
||||||
'-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'],
|
|
||||||
capture_output=True, text=True, timeout=45
|
|
||||||
)
|
|
||||||
if result.returncode == 0 and 'choices' in result.stdout:
|
|
||||||
print(" vLLM still serving ✓")
|
|
||||||
else:
|
|
||||||
print(" WARNING: vLLM may not be responding")
|
|
||||||
print(f" stdout: {result.stdout[:200]}")
|
|
||||||
|
|
||||||
print("\n=== COMPLETE ===")
|
|
||||||
if args.dry_run:
|
|
||||||
print("Run without --dry-run to apply the first real training step.")
|
|
||||||
else:
|
|
||||||
print("First Apollo training step applied to vLLM's live weights.")
|
|
||||||
print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
29
training/pyproject.toml
Normal file
29
training/pyproject.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "apollo-plugin"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Apollo training plugin for vLLM"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
"aiohttp",
|
||||||
|
"safetensors",
|
||||||
|
"pyzmq",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["pytest"]
|
||||||
|
|
||||||
|
[project.entry-points."vllm.general_plugins"]
|
||||||
|
apollo = "apollo_plugin:register"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
apollo-checkpoint = "apollo_plugin.checkpoint_sync:main"
|
||||||
|
apollo-worker = "apollo_plugin.training_worker:main"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["apollo_plugin*"]
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Start vLLM with Apollo weight export hook.
|
|
||||||
#
|
|
||||||
# The hook patches vLLM's model runner to export CUDA IPC handles
|
|
||||||
# after loading, so the Apollo training process can share the same
|
|
||||||
# GPU memory.
|
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
|
||||||
|
|
||||||
exec python3 -c "
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, '$SCRIPT_DIR')
|
|
||||||
import vllm_export_hook # patches model runner before vLLM loads
|
|
||||||
|
|
||||||
sys.argv = ['vllm'] + sys.argv[1:]
|
|
||||||
from vllm.entrypoints.cli.main import main
|
|
||||||
main()
|
|
||||||
" serve "$@"
|
|
||||||
|
|
@ -1,269 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Nightly training process for Apollo-Mini fine-tuning.
|
|
||||||
|
|
||||||
Imports vLLM's model weights via CUDA IPC, runs context-frozen
|
|
||||||
training on flagged conversation segments, saves updated checkpoint.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python3 train.py \
|
|
||||||
--weights /tmp/vllm_weight_handles.pt \
|
|
||||||
--examples training-examples.jsonl \
|
|
||||||
--checkpoint-dir checkpoints/ \
|
|
||||||
--lr 1e-5
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from apollo_mini import ApolloMini
|
|
||||||
|
|
||||||
|
|
||||||
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
|
|
||||||
"""Import weight tensors from CUDA IPC handles."""
|
|
||||||
handles = torch.load(handle_path, weights_only=False)
|
|
||||||
params = {}
|
|
||||||
for name, info in handles.items():
|
|
||||||
func, args = info['handle']
|
|
||||||
tensor = func(*args)
|
|
||||||
params[name] = tensor
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
|
|
||||||
"""Split parameters into Apollo-Mini and standard groups.
|
|
||||||
|
|
||||||
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
|
|
||||||
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
|
|
||||||
"""
|
|
||||||
apollo_params = []
|
|
||||||
standard_params = []
|
|
||||||
|
|
||||||
for name, p in params.items():
|
|
||||||
p.requires_grad_(True)
|
|
||||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
|
||||||
apollo_params.append(p)
|
|
||||||
else:
|
|
||||||
standard_params.append(p)
|
|
||||||
|
|
||||||
groups = []
|
|
||||||
if apollo_params:
|
|
||||||
groups.append({
|
|
||||||
'params': apollo_params,
|
|
||||||
'name': 'apollo',
|
|
||||||
})
|
|
||||||
if standard_params:
|
|
||||||
groups.append({
|
|
||||||
'params': standard_params,
|
|
||||||
'name': 'standard',
|
|
||||||
})
|
|
||||||
|
|
||||||
n_apollo = sum(p.nelement() for p in apollo_params)
|
|
||||||
n_standard = sum(p.nelement() for p in standard_params)
|
|
||||||
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
|
|
||||||
return groups
|
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(params, input_ids, context_len, device):
|
|
||||||
"""Run context-frozen forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: dict of name -> tensor (shared with vLLM)
|
|
||||||
input_ids: full sequence [1, seq_len]
|
|
||||||
context_len: number of context tokens (no gradient)
|
|
||||||
device: CUDA device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits for decision tokens, target ids for loss
|
|
||||||
"""
|
|
||||||
# TODO: Build proper forward model matching vLLM's weight layout.
|
|
||||||
# For now this is a placeholder — the real implementation needs
|
|
||||||
# to replicate vLLM's model architecture (merged projections,
|
|
||||||
# GDN recurrence, full attention, MLP) using the shared weights.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Forward model not yet implemented. "
|
|
||||||
"Need to build a model that matches vLLM's merged weight layout "
|
|
||||||
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
|
|
||||||
"RowParallelLinear for out_proj/down) and computes the same "
|
|
||||||
"forward pass with autograd enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(params: dict[str, torch.Tensor],
|
|
||||||
checkpoint_dir: str,
|
|
||||||
config_path: str = None):
|
|
||||||
"""Save model checkpoint in HuggingFace safetensors format.
|
|
||||||
|
|
||||||
Saves weights split across shards matching the original model layout,
|
|
||||||
archives the previous checkpoint, and updates the 'latest' symlink.
|
|
||||||
"""
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
out_dir = Path(checkpoint_dir) / date_str
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save all weights in a single safetensors file for now.
|
|
||||||
# TODO: split across shards matching HF model index for large models.
|
|
||||||
tensors = {}
|
|
||||||
for name, param in params.items():
|
|
||||||
tensors[name] = param.data.contiguous().cpu()
|
|
||||||
|
|
||||||
save_path = out_dir / "model.safetensors"
|
|
||||||
save_file(tensors, str(save_path))
|
|
||||||
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
|
|
||||||
|
|
||||||
# Copy config files if provided
|
|
||||||
if config_path:
|
|
||||||
import shutil
|
|
||||||
config_dir = Path(config_path)
|
|
||||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
|
||||||
'special_tokens_map.json', 'generation_config.json']:
|
|
||||||
src = config_dir / f
|
|
||||||
if src.exists():
|
|
||||||
shutil.copy2(src, out_dir / f)
|
|
||||||
|
|
||||||
# Update latest symlink
|
|
||||||
latest = Path(checkpoint_dir) / "latest"
|
|
||||||
if latest.is_symlink():
|
|
||||||
latest.unlink()
|
|
||||||
latest.symlink_to(date_str)
|
|
||||||
print(f"Updated {latest} -> {date_str}")
|
|
||||||
|
|
||||||
return str(out_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def train_step(params, example, optimizer, device, log_entries):
|
|
||||||
"""Run one training step on a single example.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: dict of name -> tensor
|
|
||||||
example: dict with 'input_ids', 'context_len', 'target_ids'
|
|
||||||
optimizer: ApolloMini instance
|
|
||||||
device: CUDA device
|
|
||||||
log_entries: list to append log dicts to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
loss value
|
|
||||||
"""
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
|
|
||||||
context_len = example['context_len']
|
|
||||||
|
|
||||||
# Forward pass (context frozen, decision tokens with grad)
|
|
||||||
logits, targets = forward_pass(params, input_ids, context_len, device)
|
|
||||||
|
|
||||||
# Cross-entropy loss on decision tokens
|
|
||||||
loss = torch.nn.functional.cross_entropy(
|
|
||||||
logits.view(-1, logits.shape[-1]),
|
|
||||||
targets.view(-1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Compute gradient stats before optimizer step
|
|
||||||
total_grad_norm = 0.0
|
|
||||||
for p in params.values():
|
|
||||||
if p.grad is not None:
|
|
||||||
total_grad_norm += p.grad.norm().item() ** 2
|
|
||||||
total_grad_norm = total_grad_norm ** 0.5
|
|
||||||
|
|
||||||
# Optimizer step
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Log
|
|
||||||
log_entries.append({
|
|
||||||
'example_id': example.get('id', 'unknown'),
|
|
||||||
'loss': loss.item(),
|
|
||||||
'grad_norm': total_grad_norm,
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return loss.item()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Apollo-Mini training")
|
|
||||||
parser.add_argument("--weights", required=True,
|
|
||||||
help="Path to exported weight IPC handles")
|
|
||||||
parser.add_argument("--examples", required=True,
|
|
||||||
help="Path to training examples JSONL")
|
|
||||||
parser.add_argument("--checkpoint-dir", default="checkpoints",
|
|
||||||
help="Directory for saving checkpoints")
|
|
||||||
parser.add_argument("--config-path", default=None,
|
|
||||||
help="Path to model config files (for checkpoint)")
|
|
||||||
parser.add_argument("--lr", type=float, default=1e-5,
|
|
||||||
help="Learning rate")
|
|
||||||
parser.add_argument("--warmup-steps", type=int, default=10,
|
|
||||||
help="Learning rate warmup steps")
|
|
||||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
|
||||||
parser.add_argument("--dry-run", action="store_true",
|
|
||||||
help="Load weights and validate, don't train")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"Apollo-Mini Training")
|
|
||||||
print(f" weights: {args.weights}")
|
|
||||||
print(f" examples: {args.examples}")
|
|
||||||
print(f" lr: {args.lr}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Import weights
|
|
||||||
print("Importing weights via CUDA IPC...")
|
|
||||||
params = import_weights(args.weights)
|
|
||||||
print(f" {len(params)} parameters imported")
|
|
||||||
|
|
||||||
# Make parameter groups
|
|
||||||
param_groups = make_param_groups(params)
|
|
||||||
|
|
||||||
# Initialize optimizer
|
|
||||||
optimizer = ApolloMini(param_groups, lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
warmup_steps=args.warmup_steps)
|
|
||||||
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
if args.dry_run:
|
|
||||||
print("\nDry run — weights imported and validated successfully.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load training examples
|
|
||||||
examples = []
|
|
||||||
with open(args.examples) as f:
|
|
||||||
for line in f:
|
|
||||||
examples.append(json.loads(line))
|
|
||||||
print(f" {len(examples)} training examples")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
log_entries = []
|
|
||||||
print(f"\nTraining...")
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
for i, example in enumerate(examples):
|
|
||||||
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
|
|
||||||
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
|
|
||||||
|
|
||||||
elapsed = time.time() - t0
|
|
||||||
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
|
|
||||||
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
|
||||||
|
|
||||||
# Save checkpoint
|
|
||||||
print("\nSaving checkpoint...")
|
|
||||||
save_checkpoint(params, args.checkpoint_dir, args.config_path)
|
|
||||||
|
|
||||||
# Save training log
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
|
|
||||||
with open(log_path, 'w') as f:
|
|
||||||
for entry in log_entries:
|
|
||||||
f.write(json.dumps(entry) + '\n')
|
|
||||||
print(f"Training log: {log_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,175 +0,0 @@
|
||||||
"""Training example construction and tokenization.
|
|
||||||
|
|
||||||
Takes raw conversation context + improved continuation, produces
|
|
||||||
tokenized tensors ready for context-frozen forward+backward.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingExample:
|
|
||||||
"""A single training example for context-frozen training."""
|
|
||||||
id: str
|
|
||||||
context: str # conversation up to decision point
|
|
||||||
continuation: str # the better response
|
|
||||||
reason: str = "" # why this is a training target
|
|
||||||
memories: list[str] = field(default_factory=list) # memories that were in context
|
|
||||||
|
|
||||||
# Computed after tokenization
|
|
||||||
input_ids: torch.Tensor | None = None
|
|
||||||
context_len: int = 0
|
|
||||||
total_len: int = 0
|
|
||||||
|
|
||||||
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
|
|
||||||
"""Tokenize context + continuation into training-ready tensors.
|
|
||||||
|
|
||||||
The chat template is applied to make the token distribution
|
|
||||||
match what the model sees during inference.
|
|
||||||
"""
|
|
||||||
# Build messages for context (everything up to the decision)
|
|
||||||
# The context should already be in chat format
|
|
||||||
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
|
|
||||||
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
|
|
||||||
|
|
||||||
self.context_len = len(context_ids)
|
|
||||||
self.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
if self.total_len > max_len:
|
|
||||||
# Truncate context from the left, keep continuation intact
|
|
||||||
excess = self.total_len - max_len
|
|
||||||
context_ids = context_ids[excess:]
|
|
||||||
self.context_len = len(context_ids)
|
|
||||||
self.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
self.input_ids = torch.tensor(all_ids, device=device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
'id': self.id,
|
|
||||||
'context': self.context,
|
|
||||||
'continuation': self.continuation,
|
|
||||||
'reason': self.reason,
|
|
||||||
'memories': self.memories,
|
|
||||||
'context_len': self.context_len,
|
|
||||||
'total_len': self.total_len,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: dict) -> 'TrainingExample':
|
|
||||||
return cls(
|
|
||||||
id=d['id'],
|
|
||||||
context=d['context'],
|
|
||||||
continuation=d['continuation'],
|
|
||||||
reason=d.get('reason', ''),
|
|
||||||
memories=d.get('memories', []),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_examples(path: str) -> list[TrainingExample]:
|
|
||||||
"""Load training examples from JSONL file."""
|
|
||||||
examples = []
|
|
||||||
with open(path) as f:
|
|
||||||
for line in f:
|
|
||||||
if line.strip():
|
|
||||||
examples.append(TrainingExample.from_dict(json.loads(line)))
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def save_examples(examples: list[TrainingExample], path: str):
|
|
||||||
"""Save training examples to JSONL file."""
|
|
||||||
with open(path, 'w') as f:
|
|
||||||
for ex in examples:
|
|
||||||
f.write(json.dumps(ex.to_dict()) + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleTokenizer:
|
|
||||||
"""Handles tokenization with the model's chat template.
|
|
||||||
|
|
||||||
Applies the same chat template that vLLM uses during inference,
|
|
||||||
so the token distribution matches what the model expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
def prepare_example(self, example: TrainingExample,
|
|
||||||
max_len: int = 8192,
|
|
||||||
device: str = "cuda:0") -> TrainingExample:
|
|
||||||
"""Tokenize an example using the chat template.
|
|
||||||
|
|
||||||
For proper training, the context should be formatted exactly
|
|
||||||
as vLLM would format it — with chat template applied.
|
|
||||||
"""
|
|
||||||
# Apply chat template to get the exact token sequence
|
|
||||||
# the model would see during inference
|
|
||||||
#
|
|
||||||
# Context: everything up to the decision point
|
|
||||||
# Continuation: the improved response
|
|
||||||
#
|
|
||||||
# We tokenize them separately to know where context ends
|
|
||||||
# and continuation begins.
|
|
||||||
context_ids = self.tokenizer.encode(
|
|
||||||
example.context, add_special_tokens=True)
|
|
||||||
continuation_ids = self.tokenizer.encode(
|
|
||||||
example.continuation, add_special_tokens=False)
|
|
||||||
|
|
||||||
example.context_len = len(context_ids)
|
|
||||||
example.total_len = len(context_ids) + len(continuation_ids)
|
|
||||||
|
|
||||||
if example.total_len > max_len:
|
|
||||||
excess = example.total_len - max_len
|
|
||||||
context_ids = context_ids[excess:]
|
|
||||||
example.context_len = len(context_ids)
|
|
||||||
example.total_len = example.context_len + len(continuation_ids)
|
|
||||||
|
|
||||||
all_ids = context_ids + continuation_ids
|
|
||||||
example.input_ids = torch.tensor(all_ids, device=device)
|
|
||||||
return example
|
|
||||||
|
|
||||||
def prepare_from_messages(self, example_id: str,
|
|
||||||
messages: list[dict],
|
|
||||||
decision_idx: int,
|
|
||||||
better_response: str,
|
|
||||||
reason: str = "",
|
|
||||||
memories: list[str] | None = None,
|
|
||||||
max_len: int = 8192,
|
|
||||||
device: str = "cuda:0") -> TrainingExample:
|
|
||||||
"""Build a training example from a chat message list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
example_id: unique identifier
|
|
||||||
messages: list of {"role": ..., "content": ...} dicts
|
|
||||||
decision_idx: index of the assistant message to replace
|
|
||||||
better_response: the improved response text
|
|
||||||
reason: why this is a training target
|
|
||||||
memories: memory keys that were in context
|
|
||||||
max_len: maximum sequence length
|
|
||||||
device: target device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tokenized TrainingExample
|
|
||||||
"""
|
|
||||||
# Context: all messages up to (not including) the decision
|
|
||||||
context_messages = messages[:decision_idx]
|
|
||||||
context_text = self.tokenizer.apply_chat_template(
|
|
||||||
context_messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
|
|
||||||
# Build the example
|
|
||||||
example = TrainingExample(
|
|
||||||
id=example_id,
|
|
||||||
context=context_text,
|
|
||||||
continuation=better_response,
|
|
||||||
reason=reason,
|
|
||||||
memories=memories or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.prepare_example(example, max_len=max_len, device=device)
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue