Add creativity levels to the prompt parsing
This commit is contained in:
parent
02a4897b4b
commit
d5fb204df4
|
@ -5,7 +5,6 @@ use async_recursion::async_recursion;
|
|||
use serde::de::DeserializeOwned;
|
||||
use serde_json::error::Category;
|
||||
use serde_json::Value;
|
||||
use std::borrow::BorrowMut;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
|
@ -40,6 +39,7 @@ async fn continue_execution<T: DeserializeOwned>(
|
|||
prompt.grammar.clone(),
|
||||
prompt.max_tokens,
|
||||
true,
|
||||
prompt.creativity,
|
||||
);
|
||||
|
||||
// TODO convert error to remove trait bound issue
|
||||
|
@ -65,10 +65,18 @@ async fn continue_execution<T: DeserializeOwned>(
|
|||
Ok(resp)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AiCreativity {
|
||||
Predictable,
|
||||
Normal,
|
||||
Creative,
|
||||
}
|
||||
|
||||
pub struct AiPrompt {
|
||||
pub prompt: String,
|
||||
pub grammar: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub creativity: AiCreativity,
|
||||
}
|
||||
|
||||
impl AiPrompt {
|
||||
|
@ -77,6 +85,7 @@ impl AiPrompt {
|
|||
prompt: prompt.to_string(),
|
||||
grammar: None,
|
||||
max_tokens: 150,
|
||||
creativity: AiCreativity::Normal,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,6 +94,7 @@ impl AiPrompt {
|
|||
prompt: prompt.to_string(),
|
||||
grammar: Some(grammar.to_string()),
|
||||
max_tokens: 150,
|
||||
creativity: AiCreativity::Normal,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,6 +103,25 @@ impl AiPrompt {
|
|||
prompt: prompt.to_string(),
|
||||
grammar: Some(grammar.to_string()),
|
||||
max_tokens: tokens,
|
||||
creativity: AiCreativity::Normal,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn creative_with_grammar(prompt: &str, grammar: &str) -> AiPrompt {
|
||||
AiPrompt {
|
||||
prompt: prompt.to_string(),
|
||||
grammar: Some(grammar.to_string()),
|
||||
max_tokens: 150,
|
||||
creativity: AiCreativity::Creative,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn creative_with_grammar_and_size(prompt: &str, grammar: &str, tokens: u64) -> AiPrompt {
|
||||
AiPrompt {
|
||||
prompt: prompt.to_string(),
|
||||
grammar: Some(grammar.to_string()),
|
||||
max_tokens: tokens,
|
||||
creativity: AiCreativity::Creative,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,11 +151,12 @@ impl AiConversation {
|
|||
}
|
||||
|
||||
pub async fn execute<T: DeserializeOwned>(&self, prompt: &AiPrompt) -> Result<T> {
|
||||
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
|
||||
let prompt_so_far = &mut *prompt_so_far;
|
||||
|
||||
// Handle Mistral-instruct begin instruct mode.
|
||||
// https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
|
||||
// Only the very first one begins with <s>, not subsequent.
|
||||
let mut prompt_so_far = RefCell::borrow_mut(&self.prompt_so_far);
|
||||
let prompt_so_far = &mut *prompt_so_far;
|
||||
if prompt_so_far.is_empty() {
|
||||
prompt_so_far.push_str("<s>");
|
||||
}
|
||||
|
@ -140,6 +170,7 @@ impl AiConversation {
|
|||
prompt.grammar.clone(),
|
||||
prompt.max_tokens,
|
||||
false,
|
||||
prompt.creativity,
|
||||
);
|
||||
|
||||
// TODO working on removing trait bounds issue so we can use ? operator.
|
||||
|
|
|
@ -176,6 +176,7 @@ impl AiClient {
|
|||
}
|
||||
}
|
||||
CoherenceFailure::DuplicateExits(bad_exits) => {
|
||||
println!("found duplicate exits {:?}", bad_exits);
|
||||
let position = find_exit_position(&scene.exits, bad_exits[0])?;
|
||||
SceneFix::DeleteExit(position)
|
||||
}
|
||||
|
|
|
@ -229,7 +229,7 @@ fn scene_info_for_person(scene: &SceneSeed) -> String {
|
|||
}
|
||||
|
||||
pub fn scene_creation_prompt(scene_type: &str, fantasticalness: &str) -> AiPrompt {
|
||||
AiPrompt::new_with_grammar_and_size(
|
||||
AiPrompt::creative_with_grammar_and_size(
|
||||
&SCENE_CREATION_PROMPT
|
||||
.replacen("{SCENE_INSTRUCTIONS}", SCENE_INSTRUCTIONS, 1)
|
||||
.replacen("{}", scene_type, 1)
|
||||
|
@ -269,7 +269,7 @@ pub fn scene_from_stub_prompt(connected_scene: &Scene, stub: &SceneStub) -> AiPr
|
|||
.map(|exit| exit.direction.as_ref())
|
||||
.unwrap_or("back");
|
||||
|
||||
AiPrompt::new_with_grammar_and_size(
|
||||
AiPrompt::creative_with_grammar_and_size(
|
||||
&SCENE_FROM_STUB_PROMPT
|
||||
.replacen("{SCENE_INSTRUCTIONS}", SCENE_INSTRUCTIONS, 1)
|
||||
.replacen("{CONNECTED_SCENE_NAME}", &connected_scene.name, 1)
|
||||
|
@ -289,7 +289,7 @@ pub fn scene_from_stub_prompt(connected_scene: &Scene, stub: &SceneStub) -> AiPr
|
|||
}
|
||||
|
||||
pub fn person_creation_prompt(scene: &SceneSeed, person: &PersonSeed) -> AiPrompt {
|
||||
AiPrompt::new_with_grammar_and_size(
|
||||
AiPrompt::creative_with_grammar_and_size(
|
||||
&PERSON_CREATION_PROMPT
|
||||
.replacen("{NAME}", &person.name, 1)
|
||||
.replacen("{RACE}", &person.race, 1)
|
||||
|
|
|
@ -6,14 +6,25 @@ use serde::{Deserialize, Serialize};
|
|||
use std::num::NonZeroU64;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ai::convo::AiCreativity;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/codegen.rs"));
|
||||
|
||||
fn creativity_to_temperature(creativity: AiCreativity) -> Option<f64> {
|
||||
match creativity {
|
||||
AiCreativity::Predictable => Some(0.5),
|
||||
AiCreativity::Normal => Some(0.7),
|
||||
AiCreativity::Creative => Some(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_input(
|
||||
gen_key: String,
|
||||
prompt: &str,
|
||||
grammar: Option<String>,
|
||||
max_tokens: u64,
|
||||
retain_gramar_state: bool,
|
||||
creativity: AiCreativity,
|
||||
) -> types::GenerationInput {
|
||||
types::GenerationInput {
|
||||
genkey: Some(gen_key),
|
||||
|
@ -28,7 +39,7 @@ pub fn create_input(
|
|||
mirostat_eta: None,
|
||||
mirostat_tau: None,
|
||||
rep_pen: Some(1.1),
|
||||
temperature: Some(0.7),
|
||||
temperature: creativity_to_temperature(creativity),
|
||||
tfs: None,
|
||||
top_a: Some(0.0),
|
||||
top_p: Some(0.92),
|
||||
|
|
Loading…
Reference in New Issue