summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Worthe <justin@worthe-it.co.za>2018-09-08 13:01:33 +0200
committerJustin Worthe <justin@worthe-it.co.za>2018-09-08 13:01:33 +0200
commit694c3421b23e74da244675cbf7b2c3cfbb2866ab (patch)
treecc9a88131212167b23cd931534ad206e4b064c49
parent32e1dedc420c1011f63aaa90ed96fa19d2590a77 (diff)
Cached more values in tree exploration calcs
-rw-r--r--src/strategy/monte_carlo_tree.rs22
1 files changed, 17 insertions, 5 deletions
diff --git a/src/strategy/monte_carlo_tree.rs b/src/strategy/monte_carlo_tree.rs
index fe59e34..2d27b62 100644
--- a/src/strategy/monte_carlo_tree.rs
+++ b/src/strategy/monte_carlo_tree.rs
@@ -15,8 +15,10 @@ struct NodeStats {
wins: f32,
losses: f32,
attempts: f32,
+ average: f32,
+ confidence: f32,
explored: Vec<(Command, NodeStats)>,
- unexplored: Vec<Command>
+ unexplored: Vec<Command>,
}
impl NodeStats {
@@ -58,6 +60,8 @@ impl NodeStats {
wins: 0.,
losses: 0.,
attempts: 0.,
+ average: 0.,
+ confidence: 0.,
explored: Vec::with_capacity(commands.len()),
unexplored: commands
}
@@ -67,11 +71,12 @@ impl NodeStats {
debug_assert!(self.unexplored.is_empty());
debug_assert!(self.explored.len() > 0);
let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::<f32>();
+ let sqrt_n = total_attempts.sqrt();
let mut max_position = 0;
- let mut max_value = self.explored[0].1.ucb(total_attempts);
+ let mut max_value = self.explored[0].1.ucb(sqrt_n);
for i in 1..self.explored.len() {
- let value = self.explored[i].1.ucb(total_attempts);
+ let value = self.explored[i].1.ucb(sqrt_n);
if value > max_value {
max_position = i;
max_value = value;
@@ -80,8 +85,8 @@ impl NodeStats {
&mut self.explored[max_position]
}
- fn ucb(&self, n: f32) -> f32 {
- self.wins / self.attempts + (2.0 * n / self.attempts).sqrt()
+ fn ucb(&self, sqrt_n: f32) -> f32 {
+ self.average + sqrt_n * self.confidence
}
fn add_node<'a>(&'a mut self, player: &Player, command: Command) -> &'a mut (Command, NodeStats) {
@@ -94,13 +99,20 @@ impl NodeStats {
fn add_victory(&mut self) {
self.attempts += 1.;
self.wins += 1.;
+ self.update_confidence();
}
fn add_defeat(&mut self) {
self.attempts += 1.;
self.losses += 1.;
+ self.update_confidence();
}
fn add_draw(&mut self) {
self.attempts += 1.;
+ self.update_confidence();
+ }
+ fn update_confidence(&mut self) {
+ self.average = self.wins / self.attempts;
+ self.confidence = (2.0 / self.attempts).sqrt();
}
#[cfg(feature = "benchmarking")]