#include "SS13ArcadeEnv.ch" extern "C" int __device__ clamp(int value, int min, int max) { if(value < min) return min; if(value > max) return max; return value; } // AI generated method :( extern "C" uint32_t __device__ device_ntohl(uint32_t x) { return ((x & 0x000000FFU) << 24) | ((x & 0x0000FF00U) << 8) | ((x & 0x00FF0000U) >> 8) | ((x & 0xFF000000U) >> 24); } extern "C" __global__ void initCurrand(curandState *states, unsigned long seed) { int index = threadIdx.x; curand_init(seed, index, 0, &states[index]); } extern "C" __device__ int randint(curandState *state, int low, int high) { int span = high - low + 1; return (curand(state) % span) + low; } extern "C" __global__ void CreateEnvironment(struct SS13ArcadeEnv *envs, int n) { int index = threadIdx.x; int stride = blockDim.x; for(int i = index; i < n; i += stride) { envs[i].player_hp = 30; envs[i].player_mp = 10; envs[i].enemy_hp = 45; envs[i].enemy_mp = 20; envs[i].steps = 0; } } extern "C" __global__ void ApplyAction(curandState *states, char *actionBufs, struct SS13ArcadeEnv *envs, int n, struct Observation *obs) { int index = threadIdx.x; int stride = blockDim.x; for(int i = index; i < n; i += stride) { size_t obs_size = sizeof(struct Observation); if(actionBufs[i] == 0x0) { // Step uint32_t action; memcpy(&action, &actionBufs[i]+1, sizeof(uint32_t)); action = device_ntohl(action); struct Observation new_obs; Step(&states[i], &envs[i], action, &new_obs); memcpy(&obs[i], &new_obs, obs_size); } else if(actionBufs[i] == 0x1) { // Reset struct Observation new_obs; Reset(&envs[i], &new_obs); memcpy(&obs[i], &new_obs, obs_size); } } } extern "C" __device__ void Step(curandState *randState, struct SS13ArcadeEnv *instance, const int action, struct Observation *obs) { int reward = 0; bool terminated = false; bool truncated = false; switch (action) { case 0: // Attack instance->enemy_hp -= randint(randState, 2, 6); // 2-6 break; case 1: // Heal instance->player_hp += randint(randState, 6, 8); // 6-8 instance->player_mp -= randint(randState, 1, 3); // 1-3 break; case 2: // Charge instance->player_mp += randint(randState, 4, 7); // 4-7 break; } if(instance->enemy_hp <= 0 || instance->enemy_mp <= 0) { // Enemy Defeated reward = 1; terminated = true; } else if(instance->enemy_mp <= 5 && randint(randState, 1, 10) >= 7) { // Enemy Drain Player MP 1-10 instance->player_mp -= randint(randState, 2, 3); // 2-3 } else if(instance->enemy_hp <= 10 && instance->enemy_mp > 4) { // Enemy Heal instance->enemy_hp += 4; instance->enemy_mp -= 4; } else { instance->player_hp -= randint(randState, 3, 6); // 3-6 } if(instance->player_hp <= 0 || instance->player_mp <= 0) { reward = -1; terminated = true; } else if(instance->steps == 200) { truncated = true; } instance->steps += 1; instance->player_hp = clamp(instance->player_hp, 0, 100); instance->player_mp = clamp(instance->player_mp, 0, 100); instance->enemy_hp = clamp(instance->enemy_hp, 0, 100); struct State state; GetState(instance, &state); obs->state = state; obs->reward = reward; obs->terminated = terminated; obs->truncated = truncated; } extern "C" __device__ void GetState(struct SS13ArcadeEnv *instance, struct State *state) { state->player_hp = instance->player_hp; state->player_mp = instance->player_mp; state->enemy_hp = instance->enemy_hp; state->enemy_mp = instance->enemy_mp; state->steps = instance->steps; } extern "C" __device__ void Reset(struct SS13ArcadeEnv *instance, struct Observation *obs) { instance->player_hp = 30; instance->player_mp = 10; instance->enemy_hp = 45; instance->enemy_mp = 20; instance->steps = 0; struct State state; GetState(instance, &state); obs->state = state; obs->reward = 0; obs->terminated = false; obs->truncated = false; }