108 lines
3.2 KiB
Plaintext
108 lines
3.2 KiB
Plaintext
#include "SS13ArcadeEnv.ch"
|
|
#include <stdlib.h>
|
|
#include <stdio.h>
|
|
#include <stdbool.h>
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <sys/types.h>
|
|
#include <unistd.h>
|
|
#include <errno.h>
|
|
#include <string.h>
|
|
#include <time.h>
|
|
|
|
// AI
|
|
#define CUDA_CHECK(cmd) do { cudaError_t e = cmd; if( e != cudaSuccess ) { \
|
|
fprintf(stderr, "Failed: %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); exit(1); } \
|
|
} while(0)
|
|
|
|
int SetupServer(int *socket_fd, int port) {
|
|
printf("[Client] Starting Server\n");
|
|
*socket_fd = socket(AF_INET, SOCK_STREAM, 0);
|
|
struct sockaddr_in server_addr;
|
|
server_addr.sin_family = AF_INET;
|
|
server_addr.sin_port = htons(port);
|
|
server_addr.sin_addr.s_addr = INADDR_ANY;
|
|
|
|
int bound = bind(*socket_fd, (struct sockaddr*)&server_addr, sizeof(server_addr));
|
|
if(bound < 0) {
|
|
printf("[Client] Bind error: %s (%i)\n", strerror(errno), errno);
|
|
return -1;
|
|
}
|
|
listen(*socket_fd, 5);
|
|
return 0;
|
|
}
|
|
|
|
int StartEnv(int port, int workers) {
|
|
int socket_fd;
|
|
int setup = SetupServer(&socket_fd, port);
|
|
if(setup < 0) {
|
|
printf("[Client] Error setting up server\n");
|
|
return -1;
|
|
}
|
|
|
|
struct sockaddr_in client_addr;
|
|
socklen_t client_size = sizeof(client_addr);
|
|
int client_sock = accept(socket_fd, (struct sockaddr*)&client_addr, &client_size);
|
|
|
|
srand(time(NULL));
|
|
|
|
curandState *d_states;
|
|
CUDA_CHECK(cudaMalloc((void**)&d_states, sizeof(curandState)*workers));
|
|
initCurrand<<<1, workers>>>(d_states, 1337);
|
|
CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
|
size_t obs_size = sizeof(struct Observation)*workers;
|
|
size_t envs_size = sizeof(struct SS13ArcadeEnv)*workers;
|
|
struct SS13ArcadeEnv *d_envs;
|
|
struct Observation *d_obs;
|
|
CUDA_CHECK(cudaMalloc((void**)&d_envs, envs_size));
|
|
CUDA_CHECK(cudaMalloc((void**)&d_obs, obs_size));
|
|
|
|
CreateEnvironment<<<1,workers>>>(d_envs, workers);
|
|
CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
|
while(true) {
|
|
char buffer[workers*5];
|
|
int bytes_received;
|
|
bytes_received = recv(client_sock, buffer, sizeof(buffer), 0);
|
|
if(bytes_received < 0) {
|
|
printf("[Client] Receive Error: %s (%d)\n", strerror(errno), errno);
|
|
} else if(bytes_received == 0) {
|
|
break;
|
|
}
|
|
|
|
char *d_buffer;
|
|
CUDA_CHECK(cudaMalloc((void**)&d_buffer, workers*5));
|
|
CUDA_CHECK(cudaMemcpy(d_buffer, buffer, workers*5, cudaMemcpyHostToDevice));
|
|
|
|
ApplyAction<<<1,workers>>>(d_states, d_buffer, d_envs, workers, d_obs);
|
|
CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
|
char data[obs_size];
|
|
CUDA_CHECK(cudaMemcpy(data, d_obs, obs_size, cudaMemcpyDeviceToHost));
|
|
send(client_sock, data, obs_size, 0);
|
|
}
|
|
cudaFree(d_envs);
|
|
cudaFree(d_obs);
|
|
cudaFree(d_states);
|
|
shutdown(client_sock, SHUT_RDWR);
|
|
close(client_sock);
|
|
return 0;
|
|
}
|
|
|
|
// Args: {PORT} {WORKERS}
|
|
int main(int argc, char *argv[]) {
|
|
if(argc != 3) {
|
|
printf("Invalid number of arguments!\n");
|
|
printf("Usage: SS13ArcadeEndpointCUDA {PORT} {WORKERS}\n");
|
|
return 1;
|
|
}
|
|
int env = StartEnv(atoi(argv[1]), atoi(argv[2]));
|
|
if(env < 0) {
|
|
printf("[Client] Error starting environment");
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|