diff --git a/SS13ArcadeEnv.c b/SS13ArcadeEnv.c index 4f9965e..b913469 100644 --- a/SS13ArcadeEnv.c +++ b/SS13ArcadeEnv.c @@ -17,7 +17,8 @@ void CreateEnvironment(struct SS13ArcadeEnv* env) { void Step(struct SS13ArcadeEnv* instance, const int action, struct Observation* obs) { int reward = 0; - bool done = false; + bool terminated = false; + bool truncated = false; switch (action) { case 0: // Attack @@ -34,7 +35,7 @@ void Step(struct SS13ArcadeEnv* instance, const int action, struct Observation* if(instance->enemy_hp <= 0 || instance->enemy_mp <= 0) { // Enemy Defeated reward = 1; - done = true; + terminated = true; } else if(instance->enemy_mp <= 5 && rand() % 1 + 9 >= 7) { // Enemy Drain Player MP instance->player_mp -= rand() % 2 + 1; } else if(instance->enemy_hp <= 10 && instance->enemy_mp > 4) { // Enemy Heal @@ -44,9 +45,11 @@ void Step(struct SS13ArcadeEnv* instance, const int action, struct Observation* instance->player_hp -= rand() % 3 + 3; } - if(instance->player_hp <= 0 || instance->player_mp <= 0 || instance->steps == 200) { + if(instance->player_hp <= 0 || instance->player_mp <= 0) { reward = -1; - done = true; + terminated = true; + } else if(instance->steps == 200) { + truncated = true; } instance->steps += 1; @@ -59,7 +62,8 @@ void Step(struct SS13ArcadeEnv* instance, const int action, struct Observation* obs->state = state; obs->reward = reward; - obs->done = done; + obs->terminated = terminated; + obs->truncated = truncated; } void GetState(struct SS13ArcadeEnv* instance, struct State* state) { diff --git a/SS13ArcadeEnv.h b/SS13ArcadeEnv.h index 0a5d5f6..45af2a2 100644 --- a/SS13ArcadeEnv.h +++ b/SS13ArcadeEnv.h @@ -9,7 +9,8 @@ struct __attribute__((packed)) State { struct __attribute__((packed)) Observation { struct State state; int reward; - bool done; + bool terminated; + bool truncated; }; struct SS13ArcadeEnv {