MLIR 23.0.0git
AsyncParallelFor.cpp
Go to the documentation of this file.
1//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements scf.parallel to scf.for + async.execute conversion pass.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "PassDetail.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
24#include "mlir/Support/LLVM.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include <utility>
29
30namespace mlir {
31#define GEN_PASS_DEF_ASYNCPARALLELFORPASS
32#include "mlir/Dialect/Async/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::async;
37
38#define DEBUG_TYPE "async-parallel-for"
39
40namespace {
41
42// Rewrite scf.parallel operation into multiple concurrent async.execute
43// operations over non overlapping subranges of the original loop.
44//
45// Example:
46//
47// scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
48// "do_some_compute"(%i, %j): () -> ()
49// }
50//
51// Converted to:
52//
53// // Parallel compute function that executes the parallel body region for
54// // a subset of the parallel iteration space defined by the one-dimensional
55// // compute block index.
56// func parallel_compute_function(%block_index : index, %block_size : index,
57// <parallel operation properties>, ...) {
58// // Compute multi-dimensional loop bounds for %block_index.
59// %block_lbi, %block_lbj = ...
60// %block_ubi, %block_ubj = ...
61//
62// // Clone parallel operation body into the scf.for loop nest.
63// scf.for %i = %blockLbi to %blockUbi {
64// scf.for %j = block_lbj to %block_ubj {
65// "do_some_compute"(%i, %j): () -> ()
66// }
67// }
68// }
69//
70// And a dispatch function depending on the `asyncDispatch` option.
71//
72// When async dispatch is on: (pseudocode)
73//
74// %block_size = ... compute parallel compute block size
75// %block_count = ... compute the number of compute blocks
76//
77// func @async_dispatch(%block_start : index, %block_end : index, ...) {
78// // Keep splitting block range until we reached a range of size 1.
79// while (%block_end - %block_start > 1) {
80// %mid_index = block_start + (block_end - block_start) / 2;
81// async.execute { call @async_dispatch(%mid_index, %block_end); }
82// %block_end = %mid_index
83// }
84//
85// // Call parallel compute function for a single block.
86// call @parallel_compute_fn(%block_start, %block_size, ...);
87// }
88//
89// // Launch async dispatch for [0, block_count) range.
90// call @async_dispatch(%c0, %block_count);
91//
92// When async dispatch is off:
93//
94// %block_size = ... compute parallel compute block size
95// %block_count = ... compute the number of compute blocks
96//
97// scf.for %block_index = %c0 to %block_count {
98// call @parallel_compute_fn(%block_index, %block_size, ...)
99// }
100//
101struct AsyncParallelForPass
102 : public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
103 using Base::Base;
104
105 void runOnOperation() override;
106};
107
108struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
109public:
110 AsyncParallelForRewrite(
111 MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
112 AsyncMinTaskSizeComputationFunction computeMinTaskSize)
113 : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
114 numWorkerThreads(numWorkerThreads),
115 computeMinTaskSize(std::move(computeMinTaskSize)) {}
117 LogicalResult matchAndRewrite(scf::ParallelOp op,
118 PatternRewriter &rewriter) const override;
119
120private:
126struct ParallelComputeFunctionType {
127 FunctionType type;
129};
130
131// Helper struct to parse parallel compute function argument list.
132struct ParallelComputeFunctionArgs {
133 BlockArgument blockIndex();
134 BlockArgument blockSize();
135 ArrayRef<BlockArgument> tripCounts();
140 unsigned numLoops;
143
144struct ParallelComputeFunctionBounds {
145 SmallVector<IntegerAttr> tripCounts;
146 SmallVector<IntegerAttr> lowerBounds;
150
151struct ParallelComputeFunction {
152 unsigned numLoops;
153 func::FuncOp func;
155};
156
157} // namespace
158
159BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
160BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
161
162ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
163 return args.drop_front(2).take_front(numLoops);
165
166ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
167 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
168}
169
170ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
171 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
173
174ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
175 return args.drop_front(2 + 4 * numLoops);
176}
177
178template <typename ValueRange>
180 SmallVector<IntegerAttr> attrs(values.size());
181 for (unsigned i = 0; i < values.size(); ++i)
182 matchPattern(values[i], m_Constant(&attrs[i]));
183 return attrs;
184}
186// Converts one-dimensional iteration index in the [0, tripCount) interval
187// into multidimensional iteration coordinate.
189 ArrayRef<Value> tripCounts) {
190 SmallVector<Value> coords(tripCounts.size());
191 assert(!tripCounts.empty() && "tripCounts must be not empty");
192
193 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
194 coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]);
195 index = arith::DivSIOp::create(b, index, tripCounts[i]);
196 }
197
198 return coords;
199}
200
201// Returns a function type and implicit captures for a parallel compute
202// function. We'll need a list of implicit captures to setup block and value
203// mapping when we'll clone the body of the parallel operation.
204static ParallelComputeFunctionType
205getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
206 // Values implicitly captured by the parallel operation.
207 llvm::SetVector<Value> captures;
208 getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures);
209
210 SmallVector<Type> inputs;
211 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
212
213 Type indexTy = rewriter.getIndexType();
214
215 // One-dimensional iteration space defined by the block index and size.
216 inputs.push_back(indexTy); // blockIndex
217 inputs.push_back(indexTy); // blockSize
218
219 // Multi-dimensional parallel iteration space defined by the loop trip counts.
220 for (unsigned i = 0; i < op.getNumLoops(); ++i)
221 inputs.push_back(indexTy); // loop tripCount
222
223 // Parallel operation lower bound, upper bound and step. Lower bound, upper
224 // bound and step passed as contiguous arguments:
225 // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
226 for (unsigned i = 0; i < op.getNumLoops(); ++i) {
227 inputs.push_back(indexTy); // lower bound
228 inputs.push_back(indexTy); // upper bound
229 inputs.push_back(indexTy); // step
230 }
231
232 // Types of the implicit captures.
233 for (Value capture : captures)
234 inputs.push_back(capture.getType());
235
236 // Convert captures to vector for later convenience.
237 SmallVector<Value> capturesVector(captures.begin(), captures.end());
238 return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
239}
240
241// Create a parallel compute fuction from the parallel operation.
242static ParallelComputeFunction createParallelComputeFunction(
243 scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds,
244 unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {
245 OpBuilder::InsertionGuard guard(rewriter);
246 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
247
248 ModuleOp module = op->getParentOfType<ModuleOp>();
249
250 ParallelComputeFunctionType computeFuncType =
251 getParallelComputeFunctionType(op, rewriter);
252
253 FunctionType type = computeFuncType.type;
254 func::FuncOp func = func::FuncOp::create(
255 op.getLoc(),
256 numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops"
257 : "parallel_compute_fn",
258 type);
259 func.setPrivate();
260
261 // Insert function into the module symbol table and assign it unique name.
262 SymbolTable symbolTable(module);
263 symbolTable.insert(func);
264 rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
265
266 // Create function entry block.
267 Block *block =
268 b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
269 SmallVector<Location>(type.getNumInputs(), op.getLoc()));
270 b.setInsertionPointToEnd(block);
271
272 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
273
274 // Block iteration position defined by the block index and size.
275 BlockArgument blockIndex = args.blockIndex();
276 BlockArgument blockSize = args.blockSize();
277
278 // Constants used below.
281
282 // Materialize known constants as constant operation in the function body.
283 auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
284 return llvm::map_to_vector(llvm::zip(args, attrs),
285 [&](auto tuple) -> Value {
286 if (IntegerAttr attr = std::get<1>(tuple))
287 return arith::ConstantOp::create(b, attr);
288 return std::get<0>(tuple);
289 });
290 };
291
292 // Multi-dimensional parallel iteration space defined by the loop trip counts.
293 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
294
295 // Parallel operation lower bound and step.
296 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
297 auto steps = values(args.steps(), bounds.steps);
298
299 // Remaining arguments are implicit captures of the parallel operation.
300 ArrayRef<BlockArgument> captures = args.captures();
301
302 // Compute a product of trip counts to get the size of the flattened
303 // one-dimensional iteration space.
304 Value tripCount = tripCounts[0];
305 for (unsigned i = 1; i < tripCounts.size(); ++i)
306 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
307
308 // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
309 // blockFirstIndex = blockIndex * blockSize
310 Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize);
311
312 // The last one-dimensional index in the block defined by the `blockIndex`:
313 // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
314 Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize);
315 Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount);
316 Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1);
317
318 // Convert one-dimensional indices to multi-dimensional coordinates.
319 auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
320 auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
321
322 // Compute loops upper bounds derived from the block last coordinates:
323 // blockEndCoord[i] = blockLastCoord[i] + 1
324 //
325 // Block first and last coordinates can be the same along the outer compute
326 // dimension when inner compute dimension contains multiple blocks.
327 SmallVector<Value> blockEndCoord(op.getNumLoops());
328 for (size_t i = 0; i < blockLastCoord.size(); ++i)
329 blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1);
330
331 // Construct a loop nest out of scf.for operations that will iterate over
332 // all coordinates in [blockFirstCoord, blockLastCoord] range.
333 using LoopBodyBuilder =
334 std::function<void(OpBuilder &, Location, Value, ValueRange)>;
335 using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
336
337 // Parallel region induction variables computed from the multi-dimensional
338 // iteration coordinate using parallel operation bounds and step:
339 //
340 // computeBlockInductionVars[loopIdx] =
341 // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx]
342 SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
343
344 // We need to know if we are in the first or last iteration of the
345 // multi-dimensional loop for each loop in the nest, so we can decide what
346 // loop bounds should we use for the nested loops: bounds defined by compute
347 // block interval, or bounds defined by the parallel operation.
348 //
349 // Example: 2d parallel operation
350 // i j
351 // loop sizes: [50, 50]
352 // first coord: [25, 25]
353 // last coord: [30, 30]
354 //
355 // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
356 // is between 25 and 30 it should start at 0. The upper bound for `j` should
357 // be 50, except when `i` is equal to 30, then it should also be 30.
358 //
359 // Value at ith position specifies if all loops in [0, i) range of the loop
360 // nest are in the first/last iteration.
361 SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
362 SmallVector<Value> isBlockLastCoord(op.getNumLoops());
363
364 // Builds inner loop nest inside async.execute operation that does all the
365 // work concurrently.
366 LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
367 return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
368 ValueRange args) {
369 ImplicitLocOpBuilder b(loc, nestedBuilder);
370
371 // Compute induction variable for `loopIdx`.
372 computeBlockInductionVars[loopIdx] =
373 arith::AddIOp::create(b, lowerBounds[loopIdx],
374 arith::MulIOp::create(b, iv, steps[loopIdx]));
375
376 // Check if we are inside first or last iteration of the loop.
377 isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
378 b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
379 isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
380 b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
381
382 // Check if the previous loop is in its first or last iteration.
383 if (loopIdx > 0) {
384 isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
385 b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
386 isBlockLastCoord[loopIdx] = arith::AndIOp::create(
387 b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
388 }
389
390 // Keep building loop nest.
391 if (loopIdx < op.getNumLoops() - 1) {
392 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
393 // For block aligned loops we always iterate starting from 0 up to
394 // the loop trip counts.
395 scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(),
396 workLoopBuilder(loopIdx + 1));
397
398 } else {
399 // Select nested loop lower/upper bounds depending on our position in
400 // the multi-dimensional iteration space.
401 auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx],
402 blockFirstCoord[loopIdx + 1], c0);
403
404 auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx],
405 blockEndCoord[loopIdx + 1],
406 tripCounts[loopIdx + 1]);
407
408 scf::ForOp::create(b, lb, ub, c1, ValueRange(),
409 workLoopBuilder(loopIdx + 1));
410 }
411
412 scf::YieldOp::create(b, loc);
413 return;
414 }
415
416 // Copy the body of the parallel op into the inner-most loop.
417 IRMapping mapping;
418 mapping.map(op.getInductionVars(), computeBlockInductionVars);
419 mapping.map(computeFuncType.captures, captures);
420
421 for (auto &bodyOp : op.getRegion().front().without_terminator())
422 b.clone(bodyOp, mapping);
423 scf::YieldOp::create(b, loc);
424 };
425 };
426
427 scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
428 workLoopBuilder(0));
429 func::ReturnOp::create(b, ValueRange());
430
431 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
432}
433
434// Creates recursive async dispatch function for the given parallel compute
435// function. Dispatch function keeps splitting block range into halves until it
436// reaches a single block, and then excecutes it inline.
437//
438// Function pseudocode (mix of C++ and MLIR):
439//
440// func @async_dispatch(%block_start : index, %block_end : index, ...) {
441//
442// // Keep splitting block range until we reached a range of size 1.
443// while (%block_end - %block_start > 1) {
444// %mid_index = block_start + (block_end - block_start) / 2;
445// async.execute { call @async_dispatch(%mid_index, %block_end); }
446// %block_end = %mid_index
447// }
448//
449// // Call parallel compute function for a single block.
450// call @parallel_compute_fn(%block_start, %block_size, ...);
451// }
452//
453static func::FuncOp
454createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
455 PatternRewriter &rewriter) {
456 OpBuilder::InsertionGuard guard(rewriter);
457 Location loc = computeFunc.func.getLoc();
458 ImplicitLocOpBuilder b(loc, rewriter);
459
460 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
461
462 ArrayRef<Type> computeFuncInputTypes =
463 computeFunc.func.getFunctionType().getInputs();
464
465 // Compared to the parallel compute function async dispatch function takes
466 // additional !async.group argument. Also instead of a single `blockIndex` it
467 // takes `blockStart` and `blockEnd` arguments to define the range of
468 // dispatched blocks.
469 SmallVector<Type> inputTypes;
470 inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
471 inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
472 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
473
474 FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
475 func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn", type);
476 func.setPrivate();
477
478 // Insert function into the module symbol table and assign it unique name.
479 SymbolTable symbolTable(module);
480 symbolTable.insert(func);
481 rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
482
483 // Create function entry block.
484 Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
485 SmallVector<Location>(type.getNumInputs(), loc));
486 b.setInsertionPointToEnd(block);
487
488 Type indexTy = b.getIndexType();
491
492 // Get the async group that will track async dispatch completion.
493 Value group = block->getArgument(0);
494
495 // Get the block iteration range: [blockStart, blockEnd)
496 Value blockStart = block->getArgument(1);
497 Value blockEnd = block->getArgument(2);
498
499 // Create a work splitting while loop for the [blockStart, blockEnd) range.
500 SmallVector<Type> types = {indexTy, indexTy};
501 SmallVector<Value> operands = {blockStart, blockEnd};
502 SmallVector<Location> locations = {loc, loc};
503
504 // Create a recursive dispatch loop.
505 scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands);
506 Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations);
507 Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations);
508
509 // Setup dispatch loop condition block: decide if we need to go into the
510 // `after` block and launch one more async dispatch.
511 {
512 b.setInsertionPointToEnd(before);
513 Value start = before->getArgument(0);
514 Value end = before->getArgument(1);
515 Value distance = arith::SubIOp::create(b, end, start);
516 Value dispatch =
517 arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1);
518 scf::ConditionOp::create(b, dispatch, before->getArguments());
519 }
520
521 // Setup the async dispatch loop body: recursively call dispatch function
522 // for the seconds half of the original range and go to the next iteration.
523 {
524 b.setInsertionPointToEnd(after);
525 Value start = after->getArgument(0);
526 Value end = after->getArgument(1);
527 Value distance = arith::SubIOp::create(b, end, start);
528 Value halfDistance = arith::DivSIOp::create(b, distance, c2);
529 Value midIndex = arith::AddIOp::create(b, start, halfDistance);
530
531 // Call parallel compute function inside the async.execute region.
532 auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
533 Location executeLoc, ValueRange executeArgs) {
534 // Update the original `blockStart` and `blockEnd` with new range.
535 SmallVector<Value> operands{block->getArguments().begin(),
536 block->getArguments().end()};
537 operands[1] = midIndex;
538 operands[2] = end;
539
540 func::CallOp::create(executeBuilder, executeLoc, func.getSymName(),
541 func.getResultTypes(), operands);
542 async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
543 };
544
545 // Create async.execute operation to dispatch half of the block range.
546 auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
547 executeBodyBuilder);
548 AddToGroupOp::create(b, indexTy, execute.getToken(), group);
549 scf::YieldOp::create(b, ValueRange({start, midIndex}));
550 }
551
552 // After dispatching async operations to process the tail of the block range
553 // call the parallel compute function for the first block of the range.
554 b.setInsertionPointAfter(whileOp);
555
556 // Drop async dispatch specific arguments: async group, block start and end.
557 auto forwardedInputs = block->getArguments().drop_front(3);
558 SmallVector<Value> computeFuncOperands = {blockStart};
559 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
560
561 func::CallOp::create(b, computeFunc.func.getSymName(),
562 computeFunc.func.getResultTypes(), computeFuncOperands);
563 func::ReturnOp::create(b, ValueRange());
564
565 return func;
566}
567
568// Launch async dispatch of the parallel compute function.
570 ParallelComputeFunction &parallelComputeFunction,
571 scf::ParallelOp op, Value blockSize,
572 Value blockCount,
573 const SmallVector<Value> &tripCounts) {
574 MLIRContext *ctx = op->getContext();
575
576 // Add one more level of indirection to dispatch parallel compute functions
577 // using async operations and recursive work splitting.
578 func::FuncOp asyncDispatchFunction =
579 createAsyncDispatchFunction(parallelComputeFunction, rewriter);
580
583
584 // Appends operands shared by async dispatch and parallel compute functions to
585 // the given operands vector.
586 auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
587 operands.append(tripCounts);
588 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
589 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
590 operands.append(op.getStep().begin(), op.getStep().end());
591 operands.append(parallelComputeFunction.captures);
592 };
593
594 // Check if the block size is one, in this case we can skip the async dispatch
595 // completely. If this will be known statically, then canonicalization will
596 // erase async group operations.
597 Value isSingleBlock =
598 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1);
599
600 auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
601 ImplicitLocOpBuilder b(loc, nestedBuilder);
602
603 // Call parallel compute function for the single block.
604 SmallVector<Value> operands = {c0, blockSize};
605 appendBlockComputeOperands(operands);
606
607 func::CallOp::create(b, parallelComputeFunction.func.getSymName(),
608 parallelComputeFunction.func.getResultTypes(),
609 operands);
610 scf::YieldOp::create(b);
611 };
612
613 auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
614 ImplicitLocOpBuilder b(loc, nestedBuilder);
615
616 // Create an async.group to wait on all async tokens from the concurrent
617 // execution of multiple parallel compute function. First block will be
618 // executed synchronously in the caller thread.
619 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
620 Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
621
622 // Launch async dispatch function for [0, blockCount) range.
623 SmallVector<Value> operands = {group, c0, blockCount, blockSize};
624 appendBlockComputeOperands(operands);
625
626 func::CallOp::create(b, asyncDispatchFunction.getSymName(),
627 asyncDispatchFunction.getResultTypes(), operands);
628
629 // Wait for the completion of all parallel compute operations.
630 AwaitAllOp::create(b, group);
631
632 scf::YieldOp::create(b);
633 };
634
635 // Dispatch either single block compute function, or launch async dispatch.
636 scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch);
637}
638
639// Dispatch parallel compute functions by submitting all async compute tasks
640// from a simple for loop in the caller thread.
641static void
643 ParallelComputeFunction &parallelComputeFunction,
644 scf::ParallelOp op, Value blockSize, Value blockCount,
645 const SmallVector<Value> &tripCounts) {
646 MLIRContext *ctx = op->getContext();
647
648 func::FuncOp compute = parallelComputeFunction.func;
649
652
653 // Create an async.group to wait on all async tokens from the concurrent
654 // execution of multiple parallel compute function. First block will be
655 // executed synchronously in the caller thread.
656 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
657 Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
658
659 // Call parallel compute function for all blocks.
660 using LoopBodyBuilder =
661 std::function<void(OpBuilder &, Location, Value, ValueRange)>;
662
663 // Returns parallel compute function operands to process the given block.
664 auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
665 SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
666 computeFuncOperands.append(tripCounts);
667 computeFuncOperands.append(op.getLowerBound().begin(),
668 op.getLowerBound().end());
669 computeFuncOperands.append(op.getUpperBound().begin(),
670 op.getUpperBound().end());
671 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
672 computeFuncOperands.append(parallelComputeFunction.captures);
673 return computeFuncOperands;
674 };
675
676 // Induction variable is the index of the block: [0, blockCount).
677 LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
678 Value iv, ValueRange args) {
679 ImplicitLocOpBuilder b(loc, loopBuilder);
680
681 // Call parallel compute function inside the async.execute region.
682 auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
683 Location executeLoc, ValueRange executeArgs) {
684 func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
685 compute.getResultTypes(), computeFuncOperands(iv));
686 async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
687 };
688
689 // Create async.execute operation to launch parallel computate function.
690 auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
691 executeBodyBuilder);
692 AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group);
693 scf::YieldOp::create(b);
694 };
695
696 // Iterate over all compute blocks and launch parallel compute operations.
697 scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder);
698
699 // Call parallel compute function for the first block in the caller thread.
700 func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(),
701 computeFuncOperands(c0));
702
703 // Wait for the completion of all async compute operations.
704 AwaitAllOp::create(b, group);
705}
706
707LogicalResult
708AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
709 PatternRewriter &rewriter) const {
710 // We do not currently support rewrite for parallel op with reductions.
711 if (op.getNumReductions() != 0)
712 return failure();
713
714 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
715
716 // Computing minTaskSize emits IR and can be implemented as executing a cost
717 // model on the body of the scf.parallel. Thus it needs to be computed before
718 // the body of the scf.parallel has been manipulated.
719 Value minTaskSize = computeMinTaskSize(b, op);
720
721 // Make sure that all constants will be inside the parallel operation body to
722 // reduce the number of parallel compute function arguments.
723 cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
724
725 // Compute trip count for each loop induction variable:
726 // tripCount = ceil_div(upperBound - lowerBound, step);
727 SmallVector<Value> tripCounts(op.getNumLoops());
728 for (size_t i = 0; i < op.getNumLoops(); ++i) {
729 auto lb = op.getLowerBound()[i];
730 auto ub = op.getUpperBound()[i];
731 auto step = op.getStep()[i];
732 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
733 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
734 }
735
736 // Compute a product of trip counts to get the 1-dimensional iteration space
737 // for the scf.parallel operation.
738 Value tripCount = tripCounts[0];
739 for (size_t i = 1; i < tripCounts.size(); ++i)
740 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
741
742 // Short circuit no-op parallel loops (zero iterations) that can arise from
743 // the memrefs with dynamic dimension(s) equal to zero.
744 Value c0 = arith::ConstantIndexOp::create(b, 0);
745 Value isZeroIterations =
746 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0);
747
748 // Do absolutely nothing if the trip count is zero.
749 auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
750 scf::YieldOp::create(nestedBuilder, loc);
751 };
752
753 // Compute the parallel block size and dispatch concurrent tasks computing
754 // results for each block.
755 auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
756 ImplicitLocOpBuilder b(loc, nestedBuilder);
757
758 // Collect statically known constants defining the loop nest in the parallel
759 // compute function. LLVM can't always push constants across the non-trivial
760 // async dispatch call graph, by providing these values explicitly we can
761 // choose to build more efficient loop nest, and rely on a better constant
762 // folding, loop unrolling and vectorization.
763 ParallelComputeFunctionBounds staticBounds = {
764 integerConstants(tripCounts),
765 integerConstants(op.getLowerBound()),
766 integerConstants(op.getUpperBound()),
767 integerConstants(op.getStep()),
768 };
769
770 // Find how many inner iteration dimensions are statically known, and their
771 // product is smaller than the `512`. We align the parallel compute block
772 // size by the product of statically known dimensions, so that we can
773 // guarantee that the inner loops executes from 0 to the loop trip counts
774 // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
775 // unroll the loops. The constant `512` is arbitrary, it should depend on
776 // how many iterations LLVM will typically decide to unroll.
777 static constexpr int64_t maxUnrollableIterations = 512;
778
779 // The number of inner loops with statically known number of iterations less
780 // than the `maxUnrollableIterations` value.
781 int numUnrollableLoops = 0;
782
783 auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
784
785 SmallVector<int64_t> numIterations(op.getNumLoops());
786 numIterations.back() = getInt(staticBounds.tripCounts.back());
787
788 for (int i = op.getNumLoops() - 2; i >= 0; --i) {
789 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
790 int64_t innerIterations = numIterations[i + 1];
791 numIterations[i] = tripCount * innerIterations;
792
793 // Update the number of inner loops that we can potentially unroll.
794 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
795 numUnrollableLoops++;
796 }
797
798 Value numWorkerThreadsVal;
799 if (numWorkerThreads >= 0)
800 numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads);
801 else
802 numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b);
803
804 // With large number of threads the value of creating many compute blocks
805 // is reduced because the problem typically becomes memory bound. For this
806 // reason we scale the number of workers using an equivalent to the
807 // following logic:
808 // float overshardingFactor = numWorkerThreads <= 4 ? 8.0
809 // : numWorkerThreads <= 8 ? 4.0
810 // : numWorkerThreads <= 16 ? 2.0
811 // : numWorkerThreads <= 32 ? 1.0
812 // : numWorkerThreads <= 64 ? 0.8
813 // : 0.6;
814
815 // Pairs of non-inclusive lower end of the bracket and factor that the
816 // number of workers needs to be scaled with if it falls in that bucket.
817 const SmallVector<std::pair<int, float>> overshardingBrackets = {
818 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
819 const float initialOvershardingFactor = 8.0f;
820
821 Value scalingFactor = arith::ConstantFloatOp::create(
822 b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
823 for (const std::pair<int, float> &p : overshardingBrackets) {
824 Value bracketBegin = arith::ConstantIndexOp::create(b, p.first);
825 Value inBracket = arith::CmpIOp::create(
826 b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
827 Value bracketScalingFactor = arith::ConstantFloatOp::create(
828 b, b.getF32Type(), llvm::APFloat(p.second));
829 scalingFactor = arith::SelectOp::create(
830 b, inBracket, bracketScalingFactor, scalingFactor);
831 }
832 Value numWorkersIndex =
833 arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal);
834 Value numWorkersFloat =
835 arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex);
836 Value scaledNumWorkers =
837 arith::MulFOp::create(b, scalingFactor, numWorkersFloat);
838 Value scaledNumInt =
839 arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers);
840 Value scaledWorkers =
841 arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt);
842
843 Value maxComputeBlocks = arith::MaxSIOp::create(
844 b, arith::ConstantIndexOp::create(b, 1), scaledWorkers);
845
846 // Compute parallel block size from the parallel problem size:
847 // blockSize = min(tripCount,
848 // max(ceil_div(tripCount, maxComputeBlocks),
849 // minTaskSize))
850 Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks);
851 Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize);
852 Value blockSize = arith::MinSIOp::create(b, tripCount, bs1);
853
854 // Dispatch parallel compute function using async recursive work splitting,
855 // or by submitting compute task sequentially from a caller thread.
856 auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
857
858 // Create a parallel compute function that takes a block id and computes
859 // the parallel operation body for a subset of iteration space.
860
861 // Compute the number of parallel compute blocks.
862 Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize);
863
864 // Dispatch parallel compute function without hints to unroll inner loops.
865 auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
866 ParallelComputeFunction compute =
867 createParallelComputeFunction(op, staticBounds, 0, rewriter);
868
869 ImplicitLocOpBuilder b(loc, nestedBuilder);
870 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
871 scf::YieldOp::create(b);
872 };
873
874 // Dispatch parallel compute function with hints for unrolling inner loops.
875 auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
876 ParallelComputeFunction compute = createParallelComputeFunction(
877 op, staticBounds, numUnrollableLoops, rewriter);
878
879 ImplicitLocOpBuilder b(loc, nestedBuilder);
880 // Align the block size to be a multiple of the statically known
881 // number of iterations in the inner loops.
882 Value numIters = arith::ConstantIndexOp::create(
883 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
884 Value alignedBlockSize = arith::MulIOp::create(
885 b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters);
886 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
887 tripCounts);
888 scf::YieldOp::create(b);
889 };
890
891 // Dispatch to block aligned compute function only if the computed block
892 // size is larger than the number of iterations in the unrollable inner
893 // loops, because otherwise it can reduce the available parallelism.
894 if (numUnrollableLoops > 0) {
895 Value numIters = arith::ConstantIndexOp::create(
896 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
897 Value useBlockAlignedComputeFn = arith::CmpIOp::create(
898 b, arith::CmpIPredicate::sge, blockSize, numIters);
899
900 scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned,
901 dispatchDefault);
902 scf::YieldOp::create(b);
903 } else {
904 dispatchDefault(b, loc);
905 }
906 };
907
908 // Replace the `scf.parallel` operation with the parallel compute function.
909 scf::IfOp::create(b, isZeroIterations, noOp, dispatch);
910
911 // Parallel operation was replaced with a block iteration loop.
912 rewriter.eraseOp(op);
913
914 return success();
915}
916
917void AsyncParallelForPass::runOnOperation() {
918 MLIRContext *ctx = &getContext();
919
920 RewritePatternSet patterns(ctx);
922 patterns, asyncDispatch, numWorkerThreads,
923 [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
924 return arith::ConstantIndexOp::create(builder, minTaskSize);
925 });
926 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
927 signalPassFailure();
928}
929
931 RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
932 const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {
933 MLIRContext *ctx = patterns.getContext();
934 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
935 computeMinTaskSize);
936}
return success()
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction &parallelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
BlockArgListType getArguments()
Definition Block.h:97
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
::mlir::Pass::Option< int32_t > numWorkerThreads
::mlir::Pass::Option< bool > asyncDispatch
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
Definition Transforms.h:28
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...