Automatic Kernel Generation in PlaidML

May 19, 2018 | By: Jeremy Bruestle

Historically, an engineer-intensive aspect of developing a machine learning backend was producing efficient device kernels from mathematical expressions of tensor operations. Practitioners wishing to utilize cutting edge research kernels in their networks needed to either wait for this development cycle to complete or rely on order-of-magnitude slower CPU implementations. Now, PlaidML, NNVM/TVM, and Tensor Comprehensions generate efficient kernels automatically from tensor expressions, bypassing this costly bottleneck. Below we’ll give an overview of how PlaidML transforms an operation requested by a graph in a frontend (such as Keras or ONNX) into an optimized OpenCL kernel.

We’ll also take a detailed look at the unique ways that PlaidML automatically performs several key aspects of generating efficiently parallelized kernels, including streamlining cache performance and minimizing edge-case conditionals.

Comments on Hacker News

Several core steps of our kernel generation algorithms are to optimize caching. We flatten Tile code, converting multidimensional tensor indices into pointer offsets. We split large tensors into tiles that can be cached in their entirety, and pick tile sizes that optimize these tiles for the specific hardware being used. We layout these tiles within the cache to minimize bank conflicts and maximize the provision of data to multiply accumulators. We thread the execution of the kernel to take advantage of the GPU’s SIMD processor. And finally in codegen we construct a Semtree representation of the kernel we will generate.

Throughout this article we will follow a real example: A batched 3×3 convolution with “same” padding followed by a ReLU. This operation is simple enough to provide a clear concrete example, while complex enough to demonstrate both how PlaidML addresses the core challenges of kernel generation and also how PlaidML approaches common complications, including padding and the fusion of multiple operations into a single kernel. In addition, this case is a straightforward operation whose performance is nonetheless frequently critical to overall network performance.

A researcher designing a novel operation in PlaidML would write Tile code for that operation, so we’ll start there for our example. (For an introduction to Tile and its syntax, see our Tile tutorial.) The Tile function for such a convolution and ReLU looks like this:

function (D[N, X, Y, CI], K[I, J, CO, CI]) -> (R) {
  O[n, x, y, co : N, X, Y, CO] = +(D[n, x+i-1, y+j-1, ci] * K[i, j, co, ci]);
  R = (O > 0 ? O : 0);
}

From this code, PlaidML can automatically generate efficient kernels; all that remains for the researcher to do is make this code accessible to their preferred frontend, as described in our Operations tutorial.

Now let’s walk through how PlaidML generates a kernel from this Tile code in detail. We will focus primarily on code generation for GPUs, and we simplify some of the steps for clarity.

Preparing for generation

Before we reach the kernel generation phase of compilation, a number of transformations have already been done. These won’t be covered in detail here. Some have already been discussed in prior posts, and some we hope to write additional posts on soon. These early steps include:

In addition to the Tile code for our example, we need to pick out concrete values for all the relevant sizes. One of the somewhat unique elements of PlaidML is that it generates custom kernels based on the exact shape of all of its input tensors. This allows the optimization to maximize performance. For example, batch size 1 convolutions might need a completely different tiling and loop ordering to be efficient vs batch size 32. And clearly a 3×3 convolution requires different tradeoffs than a 7×7 convolution.

So, given a batch size of 32, 64 channels in and out, and a 224×224 image, by the time we reach the kernel generation phase, we have determined:

  1. That we will fuse the ReLU into the convolution kernel (more on this in a bit)
  2. The ranges of the various indexes are:
    Index     Range
       ci        64
       co        64
        i         3
        j         3
        n        32
        x       224
        y       224
    
  3. There are additional constraints on the indexes not implied by the index ranges:
     0 <= x+i-1 < 224
     0 <= y+j-1 < 224
    

    Note: these constraints produce the 0 padding needed by “same” convolutions

Flattening

