28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
64 memref::ExpandShapeOp expandShapeOp,
75 MemRefType srcType = expandShapeOp.getSrcType();
77 for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78 if (srcType.isDynamicDim(i)) {
80 rewriter.
create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
83 srcShape.push_back(rewriter.
getIndexAttr(srcType.getShape()[i]));
88 rewriter, loc, expandShapeOp.getResultType(),
89 expandShapeOp.getReassociationIndices(), srcShape);
90 if (!outputShape.has_value())
96 assert(!groups.empty() &&
"association indices groups cannot be empty");
99 int64_t groupSize = groups.size();
104 for (int64_t i = 0; i < groupSize; ++i) {
105 sizesVal[i] = (*outputShape)[groups[i]];
115 bindDimsList<AffineExpr>(ctx, dims);
116 bindSymbolsList<AffineExpr>(ctx, symbols);
126 for (int64_t i = 0; i < groupSize; i++)
127 dynamicIndices[i] = indices[groups[i]];
132 llvm::append_range(mapOperands, suffixProduct);
133 llvm::append_range(mapOperands, dynamicIndices);
141 groupSize, srcIndexExpr),
146 sourceIndices.push_back(
167 memref::CollapseShapeOp collapseShapeOp,
174 assert(!groups.empty() &&
"association indices groups cannot be empty");
175 dynamicIndices.push_back(indices[cnt++]);
176 int64_t groupSize = groups.size();
183 for (int64_t i = 1; i < groupSize; ++i) {
184 sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
185 if (sizes[i] == ShapedType::kDynamic)
196 for (int64_t i = 0; i < groupSize; i++) {
200 delinearizingExprs[i]),
202 sourceIndices.push_back(
205 dynamicIndices.clear();
207 if (collapseShapeOp.getReassociationIndices().empty()) {
210 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
211 for (int64_t i = 0; i < srcRank; i++) {
213 rewriter, loc, zeroAffineMap, dynamicIndices);
214 sourceIndices.push_back(
222 template <
typename LoadOrStoreOpTy>
224 return op.getMemref();
228 return op.getSource();
232 return op.getSrcMemref();
244 return op.getSource();
248 return op.getSrcMemref();
252 return op.getDstMemref();
261 template <
typename OpTy>
266 LogicalResult matchAndRewrite(OpTy loadOp,
271 template <
typename OpTy>
276 LogicalResult matchAndRewrite(OpTy loadOp,
281 template <
typename OpTy>
286 LogicalResult matchAndRewrite(OpTy loadOp,
291 template <
typename OpTy>
296 LogicalResult matchAndRewrite(OpTy storeOp,
301 template <
typename OpTy>
306 LogicalResult matchAndRewrite(OpTy storeOp,
311 template <
typename OpTy>
312 class StoreOpOfCollapseShapeOpFolder final :
public OpRewritePattern<OpTy> {
316 LogicalResult matchAndRewrite(OpTy storeOp,
325 LogicalResult matchAndRewrite(memref::SubViewOp subView,
327 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
332 if (!subView.hasUnitStride()) {
335 if (!srcSubView.hasUnitStride()) {
341 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
343 subView.getMixedSizes(), srcDroppedDims,
349 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
350 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
355 subView, subView.getType(), srcSubView.getSource(),
357 srcSubView.getMixedStrides());
365 class NVGPUAsyncCopyOpSubViewOpFolder final
370 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
382 for (
unsigned i = 0, e = affineMap.
getNumResults(); i < e; i++) {
384 rewriter, loc, affineMap.
getSubMap({i}), indicesOfr);
385 expandedIndices.push_back(
388 return expandedIndices;
391 template <
typename XferOp>
394 memref::SubViewOp subviewOp) {
396 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
397 "must be a vector transfer op");
398 if (xferOp.hasOutOfBoundsDim())
400 if (!subviewOp.hasUnitStride()) {
402 xferOp,
"non-1 stride subview, need to track strides in folded memref");
409 memref::SubViewOp subviewOp) {
414 vector::TransferReadOp readOp,
415 memref::SubViewOp subviewOp) {
420 vector::TransferWriteOp writeOp,
421 memref::SubViewOp subviewOp) {
425 template <
typename OpTy>
426 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
434 LogicalResult preconditionResult =
436 if (failed(preconditionResult))
437 return preconditionResult;
440 loadOp.getIndices().end());
443 if (
auto affineLoadOp =
444 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
445 AffineMap affineMap = affineLoadOp.getAffineMap();
447 affineMap, indices, loadOp.getLoc(), rewriter);
448 indices.assign(expandedIndices.begin(), expandedIndices.end());
452 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
453 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
457 .Case([&](affine::AffineLoadOp op) {
459 loadOp, subViewOp.getSource(), sourceIndices);
461 .Case([&](memref::LoadOp op) {
463 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
465 .Case([&](vector::LoadOp op) {
467 op, op.getType(), subViewOp.getSource(), sourceIndices);
469 .Case([&](vector::MaskedLoadOp op) {
471 op, op.getType(), subViewOp.getSource(), sourceIndices,
472 op.getMask(), op.getPassThru());
474 .Case([&](vector::TransferReadOp op) {
476 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
478 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
479 subViewOp.getDroppedDims())),
480 op.getPadding(), op.getMask(), op.getInBoundsAttr());
482 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
484 op, op.getType(), subViewOp.getSource(), sourceIndices,
485 op.getLeadDimension(), op.getTransposeAttr());
487 .Case([&](nvgpu::LdMatrixOp op) {
489 op, op.getType(), subViewOp.getSource(), sourceIndices,
490 op.getTranspose(), op.getNumTiles());
492 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
496 template <
typename OpTy>
497 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
506 loadOp.getIndices().end());
509 if (
auto affineLoadOp =
510 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
511 AffineMap affineMap = affineLoadOp.getAffineMap();
513 affineMap, indices, loadOp.getLoc(), rewriter);
514 indices.assign(expandedIndices.begin(), expandedIndices.end());
518 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
521 .Case([&](affine::AffineLoadOp op) {
523 loadOp, expandShapeOp.getViewSource(), sourceIndices);
525 .Case([&](memref::LoadOp op) {
527 loadOp, expandShapeOp.getViewSource(), sourceIndices,
528 op.getNontemporal());
530 .Case([&](vector::LoadOp op) {
532 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
533 op.getNontemporal());
535 .Case([&](vector::MaskedLoadOp op) {
537 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
538 op.getMask(), op.getPassThru());
540 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
544 template <
typename OpTy>
545 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
548 .template getDefiningOp<memref::CollapseShapeOp>();
550 if (!collapseShapeOp)
554 loadOp.getIndices().end());
557 if (
auto affineLoadOp =
558 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
559 AffineMap affineMap = affineLoadOp.getAffineMap();
561 affineMap, indices, loadOp.getLoc(), rewriter);
562 indices.assign(expandedIndices.begin(), expandedIndices.end());
566 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
569 .Case([&](affine::AffineLoadOp op) {
571 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
573 .Case([&](memref::LoadOp op) {
575 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
576 op.getNontemporal());
578 .Case([&](vector::LoadOp op) {
580 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
581 op.getNontemporal());
583 .Case([&](vector::MaskedLoadOp op) {
585 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
586 op.getMask(), op.getPassThru());
588 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
592 template <
typename OpTy>
593 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
601 LogicalResult preconditionResult =
603 if (failed(preconditionResult))
604 return preconditionResult;
607 storeOp.getIndices().end());
610 if (
auto affineStoreOp =
611 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
612 AffineMap affineMap = affineStoreOp.getAffineMap();
614 affineMap, indices, storeOp.getLoc(), rewriter);
615 indices.assign(expandedIndices.begin(), expandedIndices.end());
619 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
620 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
624 .Case([&](affine::AffineStoreOp op) {
626 op, op.getValue(), subViewOp.getSource(), sourceIndices);
628 .Case([&](memref::StoreOp op) {
630 op, op.getValue(), subViewOp.getSource(), sourceIndices,
631 op.getNontemporal());
633 .Case([&](vector::TransferWriteOp op) {
635 op, op.getValue(), subViewOp.getSource(), sourceIndices,
637 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
638 subViewOp.getDroppedDims())),
639 op.getMask(), op.getInBoundsAttr());
641 .Case([&](vector::StoreOp op) {
643 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
645 .Case([&](vector::MaskedStoreOp op) {
647 op, subViewOp.getSource(), sourceIndices, op.getMask(),
648 op.getValueToStore());
650 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
652 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
653 op.getLeadDimension(), op.getTransposeAttr());
655 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
659 template <
typename OpTy>
660 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
669 storeOp.getIndices().end());
672 if (
auto affineStoreOp =
673 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
674 AffineMap affineMap = affineStoreOp.getAffineMap();
676 affineMap, indices, storeOp.getLoc(), rewriter);
677 indices.assign(expandedIndices.begin(), expandedIndices.end());
681 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
684 .Case([&](affine::AffineStoreOp op) {
686 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
689 .Case([&](memref::StoreOp op) {
691 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
692 sourceIndices, op.getNontemporal());
694 .Case([&](vector::StoreOp op) {
696 op, op.getValueToStore(), expandShapeOp.getViewSource(),
697 sourceIndices, op.getNontemporal());
699 .Case([&](vector::MaskedStoreOp op) {
701 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
702 op.getValueToStore());
704 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
708 template <
typename OpTy>
709 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
712 .template getDefiningOp<memref::CollapseShapeOp>();
714 if (!collapseShapeOp)
718 storeOp.getIndices().end());
721 if (
auto affineStoreOp =
722 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
723 AffineMap affineMap = affineStoreOp.getAffineMap();
725 affineMap, indices, storeOp.getLoc(), rewriter);
726 indices.assign(expandedIndices.begin(), expandedIndices.end());
730 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
733 .Case([&](affine::AffineStoreOp op) {
735 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
738 .Case([&](memref::StoreOp op) {
740 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
741 sourceIndices, op.getNontemporal());
743 .Case([&](vector::StoreOp op) {
745 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
746 sourceIndices, op.getNontemporal());
748 .Case([&](vector::MaskedStoreOp op) {
750 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
751 op.getValueToStore());
753 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
757 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
760 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
763 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
765 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
767 if (!(srcSubViewOp || dstSubViewOp))
769 "source or destination");
773 copyOp.getSrcIndices().end());
777 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
779 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
780 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
781 srcindices, foldedSrcIndices);
786 copyOp.getDstIndices().end());
790 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
792 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
793 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
794 dstindices, foldedDstIndices);
801 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
803 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
804 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
805 copyOp.getBypassL1Attr());
811 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
812 LoadOpOfSubViewOpFolder<memref::LoadOp>,
813 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
814 LoadOpOfSubViewOpFolder<vector::LoadOp>,
815 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
816 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
817 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
818 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
819 StoreOpOfSubViewOpFolder<memref::StoreOp>,
820 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
821 StoreOpOfSubViewOpFolder<vector::StoreOp>,
822 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
823 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
824 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
825 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
826 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
827 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
828 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
829 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
830 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
831 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
832 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
833 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
834 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
835 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
836 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
837 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
838 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
839 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
840 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
850 struct FoldMemRefAliasOpsPass final
851 :
public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
852 void runOnOperation()
override;
857 void FoldMemRefAliasOpsPass::runOnOperation() {
864 return std::make_unique<FoldMemRefAliasOpsPass>();
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...