Add creativity levels to the prompt parsing

This commit is contained in:
projectmoon 2024-01-12 21:47:26 +01:00
parent 02a4897b4b
commit d5fb204df4
4 changed files with 50 additions and 7 deletions

View File

@ -5,7 +5,6 @@ use async_recursion::async_recursion;
use serde::de::DeserializeOwned;
use serde_json::error::Category;
use serde_json::Value;
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::rc::Rc;
@ -40,6 +39,7 @@ async fn continue_execution<T: DeserializeOwned>(
prompt.grammar.clone(),
prompt.max_tokens,
true,
prompt.creativity,
);
// TODO convert error to remove trait bound issue
@ -65,10 +65,18 @@ async fn continue_execution<T: DeserializeOwned>(
Ok(resp)
}
#[derive(Debug, Clone, Copy)]
pub enum AiCreativity {
Predictable,
Normal,
Creative,
}
pub struct AiPrompt {
pub prompt: String,
pub grammar: Option<String>,
pub max_tokens: u64,
pub creativity: AiCreativity,
}
impl AiPrompt {
@ -77,6 +85,7 @@ impl AiPrompt {
prompt: prompt.to_string(),
grammar: None,
max_tokens: 150,
creativity: AiCreativity::Normal,
}
}
@ -85,6 +94,7 @@ impl AiPrompt {
prompt: prompt.to_string(),
grammar: Some(grammar.to_string()),
max_tokens: 150,
creativity: AiCreativity::Normal,
}
}
@ -93,6 +103,25 @@ impl AiPrompt {
prompt: prompt.to_string(),
grammar: Some(grammar.to_string()),
max_tokens: tokens,
creativity: AiCreativity::Normal,
}
}
pub fn creative_with_grammar(prompt: &str, grammar: &str) -> AiPrompt {
AiPrompt {
prompt: prompt.to_string(),
grammar: Some(grammar.to_string()),
max_tokens: 150,
creativity: AiCreativity::Creative,
}
}
pub fn creative_with_grammar_and_size(prompt: &str, grammar: &str, tokens: u64) -> AiPrompt {
AiPrompt {
prompt: prompt.to_string(),
grammar: Some(grammar.to_string()),
max_tokens: tokens,
creativity: AiCreativity::Creative,
}
}
}
@ -122,11 +151,12 @@ impl AiConversation {
}
pub async fn execute<T: DeserializeOwned>(&self, prompt: &AiPrompt) -> Result<T> {
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
let prompt_so_far = &mut *prompt_so_far;
// Handle Mistral-instruct begin instruct mode.
// https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
// Only the very first one begins with <s>, not subsequent.
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
let prompt_so_far = &mut *prompt_so_far;
if prompt_so_far.is_empty() {
prompt_so_far.push_str("<s>");
}
@ -140,6 +170,7 @@ impl AiConversation {
prompt.grammar.clone(),
prompt.max_tokens,
false,
prompt.creativity,
);
// TODO working on removing trait bounds issue so we can use ? operator.

View File

@ -176,6 +176,7 @@ impl AiClient {
}
}
CoherenceFailure::DuplicateExits(bad_exits) => {
println!("found duplicate exits {:?}", bad_exits);
let position = find_exit_position(&scene.exits, bad_exits[0])?;
SceneFix::DeleteExit(position)
}

View File

@ -229,7 +229,7 @@ fn scene_info_for_person(scene: &SceneSeed) -> String {
}
pub fn scene_creation_prompt(scene_type: &str, fantasticalness: &str) -> AiPrompt {
AiPrompt::new_with_grammar_and_size(
AiPrompt::creative_with_grammar_and_size(
&SCENE_CREATION_PROMPT
.replacen("{SCENE_INSTRUCTIONS}", SCENE_INSTRUCTIONS, 1)
.replacen("{}", scene_type, 1)
@ -269,7 +269,7 @@ pub fn scene_from_stub_prompt(connected_scene: &Scene, stub: &SceneStub) -> AiPr
.map(|exit| exit.direction.as_ref())
.unwrap_or("back");
AiPrompt::new_with_grammar_and_size(
AiPrompt::creative_with_grammar_and_size(
&SCENE_FROM_STUB_PROMPT
.replacen("{SCENE_INSTRUCTIONS}", SCENE_INSTRUCTIONS, 1)
.replacen("{CONNECTED_SCENE_NAME}", &connected_scene.name, 1)
@ -289,7 +289,7 @@ pub fn scene_from_stub_prompt(connected_scene: &Scene, stub: &SceneStub) -> AiPr
}
pub fn person_creation_prompt(scene: &SceneSeed, person: &PersonSeed) -> AiPrompt {
AiPrompt::new_with_grammar_and_size(
AiPrompt::creative_with_grammar_and_size(
&PERSON_CREATION_PROMPT
.replacen("{NAME}", &person.name, 1)
.replacen("{RACE}", &person.race, 1)

View File

@ -6,14 +6,25 @@ use serde::{Deserialize, Serialize};
use std::num::NonZeroU64;
use std::time::Duration;
use crate::ai::convo::AiCreativity;
include!(concat!(env!("OUT_DIR"), "/codegen.rs"));
fn creativity_to_temperature(creativity: AiCreativity) -> Option<f64> {
match creativity {
AiCreativity::Predictable => Some(0.5),
AiCreativity::Normal => Some(0.7),
AiCreativity::Creative => Some(1.0),
}
}
pub fn create_input(
gen_key: String,
prompt: &str,
grammar: Option<String>,
max_tokens: u64,
retain_gramar_state: bool,
creativity: AiCreativity,
) -> types::GenerationInput {
types::GenerationInput {
genkey: Some(gen_key),
@ -28,7 +39,7 @@ pub fn create_input(
mirostat_eta: None,
mirostat_tau: None,
rep_pen: Some(1.1),
temperature: Some(0.7),
temperature: creativity_to_temperature(creativity),
tfs: None,
top_a: Some(0.0),
top_p: Some(0.92),