Open main menu

CDOT Wiki β

Changes

A-Team

3 bytes added, 05:42, 1 April 2019
After that and many coffees!
=====After that and many coffees!=====
__global__ void train(float* d_W1, float* d_W2, float* d_W3, float* d_b_X, float* d_b_Y, float* d_a2, float* d_a1, float* d_dyhat, float* d_dW3, float* d_dW2, float* d_dW1, float* d_dz2, float* d_dz1) { int BATCH_SIZE = 256; float lr = .01 / BATCH_SIZE; kdot<<< 50,51>>>(ktranspose(d_a2, BATCH_SIZE, 64), d_dyhat, 64, BATCH_SIZE, 10, d_dW3); kdot << <80,32>> >(d_dyhat, ktranspose(d_W3, 64, 10), BATCH_SIZE, 10, 64, d_dz2); kreluPrime(d_a2, 128 * 64); for (int i = 0; i < BATCH_SIZE * 10; i++) { d_dz2[i] = d_dz2[i] * d_a2[i]; } kdot << <1024, 32>> >(ktranspose(d_a1, BATCH_SIZE, 128), d_dz2, 128, BATCH_SIZE, 64, d_dW2); kdot << <512,32>> >(d_dz2, ktranspose(d_W2, 128, 64), BATCH_SIZE, 64, 128, d_dz1); kreluPrime(d_a1, BATCH_SIZE * 784); for (int i = 0; i < 256 * 64; i++) { d_dz1[i] = d_dz1[i] * d_a1[i]; } kdot <<<512,512,32 >>>(ktranspose(d_b_X, BATCH_SIZE, 784), d_dz1, 784, BATCH_SIZE, 128, d_dW1); // Updating the parameters //W3 = W3 - lr * dW3; for (int i = 0; i < (64*10); i++) { d_W3[i] = d_W3[i] - lr * d_dW3[i]; } //W2 = W2 - lr * dW2; for (int i = 0; i < (128*64); i++) { d_W2[i] = d_W2[i] - lr * d_dW2[i]; } //W1 = W1 - lr * dW1; for (int i = 0; i < (784*128); i++) { d_W1[i] = d_W1[i] - lr * d_dW1[i]; }
kdot<<< 50,51>>>(ktranspose(d_a2, BATCH_SIZE, 64), d_dyhat, 64, BATCH_SIZE, 10, d_dW3);   kdot << <80,32>> >(d_dyhat, ktranspose(d_W3, 64, 10), BATCH_SIZE, 10, 64, d_dz2); kreluPrime(d_a2, 128 * 64); for (int i = 0; i < BATCH_SIZE * 10; i++) { d_dz2[i] = d_dz2[i] * d_a2[i]; }  kdot << <1024, 32>> >(ktranspose(d_a1, BATCH_SIZE, 128), d_dz2, 128, BATCH_SIZE, 64, d_dW2);  kdot << <512,32>> >(d_dz2, ktranspose(d_W2, 128, 64), BATCH_SIZE, 64, 128, d_dz1); kreluPrime(d_a1, BATCH_SIZE * 784); for (int i = 0; i < 256 * 64; i++) { d_dz1[i] = d_dz1[i] * d_a1[i]; }  kdot <<<512,512,32 >>>(ktranspose(d_b_X, BATCH_SIZE, 784), d_dz1, 784, BATCH_SIZE, 128, d_dW1); // Updating the parameters //W3 = W3 - lr * dW3; for (int i = 0; i < (64*10); i++) { d_W3[i] = d_W3[i] - lr * d_dW3[i]; } //W2 = W2 - lr * dW2; for (int i = 0; i < (128*64); i++) { d_W2[i] = d_W2[i] - lr * d_dW2[i]; } //W1 = W1 - lr * dW1; for (int i = 0; i < (784*128); i++) { d_W1[i] = d_W1[i] - lr * d_dW1[i]; }  }
=== Assignment 3 ===
113
edits