I recently read a research paper titled "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations," authored by Philippe Tillet, H. T. Kung, and David Cox, presented at the 2019 ACM SIGPLAN International Workshop on Machine Learning and Programming Languages (MAPL), and thought I’ll share my notes and insights.
The challenge in Deep Learning (DL) due to limited efficient compute kernels is real. It applies to (almost everything), but is critical especially for mobiles, robotics, IoT and on-glasses deployments.
GPUs have revolutionized Deep Learning (DL), but existing libraries fall short in supporting a wide range of operations. Triton emerges as a solution, offering a unique approach to efficiently handle complex neural network tasks. It includes:
Triton-C, a specialized language
Triton Intermediate Representation (IR)
Just-In-Time (JIT) compiler
(all aimed at optimizing neural network computations)
Triton-C is tailored for neural network tasks, with a syntax and semantics designed to simplify tensor programming and optimize GPU usage. Triton-IR, based on LLVM, plays a crucial role in program analysis, allowing for sophisticated optimization of tile-level data flow and control flow. The Triton-JIT compiler is key in converting Triton-IR programs into efficient machine code, using both machine-independent and machine-dependent optimization techniques.
I had a look at some examples (first, RELU!) and took some notes (below).
define kernel void @relu ( float * %A , i32 %M , i32 % N ) {
prologue :
% rm = call i32 <8 > get_global_range (0) ;
% rn = call i32 <8 > get_global_range (1) ;
; broadcast shapes
%1 = reshape i32 <8 , 8 > % M;
% M0 = broadcast i32 <8 , 8 > %1;
%2 = reshape i32 <8 , 8 > % N;
% N0 = broadcast i32 <8 , 8 > %2;
; broadcast global ranges
%3 = reshape i32 <8 , 1 > % rm;
% rm_bc = broadcast i32 <8 , 8 > %3;
%4 = reshape i32 <1 , 8 > % rn;
% rn_bc = broadcast i32 <8 , 8 > %4;
; compute mask
% pm = icmp slt % rm_bc , % M0;
% pn = icmp slt % rn_bc , % N0;
% msk = and % pm , % pn;
; compute pointer
% A0 = splat float * <8 , 8 > % A;
%5 = getelementptr % A0 , % rm_bc ;
%6 = mul % rn_bc , % M0;
% pa = getelementptr %5 , %6;
; compute result
% a = load % pa;
% _0 = splat float <8 , 8 > 0;
% result = max % float %a , % _0;
; write back
store fp32 <8 , 8 > % pa , % result
}
Function Definition: define kernel void @relu(float* %A, i32 %M, i32 %N) defines a kernel function named relu that takes three parameters:
%A: A pointer to a floating-point array.
%M and %N: Integer variables representing dimensions of the data.
Prologue and Global Range Acquisition:
%rm = call i32 <8> get_global_range(0);
%rn = call i32 <8> get_global_range(1);
These lines get the global range (size of the execution grid) in two dimensions. %rm and %rn are likely the row and column indices for the current work-item.
Broadcast Shapes: The reshape and broadcast operations transform %M and %N into 8x8 matrices (%M0 and %N0). This is probably done to match the dimensions with the global range for parallel execution.
Compute Mask: The icmp slt instructions compare %rm_bc and %rn_bc with %M0 and %N0 to create masks (%pm and %pn). This step checks if the current indices are within the bounds of the matrix dimensions.
%msk is the logical AND of these masks, used to determine valid computation points.
Compute Pointer: These instructions calculate the memory address for each element in the array %A to be processed. %A0 is a broadcasted pointer array, and getelementptr is used to navigate through it.
Compute Result (and storage):
%a = load %pa; loads the data from the calculated address.
%_0 = splat float <8, 8> 0; creates a matrix of zeros.
%result = max % float %a, % _0; applies the ReLU operation, which is essentially max(0, x) for each element x in %a.
store fp32 <8 , 8 > % pa , % result ; storage
Global Range Acquisition was something I had a look at in more detail. These lines (below) are about obtaining the global range of the kernel execution. Let’s break this down:
%rm = call i32 <8> get_global_range(0);
%rn = call i32 <8> get_global_range(1);
Global Range: In parallel computing, especially in GPU computing, a kernel function is executed over a grid of threads or work-items. The global range defines the total number of these work-items or threads in each dimension of this grid. In GPU terms, this could refer to the total number of threads in each block and the number of blocks in the grid.
get_global_range Function: This is a function call that returns the size of the execution grid in a specific dimension. The execution grid is often conceptualized in multiple dimensions (e.g., 2D grid with rows and columns). Here, get_global_range(0) and get_global_range(1) are likely retrieving the size of the grid in two dimensions: dimension 0 (which could be rows) and dimension 1 (which could be columns).
i32 <8>: This notation suggests that the returned value is a 32-bit integer (i32), and the <8> likely signifies a vector width of 8. This implies that the function is designed to operate on data in a SIMD (Single Instruction, Multiple Data) manner, where each work-item can perform operations on 8 data points simultaneously.
%rm and %rn: These are variables that store the global range values for each dimension. They are likely used later in the kernel to determine which part of the data each work-item should process. For instance, in a matrix operation, %rm might correspond to the row index and %rn to the column index that the current work-item is responsible for.
Interesting! Then I had a look at their experiments and esp at their shift-convolution implementation (made me think of how all these abstractions are going to help us come up with new algorithms):
const tunable int TM = {16 , 32 , 64 , 128};
const tunable int TN = {16 , 32 , 64 , 128};
const tunable int TK = {8};
__constant__ int * delta = alloc_const int [512];
for ( int c = 0; c < C ; c ++)
delta [ c ] = c * H * W + shift_h [ c ]* W + shift_w [ c ]
void shift_conv ( restrict read_only float *a ,
restrict read_only float *b , float *c ,
int M , int N , int K ) {
int rxa [ TM ] = get_global_range [ TM ](0) ;
int ryb [ TN ] = get_global_range [ TN ](1) ;
int rka [ TK ] = 0 ... TK ;
int rkb [ TK ] = 0 ... TK ;
float C [ TM , TN ] = 0;
float * pxa [ TM , TK ] = a + rxa [: , newaxis ];
float * pb [ TN , TK ] = b + ryb [: , newaxis ] + rkb * N ;
__constant__ int * pd [ TK ] = delta + rka ;
for ( int k = K ; k > 0; k = k - TK ) {
int delta [ TK ] = * pd ;
float * pa [ TM , TK ] = pxa + delta [ newaxis , :];
float a [ TM , TK ] = * pa ;
float b [ TN , TK ] = * pb ;
C = dot (a , trans ( b ) , C ) ;
pb = pb + TK * N ;
pd = pd + TK ;
}
int rxc [ TM ] = get_global_range [ TM ](0) ;
int ryc [ TN ] = get_global_range [ TN ](1) ;
float * pc [ TM , TN ] = c + rxc [: , newaxis ] + ryc * M ;
bool checkc0 [ TM ] = rxc < M ;
bool checkc1 [ TN ] = ryc < N ;
bool checkc [ TM , TN ] = checkc0 [: , newaxis ] && checkc1 ;
@checkc * pc = C ;
}
Took some notes -
Tunable Parameters (TM, TN, TK): TM, TN, and TK are defined as tunable parameters, allowing the kernel to be optimized for different tile sizes. This is critical for optimizing performance on various GPU architectures.
Constant Memory Allocation (delta): delta is allocated in constant memory, which is a read-only memory space on GPUs. This is used for storing shift indices, an essential part of shifted convolutions.
Shift Vector Calculation: The loop calculates the shift vector delta[c] for each channel c. This vector is used to apply different shifts to different channels of the input data.
Shift Convolution Function (shift_conv): This function performs the shifted convolution operation. It takes pointers to input matrices a and b, an output matrix c, and the dimensions M, N, and K.
Global Range Indexing: rxa, ryb, rka, and rkb are arrays that store global range indices for tiles. These are used to distribute the computation across multiple threads in a GPU.
Tile-Based Computation: The algorithm computes on tiles of the input matrices. This tiled approach is essential for efficient use of GPU resources.
Convolution Computation: The convolution is computed in a loop over the dimension K. It uses the dot product of submatrices of a and b (with appropriate shifts applied), accumulating the results in C.
I hope you enjoyed it!