/// inference/inference.rs, BAYES STAR, (c) coppola.ai 2024
use super::{
    graph::{PropositionFactor, PropositionGraph},
    table::{HashMapBeliefTable, InferenceResult, PropositionNode},
};
use crate::{
    common::{
        interface::PropositionDB,
        model::{FactorContext, InferenceModel},
        proposition_db,
    },
    inference::table::{GenericNodeType, HashMapInferenceResult},
    model::{
        objects::{Predicate, PredicateGroup, Proposition, PropositionGroup, EXISTENCE_FUNCTION},
        weights::CLASS_LABELS,
    },
    print_blue, print_green, print_red, print_yellow,
};
use redis::Connection;
use std::{
    borrow::Borrow,
    collections::{HashMap, HashSet, VecDeque},
    error::Error,
    rc::Rc,
};

struct Inferencer {
    model: Rc<InferenceModel>,
    proposition_graph: Rc<PropositionGraph>,
    pub data: HashMapBeliefTable,
    bfs_order: Vec<PropositionNode>,
}

fn reverse_prune_duplicates(raw_order: &Vec<(i32, PropositionNode)>) -> Vec<PropositionNode> {
    let mut seen = HashSet::new();
    let mut result = vec![];
    for (depth, node) in raw_order.iter().rev() {
        if !seen.contains(node) {
            result.push(node.clone());
        }
        seen.insert(node);
    }
    result.reverse();
    result
}
fn create_bfs_order(proposition_graph: &PropositionGraph) -> Vec<PropositionNode> {
    let mut queue = VecDeque::new();
    let mut buffer = vec![];
    for root in &proposition_graph.roots {
        queue.push_back((0, PropositionNode::from_single(&root)));
    }

    print_yellow!("create_bfs_order initial: queue {:?}", &queue);

    while let Some((depth, node)) = queue.pop_front() {
        buffer.push((depth, node.clone()));
        let forward = proposition_graph.get_all_forward(&node);
        for child in &forward {
            queue.push_back((depth + 1, child.clone()));
        }

        print_yellow!("create_bfs_order initial: queue {:?}", &queue);
        print_yellow!("create_bfs_order initial: buffer {:?}", &buffer);
    }

    let result = reverse_prune_duplicates(&buffer);
    print_yellow!("create_bfs_order result: {:?}", &result);
    result
}

impl Inferencer {
    // Initialize new Storage with a Redis connection
    pub fn new_mutable(
        model: Rc<InferenceModel>,
        proposition_graph: Rc<PropositionGraph>,
    ) -> Result<Box<Self>, redis::RedisError> {
        let bfs_order = create_bfs_order(&proposition_graph);
        Ok(Box::new(Inferencer {
            model,
            proposition_graph,
            data: HashMapBeliefTable::new(),
            bfs_order,
        }))
    }

    pub fn initialize(&mut self, proposition: &Proposition) -> Result<(), Box<dyn Error>> {
        print_red!("initialize: proposition {:?}", proposition.hash_string());
        // self.initialize_pi()?;
        self.initialize_lambda()?;
        self.send_pi_messages()?;
        self.update_marginals()?;
        Ok(())
    }

    pub fn update_marginals(&mut self) -> Result<(), Box<dyn Error>> {
        print_red!("update_marginals over {:?}", &self.bfs_order);
        for node in &self.bfs_order {
            let pi0 = self.data.get_pi_value(node, 0).unwrap();
            let pi1 = self.data.get_pi_value(node, 1).unwrap();
            let lambda0 = self.data.get_lambda_value(node, 0).unwrap();
            let lambda1 = self.data.get_lambda_value(node, 1).unwrap();
            let potential0 = pi0 * lambda0;
            let potential1 = pi1 * lambda1;
            let norm = potential0 + potential1;
            let probability0 = potential0 / norm;
            let probability1 = potential1 / norm;
            print_red!("node {:?} p0 {} p1 {}", node, probability0, probability1);
        }
        Ok(())
    }

    pub fn initialize_lambda(&mut self) -> Result<(), Box<dyn Error>> {
        print_red!("initialize_lambda: proposition");
        for node in &self.proposition_graph.all_nodes {
            print_red!("initializing: {}", node.debug_string());
            for outcome in CLASS_LABELS {
                self.data.set_lambda_value(node, outcome, 1f64);
            }
            for parent in &self.proposition_graph.get_all_backward(node) {
                print_red!(
                    "initializing lambda link from {} to {}",
                    node.debug_string(),
                    parent.debug_string()
                );
                for outcome in CLASS_LABELS {
                    self.data.set_lambda_message(node, parent, outcome, 1f64);
                }
            }
        }
        Ok(())
    }

