From 694c3421b23e74da244675cbf7b2c3cfbb2866ab Mon Sep 17 00:00:00 2001 From: Justin Worthe Date: Sat, 8 Sep 2018 13:01:33 +0200 Subject: Cached more values in tree exploration calcs --- src/strategy/monte_carlo_tree.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'src/strategy') diff --git a/src/strategy/monte_carlo_tree.rs b/src/strategy/monte_carlo_tree.rs index fe59e34..2d27b62 100644 --- a/src/strategy/monte_carlo_tree.rs +++ b/src/strategy/monte_carlo_tree.rs @@ -15,8 +15,10 @@ struct NodeStats { wins: f32, losses: f32, attempts: f32, + average: f32, + confidence: f32, explored: Vec<(Command, NodeStats)>, - unexplored: Vec + unexplored: Vec, } impl NodeStats { @@ -58,6 +60,8 @@ impl NodeStats { wins: 0., losses: 0., attempts: 0., + average: 0., + confidence: 0., explored: Vec::with_capacity(commands.len()), unexplored: commands } @@ -67,11 +71,12 @@ impl NodeStats { debug_assert!(self.unexplored.is_empty()); debug_assert!(self.explored.len() > 0); let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::(); + let sqrt_n = total_attempts.sqrt(); let mut max_position = 0; - let mut max_value = self.explored[0].1.ucb(total_attempts); + let mut max_value = self.explored[0].1.ucb(sqrt_n); for i in 1..self.explored.len() { - let value = self.explored[i].1.ucb(total_attempts); + let value = self.explored[i].1.ucb(sqrt_n); if value > max_value { max_position = i; max_value = value; @@ -80,8 +85,8 @@ impl NodeStats { &mut self.explored[max_position] } - fn ucb(&self, n: f32) -> f32 { - self.wins / self.attempts + (2.0 * n / self.attempts).sqrt() + fn ucb(&self, sqrt_n: f32) -> f32 { + self.average + sqrt_n * self.confidence } fn add_node<'a>(&'a mut self, player: &Player, command: Command) -> &'a mut (Command, NodeStats) { @@ -94,13 +99,20 @@ impl NodeStats { fn add_victory(&mut self) { self.attempts += 1.; self.wins += 1.; + self.update_confidence(); } fn add_defeat(&mut self) { self.attempts += 1.; self.losses += 1.; + self.update_confidence(); } fn add_draw(&mut self) { self.attempts += 1.; + self.update_confidence(); + } + fn update_confidence(&mut self) { + self.average = self.wins / self.attempts; + self.confidence = (2.0 / self.attempts).sqrt(); } #[cfg(feature = "benchmarking")] -- cgit v1.2.3