MLIR  22.0.0git
AsyncParallelFor.cpp
Go to the documentation of this file.
1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements scf.parallel to scf.for + async.execute conversion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "PassDetail.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
27 #include <utility>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_ASYNCPARALLELFORPASS
31 #include "mlir/Dialect/Async/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::async;
36 
37 #define DEBUG_TYPE "async-parallel-for"
38 
39 namespace {
40 
41 // Rewrite scf.parallel operation into multiple concurrent async.execute
42 // operations over non overlapping subranges of the original loop.
43 //
44 // Example:
45 //
46 // scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
47 // "do_some_compute"(%i, %j): () -> ()
48 // }
49 //
50 // Converted to:
51 //
52 // // Parallel compute function that executes the parallel body region for
53 // // a subset of the parallel iteration space defined by the one-dimensional
54 // // compute block index.
55 // func parallel_compute_function(%block_index : index, %block_size : index,
56 // <parallel operation properties>, ...) {
57 // // Compute multi-dimensional loop bounds for %block_index.
58 // %block_lbi, %block_lbj = ...
59 // %block_ubi, %block_ubj = ...
60 //
61 // // Clone parallel operation body into the scf.for loop nest.
62 // scf.for %i = %blockLbi to %blockUbi {
63 // scf.for %j = block_lbj to %block_ubj {
64 // "do_some_compute"(%i, %j): () -> ()
65 // }
66 // }
67 // }
68 //
69 // And a dispatch function depending on the `asyncDispatch` option.
70 //
71 // When async dispatch is on: (pseudocode)
72 //
73 // %block_size = ... compute parallel compute block size
74 // %block_count = ... compute the number of compute blocks
75 //
76 // func @async_dispatch(%block_start : index, %block_end : index, ...) {
77 // // Keep splitting block range until we reached a range of size 1.
78 // while (%block_end - %block_start > 1) {
79 // %mid_index = block_start + (block_end - block_start) / 2;
80 // async.execute { call @async_dispatch(%mid_index, %block_end); }
81 // %block_end = %mid_index
82 // }
83 //
84 // // Call parallel compute function for a single block.
85 // call @parallel_compute_fn(%block_start, %block_size, ...);
86 // }
87 //
88 // // Launch async dispatch for [0, block_count) range.
89 // call @async_dispatch(%c0, %block_count);
90 //
91 // When async dispatch is off:
92 //
93 // %block_size = ... compute parallel compute block size
94 // %block_count = ... compute the number of compute blocks
95 //
96 // scf.for %block_index = %c0 to %block_count {
97 // call @parallel_compute_fn(%block_index, %block_size, ...)
98 // }
99 //
100 struct AsyncParallelForPass
101  : public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
102  using Base::Base;
103 
104  void runOnOperation() override;
105 };
106 
107 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
108 public:
109  AsyncParallelForRewrite(
110  MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
111  AsyncMinTaskSizeComputationFunction computeMinTaskSize)
112  : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
113  numWorkerThreads(numWorkerThreads),
114  computeMinTaskSize(std::move(computeMinTaskSize)) {}
115 
116  LogicalResult matchAndRewrite(scf::ParallelOp op,
117  PatternRewriter &rewriter) const override;
118 
119 private:
120  bool asyncDispatch;
121  int32_t numWorkerThreads;
122  AsyncMinTaskSizeComputationFunction computeMinTaskSize;
123 };
124 
125 struct ParallelComputeFunctionType {
126  FunctionType type;
127  SmallVector<Value> captures;
128 };
129 
130 // Helper struct to parse parallel compute function argument list.
131 struct ParallelComputeFunctionArgs {
132  BlockArgument blockIndex();
133  BlockArgument blockSize();
134  ArrayRef<BlockArgument> tripCounts();
135  ArrayRef<BlockArgument> lowerBounds();
136  ArrayRef<BlockArgument> steps();
137  ArrayRef<BlockArgument> captures();
138 
139  unsigned numLoops;
141 };
142 
143 struct ParallelComputeFunctionBounds {
144  SmallVector<IntegerAttr> tripCounts;
145  SmallVector<IntegerAttr> lowerBounds;
146  SmallVector<IntegerAttr> upperBounds;
148 };
149 
150 struct ParallelComputeFunction {
151  unsigned numLoops;
152  func::FuncOp func;
153  llvm::SmallVector<Value> captures;
154 };
155 
156 } // namespace
157 
158 BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
159 BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
160 
161 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
162  return args.drop_front(2).take_front(numLoops);
163 }
164 
165 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
166  return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
167 }
168 
169 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
170  return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
171 }
172 
173 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
174  return args.drop_front(2 + 4 * numLoops);
175 }
176 
177 template <typename ValueRange>
179  SmallVector<IntegerAttr> attrs(values.size());
180  for (unsigned i = 0; i < values.size(); ++i)
181  matchPattern(values[i], m_Constant(&attrs[i]));
182  return attrs;
183 }
184 
185 // Converts one-dimensional iteration index in the [0, tripCount) interval
186 // into multidimensional iteration coordinate.
188  ArrayRef<Value> tripCounts) {
189  SmallVector<Value> coords(tripCounts.size());
190  assert(!tripCounts.empty() && "tripCounts must be not empty");
191 
192  for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
193  coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]);
194  index = arith::DivSIOp::create(b, index, tripCounts[i]);
195  }
196 
197  return coords;
198 }
199 
200 // Returns a function type and implicit captures for a parallel compute
201 // function. We'll need a list of implicit captures to setup block and value
202 // mapping when we'll clone the body of the parallel operation.
203 static ParallelComputeFunctionType
204 getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
205  // Values implicitly captured by the parallel operation.
206  llvm::SetVector<Value> captures;
207  getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures);
208 
209  SmallVector<Type> inputs;
210  inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
211 
212  Type indexTy = rewriter.getIndexType();
213 
214  // One-dimensional iteration space defined by the block index and size.
215  inputs.push_back(indexTy); // blockIndex
216  inputs.push_back(indexTy); // blockSize
217 
218  // Multi-dimensional parallel iteration space defined by the loop trip counts.
219  for (unsigned i = 0; i < op.getNumLoops(); ++i)
220  inputs.push_back(indexTy); // loop tripCount
221 
222  // Parallel operation lower bound, upper bound and step. Lower bound, upper
223  // bound and step passed as contiguous arguments:
224  // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
225  for (unsigned i = 0; i < op.getNumLoops(); ++i) {
226  inputs.push_back(indexTy); // lower bound
227  inputs.push_back(indexTy); // upper bound
228  inputs.push_back(indexTy); // step
229  }
230 
231  // Types of the implicit captures.
232  for (Value capture : captures)
233  inputs.push_back(capture.getType());
234 
235  // Convert captures to vector for later convenience.
236  SmallVector<Value> capturesVector(captures.begin(), captures.end());
237  return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
238 }
239 
240 // Create a parallel compute fuction from the parallel operation.
241 static ParallelComputeFunction createParallelComputeFunction(
242  scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds,
243  unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {
244  OpBuilder::InsertionGuard guard(rewriter);
245  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
246 
247  ModuleOp module = op->getParentOfType<ModuleOp>();
248 
249  ParallelComputeFunctionType computeFuncType =
250  getParallelComputeFunctionType(op, rewriter);
251 
252  FunctionType type = computeFuncType.type;
253  func::FuncOp func = func::FuncOp::create(
254  op.getLoc(),
255  numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops"
256  : "parallel_compute_fn",
257  type);
258  func.setPrivate();
259 
260  // Insert function into the module symbol table and assign it unique name.
261  SymbolTable symbolTable(module);
262  symbolTable.insert(func);
263  rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
264 
265  // Create function entry block.
266  Block *block =
267  b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
268  SmallVector<Location>(type.getNumInputs(), op.getLoc()));
269  b.setInsertionPointToEnd(block);
270 
271  ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
272 
273  // Block iteration position defined by the block index and size.
274  BlockArgument blockIndex = args.blockIndex();
275  BlockArgument blockSize = args.blockSize();
276 
277  // Constants used below.
280 
281  // Materialize known constants as constant operation in the function body.
282  auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
283  return llvm::to_vector(
284  llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
285  if (IntegerAttr attr = std::get<1>(tuple))
286  return arith::ConstantOp::create(b, attr);
287  return std::get<0>(tuple);
288  }));
289  };
290 
291  // Multi-dimensional parallel iteration space defined by the loop trip counts.
292  auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
293 
294  // Parallel operation lower bound and step.
295  auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
296  auto steps = values(args.steps(), bounds.steps);
297 
298  // Remaining arguments are implicit captures of the parallel operation.
299  ArrayRef<BlockArgument> captures = args.captures();
300 
301  // Compute a product of trip counts to get the size of the flattened
302  // one-dimensional iteration space.
303  Value tripCount = tripCounts[0];
304  for (unsigned i = 1; i < tripCounts.size(); ++i)
305  tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
306 
307  // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
308  // blockFirstIndex = blockIndex * blockSize
309  Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize);
310 
311  // The last one-dimensional index in the block defined by the `blockIndex`:
312  // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
313  Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize);
314  Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount);
315  Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1);
316 
317  // Convert one-dimensional indices to multi-dimensional coordinates.
318  auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
319  auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
320 
321  // Compute loops upper bounds derived from the block last coordinates:
322  // blockEndCoord[i] = blockLastCoord[i] + 1
323  //
324  // Block first and last coordinates can be the same along the outer compute
325  // dimension when inner compute dimension contains multiple blocks.
326  SmallVector<Value> blockEndCoord(op.getNumLoops());
327  for (size_t i = 0; i < blockLastCoord.size(); ++i)
328  blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1);
329 
330  // Construct a loop nest out of scf.for operations that will iterate over
331  // all coordinates in [blockFirstCoord, blockLastCoord] range.
332  using LoopBodyBuilder =
333  std::function<void(OpBuilder &, Location, Value, ValueRange)>;
334  using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
335 
336  // Parallel region induction variables computed from the multi-dimensional
337  // iteration coordinate using parallel operation bounds and step:
338  //
339  // computeBlockInductionVars[loopIdx] =
340  // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx]
341  SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
342 
343  // We need to know if we are in the first or last iteration of the
344  // multi-dimensional loop for each loop in the nest, so we can decide what
345  // loop bounds should we use for the nested loops: bounds defined by compute
346  // block interval, or bounds defined by the parallel operation.
347  //
348  // Example: 2d parallel operation
349  // i j
350  // loop sizes: [50, 50]
351  // first coord: [25, 25]
352  // last coord: [30, 30]
353  //
354  // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
355  // is between 25 and 30 it should start at 0. The upper bound for `j` should
356  // be 50, except when `i` is equal to 30, then it should also be 30.
357  //
358  // Value at ith position specifies if all loops in [0, i) range of the loop
359  // nest are in the first/last iteration.
360  SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
361  SmallVector<Value> isBlockLastCoord(op.getNumLoops());
362 
363  // Builds inner loop nest inside async.execute operation that does all the
364  // work concurrently.
365  LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
366  return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
367  ValueRange args) {
368  ImplicitLocOpBuilder b(loc, nestedBuilder);
369 
370  // Compute induction variable for `loopIdx`.
371  computeBlockInductionVars[loopIdx] =
372  arith::AddIOp::create(b, lowerBounds[loopIdx],
373  arith::MulIOp::create(b, iv, steps[loopIdx]));
374 
375  // Check if we are inside first or last iteration of the loop.
376  isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
377  b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
378  isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
379  b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
380 
381  // Check if the previous loop is in its first or last iteration.
382  if (loopIdx > 0) {
383  isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
384  b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
385  isBlockLastCoord[loopIdx] = arith::AndIOp::create(
386  b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
387  }
388 
389  // Keep building loop nest.
390  if (loopIdx < op.getNumLoops() - 1) {
391  if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
392  // For block aligned loops we always iterate starting from 0 up to
393  // the loop trip counts.
394  scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(),
395  workLoopBuilder(loopIdx + 1));
396 
397  } else {
398  // Select nested loop lower/upper bounds depending on our position in
399  // the multi-dimensional iteration space.
400  auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx],
401  blockFirstCoord[loopIdx + 1], c0);
402 
403  auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx],
404  blockEndCoord[loopIdx + 1],
405  tripCounts[loopIdx + 1]);
406 
407  scf::ForOp::create(b, lb, ub, c1, ValueRange(),
408  workLoopBuilder(loopIdx + 1));
409  }
410 
411  scf::YieldOp::create(b, loc);
412  return;
413  }
414 
415  // Copy the body of the parallel op into the inner-most loop.
416  IRMapping mapping;
417  mapping.map(op.getInductionVars(), computeBlockInductionVars);
418  mapping.map(computeFuncType.captures, captures);
419 
420  for (auto &bodyOp : op.getRegion().front().without_terminator())
421  b.clone(bodyOp, mapping);
422  scf::YieldOp::create(b, loc);
423  };
424  };
425 
426  scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
427  workLoopBuilder(0));
428  func::ReturnOp::create(b, ValueRange());
429 
430  return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
431 }
432 
433 // Creates recursive async dispatch function for the given parallel compute
434 // function. Dispatch function keeps splitting block range into halves until it
435 // reaches a single block, and then excecutes it inline.
436 //
437 // Function pseudocode (mix of C++ and MLIR):
438 //
439 // func @async_dispatch(%block_start : index, %block_end : index, ...) {
440 //
441 // // Keep splitting block range until we reached a range of size 1.
442 // while (%block_end - %block_start > 1) {
443 // %mid_index = block_start + (block_end - block_start) / 2;
444 // async.execute { call @async_dispatch(%mid_index, %block_end); }
445 // %block_end = %mid_index
446 // }
447 //
448 // // Call parallel compute function for a single block.
449 // call @parallel_compute_fn(%block_start, %block_size, ...);
450 // }
451 //
452 static func::FuncOp
453 createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
454  PatternRewriter &rewriter) {
455  OpBuilder::InsertionGuard guard(rewriter);
456  Location loc = computeFunc.func.getLoc();
457  ImplicitLocOpBuilder b(loc, rewriter);
458 
459  ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
460 
461  ArrayRef<Type> computeFuncInputTypes =
462  computeFunc.func.getFunctionType().getInputs();
463 
464  // Compared to the parallel compute function async dispatch function takes
465  // additional !async.group argument. Also instead of a single `blockIndex` it
466  // takes `blockStart` and `blockEnd` arguments to define the range of
467  // dispatched blocks.
468  SmallVector<Type> inputTypes;
469  inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
470  inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
471  inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
472 
473  FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
474  func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn", type);
475  func.setPrivate();
476 
477  // Insert function into the module symbol table and assign it unique name.
478  SymbolTable symbolTable(module);
479  symbolTable.insert(func);
480  rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
481 
482  // Create function entry block.
483  Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
484  SmallVector<Location>(type.getNumInputs(), loc));
485  b.setInsertionPointToEnd(block);
486 
487  Type indexTy = b.getIndexType();
490 
491  // Get the async group that will track async dispatch completion.
492  Value group = block->getArgument(0);
493 
494  // Get the block iteration range: [blockStart, blockEnd)
495  Value blockStart = block->getArgument(1);
496  Value blockEnd = block->getArgument(2);
497 
498  // Create a work splitting while loop for the [blockStart, blockEnd) range.
499  SmallVector<Type> types = {indexTy, indexTy};
500  SmallVector<Value> operands = {blockStart, blockEnd};
501  SmallVector<Location> locations = {loc, loc};
502 
503  // Create a recursive dispatch loop.
504  scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands);
505  Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations);
506  Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations);
507 
508  // Setup dispatch loop condition block: decide if we need to go into the
509  // `after` block and launch one more async dispatch.
510  {
511  b.setInsertionPointToEnd(before);
512  Value start = before->getArgument(0);
513  Value end = before->getArgument(1);
514  Value distance = arith::SubIOp::create(b, end, start);
515  Value dispatch =
516  arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1);
517  scf::ConditionOp::create(b, dispatch, before->getArguments());
518  }
519 
520  // Setup the async dispatch loop body: recursively call dispatch function
521  // for the seconds half of the original range and go to the next iteration.
522  {
523  b.setInsertionPointToEnd(after);
524  Value start = after->getArgument(0);
525  Value end = after->getArgument(1);
526  Value distance = arith::SubIOp::create(b, end, start);
527  Value halfDistance = arith::DivSIOp::create(b, distance, c2);
528  Value midIndex = arith::AddIOp::create(b, start, halfDistance);
529 
530  // Call parallel compute function inside the async.execute region.
531  auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
532  Location executeLoc, ValueRange executeArgs) {
533  // Update the original `blockStart` and `blockEnd` with new range.
534  SmallVector<Value> operands{block->getArguments().begin(),
535  block->getArguments().end()};
536  operands[1] = midIndex;
537  operands[2] = end;
538 
539  func::CallOp::create(executeBuilder, executeLoc, func.getSymName(),
540  func.getResultTypes(), operands);
541  async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
542  };
543 
544  // Create async.execute operation to dispatch half of the block range.
545  auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
546  executeBodyBuilder);
547  AddToGroupOp::create(b, indexTy, execute.getToken(), group);
548  scf::YieldOp::create(b, ValueRange({start, midIndex}));
549  }
550 
551  // After dispatching async operations to process the tail of the block range
552  // call the parallel compute function for the first block of the range.
553  b.setInsertionPointAfter(whileOp);
554 
555  // Drop async dispatch specific arguments: async group, block start and end.
556  auto forwardedInputs = block->getArguments().drop_front(3);
557  SmallVector<Value> computeFuncOperands = {blockStart};
558  computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
559 
560  func::CallOp::create(b, computeFunc.func.getSymName(),
561  computeFunc.func.getResultTypes(), computeFuncOperands);
562  func::ReturnOp::create(b, ValueRange());
563 
564  return func;
565 }
566 
567 // Launch async dispatch of the parallel compute function.
569  ParallelComputeFunction &parallelComputeFunction,
570  scf::ParallelOp op, Value blockSize,
571  Value blockCount,
572  const SmallVector<Value> &tripCounts) {
573  MLIRContext *ctx = op->getContext();
574 
575  // Add one more level of indirection to dispatch parallel compute functions
576  // using async operations and recursive work splitting.
577  func::FuncOp asyncDispatchFunction =
578  createAsyncDispatchFunction(parallelComputeFunction, rewriter);
579 
582 
583  // Appends operands shared by async dispatch and parallel compute functions to
584  // the given operands vector.
585  auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
586  operands.append(tripCounts);
587  operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
588  operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
589  operands.append(op.getStep().begin(), op.getStep().end());
590  operands.append(parallelComputeFunction.captures);
591  };
592 
593  // Check if the block size is one, in this case we can skip the async dispatch
594  // completely. If this will be known statically, then canonicalization will
595  // erase async group operations.
596  Value isSingleBlock =
597  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1);
598 
599  auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
600  ImplicitLocOpBuilder b(loc, nestedBuilder);
601 
602  // Call parallel compute function for the single block.
603  SmallVector<Value> operands = {c0, blockSize};
604  appendBlockComputeOperands(operands);
605 
606  func::CallOp::create(b, parallelComputeFunction.func.getSymName(),
607  parallelComputeFunction.func.getResultTypes(),
608  operands);
609  scf::YieldOp::create(b);
610  };
611 
612  auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
613  ImplicitLocOpBuilder b(loc, nestedBuilder);
614 
615  // Create an async.group to wait on all async tokens from the concurrent
616  // execution of multiple parallel compute function. First block will be
617  // executed synchronously in the caller thread.
618  Value groupSize = arith::SubIOp::create(b, blockCount, c1);
619  Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
620 
621  // Launch async dispatch function for [0, blockCount) range.
622  SmallVector<Value> operands = {group, c0, blockCount, blockSize};
623  appendBlockComputeOperands(operands);
624 
625  func::CallOp::create(b, asyncDispatchFunction.getSymName(),
626  asyncDispatchFunction.getResultTypes(), operands);
627 
628  // Wait for the completion of all parallel compute operations.
629  AwaitAllOp::create(b, group);
630 
631  scf::YieldOp::create(b);
632  };
633 
634  // Dispatch either single block compute function, or launch async dispatch.
635  scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch);
636 }
637 
638 // Dispatch parallel compute functions by submitting all async compute tasks
639 // from a simple for loop in the caller thread.
640 static void
642  ParallelComputeFunction &parallelComputeFunction,
643  scf::ParallelOp op, Value blockSize, Value blockCount,
644  const SmallVector<Value> &tripCounts) {
645  MLIRContext *ctx = op->getContext();
646 
647  func::FuncOp compute = parallelComputeFunction.func;
648 
651 
652  // Create an async.group to wait on all async tokens from the concurrent
653  // execution of multiple parallel compute function. First block will be
654  // executed synchronously in the caller thread.
655  Value groupSize = arith::SubIOp::create(b, blockCount, c1);
656  Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
657 
658  // Call parallel compute function for all blocks.
659  using LoopBodyBuilder =
660  std::function<void(OpBuilder &, Location, Value, ValueRange)>;
661 
662  // Returns parallel compute function operands to process the given block.
663  auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
664  SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
665  computeFuncOperands.append(tripCounts);
666  computeFuncOperands.append(op.getLowerBound().begin(),
667  op.getLowerBound().end());
668  computeFuncOperands.append(op.getUpperBound().begin(),
669  op.getUpperBound().end());
670  computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
671  computeFuncOperands.append(parallelComputeFunction.captures);
672  return computeFuncOperands;
673  };
674 
675  // Induction variable is the index of the block: [0, blockCount).
676  LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
677  Value iv, ValueRange args) {
678  ImplicitLocOpBuilder b(loc, loopBuilder);
679 
680  // Call parallel compute function inside the async.execute region.
681  auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
682  Location executeLoc, ValueRange executeArgs) {
683  func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
684  compute.getResultTypes(), computeFuncOperands(iv));
685  async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
686  };
687 
688  // Create async.execute operation to launch parallel computate function.
689  auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
690  executeBodyBuilder);
691  AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group);
692  scf::YieldOp::create(b);
693  };
694 
695  // Iterate over all compute blocks and launch parallel compute operations.
696  scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder);
697 
698  // Call parallel compute function for the first block in the caller thread.
699  func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(),
700  computeFuncOperands(c0));
701 
702  // Wait for the completion of all async compute operations.
703  AwaitAllOp::create(b, group);
704 }
705 
706 LogicalResult
707 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
708  PatternRewriter &rewriter) const {
709  // We do not currently support rewrite for parallel op with reductions.
710  if (op.getNumReductions() != 0)
711  return failure();
712 
713  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
714 
715  // Computing minTaskSize emits IR and can be implemented as executing a cost
716  // model on the body of the scf.parallel. Thus it needs to be computed before
717  // the body of the scf.parallel has been manipulated.
718  Value minTaskSize = computeMinTaskSize(b, op);
719 
720  // Make sure that all constants will be inside the parallel operation body to
721  // reduce the number of parallel compute function arguments.
722  cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
723 
724  // Compute trip count for each loop induction variable:
725  // tripCount = ceil_div(upperBound - lowerBound, step);
726  SmallVector<Value> tripCounts(op.getNumLoops());
727  for (size_t i = 0; i < op.getNumLoops(); ++i) {
728  auto lb = op.getLowerBound()[i];
729  auto ub = op.getUpperBound()[i];
730  auto step = op.getStep()[i];
731  auto range = b.createOrFold<arith::SubIOp>(ub, lb);
732  tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
733  }
734 
735  // Compute a product of trip counts to get the 1-dimensional iteration space
736  // for the scf.parallel operation.
737  Value tripCount = tripCounts[0];
738  for (size_t i = 1; i < tripCounts.size(); ++i)
739  tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
740 
741  // Short circuit no-op parallel loops (zero iterations) that can arise from
742  // the memrefs with dynamic dimension(s) equal to zero.
744  Value isZeroIterations =
745  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0);
746 
747  // Do absolutely nothing if the trip count is zero.
748  auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
749  scf::YieldOp::create(nestedBuilder, loc);
750  };
751 
752  // Compute the parallel block size and dispatch concurrent tasks computing
753  // results for each block.
754  auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
755  ImplicitLocOpBuilder b(loc, nestedBuilder);
756 
757  // Collect statically known constants defining the loop nest in the parallel
758  // compute function. LLVM can't always push constants across the non-trivial
759  // async dispatch call graph, by providing these values explicitly we can
760  // choose to build more efficient loop nest, and rely on a better constant
761  // folding, loop unrolling and vectorization.
762  ParallelComputeFunctionBounds staticBounds = {
763  integerConstants(tripCounts),
764  integerConstants(op.getLowerBound()),
765  integerConstants(op.getUpperBound()),
766  integerConstants(op.getStep()),
767  };
768 
769  // Find how many inner iteration dimensions are statically known, and their
770  // product is smaller than the `512`. We align the parallel compute block
771  // size by the product of statically known dimensions, so that we can
772  // guarantee that the inner loops executes from 0 to the loop trip counts
773  // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
774  // unroll the loops. The constant `512` is arbitrary, it should depend on
775  // how many iterations LLVM will typically decide to unroll.
776  static constexpr int64_t maxUnrollableIterations = 512;
777 
778  // The number of inner loops with statically known number of iterations less
779  // than the `maxUnrollableIterations` value.
780  int numUnrollableLoops = 0;
781 
782  auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
783 
784  SmallVector<int64_t> numIterations(op.getNumLoops());
785  numIterations.back() = getInt(staticBounds.tripCounts.back());
786 
787  for (int i = op.getNumLoops() - 2; i >= 0; --i) {
788  int64_t tripCount = getInt(staticBounds.tripCounts[i]);
789  int64_t innerIterations = numIterations[i + 1];
790  numIterations[i] = tripCount * innerIterations;
791 
792  // Update the number of inner loops that we can potentially unroll.
793  if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
794  numUnrollableLoops++;
795  }
796 
797  Value numWorkerThreadsVal;
798  if (numWorkerThreads >= 0)
799  numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads);
800  else
801  numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b);
802 
803  // With large number of threads the value of creating many compute blocks
804  // is reduced because the problem typically becomes memory bound. For this
805  // reason we scale the number of workers using an equivalent to the
806  // following logic:
807  // float overshardingFactor = numWorkerThreads <= 4 ? 8.0
808  // : numWorkerThreads <= 8 ? 4.0
809  // : numWorkerThreads <= 16 ? 2.0
810  // : numWorkerThreads <= 32 ? 1.0
811  // : numWorkerThreads <= 64 ? 0.8
812  // : 0.6;
813 
814  // Pairs of non-inclusive lower end of the bracket and factor that the
815  // number of workers needs to be scaled with if it falls in that bucket.
816  const SmallVector<std::pair<int, float>> overshardingBrackets = {
817  {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
818  const float initialOvershardingFactor = 8.0f;
819 
820  Value scalingFactor = arith::ConstantFloatOp::create(
821  b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
822  for (const std::pair<int, float> &p : overshardingBrackets) {
823  Value bracketBegin = arith::ConstantIndexOp::create(b, p.first);
824  Value inBracket = arith::CmpIOp::create(
825  b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
826  Value bracketScalingFactor = arith::ConstantFloatOp::create(
827  b, b.getF32Type(), llvm::APFloat(p.second));
828  scalingFactor = arith::SelectOp::create(
829  b, inBracket, bracketScalingFactor, scalingFactor);
830  }
831  Value numWorkersIndex =
832  arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal);
833  Value numWorkersFloat =
834  arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex);
835  Value scaledNumWorkers =
836  arith::MulFOp::create(b, scalingFactor, numWorkersFloat);
837  Value scaledNumInt =
838  arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers);
839  Value scaledWorkers =
840  arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt);
841 
842  Value maxComputeBlocks = arith::MaxSIOp::create(
843  b, arith::ConstantIndexOp::create(b, 1), scaledWorkers);
844 
845  // Compute parallel block size from the parallel problem size:
846  // blockSize = min(tripCount,
847  // max(ceil_div(tripCount, maxComputeBlocks),
848  // minTaskSize))
849  Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks);
850  Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize);
851  Value blockSize = arith::MinSIOp::create(b, tripCount, bs1);
852 
853  // Dispatch parallel compute function using async recursive work splitting,
854  // or by submitting compute task sequentially from a caller thread.
855  auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
856 
857  // Create a parallel compute function that takes a block id and computes
858  // the parallel operation body for a subset of iteration space.
859 
860  // Compute the number of parallel compute blocks.
861  Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize);
862 
863  // Dispatch parallel compute function without hints to unroll inner loops.
864  auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
865  ParallelComputeFunction compute =
866  createParallelComputeFunction(op, staticBounds, 0, rewriter);
867 
868  ImplicitLocOpBuilder b(loc, nestedBuilder);
869  doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
870  scf::YieldOp::create(b);
871  };
872 
873  // Dispatch parallel compute function with hints for unrolling inner loops.
874  auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
875  ParallelComputeFunction compute = createParallelComputeFunction(
876  op, staticBounds, numUnrollableLoops, rewriter);
877 
878  ImplicitLocOpBuilder b(loc, nestedBuilder);
879  // Align the block size to be a multiple of the statically known
880  // number of iterations in the inner loops.
882  b, numIterations[op.getNumLoops() - numUnrollableLoops]);
883  Value alignedBlockSize = arith::MulIOp::create(
884  b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters);
885  doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
886  tripCounts);
887  scf::YieldOp::create(b);
888  };
889 
890  // Dispatch to block aligned compute function only if the computed block
891  // size is larger than the number of iterations in the unrollable inner
892  // loops, because otherwise it can reduce the available parallelism.
893  if (numUnrollableLoops > 0) {
895  b, numIterations[op.getNumLoops() - numUnrollableLoops]);
896  Value useBlockAlignedComputeFn = arith::CmpIOp::create(
897  b, arith::CmpIPredicate::sge, blockSize, numIters);
898 
899  scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned,
900  dispatchDefault);
901  scf::YieldOp::create(b);
902  } else {
903  dispatchDefault(b, loc);
904  }
905  };
906 
907  // Replace the `scf.parallel` operation with the parallel compute function.
908  scf::IfOp::create(b, isZeroIterations, noOp, dispatch);
909 
910  // Parallel operation was replaced with a block iteration loop.
911  rewriter.eraseOp(op);
912 
913  return success();
914 }
915 
916 void AsyncParallelForPass::runOnOperation() {
917  MLIRContext *ctx = &getContext();
918 
921  patterns, asyncDispatch, numWorkerThreads,
922  [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
923  return arith::ConstantIndexOp::create(builder, minTaskSize);
924  });
925  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
926  signalPassFailure();
927 }
928 
930  RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
931  const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {
932  MLIRContext *ctx = patterns.getContext();
933  patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
934  computeMinTaskSize);
935 }
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
BlockArgListType getArguments()
Definition: Block.h:87
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition: ArithOps.cpp:330
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
Definition: Transforms.h:29
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Definition: PassDetail.cpp:15
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314