MLIR 22.0.0git
DistributionUtils.cpp
Go to the documentation of this file.
1//===- DistributionUtils.cpp - Distribution tools for GPUOps --------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements distribution utility methods.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Value.h"
17#include "llvm/ADT/DenseMap.h"
18#include "llvm/ADT/STLExtras.h"
19
20#include <numeric>
21
22using namespace mlir;
23using namespace mlir::gpu;
24
25WarpExecuteOnLane0Op
27 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
28 ValueRange newYieldedValues, TypeRange newReturnTypes) const {
29 // Create a new op before the existing one, with the extra operands.
30 OpBuilder::InsertionGuard g(rewriter);
31 rewriter.setInsertionPoint(warpOp);
32 auto newWarpOp = WarpExecuteOnLane0Op::create(
33 rewriter, warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(),
34 warpOp.getWarpSize(), warpOp.getArgs(),
35 warpOp.getBody()->getArgumentTypes());
36
37 Region &opBody = warpOp.getBodyRegion();
38 Region &newOpBody = newWarpOp.getBodyRegion();
39 Block &newOpFirstBlock = newOpBody.front();
40 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
41 rewriter.eraseBlock(&newOpFirstBlock);
42 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
43 "expected WarpOp with single block");
44
45 auto yield =
46 cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
47
48 rewriter.modifyOpInPlace(
49 yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
50 return newWarpOp;
51}
52
53WarpExecuteOnLane0Op
55 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
56 ValueRange newYieldedValues, TypeRange newReturnTypes,
58 SmallVector<Type> types(warpOp.getResultTypes().begin(),
59 warpOp.getResultTypes().end());
60 gpu::YieldOp yield = warpOp.getTerminator();
61 SmallVector<Value> yieldValues(yield.getOperands().begin(),
62 yield.getOperands().end());
63 llvm::SmallDenseMap<Value, unsigned> indexLookup;
64 // Record the value -> first index mapping for faster lookup.
65 for (auto [i, v] : llvm::enumerate(yieldValues)) {
66 if (!indexLookup.count(v))
67 indexLookup[v] = i;
68 }
69
70 for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
71 // If the value already exists in the yield, don't create a new output.
72 if (indexLookup.count(value)) {
73 indices.push_back(indexLookup[value]);
74 } else {
75 // If the value is new, add it to the yield and to the types.
76 yieldValues.push_back(value);
77 types.push_back(type);
78 indices.push_back(yieldValues.size() - 1);
79 }
80 }
81
82 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
83 rewriter, warpOp, yieldValues, types);
84 rewriter.replaceOp(warpOp,
85 newWarpOp.getResults().take_front(warpOp.getNumResults()));
86 return newWarpOp;
87}
88
90 WarpExecuteOnLane0Op warpOp,
91 llvm::function_ref<bool(Operation *)> fn) const {
92 gpu::YieldOp yield = warpOp.getTerminator();
93 for (OpOperand &yieldOperand : yield->getOpOperands()) {
94 Value yieldValues = yieldOperand.get();
95 Operation *definedOp = yieldValues.getDefiningOp();
96 if (definedOp && fn(definedOp)) {
97 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
98 return &yieldOperand;
99 }
100 }
101 return nullptr;
102}
103
105 OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape,
106 ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId,
107 SmallVectorImpl<Value> &delinearizedIds) const {
108 // If the original shape and the distributed shape is the same, we don't
109 // distribute at all--every thread is handling the whole. For such case, we
110 // should not rely on lane IDs later. So just return an empty lane ID vector.
111 if (originalShape == distributedShape) {
112 delinearizedIds.clear();
113 return true;
114 }
115
117 for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
118 if (large % small != 0)
119 return false;
120 sizes.push_back(large / small);
121 }
122 if (llvm::product_of(sizes) != warpSize)
123 return false;
124
125 AffineExpr s0, s1;
126 bindSymbols(builder.getContext(), s0, s1);
127
128 int64_t usedThreads = 1;
129
130 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
131 delinearizedIds.assign(sizes.size(), zero);
132
133 for (int i = sizes.size() - 1; i >= 0; --i) {
134 usedThreads *= sizes[i];
135 if (usedThreads == warpSize) {
136 // We've used up all available threads. Don't need to perform modulo
137 // anymore. And we can stop the calculation for further dimensions.
138 delinearizedIds[i] = laneId;
139 break;
140 }
141 delinearizedIds[i] =
142 affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
144 builder, loc, s0.floorDiv(usedThreads), {laneId});
145 }
146 return true;
147}
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Block represents an ordered list of Operations.
Definition Block.h:33
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
BlockListType & getBlocks()
Definition Region.h:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Include the generated interface declarations.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.