17 #include "llvm/ADT/SetVector.h" 26 assert(warpOp.getBodyRegion().hasOneBlock() &&
27 "expected WarpOp with single block");
28 Block *warpOpBody = &warpOp.getBodyRegion().
front();
37 Value isLane0 = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38 warpOp.getLaneid(), c0);
39 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
41 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
47 Value val = it.value();
48 Value bbArg = warpOpBody->getArgument(it.index());
59 loc, warpOp.getLaneid(),
61 rewriter.
create<vector::StoreOp>(loc, val, buffer, storeOffset);
65 auto bbArgType = bbArg.
getType().
cast<VectorType>();
66 Value loadOp = rewriter.
create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67 bbArgReplacements.push_back(loadOp);
71 if (!warpOp.getArgs().empty()) {
77 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
81 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82 Location yieldLoc = yieldOp.getLoc();
84 Value val = it.value();
85 Type resultType = warpOp->getResultTypes()[it.index()];
93 rewriter.
create<vector::StoreOp>(yieldLoc, val, buffer, c0);
95 rewriter.
create<memref::StoreOp>(yieldLoc, val, buffer, c0);
99 if (resultType == val.getType()) {
107 Value loadOp = rewriter.
create<memref::LoadOp>(loc, buffer, c0);
108 replacements.push_back(loadOp);
110 auto loadedVectorType = resultType.
cast<VectorType>();
111 int64_t loadSize = loadedVectorType.getShape()[0];
115 loc, warpOp.getLaneid(),
117 Value loadOp = rewriter.
create<vector::LoadOp>(loc, loadedVectorType,
119 replacements.push_back(loadOp);
124 if (!yieldOp.operands().empty()) {
132 rewriter.
create<scf::YieldOp>(yieldLoc);
135 rewriter.
replaceOp(warpOp, replacements);
147 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
148 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
151 Region &opBody = warpOp.getBodyRegion();
152 Region &newOpBody = newWarpOp.getBodyRegion();
156 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157 "expected WarpOp with single block");
160 cast<vector::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
163 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
174 warpOp.getResultTypes().end());
175 auto yield = cast<vector::YieldOp>(
176 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178 yield.getOperands().end());
179 for (
auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180 if (yieldValues.insert(std::get<0>(newRet))) {
181 types.push_back(std::get<1>(newRet));
182 indices.push_back(yieldValues.size() - 1);
185 for (
auto &yieldOperand :
llvm::enumerate(yieldValues.getArrayRef())) {
186 if (yieldOperand.value() == std::get<0>(newRet)) {
187 indices.push_back(yieldOperand.index());
193 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
195 rewriter, warpOp, yieldValues.getArrayRef(), types);
197 newWarpOp.getResults().take_front(warpOp.getNumResults()));
204 return llvm::all_of(op->
getOperands(), definedOutside) &&
212 auto yield = cast<vector::YieldOp>(
213 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
214 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
215 Value yieldValues = yieldOperand.get();
217 if (definedOp && fn(definedOp)) {
218 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
219 return &yieldOperand;
233 return rewriter.
create(res);
254 for (
unsigned i = 0, e = srcType.getRank(); i < e; i++) {
255 if (srcType.getDimSize(i) != dstType.getDimSize(i))
264 struct WarpOpToScfForPattern :
public OpRewritePattern<WarpExecuteOnLane0Op> {
284 static vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
285 WarpExecuteOnLane0Op warpOp,
286 vector::TransferWriteOp writeOp,
287 VectorType targetType) {
288 assert(writeOp->getParentOp() == warpOp &&
289 "write must be nested immediately under warp");
293 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
297 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
299 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
320 struct WarpOpTransferWrite :
public OpRewritePattern<vector::TransferWriteOp> {
324 distributionMapFn(std::move(fn)) {}
329 vector::TransferWriteOp writeOp,
330 WarpExecuteOnLane0Op warpOp)
const {
331 VectorType writtenVectorType = writeOp.getVectorType();
335 if (writtenVectorType.getRank() == 0)
339 AffineMap map = distributionMapFn(writeOp);
341 return writeOp->emitError(
"multi-dim distribution not implemented yet");
345 writtenVectorType.getShape().end());
348 if (targetShape[position] % warpOp.getWarpSize() != 0)
350 targetShape[position] = targetShape[position] / warpOp.getWarpSize();
352 VectorType targetType =
353 VectorType::get(targetShape, writtenVectorType.getElementType());
357 vector::TransferWriteOp newWriteOp =
358 cloneWriteOp(rewriter, warpOp, writeOp, targetType);
362 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
367 newWriteOp.getIndices().end());
370 bindDims(newWarpOp.getContext(), d0, d1);
371 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
374 unsigned indexPos = indexExpr.getPosition();
375 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
379 {indices[indexPos], newWarpOp.getLaneid()});
381 newWriteOp.getIndicesMutable().assign(indices);
388 vector::TransferWriteOp writeOp,
389 WarpExecuteOnLane0Op warpOp)
const {
391 VectorType vecType = writeOp.getVectorType();
395 if (vecType.getNumElements() != 1)
399 if (llvm::all_of(warpOp.getOps(), [](
Operation &op) {
400 return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
408 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
412 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
413 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414 Block &body = secondWarpOp.getBodyRegion().
front();
417 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
418 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
420 rewriter.
create<vector::YieldOp>(newWarpOp.getLoc());
424 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
427 if (writeOp.getMask())
430 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
435 Operation *nextOp = writeOp.getOperation();
436 while ((nextOp = nextOp->getNextNode()))
440 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
441 return writeOp.getVector() ==
value ||
442 warpOp.isDefinedOutsideOfRegion(
value);
446 if (
succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
449 if (
succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
488 Value distributedVal = warpOp.getResult(operandIndex);
496 auto operandType = operand.get().getType().
cast<VectorType>();
498 VectorType::get(vecType.getShape(), operandType.getElementType());
500 auto operandType = operand.get().getType();
501 assert(!operandType.isa<VectorType>() &&
502 "unexpected yield of vector from op with scalar result type");
503 targetType = operandType;
505 retTypes.push_back(targetType);
506 yieldValues.push_back(operand.get());
510 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
514 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
515 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
520 rewriter, loc, elementWise, newOperands,
546 warpOp, [](
Operation *op) {
return isa<arith::ConstantOp>(op); });
556 warpOp.getResult(operandIndex).getType(), scalarAttr);
559 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
588 warpOp, [](
Operation *op) {
return isa<vector::TransferReadOp>(op); });
593 Value distributedVal = warpOp.getResult(operandIndex);
596 read.getIndices().end());
598 AffineMap indexMap = map.compose(read.getPermutationMap());
601 for (
auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
603 bindDims(read.getContext(), d0, d1);
604 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
607 unsigned indexPos = indexExpr.getPosition();
608 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
610 distributedVal.
getType().
cast<VectorType>().getDimSize(vectorPos);
613 {indices[indexPos], warpOp.getLaneid()});
615 Value newRead = rewriter.
create<vector::TransferReadOp>(
616 read.getLoc(), distributedVal.
getType(), read.getSource(), indices,
617 read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
618 read.getInBoundsAttr());
632 auto yield = cast<vector::YieldOp>(
633 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
634 for (
OpResult result : warpOp.getResults()) {
635 if (result.use_empty())
637 resultTypes.push_back(result.getType());
638 yieldValues.push_back(yield.getOperand(result.getResultNumber()));
640 if (yield.getNumOperands() == yieldValues.size())
643 rewriter, warpOp, yieldValues, resultTypes);
645 for (
OpResult result : warpOp.getResults()) {
646 if (result.use_empty())
648 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
657 struct WarpOpForwardOperand :
public OpRewritePattern<WarpExecuteOnLane0Op> {
663 auto yield = cast<vector::YieldOp>(
664 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
667 for (
OpOperand &operand : yield->getOpOperands()) {
668 Value result = warpOp.getResult(operand.getOperandNumber());
673 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
674 if (result.
getType() != operand.get().getType())
676 valForwarded = operand.get();
677 resultIndex = operand.getOperandNumber();
681 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
683 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
686 valForwarded = warpOperand;
687 resultIndex = operand.getOperandNumber();
692 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
702 warpOp, [](
Operation *op) {
return isa<vector::BroadcastOp>(op); });
707 Location loc = broadcastOp.getLoc();
709 warpOp->getResultTypes()[operandNumber].
cast<VectorType>();
712 rewriter, warpOp, {broadcastOp.getSource()},
713 {broadcastOp.getSource().
getType()}, newRetIndices);
715 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
716 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
729 warpOp, [](
Operation *op) {
return isa<vector::ExtractOp>(op); });
734 if (extractOp.getVectorType().getNumElements() != 1)
739 rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
742 Value newExtract = rewriter.
create<vector::ExtractOp>(
743 loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
785 auto yield = cast<vector::YieldOp>(
786 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
788 Operation *lastNode = yield->getPrevNode();
789 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
795 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
796 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
798 auto forResult = yieldOperand.get().cast<
OpResult>();
799 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
800 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
801 resultIdx.push_back(yieldOperand.getOperandNumber());
807 auto newForOp = rewriter.
create<scf::ForOp>(
808 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
809 forOp.getStep(), newOperands);
811 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
812 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
813 warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
814 forOp.getResultTypes());
817 argMapping.push_back(newForOp.getInductionVar());
818 for (
Value args : innerWarp.getBody()->getArguments()) {
819 argMapping.push_back(args);
822 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
823 yieldOperands.push_back(operand);
824 rewriter.
eraseOp(forOp.getBody()->getTerminator());
825 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
827 rewriter.
create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
829 if (!innerWarp.getResults().empty())
830 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
834 warpOp.getResult(res.value())
835 .replaceAllUsesWith(newForOp.getResult(res.index()));
836 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
863 DistributedReductionFn distributedReductionFn,
866 distributedReductionFn(distributedReductionFn) {}
871 warpOp, [](
Operation *op) {
return isa<vector::ReductionOp>(op); });
877 auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
879 if (vectorType.getRank() != 1)
881 warpOp,
"Only rank 1 reductions can be distributed.");
883 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
885 warpOp,
"Reduction vector dimension must match was size.");
887 if (!reductionOp.getType().isF32() &&
888 !reductionOp.getType().isSignlessInteger(32))
891 "Reduction distribution currently only supports 32bits types.");
893 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
898 VectorType::get({numElements}, reductionOp.getType())};
899 if (reductionOp.getAcc()) {
900 yieldValues.push_back(reductionOp.getAcc());
901 retTypes.push_back(reductionOp.getAcc().getType());
905 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
908 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
910 Value perLaneReduction = rewriter.
create<vector::ReductionOp>(
911 reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
914 distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
915 reductionOp.getKind(), newWarpOp.getWarpSize());
916 if (reductionOp.getAcc()) {
918 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
919 newWarpOp.getResult(newRetIndices[1]));
926 DistributedReductionFn distributedReductionFn;
937 void mlir::vector::populateDistributeTransferWriteOpPatterns(
939 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn);
942 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
944 patterns.
add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
945 WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
946 WarpOpScfForOp, WarpOpConstant>(patterns.
getContext());
949 void mlir::vector::populateDistributeReduction(
951 DistributedReductionFn distributedReductionFn) {
952 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn);
955 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
956 Block *body = warpOp.getBody();
959 llvm::SmallSetVector<Operation *, 8> opsToMove;
962 auto isDefinedOutsideOfBody = [&](
Value value) {
963 auto *definingOp =
value.getDefiningOp();
964 return (definingOp && opsToMove.count(definingOp)) ||
965 warpOp.isDefinedOutsideOfRegion(
value);
971 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
972 return result.
getType().isa<VectorType>();
974 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
975 opsToMove.insert(&op);
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options)
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Operation is a basic unit of execution within MLIR.
static AffineMap calculateImplicitMap(Value yield, Value ret)
Currently the distribution map is implicit based on the vector shape.
BlockListType & getBlocks()
This is a value defined by a result of an operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
std::function< AffineMap(vector::TransferWriteOp)> DistributionMapFn
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
static constexpr const bool value
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
MutableArrayRef< OpOperand > getOpOperands()
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isSideEffectFree(Operation *op)
Returns true if the given operation is side-effect free.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
WarpAllocationFn warpAllocationFn
Attributes are known-constant values of operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Base type for affine expression.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class provides an abstraction over the various different ranges of value types.
unsigned getNumResults() const
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
IRValueT get() const
Return the current value being used by this operand.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
This class represents an argument of a Block.
ArrayRef< AffineExpr > getResults() const
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
bool use_empty() const
Returns true if this value has no uses.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, std::function< bool(Operation *)> fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead...
static llvm::ManagedStatic< PassManagerOptions > options
static int resultIndex(int i)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
AffineExpr getAffineConstantExpr(int64_t constant)
RAII guard to reset the insertion point of the builder when destroyed.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
Type getType() const
Return the type of this value.
WarpSyncronizationFn warpSyncronizationFn
Specialization of arith.constant op that returns an integer of index type.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
type_range getType() const
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class provides an abstraction over the different types of ranges over Values.
static LogicalResult rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, const WarpExecuteOnLane0LoweringOptions &options)
MLIRContext * getContext() const
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.