    pub fn send_pi_messages(&mut self) -> Result<(), Box<dyn Error>> {
        let bfs_order = self.bfs_order.clone();
        print_red!("send_pi_messages bfs_order: {:?}", &bfs_order);
        for node in &bfs_order {
            print_yellow!("send pi bfs selects {:?}", node);
            self.pi_visit_node(node)?;
        }
        Ok(())
    }

    fn pi_compute_root(&mut self, node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        let root = node.extract_single();
        assert_eq!(root.predicate.function, EXISTENCE_FUNCTION.to_string());
        self.data
            .set_pi_value(&PropositionNode::from_single(&root), 1, 1.0f64);
        self.data
            .set_pi_value(&PropositionNode::from_single(&root), 0, 0.0f64);

        Ok(())
    }

    pub fn pi_set_from_evidence(&mut self, node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        let as_single = node.extract_single();
        let probability = self
            .model
            .proposition_db
            .get_proposition_probability(&as_single)?
            .unwrap();
        self.data
            .set_pi_value(node, 1, probability);
        self.data
            .set_pi_value(node, 0, 1f64 - probability);
        Ok(())
    }

    pub fn pi_visit_node(&mut self, from_node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        // Part 1: Compute pi for this node.
        if !self.is_root(from_node) {
            let is_observed = self.is_observed(from_node)?;
            if is_observed {
                self.pi_set_from_evidence(from_node)?;
            } else {
                self.pi_compute_generic(&from_node)?;
            }
        } else {
            self.pi_compute_root(from_node)?;
        }
        // Part 2: For each value of z, compute pi_X(z)
        let forward_groups = self.proposition_graph.get_all_forward(from_node);
        for (this_index, to_node) in forward_groups.iter().enumerate() {
            for class_label in &CLASS_LABELS {
                let mut lambda_part = 1f64;
                for (other_index, other_node) in forward_groups.iter().enumerate() {
                    if other_index != this_index {
                        let this_lambda = self
                            .data
                            .get_lambda_value(&other_node, *class_label)
                            .unwrap();
                        lambda_part *= this_lambda;
                    }
                }
                let pi_part = self.data.get_pi_value(&from_node, *class_label).unwrap();
                let message = pi_part * lambda_part;
                self.data
                    .set_pi_message(&from_node, &to_node, *class_label, message);
            }
        }
        // Success.
        Ok(())
    }

    fn is_root(&self, node: &PropositionNode) -> bool {
        if node.is_single() {
            let as_single = node.extract_single();
            let is_root = self.proposition_graph.roots.contains(&as_single);
            is_root
        } else {
            false
        }
    }

    fn is_observed(&self, node: &PropositionNode) -> Result<bool, Box<dyn Error>> {
        if node.is_single() {
            let as_single = node.extract_single();
            let has_evidence = self
                .model
                .proposition_db
                .get_proposition_probability(&as_single)?
                .is_some();
            print_green!(
                "is_observed? node {:?}, has_evidence {}",
                &as_single,
                has_evidence
            );
            Ok(has_evidence)
        } else {
            Ok(false)
        }
    }

    pub fn pi_compute_generic(&mut self, node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        match &node.node {
            GenericNodeType::Single(proposition) => {
                self.pi_compute_single(node)?;
            }
            GenericNodeType::Group(group) => {
                self.pi_compute_group(node)?;
            }
        }
        Ok(())
    }

    // from_node is a single.. compute it from the group
    pub fn pi_compute_single(&mut self, node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        let conclusion = node.extract_single();
        let parent_nodes = self.proposition_graph.get_all_backward(node);
        let premise_groups = groups_from_backlinks(&parent_nodes);
        let all_combinations = compute_each_combination(&parent_nodes);
        let mut sum_true = 0f64;
        let mut sum_false = 0f64;
        for combination in &all_combinations {
            let mut product = 1f64;
            for (index, parent_node) in parent_nodes.iter().enumerate() {
                let boolean_outcome = combination.get(parent_node).unwrap();
                let usize_outcome = if *boolean_outcome { 1 } else { 0 };
                let pi_x_z = self
                    .data
                    .get_pi_message(parent_node, node, usize_outcome)
                    .unwrap();
                print_red!(
                    "getting pi message parent_node {:?}, node {:?}, usize_outcome {}, pi_x_z {}",
                    &parent_node,
                    &node,
                    usize_outcome,
                    pi_x_z,
                );
                product *= pi_x_z;
            }
            let factor =
                self.build_factor_context_for_assignment(&premise_groups, combination, &conclusion);
            let prediction = self.model.model.predict(&factor)?;
            print_yellow!("local probability {}  for factor {:?}", &prediction.marginal, &factor);
            let true_marginal = &prediction.marginal;
            let false_marginal = 1f64 - true_marginal;
            sum_true += true_marginal * product;
            sum_false += false_marginal * product;
        }
        self.data.set_pi_value(node, 1, sum_true);
        self.data.set_pi_value(node, 0, sum_false);
        Ok(())
    }

