/// model/exponential.rs, BAYES STAR, (c) coppola.ai 2024
use super::choose::extract_backimplications_from_proposition;
use super::config::ConfigurationOptions;
use super::objects::PredicateFactor;
use super::weights::{negative_feature, positive_feature, ExponentialWeights};
use crate::common::interface::{PropositionDB, PredictStatistics, TrainStatistics};
use crate::common::model::InferenceModel;
use crate::common::model::{FactorContext, FactorModel};
use crate::common::redis::RedisManager;
use crate::common::resources::FactoryResources;
use crate::model::objects::Predicate;
use crate::model::weights::CLASS_LABELS;
use crate::{print_yellow, print_blue};
use redis::Connection;
use std::cell::RefCell;
use std::collections::HashMap;
use std::error::Error;
use std::rc::Rc;
pub struct ExponentialModel {
config: ConfigurationOptions,
weights: ExponentialWeights,
}
impl ExponentialModel {
pub fn new_mutable(resources: &FactoryResources) -> Result<Box<dyn FactorModel>, Box<dyn Error>> {
let connection = resources.redis.get_connection()?;
let weights = ExponentialWeights::new(connection);
Ok(Box::new(ExponentialModel {
config: resources.config.clone(),
weights,
}))
}
pub fn new_shared(resources: &FactoryResources) -> Result<Rc<dyn FactorModel>, Box<dyn Error>> {
let connection = resources.redis.get_connection()?;
let weights = ExponentialWeights::new(connection);
Ok(Rc::new(ExponentialModel {
config: resources.config.clone(),
weights,
}))
}
}
fn dot_product(dict1: &HashMap<String, f64>, dict2: &HashMap<String, f64>) -> f64 {
let mut result = 0.0;
for (key, &v1) in dict1 {
if let Some(&v2) = dict2.get(key) {
let product = v1 * v2;
print_blue!("dot_product: key {}, v1 {}, v2 {}, product {}", key, v1, v2, product);
result += product;
}
// In case of null (None), we skip the key as per the original JavaScript logic.
}
result
}
pub fn compute_potential(weights: &HashMap<String, f64>, features: &HashMap<String, f64>) -> f64 {
let dot = dot_product(weights, features);
dot.exp()
}
pub fn features_from_factor(
factor: &FactorContext,
) -> Result<Vec<HashMap<String, f64>>, Box<dyn Error>> {
let mut vec_result = vec![];
for class_label in CLASS_LABELS {
let mut result = HashMap::new();
for (i, premise) in factor.factor.iter().enumerate() {
debug!("Processing backimplication {}", i);
let feature = premise.inference.unique_key();
debug!("Generated unique key for feature: {}", feature);
let probability = factor.probabilities[i];
debug!(
"Conjunction probability for backimplication {}: {}",
i, probability
);
let posf = positive_feature(&feature, class_label);
let negf = negative_feature(&feature, class_label);
result.insert(posf.clone(), probability);
result.insert(negf.clone(), 1.0 - probability);
debug!(
"Inserted features for backimplication {}: positive - {}, negative - {}",
i, posf, negf
);
}
vec_result.push(result);
}
trace!("features_from_backimplications completed successfully");
Ok(vec_result)
}
pub fn compute_expected_features(
probability: f64,
features: &HashMap<String, f64>,
) -> HashMap<String, f64> {
let mut result = HashMap::new();
for (key, &value) in features {
result.insert(key.clone(), value * probability);
}
result
}
const LEARNING_RATE: f64 = 0.025;
pub fn do_sgd_update(
weights: &HashMap<String, f64>,
gold_features: &HashMap<String, f64>,
expected_features: &HashMap<String, f64>,
print_training_loss: bool,
) -> HashMap<String, f64> {
let mut new_weights = HashMap::new();
for (feature, &wv) in weights {
let gv = gold_features.get(feature).unwrap_or(&0.0);
let ev = expected_features.get(feature).unwrap_or(&0.0);
let new_weight = wv + LEARNING_RATE * (gv - ev);
let loss = (gv - ev).abs();
if print_training_loss {
info!(
"feature: {}, gv: {}, ev: {}, loss: {}, old_weight: {}, new_weight: {}",
feature, gv, ev, loss, wv, new_weight
);
}
new_weights.insert(feature.clone(), new_weight);
}
new_weights
}
impl FactorModel for ExponentialModel {
fn initialize_connection(
&mut self,
implication: &PredicateFactor,
) -> Result<(), Box<dyn Error>> {
self.weights.initialize_weights(implication)?;
Ok(())
}
fn train(
&mut self,
factor: &FactorContext,
gold_probability: f64,
) -> Result<TrainStatistics, Box<dyn Error>> {
trace!("train_on_example - Getting features from backimplications");
let features = match features_from_factor(factor) {
Ok(f) => f,
Err(e) => {
trace!(
"train_on_example - Error in features_from_backimplications: {:?}",
e
);
return Err(e);
}
};
let mut weight_vectors = vec![];
let mut potentials = vec![];
for class_label in CLASS_LABELS {
for (feature, weight) in &features[class_label] {
trace!("feature {:?} {}", feature, weight);
}
trace!(
"train_on_example - Reading weights for class {}",
class_label
);
let weight_vector = match self
.weights
.read_weights(&features[class_label].keys().cloned().collect::<Vec<_>>())
{
Ok(w) => w,
Err(e) => {
trace!("train_on_example - Error in read_weights: {:?}", e);
return Err(e);
}
};
trace!("train_on_example - Computing probability");
let potential = compute_potential(&weight_vector, &features[class_label]);
trace!("train_on_example - Computed probability: {}", potential);
potentials.push(potential);
weight_vectors.push(weight_vector);
}
let normalization = potentials[0] + potentials[1];
for class_label in CLASS_LABELS {
let probability = potentials[class_label] / normalization;
trace!("train_on_example - Computing expected features");
let this_true_prob = if class_label == 0 {
1f64 - gold_probability
} else {
gold_probability
};
let gold = compute_expected_features(this_true_prob, &features[class_label]);
let expected = compute_expected_features(probability, &features[class_label]);
trace!("train_on_example - Performing SGD update");
let new_weight = do_sgd_update(
&weight_vectors[class_label],
&gold,
&expected,
self.config.print_training_loss,
);
trace!("train_on_example - Saving new weights");
self.weights.save_weights(&new_weight)?;
}
trace!("train_on_example - End");
Ok(TrainStatistics { loss: 1f64 })
}
fn predict(&self, factor: &FactorContext) -> Result<PredictStatistics, Box<dyn Error>> {
let features = match features_from_factor(factor) {
Ok(f) => f,
Err(e) => {
print_yellow!(
"inference_probability - Error in features_from_backimplications: {:?}",
e
);
return Err(e);
}
};
let mut potentials = vec![];
for class_label in CLASS_LABELS {
let this_features = &features[class_label];
for (feature, weight) in this_features.iter() {
print_yellow!("feature {:?} {}", &feature, weight);
}
print_yellow!("inference_probability - Reading weights");
let weight_vector = match self
.weights
.read_weights(&this_features.keys().cloned().collect::<Vec<_>>())
{
Ok(w) => w,
Err(e) => {
print_yellow!("inference_probability - Error in read_weights: {:?}", e);
return Err(e);
}
};
for (feature, weight) in weight_vector.iter() {
print_yellow!("weight {:?} {}", &feature, weight);
}
let potential = compute_potential(&weight_vector, &this_features);
print_yellow!("potential for {} {} {:?}", class_label, potential, &factor);
potentials.push(potential);
}
let normalization = potentials[0] + potentials[1];
let marginal = potentials[1] / normalization;
print_yellow!("dot_product: normalization {}, marginal {}", normalization, marginal);
Ok(PredictStatistics { marginal })
}
}