Compare commits
2 Commits
ccbaf41211
...
2ee0d8c265
Author | SHA1 | Date |
---|---|---|
projectmoon | 2ee0d8c265 | |
projectmoon | e937d066fc |
|
@ -6,8 +6,8 @@ use itertools::Itertools;
|
|||
|
||||
use crate::models::{
|
||||
coherence::{CoherenceFailure, SceneFix},
|
||||
world::scenes::{root_scene_id, Exit, Scene},
|
||||
Content, ContentContainer,
|
||||
world::scenes::{root_scene_id, Exit, Scene, SceneStub},
|
||||
Content, ContentContainer, ContentRelation,
|
||||
};
|
||||
|
||||
use super::generator::AiClient;
|
||||
|
@ -91,6 +91,46 @@ pub fn reverse_direction(direction: &str) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
fn find_exit_by_connection(exit: &Exit, connected_scene: &Scene) -> bool {
|
||||
let connected_key = connected_scene._key.as_deref().unwrap();
|
||||
Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
|
||||
|| exit.scene_id.as_deref() == connected_scene._id.as_deref()
|
||||
|| exit.name.to_lowercase() == connected_scene.name.to_lowercase()
|
||||
|| exit.name == connected_key
|
||||
}
|
||||
|
||||
fn find_exit_by_direction(exit: &Exit, direction_from: &str) -> bool {
|
||||
exit.direction == reverse_direction(direction_from)
|
||||
}
|
||||
|
||||
// fn yoink<'a>(
|
||||
// //exits: Vec<&'a mut Exit>,
|
||||
// mut exits: Vec<&'a mut Exit>,
|
||||
// direction_from: &'a str,
|
||||
// connected_scene: &'a Scene,
|
||||
// ) -> Option<&'a mut Exit> {
|
||||
// if exits.len() > 1 {
|
||||
// exits
|
||||
// .iter()
|
||||
// .find_map(|exit| match find_exit_by_direction(*exit, direction_from) {
|
||||
// true => Some(&mut **exit),
|
||||
// _ => None,
|
||||
// })
|
||||
// .or_else(|| {
|
||||
// exits.iter().find_map(|&exit| {
|
||||
// match find_exit_by_connection(exit, connected_scene) {
|
||||
// true => Some(&mut *exit),
|
||||
// _ => None,
|
||||
// }
|
||||
// })
|
||||
// })
|
||||
// } else if exits.len() == 1 {
|
||||
// Some(exits[0])
|
||||
// } else {
|
||||
// None
|
||||
// }
|
||||
// }
|
||||
|
||||
/// Attempt to reconnect back to the connected scene. The model is
|
||||
/// not always good at this. Here, we correct it by attempting to
|
||||
/// find the exit by name, and also making sure the direction is
|
||||
|
@ -109,37 +149,41 @@ pub fn make_scene_from_stub_coherent(content: &mut ContentContainer, connected_s
|
|||
.map(|exit| exit.direction.as_ref())
|
||||
.unwrap_or("from");
|
||||
|
||||
// TODO fuzzy search exit - perhaps with Skim V2. Prefer scene
|
||||
// key or id match, but try to find it via name too
|
||||
let exit = new_scene.exits.iter_mut().find(|exit| {
|
||||
Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
|
||||
let reversed_direction = reverse_direction(direction_from);
|
||||
|
||||
// Rethink this.
|
||||
// 1. Delete any exits that have the same direction as the reversed direction_from
|
||||
// AND do not point to connected scene ID.
|
||||
// 2. Find potential connected exit and modify as normal.
|
||||
new_scene.exits.retain(|exit| {
|
||||
(exit.direction == reversed_direction && exit.scene_key == connected_key)
|
||||
|| exit.direction != reversed_direction
|
||||
});
|
||||
|
||||
////////////////////////////////
|
||||
|
||||
// It is possible we have an exit that leads in the direction of
|
||||
// the way back, or one exit that is created to go back to the connected scene.
|
||||
let exit_to_change = new_scene.exits.iter_mut().find(|exit| {
|
||||
exit.direction == reverse_direction(direction_from)
|
||||
|| Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
|
||||
|| exit.scene_id.as_deref() == connected_scene._id.as_deref()
|
||||
|| exit.name.to_lowercase() == connected_scene.name.to_lowercase()
|
||||
|| exit.name == connected_key
|
||||
|| exit.name == connected_id
|
||||
});
|
||||
|
||||
if let Some(exit) = exit {
|
||||
let bad_id = &exit.scene_id;
|
||||
if let Some(exit) = exit_to_change {
|
||||
// Delete the stub that was this exit, and update the exit
|
||||
content.contained.retain(|c| match &c.content {
|
||||
Content::SceneStub(stub) => stub._id != exit.scene_id,
|
||||
_ => true,
|
||||
});
|
||||
|
||||
let remove_stub_at_pos = content
|
||||
.contained
|
||||
.iter()
|
||||
.find_position(|c| match &c.content {
|
||||
Content::SceneStub(stub) => stub._id.as_ref() == bad_id.as_ref(),
|
||||
_ => false,
|
||||
});
|
||||
|
||||
// Fix up exit
|
||||
// Fix up exit to point back to connected scene.
|
||||
exit.scene_id = connected_scene._id.clone();
|
||||
exit.scene_key = connected_scene._key.as_ref().unwrap().clone();
|
||||
exit.direction = reverse_direction(direction_from);
|
||||
|
||||
// Delete the stub that was this exit, and update the exit
|
||||
// with the right stuff.
|
||||
if let Some((pos, _)) = remove_stub_at_pos {
|
||||
content.contained.swap_remove(pos);
|
||||
}
|
||||
} else {
|
||||
println!("WARNING: could not correct stub exit - creating manually");
|
||||
|
||||
|
@ -149,7 +193,11 @@ pub fn make_scene_from_stub_coherent(content: &mut ContentContainer, connected_s
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn make_scene_coherent(generator: &mut AiClient, scene: &mut Scene) -> Result<()> {
|
||||
pub async fn make_scene_coherent(
|
||||
generator: &mut AiClient,
|
||||
content: &mut ContentContainer,
|
||||
) -> Result<()> {
|
||||
let scene = content.owner.as_scene_mut();
|
||||
let failures = check_scene_coherence(&scene);
|
||||
let fixes = generator.fix_scene(&scene, failures).await?;
|
||||
let mut deletes = vec![]; // Needed because we Vec::retain after the fact
|
||||
|
@ -160,7 +208,20 @@ pub async fn make_scene_coherent(generator: &mut AiClient, scene: &mut Scene) ->
|
|||
index,
|
||||
new: fixed_exit,
|
||||
} => {
|
||||
// TODO could someday use swap_remove? only cloning
|
||||
// due to deletes below.
|
||||
let old_exit = scene.exits[index].clone();
|
||||
scene.exits[index] = fixed_exit.into();
|
||||
content.contained.retain(|c| match &c.content {
|
||||
Content::SceneStub(stub) => stub._id != old_exit.scene_id,
|
||||
_ => true,
|
||||
});
|
||||
|
||||
let fixed_exit = &scene.exits[index];
|
||||
|
||||
content
|
||||
.contained
|
||||
.push(ContentRelation::scene_stub(SceneStub::from(fixed_exit)));
|
||||
}
|
||||
SceneFix::DeleteExit(index) => {
|
||||
deletes.push(index);
|
||||
|
|
118
src/ai/convo.rs
118
src/ai/convo.rs
|
@ -1,8 +1,12 @@
|
|||
use crate::kobold_api::{create_input, Client as KoboldClient, SseGenerationExt};
|
||||
use crate::models::new_uuid_string;
|
||||
use anyhow::Result;
|
||||
use async_recursion::async_recursion;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::error::Category;
|
||||
use serde_json::Value;
|
||||
use std::borrow::BorrowMut;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
/// Characters which can break the JSON deserialization. Do not rely
|
||||
|
@ -18,6 +22,49 @@ fn sanitize_json_response(mut json: String) -> String {
|
|||
json
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
async fn continue_execution<T: DeserializeOwned>(
|
||||
client: &KoboldClient,
|
||||
gen_key: &str,
|
||||
prompt: &AiPrompt,
|
||||
prompt_so_far: &mut String,
|
||||
resp_so_far: &mut String,
|
||||
) -> Result<T> {
|
||||
println!("hit continue execution");
|
||||
// Grammar state is retained here (as opposed to false
|
||||
// normally) to let the model continue to generate JSON.
|
||||
|
||||
let input = create_input(
|
||||
gen_key.to_string(),
|
||||
prompt_so_far,
|
||||
prompt.grammar.clone(),
|
||||
prompt.max_tokens,
|
||||
true,
|
||||
);
|
||||
|
||||
// TODO convert error to remove trait bound issue
|
||||
let resp = client.sse_generate(input).await.unwrap();
|
||||
|
||||
prompt_so_far.push_str(&resp);
|
||||
resp_so_far.push_str(&resp);
|
||||
|
||||
let resp: Value = match serde_json::from_str(&resp_so_far) {
|
||||
Ok(obj) => obj,
|
||||
Err(e) => match e.classify() {
|
||||
Category::Eof | Category::Syntax => {
|
||||
continue_execution(client, gen_key, prompt, prompt_so_far, resp_so_far).await?
|
||||
}
|
||||
_ => {
|
||||
return Err(e.into());
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let resp: T = serde_json::from_value(resp)?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
pub struct AiPrompt {
|
||||
pub prompt: String,
|
||||
pub grammar: Option<String>,
|
||||
|
@ -51,38 +98,45 @@ impl AiPrompt {
|
|||
}
|
||||
|
||||
pub struct AiConversation {
|
||||
prompt_so_far: String,
|
||||
gen_key: String,
|
||||
prompt_so_far: Rc<RefCell<String>>,
|
||||
client: Rc<KoboldClient>,
|
||||
}
|
||||
|
||||
impl AiConversation {
|
||||
pub fn new(client: Rc<KoboldClient>) -> AiConversation {
|
||||
AiConversation {
|
||||
prompt_so_far: String::new(),
|
||||
prompt_so_far: Rc::new(RefCell::new(String::new())),
|
||||
gen_key: new_uuid_string(),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.prompt_so_far.is_empty()
|
||||
self.prompt_so_far.borrow().is_empty()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.prompt_so_far = "".to_string();
|
||||
pub fn reset(&self) {
|
||||
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
|
||||
prompt_so_far.clear();
|
||||
}
|
||||
|
||||
pub async fn execute<T: DeserializeOwned>(&mut self, prompt: &AiPrompt) -> Result<T> {
|
||||
pub async fn execute<T: DeserializeOwned>(&self, prompt: &AiPrompt) -> Result<T> {
|
||||
// Handle Mistral-instruct begin instruct mode.
|
||||
// https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
|
||||
// Only the very first one begins with <s>, not subsequent.
|
||||
if self.prompt_so_far.is_empty() {
|
||||
self.prompt_so_far.push_str("<s>");
|
||||
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
|
||||
let prompt_so_far = &mut *prompt_so_far;
|
||||
if prompt_so_far.is_empty() {
|
||||
prompt_so_far.push_str("<s>");
|
||||
}
|
||||
|
||||
self.prompt_so_far.push_str(&prompt.prompt);
|
||||
prompt_so_far.push_str(&prompt.prompt);
|
||||
|
||||
// Will fail ?
|
||||
let input = create_input(
|
||||
&self.prompt_so_far,
|
||||
self.gen_key.clone(),
|
||||
&prompt_so_far,
|
||||
prompt.grammar.clone(),
|
||||
prompt.max_tokens,
|
||||
false,
|
||||
|
@ -96,7 +150,7 @@ impl AiConversation {
|
|||
.map(sanitize_json_response)
|
||||
.unwrap();
|
||||
|
||||
self.prompt_so_far.push_str(&str_resp);
|
||||
prompt_so_far.push_str(&str_resp);
|
||||
|
||||
let resp: T = match serde_json::from_str(&str_resp) {
|
||||
Ok(obj) => obj,
|
||||
|
@ -104,44 +158,26 @@ impl AiConversation {
|
|||
// If the resp is not fully valid JSON, request more
|
||||
// from the LLM.
|
||||
match e.classify() {
|
||||
Category::Eof => self.continue_execution(prompt, &mut str_resp).await?,
|
||||
Category::Eof => {
|
||||
continue_execution(
|
||||
&self.client,
|
||||
&self.gen_key,
|
||||
prompt,
|
||||
prompt_so_far,
|
||||
&mut str_resp,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// mistral 7b end of response token (for when BNF is used)
|
||||
if !self.prompt_so_far.trim().ends_with("</s>") {
|
||||
self.prompt_so_far.push_str("</s>");
|
||||
if !prompt_so_far.trim().ends_with("</s>") {
|
||||
prompt_so_far.push_str("</s>");
|
||||
}
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
async fn continue_execution<T: DeserializeOwned>(
|
||||
&mut self,
|
||||
prompt: &AiPrompt,
|
||||
resp_so_far: &mut String,
|
||||
) -> Result<T> {
|
||||
// Grammar state is retained here (as opposed to false
|
||||
// normally) to let the model continue to generate JSON.
|
||||
let input = create_input(&self.prompt_so_far, None, prompt.max_tokens, true);
|
||||
|
||||
// TODO convert error to remove trait bound issue
|
||||
let resp = self.client.sse_generate(input).await.unwrap();
|
||||
|
||||
self.prompt_so_far.push_str(&resp);
|
||||
resp_so_far.push_str(&resp);
|
||||
|
||||
let resp: T = match serde_json::from_str(&resp_so_far) {
|
||||
Ok(obj) => obj,
|
||||
Err(e) => match e.classify() {
|
||||
Category::Eof => self.continue_execution(prompt, resp_so_far).await?,
|
||||
_ => return Err(e.into()),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -168,8 +168,10 @@ impl AiClient {
|
|||
for failure in failures {
|
||||
let fix = match failure {
|
||||
CoherenceFailure::InvalidExitName(original_exit) => {
|
||||
println!("invalid exit name: {}", original_exit.name);
|
||||
let prompt = world_prompts::fix_exit_prompt(scene, original_exit);
|
||||
let fixed: ExitSeed = self.world_creation_convo.execute(&prompt).await?;
|
||||
println!("fixed with: {:?}", fixed);
|
||||
let position = find_exit_position(&scene.exits, original_exit)?;
|
||||
|
||||
SceneFix::FixedExit {
|
||||
|
|
|
@ -171,9 +171,7 @@ impl AiLogic {
|
|||
// coherence (that can invoke the LLM).
|
||||
let mut content = self.fill_in_scene_from_stub(seed, stub).await?;
|
||||
coherence::make_scene_from_stub_coherent(&mut content, connected_scene);
|
||||
|
||||
let mut scene = content.owner.as_scene_mut();
|
||||
coherence::make_scene_coherent(&mut self.generator, &mut scene).await?;
|
||||
coherence::make_scene_coherent(&mut self.generator, &mut content).await?;
|
||||
|
||||
self.generator.reset_world_creation();
|
||||
|
||||
|
@ -193,8 +191,7 @@ impl AiLogic {
|
|||
.await?;
|
||||
|
||||
let mut content = self.fill_in_scene(scene_seed).await?;
|
||||
let mut scene = content.owner.as_scene_mut();
|
||||
coherence::make_scene_coherent(&mut self.generator, &mut scene).await?;
|
||||
coherence::make_scene_coherent(&mut self.generator, &mut content).await?;
|
||||
|
||||
self.generator.reset_world_creation();
|
||||
Ok(content)
|
||||
|
@ -241,13 +238,7 @@ impl AiLogic {
|
|||
let mut stubs: Vec<_> = exits
|
||||
.iter()
|
||||
.map(|exit| {
|
||||
ContentRelation::scene_stub(SceneStub {
|
||||
_key: Some(exit.scene_key.clone()),
|
||||
name: exit.name.clone(),
|
||||
location: exit.location.clone(),
|
||||
is_stub: true,
|
||||
..Default::default()
|
||||
})
|
||||
ContentRelation::scene_stub(SceneStub::from(exit))
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -235,7 +235,7 @@ pub fn scene_creation_prompt(scene_type: &str, fantasticalness: &str) -> AiPromp
|
|||
.replacen("{}", scene_type, 1)
|
||||
.replacen("{}", fantasticalness, 1),
|
||||
SCENE_BNF,
|
||||
1024,
|
||||
100,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,15 +3,21 @@ use es::SSE;
|
|||
use eventsource_client as es;
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use std::num::NonZeroU64;
|
||||
use std::time::Duration;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/codegen.rs"));
|
||||
|
||||
pub fn create_input(prompt: &str, grammar: Option<String>, max_tokens: u64, retain_gramar_state: bool) -> types::GenerationInput {
|
||||
pub fn create_input(
|
||||
gen_key: String,
|
||||
prompt: &str,
|
||||
grammar: Option<String>,
|
||||
max_tokens: u64,
|
||||
retain_gramar_state: bool,
|
||||
) -> types::GenerationInput {
|
||||
types::GenerationInput {
|
||||
genkey: Some(gen_key),
|
||||
prompt: prompt.to_string(),
|
||||
genkey: None,
|
||||
grammar: grammar,
|
||||
grammar_retain_state: retain_gramar_state,
|
||||
use_default_badwordsids: false,
|
||||
|
|
|
@ -87,6 +87,18 @@ impl Default for SceneStub {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&Exit> for SceneStub {
|
||||
fn from(exit: &Exit) -> Self {
|
||||
Self {
|
||||
_key: Some(exit.scene_key.clone()),
|
||||
name: exit.name.clone(),
|
||||
location: exit.location.clone(),
|
||||
is_stub: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The stage is everything: a scene, the people ("actors") in it, the
|
||||
// props, etc.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
Loading…
Reference in New Issue