    pub fn pi_compute_group(&mut self, node: &PropositionNode) -> Result<(), Box<dyn Error>> {
        let parent_nodes = self.proposition_graph.get_all_backward(node);
        print_yellow!("pi_compute_group {:?}", &parent_nodes);
        let all_combinations = compute_each_combination(&parent_nodes);
        let mut sum_true = 0f64;
        let mut sum_false = 0f64;
        for combination in &all_combinations {
            let mut product = 1f64;
            let mut condition = true;
            for (index, parent_node) in parent_nodes.iter().enumerate() {
                let boolean_outcome = combination.get(parent_node).unwrap();
                let usize_outcome = if *boolean_outcome { 1 } else { 0 };
                print_green!(
                    "get pi message: parent_node {:?}, node {:?}, outcome: {}",
                    parent_node,
                    node,
                    usize_outcome
                );
                let pi_x_z = self
                    .data
                    .get_pi_message(parent_node, node, usize_outcome)
                    .unwrap();
                print_yellow!(
                    "boolean_outcome {} usize_outcome {} pi_x_z {}",
                    boolean_outcome,
                    usize_outcome,
                    pi_x_z
                );
                product *= pi_x_z;
                let combination_val = combination[parent_node];
                condition = condition && combination_val;
                print_yellow!(
                    "combination_val {} condition {}",
                    combination_val,
                    condition
                );
            }
            if condition {
                print_blue!("true combination: {:?}, product {}", &combination, product);
                sum_true += product;
            } else {
                print_blue!("false combination: {:?}, product {}", &combination, product);
                sum_false += product;
            }
        }
        self.data.set_pi_value(node, 1, sum_true);
        self.data.set_pi_value(node, 0, sum_false);
        Ok(())
    }

    // TODO: move this out of the class
    fn build_factor_context_for_assignment(
        &self,
        premises: &Vec<PropositionGroup>,
        premise_assignment: &HashMap<PropositionNode, bool>,
        conclusion: &Proposition,
    ) -> FactorContext {
        let mut probabilities = vec![];
        let mut factors = vec![];
        for proposition_group in premises {
            let node = PropositionNode::from_group(proposition_group);
            let assignment = *premise_assignment.get(&node).unwrap();
            if assignment {
                probabilities.push(1f64);
            } else {
                probabilities.push(0f64);
            }
            let inference = self
                .proposition_graph
                .get_inference_used(proposition_group, conclusion);
            let factor = PropositionFactor {
                premise: proposition_group.clone(),
                conclusion: conclusion.clone(),
                inference,
            };
            factors.push(factor);
        }
        let context = FactorContext {
            factor: factors,
            probabilities,
        };
        context
    }
}

// Return 1 HashMap for each of the 2^N ways to assign each of the N memebers of `propositions` to either true or false.
fn compute_each_combination(
    propositions: &Vec<PropositionNode>,
) -> Vec<HashMap<PropositionNode, bool>> {
    print_yellow!("compute_each_combination: propositions={:?}", &propositions);
    let n = propositions.len();
    let mut all_combinations = Vec::new();
    for i in 0..(1 << n) {
        let mut current_combination = HashMap::new();
        for j in 0..n {
            let prop = &propositions[j];
            let state = i & (1 << j) != 0;
            current_combination.insert(prop.clone(), state);
        }
        all_combinations.push(current_combination);
    }
    all_combinations
}

// Note: GraphicalModel contains PropositionDB, which contains the "evidence".
pub fn inference_compute_marginals(
    model: Rc<InferenceModel>,
    target: &Proposition,
) -> Result<Rc<dyn InferenceResult>, Box<dyn Error>> {
    let proposition_graph = PropositionGraph::new_shared(model.graph.clone(), target)?;
    // proposition_graph.visualize();
    let mut inferencer = Inferencer::new_mutable(model.clone(), proposition_graph.clone())?;
    inferencer.initialize(target)?;
    inferencer.data.print_debug();
    HashMapInferenceResult::new_shared(inferencer.data)
}

fn groups_from_backlinks(backlinks: &Vec<PropositionNode>) -> Vec<PropositionGroup> {
    let mut result = vec![];
    for backlink in backlinks {
        let group = backlink.extract_group();
        result.push(group);
    }
    result
}