MLIR  19.0.0git
XeGPUFoldAliasOps.cpp
Go to the documentation of this file.
1 //===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/Debug.h"
18 
19 namespace mlir {
20 namespace xegpu {
21 #define GEN_PASS_DEF_XEGPUFOLDALIASOPS
22 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
23 } // namespace xegpu
24 } // namespace mlir
25 
26 #define DEBUG_TYPE "xegpu-fold-alias-ops"
27 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
28 
29 using namespace mlir;
30 
31 namespace {
32 /// Merges subview operation with xegpu.create_nd_tdesc operation.
33 class XegpuCreateNdDescOpSubViewOpFolder final
34  : public OpRewritePattern<xegpu::CreateNdDescOp> {
35 public:
37 
38  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
39  PatternRewriter &rewriter) const override;
40 };
41 } // namespace
42 
43 LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
44  xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const {
45  auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
46 
47  if (!subViewOp)
48  return rewriter.notifyMatchFailure(descOp, "not a subview producer");
49  if (!subViewOp.hasUnitStride())
50  return rewriter.notifyMatchFailure(descOp, "requires unit strides");
51 
52  SmallVector<Value> resolvedOffsets;
54  rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
55  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
56  descOp.getMixedOffsets(), resolvedOffsets);
57 
58  rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
59  descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
60  getAsOpFoldResult(resolvedOffsets));
61 
62  return success();
63 }
64 
66  patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext());
67 }
68 
69 namespace {
70 
71 struct XeGPUFoldAliasOpsPass final
72  : public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> {
73  void runOnOperation() override;
74 };
75 
76 } // namespace
77 
78 void XeGPUFoldAliasOpsPass::runOnOperation() {
79  RewritePatternSet patterns(&getContext());
81  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
82 }
static MLIRContext * getContext(OpFoldResult val)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns)
Appends patterns for folding aliasing ops into XeGPU ops into patterns.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358