개요
내가 만든 모바일 게임인 HungryCat을 유니티 ML-Agent를 이용해서 강화학습 시켜보려고 한다.
다운로드 링크 : https://play.google.com/store/apps/details?id=com.Truer.HungryCat
먼저 게임 소개를 간단히 하자면 플래피 버드와 게임 방식이 똑같다.
터치하면 고양이가 점프를 하고, 굴뚝 사이를 지나가면 되는 게임이다.
먼저 강화학습을 위해 state, behavior, reward를 정의해 보겠다.
State
state는 2가지 방식을 생각할 수 있다.
첫째로는 RenderTexture를 활용한 화면 캡처 방식
두번째로는 오브젝트의 상태를 알려주는 방식
첫번째 방식은 조금 문제가 있는게, 배경이 시시각각 변화하고, 무엇보다 고양이가 중력의 영향을 받기 때문에 하나의 사진으로는 현재 상태를 정확히 알 수가 없다. 따라서 2~3장의 연속된 사진을 사용해야하는데, 이렇게 되면 그냥 두번째 방법을 사용하는 것 보다 훨씬 비효율적이게 된다.
따라서 오브젝트의 상태 (고양이의 현재 속도, 고양이의 현재 위치, 굴뚝의 위치, 고양이의 y축 속도)를 알려주는 방식으로 사용하려 한다.
Behaivor
behavior는 고양이가 점프를 하는것 (1) 하지 않는 것 (0) 2가지로 정의할 수 있다.
Reward
보상은 살아남을 때 마다 1, 죽으면 -10을 주는 식으로 설정하였다.
유니티 설정
먼저 씬을 하나로 통합하고, Agent 코드를 짠다.
CatAgent
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.SceneManagement;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Rendering;
using UnityEngine.Serialization;
public class CatAgent : Agent
{
public float strong = 8.0f;
Rigidbody2D rb;
static public GameObject pipe;
//초기화
public override void Initialize()
{
rb = GetComponent<Rigidbody2D>();
}
//state
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(this.gameObject.transform.position.y); //고양이 y값
sensor.AddObservation(pipe.transform.position.x); //파이프 x값
sensor.AddObservation(pipe.transform.position.y); //파이프 y값
sensor.AddObservation(rb.velocity.y); //고양이 속도
}
//에피소드 시작 초기화
public override void OnEpisodeBegin()
{
SetReward(0.0f);
Score.score = 0;
DestroyAllPipes();
Vector3 position = this.gameObject.transform.position;
position.y = 0;
this.gameObject.transform.position = position;
rb.velocity = Vector2.zero;
rb.angularVelocity = 0;
rb.gravityScale = 0;
this.gameObject.transform.rotation = Quaternion.identity;
}
const int noJump = 0;
const int Jump = 1;
//액션
public override void OnActionReceived(ActionBuffers actionBuffers)
{
if(rb.gravityScale == 0)
rb.gravityScale = 1.3f;
var action = actionBuffers.DiscreteActions[0];
AddReward(1.0f);
switch (action)
{
case noJump:
break;
case Jump:
rb.velocity = Vector2.up * strong;
break;
}
}
//파이프 전부 제거
void DestroyAllPipes()
{
// "Pipe" 태그를 가진 모든 게임 오브젝트를 찾습니다.
GameObject[] pipes = GameObject.FindGameObjectsWithTag("Pipe");
// 찾은 모든 게임 오브젝트를 삭제합니다.
foreach (GameObject pipe in pipes)
{
Destroy(pipe);
}
}
// 클릭하면 점프
void Update()
{
if(Input.GetMouseButtonDown(0))
{
rb.velocity = Vector2.up * strong;
}
}
//부딪히면 에피소드 종료
private void OnCollisionEnter2D(Collision2D other)
{
AddReward(-10.0f);
EndEpisode();
}
}
파이프 생성과 파이프 이동은 원본 게임과 똑같이 설정하므로 생략.
단 파이프가 생성될 때 CatAgent의 Pipe 정보를 업데이트를 해줘야 하므로 해당 코드를 추가한다.
void Start()
{
CatAgent.pipe = this.gameObject;
}
state는 위에서 설명했듯이 4종류이므로 Space Size를 4로 해주고,
고양이의 행동은 가만히 있기와 점프하기 2가지 이므로 Behavior Parameters의 Branch 0 Size를 2로 설정해주면 된다.
이제 A2C 알고리즘을 이용하여 학습을 시키려고 한다.
A2C 알고리즘
먼저 대전제는 인공지능은 크게 보면 함수라고 볼 수 있다.
입력 -> 출력의 형태이며, 무엇을 출력하느냐에 따라 알고리즘 틀이 결정된다.
A2C알고리즘은 액터-크리틱 알고리즘이다.
액터-크리틱은 가치 기반 강화학습과 정책 기반 강화학습을 결합한 형태이다.
가치기반 강화학습은 state -> 해당 state의 가치(상태 가치 함수) 혹은 state/action -> 해당 state에서 하는 action의 가치 (행동 가치함수)를 출력시키게 하는 방법이다.
정책 기반 강화학습은 state -> 정책(각 액션들을 수행할 확률분포)를 출력시키게 하는 방법이다.
A2C알고리즘은 한 신경망에서 마지막 출력층만 2개로 하는 것으로, 1개는 softmax함수를 이용하여 정책을 출력하게 하고, 나머지 하나는 해당 state의 가치를 출력하게 한다.
정책을 출력하는 정책 함수는 다음과 같이 나타낼 수 있다.
정책 가치 함수 = state s에서 action a를 할 확률 분포 * 각 action의 가치 들의 합
으로 표현될 수 있다.
이때 목표는 정책 가치 함수를 최대화 시키는 것이므로 정책 가치 함수에 -를 곱한 값을 손실함수로 만들면 된다.
state의 가치를 출력하는 함수의 손실함수는 평균 제곱 오차로 나타낸다.
평균제곱 오차는 타깃값과 현재 가치 함수 값의 차를 제곱한 뒤 1/2로 나눈 것이다.
이때 시간 i에서 타깃값(원래 정답)은 i에서의 보상 + 감가율*i+1에서의 상태에서의 가치 함수 값이다.
즉 보상+미래의가치 와 현재의 가치의 예상값의 오차를 줄이는 방식으로 학습시키게 된다.
이렇게 구한 손실함수 값을 더한 뒤 optimizer를 이용하여 역전파 방향으로 학습시키면 된다.
파이토치 코드
코드는 [파이토치와 유니티 ML-Agents로 배우는 강화학습]이라는 책에 나오는 A2C 알고리즘을 참고하여 만들었다.
import numpy as np
import random
import copy
import datetime
import platform
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque
from mlagents_envs.environment import UnityEnvironment, ActionTuple
from mlagents_envs.side_channel.engine_configuration_channel\
import EngineConfigurationChannel
state_size = 4
action_size = 2
load_model = False
train_mode = True
discount_factor = 0.8
learning_rate = 0.0000002
run_step = 50000 if train_mode else 0
test_step = 500
OBS = 2
print_interval = 100
save_interval = 100
game = "HungryCatLearn"
env_name = "HungryCatLearn/HungryCatLearn"
loaddate = ""
date_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
save_path = f"./saved_models/{game}/A2C/{date_time}"
load_path = f"./saved_models/{game}/A2C/{loaddate}"
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
class A2C(torch.nn.Module):
def __init__(self,**kwargs):
super(A2C,self).__init__(**kwargs)
self.d1 = torch.nn.Linear(state_size,256)
self.d2 = torch.nn.Linear(256,256)
self.pi = torch.nn.Linear(256,action_size)
self.v = torch.nn.Linear(256,1)
def forward(self,x):
x = F.relu(self.d1(x))
x = F.relu(self.d2(x))
#print(F.softmax(self.pi(x),dim=1))
return F.softmax(self.pi(x),dim=1),self.v(x)
class A2CAgent:
def __init__(self):
self.a2c = A2C().to(device)
self.optimizer = torch.optim.Adam(self.a2c.parameters(),lr=learning_rate)
self.writer = SummaryWriter(save_path)
if load_model == True:
print(f"Load model from {load_path}/ckpt ...")
checkpoint = torch.load(load_path+'/ckpt', map_location=device)
self.a2c.load_state_dict(checkpoint["network"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
def get_action(self,state,trainig=True):
self.a2c.train(trainig)
pi, _ = self.a2c(torch.FloatTensor(state).to(device))
action = torch.multinomial(pi,num_samples=1).cpu().numpy()
return action
def train_model(self,state,action,reward,next_state,done):
state,action,reward,next_state,done = map(lambda x: torch.FloatTensor(x).to(device),[state,action,reward,next_state,done])
pi, value = self.a2c(state)
print(state)
with torch.no_grad():
_,next_value = self.a2c(next_state)
target_value = reward + discount_factor * next_value
critic_loss = F.mse_loss(target_value,value)
eye = torch.eye(action_size).to(device)
one_hot_action = eye[action.view(-1).long()]
advantage = (target_value - value).detach()
actor_loss = -(torch.log((one_hot_action*pi).sum(1))*advantage).mean()
total_loss = critic_loss + actor_loss
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
return actor_loss.item(),critic_loss.item()
def save_model(self):
print(f"Save Model to {save_path}/ckpt")
torch.save({
"network" : self.a2c.state_dict(),
"optimizer" : self.optimizer.state_dict(),
},save_path+'/ckpt')
def write_summary(self,score,actor_loss,critic_loss,step):
self.writer.add_scalar("run/score",score,step)
self.writer.add_scalar("model/acotr_loss",actor_loss,step)
self.writer.add_scalar("model/critic_loss",critic_loss,step)
if __name__ == '__main__':
engine_configuration_channel = EngineConfigurationChannel()
env = UnityEnvironment(file_name=env_name,
side_channels=[engine_configuration_channel])
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
spec = env.behavior_specs[behavior_name]
engine_configuration_channel.set_configuration_parameters(time_scale=12.0)
dec, term = env.get_steps(behavior_name)
agent = A2CAgent()
actor_losses, critic_losses, scores, episode, score = [],[],[],0,0
step = 0
while(step <= run_step + test_step):
if step == run_step:
if train_mode:
agent.save_model()
print("TEST START")
train_mode= False
engine_configuration_channel.set_configuration_parameters(time_scale=1.0)
#print(dec.obs)
state = dec.obs[0]
action = agent.get_action(state,train_mode)
#print(action)
action_tuple = ActionTuple()
action_tuple.add_discrete(action)
env.set_actions(behavior_name,action_tuple)
env.step()
dec,term = env.get_steps(behavior_name)
done = len(term.agent_id)>0
reward = term.reward if done else dec.reward
if done:
next_state = term.obs[0] #여기도 수정해야됨
else:
next_state = dec.obs[0]
score += reward[0]
if train_mode:
actor_loss, critic_loss = agent.train_model(state,action[0],[reward],next_state,[done])
actor_losses.append(actor_loss)
critic_losses.append(critic_loss)
if done:
print(f"Episode : {episode}, socre : {score:.1f}")
episode += 1
step += 1
scores.append(score)
score = 0
if episode % print_interval == 0 and episode != 0:
mean_score = np.mean(scores)
mean_actor_loss = np.mean(actor_losses) if len(actor_losses) > 0 else 0
mean_critic_loss = np.mean(critic_losses) if len(critic_losses) > 0 else 0
agent.write_summary(mean_score,mean_actor_loss,mean_critic_loss,step)
actor_losses, critic_losses, scores = [],[],[]
print(f"{episode} Episode / Step: {step} / Score: {mean_score:.2f}/ " + \
f"Actor loss: {mean_actor_loss:.2f} / Critic loss: {mean_critic_loss:.4f}")
if train_mode and episode % save_interval == 0:
agent.save_model()
print("End")
env.close()
state는 위에서 설명했듯이 4개이고 액션은 2개이다.
은닉층은 256개의 노드가 있는 relu를 활성화 함수로 하여 2층을 사용하였다.
출력은 가치를 출력하는 노드와 정책을 출력하는 노드 2개이다.
나머지 코드는 하이퍼파라미터와 데이터 저장 등을 담당하는 코드이다.
결과
결과는 코드를 잘못 짰는지, 계속 고양이가 무한으로 점프하는 방향으로 이루어 졌다.
하이퍼 파라미터의 문제인지 확인을 위해 learning rate를 높여보기도 하고, 작게 해보기도 하고, 점프와 가만히 있기의 인덱스를 바꿔보기도 했지만 결과는 똑같았다.
다음에는 다른 알고리즘을 사용하여 학습을 시켜봐야겠다.
혹시나 학습을 시켜보고 싶다면
ryujm1828/HungryCatReinforceLearning (github.com)
이곳에서 해볼 수 있다.
'인공지능 > 강화학습' 카테고리의 다른 글
HungryCat강화학습3 - BC + DQN 알고리즘 (with ML-Agent) (0) | 2023.08.05 |
---|---|
HungryCat 강화학습2 - DQN 알고리즘 (with ML-Agent) (0) | 2023.08.03 |
REINFORCE 알고리즘을 이용한 CartPole 강화학습 (0) | 2023.06.10 |
DQN을 이용한 CartPole 강화학습 (0) | 2023.06.07 |