MLIR  19.0.0git
ShuffleRewriter.cpp
Go to the documentation of this file.
1 //===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the shuffle op for types i64 and
10 // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11 // implementation using shifts and truncations can be obtained using clang: by
12 // emitting IR for shuffle operations with `-O3`.
13 //
14 //===----------------------------------------------------------------------===//
15 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
28 
29  void initialize() {
30  // Required as the pattern will replace the Op with 2 additional ShuffleOps.
31  setHasBoundedRewriteRecursion();
32  }
33  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34  PatternRewriter &rewriter) const override {
35  auto loc = op.getLoc();
36  auto value = op.getValue();
37  auto valueType = value.getType();
38  auto valueLoc = value.getLoc();
39  auto i32 = rewriter.getI32Type();
40  auto i64 = rewriter.getI64Type();
41 
42  // If the type of the value is either i32 or f32, the op is already valid.
43  if (valueType.getIntOrFloatBitWidth() == 32)
44  return failure();
45 
46  Value lo, hi;
47 
48  // Float types must be converted to i64 to extract the bits.
49  if (isa<FloatType>(valueType))
50  value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
51 
52  // Get the low bits by trunc(value).
53  lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
54 
55  // Get the high bits by trunc(value >> 32).
56  auto c32 = rewriter.create<arith::ConstantOp>(
57  valueLoc, rewriter.getIntegerAttr(i64, 32));
58  hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
59  hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
60 
61  // Shuffle the values.
62  ValueRange loRes =
63  rewriter
64  .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
65  op.getWidth(), op.getMode())
66  .getResults();
67  ValueRange hiRes =
68  rewriter
69  .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
70  op.getWidth(), op.getMode())
71  .getResults();
72 
73  // Convert lo back to i64.
74  lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
75 
76  // Convert hi back to i64.
77  hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
78  hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
79 
80  // Obtain the shuffled bits hi | lo.
81  value = rewriter.create<arith::OrIOp>(loc, hi, lo);
82 
83  // Convert the value back to float.
84  if (isa<FloatType>(valueType))
85  value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
86 
87  // Obtain the shuffle validity by combining both validities.
88  auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
89 
90  // Replace the op.
91  rewriter.replaceOp(op, {value, validity});
92  return success();
93  }
94 };
95 } // namespace
96 
98  patterns.add<GpuShuffleRewriter>(patterns.getContext());
99 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateGpuShufflePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358