本任务中执行的NanoGPT模块是一个序列到序列的模型。输入是一系列单词,例如短语 “The course of true love never did run smooth”。模型的输出是一个新的单词序列,可能是根据已经在大量莎士比亚文本上进行训练的模型确定的输入后续。例如,给定上述前缀,模型的输出可能是 “whispered cs149 students whilst coding on assignments”。
// Step #2: Implement Read/Write Accessors for a 4D Tensor inlinefloatfourDimRead(std::vector<float> &tensor, int &x, int &y, int &z, int &b, constint &sizeX, constint &sizeY, constint &sizeZ){ return tensor[x*sizeX*sizeY*sizeZ+y*sizeY*sizeZ+z*sizeZ+b]; }
inlinevoidfourDimWrite(std::vector<float> &tensor, int &x, int &y, int &z, int &b, constint &sizeX, constint &sizeY, constint &sizeZ, float &val){ tensor[x*sizeX*sizeY*sizeZ+y*sizeY*sizeZ+z*sizeZ+b] = val; return; }
4D张量的布局应该是按照b,z,y,x的顺序一维展开,这样更符合局部性,相邻元素是连续存储的。
Part 1: A Simple (But Not So Efficient) Implementation of Attention
第一步首先要实现没有优化的串行注意力层。框架中给出了如何访问4Dtensor的示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//Here is an example of how to read/write 0's to Q (B, H, N, d) using the 4D accessors //loop over Batch Size for (int b = 0; b < B; b++) { //loop over Heads for (int h = 0; h < H; h++) { //loop over Sequence Length for (int i = 0; i < N; i++) { //loop over Embedding Dimensionality for (int j = 0; j < d; j++) { float val = fourDimRead(Q, b, h, i, j, H, N, d); val = 0.0; fourDimWrite(Q, b, h, i, j, H, N, d, val); } } } }
这里需要注意,正如2D数组访问需要使用的是num_cols,即第二维的SIZE一样,传入4Dtensor访问(b,h,n,d)的SIZE X Y Z应该是H N d,而不需要知道Batch SIZE。
对于需要实现的注意力层,步骤如下:
STEP1
For each Batch: For each Head: 遍历 Q 和 K,并将Q与K^t相乘,将结果存储在 QK^t 中。QK^t已经预先分配,并作为参数传递给 myNaiveAttention。经过Batch和Head索引后得到的是(N, d)的Q和K的2D矩阵。还要注意K的维度为(N, d),而想要的K^t的维度为(d, N)。可以不进行转置,直接调整矩阵相乘时的行列顺序就可以了。
torch::Tensor myNaiveAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor, int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d) //QK^t Intermediate Tensor has Shape (N, N) //Make O Tensor with Shape (B, H, N, d) at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors std::vector<float> O = formatTensor(OTensor); std::vector<float> Q = formatTensor(QTensor); std::vector<float> K = formatTensor(KTensor); std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector. std::vector<float> QK_t = formatTensor(QK_tTensor); // -------- YOUR CODE HERE -------- // //caculate QK_t = Q(N,d) * K^t(d, N) //loop over Batch Size for(int b = 0; b < B; b++){ //loop over Heads for(int h = 0; h < H; h++){ //loop over Sequence Length ijk calculate S(i,j) = Q(i,k)*Q(j,k) for(int i = 0; i < N; i++){ for(int j = 0; j < N; j++){ float val = 0; for(int k = 0; k < d; k++){ val += fourDimRead(Q, b, h, i, k, H, N, d) * fourDimRead(K, b, h, j, k, H, N, d); } twoDimWrite(QK_t, i, j, N, val); } } //softmax for(int i = 0; i < N ; i++){ float sum = 0.0; for(int j = 0; j < N; j++){ sum += exp(twoDimRead(QK_t, i, j, N)); } for(int j = 0; j < N; j++){ float val = exp(twoDimRead(QK_t, i, j, N))/sum; twoDimWrite(QK_t, i, j, N, val); } } //loop over Sequence Length ikj calculate O(i,j) = QK_t(i,k)*V(k,j) for(int i = 0; i < N; i++){ for(int k = 0; k < N; k++){ for(int j = 0; j < d; j++){ float val = fourDimRead(O, b, h, i, j, H, N, d); val += twoDimRead(QK_t, i, k, N) * fourDimRead(V, b, h, k, j, H, N, d); fourDimWrite(O, b, h, i, j, H, N, d, val); } } } } } // DO NOT EDIT THIS RETURN STATEMENT // // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it // return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone(); }
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor, int B, int H, int N, int d){ // Q, K, V are passed in with Shape: (B, H, N, d) //QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d) at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors std::vector<float> O = formatTensor(OTensor); std::vector<float> Q = formatTensor(QTensor); std::vector<float> K = formatTensor(KTensor); std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector. std::vector<float> QK_t = formatTensor(QK_tTensor);
// -------- YOUR CODE HERE -------- // //caculate QK_t = Q(N,d) * K^t(d, N) //loop over Batch Size for(int b = 0; b < B; b++){ //loop over Heads for(int h = 0; h < H; h++){ //loop over Sequence Length ijk calculate S(i,j) = Q(i,k)*Q(j,k) std::fill(QK_t.begin(), QK_t.end(), 0); int bsize_H, bsize_W; bsize_W = 16, bsize_H = 64; for(int is = 0; is < N; is += bsize_H){ for(int js = 0; js < N; js += bsize_W){ for(int ks = 0; ks < d; ks += bsize_W){ for(int i = is; i < std::min(is+bsize_H, N); i++){ for(int j = js; j < std::min(js+bsize_W, N); j++){ float val = twoDimRead(QK_t, i, j, N); for(int k = ks; k < std::min(ks+bsize_W, d); k++){ val += fourDimRead(Q, b, h, i, k, H, N, d) * fourDimRead(K, b, h, j, k, H, N, d); } twoDimWrite(QK_t, i, j, N, val); } } } } } //softmax for(int i = 0; i < N ; i++){ float sum = 0.0; for(int j = 0; j < N; j++){ sum += exp(twoDimRead(QK_t, i, j, N)); } for(int j = 0; j < N; j++){ float val = exp(twoDimRead(QK_t, i, j, N))/sum; twoDimWrite(QK_t, i, j, N, val); } } //loop over Sequence Length ikj calculate O(i,j) = QK_t(i,k)*V(k,j) for(int is = 0; is < N; is += bsize_H){ for(int js = 0; js < d; js += bsize_W){ for(int ks = 0; ks < N; ks += bsize_W){ for(int i = is; i < std::min(is+bsize_H, N); i++){ for(int k = ks; k < std::min(ks+bsize_W ,N); k++){ for(int j = js; j < std::min(js+bsize_W, d); j++){ float val = fourDimRead(O, b, h, i, j, H, N, d); val += twoDimRead(QK_t, i, k, N) * fourDimRead(V, b, h, k, j, H, N, d); fourDimWrite(O, b, h, i, j, H, N, d, val); } } } } } } } } // DO NOT EDIT THIS RETURN STATEMENT // // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it // return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone(); }
仍然取N = 1000,测试结果如下:
1 2 3 4 5 6 7
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics cpu time: 148.424ms mem usage: 4512000 bytes STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics cpu time: 99.278ms mem usage: 4512000 bytes
torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp, int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//Make O Tensor with Shape (B, H, N, d) //and O Row Tensor with Shape (N) at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat); at::Tensor ORowTensor = at::zeros({N}, at::kFloat);
//Format Y, Q, K, and V tensors into 4D vectors std::vector<float> O = formatTensor(OTensor); std::vector<float> Q = formatTensor(QTensor); std::vector<float> K = formatTensor(KTensor); std::vector<float> V = formatTensor(VTensor); //Format ORow Tensor into a 1D vector // You can simply access this as ORow[i] std::vector<float> ORow = formatTensor(ORowTensor);
// -------- YOUR CODE HERE -------- // // We give you a template of the first three loops for your convenience //loop over batch #pragma omp parallel for collapse(3) for (int b = 0; b < B; b++){ //loop over heads for (int h = 0; h < H; h++){ for (int i = 0; i < N ; i++){ // YRow is moved inside so each OpenMP thread gets a local copy. at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)}); std::vector<float> ORow = formatTensor(ORowTensor); std::fill(ORow.begin(), ORow.end(), 0); float sum = 0.0; for(int j = 0; j < N; j++){ float val = 0; for(int k = 0; k < d; k++){ val += fourDimRead(Q, b, h, i, k, H, N, d) * fourDimRead(K, b, h, j, k, H, N, d); } float tmp = exp(val); ORow[j] = tmp; sum += tmp; } //softmax for(int j = 0; j < N; j++) ORow[j] = ORow[j]/sum; //calculate O for(int k = 0; k < N; k++){ for(int j = 0; j < d; j++){ float val = fourDimRead(O, b, h, i, j, H, N, d); val += ORow[k] * fourDimRead(V, b, h, k, j, H, N, d); fourDimWrite(O, b, h, i, j, H, N, d, val); } } } } } // DO NOT EDIT THIS RETURN STATEMENT // // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it // return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone(); }
测试结果:
1 2 3 4 5 6 7
REFERENCE - FUSED ATTENTION statistics cpu time: 36.62ms mem usage: 544000 bytes
STUDENT - FUSED ATTENTION statistics cpu time: 20.223ms mem usage: 544000 bytess
Part 4 : Putting it all Together - Flash Attention