The next step in kernel generation is flattening. Flattening transforms the Tile code into a simple table. Basically, given the fact that we know the full memory layout of each of the tensors, we can convert the tensor accesses into a memory location that is represented as an affine polynomial over the indexes. As a simplified example, imagine a 16×16 matrix M, represented in memory as a packed C style buffer. If we wanted to access M[i, i+j], we could just access the memory located 16*i + 1*(i + j) == 17*i + j elements into M. If each tensor dimension’s access is an affine polynomial of the index variables, then the entire memory access is also just an affine polynomial. Thus flattened, the entire convolution is just a set of numbers: for each index variable, and each tensor, what is the index multiplier (or stride in numpy terms)? In our example, again assuming standard C style layout, we get (from our debug logging):

Contraction kernel_ID_0:
O[n, x, y, co : N, X, Y, CO] = +(D[n, -1 + i + x, -1 + j + y, ci] * K[i, j, co, ci])
             Range         O         D         K
      ci        64         0         1         1
      co        64         1         0        64
       i         3         0     14336     12288
       j         3         0        64      4096
       n        32   3211264   3211264         0
       x       224     14336     14336         0
       y       224        64        64         0
     off                   0    -14400         0

Here off represents the fixed offset (caused in this case by the -1 constant used to center the convolutions).

Looking at this table in detail, we can see why the listed strides correspond to the specified Tile function. Take for example the value for index i and tensor D. Index i is used in the second dimension of D, and for each i, that dimension moves by 1. Since the lower two dimensions have sizes of 224 and 64 respectively, and since i is not used elsewhere in D, each movement of i moves our pointer in D by 1×224×64 = 14336 entries.

In addition, we flatten the constraints:

Constraint: (0,0,-1,0,0,-1,0) <= -1
Constraint: (0,0,1,0,0,1,0) <= 224
Constraint: (0,0,0,-1,0,0,-1) <= -1
Constraint: (0,0,0,1,0,0,1) <= 224

Here, the numbers on the left hand side are multipliers for ci, co, i, etc in the same order as they appear in the table above. At this level, all constraints are one sided.

Finally, we have additional post-convolution operations which take place once for each output generated. Those are given in the debug output as:

Output operations:
   _T1 = cmp_gt(O, 0)
   R = cond(_T1, O, 0)

These three tables fully specify the convolution/ReLU. That is: the tables are sufficient to generate a correct kernel with no additional information. At this point, we could write a (terribly inefficient) version of the kernel as:

set all values in O to zero.

for ci in [0, 64):
  for co in [0, 64):
    for i in [0, 3):
      for j in [0, 3):
        for n in [0, 32):
          for x in [0, 224):
            for y in [0, 224):
              if ((-1*i + -1*x) <= -1) and
                  (1*i + 1*x) <= 224) and
                  (-1*j + -1*y) <= -1) and
                  (1*j + 1*y) <= 224)):
                O[1*co + 3211264*n + 14336*x + 64*y] +=
                  D[1*ci + 14336*i + 64*j + 3211264*n + 14336*x + 64*y] *
                  K[1*ci + 64*co + 12288*i + 4096*j];

for i in [0, len(O)):
  _T1[i] = cmp_gt(O[i], 0)
  _R[i] = cond(_T1[i], O[i], 0)

Again, pay close attention to the relation between the tables above and the code. While the above pseudocode is basically correct, it is not efficient. There are 3 major causes of this inefficiency:

  1. No parallelism
  2. Poor cache performance
  3. Conditionals in the innermost loop

#1 we will resolve by using the GPU to parallelize the work across threads and work-groups, and #3 we will resolve during code generation by making two versions of the code, a fast path that runs most of the time and ignores constraints, and a slow path that runs rarely but checks constraints. Let’s focus first on #2, since determining how to block work based on caching turns out to be one of the most fundamental and difficult problems in automatic kernel generation, and the first thing we need to decide to inform the rest of our choices.

Picking tile sizes

Memory is king when it comes to making a GPU function at maximal efficiency. Because the GPU has so many available FLOPS, and DRAM is only so fast, it’s very easy to have computations be limited not by multiply-accumulate, but instead on access to DRAM. Additionally, DRAM on the GPU is only fast when used in a very specific way: large contiguous accesses. Finally, even with all the data loaded into the GPU equivalent of L1 cache (shared memory in CUDA terminology, or local memory in OpenCL terminology), it’s imperative that the SIMD multiplier accumulates in the GPU each read from different banks of cache, or they will stall.

