MLIR  16.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "math-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32 /// given type is not a 32-bit scalar/vector type.
34  OpBuilder &builder, Location loc) {
35  if (auto vectorType = type.dyn_cast<VectorType>()) {
36  if (!vectorType.getElementType().isInteger(32))
37  return nullptr;
38  SmallVector<int> values(vectorType.getNumElements(), value);
39  return builder.create<spirv::ConstantOp>(loc, type,
40  builder.getI32VectorAttr(values));
41  }
42  if (type.isInteger(32))
43  return builder.create<spirv::ConstantOp>(loc, type,
44  builder.getI32IntegerAttr(value));
45 
46  return nullptr;
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // Operation conversion
51 //===----------------------------------------------------------------------===//
52 
53 // Note that DRR cannot be used for the patterns in this file: we may need to
54 // convert type along the way, which requires ConversionPattern. DRR generates
55 // normal RewritePattern.
56 
57 namespace {
58 /// Converts math.copysign to SPIR-V ops.
59 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
61 
63  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const override {
65  auto type = getTypeConverter()->convertType(copySignOp.getType());
66  if (!type)
67  return failure();
68 
69  FloatType floatType;
70  if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
71  floatType = scalarType;
72  } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
73  floatType = vectorType.getElementType().cast<FloatType>();
74  } else {
75  return failure();
76  }
77 
78  Location loc = copySignOp.getLoc();
79  int bitwidth = floatType.getWidth();
80  Type intType = rewriter.getIntegerType(bitwidth);
81  uint64_t intValue = uint64_t(1) << (bitwidth - 1);
82 
83  Value signMask = rewriter.create<spirv::ConstantOp>(
84  loc, intType, rewriter.getIntegerAttr(intType, intValue));
85  Value valueMask = rewriter.create<spirv::ConstantOp>(
86  loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
87 
88  if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
89  assert(vectorType.getRank() == 1);
90  int count = vectorType.getNumElements();
91  intType = VectorType::get(count, intType);
92 
93  SmallVector<Value> signSplat(count, signMask);
94  signMask =
95  rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
96 
97  SmallVector<Value> valueSplat(count, valueMask);
98  valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
99  valueSplat);
100  }
101 
102  Value lhsCast =
103  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
104  Value rhsCast =
105  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
106 
107  Value value = rewriter.create<spirv::BitwiseAndOp>(
108  loc, intType, ValueRange{lhsCast, valueMask});
109  Value sign = rewriter.create<spirv::BitwiseAndOp>(
110  loc, intType, ValueRange{rhsCast, signMask});
111 
112  Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
113  ValueRange{value, sign});
114  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
115  return success();
116  }
117 };
118 
119 /// Converts math.ctlz to SPIR-V ops.
120 ///
121 /// SPIR-V does not have a direct operations for counting leading zeros. If
122 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
123 /// it.
124 class CountLeadingZerosPattern final
125  : public OpConversionPattern<math::CountLeadingZerosOp> {
127 
129  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
131  auto type = getTypeConverter()->convertType(countOp.getType());
132  if (!type)
133  return failure();
134 
135  // We can only support 32-bit integer types for now.
136  unsigned bitwidth = 0;
137  if (type.isa<IntegerType>())
138  bitwidth = type.getIntOrFloatBitWidth();
139  if (auto vectorType = type.dyn_cast<VectorType>())
140  bitwidth = vectorType.getElementTypeBitWidth();
141  if (bitwidth != 32)
142  return failure();
143 
144  Location loc = countOp.getLoc();
145  Value input = adaptor.getOperand();
146  Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
147  Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
148  Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
149 
150  Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
151  // We need to subtract from 31 given that the index returned by GLSL
152  // FindUMsb is counted from the least significant bit. Theoretically this
153  // also gives the correct result even if the integer has all zero bits, in
154  // which case GL FindUMsb would return -1.
155  Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
156  // However, certain Vulkan implementations have driver bugs for the corner
157  // case where the input is zero. And.. it can be smart to optimize a select
158  // only involving the corner case. So separately compute the result when the
159  // input is either zero or one.
160  Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
161  Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
162  rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
163  subMsb);
164  return success();
165  }
166 };
167 
168 /// Converts math.expm1 to SPIR-V ops.
169 ///
170 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
171 /// these operations.
172 template <typename ExpOp>
173 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
175 
177  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
178  ConversionPatternRewriter &rewriter) const override {
179  assert(adaptor.getOperands().size() == 1);
180  Location loc = operation.getLoc();
181  auto type = this->getTypeConverter()->convertType(operation.getType());
182  auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
183  auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
184  rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
185  return success();
186  }
187 };
188 
189 /// Converts math.log1p to SPIR-V ops.
190 ///
191 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
192 /// these operations.
193 template <typename LogOp>
194 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
196 
198  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
199  ConversionPatternRewriter &rewriter) const override {
200  assert(adaptor.getOperands().size() == 1);
201  Location loc = operation.getLoc();
202  auto type = this->getTypeConverter()->convertType(operation.getType());
203  auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
204  auto onePlus =
205  rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
206  rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
207  return success();
208  }
209 };
210 
211 /// Converts math.powf to SPIRV-Ops.
212 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
214 
216  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override {
218  auto dstType = getTypeConverter()->convertType(powfOp.getType());
219  if (!dstType)
220  return failure();
221 
222  // Per GL Pow extended instruction spec:
223  // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
224  Location loc = powfOp.getLoc();
225  Value zero =
226  spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
227  Value lessThan =
228  rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
229  Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
230  Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
231  Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
232  rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
233  return success();
234  }
235 };
236 
237 /// Converts math.round to GLSL SPIRV extended ops.
238 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
240 
242  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const override {
244  Location loc = roundOp.getLoc();
245  auto operand = roundOp.getOperand();
246  auto ty = operand.getType();
247  auto ety = getElementTypeOrSelf(ty);
248 
249  auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
250  auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
251  Value half;
252  if (VectorType vty = ty.dyn_cast<VectorType>()) {
253  half = rewriter.create<spirv::ConstantOp>(
254  loc, vty,
256  rewriter.getFloatAttr(ety, 0.5).getValue()));
257  } else {
258  half = rewriter.create<spirv::ConstantOp>(
259  loc, ty, rewriter.getFloatAttr(ety, 0.5));
260  }
261 
262  auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
263  auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
264  auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
265  auto greater =
266  rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
267  auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
268  auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
269  rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
270  return success();
271  }
272 };
273 
274 } // namespace
275 
276 //===----------------------------------------------------------------------===//
277 // Pattern population
278 //===----------------------------------------------------------------------===//
279 
280 namespace mlir {
282  RewritePatternSet &patterns) {
283  // Core patterns
284  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
285 
286  // GLSL patterns
287  patterns
288  .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
289  ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
301  typeConverter, patterns.getContext());
302 
303  // OpenCL patterns
304  patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
319  typeConverter, patterns.getContext());
320 }
321 
322 } // namespace mlir
Include the generated interface declarations.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:369
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:21
int64_t floor(Fraction f)
Definition: Fraction.h:63
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:217
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:172
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
unsigned getWidth()
Return the bitwidth of this float type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:194
U dyn_cast() const
Definition: Types.h:270
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Definition: MathToSPIRV.cpp:33
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
This class implements a pattern rewriter for use with ConversionPatterns.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
MLIRContext * getContext() const
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type conversion from builtin types to SPIR-V types for shader interface.