Compare commits

...

2 Commits

7 changed files with 189 additions and 81 deletions

View File

@ -6,8 +6,8 @@ use itertools::Itertools;
use crate::models::{
coherence::{CoherenceFailure, SceneFix},
world::scenes::{root_scene_id, Exit, Scene},
Content, ContentContainer,
world::scenes::{root_scene_id, Exit, Scene, SceneStub},
Content, ContentContainer, ContentRelation,
};
use super::generator::AiClient;
@ -91,6 +91,46 @@ pub fn reverse_direction(direction: &str) -> String {
}
}
fn find_exit_by_connection(exit: &Exit, connected_scene: &Scene) -> bool {
let connected_key = connected_scene._key.as_deref().unwrap();
Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
|| exit.scene_id.as_deref() == connected_scene._id.as_deref()
|| exit.name.to_lowercase() == connected_scene.name.to_lowercase()
|| exit.name == connected_key
}
fn find_exit_by_direction(exit: &Exit, direction_from: &str) -> bool {
exit.direction == reverse_direction(direction_from)
}
// fn yoink<'a>(
// //exits: Vec<&'a mut Exit>,
// mut exits: Vec<&'a mut Exit>,
// direction_from: &'a str,
// connected_scene: &'a Scene,
// ) -> Option<&'a mut Exit> {
// if exits.len() > 1 {
// exits
// .iter()
// .find_map(|exit| match find_exit_by_direction(*exit, direction_from) {
// true => Some(&mut **exit),
// _ => None,
// })
// .or_else(|| {
// exits.iter().find_map(|&exit| {
// match find_exit_by_connection(exit, connected_scene) {
// true => Some(&mut *exit),
// _ => None,
// }
// })
// })
// } else if exits.len() == 1 {
// Some(exits[0])
// } else {
// None
// }
// }
/// Attempt to reconnect back to the connected scene. The model is
/// not always good at this. Here, we correct it by attempting to
/// find the exit by name, and also making sure the direction is
@ -109,37 +149,41 @@ pub fn make_scene_from_stub_coherent(content: &mut ContentContainer, connected_s
.map(|exit| exit.direction.as_ref())
.unwrap_or("from");
// TODO fuzzy search exit - perhaps with Skim V2. Prefer scene
// key or id match, but try to find it via name too
let exit = new_scene.exits.iter_mut().find(|exit| {
Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
let reversed_direction = reverse_direction(direction_from);
// Rethink this.
// 1. Delete any exits that have the same direction as the reversed direction_from
// AND do not point to connected scene ID.
// 2. Find potential connected exit and modify as normal.
new_scene.exits.retain(|exit| {
(exit.direction == reversed_direction && exit.scene_key == connected_key)
|| exit.direction != reversed_direction
});
////////////////////////////////
// It is possible we have an exit that leads in the direction of
// the way back, or one exit that is created to go back to the connected scene.
let exit_to_change = new_scene.exits.iter_mut().find(|exit| {
exit.direction == reverse_direction(direction_from)
|| Some(exit.scene_key.as_ref()) == connected_scene._key.as_deref()
|| exit.scene_id.as_deref() == connected_scene._id.as_deref()
|| exit.name.to_lowercase() == connected_scene.name.to_lowercase()
|| exit.name == connected_key
|| exit.name == connected_id
});
if let Some(exit) = exit {
let bad_id = &exit.scene_id;
if let Some(exit) = exit_to_change {
// Delete the stub that was this exit, and update the exit
content.contained.retain(|c| match &c.content {
Content::SceneStub(stub) => stub._id != exit.scene_id,
_ => true,
});
let remove_stub_at_pos = content
.contained
.iter()
.find_position(|c| match &c.content {
Content::SceneStub(stub) => stub._id.as_ref() == bad_id.as_ref(),
_ => false,
});
// Fix up exit
// Fix up exit to point back to connected scene.
exit.scene_id = connected_scene._id.clone();
exit.scene_key = connected_scene._key.as_ref().unwrap().clone();
exit.direction = reverse_direction(direction_from);
// Delete the stub that was this exit, and update the exit
// with the right stuff.
if let Some((pos, _)) = remove_stub_at_pos {
content.contained.swap_remove(pos);
}
} else {
println!("WARNING: could not correct stub exit - creating manually");
@ -149,7 +193,11 @@ pub fn make_scene_from_stub_coherent(content: &mut ContentContainer, connected_s
}
}
pub async fn make_scene_coherent(generator: &mut AiClient, scene: &mut Scene) -> Result<()> {
pub async fn make_scene_coherent(
generator: &mut AiClient,
content: &mut ContentContainer,
) -> Result<()> {
let scene = content.owner.as_scene_mut();
let failures = check_scene_coherence(&scene);
let fixes = generator.fix_scene(&scene, failures).await?;
let mut deletes = vec![]; // Needed because we Vec::retain after the fact
@ -160,7 +208,20 @@ pub async fn make_scene_coherent(generator: &mut AiClient, scene: &mut Scene) ->
index,
new: fixed_exit,
} => {
// TODO could someday use swap_remove? only cloning
// due to deletes below.
let old_exit = scene.exits[index].clone();
scene.exits[index] = fixed_exit.into();
content.contained.retain(|c| match &c.content {
Content::SceneStub(stub) => stub._id != old_exit.scene_id,
_ => true,
});
let fixed_exit = &scene.exits[index];
content
.contained
.push(ContentRelation::scene_stub(SceneStub::from(fixed_exit)));
}
SceneFix::DeleteExit(index) => {
deletes.push(index);

View File

@ -1,8 +1,12 @@
use crate::kobold_api::{create_input, Client as KoboldClient, SseGenerationExt};
use crate::models::new_uuid_string;
use anyhow::Result;
use async_recursion::async_recursion;
use serde::de::DeserializeOwned;
use serde_json::error::Category;
use serde_json::Value;
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::rc::Rc;
/// Characters which can break the JSON deserialization. Do not rely
@ -18,6 +22,49 @@ fn sanitize_json_response(mut json: String) -> String {
json
}
#[async_recursion(?Send)]
async fn continue_execution<T: DeserializeOwned>(
client: &KoboldClient,
gen_key: &str,
prompt: &AiPrompt,
prompt_so_far: &mut String,
resp_so_far: &mut String,
) -> Result<T> {
println!("hit continue execution");
// Grammar state is retained here (as opposed to false
// normally) to let the model continue to generate JSON.
let input = create_input(
gen_key.to_string(),
prompt_so_far,
prompt.grammar.clone(),
prompt.max_tokens,
true,
);
// TODO convert error to remove trait bound issue
let resp = client.sse_generate(input).await.unwrap();
prompt_so_far.push_str(&resp);
resp_so_far.push_str(&resp);
let resp: Value = match serde_json::from_str(&resp_so_far) {
Ok(obj) => obj,
Err(e) => match e.classify() {
Category::Eof | Category::Syntax => {
continue_execution(client, gen_key, prompt, prompt_so_far, resp_so_far).await?
}
_ => {
return Err(e.into());
}
},
};
let resp: T = serde_json::from_value(resp)?;
Ok(resp)
}
pub struct AiPrompt {
pub prompt: String,
pub grammar: Option<String>,
@ -51,38 +98,45 @@ impl AiPrompt {
}
pub struct AiConversation {
prompt_so_far: String,
gen_key: String,
prompt_so_far: Rc<RefCell<String>>,
client: Rc<KoboldClient>,
}
impl AiConversation {
pub fn new(client: Rc<KoboldClient>) -> AiConversation {
AiConversation {
prompt_so_far: String::new(),
prompt_so_far: Rc::new(RefCell::new(String::new())),
gen_key: new_uuid_string(),
client,
}
}
pub fn is_empty(&self) -> bool {
self.prompt_so_far.is_empty()
self.prompt_so_far.borrow().is_empty()
}
pub fn reset(&mut self) {
self.prompt_so_far = "".to_string();
pub fn reset(&self) {
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
prompt_so_far.clear();
}
pub async fn execute<T: DeserializeOwned>(&mut self, prompt: &AiPrompt) -> Result<T> {
pub async fn execute<T: DeserializeOwned>(&self, prompt: &AiPrompt) -> Result<T> {
// Handle Mistral-instruct begin instruct mode.
// https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
// Only the very first one begins with <s>, not subsequent.
if self.prompt_so_far.is_empty() {
self.prompt_so_far.push_str("<s>");
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
let prompt_so_far = &mut *prompt_so_far;
if prompt_so_far.is_empty() {
prompt_so_far.push_str("<s>");
}
self.prompt_so_far.push_str(&prompt.prompt);
prompt_so_far.push_str(&prompt.prompt);
// Will fail ?
let input = create_input(
&self.prompt_so_far,
self.gen_key.clone(),
&prompt_so_far,
prompt.grammar.clone(),
prompt.max_tokens,
false,
@ -96,7 +150,7 @@ impl AiConversation {
.map(sanitize_json_response)
.unwrap();
self.prompt_so_far.push_str(&str_resp);
prompt_so_far.push_str(&str_resp);
let resp: T = match serde_json::from_str(&str_resp) {
Ok(obj) => obj,
@ -104,44 +158,26 @@ impl AiConversation {
// If the resp is not fully valid JSON, request more
// from the LLM.
match e.classify() {
Category::Eof => self.continue_execution(prompt, &mut str_resp).await?,
Category::Eof => {
continue_execution(
&self.client,
&self.gen_key,
prompt,
prompt_so_far,
&mut str_resp,
)
.await?
}
_ => return Err(e.into()),
}
}
};
// mistral 7b end of response token (for when BNF is used)
if !self.prompt_so_far.trim().ends_with("</s>") {
self.prompt_so_far.push_str("</s>");
if !prompt_so_far.trim().ends_with("</s>") {
prompt_so_far.push_str("</s>");
}
Ok(resp)
}
#[async_recursion(?Send)]
async fn continue_execution<T: DeserializeOwned>(
&mut self,
prompt: &AiPrompt,
resp_so_far: &mut String,
) -> Result<T> {
// Grammar state is retained here (as opposed to false
// normally) to let the model continue to generate JSON.
let input = create_input(&self.prompt_so_far, None, prompt.max_tokens, true);
// TODO convert error to remove trait bound issue
let resp = self.client.sse_generate(input).await.unwrap();
self.prompt_so_far.push_str(&resp);
resp_so_far.push_str(&resp);
let resp: T = match serde_json::from_str(&resp_so_far) {
Ok(obj) => obj,
Err(e) => match e.classify() {
Category::Eof => self.continue_execution(prompt, resp_so_far).await?,
_ => return Err(e.into()),
},
};
Ok(resp)
}
}

View File

@ -168,8 +168,10 @@ impl AiClient {
for failure in failures {
let fix = match failure {
CoherenceFailure::InvalidExitName(original_exit) => {
println!("invalid exit name: {}", original_exit.name);
let prompt = world_prompts::fix_exit_prompt(scene, original_exit);
let fixed: ExitSeed = self.world_creation_convo.execute(&prompt).await?;
println!("fixed with: {:?}", fixed);
let position = find_exit_position(&scene.exits, original_exit)?;
SceneFix::FixedExit {

View File

@ -171,9 +171,7 @@ impl AiLogic {
// coherence (that can invoke the LLM).
let mut content = self.fill_in_scene_from_stub(seed, stub).await?;
coherence::make_scene_from_stub_coherent(&mut content, connected_scene);
let mut scene = content.owner.as_scene_mut();
coherence::make_scene_coherent(&mut self.generator, &mut scene).await?;
coherence::make_scene_coherent(&mut self.generator, &mut content).await?;
self.generator.reset_world_creation();
@ -193,8 +191,7 @@ impl AiLogic {
.await?;
let mut content = self.fill_in_scene(scene_seed).await?;
let mut scene = content.owner.as_scene_mut();
coherence::make_scene_coherent(&mut self.generator, &mut scene).await?;
coherence::make_scene_coherent(&mut self.generator, &mut content).await?;
self.generator.reset_world_creation();
Ok(content)
@ -241,13 +238,7 @@ impl AiLogic {
let mut stubs: Vec<_> = exits
.iter()
.map(|exit| {
ContentRelation::scene_stub(SceneStub {
_key: Some(exit.scene_key.clone()),
name: exit.name.clone(),
location: exit.location.clone(),
is_stub: true,
..Default::default()
})
ContentRelation::scene_stub(SceneStub::from(exit))
})
.collect();

View File

@ -235,7 +235,7 @@ pub fn scene_creation_prompt(scene_type: &str, fantasticalness: &str) -> AiPromp
.replacen("{}", scene_type, 1)
.replacen("{}", fantasticalness, 1),
SCENE_BNF,
1024,
100,
)
}

View File

@ -3,15 +3,21 @@ use es::SSE;
use eventsource_client as es;
use futures::{Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::num::NonZeroU64;
use std::time::Duration;
include!(concat!(env!("OUT_DIR"), "/codegen.rs"));
pub fn create_input(prompt: &str, grammar: Option<String>, max_tokens: u64, retain_gramar_state: bool) -> types::GenerationInput {
pub fn create_input(
gen_key: String,
prompt: &str,
grammar: Option<String>,
max_tokens: u64,
retain_gramar_state: bool,
) -> types::GenerationInput {
types::GenerationInput {
genkey: Some(gen_key),
prompt: prompt.to_string(),
genkey: None,
grammar: grammar,
grammar_retain_state: retain_gramar_state,
use_default_badwordsids: false,

View File

@ -87,6 +87,18 @@ impl Default for SceneStub {
}
}
impl From<&Exit> for SceneStub {
fn from(exit: &Exit) -> Self {
Self {
_key: Some(exit.scene_key.clone()),
name: exit.name.clone(),
location: exit.location.clone(),
is_stub: true,
..Default::default()
}
}
}
// The stage is everything: a scene, the people ("actors") in it, the
// props, etc.
#[derive(Serialize, Deserialize, Debug, Clone)]