MLIR 22.0.0git
ShuffleRewriter.cpp
Go to the documentation of this file.
1//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the shuffle op for types i64 and
10// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11// implementation using shifts and truncations can be obtained using clang: by
12// emitting IR for shuffle operations with `-O3`.
13//
14//===----------------------------------------------------------------------===//
15
19#include "mlir/IR/Builders.h"
21
22using namespace mlir;
23
24namespace {
25struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
26 using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
27
28 void initialize() {
29 // Required as the pattern will replace the Op with 2 additional ShuffleOps.
30 setHasBoundedRewriteRecursion();
31 }
32 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
33 PatternRewriter &rewriter) const override {
34 auto loc = op.getLoc();
35 auto value = op.getValue();
36 auto valueType = value.getType();
37 auto valueLoc = value.getLoc();
38 auto i32 = rewriter.getI32Type();
39 auto i64 = rewriter.getI64Type();
40
41 // If the type of the value is either i32 or f32, the op is already valid.
42 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
43 return rewriter.notifyMatchFailure(
44 op, "only 64-bit int/float types are supported");
45
46 Value lo, hi;
47
48 // Float types must be converted to i64 to extract the bits.
49 if (isa<FloatType>(valueType))
50 value = arith::BitcastOp::create(rewriter, valueLoc, i64, value);
51
52 // Get the low bits by trunc(value).
53 lo = arith::TruncIOp::create(rewriter, valueLoc, i32, value);
54
55 // Get the high bits by trunc(value >> 32).
56 auto c32 = arith::ConstantOp::create(rewriter, valueLoc,
57 rewriter.getIntegerAttr(i64, 32));
58 hi = arith::ShRUIOp::create(rewriter, valueLoc, value, c32);
59 hi = arith::TruncIOp::create(rewriter, valueLoc, i32, hi);
60
61 // Shuffle the values.
62 ValueRange loRes =
63 gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(),
64 op.getWidth(), op.getMode())
65 .getResults();
66 ValueRange hiRes =
67 gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(),
68 op.getWidth(), op.getMode())
69 .getResults();
70
71 // Convert lo back to i64.
72 lo = arith::ExtUIOp::create(rewriter, valueLoc, i64, loRes[0]);
73
74 // Convert hi back to i64.
75 hi = arith::ExtUIOp::create(rewriter, valueLoc, i64, hiRes[0]);
76 hi = arith::ShLIOp::create(rewriter, valueLoc, hi, c32);
77
78 // Obtain the shuffled bits hi | lo.
79 value = arith::OrIOp::create(rewriter, loc, hi, lo);
80
81 // Convert the value back to float.
82 if (isa<FloatType>(valueType))
83 value = arith::BitcastOp::create(rewriter, valueLoc, valueType, value);
84
85 // Obtain the shuffle validity by combining both validities.
86 auto validity = arith::AndIOp::create(rewriter, loc, loRes[1], hiRes[1]);
87
88 // Replace the op.
89 rewriter.replaceOp(op, {value, validity});
90 return success();
91 }
92};
93} // namespace
94
96 patterns.add<GpuShuffleRewriter>(patterns.getContext());
97}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Include the generated interface declarations.
void populateGpuShufflePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...