MLIR  20.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "math-to-spirv-pattern"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34 /// given type is not a 32-bit scalar/vector type.
35 static Value getScalarOrVectorI32Constant(Type type, int value,
36  OpBuilder &builder, Location loc) {
37  if (auto vectorType = dyn_cast<VectorType>(type)) {
38  if (!vectorType.getElementType().isInteger(32))
39  return nullptr;
40  SmallVector<int> values(vectorType.getNumElements(), value);
41  return builder.create<spirv::ConstantOp>(loc, type,
42  builder.getI32VectorAttr(values));
43  }
44  if (type.isInteger(32))
45  return builder.create<spirv::ConstantOp>(loc, type,
46  builder.getI32IntegerAttr(value));
47 
48  return nullptr;
49 }
50 
51 /// Check if the type is supported by math-to-spirv conversion. We expect to
52 /// only see scalars and vectors at this point, with higher-level types already
53 /// lowered.
54 static bool isSupportedSourceType(Type originalType) {
55  if (originalType.isIntOrIndexOrFloat())
56  return true;
57 
58  if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59  if (!vecTy.getElementType().isIntOrIndexOrFloat())
60  return false;
61  if (vecTy.isScalable())
62  return false;
63  if (vecTy.getRank() > 1)
64  return false;
65 
66  return true;
67  }
68 
69  return false;
70 }
71 
72 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73 /// Notify of a match failure othwerise and return a `failure` result.
74 /// This is intended to simplify type checks in `OpConversionPattern`s.
75 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76  Operation *sourceOp) {
77  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
78  llvm::append_range(allTypes, sourceOp->getResultTypes());
79 
80  for (Type ty : allTypes) {
81  if (!isSupportedSourceType(ty)) {
82  return rewriter.notifyMatchFailure(
83  sourceOp,
84  llvm::formatv(
85  "unsupported source type for Math to SPIR-V conversion: {0}",
86  ty));
87  }
88  }
89 
90  return success();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Operation conversion
95 //===----------------------------------------------------------------------===//
96 
97 // Note that DRR cannot be used for the patterns in this file: we may need to
98 // convert type along the way, which requires ConversionPattern. DRR generates
99 // normal RewritePattern.
100 
101 namespace {
102 /// Converts elementwise unary, binary, and ternary standard operations to
103 /// SPIR-V operations. Checks that source `Op` types are supported.
104 template <typename Op, typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
106  : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107  using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108  using BasePattern::BasePattern;
109 
110  LogicalResult
111  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112  ConversionPatternRewriter &rewriter) const override {
113  if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
114  return res;
115 
116  return BasePattern::matchAndRewrite(op, adaptor, rewriter);
117  }
118 };
119 
120 /// Converts math.copysign to SPIR-V ops.
121 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
123 
124  LogicalResult
125  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126  ConversionPatternRewriter &rewriter) const override {
127  if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128  failed(res))
129  return res;
130 
131  Type type = getTypeConverter()->convertType(copySignOp.getType());
132  if (!type)
133  return failure();
134 
135  FloatType floatType;
136  if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137  floatType = scalarType;
138  } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139  floatType = cast<FloatType>(vectorType.getElementType());
140  } else {
141  return failure();
142  }
143 
144  Location loc = copySignOp.getLoc();
145  int bitwidth = floatType.getWidth();
146  Type intType = rewriter.getIntegerType(bitwidth);
147  uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148 
149  Value signMask = rewriter.create<spirv::ConstantOp>(
150  loc, intType, rewriter.getIntegerAttr(intType, intValue));
151  Value valueMask = rewriter.create<spirv::ConstantOp>(
152  loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
153 
154  if (auto vectorType = dyn_cast<VectorType>(type)) {
155  assert(vectorType.getRank() == 1);
156  int count = vectorType.getNumElements();
157  intType = VectorType::get(count, intType);
158 
159  SmallVector<Value> signSplat(count, signMask);
160  signMask =
161  rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162 
163  SmallVector<Value> valueSplat(count, valueMask);
164  valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165  valueSplat);
166  }
167 
168  Value lhsCast =
169  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170  Value rhsCast =
171  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
172 
173  Value value = rewriter.create<spirv::BitwiseAndOp>(
174  loc, intType, ValueRange{lhsCast, valueMask});
175  Value sign = rewriter.create<spirv::BitwiseAndOp>(
176  loc, intType, ValueRange{rhsCast, signMask});
177 
178  Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179  ValueRange{value, sign});
180  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181  return success();
182  }
183 };
184 
185 /// Converts math.ctlz to SPIR-V ops.
186 ///
187 /// SPIR-V does not have a direct operations for counting leading zeros. If
188 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
189 /// it.
190 struct CountLeadingZerosPattern final
191  : public OpConversionPattern<math::CountLeadingZerosOp> {
193 
194  LogicalResult
195  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196  ConversionPatternRewriter &rewriter) const override {
197  if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
198  return res;
199 
200  Type type = getTypeConverter()->convertType(countOp.getType());
201  if (!type)
202  return failure();
203 
204  // We can only support 32-bit integer types for now.
205  unsigned bitwidth = 0;
206  if (isa<IntegerType>(type))
207  bitwidth = type.getIntOrFloatBitWidth();
208  if (auto vectorType = dyn_cast<VectorType>(type))
209  bitwidth = vectorType.getElementTypeBitWidth();
210  if (bitwidth != 32)
211  return failure();
212 
213  Location loc = countOp.getLoc();
214  Value input = adaptor.getOperand();
215  Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
216  Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
217  Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
218 
219  Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
220  // We need to subtract from 31 given that the index returned by GLSL
221  // FindUMsb is counted from the least significant bit. Theoretically this
222  // also gives the correct result even if the integer has all zero bits, in
223  // which case GL FindUMsb would return -1.
224  Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
225  // However, certain Vulkan implementations have driver bugs for the corner
226  // case where the input is zero. And.. it can be smart to optimize a select
227  // only involving the corner case. So separately compute the result when the
228  // input is either zero or one.
229  Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
230  Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
231  rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
232  subMsb);
233  return success();
234  }
235 };
236 
237 /// Converts math.expm1 to SPIR-V ops.
238 ///
239 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
240 /// these operations.
241 template <typename ExpOp>
242 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
244 
245  LogicalResult
246  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  assert(adaptor.getOperands().size() == 1);
249  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
250  failed(res))
251  return res;
252 
253  Location loc = operation.getLoc();
254  Type type = this->getTypeConverter()->convertType(operation.getType());
255  if (!type)
256  return failure();
257 
258  Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
259  auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
260  rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
261  return success();
262  }
263 };
264 
265 /// Converts math.log1p to SPIR-V ops.
266 ///
267 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268 /// these operations.
269 template <typename LogOp>
270 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
272 
273  LogicalResult
274  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275  ConversionPatternRewriter &rewriter) const override {
276  assert(adaptor.getOperands().size() == 1);
277  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
278  failed(res))
279  return res;
280 
281  Location loc = operation.getLoc();
282  Type type = this->getTypeConverter()->convertType(operation.getType());
283  if (!type)
284  return failure();
285 
286  auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287  Value onePlus =
288  rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
289  rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290  return success();
291  }
292 };
293 
294 /// Converts math.log2 and math.log10 to SPIR-V ops.
295 ///
296 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
297 /// lower to these operations using:
298 /// log2(x) = log(x) * 1/log(2)
299 /// log10(x) = log(x) * 1/log(10)
300 
301 template <typename MathLogOp, typename SpirvLogOp>
302 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
305 
306  static constexpr double log2Reciprocal =
307  1.442695040888963407359924681001892137426645954152985934135449407;
308  static constexpr double log10Reciprocal =
309  0.4342944819032518276511289189166050822943970058036665661144537832;
310 
311  LogicalResult
312  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313  ConversionPatternRewriter &rewriter) const override {
314  assert(adaptor.getOperands().size() == 1);
315  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
316  failed(res))
317  return res;
318 
319  Location loc = operation.getLoc();
320  Type type = this->getTypeConverter()->convertType(operation.getType());
321  if (!type)
322  return rewriter.notifyMatchFailure(operation, "type conversion failed");
323 
324  auto getConstantValue = [&](double value) {
325  if (auto floatType = dyn_cast<FloatType>(type)) {
326  return rewriter.create<spirv::ConstantOp>(
327  loc, type, rewriter.getFloatAttr(floatType, value));
328  }
329  if (auto vectorType = dyn_cast<VectorType>(type)) {
330  Type elemType = vectorType.getElementType();
331 
332  if (isa<FloatType>(elemType)) {
333  return rewriter.create<spirv::ConstantOp>(
334  loc, type,
336  vectorType, FloatAttr::get(elemType, value).getValue()));
337  }
338  }
339 
340  llvm_unreachable("unimplemented types for log2/log10");
341  };
342 
343  Value constantValue = getConstantValue(
344  std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345  : log10Reciprocal);
346  Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
347  rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
348  constantValue);
349  return success();
350  }
351 };
352 
353 /// Converts math.powf to SPIRV-Ops.
354 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
356 
357  LogicalResult
358  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
359  ConversionPatternRewriter &rewriter) const override {
360  if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
361  return res;
362 
363  Type dstType = getTypeConverter()->convertType(powfOp.getType());
364  if (!dstType)
365  return failure();
366 
367  // Get the scalar float type.
368  FloatType scalarFloatType;
369  if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370  scalarFloatType = scalarType;
371  } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372  scalarFloatType = cast<FloatType>(vectorType.getElementType());
373  } else {
374  return failure();
375  }
376 
377  // Get int type of the same shape as the float type.
378  Type scalarIntType = rewriter.getIntegerType(32);
379  Type intType = scalarIntType;
380  auto operandType = adaptor.getRhs().getType();
381  if (auto vectorType = dyn_cast<VectorType>(operandType)) {
382  auto shape = vectorType.getShape();
383  intType = VectorType::get(shape, scalarIntType);
384  }
385 
386  // Per GL Pow extended instruction spec:
387  // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
388  Location loc = powfOp.getLoc();
389  Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
390  Value lessThan =
391  rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392 
393  // Per C/C++ spec:
394  // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395  // > finite and negative and exponent is finite and non-integer.
396  // Calculate the reminder from the exponent and check whether it is zero.
397  Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398  Value expRem =
399  rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400  Value expRemNonZero =
401  rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402  Value cmpNegativeWithFractionalExp =
403  rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404  // Create NaN result and replace base value if conditions are met.
405  const auto &floatSemantics = scalarFloatType.getFloatSemantics();
406  const auto nan = APFloat::getNaN(floatSemantics);
407  Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
408  if (auto vectorType = dyn_cast<VectorType>(operandType))
409  nanAttr = DenseElementsAttr::get(vectorType, nan);
410 
411  Value NanValue =
412  rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
413  Value lhs = rewriter.create<spirv::SelectOp>(
414  loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
415  Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
416 
417  // TODO: The following just forcefully casts y into an integer value in
418  // order to properly propagate the sign, assuming integer y cases. It
419  // doesn't cover other cases and should be fixed.
420 
421  // Cast exponent to integer and calculate exponent % 2 != 0.
422  Value intRhs =
423  rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
424  Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425  Value bitwiseAndOne =
426  rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
427  Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
428 
429  // calculate pow based on abs(lhs)^rhs.
430  Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
431  Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
432  // if the exponent is odd and lhs < 0, negate the result.
433  Value shouldNegate =
434  rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
435  rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
436  pow);
437  return success();
438  }
439 };
440 
441 /// Converts math.round to GLSL SPIRV extended ops.
442 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
444 
445  LogicalResult
446  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447  ConversionPatternRewriter &rewriter) const override {
448  if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
449  return res;
450 
451  Location loc = roundOp.getLoc();
452  Value operand = roundOp.getOperand();
453  Type ty = operand.getType();
454  Type ety = getElementTypeOrSelf(ty);
455 
456  auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457  auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458  Value half;
459  if (VectorType vty = dyn_cast<VectorType>(ty)) {
460  half = rewriter.create<spirv::ConstantOp>(
461  loc, vty,
463  rewriter.getFloatAttr(ety, 0.5).getValue()));
464  } else {
465  half = rewriter.create<spirv::ConstantOp>(
466  loc, ty, rewriter.getFloatAttr(ety, 0.5));
467  }
468 
469  auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
470  auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
471  auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
472  auto greater =
473  rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
474  auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
475  auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
476  rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477  return success();
478  }
479 };
480 
481 } // namespace
482 
483 //===----------------------------------------------------------------------===//
484 // Pattern population
485 //===----------------------------------------------------------------------===//
486 
487 namespace mlir {
490  // Core patterns
491  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
492 
493  // GLSL patterns
494  patterns
495  .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
496  Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
497  Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
498  ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
499  CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
500  CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
501  CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
502  CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
503  CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
504  CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
505  CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
506  CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
507  CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
508  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
509  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
510  CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
511  CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
512  CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
513  typeConverter, patterns.getContext());
514 
515  // OpenCL patterns
516  patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
517  Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
518  Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
519  CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
520  CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
521  CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
522  CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
523  CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
524  CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
525  CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
526  CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
527  CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
528  CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
529  CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
530  CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
531  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
532  CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
533  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
534  CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
535  CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
536  CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
537  typeConverter, patterns.getContext());
538 }
539 
540 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:75
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:54
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Definition: MathToSPIRV.cpp:35
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23