MLIR 23.0.0git
PromoteShuffleToAMDGPU.cpp
Go to the documentation of this file.
1//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10// AMDGPU intrinsics.
11//
12//===----------------------------------------------------------------------===//
13
16
21#include <optional>
22
23using namespace mlir;
24
25namespace {
26
27constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
28
29/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
30/// and offset must be a constant integer in the range [0, 31].
31struct PromoteShuffleToSwizzlePattern
32 : public OpRewritePattern<gpu::ShuffleOp> {
34
35 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
36 PatternRewriter &rewriter) const override {
37 if (op.getMode() != gpu::ShuffleMode::XOR)
38 return rewriter.notifyMatchFailure(op,
39 "only xor shuffle mode is supported");
40
41 if (!isConstantIntValue(op.getWidth(), 64))
42 return rewriter.notifyMatchFailure(op,
43 "only 64 width shuffle is supported");
44
45 std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
46 if (!offset)
47 return rewriter.notifyMatchFailure(op,
48 "offset must be a constant integer");
49
50 int64_t offsetValue = *offset;
51 if (offsetValue < 0 || offsetValue >= 32)
52 return rewriter.notifyMatchFailure(op,
53 "offset must be in the range [0, 31]");
54
55 Location loc = op.getLoc();
56 Value res = amdgpu::SwizzleBitModeOp::create(
57 rewriter, loc, op.getResult(0).getType(), op.getValue(),
58 /*and_mask=*/31,
59 /*orMask=*/0, /*xorMask=*/offsetValue);
60 Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
61 rewriter.replaceOp(op, {res, valid});
62 return success();
63 }
64};
65
66/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
67/// and offset must be a constant integer in the set {16, 32}.
68struct PromoteShuffleToPermlanePattern
69 : public OpRewritePattern<gpu::ShuffleOp> {
71
72 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
73 PatternRewriter &rewriter) const override {
74 if (op.getMode() != gpu::ShuffleMode::XOR)
75 return rewriter.notifyMatchFailure(op,
76 "only xor shuffle mode is supported");
77
78 if (!isConstantIntValue(op.getWidth(), 64))
79 return rewriter.notifyMatchFailure(op,
80 "only 64 width shuffle is supported");
81
82 std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
83 if (!offset)
84 return rewriter.notifyMatchFailure(op,
85 "offset must be a constant integer");
86
87 int64_t offsetValue = *offset;
88 if (offsetValue != 16 && offsetValue != 32)
89 return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
90
91 Location loc = op.getLoc();
92 Value res = amdgpu::PermlaneSwapOp::create(
93 rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
94 Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
95 rewriter.replaceOp(op, {res, valid});
96 return success();
97 }
98};
99
100} // namespace
101
103 RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
104 patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
105 /*benefit*/ 1);
106 if (maybeChipset && *maybeChipset >= kGfx950)
107 patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
108 /*benefit*/ 2);
109}
constexpr Chipset kGfx950
return success()
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22