MLIR 22.0.0git
PromoteShuffleToAMDGPU.cpp
Go to the documentation of this file.
1//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10// AMDGPU intrinsics.
11//
12//===----------------------------------------------------------------------===//
13
16
21#include <optional>
22
23using namespace mlir;
24
25namespace {
26
27constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
28
29/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
30/// and offset must be a constant integer in the range [0, 31].
31struct PromoteShuffleToSwizzlePattern
32 : public OpRewritePattern<gpu::ShuffleOp> {
34
35 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
36 PatternRewriter &rewriter) const override {
37 if (op.getMode() != gpu::ShuffleMode::XOR)
38 return rewriter.notifyMatchFailure(op,
39 "only xor shuffle mode is supported");
40
41 if (!isConstantIntValue(op.getWidth(), 64))
42 return rewriter.notifyMatchFailure(op,
43 "only 64 width shuffle is supported");
44
45 std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
46 if (!offset)
47 return rewriter.notifyMatchFailure(op,
48 "offset must be a constant integer");
49
50 int64_t offsetValue = *offset;
51 if (offsetValue < 0 || offsetValue >= 32)
52 return rewriter.notifyMatchFailure(op,
53 "offset must be in the range [0, 31]");
54
55 Location loc = op.getLoc();
56 Value res = amdgpu::SwizzleBitModeOp::create(
57 rewriter, loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
58 /*orMask=*/0, /*xorMask=*/offsetValue);
59 Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
60 rewriter.replaceOp(op, {res, valid});
61 return success();
62 }
63};
64
65/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
66/// and offset must be a constant integer in the set {16, 32}.
67struct PromoteShuffleToPermlanePattern
68 : public OpRewritePattern<gpu::ShuffleOp> {
70
71 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
72 PatternRewriter &rewriter) const override {
73 if (op.getMode() != gpu::ShuffleMode::XOR)
74 return rewriter.notifyMatchFailure(op,
75 "only xor shuffle mode is supported");
76
77 if (!isConstantIntValue(op.getWidth(), 64))
78 return rewriter.notifyMatchFailure(op,
79 "only 64 width shuffle is supported");
80
81 std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
82 if (!offset)
83 return rewriter.notifyMatchFailure(op,
84 "offset must be a constant integer");
85
86 int64_t offsetValue = *offset;
87 if (offsetValue != 16 && offsetValue != 32)
88 return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
89
90 Location loc = op.getLoc();
91 Value res = amdgpu::PermlaneSwapOp::create(
92 rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
93 Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
94 rewriter.replaceOp(op, {res, valid});
95 return success();
96 }
97};
98
99} // namespace
100
102 RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103 patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
104 /*benefit*/ 1);
105 if (maybeChipset && *maybeChipset >= kGfx950)
106 patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
107 /*benefit*/ 2);
108}
constexpr Chipset kGfx950
return success()
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22