feat: Initial Commit
This commit is contained in:
commit
65837e3e29
|
@ -0,0 +1,6 @@
|
|||
CompileFlags:
|
||||
Add:
|
||||
- --cuda-path=/opt/cuda
|
||||
- --cuda-gpu-arch=sm_89 # Replace XX with your actual GPU architecture, e.g., 86
|
||||
- -I/opt/cuda/include
|
||||
- -L/opt/cuda/lib64
|
|
@ -0,0 +1,2 @@
|
|||
build/
|
||||
.cache
|
|
@ -0,0 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.15)
|
||||
project(SS13ArcadeEndpointCUDA LANGUAGES C CUDA)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_BUILD_TYPE Debug)
|
||||
set(CMAKE_CUDA_ARCHITECTURES 89)
|
||||
add_executable(SS13ArcadeEndpointCUDA main.cu SS13ArcadeEnv.cu)
|
|
@ -0,0 +1,40 @@
|
|||
#ifndef ARCADE_ENV_H
|
||||
#define ARCADE_ENV_H
|
||||
#include <curand_kernel.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
struct __attribute__((packed)) State {
|
||||
int player_hp, player_mp, enemy_hp, enemy_mp, steps;
|
||||
};
|
||||
|
||||
struct __attribute__((packed)) Observation {
|
||||
struct State state;
|
||||
int reward;
|
||||
bool terminated;
|
||||
bool truncated;
|
||||
};
|
||||
|
||||
struct SS13ArcadeEnv {
|
||||
int player_hp, player_mp, enemy_hp, enemy_mp, steps;
|
||||
};
|
||||
|
||||
__device__ int clamp(int value, int min, int max);
|
||||
__device__ uint32_t device_ntohl(uint32_t x);
|
||||
|
||||
__global__ void initCurrand(curandState *states, unsigned long seed);
|
||||
__device__ int randint(curandState *state, int low, int high);
|
||||
|
||||
__global__ void CreateEnvironment(struct SS13ArcadeEnv *env, int n);
|
||||
__global__ void ApplyAction(curandState *states, char *actionBuf, struct SS13ArcadeEnv *env, int n, struct Observation *obs);
|
||||
__device__ void Step(curandState *state, struct SS13ArcadeEnv *instance, const int action, struct Observation *obs);
|
||||
__device__ void GetState(struct SS13ArcadeEnv *env, struct State *state);
|
||||
__device__ void Reset(struct SS13ArcadeEnv *instance, struct Observation *obs);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -0,0 +1,132 @@
|
|||
#include "SS13ArcadeEnv.ch"
|
||||
|
||||
extern "C" int __device__ clamp(int value, int min, int max) {
|
||||
if(value < min) return min;
|
||||
if(value > max) return max;
|
||||
return value;
|
||||
}
|
||||
|
||||
// AI generated method :(
|
||||
extern "C" uint32_t __device__ device_ntohl(uint32_t x) {
|
||||
return ((x & 0x000000FFU) << 24) |
|
||||
((x & 0x0000FF00U) << 8) |
|
||||
((x & 0x00FF0000U) >> 8) |
|
||||
((x & 0xFF000000U) >> 24);
|
||||
}
|
||||
|
||||
extern "C" __global__ void initCurrand(curandState *states, unsigned long seed) {
|
||||
int index = threadIdx.x;
|
||||
curand_init(seed, index, 0, &states[index]);
|
||||
}
|
||||
|
||||
extern "C" __device__ int randint(curandState *state, int low, int high) {
|
||||
int span = high - low + 1;
|
||||
return (curand(state) % span) + low;
|
||||
}
|
||||
|
||||
extern "C" __global__ void CreateEnvironment(struct SS13ArcadeEnv *envs, int n) {
|
||||
int index = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
for(int i = index; i < n; i += stride) {
|
||||
envs[i].player_hp = 30;
|
||||
envs[i].player_mp = 10;
|
||||
envs[i].enemy_hp = 45;
|
||||
envs[i].enemy_mp = 20;
|
||||
envs[i].steps = 0;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void ApplyAction(curandState *states, char *actionBufs, struct SS13ArcadeEnv *envs, int n, struct Observation *obs) {
|
||||
int index = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
for(int i = index; i < n; i += stride) {
|
||||
size_t obs_size = sizeof(struct Observation);
|
||||
if(actionBufs[i] == 0x0) { // Step
|
||||
uint32_t action;
|
||||
memcpy(&action, &actionBufs[i]+1, sizeof(uint32_t));
|
||||
action = device_ntohl(action);
|
||||
struct Observation new_obs;
|
||||
Step(&states[i], &envs[i], action, &new_obs);
|
||||
memcpy(&obs[i], &new_obs, obs_size);
|
||||
} else if(actionBufs[i] == 0x1) { // Reset
|
||||
struct Observation new_obs;
|
||||
Reset(&envs[i], &new_obs);
|
||||
memcpy(&obs[i], &new_obs, obs_size);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __device__ void Step(curandState *randState, struct SS13ArcadeEnv *instance, const int action, struct Observation *obs) {
|
||||
int reward = 0;
|
||||
bool terminated = false;
|
||||
bool truncated = false;
|
||||
|
||||
switch (action) {
|
||||
case 0: // Attack
|
||||
instance->enemy_hp -= randint(randState, 2, 6); // 2-6
|
||||
break;
|
||||
case 1: // Heal
|
||||
instance->player_hp += randint(randState, 6, 8); // 6-8
|
||||
instance->player_mp -= randint(randState, 1, 3); // 1-3
|
||||
break;
|
||||
case 2: // Charge
|
||||
instance->player_mp += randint(randState, 4, 7); // 4-7
|
||||
break;
|
||||
}
|
||||
|
||||
if(instance->enemy_hp <= 0 || instance->enemy_mp <= 0) { // Enemy Defeated
|
||||
reward = 1;
|
||||
terminated = true;
|
||||
} else if(instance->enemy_mp <= 5 && randint(randState, 1, 10) >= 7) { // Enemy Drain Player MP 1-10
|
||||
instance->player_mp -= randint(randState, 2, 3); // 2-3
|
||||
} else if(instance->enemy_hp <= 10 && instance->enemy_mp > 4) { // Enemy Heal
|
||||
instance->enemy_hp += 4;
|
||||
instance->enemy_mp -= 4;
|
||||
} else {
|
||||
instance->player_hp -= randint(randState, 3, 6); // 3-6
|
||||
}
|
||||
|
||||
if(instance->player_hp <= 0 || instance->player_mp <= 0) {
|
||||
reward = -1;
|
||||
terminated = true;
|
||||
} else if(instance->steps == 200) {
|
||||
truncated = true;
|
||||
}
|
||||
instance->steps += 1;
|
||||
|
||||
instance->player_hp = clamp(instance->player_hp, 0, 100);
|
||||
instance->player_mp = clamp(instance->player_mp, 0, 100);
|
||||
instance->enemy_hp = clamp(instance->enemy_hp, 0, 100);
|
||||
|
||||
struct State state;
|
||||
GetState(instance, &state);
|
||||
|
||||
obs->state = state;
|
||||
obs->reward = reward;
|
||||
obs->terminated = terminated;
|
||||
obs->truncated = truncated;
|
||||
}
|
||||
|
||||
extern "C" __device__ void GetState(struct SS13ArcadeEnv *instance, struct State *state) {
|
||||
state->player_hp = instance->player_hp;
|
||||
state->player_mp = instance->player_mp;
|
||||
state->enemy_hp = instance->enemy_hp;
|
||||
state->enemy_mp = instance->enemy_mp;
|
||||
state->steps = instance->steps;
|
||||
}
|
||||
|
||||
extern "C" __device__ void Reset(struct SS13ArcadeEnv *instance, struct Observation *obs) {
|
||||
instance->player_hp = 30;
|
||||
instance->player_mp = 10;
|
||||
instance->enemy_hp = 45;
|
||||
instance->enemy_mp = 20;
|
||||
instance->steps = 0;
|
||||
|
||||
struct State state;
|
||||
GetState(instance, &state);
|
||||
obs->state = state;
|
||||
obs->reward = 0;
|
||||
obs->terminated = false;
|
||||
obs->truncated = false;
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
#include "SS13ArcadeEnv.ch"
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
// AI
|
||||
#define CUDA_CHECK(cmd) do { cudaError_t e = cmd; if( e != cudaSuccess ) { \
|
||||
fprintf(stderr, "Failed: %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); exit(1); } \
|
||||
} while(0)
|
||||
|
||||
int SetupServer(int *socket_fd, int port) {
|
||||
printf("[Client] Starting Server\n");
|
||||
*socket_fd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
struct sockaddr_in server_addr;
|
||||
server_addr.sin_family = AF_INET;
|
||||
server_addr.sin_port = htons(port);
|
||||
server_addr.sin_addr.s_addr = INADDR_ANY;
|
||||
|
||||
int bound = bind(*socket_fd, (struct sockaddr*)&server_addr, sizeof(server_addr));
|
||||
if(bound < 0) {
|
||||
printf("[Client] Bind error: %s (%i)\n", strerror(errno), errno);
|
||||
return -1;
|
||||
}
|
||||
listen(*socket_fd, 5);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int StartEnv(int port, int workers) {
|
||||
int socket_fd;
|
||||
int setup = SetupServer(&socket_fd, port);
|
||||
if(setup < 0) {
|
||||
printf("[Client] Error setting up server\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in client_addr;
|
||||
socklen_t client_size = sizeof(client_addr);
|
||||
int client_sock = accept(socket_fd, (struct sockaddr*)&client_addr, &client_size);
|
||||
|
||||
srand(time(NULL));
|
||||
|
||||
curandState *d_states;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_states, sizeof(curandState)*workers));
|
||||
initCurrand<<<1, workers>>>(d_states, 1337);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
size_t obs_size = sizeof(struct Observation)*workers;
|
||||
size_t envs_size = sizeof(struct SS13ArcadeEnv)*workers;
|
||||
struct SS13ArcadeEnv *d_envs;
|
||||
struct Observation *d_obs;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_envs, envs_size));
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_obs, obs_size));
|
||||
|
||||
CreateEnvironment<<<1,workers>>>(d_envs, workers);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
while(true) {
|
||||
char buffer[workers*5];
|
||||
int bytes_received;
|
||||
bytes_received = recv(client_sock, buffer, sizeof(buffer), 0);
|
||||
if(bytes_received < 0) {
|
||||
printf("[Client] Receive Error: %s (%d)\n", strerror(errno), errno);
|
||||
} else if(bytes_received == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
char *d_buffer;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_buffer, workers*5));
|
||||
CUDA_CHECK(cudaMemcpy(d_buffer, buffer, workers*5, cudaMemcpyHostToDevice));
|
||||
|
||||
ApplyAction<<<1,workers>>>(d_states, d_buffer, d_envs, workers, d_obs);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
char data[obs_size];
|
||||
CUDA_CHECK(cudaMemcpy(data, d_obs, obs_size, cudaMemcpyDeviceToHost));
|
||||
send(client_sock, data, obs_size, 0);
|
||||
}
|
||||
cudaFree(d_envs);
|
||||
cudaFree(d_obs);
|
||||
cudaFree(d_states);
|
||||
shutdown(client_sock, SHUT_RDWR);
|
||||
close(client_sock);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Args: {PORT} {WORKERS}
|
||||
int main(int argc, char *argv[]) {
|
||||
if(argc != 3) {
|
||||
printf("Invalid number of arguments!\n");
|
||||
printf("Usage: SS13ArcadeEndpointCUDA {PORT} {WORKERS}\n");
|
||||
return 1;
|
||||
}
|
||||
int env = StartEnv(atoi(argv[1]), atoi(argv[2]));
|
||||
if(env < 0) {
|
||||
printf("[Client] Error starting environment");
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue