summaryrefslogtreecommitdiff
path: root/src/strategy/monte_carlo_tree.rs
diff options
context:
space:
mode:
authorJustin Wernick <justin@worthe-it.co.za>2022-04-19 21:26:49 +0200
committerJustin Wernick <justin@worthe-it.co.za>2022-04-19 21:26:49 +0200
commit7ec48d0d454499177b63bc5bd512a3a2d6baa839 (patch)
tree23d34d45dbb3ae977710361501a3dde3544734d1 /src/strategy/monte_carlo_tree.rs
parent1e21ebed15321aacbba53121cb40bbc60f4db1cc (diff)
Refile for merging repos
Diffstat (limited to 'src/strategy/monte_carlo_tree.rs')
-rw-r--r--src/strategy/monte_carlo_tree.rs243
1 files changed, 0 insertions, 243 deletions
diff --git a/src/strategy/monte_carlo_tree.rs b/src/strategy/monte_carlo_tree.rs
deleted file mode 100644
index 24b2088..0000000
--- a/src/strategy/monte_carlo_tree.rs
+++ /dev/null
@@ -1,243 +0,0 @@
-use engine::command::*;
-use engine::status::GameStatus;
-use engine::bitwise_engine::{Player, BitwiseGameState};
-use engine::constants::*;
-
-use rand::{Rng, XorShiftRng, SeedableRng};
-use time::{Duration, PreciseTime};
-
-use strategy::monte_carlo;
-
-use arrayvec::ArrayVec;
-
-#[derive(Debug)]
-struct NodeStats {
- wins: f32,
- losses: f32,
- attempts: f32,
- average: f32,
- confidence: f32,
- explored: Vec<(Command, NodeStats)>,
- unexplored: Vec<Command>,
-}
-
-impl NodeStats {
- fn create_node(player: &Player) -> NodeStats {
- let unoccupied_cells_count = player.unoccupied_cell_count();
- let unoccupied_cells = (0..unoccupied_cells_count)
- .map(|i| player.location_of_unoccupied_cell(i));
-
- let mut all_buildings: ArrayVec<[BuildingType; NUMBER_OF_BUILDING_TYPES]> = ArrayVec::new();
- if DEFENCE_PRICE <= player.energy {
- all_buildings.push(BuildingType::Defence);
- }
- if MISSILE_PRICE <= player.energy {
- all_buildings.push(BuildingType::Attack);
- }
- if ENERGY_PRICE <= player.energy {
- all_buildings.push(BuildingType::Energy);
- }
- if TESLA_PRICE <= player.energy && !player.has_max_teslas() {
- all_buildings.push(BuildingType::Tesla);
- }
-
- let building_command_count = unoccupied_cells.len()*all_buildings.len();
-
- let mut commands = Vec::with_capacity(building_command_count + 2);
-
- commands.push(Command::Nothing);
- if IRON_CURTAIN_PRICE <= player.energy && player.can_build_iron_curtain() {
- commands.push(Command::IronCurtain);
- }
-
- for position in unoccupied_cells {
- for &building in &all_buildings {
- commands.push(Command::Build(position, building));
- }
- }
-
- NodeStats {
- wins: 0.,
- losses: 0.,
- attempts: 0.,
- average: 0.,
- confidence: 0.,
- explored: Vec::with_capacity(commands.len()),
- unexplored: commands
- }
- }
-
- fn node_with_highest_ucb<'a>(&'a mut self) -> &'a mut (Command, NodeStats) {
- debug_assert!(self.unexplored.is_empty());
- debug_assert!(self.explored.len() > 0);
- let sqrt_n = self.attempts.sqrt();
-
- let mut max_position = 0;
- let mut max_value = self.explored[0].1.ucb(sqrt_n);
- for i in 1..self.explored.len() {
- let value = self.explored[i].1.ucb(sqrt_n);
- if value > max_value {
- max_position = i;
- max_value = value;
- }
- }
- &mut self.explored[max_position]
- }
-
- fn ucb(&self, sqrt_n: f32) -> f32 {
- self.average + sqrt_n * self.confidence
- }
-
- fn add_node<'a>(&'a mut self, player: &Player, command: Command) -> &'a mut (Command, NodeStats) {
- let node = NodeStats::create_node(player);
- self.explored.push((command, node));
- self.unexplored.retain(|c| *c != command);
- self.explored.last_mut().unwrap()
- }
-
- fn add_victory(&mut self) {
- self.attempts += 1.;
- self.wins += 1.;
- self.update_confidence();
- }
- fn add_defeat(&mut self) {
- self.attempts += 1.;
- self.losses += 1.;
- self.update_confidence();
- }
- fn add_draw(&mut self) {
- self.attempts += 1.;
- self.update_confidence();
- }
- fn update_confidence(&mut self) {
- self.average = self.wins / self.attempts;
- self.confidence = (2.0 / self.attempts).sqrt();
- }
-
- #[cfg(feature = "benchmarking")]
- fn count_explored(&self) -> usize {
- 1 + self.explored.iter().map(|(_, n)| n.count_explored()).sum::<usize>()
- }
-}
-
-pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time: Duration) -> Command {
- let mut rng = XorShiftRng::from_seed(INIT_SEED);
-
- let mut root = NodeStats::create_node(&state.player);
-
- while start_time.to(PreciseTime::now()) < max_time {
- tree_search(&state, &mut root, &mut rng);
- }
-
- #[cfg(feature = "benchmarking")]
- {
- println!("Explored nodes: {}", root.count_explored());
- }
-
- let (command, _) = root.node_with_highest_ucb();
- command.clone()
-}
-
-fn tree_search<R: Rng>(state: &BitwiseGameState, stats: &mut NodeStats, rng: &mut R) -> GameStatus {
- // root is opponent move
- // node being added is player move
-
- if state.round >= MAX_MOVES {
- return GameStatus::Draw
- }
-
- if stats.unexplored.is_empty() {
- let result = {
- let (next_command, next_tree) = stats.node_with_highest_ucb();
- tree_search_opponent(state, next_tree, next_command.clone(), rng)
- };
- match result {
- GameStatus::PlayerWon => {stats.add_defeat()},
- GameStatus::OpponentWon => {stats.add_victory()},
- _ => {stats.add_draw()}
- };
- result
- } else {
- let next_command = rng.choose(&stats.unexplored).expect("Partially explored had no options").clone();
- let result = {
- let (_, next_stats) = stats.add_node(&state.opponent, next_command);
-
- let opponent_random = monte_carlo::random_move(&state.opponent, &state.player, rng);
- let mut next_state = state.clone();
- next_state.simulate(next_command, opponent_random);
-
- let result = simulate_to_endstate(next_state, rng);
- match result {
- GameStatus::PlayerWon => {next_stats.add_victory()},
- GameStatus::OpponentWon => {next_stats.add_defeat()},
- _ => {next_stats.add_draw()}
- };
-
- result
- };
-
- match result {
- GameStatus::PlayerWon => {stats.add_defeat()},
- GameStatus::OpponentWon => {stats.add_victory()},
- _ => {stats.add_draw()}
- };
- result
- }
-}
-
-fn tree_search_opponent<R: Rng>(state: &BitwiseGameState, stats: &mut NodeStats, player_command: Command, rng: &mut R) -> GameStatus {
- // root is player move
- // node being added is opponent move
-
- if stats.unexplored.is_empty() {
- let result = {
- let (next_command, next_tree) = stats.node_with_highest_ucb();
- let mut next_state = state.clone();
- next_state.simulate(player_command, next_command.clone());
- tree_search(&next_state, next_tree, rng)
- };
- match result {
- GameStatus::PlayerWon => {stats.add_victory()},
- GameStatus::OpponentWon => {stats.add_defeat()},
- _ => {stats.add_draw()}
- };
- result
- } else {
- let next_command = rng.choose(&stats.unexplored).expect("Partially explored had no options").clone();
- let mut next_state = state.clone();
- next_state.simulate(player_command, next_command);
-
- let result = {
- let (_, next_stats) = stats.add_node(&next_state.player, next_command);
-
- let result = simulate_to_endstate(next_state, rng);
- match result {
- GameStatus::PlayerWon => {next_stats.add_defeat()},
- GameStatus::OpponentWon => {next_stats.add_victory()},
- _ => {next_stats.add_draw()}
- };
-
- result
- };
-
- match result {
- GameStatus::PlayerWon => {stats.add_victory()},
- GameStatus::OpponentWon => {stats.add_defeat()},
- _ => {stats.add_draw()}
- };
- result
- }
-}
-
-
-fn simulate_to_endstate<R: Rng>(mut state: BitwiseGameState, rng: &mut R) -> GameStatus {
- let mut status = GameStatus::Continue;
-
- while status == GameStatus::Continue && state.round < MAX_MOVES {
- let player_command = monte_carlo::random_move(&state.player, &state.opponent, rng);
- let opponent_command = monte_carlo::random_move(&state.opponent, &state.player, rng);
- status = state.simulate(player_command, opponent_command);
- }
- status
-}
-