diff --git a/src/commands/coherence.rs b/src/commands/coherence.rs index a2e90d1..6759d84 100644 --- a/src/commands/coherence.rs +++ b/src/commands/coherence.rs @@ -1,4 +1,5 @@ use super::converter::validate_event_coherence; +use super::partition; use crate::{ ai::logic::AiLogic, db::Database, @@ -38,20 +39,9 @@ impl CommandCoherence<'_> { &self, failures: Vec, ) -> ExecutionConversionResult { - let (successes, failures): (Vec<_>, Vec<_>) = stream::iter(failures.into_iter()) - .then(|failure| self.cohere_event(failure)) - .fold( - (vec![], vec![]), - |(mut successes, mut failures), res| async { - match res { - Ok(event) => successes.push(event), - Err(err) => failures.push(err), - }; - - (successes, failures) - }, - ) - .await; + let (successes, failures) = partition!( + stream::iter(failures.into_iter()).then(|failure| self.cohere_event(failure)) + ); // TODO we need to use LLM on events that have failed non-LLM coherence. @@ -68,7 +58,7 @@ impl CommandCoherence<'_> { } async fn cohere_event(&self, failure: EventCoherenceFailure) -> CoherenceResult { - let event = async { + let event_fix = async { match failure { EventCoherenceFailure::TargetDoesNotExist(event) => { self.fix_target_does_not_exist(event).await @@ -77,7 +67,7 @@ impl CommandCoherence<'_> { } }; - event + event_fix .and_then(|e| validate_event_coherence(&self.db, e)) .await } diff --git a/src/commands/converter.rs b/src/commands/converter.rs index 0aba2a5..9ef2824 100644 --- a/src/commands/converter.rs +++ b/src/commands/converter.rs @@ -1,3 +1,5 @@ +use super::coherence::strip_prefixes; +use super::partition; use crate::{ db::Database, models::commands::{ @@ -9,7 +11,6 @@ use anyhow::Result; use futures::stream::{self, StreamExt, TryStreamExt}; use itertools::{Either, Itertools}; use std::convert::TryFrom; -use super::coherence::strip_prefixes; use strum::VariantNames; @@ -85,15 +86,9 @@ pub async fn convert_raw_execution( }); // Coherence validation of converted events. - let (successes, incoherent_events): (Vec<_>, Vec<_>) = stream::iter(converted.into_iter()) - .then(|event| validate_event_coherence(db, event)) - .collect::>() - .await - .into_iter() - .partition_map(|res| match res { - Ok(event) => Either::Left(event), - Err(err) => Either::Right(err), - }); + let (successes, incoherent_events) = partition!( + stream::iter(converted.into_iter()).then(|event| validate_event_coherence(db, event)) + ); let failure_len = conversion_failures.len() + incoherent_events.len(); diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 063d102..2d995e0 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -12,6 +12,27 @@ use crate::{ use anyhow::Result; use std::rc::Rc; +/// Splits up a stream of results into successes and failures. +macro_rules! partition { + ($stream: expr) => { + $stream + .fold( + (vec![], vec![]), + |(mut successes, mut failures), res| async { + match res { + Ok(event) => successes.push(event), + Err(err) => failures.push(err), + }; + + (successes, failures) + }, + ) + .await + }; +} + +pub(self) use partition; + pub mod builtins; pub mod coherence; pub mod converter;