17 #include "llvm/Support/Debug.h"
21 #define GEN_PASS_DEF_XEGPUFOLDALIASOPS
22 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26 #define DEBUG_TYPE "xegpu-fold-alias-ops"
27 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33 class XegpuCreateNdDescOpSubViewOpFolder final
38 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
43 LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
45 auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
49 if (!subViewOp.hasUnitStride())
54 rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
55 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
56 descOp.getMixedOffsets(), resolvedOffsets);
59 descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
66 patterns.
add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.
getContext());
71 struct XeGPUFoldAliasOpsPass final
72 :
public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> {
73 void runOnOperation()
override;
78 void XeGPUFoldAliasOpsPass::runOnOperation() {
static MLIRContext * getContext(OpFoldResult val)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns)
Appends patterns for folding aliasing ops into XeGPU ops into patterns.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...