MLIR  22.0.0git
PromoteShuffleToAMDGPU.cpp
Go to the documentation of this file.
1 //===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns to try to promote `gpu.shuffle`s to specialized
10 // AMDGPU intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 
20 #include "mlir/IR/PatternMatch.h"
21 #include <optional>
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
28 
29 /// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
30 /// and offset must be a constant integer in the range [0, 31].
31 struct PromoteShuffleToSwizzlePattern
32  : public OpRewritePattern<gpu::ShuffleOp> {
34 
35  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
36  PatternRewriter &rewriter) const override {
37  if (op.getMode() != gpu::ShuffleMode::XOR)
38  return rewriter.notifyMatchFailure(op,
39  "only xor shuffle mode is supported");
40 
41  if (!isConstantIntValue(op.getWidth(), 64))
42  return rewriter.notifyMatchFailure(op,
43  "only 64 width shuffle is supported");
44 
45  std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
46  if (!offset)
47  return rewriter.notifyMatchFailure(op,
48  "offset must be a constant integer");
49 
50  int64_t offsetValue = *offset;
51  if (offsetValue < 0 || offsetValue >= 32)
52  return rewriter.notifyMatchFailure(op,
53  "offset must be in the range [0, 31]");
54 
55  Location loc = op.getLoc();
56  Value res = amdgpu::SwizzleBitModeOp::create(
57  rewriter, loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
58  /*orMask=*/0, /*xorMask=*/offsetValue);
59  Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
60  rewriter.replaceOp(op, {res, valid});
61  return success();
62  }
63 };
64 
65 /// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
66 /// and offset must be a constant integer in the set {16, 32}.
67 struct PromoteShuffleToPermlanePattern
68  : public OpRewritePattern<gpu::ShuffleOp> {
70 
71  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
72  PatternRewriter &rewriter) const override {
73  if (op.getMode() != gpu::ShuffleMode::XOR)
74  return rewriter.notifyMatchFailure(op,
75  "only xor shuffle mode is supported");
76 
77  if (!isConstantIntValue(op.getWidth(), 64))
78  return rewriter.notifyMatchFailure(op,
79  "only 64 width shuffle is supported");
80 
81  std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
82  if (!offset)
83  return rewriter.notifyMatchFailure(op,
84  "offset must be a constant integer");
85 
86  int64_t offsetValue = *offset;
87  if (offsetValue != 16 && offsetValue != 32)
88  return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
89 
90  Location loc = op.getLoc();
91  Value res = amdgpu::PermlaneSwapOp::create(
92  rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
93  Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
94  rewriter.replaceOp(op, {res, valid});
95  return success();
96  }
97 };
98 
99 } // namespace
100 
102  RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103  patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
104  /*benefit*/ 1);
105  if (maybeChipset && *maybeChipset >= kGfx950)
106  patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
107  /*benefit*/ 2);
108 }
constexpr Chipset kGfx950
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319