All of these hinge on picking the right way to group operations so that each value read from global memory is used multiple times before it leaves cache. To allow an organized approach to this problem, we consider the space of all valid indexes. That is all 64×64×3×3×32×224×224 = 59,190,018,048 multiply accumulates that will be performed during the execution of the convolution. Imagine we group the operations into tiles each operating over a sub-range of each index. For a given tile size we can determine:

PlaidML considers a large number of tile sizes, and for each tile size, evaluates these values. It then uses a model of the specific hardware it’s targeting to compute a cost for that tile size. The hardware model includes information such as:

PlaidML then chooses the best tile size. If the hardware model was perfectly accurate, and all tiles were created equal in terms of downstream code generation, this would always be the path to the maximal performing kernel. In practice, our chosen tile size is usually close to the best tile size, but we provide a profiling guided optimization setting that lets PlaidML try out its top guesses and measure the performance. When enabled, tile scanning typically improves performance by 10-40%, but by default it is disabled since it increases compile time significantly.

Getting back to our example, here is a small bit of the output of the tile size optimizer from the super-verbose logger:

Computing cost for tile size: [8, 32, 1, 3, 16, 8, 2]
  Compute score: to=236760072192 wg=12544 il=24 sm=20656 or=32768 mr=19456 mw=32768 op=256 rp=0 tu=256
  over memory
Computing cost for tile size: [8, 32, 1, 3, 16, 4, 4]
  Compute score: to=236760072192 wg=12544 il=24 sm=16272 or=32768 mr=15360 mw=32768 op=256 rp=0 tu=256
  over regs
Computing cost for tile size: [16, 32, 1, 1, 16, 2, 2]
  Compute score: to=236760072192 wg=50176 il=36 sm=6488 or=8192 mr=6144 mw=8192 op=256 rp=0 tu=256
  flops_per_byte=20.5714 occupancy=20
  roof_ratio=1 occ_ratio=1 thread_ratio=1 score=1
Computing cost for tile size: [8, 32, 2, 1, 16, 2, 2]
  Compute score: to=236760072192 wg=50176 il=48 sm=5264 or=8192 mr=5120 mw=8192 op=256 rp=0 tu=256
  flops_per_byte=18.5806 occupancy=20
  roof_ratio=0.929032 occ_ratio=1 thread_ratio=1 score=0.929032

The chosen tile size for our example is: [ci: 8, co: 32, i: 2, j: 3, n: 16, x: 2, y: 2]

This corresponds to grouping ci into groups of 8, co into groups of 32, etc. This will differ quite a bit on different hardware or even across differences to the sizes of input tensors.

At this point we are ready for the next step: memory layout.

Memory layout

Based on the tile size that was picked, we will divide the work up into work groups (for more information on the work group concept, see any introduction to OpenCL programming). Each work group will produce one part of the output tensor. In this case, 32 output channels across 16 batch elements of a 2×2 spatial region (compare this to the tile size, for example co=32). All of these work groups are mutually independent, and will be scheduled by the GPU onto one of its multiple compute units. Each work group will be passed coordinates by the GPU specifying which region of the output space it needs to compute. It will perform this computation, write its results to the output buffer, and terminate. Since each work group writes to a distinct output region, there are no concurrency issues to consider.

Within each work group, we need to perform all of the accumulations required to produce the outputs. However, once again, we wish to tile the work into an outer loop that performs the loads of all relevant inputs, followed by an inner loop that performs dense computation directly out of L1. In this case, the outer loops walk over the right hand side variables, ci (input channels, 8 at a time) and i (the position within the convolution). The outer loops do not walk over j in this case, since the tile size of j is the full range of j.

Each of the inner loops perform direct multiply accumulations, done in a thread parallel way. At this stage the overall shape of the process has been determined; this includes the size of the output block for each work group, the number of outer loops, and the operations performed in the inner loop. However, the details of how memory is loaded into L1 and how the inner loop is performed must still be determined. We begin by deciding how L1 cache memory will be organized.

