kevinwang676 commited on
Commit
5133bed
·
verified ·
1 Parent(s): 97073f9

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +403 -0
test.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.distributions import Categorical
6
+ import numpy as np
7
+ import gymnasium as gym
8
+ import matplotlib.pyplot as plt
9
+ from collections import deque
10
+
11
+ # Set random seeds for reproducibility
12
+ torch.manual_seed(42)
13
+ np.random.seed(42)
14
+
15
+ # Check if GPU is available
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ print(f"Using device: {device}")
18
+
19
+
20
+ # ==================== Policy Networks ====================
21
+
22
+ class CartPolePolicy(nn.Module):
23
+ """Policy network for CartPole environment"""
24
+ def __init__(self, state_dim, action_dim, hidden_dim=128):
25
+ super(CartPolePolicy, self).__init__()
26
+ self.fc1 = nn.Linear(state_dim, hidden_dim)
27
+ self.fc2 = nn.Linear(hidden_dim, action_dim)
28
+
29
+ def forward(self, x):
30
+ x = F.relu(self.fc1(x))
31
+ x = self.fc2(x)
32
+ return F.softmax(x, dim=-1)
33
+
34
+
35
+ class PongPolicy(nn.Module):
36
+ """Policy network for Pong with CNN architecture"""
37
+ def __init__(self, action_dim=2):
38
+ super(PongPolicy, self).__init__()
39
+ # CNN layers for processing 80x80 images
40
+ self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=4)
41
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
42
+
43
+ # Calculate size after convolutions: 80 -> 19 -> 8
44
+ self.fc1 = nn.Linear(32 * 8 * 8, 256)
45
+ self.fc2 = nn.Linear(256, action_dim)
46
+
47
+ def forward(self, x):
48
+ # x shape: (batch, 80, 80) -> add channel dimension
49
+ if len(x.shape) == 2:
50
+ x = x.unsqueeze(0).unsqueeze(0)
51
+ elif len(x.shape) == 3:
52
+ x = x.unsqueeze(1)
53
+
54
+ x = F.relu(self.conv1(x))
55
+ x = F.relu(self.conv2(x))
56
+ x = x.view(x.size(0), -1)
57
+ x = F.relu(self.fc1(x))
58
+ x = self.fc2(x)
59
+ return F.softmax(x, dim=-1)
60
+
61
+
62
+ # ==================== Helper Functions ====================
63
+
64
+ def preprocess(image):
65
+ """Prepro 210x160x3 uint8 frame into 6400 (80x80) 2D float array"""
66
+ image = image[35:195] # crop
67
+ image = image[::2, ::2, 0] # downsample by factor of 2
68
+ image[image == 144] = 0 # erase background (background type 1)
69
+ image[image == 109] = 0 # erase background (background type 2)
70
+ image[image != 0] = 1 # everything else (paddles, ball) just set to 1
71
+ return np.reshape(image.astype(float).ravel(), [80, 80])
72
+
73
+
74
+ def compute_returns(rewards, gamma):
75
+ """Compute discounted returns for each timestep"""
76
+ returns = []
77
+ R = 0
78
+ for r in reversed(rewards):
79
+ R = r + gamma * R
80
+ returns.insert(0, R)
81
+ returns = torch.tensor(returns, dtype=torch.float32).to(device)
82
+ # Normalize returns for more stable training
83
+ if len(returns) > 1:
84
+ returns = (returns - returns.mean()) / (returns.std() + 1e-8)
85
+ return returns
86
+
87
+
88
+ def moving_average(data, window_size):
89
+ """Compute moving average"""
90
+ if len(data) < window_size:
91
+ return np.array([np.mean(data[:i+1]) for i in range(len(data))])
92
+
93
+ moving_avg = []
94
+ for i in range(len(data)):
95
+ if i < window_size:
96
+ moving_avg.append(np.mean(data[:i+1]))
97
+ else:
98
+ moving_avg.append(np.mean(data[i-window_size+1:i+1]))
99
+ return np.array(moving_avg)
100
+
101
+
102
+ # ==================== Policy Gradient Algorithm ====================
103
+
104
+ def train_policy_gradient(env_name, policy, optimizer, gamma, num_episodes,
105
+ max_steps=None, is_pong=False, action_map=None):
106
+ """
107
+ Train policy using REINFORCE algorithm
108
+
109
+ Args:
110
+ env_name: Name of the gym environment
111
+ policy: Policy network
112
+ optimizer: PyTorch optimizer
113
+ gamma: Discount factor
114
+ num_episodes: Number of training episodes
115
+ max_steps: Maximum steps per episode (None for default)
116
+ is_pong: Whether this is Pong environment
117
+ action_map: Mapping from policy action to env action (for Pong)
118
+ """
119
+ env = gym.make(env_name)
120
+ episode_rewards = []
121
+
122
+ for episode in range(num_episodes):
123
+ state, _ = env.reset()
124
+
125
+ # Preprocess state for Pong
126
+ if is_pong:
127
+ state = preprocess(state)
128
+ prev_frame = None # Track previous frame for motion
129
+
130
+ log_probs = []
131
+ rewards = []
132
+
133
+ done = False
134
+ step = 0
135
+
136
+ while not done:
137
+ # For Pong, use frame difference (motion signal)
138
+ if is_pong:
139
+ cur_frame = state
140
+ if prev_frame is not None:
141
+ state_input = cur_frame - prev_frame
142
+ else:
143
+ state_input = np.zeros_like(cur_frame, dtype=np.float32)
144
+ prev_frame = cur_frame
145
+ state_tensor = torch.FloatTensor(state_input).to(device)
146
+ else:
147
+ # Convert state to tensor
148
+ state_tensor = torch.FloatTensor(state).to(device)
149
+
150
+ # Get action probabilities
151
+ action_probs = policy(state_tensor)
152
+
153
+ # Sample action from the distribution
154
+ dist = Categorical(action_probs)
155
+ action = dist.sample()
156
+ log_prob = dist.log_prob(action)
157
+
158
+ # Map action for Pong (0,1 -> 2,3)
159
+ if is_pong:
160
+ env_action = action_map[action.item()]
161
+ else:
162
+ env_action = action.item()
163
+
164
+ # Take action in environment
165
+ next_state, reward, terminated, truncated, _ = env.step(env_action)
166
+ done = terminated or truncated
167
+
168
+ # Preprocess next state for Pong
169
+ if is_pong:
170
+ next_state = preprocess(next_state)
171
+
172
+ # Store log probability and reward
173
+ log_probs.append(log_prob)
174
+ rewards.append(reward)
175
+
176
+ state = next_state
177
+ step += 1
178
+
179
+ if max_steps and step >= max_steps:
180
+ break
181
+
182
+ # Compute returns
183
+ returns = compute_returns(rewards, gamma)
184
+
185
+ # Compute policy gradient loss
186
+ policy_loss = []
187
+ for log_prob, R in zip(log_probs, returns):
188
+ policy_loss.append(-log_prob * R)
189
+
190
+ # Optimize policy
191
+ optimizer.zero_grad()
192
+ policy_loss = torch.stack(policy_loss).sum()
193
+ policy_loss.backward()
194
+ optimizer.step()
195
+
196
+ # Record episode reward
197
+ episode_reward = sum(rewards)
198
+ episode_rewards.append(episode_reward)
199
+
200
+ # Print progress
201
+ if (episode + 1) % 100 == 0:
202
+ avg_reward = np.mean(episode_rewards[-100:])
203
+ print(f"Episode {episode + 1}/{num_episodes}, "
204
+ f"Avg Reward (last 100): {avg_reward:.2f}")
205
+
206
+ env.close()
207
+ return episode_rewards
208
+
209
+
210
+ def evaluate_policy(env_name, policy, num_episodes=500, is_pong=False, action_map=None):
211
+ """Evaluate trained policy over multiple episodes"""
212
+ env = gym.make(env_name)
213
+ eval_rewards = []
214
+
215
+ for episode in range(num_episodes):
216
+ state, _ = env.reset()
217
+
218
+ if is_pong:
219
+ state = preprocess(state)
220
+ prev_frame = None # Track previous frame for motion
221
+
222
+ episode_reward = 0
223
+ done = False
224
+
225
+ while not done:
226
+ # For Pong, use frame difference (motion signal)
227
+ if is_pong:
228
+ cur_frame = state
229
+ if prev_frame is not None:
230
+ state_input = cur_frame - prev_frame
231
+ else:
232
+ state_input = np.zeros_like(cur_frame, dtype=np.float32)
233
+ prev_frame = cur_frame
234
+ state_tensor = torch.FloatTensor(state_input).to(device)
235
+ else:
236
+ state_tensor = torch.FloatTensor(state).to(device)
237
+
238
+ with torch.no_grad():
239
+ action_probs = policy(state_tensor)
240
+ action = torch.argmax(action_probs).item()
241
+
242
+ if is_pong:
243
+ env_action = action_map[action]
244
+ else:
245
+ env_action = action
246
+
247
+ next_state, reward, terminated, truncated, _ = env.step(env_action)
248
+ done = terminated or truncated
249
+
250
+ if is_pong:
251
+ next_state = preprocess(next_state)
252
+
253
+ episode_reward += reward
254
+ state = next_state
255
+
256
+ eval_rewards.append(episode_reward)
257
+
258
+ if (episode + 1) % 100 == 0:
259
+ print(f"Evaluated {episode + 1}/{num_episodes} episodes")
260
+
261
+ env.close()
262
+ return eval_rewards
263
+
264
+
265
+ def plot_results(episode_rewards, eval_rewards, title, save_prefix):
266
+ """Plot training curve and evaluation histogram"""
267
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
268
+
269
+ # Plot training curve
270
+ ax1 = axes[0]
271
+ episodes = np.arange(1, len(episode_rewards) + 1)
272
+ ma = moving_average(episode_rewards, 100)
273
+
274
+ ax1.plot(episodes, episode_rewards, alpha=0.3, label='Episode Reward')
275
+ ax1.plot(episodes, ma, linewidth=2, label='Moving Average (100 episodes)')
276
+ ax1.set_xlabel('Episode')
277
+ ax1.set_ylabel('Reward')
278
+ ax1.set_title(f'{title} - Training Curve')
279
+ ax1.legend()
280
+ ax1.grid(True, alpha=0.3)
281
+
282
+ # Plot evaluation histogram
283
+ ax2 = axes[1]
284
+ mean_reward = np.mean(eval_rewards)
285
+ std_reward = np.std(eval_rewards)
286
+
287
+ ax2.hist(eval_rewards, bins=30, edgecolor='black', alpha=0.7)
288
+ ax2.axvline(mean_reward, color='red', linestyle='--', linewidth=2,
289
+ label=f'Mean: {mean_reward:.2f}')
290
+ ax2.set_xlabel('Episode Reward')
291
+ ax2.set_ylabel('Frequency')
292
+ ax2.set_title(f'{title} - Evaluation Histogram (500 episodes)\n'
293
+ f'Mean: {mean_reward:.2f}, Std: {std_reward:.2f}')
294
+ ax2.legend()
295
+ ax2.grid(True, alpha=0.3, axis='y')
296
+
297
+ plt.tight_layout()
298
+ plt.savefig(f'{save_prefix}_results.png', dpi=150, bbox_inches='tight')
299
+ plt.show()
300
+
301
+ print(f"\n{title} Evaluation Results:")
302
+ print(f"Mean Reward: {mean_reward:.2f}")
303
+ print(f"Std Reward: {std_reward:.2f}")
304
+
305
+
306
+ # ==================== Main Training Scripts ====================
307
+
308
+ def train_cartpole():
309
+ """Train CartPole-v1"""
310
+ print("\n" + "="*60)
311
+ print("Training CartPole-v1")
312
+ print("="*60 + "\n")
313
+
314
+ # Environment parameters
315
+ env = gym.make('CartPole-v1')
316
+ state_dim = env.observation_space.shape[0]
317
+ action_dim = env.action_space.n
318
+ env.close()
319
+
320
+ # Hyperparameters
321
+ gamma = 0.95
322
+ learning_rate = 0.01
323
+ num_episodes = 1000
324
+
325
+ # Initialize policy and optimizer
326
+ policy = CartPolePolicy(state_dim, action_dim).to(device)
327
+ optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
328
+
329
+ # Train
330
+ episode_rewards = train_policy_gradient(
331
+ 'CartPole-v1', policy, optimizer, gamma, num_episodes
332
+ )
333
+
334
+ # Evaluate
335
+ print("\nEvaluating trained policy...")
336
+ eval_rewards = evaluate_policy('CartPole-v1', policy, num_episodes=500)
337
+
338
+ # Plot results
339
+ plot_results(episode_rewards, eval_rewards, 'CartPole-v1', 'cartpole')
340
+
341
+ # Save model
342
+ torch.save(policy.state_dict(), 'cartpole_policy.pth')
343
+ print("\nModel saved as 'cartpole_policy.pth'")
344
+
345
+ return policy, episode_rewards, eval_rewards
346
+
347
+
348
+ def train_pong():
349
+ """Train Pong-v5"""
350
+ print("\n" + "="*60)
351
+ print("Training Pong-v5")
352
+ print("="*60 + "\n")
353
+
354
+ # Hyperparameters
355
+ gamma = 0.99
356
+ learning_rate = 0.001 # Lower learning rate for Pong
357
+ num_episodes = 3000 # Pong requires more episodes
358
+
359
+ # Action mapping: policy outputs 0 or 1, map to RIGHT(2) or LEFT(3)
360
+ action_map = {0: 2, 1: 3} # 0->RIGHT, 1->LEFT
361
+
362
+ # Initialize policy and optimizer
363
+ policy = PongPolicy(action_dim=2).to(device)
364
+ optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
365
+
366
+ print(f"Using learning rate: {learning_rate}")
367
+ print(f"Action mapping: 0->RIGHT(2), 1->LEFT(3)\n")
368
+
369
+ # Train
370
+ episode_rewards = train_policy_gradient(
371
+ 'ALE/Pong-v5', policy, optimizer, gamma, num_episodes,
372
+ is_pong=True, action_map=action_map
373
+ )
374
+
375
+ # Evaluate
376
+ print("\nEvaluating trained policy...")
377
+ eval_rewards = evaluate_policy(
378
+ 'ALE/Pong-v5', policy, num_episodes=500,
379
+ is_pong=True, action_map=action_map
380
+ )
381
+
382
+ # Plot results
383
+ plot_results(episode_rewards, eval_rewards, 'Pong-v5', 'pong')
384
+
385
+ # Save model
386
+ torch.save(policy.state_dict(), 'pong_policy.pth')
387
+ print("\nModel saved as 'pong_policy.pth'")
388
+
389
+ return policy, episode_rewards, eval_rewards
390
+
391
+
392
+ # ==================== Run Training ====================
393
+
394
+ if __name__ == "__main__":
395
+ # Train CartPole
396
+ cartpole_policy, cartpole_train_rewards, cartpole_eval_rewards = train_cartpole()
397
+
398
+ # Train Pong (this will take longer)
399
+ print("\n\nNote: Pong training will take significantly longer (may take hours)")
400
+ print("You may want to reduce num_episodes if just testing the code.\n")
401
+
402
+ # Uncomment the line below to train Pong
403
+ # pong_policy, pong_train_rewards, pong_eval_rewards = train_pong()