MLIR  19.0.0git
ShuffleRewriter.cpp
Go to the documentation of this file.
1 //===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the shuffle op for types i64 and
10 // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11 // implementation using shifts and truncations can be obtained using clang: by
12 // emitting IR for shuffle operations with `-O3`.
13 //
14 //===----------------------------------------------------------------------===//
15 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
28 
29  void initialize() {
30  // Required as the pattern will replace the Op with 2 additional ShuffleOps.
31  setHasBoundedRewriteRecursion();
32  }
33  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34  PatternRewriter &rewriter) const override {
35  auto loc = op.getLoc();
36  auto value = op.getValue();
37  auto valueType = value.getType();
38  auto valueLoc = value.getLoc();
39  auto i32 = rewriter.getI32Type();
40  auto i64 = rewriter.getI64Type();
41 
42  // If the type of the value is either i32 or f32, the op is already valid.
43  if (valueType.getIntOrFloatBitWidth() == 32)
44  return failure();
45 
46  Value lo, hi;
47 
48  // Float types must be converted to i64 to extract the bits.
49  if (isa<FloatType>(valueType))
50  value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
51 
52  // Get the low bits by trunc(value).
53  lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
54 
55  // Get the high bits by trunc(value >> 32).
56  auto c32 = rewriter.create<arith::ConstantOp>(
57  valueLoc, rewriter.getIntegerAttr(i64, 32));
58  hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
59  hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
60 
61  // Shuffle the values.
62  ValueRange loRes =
63  rewriter
64  .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
65  op.getWidth(), op.getMode())
66  .getResults();
67  ValueRange hiRes =
68  rewriter
69  .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
70  op.getWidth(), op.getMode())
71  .getResults();
72 
73  // Convert lo back to i64.
74  lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
75 
76  // Convert hi back to i64.
77  hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
78  hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
79 
80  // Obtain the shuffled bits hi | lo.
81  value = rewriter.create<arith::OrIOp>(loc, hi, lo);
82 
83  // Convert the value back to float.
84  if (isa<FloatType>(valueType))
85  value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
86 
87  // Obtain the shuffle validity by combining both validities.
88  auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
89 
90  // Replace the op.
91  rewriter.replaceOp(op, {value, validity});
92  return success();
93  }
94 };
95 } // namespace
96 
98  patterns.add<GpuShuffleRewriter>(patterns.getContext());
99 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateGpuShufflePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357