30 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
31 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
32 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
34 Region &opBody = warpOp.getBodyRegion();
35 Region &newOpBody = newWarpOp.getBodyRegion();
39 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
40 "expected WarpOp with single block");
43 cast<gpu::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
46 yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
56 warpOp.getResultTypes().end());
57 auto yield = cast<gpu::YieldOp>(
58 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
59 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
60 yield.getOperands().end());
61 for (
auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
62 if (yieldValues.insert(value)) {
63 types.push_back(type);
64 indices.push_back(yieldValues.size() - 1);
67 for (
auto [idx, yieldOperand] :
69 if (yieldOperand == value) {
70 indices.push_back(idx);
76 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
78 rewriter, warpOp, yieldValues.getArrayRef(), types);
80 newWarpOp.getResults().take_front(warpOp.getNumResults()));
85 WarpExecuteOnLane0Op warpOp,
87 auto yield = cast<gpu::YieldOp>(
88 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
89 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
90 Value yieldValues = yieldOperand.get();
92 if (definedOp && fn(definedOp)) {
93 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
107 if (originalShape == distributedShape) {
108 delinearizedIds.clear();
113 for (
auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
114 if (large % small != 0)
116 sizes.push_back(large / small);
118 if (std::accumulate(sizes.begin(), sizes.end(), 1,
119 std::multiplies<int64_t>()) != warpSize)
125 int64_t usedThreads = 1;
128 delinearizedIds.assign(sizes.size(), zero);
130 for (
int i = sizes.size() - 1; i >= 0; --i) {
131 usedThreads *= sizes[i];
132 if (usedThreads == warpSize) {
135 delinearizedIds[i] = laneId;
141 builder, loc, s0.
floorDiv(usedThreads), {laneId});
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Block represents an ordered list of Operations.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.