First, we determine how much of each input to load for the inner loop, and how to load it. Here we see an interesting effect. Even though we are computing a 2×2 spatial region of the output, since each inner loop computes over a 2×3 sub-convolution, we actually need to load a 3×4 spatial region from the input. To explain further: over the inner loop, the y offset is in [0, 1], and the j offset is in [0, 1, 2], the smallest value of y + j is 0 and the largest value is 3, meaning the range is 4. Note that while there are 6 possible values for (y, j), there are only 4 memory locations accessed by these 6 combinations.

When two index variables access memory relative to a given input in a complex fashion such as this, we must consider them as a unit when loading global memory into shared memory. We call the process merging indexes. We may also merge indexes for other reasons, such as reducing the number of separate indexes to consider. In this case, for loading the image input, x must be merged with i, and x must be merged with j. Each merged index can now be given a stride and size in global memory as well as in shared memory.

So for the image input, we have 4 logical indexes to consider when loading the image data on each inner loop: i_x, j_y, ci, and n. We can now decide how to place those values in shared/local memory, that is, for each index, what stride (in elements) do we multiply by to move one step in that index for both global memory as well as for shared memory. Also: how large is the shared memory buffer? Our memory layout algorithm computes the answer as follows:

         Range    Global     Local
 i_x         3     14336         1
 j_j         4        64       393
  ci         8         1        49
   n        16   3211264         3

The size of the input buffer used to hold the image data is 1572 elements. We call this set of decisions regarding loading and storage of input data the read plan. Memory is generally arranged to try to reduce bank conflicts both during load and during compute. To do this, the tiles of each input are placed into shared memory in a padded way, each stride being an odd number, guaranteeing that the stride is relatively prime to the number of banks (always a power of 2). This reduces bank conflicts when threads in the same SIMD group (or warp in CUDA terminology) each access consecutive values along any dimension.

For the convolutional kernel K, a similar table is produced. Here j and i end up merged to reduce the index count.

         Range    Global     Local  
 j_i         6      4096       265 
  ci         8         1        33
  co        32        64         1   

The outputs are accumulated into thread local variables rather than shared/local memory. This is because one of the goals of an efficient kernel is to reduce the required inter-thread dependencies, since thread synchronization is expensive. By making all accumulations happen in independent threads, once the data is loaded into L1, no further synchronization is required as each thread operates independently. In some cases, where the total number of accumulators is smaller than the number of threads required for GPU efficiency, we perform partial accumulations in each thread, followed by a rollup phase where thread local data is passed through shared memory using synchronization to compute the final totals, but as long as the number of output values per work group is larger than the number of threads in the work group, this step is not needed.

In this case, the output is 2×2×16×32 = 2048 elements. Since we will be using 256 threads (a generally good number for the GPU in question, as determined by the hardware model), each thread will hold 8 accumulators. In part this tile size was chosen because:

  1. The amount of shared/local memory required is smaller than the available size.
  2. The number of accumulators used by each thread is small enough to not overflow the register file.
  3. Memory access for the reading from global memory is wide enough (8 elements in this case) to maximize bandwidth for the hardware in question.
  4. The amount of accumulation (compute) per memory load is large.

At this point, we move on to thread layout.

Threading

Each work group of the kernel will have the following overall structure:

Clear output accumulators
For each outer loop index:
  Load input 1 (Image)
  Load input 2 (Kernel)
  Accumulate inputs into the output (main compute)
Perform post-accumulation operations and write outputs

We now want to thread the process of loading / computing / storing so that each of the steps can be parallelized properly, taking advantage of the high degree of parallelism in modern GPUs. This also entails performing synchronization, making the code look something like:

Clear output accumulators  // Threaded
For each outer loop index:
  Load input 1 (Image)  // Threaded
  Load input 2 (Kernel)  // Threaded
  Sync()  // Make sure data is loaded before computing
  Accumulate inputs into the output (main compute)  // Threaded
  Sync()  // Make sure computations are done before loading new data
Perform post-accumulation operations and write outputs  // Threaded

Here Sync() represents a work-group wide thread synchronization. That is, all threads must reach the sync point before any threads begin executing instructions beyond the sync point.

In each of the cases marked as threaded above, we have some number of operations to perform, and some number of available threads (here 256). We need to assign each operation to each thread such that all threads get good utilization, and no two threads require any additional synchronization beyond the two syncs in the pseudocode above. Additionally, we would like to organize loading/storing threads such that nearby threads (part of the same SIMD group or warp) access contiguous parts of global memory, and that nearby computational threads do not suffer warp divergence, that is, they take any code branches as a group.

