25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "vector-transfer-opt"
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
46 class TransferOptimization {
49 : rewriter(rewriter), dominators(op), postDominators(op) {}
50 void deadStoreOp(vector::TransferWriteOp);
51 void storeToLoadForwarding(vector::TransferReadOp);
63 std::vector<Operation *> opToErase;
71 "This function only works for ops i the same region");
73 if (dominators.dominates(start, dest))
80 while (!worklist.empty()) {
81 Block *bb = worklist.pop_back_val();
82 if (!visited.insert(bb).second)
84 if (dominators.dominates(bb, destBlock))
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
106 Operation *firstOverwriteCandidate =
nullptr;
107 Value source = write.getSource();
109 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
110 source = subView.getSource();
113 llvm::SmallDenseSet<Operation *, 32> processed;
114 while (!users.empty()) {
117 if (!processed.insert(user).second)
119 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
120 users.append(subView->getUsers().begin(), subView->getUsers().end());
125 if (user == write.getOperation())
127 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
129 if (write.getSource() == nextWrite.getSource() &&
131 postDominators.postDominates(nextWrite, write)) {
132 if (firstOverwriteCandidate ==
nullptr ||
133 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134 firstOverwriteCandidate = nextWrite;
137 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
141 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
144 cast<VectorTransferOpInterface>(write.getOperation()),
145 cast<VectorTransferOpInterface>(transferOp.getOperation()),
149 blockingAccesses.push_back(user);
151 if (firstOverwriteCandidate ==
nullptr)
155 assert(writeAncestor &&
156 "write op should be recursively part of the top region");
158 for (
Operation *access : blockingAccesses) {
162 if (accessAncestor ==
nullptr ||
165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
167 << *accessAncestor <<
"\n");
171 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
172 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
173 opToErase.push_back(write.getOperation());
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188 if (read.hasOutOfBoundsDim())
190 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
193 vector::TransferWriteOp lastwrite =
nullptr;
194 Value source = read.getSource();
196 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
197 source = subView.getSource();
200 llvm::SmallDenseSet<Operation *, 32> processed;
201 while (!users.empty()) {
204 if (!processed.insert(user).second)
206 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
207 users.append(subView->getUsers().begin(), subView->getUsers().end());
210 if (
auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
216 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
220 cast<VectorTransferOpInterface>(write.getOperation()),
221 cast<VectorTransferOpInterface>(read.getOperation()),
224 if (write.getSource() == read.getSource() &&
226 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
229 assert(dominators.dominates(write, lastwrite));
233 blockingWrites.push_back(user);
236 if (lastwrite ==
nullptr)
241 assert(readAncestor &&
242 "read op should be recursively part of the top region");
244 for (
Operation *write : blockingWrites) {
248 if (writeAncestor ==
nullptr || !
isReachable(writeAncestor, readAncestor))
250 if (!postDominators.postDominates(lastwrite, write)) {
251 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
257 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
258 <<
" to: " << *read.getOperation() <<
"\n");
259 read.replaceAllUsesWith(lastwrite.getVector());
260 opToErase.push_back(read.getOperation());
266 llvm::copy_if(shape, std::back_inserter(reducedShape),
267 [](int64_t dimSize) {
return dimSize != 1; });
274 for (
const auto size : mixedSizes) {
275 if (llvm::dyn_cast_if_present<Value>(size)) {
276 reducedShape.push_back(ShapedType::kDynamic);
280 auto value = cast<IntegerAttr>(size.get<
Attribute>()).getValue();
283 reducedShape.push_back(value.getSExtValue());
294 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
295 targetShape, inputType, offsets, sizes, strides);
304 MemRefType inputType = cast<MemRefType>(input.
getType());
310 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
315 return rewriter.
create<memref::SubViewOp>(loc, resultType, input, offsets,
321 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
330 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
332 newShape.push_back(dimSize);
333 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
335 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
341 vector::CreateMaskOp op) {
342 auto type = op.getType();
344 if (reducedType.getRank() == type.getRank())
348 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
349 type.getShape(), type.getScalableDims(), op.
getOperands())) {
350 if (dim == 1 && !dimIsScalable) {
352 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
353 if (!constant || (constant.value() != 1))
357 reducedOperands.push_back(operand);
360 .
create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
369 class TransferReadDropUnitDimsPattern
373 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
375 auto loc = transferReadOp.getLoc();
376 Value vector = transferReadOp.getVector();
377 VectorType vectorType = cast<VectorType>(vector.
getType());
378 Value source = transferReadOp.getSource();
379 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
384 if (transferReadOp.hasOutOfBoundsDim())
386 if (!transferReadOp.getPermutationMap().isMinorIdentity())
390 if (reducedRank == sourceType.getRank())
395 if (reducedRank != reducedVectorType.getRank())
397 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
398 return getConstantIntValue(v) != static_cast<int64_t>(0);
402 Value maskOp = transferReadOp.getMask();
404 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
407 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
408 "currently supported");
411 if (
failed(rankReducedCreateMask))
413 maskOp = *rankReducedCreateMask;
416 Value reducedShapeSource =
418 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
422 auto newTransferReadOp = rewriter.
create<vector::TransferReadOp>(
423 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
424 transferReadOp.getPadding(), maskOp,
426 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
427 loc, vectorType, newTransferReadOp);
428 rewriter.
replaceOp(transferReadOp, shapeCast);
437 class TransferWriteDropUnitDimsPattern
441 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
443 auto loc = transferWriteOp.getLoc();
444 Value vector = transferWriteOp.getVector();
445 VectorType vectorType = cast<VectorType>(vector.
getType());
446 Value source = transferWriteOp.getSource();
447 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
449 if (!sourceType || !sourceType.hasStaticShape())
451 if (sourceType.getNumElements() != vectorType.getNumElements())
454 if (transferWriteOp.hasOutOfBoundsDim())
456 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
460 if (reducedRank == sourceType.getRank())
465 if (reducedRank != vectorReducedRank)
467 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
468 return getConstantIntValue(v) != static_cast<int64_t>(0);
471 Value reducedShapeSource =
473 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
479 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
480 loc, reducedVectorType, vector);
482 transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
494 auto shape = memrefType.getShape();
499 if (strides.back() != 1)
503 for (
auto [targetDim, memrefDim, memrefStride] :
504 llvm::reverse(llvm::zip(targetShape, shape, strides))) {
505 flatDim *= memrefDim;
506 if (flatDim != memrefStride || targetDim != memrefDim)
515 Value input, int64_t firstDimToCollapse) {
516 ShapedType inputType = cast<ShapedType>(input.
getType());
517 if (inputType.getRank() == 1)
520 for (int64_t i = 0; i < firstDimToCollapse; ++i)
523 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
524 collapsedIndices.push_back(i);
525 reassociation.push_back(collapsedIndices);
526 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
535 int64_t rank = indices.size();
536 if (firstDimToCollapse >= rank)
538 for (int64_t i = firstDimToCollapse; i < rank; ++i) {
540 if (!cst || cst.value() != 0)
543 outIndices = indices;
544 outIndices.resize(firstDimToCollapse + 1);
554 class FlattenContiguousRowMajorTransferReadPattern
558 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
560 auto loc = transferReadOp.getLoc();
561 Value vector = transferReadOp.getVector();
562 VectorType vectorType = cast<VectorType>(vector.
getType());
563 Value source = transferReadOp.getSource();
564 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
568 if (vectorType.getRank() <= 1)
573 vectorType.getShape().take_back(vectorType.getRank() - 1)))
575 int64_t firstContiguousInnerDim =
576 sourceType.getRank() - vectorType.getRank();
578 if (transferReadOp.hasOutOfBoundsDim())
580 if (!transferReadOp.getPermutationMap().isMinorIdentity())
582 if (transferReadOp.getMask())
586 firstContiguousInnerDim,
589 Value collapsedSource =
591 MemRefType collapsedSourceType =
592 dyn_cast<MemRefType>(collapsedSource.
getType());
593 int64_t collapsedRank = collapsedSourceType.getRank();
594 assert(collapsedRank == firstContiguousInnerDim + 1);
599 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
600 vectorType.getElementType());
601 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
602 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
605 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
614 class FlattenContiguousRowMajorTransferWritePattern
618 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
620 auto loc = transferWriteOp.getLoc();
621 Value vector = transferWriteOp.getVector();
622 VectorType vectorType = cast<VectorType>(vector.
getType());
623 Value source = transferWriteOp.getSource();
624 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
628 if (vectorType.getRank() <= 1)
633 vectorType.getShape().take_back(vectorType.getRank() - 1)))
635 int64_t firstContiguousInnerDim =
636 sourceType.getRank() - vectorType.getRank();
638 if (transferWriteOp.hasOutOfBoundsDim())
640 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
642 if (transferWriteOp.getMask())
646 firstContiguousInnerDim,
649 Value collapsedSource =
651 MemRefType collapsedSourceType =
652 cast<MemRefType>(collapsedSource.
getType());
653 int64_t collapsedRank = collapsedSourceType.getRank();
654 assert(collapsedRank == firstContiguousInnerDim + 1);
659 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
660 vectorType.getElementType());
662 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
663 vector::TransferWriteOp flatWrite =
664 rewriter.
create<vector::TransferWriteOp>(
665 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
667 rewriter.
eraseOp(transferWriteOp);
675 template <
class VectorExtractOp>
676 class RewriteScalarExtractOfTransferReadBase
681 RewriteScalarExtractOfTransferReadBase(
MLIRContext *context,
683 bool allowMultipleUses)
685 allowMultipleUses(allowMultipleUses) {}
687 LogicalResult match(VectorExtractOp extractOp)
const override {
689 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
693 if (isa<VectorType>(extractOp.getResult().getType()))
696 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
699 if (allowMultipleUses &&
700 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
701 return isa<vector::ExtractOp, vector::ExtractElementOp>(
706 if (xferOp.getMask())
709 if (!xferOp.getPermutationMap().isMinorIdentity())
712 if (xferOp.hasOutOfBoundsDim())
718 bool allowMultipleUses;
728 class RewriteScalarExtractElementOfTransferRead
729 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
730 using RewriteScalarExtractOfTransferReadBase::
731 RewriteScalarExtractOfTransferReadBase;
733 void rewrite(vector::ExtractElementOp extractOp,
736 auto loc = extractOp.getLoc();
737 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
739 xferOp.getIndices().end());
740 if (extractOp.getPosition()) {
744 rewriter, loc, sym0 + sym1,
745 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
746 if (ofr.is<
Value>()) {
747 newIndices[newIndices.size() - 1] = ofr.get<
Value>();
749 newIndices[newIndices.size() - 1] =
750 rewriter.
create<arith::ConstantIndexOp>(loc,
754 if (isa<MemRefType>(xferOp.getSource().getType())) {
759 extractOp, xferOp.getSource(), newIndices);
772 class RewriteScalarExtractOfTransferRead
773 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
774 using RewriteScalarExtractOfTransferReadBase::
775 RewriteScalarExtractOfTransferReadBase;
777 void rewrite(vector::ExtractOp extractOp,
780 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
782 xferOp.getIndices().end());
784 assert(pos.is<
Attribute>() &&
"Unexpected non-constant index");
785 int64_t offset = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
786 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
788 rewriter, extractOp.getLoc(),
790 if (ofr.is<
Value>()) {
791 newIndices[idx] = ofr.get<
Value>();
793 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
797 if (isa<MemRefType>(xferOp.getSource().getType())) {
802 extractOp, xferOp.getSource(), newIndices);
809 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
812 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
815 auto vecType = xferOp.getVectorType();
816 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
819 if (xferOp.getMask())
822 if (!xferOp.getPermutationMap().isMinorIdentity())
826 if (vecType.getRank() == 0) {
829 scalar = rewriter.
create<vector::ExtractElementOp>(xferOp.getLoc(),
833 scalar = rewriter.
create<vector::ExtractOp>(xferOp.getLoc(),
834 xferOp.getVector(), pos);
837 if (isa<MemRefType>(xferOp.getSource().getType())) {
839 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
842 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
852 TransferOptimization opt(rewriter, rootOp);
855 rootOp->
walk([&](vector::TransferReadOp read) {
856 if (isa<MemRefType>(read.getShapedType()))
857 opt.storeToLoadForwarding(read);
860 rootOp->
walk([&](vector::TransferWriteOp write) {
861 if (isa<MemRefType>(write.getShapedType()))
862 opt.deadStoreOp(write);
869 bool allowMultipleUses) {
870 patterns.
add<RewriteScalarExtractElementOfTransferRead,
871 RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
872 benefit, allowMultipleUses);
873 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
879 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
886 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
887 FlattenContiguousRowMajorTransferWritePattern>(
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
succ_iterator succ_begin()
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...