use async_trait::async_trait; use es::SSE; use eventsource_client as es; use futures::{Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use std::num::NonZeroU64; use std::time::Duration; use crate::ai::convo::AiCreativity; include!(concat!(env!("OUT_DIR"), "/codegen.rs")); fn creativity_to_temperature(creativity: AiCreativity) -> Option { match creativity { AiCreativity::Predictable => Some(0.5), AiCreativity::Normal => Some(0.7), AiCreativity::Creative => Some(1.0), } } pub fn create_input( gen_key: String, prompt: &str, grammar: Option, max_tokens: u64, retain_gramar_state: bool, creativity: AiCreativity, ) -> types::GenerationInput { types::GenerationInput { genkey: Some(gen_key), prompt: prompt.to_string(), grammar: grammar, grammar_retain_state: retain_gramar_state, use_default_badwordsids: false, max_context_length: None, max_length: NonZeroU64::new(max_tokens), min_p: None, mirostat: None, mirostat_eta: None, mirostat_tau: None, rep_pen: Some(1.1), temperature: creativity_to_temperature(creativity), tfs: None, top_a: Some(0.0), top_p: Some(0.92), typical: None, rep_pen_range: Some(320), top_k: None, sampler_order: vec![6, 0, 1, 3, 4, 2, 5], sampler_seed: None, stop_sequence: vec!["".to_string(), "".to_string()], } } pub struct WrappedGenerationError(String); impl From for WrappedGenerationError { fn from(value: es::Error) -> Self { WrappedGenerationError(format!("{:?}", value)) } } #[derive(Serialize, Deserialize)] struct AIEvent { token: String, } fn create_response_stream( client: impl es::Client, ) -> impl Stream> { client.stream().map(|sse| { sse.and_then(|event| match event { SSE::Event(ev) => serde_json::from_str::(&ev.data) .map(|r| r.token) .map_err(|err| es::Error::Unexpected(Box::new(err))), SSE::Comment(_) => Ok("".to_string()), }) }) } #[async_trait] pub trait SseGenerationExt { async fn sse_generate( &self, input: types::GenerationInput, ) -> std::result::Result; } #[async_trait] impl SseGenerationExt for Client { async fn sse_generate( &self, input: types::GenerationInput, ) -> std::result::Result { let params = serde_json::to_string(&input)?; let stream_url = format!("{}/extra/generate/stream", self.baseurl()); let reconnect_opts = es::ReconnectOptions::reconnect(true) .retry_initial(false) .delay(Duration::from_secs(1)) .backoff_factor(2) .delay_max(Duration::from_secs(60)) .build(); let client = es::ClientBuilder::for_url(&stream_url)? .header("accept", "application/json")? .header("Content-Type", "application/json")? .method("POST".to_string()) .body(params) .reconnect(reconnect_opts) .build(); let mut stream = create_response_stream(client); let mut response = String::new(); loop { let maybe_token = stream.try_next().await; match maybe_token { Ok(Some(token)) => response.push_str(&token), Err(es::Error::Eof) => break, Err(err) => return Err(err), _ => (), } } Ok(response) } }