To perform thread assignment, we first construct a logical ordering of operations which places operations we want to be assigned nearby threads near to each other. We then divide the overall index space of all operations into two components. The low order part of the index space is distributed “spatially” across threads (a thread idx), and high order part of the index space is distributed “temporally” via looping over an index within a thread (a loop idx). A given logical index can actually be divided up into both types of concrete indexes.

To make this clear, here is the load plan for the kernel K, and the resulting threaded kernel code (simplified and with comments added):

             Range    Global     Local  
     j_i         6      4096       265 
      ci         8         1        33
      co        32        64         1   
int tid = get_local_id(0);      // The thread id, from 0 - 255
int ci_tid = (tid % 8);         // ci is based on the low order 3 bits of the thread id
int co_tid = ((tid / 8) % 32);  // co is based on the high order 5 bits of the thread id 
for(int j_i_lid = 0; j_i_lid < 6; j_i_lid++)  // j_i is looped over in each thread
{
  // Compute the shared memory position
  int lidx = (((33 * ci_tid) + co_tid) + (265 * j_i_lid));  
  // Compute the global memory position, here gbase is determined earlier based on which
  // work group we are in and which outer loop is being executed.
  int gidx = (((gbase + ci_tid) + (64 * co_tid)) + (4096 * j_i_lid));
  // Perform the assignment.
  K_shared[lidx] = K_global[gidx];
}

In general, some index variables may need to be split (partially threaded and partially looped) and a number of edge cases must be handled (total number of indexes not divisible by 256, etc).

Code-gen

Once we have determined the memory layout and threading, the next step is generating code. To allow for portability, we generate sem-trees, which represent the code in a platform neutral way. This is a largely straightforward process, but has a few complexities. First, coordinates of each work group must be packed into the GPU’s notion of coordinates, which are usually limited to 3 dimensions, and then unpacked within the work group. Second, to prevent the edge cases induced by any constraints from slowing down kernel execution time, we build two versions of the inner loop, one that checks the constraints and one that does not. If we determine that there will be no out of bounds access for a given tile, we skip the per-element checks in the inner loop. Additionally, any post-accumulation code is added to the kernel prior to writing the outputs.

The resulting sem-tree is then run through an optimization pass, and lowered into whatever format the platform needs by the HAL (hardware abstraction layer). This might be LLVM, OpenCL C, or even Metal shader language.

The kernel will be further rewritten by additional back-end specific optimization passes. These optimizations include loop unrolling, strength reduction, hoisting and many more. While these further optimizations are critical to the performance of the generated kernels, they are well understood compiler technologies, and don’t provide any additional clarity, so we ignore them here. However, many decisions used in producing the code, for example the use of fixed sized loops, are guided by their impact on such optimizations.

Without further ado, we present the final pre-optimization kernel below, with some variable renaming and additional hand-written comments added for clarity.

