SS13ArcadeTCPCUDA/SS13ArcadeEnv.cu

133 lines
4.2 KiB
Plaintext

#include "SS13ArcadeEnv.ch"
extern "C" int __device__ clamp(int value, int min, int max) {
if(value < min) return min;
if(value > max) return max;
return value;
}
// AI generated method :(
extern "C" uint32_t __device__ device_ntohl(uint32_t x) {
return ((x & 0x000000FFU) << 24) |
((x & 0x0000FF00U) << 8) |
((x & 0x00FF0000U) >> 8) |
((x & 0xFF000000U) >> 24);
}
extern "C" __global__ void initCurrand(curandState *states, unsigned long seed) {
int index = threadIdx.x;
curand_init(seed, index, 0, &states[index]);
}
extern "C" __device__ int randint(curandState *state, int low, int high) {
int span = high - low + 1;
return (curand(state) % span) + low;
}
extern "C" __global__ void CreateEnvironment(struct SS13ArcadeEnv *envs, int n) {
int index = threadIdx.x;
int stride = blockDim.x;
for(int i = index; i < n; i += stride) {
envs[i].player_hp = 30;
envs[i].player_mp = 10;
envs[i].enemy_hp = 45;
envs[i].enemy_mp = 20;
envs[i].steps = 0;
}
}
extern "C" __global__ void ApplyAction(curandState *states, char *actionBufs, struct SS13ArcadeEnv *envs, int n, struct Observation *obs) {
int index = threadIdx.x;
int stride = blockDim.x;
for(int i = index; i < n; i += stride) {
size_t obs_size = sizeof(struct Observation);
if(actionBufs[i] == 0x0) { // Step
uint32_t action;
memcpy(&action, &actionBufs[i]+1, sizeof(uint32_t));
action = device_ntohl(action);
struct Observation new_obs;
Step(&states[i], &envs[i], action, &new_obs);
memcpy(&obs[i], &new_obs, obs_size);
} else if(actionBufs[i] == 0x1) { // Reset
struct Observation new_obs;
Reset(&envs[i], &new_obs);
memcpy(&obs[i], &new_obs, obs_size);
}
}
}
extern "C" __device__ void Step(curandState *randState, struct SS13ArcadeEnv *instance, const int action, struct Observation *obs) {
int reward = 0;
bool terminated = false;
bool truncated = false;
switch (action) {
case 0: // Attack
instance->enemy_hp -= randint(randState, 2, 6); // 2-6
break;
case 1: // Heal
instance->player_hp += randint(randState, 6, 8); // 6-8
instance->player_mp -= randint(randState, 1, 3); // 1-3
break;
case 2: // Charge
instance->player_mp += randint(randState, 4, 7); // 4-7
break;
}
if(instance->enemy_hp <= 0 || instance->enemy_mp <= 0) { // Enemy Defeated
reward = 1;
terminated = true;
} else if(instance->enemy_mp <= 5 && randint(randState, 1, 10) >= 7) { // Enemy Drain Player MP 1-10
instance->player_mp -= randint(randState, 2, 3); // 2-3
} else if(instance->enemy_hp <= 10 && instance->enemy_mp > 4) { // Enemy Heal
instance->enemy_hp += 4;
instance->enemy_mp -= 4;
} else {
instance->player_hp -= randint(randState, 3, 6); // 3-6
}
if(instance->player_hp <= 0 || instance->player_mp <= 0) {
reward = -1;
terminated = true;
} else if(instance->steps == 200) {
truncated = true;
}
instance->steps += 1;
instance->player_hp = clamp(instance->player_hp, 0, 100);
instance->player_mp = clamp(instance->player_mp, 0, 100);
instance->enemy_hp = clamp(instance->enemy_hp, 0, 100);
struct State state;
GetState(instance, &state);
obs->state = state;
obs->reward = reward;
obs->terminated = terminated;
obs->truncated = truncated;
}
extern "C" __device__ void GetState(struct SS13ArcadeEnv *instance, struct State *state) {
state->player_hp = instance->player_hp;
state->player_mp = instance->player_mp;
state->enemy_hp = instance->enemy_hp;
state->enemy_mp = instance->enemy_mp;
state->steps = instance->steps;
}
extern "C" __device__ void Reset(struct SS13ArcadeEnv *instance, struct Observation *obs) {
instance->player_hp = 30;
instance->player_mp = 10;
instance->enemy_hp = 45;
instance->enemy_mp = 20;
instance->steps = 0;
struct State state;
GetState(instance, &state);
obs->state = state;
obs->reward = 0;
obs->terminated = false;
obs->truncated = false;
}