summaryrefslogtreecommitdiff
path: root/src/strategy.rs
diff options
context:
space:
mode:
authorJustin Worthe <justin@worthe-it.co.za>2019-05-12 00:58:14 +0200
committerJustin Worthe <justin@worthe-it.co.za>2019-05-12 00:58:14 +0200
commit3d1676842e20c90bb5599daa2caefdea2bbf9fe8 (patch)
treec5c45ef35064fad01ed690be8555273224765e67 /src/strategy.rs
parentebe7d5cd5cc42d8f3f02ca926f4b920ada03765f (diff)
Outline of MCTS
Diffstat (limited to 'src/strategy.rs')
-rw-r--r--src/strategy.rs144
1 files changed, 113 insertions, 31 deletions
diff --git a/src/strategy.rs b/src/strategy.rs
index dd15854..ce65e54 100644
--- a/src/strategy.rs
+++ b/src/strategy.rs
@@ -1,52 +1,132 @@
use crate::command::Command;
-use crate::game::GameBoard;
+use crate::game::{GameBoard, SimulationOutcome};
use crate::geometry::*;
+use std::ops::*;
+use std::collections::HashMap;
+use time::{Duration, PreciseTime};
+
struct GameTree {
state: GameBoard,
next_states: Vec<([Command; 2], GameTree)>
}
-pub fn choose_move(state: &GameBoard) -> Command {
- let mut root = GameTree {
+pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Duration) -> Command {
+ let mut root_node = Node {
state: state.clone(),
- next_states: Vec::new()
+ score_sum: Score { val: 0 },
+ visit_count: 0,
+ children: HashMap::new()
};
- let mut last_depth = vec!(&mut root);
-
- for depth in 0.. {
- println!("Trying depth {}", depth);
- println!("{} wide", last_depth.len());
- let mut next_depth = Vec::new();
- for mut tree in last_depth {
- populate_next_states(&mut tree);
- for x in &mut tree.next_states {
- next_depth.push(&mut x.1);
- }
- }
- last_depth = next_depth;
+ while start_time.to(PreciseTime::now()) < max_time {
+ let _ = mcts(&mut root_node);
}
-
- Command::DoNothing
+
+ root_node
+ .children
+ .iter()
+ .max_by_key(|(_k, v)| v.score())
+ .map(|(k, _v)| k[0])
+ .unwrap_or(Command::DoNothing)
+}
+
+struct Node {
+ state: GameBoard,
+ score_sum: Score,
+ visit_count: u32,
+ children: HashMap<[Command; 2], Node>
+}
+
+impl Node {
+ fn score(&self) -> Score {
+ self.score_sum / self.visit_count
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
+struct Score {
+ val: i32
}
-fn populate_next_states(tree: &mut GameTree) {
- let valid_player_moves = valid_moves(&tree.state, 0);
- let valid_opponent_moves = valid_moves(&tree.state, 1);
- for player_move in valid_player_moves {
- for opponent_move in &valid_opponent_moves {
- let commands = [player_move, *opponent_move];
- let mut new_state = tree.state.clone();
- let _ = new_state.simulate(commands);
- tree.next_states.push((commands, GameTree {
- state: new_state,
- next_states: Vec::new()
- }));
+impl AddAssign for Score {
+ fn add_assign(&mut self, other: Self) {
+ self.val += other.val;
+ }
+}
+
+impl Div<u32> for Score {
+ type Output = Self;
+ fn div(self, other: u32) -> Self {
+ Score {
+ val: self.val / other as i32
}
}
}
+fn mcts(node: &mut Node) -> Score {
+ if node.state.outcome != SimulationOutcome::Continue {
+ score(&node.state)
+ } else if has_unsimulated_outcomes(node) {
+ let commands = choose_unsimulated(&node);
+
+ let mut new_state = node.state.clone();
+ new_state.simulate(commands);
+ let score = rollout(&new_state);
+
+ let new_node = Node {
+ state: new_state,
+ score_sum: score,
+ visit_count: 1,
+ children: HashMap::new()
+ };
+ node.children.insert(commands, new_node);
+
+ update(node, commands, score);
+ score
+ } else {
+ let commands = select(node);
+ let score = mcts(node.children.get_mut(&commands).unwrap());
+ update(node, commands, score);
+ score
+ }
+}
+
+fn score(state: &GameBoard) -> Score {
+ // TODO
+ Score { val: 0 }
+}
+
+fn has_unsimulated_outcomes(node: &Node) -> bool {
+ // TODO
+ false
+}
+
+fn choose_unsimulated(node: &Node) -> [Command; 2] {
+ // TODO
+ [
+ Command::DoNothing,
+ Command::DoNothing
+ ]
+}
+
+fn rollout(state: &GameBoard) -> Score {
+ // TODO
+ Score { val: 0 }
+}
+
+fn select(node: &Node) -> [Command; 2] {
+ // TODO
+ [
+ Command::DoNothing,
+ Command::DoNothing
+ ]
+}
+
+fn update(node: &mut Node, commands: [Command; 2], score: Score) {
+ // TODO
+}
+
fn valid_moves(state: &GameBoard, player_index: usize) -> Vec<Command> {
let worm = state.players[player_index].active_worm();
@@ -60,9 +140,11 @@ fn valid_moves(state: &GameBoard, player_index: usize) -> Vec<Command> {
})
.collect::<Vec<_>>();
let mut shoots = Direction::all().iter()
+ .filter(|dir| state.find_target(worm.position, **dir, worm.weapon_range).is_some())
.map(|d| Command::Shoot(*d))
.collect::<Vec<_>>();
moves.append(&mut shoots);
+ moves.retain(|m| *m != Command::DoNothing);
moves
}