MLIR 22.0.0git
AsyncParallelFor.cpp
Go to the documentation of this file.
1//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements scf.parallel to scf.for + async.execute conversion pass.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "PassDetail.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
24#include "mlir/Support/LLVM.h"
27#include <utility>
28
29namespace mlir {
30#define GEN_PASS_DEF_ASYNCPARALLELFORPASS
31#include "mlir/Dialect/Async/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::async;
36
37#define DEBUG_TYPE "async-parallel-for"
38
39namespace {
40
41// Rewrite scf.parallel operation into multiple concurrent async.execute
42// operations over non overlapping subranges of the original loop.
43//
44// Example:
45//
46// scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
47// "do_some_compute"(%i, %j): () -> ()
48// }
49//
50// Converted to:
51//
52// // Parallel compute function that executes the parallel body region for
53// // a subset of the parallel iteration space defined by the one-dimensional
54// // compute block index.
55// func parallel_compute_function(%block_index : index, %block_size : index,
56// <parallel operation properties>, ...) {
57// // Compute multi-dimensional loop bounds for %block_index.
58// %block_lbi, %block_lbj = ...
59// %block_ubi, %block_ubj = ...
60//
61// // Clone parallel operation body into the scf.for loop nest.
62// scf.for %i = %blockLbi to %blockUbi {
63// scf.for %j = block_lbj to %block_ubj {
64// "do_some_compute"(%i, %j): () -> ()
65// }
66// }
67// }
68//
69// And a dispatch function depending on the `asyncDispatch` option.
70//
71// When async dispatch is on: (pseudocode)
72//
73// %block_size = ... compute parallel compute block size
74// %block_count = ... compute the number of compute blocks
75//
76// func @async_dispatch(%block_start : index, %block_end : index, ...) {
77// // Keep splitting block range until we reached a range of size 1.
78// while (%block_end - %block_start > 1) {
79// %mid_index = block_start + (block_end - block_start) / 2;
80// async.execute { call @async_dispatch(%mid_index, %block_end); }
81// %block_end = %mid_index
82// }
83//
84// // Call parallel compute function for a single block.
85// call @parallel_compute_fn(%block_start, %block_size, ...);
86// }
87//
88// // Launch async dispatch for [0, block_count) range.
89// call @async_dispatch(%c0, %block_count);
90//
91// When async dispatch is off:
92//
93// %block_size = ... compute parallel compute block size
94// %block_count = ... compute the number of compute blocks
95//
96// scf.for %block_index = %c0 to %block_count {
97// call @parallel_compute_fn(%block_index, %block_size, ...)
98// }
99//
100struct AsyncParallelForPass
101 : public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
102 using Base::Base;
103
104 void runOnOperation() override;
105};
106
107struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
108public:
109 AsyncParallelForRewrite(
110 MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
111 AsyncMinTaskSizeComputationFunction computeMinTaskSize)
112 : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
113 numWorkerThreads(numWorkerThreads),
114 computeMinTaskSize(std::move(computeMinTaskSize)) {}
115
116 LogicalResult matchAndRewrite(scf::ParallelOp op,
117 PatternRewriter &rewriter) const override;
119private:
125struct ParallelComputeFunctionType {
126 FunctionType type;
127 SmallVector<Value> captures;
129
130// Helper struct to parse parallel compute function argument list.
131struct ParallelComputeFunctionArgs {
132 BlockArgument blockIndex();
133 BlockArgument blockSize();
134 ArrayRef<BlockArgument> tripCounts();
135 ArrayRef<BlockArgument> lowerBounds();
139 unsigned numLoops;
141};
143struct ParallelComputeFunctionBounds {
144 SmallVector<IntegerAttr> tripCounts;
145 SmallVector<IntegerAttr> lowerBounds;
146 SmallVector<IntegerAttr> upperBounds;
148};
149
150struct ParallelComputeFunction {
151 unsigned numLoops;
152 func::FuncOp func;
154};
155
156} // namespace
157
158BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
159BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
160
161ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
162 return args.drop_front(2).take_front(numLoops);
163}
165ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
166 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
167}
168
169ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
170 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
173ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
174 return args.drop_front(2 + 4 * numLoops);
176
177template <typename ValueRange>
179 SmallVector<IntegerAttr> attrs(values.size());
180 for (unsigned i = 0; i < values.size(); ++i)
181 matchPattern(values[i], m_Constant(&attrs[i]));
182 return attrs;
183}
184
185// Converts one-dimensional iteration index in the [0, tripCount) interval
186// into multidimensional iteration coordinate.
188 ArrayRef<Value> tripCounts) {
189 SmallVector<Value> coords(tripCounts.size());
190 assert(!tripCounts.empty() && "tripCounts must be not empty");
191
192 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
193 coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]);
194 index = arith::DivSIOp::create(b, index, tripCounts[i]);
195 }
196
197 return coords;
198}
199
200// Returns a function type and implicit captures for a parallel compute
201// function. We'll need a list of implicit captures to setup block and value
202// mapping when we'll clone the body of the parallel operation.
203static ParallelComputeFunctionType
204getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
205 // Values implicitly captured by the parallel operation.
206 llvm::SetVector<Value> captures;
207 getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures);
208
209 SmallVector<Type> inputs;
210 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
211
212 Type indexTy = rewriter.getIndexType();
213
214 // One-dimensional iteration space defined by the block index and size.
215 inputs.push_back(indexTy); // blockIndex
216 inputs.push_back(indexTy); // blockSize
217
218 // Multi-dimensional parallel iteration space defined by the loop trip counts.
219 for (unsigned i = 0; i < op.getNumLoops(); ++i)
220 inputs.push_back(indexTy); // loop tripCount
221
222 // Parallel operation lower bound, upper bound and step. Lower bound, upper
223 // bound and step passed as contiguous arguments:
224 // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
225 for (unsigned i = 0; i < op.getNumLoops(); ++i) {
226 inputs.push_back(indexTy); // lower bound
227 inputs.push_back(indexTy); // upper bound
228 inputs.push_back(indexTy); // step
229 }
230
231 // Types of the implicit captures.
232 for (Value capture : captures)
233 inputs.push_back(capture.getType());
234
235 // Convert captures to vector for later convenience.
236 SmallVector<Value> capturesVector(captures.begin(), captures.end());
237 return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
238}
239
240// Create a parallel compute fuction from the parallel operation.
241static ParallelComputeFunction createParallelComputeFunction(
242 scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds,
243 unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {
244 OpBuilder::InsertionGuard guard(rewriter);
245 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
246
247 ModuleOp module = op->getParentOfType<ModuleOp>();
248
249 ParallelComputeFunctionType computeFuncType =
250 getParallelComputeFunctionType(op, rewriter);
251
252 FunctionType type = computeFuncType.type;
253 func::FuncOp func = func::FuncOp::create(
254 op.getLoc(),
255 numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops"
256 : "parallel_compute_fn",
257 type);
258 func.setPrivate();
259
260 // Insert function into the module symbol table and assign it unique name.
261 SymbolTable symbolTable(module);
262 symbolTable.insert(func);
263 rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
264
265 // Create function entry block.
266 Block *block =
267 b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
268 SmallVector<Location>(type.getNumInputs(), op.getLoc()));
269 b.setInsertionPointToEnd(block);
270
271 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
272
273 // Block iteration position defined by the block index and size.
274 BlockArgument blockIndex = args.blockIndex();
275 BlockArgument blockSize = args.blockSize();
276
277 // Constants used below.
280
281 // Materialize known constants as constant operation in the function body.
282 auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
283 return llvm::to_vector(
284 llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
285 if (IntegerAttr attr = std::get<1>(tuple))
286 return arith::ConstantOp::create(b, attr);
287 return std::get<0>(tuple);
288 }));
289 };
290
291 // Multi-dimensional parallel iteration space defined by the loop trip counts.
292 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
293
294 // Parallel operation lower bound and step.
295 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
296 auto steps = values(args.steps(), bounds.steps);
297
298 // Remaining arguments are implicit captures of the parallel operation.
299 ArrayRef<BlockArgument> captures = args.captures();
300
301 // Compute a product of trip counts to get the size of the flattened
302 // one-dimensional iteration space.
303 Value tripCount = tripCounts[0];
304 for (unsigned i = 1; i < tripCounts.size(); ++i)
305 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
306
307 // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
308 // blockFirstIndex = blockIndex * blockSize
309 Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize);
310
311 // The last one-dimensional index in the block defined by the `blockIndex`:
312 // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
313 Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize);
314 Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount);
315 Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1);
316
317 // Convert one-dimensional indices to multi-dimensional coordinates.
318 auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
319 auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
320
321 // Compute loops upper bounds derived from the block last coordinates:
322 // blockEndCoord[i] = blockLastCoord[i] + 1
323 //
324 // Block first and last coordinates can be the same along the outer compute
325 // dimension when inner compute dimension contains multiple blocks.
326 SmallVector<Value> blockEndCoord(op.getNumLoops());
327 for (size_t i = 0; i < blockLastCoord.size(); ++i)
328 blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1);
329
330 // Construct a loop nest out of scf.for operations that will iterate over
331 // all coordinates in [blockFirstCoord, blockLastCoord] range.
332 using LoopBodyBuilder =
333 std::function<void(OpBuilder &, Location, Value, ValueRange)>;
334 using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
335
336 // Parallel region induction variables computed from the multi-dimensional
337 // iteration coordinate using parallel operation bounds and step:
338 //
339 // computeBlockInductionVars[loopIdx] =
340 // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx]
341 SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
342
343 // We need to know if we are in the first or last iteration of the
344 // multi-dimensional loop for each loop in the nest, so we can decide what
345 // loop bounds should we use for the nested loops: bounds defined by compute
346 // block interval, or bounds defined by the parallel operation.
347 //
348 // Example: 2d parallel operation
349 // i j
350 // loop sizes: [50, 50]
351 // first coord: [25, 25]
352 // last coord: [30, 30]
353 //
354 // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
355 // is between 25 and 30 it should start at 0. The upper bound for `j` should
356 // be 50, except when `i` is equal to 30, then it should also be 30.
357 //
358 // Value at ith position specifies if all loops in [0, i) range of the loop
359 // nest are in the first/last iteration.
360 SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
361 SmallVector<Value> isBlockLastCoord(op.getNumLoops());
362
363 // Builds inner loop nest inside async.execute operation that does all the
364 // work concurrently.
365 LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
366 return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
367 ValueRange args) {
368 ImplicitLocOpBuilder b(loc, nestedBuilder);
369
370 // Compute induction variable for `loopIdx`.
371 computeBlockInductionVars[loopIdx] =
372 arith::AddIOp::create(b, lowerBounds[loopIdx],
373 arith::MulIOp::create(b, iv, steps[loopIdx]));
374
375 // Check if we are inside first or last iteration of the loop.
376 isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
377 b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
378 isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
379 b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
380
381 // Check if the previous loop is in its first or last iteration.
382 if (loopIdx > 0) {
383 isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
384 b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
385 isBlockLastCoord[loopIdx] = arith::AndIOp::create(
386 b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
387 }
388
389 // Keep building loop nest.
390 if (loopIdx < op.getNumLoops() - 1) {
391 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
392 // For block aligned loops we always iterate starting from 0 up to
393 // the loop trip counts.
394 scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(),
395 workLoopBuilder(loopIdx + 1));
396
397 } else {
398 // Select nested loop lower/upper bounds depending on our position in
399 // the multi-dimensional iteration space.
400 auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx],
401 blockFirstCoord[loopIdx + 1], c0);
402
403 auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx],
404 blockEndCoord[loopIdx + 1],
405 tripCounts[loopIdx + 1]);
406
407 scf::ForOp::create(b, lb, ub, c1, ValueRange(),
408 workLoopBuilder(loopIdx + 1));
409 }
410
411 scf::YieldOp::create(b, loc);
412 return;
413 }
414
415 // Copy the body of the parallel op into the inner-most loop.
416 IRMapping mapping;
417 mapping.map(op.getInductionVars(), computeBlockInductionVars);
418 mapping.map(computeFuncType.captures, captures);
419
420 for (auto &bodyOp : op.getRegion().front().without_terminator())
421 b.clone(bodyOp, mapping);
422 scf::YieldOp::create(b, loc);
423 };
424 };
425
426 scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
427 workLoopBuilder(0));
428 func::ReturnOp::create(b, ValueRange());
429
430 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
431}
432
433// Creates recursive async dispatch function for the given parallel compute
434// function. Dispatch function keeps splitting block range into halves until it
435// reaches a single block, and then excecutes it inline.
436//
437// Function pseudocode (mix of C++ and MLIR):
438//
439// func @async_dispatch(%block_start : index, %block_end : index, ...) {
440//
441// // Keep splitting block range until we reached a range of size 1.
442// while (%block_end - %block_start > 1) {
443// %mid_index = block_start + (block_end - block_start) / 2;
444// async.execute { call @async_dispatch(%mid_index, %block_end); }
445// %block_end = %mid_index
446// }
447//
448// // Call parallel compute function for a single block.
449// call @parallel_compute_fn(%block_start, %block_size, ...);
450// }
451//
452static func::FuncOp
453createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
454 PatternRewriter &rewriter) {
455 OpBuilder::InsertionGuard guard(rewriter);
456 Location loc = computeFunc.func.getLoc();
457 ImplicitLocOpBuilder b(loc, rewriter);
458
459 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
460
461 ArrayRef<Type> computeFuncInputTypes =
462 computeFunc.func.getFunctionType().getInputs();
463
464 // Compared to the parallel compute function async dispatch function takes
465 // additional !async.group argument. Also instead of a single `blockIndex` it
466 // takes `blockStart` and `blockEnd` arguments to define the range of
467 // dispatched blocks.
468 SmallVector<Type> inputTypes;
469 inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
470 inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
471 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
472
473 FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
474 func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn", type);
475 func.setPrivate();
476
477 // Insert function into the module symbol table and assign it unique name.
478 SymbolTable symbolTable(module);
479 symbolTable.insert(func);
480 rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
481
482 // Create function entry block.
483 Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
484 SmallVector<Location>(type.getNumInputs(), loc));
485 b.setInsertionPointToEnd(block);
486
487 Type indexTy = b.getIndexType();
490
491 // Get the async group that will track async dispatch completion.
492 Value group = block->getArgument(0);
493
494 // Get the block iteration range: [blockStart, blockEnd)
495 Value blockStart = block->getArgument(1);
496 Value blockEnd = block->getArgument(2);
497
498 // Create a work splitting while loop for the [blockStart, blockEnd) range.
499 SmallVector<Type> types = {indexTy, indexTy};
500 SmallVector<Value> operands = {blockStart, blockEnd};
501 SmallVector<Location> locations = {loc, loc};
502
503 // Create a recursive dispatch loop.
504 scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands);
505 Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations);
506 Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations);
507
508 // Setup dispatch loop condition block: decide if we need to go into the
509 // `after` block and launch one more async dispatch.
510 {
511 b.setInsertionPointToEnd(before);
512 Value start = before->getArgument(0);
513 Value end = before->getArgument(1);
514 Value distance = arith::SubIOp::create(b, end, start);
515 Value dispatch =
516 arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1);
517 scf::ConditionOp::create(b, dispatch, before->getArguments());
518 }
519
520 // Setup the async dispatch loop body: recursively call dispatch function
521 // for the seconds half of the original range and go to the next iteration.
522 {
523 b.setInsertionPointToEnd(after);
524 Value start = after->getArgument(0);
525 Value end = after->getArgument(1);
526 Value distance = arith::SubIOp::create(b, end, start);
527 Value halfDistance = arith::DivSIOp::create(b, distance, c2);
528 Value midIndex = arith::AddIOp::create(b, start, halfDistance);
529
530 // Call parallel compute function inside the async.execute region.
531 auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
532 Location executeLoc, ValueRange executeArgs) {
533 // Update the original `blockStart` and `blockEnd` with new range.
534 SmallVector<Value> operands{block->getArguments().begin(),
535 block->getArguments().end()};
536 operands[1] = midIndex;
537 operands[2] = end;
538
539 func::CallOp::create(executeBuilder, executeLoc, func.getSymName(),
540 func.getResultTypes(), operands);
541 async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
542 };
543
544 // Create async.execute operation to dispatch half of the block range.
545 auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
546 executeBodyBuilder);
547 AddToGroupOp::create(b, indexTy, execute.getToken(), group);
548 scf::YieldOp::create(b, ValueRange({start, midIndex}));
549 }
550
551 // After dispatching async operations to process the tail of the block range
552 // call the parallel compute function for the first block of the range.
553 b.setInsertionPointAfter(whileOp);
554
555 // Drop async dispatch specific arguments: async group, block start and end.
556 auto forwardedInputs = block->getArguments().drop_front(3);
557 SmallVector<Value> computeFuncOperands = {blockStart};
558 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
559
560 func::CallOp::create(b, computeFunc.func.getSymName(),
561 computeFunc.func.getResultTypes(), computeFuncOperands);
562 func::ReturnOp::create(b, ValueRange());
563
564 return func;
565}
566
567// Launch async dispatch of the parallel compute function.
569 ParallelComputeFunction &parallelComputeFunction,
570 scf::ParallelOp op, Value blockSize,
571 Value blockCount,
572 const SmallVector<Value> &tripCounts) {
573 MLIRContext *ctx = op->getContext();
574
575 // Add one more level of indirection to dispatch parallel compute functions
576 // using async operations and recursive work splitting.
577 func::FuncOp asyncDispatchFunction =
578 createAsyncDispatchFunction(parallelComputeFunction, rewriter);
579
582
583 // Appends operands shared by async dispatch and parallel compute functions to
584 // the given operands vector.
585 auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
586 operands.append(tripCounts);
587 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
588 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
589 operands.append(op.getStep().begin(), op.getStep().end());
590 operands.append(parallelComputeFunction.captures);
591 };
592
593 // Check if the block size is one, in this case we can skip the async dispatch
594 // completely. If this will be known statically, then canonicalization will
595 // erase async group operations.
596 Value isSingleBlock =
597 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1);
598
599 auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
600 ImplicitLocOpBuilder b(loc, nestedBuilder);
601
602 // Call parallel compute function for the single block.
603 SmallVector<Value> operands = {c0, blockSize};
604 appendBlockComputeOperands(operands);
605
606 func::CallOp::create(b, parallelComputeFunction.func.getSymName(),
607 parallelComputeFunction.func.getResultTypes(),
608 operands);
609 scf::YieldOp::create(b);
610 };
611
612 auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
613 ImplicitLocOpBuilder b(loc, nestedBuilder);
614
615 // Create an async.group to wait on all async tokens from the concurrent
616 // execution of multiple parallel compute function. First block will be
617 // executed synchronously in the caller thread.
618 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
619 Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
620
621 // Launch async dispatch function for [0, blockCount) range.
622 SmallVector<Value> operands = {group, c0, blockCount, blockSize};
623 appendBlockComputeOperands(operands);
624
625 func::CallOp::create(b, asyncDispatchFunction.getSymName(),
626 asyncDispatchFunction.getResultTypes(), operands);
627
628 // Wait for the completion of all parallel compute operations.
629 AwaitAllOp::create(b, group);
630
631 scf::YieldOp::create(b);
632 };
633
634 // Dispatch either single block compute function, or launch async dispatch.
635 scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch);
636}
637
638// Dispatch parallel compute functions by submitting all async compute tasks
639// from a simple for loop in the caller thread.
640static void
642 ParallelComputeFunction &parallelComputeFunction,
643 scf::ParallelOp op, Value blockSize, Value blockCount,
644 const SmallVector<Value> &tripCounts) {
645 MLIRContext *ctx = op->getContext();
646
647 func::FuncOp compute = parallelComputeFunction.func;
648
651
652 // Create an async.group to wait on all async tokens from the concurrent
653 // execution of multiple parallel compute function. First block will be
654 // executed synchronously in the caller thread.
655 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
656 Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
657
658 // Call parallel compute function for all blocks.
659 using LoopBodyBuilder =
660 std::function<void(OpBuilder &, Location, Value, ValueRange)>;
661
662 // Returns parallel compute function operands to process the given block.
663 auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
664 SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
665 computeFuncOperands.append(tripCounts);
666 computeFuncOperands.append(op.getLowerBound().begin(),
667 op.getLowerBound().end());
668 computeFuncOperands.append(op.getUpperBound().begin(),
669 op.getUpperBound().end());
670 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
671 computeFuncOperands.append(parallelComputeFunction.captures);
672 return computeFuncOperands;
673 };
674
675 // Induction variable is the index of the block: [0, blockCount).
676 LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
677 Value iv, ValueRange args) {
678 ImplicitLocOpBuilder b(loc, loopBuilder);
679
680 // Call parallel compute function inside the async.execute region.
681 auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
682 Location executeLoc, ValueRange executeArgs) {
683 func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
684 compute.getResultTypes(), computeFuncOperands(iv));
685 async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
686 };
687
688 // Create async.execute operation to launch parallel computate function.
689 auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
690 executeBodyBuilder);
691 AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group);
692 scf::YieldOp::create(b);
693 };
694
695 // Iterate over all compute blocks and launch parallel compute operations.
696 scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder);
697
698 // Call parallel compute function for the first block in the caller thread.
699 func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(),
700 computeFuncOperands(c0));
701
702 // Wait for the completion of all async compute operations.
703 AwaitAllOp::create(b, group);
704}
705
706LogicalResult
707AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
708 PatternRewriter &rewriter) const {
709 // We do not currently support rewrite for parallel op with reductions.
710 if (op.getNumReductions() != 0)
711 return failure();
712
713 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
714
715 // Computing minTaskSize emits IR and can be implemented as executing a cost
716 // model on the body of the scf.parallel. Thus it needs to be computed before
717 // the body of the scf.parallel has been manipulated.
718 Value minTaskSize = computeMinTaskSize(b, op);
719
720 // Make sure that all constants will be inside the parallel operation body to
721 // reduce the number of parallel compute function arguments.
722 cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
723
724 // Compute trip count for each loop induction variable:
725 // tripCount = ceil_div(upperBound - lowerBound, step);
726 SmallVector<Value> tripCounts(op.getNumLoops());
727 for (size_t i = 0; i < op.getNumLoops(); ++i) {
728 auto lb = op.getLowerBound()[i];
729 auto ub = op.getUpperBound()[i];
730 auto step = op.getStep()[i];
731 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
732 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
733 }
734
735 // Compute a product of trip counts to get the 1-dimensional iteration space
736 // for the scf.parallel operation.
737 Value tripCount = tripCounts[0];
738 for (size_t i = 1; i < tripCounts.size(); ++i)
739 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
740
741 // Short circuit no-op parallel loops (zero iterations) that can arise from
742 // the memrefs with dynamic dimension(s) equal to zero.
743 Value c0 = arith::ConstantIndexOp::create(b, 0);
744 Value isZeroIterations =
745 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0);
746
747 // Do absolutely nothing if the trip count is zero.
748 auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
749 scf::YieldOp::create(nestedBuilder, loc);
750 };
751
752 // Compute the parallel block size and dispatch concurrent tasks computing
753 // results for each block.
754 auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
755 ImplicitLocOpBuilder b(loc, nestedBuilder);
756
757 // Collect statically known constants defining the loop nest in the parallel
758 // compute function. LLVM can't always push constants across the non-trivial
759 // async dispatch call graph, by providing these values explicitly we can
760 // choose to build more efficient loop nest, and rely on a better constant
761 // folding, loop unrolling and vectorization.
762 ParallelComputeFunctionBounds staticBounds = {
763 integerConstants(tripCounts),
764 integerConstants(op.getLowerBound()),
765 integerConstants(op.getUpperBound()),
766 integerConstants(op.getStep()),
767 };
768
769 // Find how many inner iteration dimensions are statically known, and their
770 // product is smaller than the `512`. We align the parallel compute block
771 // size by the product of statically known dimensions, so that we can
772 // guarantee that the inner loops executes from 0 to the loop trip counts
773 // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
774 // unroll the loops. The constant `512` is arbitrary, it should depend on
775 // how many iterations LLVM will typically decide to unroll.
776 static constexpr int64_t maxUnrollableIterations = 512;
777
778 // The number of inner loops with statically known number of iterations less
779 // than the `maxUnrollableIterations` value.
780 int numUnrollableLoops = 0;
781
782 auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
783
784 SmallVector<int64_t> numIterations(op.getNumLoops());
785 numIterations.back() = getInt(staticBounds.tripCounts.back());
786
787 for (int i = op.getNumLoops() - 2; i >= 0; --i) {
788 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
789 int64_t innerIterations = numIterations[i + 1];
790 numIterations[i] = tripCount * innerIterations;
791
792 // Update the number of inner loops that we can potentially unroll.
793 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
794 numUnrollableLoops++;
795 }
796
797 Value numWorkerThreadsVal;
798 if (numWorkerThreads >= 0)
799 numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads);
800 else
801 numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b);
802
803 // With large number of threads the value of creating many compute blocks
804 // is reduced because the problem typically becomes memory bound. For this
805 // reason we scale the number of workers using an equivalent to the
806 // following logic:
807 // float overshardingFactor = numWorkerThreads <= 4 ? 8.0
808 // : numWorkerThreads <= 8 ? 4.0
809 // : numWorkerThreads <= 16 ? 2.0
810 // : numWorkerThreads <= 32 ? 1.0
811 // : numWorkerThreads <= 64 ? 0.8
812 // : 0.6;
813
814 // Pairs of non-inclusive lower end of the bracket and factor that the
815 // number of workers needs to be scaled with if it falls in that bucket.
816 const SmallVector<std::pair<int, float>> overshardingBrackets = {
817 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
818 const float initialOvershardingFactor = 8.0f;
819
820 Value scalingFactor = arith::ConstantFloatOp::create(
821 b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
822 for (const std::pair<int, float> &p : overshardingBrackets) {
823 Value bracketBegin = arith::ConstantIndexOp::create(b, p.first);
824 Value inBracket = arith::CmpIOp::create(
825 b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
826 Value bracketScalingFactor = arith::ConstantFloatOp::create(
827 b, b.getF32Type(), llvm::APFloat(p.second));
828 scalingFactor = arith::SelectOp::create(
829 b, inBracket, bracketScalingFactor, scalingFactor);
830 }
831 Value numWorkersIndex =
832 arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal);
833 Value numWorkersFloat =
834 arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex);
835 Value scaledNumWorkers =
836 arith::MulFOp::create(b, scalingFactor, numWorkersFloat);
837 Value scaledNumInt =
838 arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers);
839 Value scaledWorkers =
840 arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt);
841
842 Value maxComputeBlocks = arith::MaxSIOp::create(
843 b, arith::ConstantIndexOp::create(b, 1), scaledWorkers);
844
845 // Compute parallel block size from the parallel problem size:
846 // blockSize = min(tripCount,
847 // max(ceil_div(tripCount, maxComputeBlocks),
848 // minTaskSize))
849 Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks);
850 Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize);
851 Value blockSize = arith::MinSIOp::create(b, tripCount, bs1);
852
853 // Dispatch parallel compute function using async recursive work splitting,
854 // or by submitting compute task sequentially from a caller thread.
855 auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
856
857 // Create a parallel compute function that takes a block id and computes
858 // the parallel operation body for a subset of iteration space.
859
860 // Compute the number of parallel compute blocks.
861 Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize);
862
863 // Dispatch parallel compute function without hints to unroll inner loops.
864 auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
865 ParallelComputeFunction compute =
866 createParallelComputeFunction(op, staticBounds, 0, rewriter);
867
868 ImplicitLocOpBuilder b(loc, nestedBuilder);
869 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
870 scf::YieldOp::create(b);
871 };
872
873 // Dispatch parallel compute function with hints for unrolling inner loops.
874 auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
875 ParallelComputeFunction compute = createParallelComputeFunction(
876 op, staticBounds, numUnrollableLoops, rewriter);
877
878 ImplicitLocOpBuilder b(loc, nestedBuilder);
879 // Align the block size to be a multiple of the statically known
880 // number of iterations in the inner loops.
881 Value numIters = arith::ConstantIndexOp::create(
882 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
883 Value alignedBlockSize = arith::MulIOp::create(
884 b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters);
885 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
886 tripCounts);
887 scf::YieldOp::create(b);
888 };
889
890 // Dispatch to block aligned compute function only if the computed block
891 // size is larger than the number of iterations in the unrollable inner
892 // loops, because otherwise it can reduce the available parallelism.
893 if (numUnrollableLoops > 0) {
894 Value numIters = arith::ConstantIndexOp::create(
895 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
896 Value useBlockAlignedComputeFn = arith::CmpIOp::create(
897 b, arith::CmpIPredicate::sge, blockSize, numIters);
898
899 scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned,
900 dispatchDefault);
901 scf::YieldOp::create(b);
902 } else {
903 dispatchDefault(b, loc);
904 }
905 };
906
907 // Replace the `scf.parallel` operation with the parallel compute function.
908 scf::IfOp::create(b, isZeroIterations, noOp, dispatch);
909
910 // Parallel operation was replaced with a block iteration loop.
911 rewriter.eraseOp(op);
912
913 return success();
914}
915
916void AsyncParallelForPass::runOnOperation() {
917 MLIRContext *ctx = &getContext();
918
919 RewritePatternSet patterns(ctx);
921 patterns, asyncDispatch, numWorkerThreads,
922 [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
923 return arith::ConstantIndexOp::create(builder, minTaskSize);
924 });
925 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
926 signalPassFailure();
927}
928
930 RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
931 const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {
932 MLIRContext *ctx = patterns.getContext();
933 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
934 computeMinTaskSize);
935}
return success()
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
BlockArgListType getArguments()
Definition Block.h:87
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:330
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
::mlir::Pass::Option< int32_t > numWorkerThreads
::mlir::Pass::Option< bool > asyncDispatch
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
Definition Transforms.h:28
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:298
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...