// The convolution kernel, where R is the output of the RELU, D is the image, and K is the convolutional kernel
kernel_ID_0(float* R, const float* D, const float* K)
{
  // Adjust the input buffer up and to the right one to remove the constant offset caused by the -1's in x+i-1 and y+j-1
  D = (D + -14400);
  // Setup thead id
  int tid = get_local_id(0);
  // Setup the 8 per-thread output accumulators
  float agg[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  // Make room for the two input tiles
  float D_shared[1572];
  float K_shared[1590];
  // Unpack the output coordinates from the work group coordinates
  // Each coordinate tells us how much to add to the associated index
  int co_gid = ((get_group_id(2) / 2) * 32);
  int y_gid = (get_group_id(0) * 2);
  int x_gid = (get_group_id(1) * 2);
  int n_gid = ((get_group_id(2) % 2) * 16);
  // Now we walk over the outer loops, stepping by the associated tile size
  // Each time through the outer loop will load a new pair of tiles and accumulate them
  // into the final outputs.
  for(int ci_gid = 0; ci_gid < 64; ci_gid += 8)
  {
    for(int i_gid = 0; i_gid < 4; i_gid += 2)
    {
      for(int j_gid = 0; j_gid < 3; j_gid += 3)
      {
        // Load a tile of D
        {
          // Use which tile we are on to compute a base offset
          int gbase = (((((ci_gid + (j_gid * 64)) + (y_gid * 64)) + (i_gid * 14336)) + (x_gid * 14336)) + (n_gid * 3211264));
          // Do a parallel load of the actual elements of D into D_shared
          int ci_tid = (tid % 8);
          int n_tid = ((tid / 8) % 16);
          int j_y_tid = ((tid / 128) % 2);
          for(int j_y_lid = 0; j_y_lid < 2; j_y_lid += 1)
          {
            int j_y = ((2 * j_y_lid) + j_y_tid);
            for(int i_x_lid = 0; i_x_lid < 3; i_x_lid += 1)
            {
              int lidx = ((((49 * ci_tid) + (3 * n_tid)) + (393 * j_y)) + i_x_lid);
              int gidx = ((((gbase + ci_tid) + (3211264 * n_tid)) + (64 * j_y)) + (14336 * i_x_lid));
              // Do the load, making sure to not read out of bounds.  Logical out of bounds elements
              // will be filled with first/last element of tensor, but subsequently ignored
              D_shared[lidx] = D[clamp(gidx, 14400, 102774847)];
            }
          }
        }
        // Load a tile of K
        {
          // Use which tile we are on to compute a base offset
          int gbase = (((ci_gid + (co_gid * 64)) + (j_gid * 4096)) + (i_gid * 12288));
          // Do a parallel load of the actual elements of K into K_shared
          int ci_tid = (tid % 8);
          int co_tid = ((tid / 8) % 32);
          for(int j_i_lid = 0; j_i_lid < 6; j_i_lid += 1)
          {
            int lidx = (((33 * ci_tid) + co_tid) + (265 * j_i_lid));
            int gidx = (((gbase + ci_tid) + (64 * co_tid)) + (4096 * j_i_lid));
            // Do the load, making sure to not read out of bounds.  Logical out of bonds elements
            // will be filled with first/last element of tensor, but subsequently ignored
            K_shared[lidx] = K[clamp(gidx, 0, 36863)];
          }
        }
        // Make sure that all of the threaded loads complete before doing the compute work
        barrier();
        // Here we check if this compute block will ever go out of bounds
        if (((((((-1 * i_gid) + (-1 * x_gid)) <= -1) && ((((i_gid + 2) - 1) + ((x_gid + 2) - 1)) <= 224)) && (((-1 * j_gid) + (-1 * y_gid)) <= -1)) && ((((j_gid + 3) - 1) + ((y_gid + 2) - 1)) <= 224)))
        {
          // Compute block, fast case (no bounds checking)
          // We begin by computing the 'thread' portion of the indexes
          int co_tid = (tid % 8);
          int y_tid = ((tid / 8) % 2);
          int x_tid = ((tid / 16) % 2);
          int n_tid = ((tid / 32) % 8);
          // Now we loop over the remaining indexes/portions of indexes
          for(int i_lid = 0; i_lid < 2; i_lid += 1)
          {
            // We need to handle the fact that the tile size for i (2) doesn't
            // divide the range of i (3)
            int i_cond = ((i_lid < 1) || (i_gid != 2));
            if (i_cond)
            {
              // Innermost loops here will be unrolled by because they are fixed size 
              for(int ci_lid = 0; ci_lid < 8; ci_lid += 1)
              {
                for(int j_lid = 0; j_lid < 3; j_lid += 1)
                {
                  for(int co_lid = 0; co_lid < 4; co_lid += 1)
                  {
                    int co = ((8 * co_lid) + co_tid);
                    for(int n_lid = 0; n_lid < 2; n_lid += 1)
                    {
                      int n = ((8 * n_lid) + n_tid);
                      // Load values from each inputs
                      float val1 = ((float) D_shared[((((((49 * ci_lid) + (393 * j_lid)) + (393 * y_tid)) + i_lid) + x_tid) + (3 * n))]);
                      float val2 = ((float) K_shared[((((33 * ci_lid) + co) + (265 * j_lid)) + (795 * i_lid))]);
                      // Decide which accumulator to add to
                      int agg_idx = (co_lid + (n_lid * 4));
                      // Do the actual multiply accumulate
                      float agg_rhs = (agg[agg_idx] + ((float) (val2 * val1)));
                      agg[agg_idx] = agg_rhs;
                    }
                  }
                }
              }
            }
          }
        }
        else
        {
          // Slow case (do edge handling).  Very similar to above but with index check in innermost loop
          // We begin by computing the 'thread' portion of the indexes
          int co_tid = (tid % 8);
          int y_tid = ((tid / 8) % 2);
          int x_tid = ((tid / 16) % 2);
          int n_tid = ((tid / 32) % 8);
          for(int i_lid = 0; i_lid < 2; i_lid += 1)
          {
            // We need to handle the fact that the tile size for i (2) doesn't
            // divide the range of i (3)
            int i_cond = ((i_lid < 1) || (i_gid != 2));
            if (i_cond)
            {
              // Innermost loops this will be unrolled compiler because they are fixed size 
              for(int ci_lid = 0; ci_lid < 8; ci_lid += 1)
              {
                for(int j_lid = 0; j_lid < 3; j_lid += 1)
                {
                  for(int co_lid = 0; co_lid < 4; co_lid += 1)
                  {
                    int co = ((8 * co_lid) + co_tid);
                    for(int n_lid = 0; n_lid < 2; n_lid += 1)
                    {
                      int n = ((8 * n_lid) + n_tid);
                      // Load values from each inputs.  Note, some of these are logically off the end of the tensor
                      // However, they are never outside of shared memory, and were filled by safe but 'garbage' values during load
                      // Those out of bounds elements will not be accumulated in however, so they will not effect correctness
                      float val1 = ((float) D_shared[((((((49 * ci_lid) + (393 * j_lid)) + (393 * y_tid)) + i_lid) + x_tid) + (3 * n))]);
                      float val2 = ((float) K_shared[((((33 * ci_lid) + co) + (265 * j_lid)) + (795 * i_lid))]);
                      // Decide which accumulator to add to
                      int agg_idx = (co_lid + (n_lid * 4));
                      // Multiply
                      float agg_rhs = (agg[agg_idx] + ((float) (val2 * val1)));
                      // Check constraints and optional accumulate the value in
                      agg[agg_idx] = (((((((-1 * (i_gid + i_lid)) + (-1 * (x_gid + x_tid))) <= -1) && (((i_gid + i_lid) + (x_gid + x_tid)) <= 224)) && (((-1 * (j_gid + j_lid)) + (-1 * (y_gid + y_tid))) <= -1)) && (((j_gid + j_lid) + (y_gid + y_tid)) <= 224))? agg_rhs: agg[agg_idx]);
                    }
                  }
                }
              }
            }
          }
        }
        // Sync threads to prevent new data being loaded before computation completes.
        barrier();
      }
    }
  }
  // Thread the post-processing and output writing
  int co_tid = (tid % 8);
  int y_tid = ((tid / 8) % 2);
  int x_tid = ((tid / 16) % 2);
  int n_tid = ((tid / 32) % 8);
  for(int co_lid = 0; co_lid < 4; co_lid += 1)
  {
    int co = ((8 * co_lid) + co_tid);
    for(int n_lid = 0; n_lid < 2; n_lid += 1)
    {
      int n = ((8 * n_lid) + n_tid);
      // Get the accumulated value from the convolution
      float LO = agg[(co_lid + (n_lid * 4))];
      // Compute the position to write to the output tensor 
      int gout_idx = ((((co_gid + co) + (3211264 * (n_gid + n))) + (14336 * (x_gid + x_tid))) + (64 * (y_gid + y_tid)));
      // Do the RELU
      bool L_T1 = (((float) LO) > 0);
      float LR = (((bool) L_T1)? ((float) LO): 0);
      // Write the result
      R[gout_idx] = LR;
    }
  }
}

Conclusion

As you can see, PlaidML’s compiler methodology provides a powerful way to generate high performance GPU accelerated kernel for nearly any operation. We hope that this document helps to explain how our system works under the hood. If there are unclear aspects of this document, feel free to reach out to us.

Social media:

© 2018 Vertex.AI