MLIR  22.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 #define DEBUG_TYPE "math-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32 /// given type is not a 32-bit scalar/vector type.
33 static Value getScalarOrVectorI32Constant(Type type, int value,
34  OpBuilder &builder, Location loc) {
35  if (auto vectorType = dyn_cast<VectorType>(type)) {
36  if (!vectorType.getElementType().isInteger(32))
37  return nullptr;
38  SmallVector<int> values(vectorType.getNumElements(), value);
39  return spirv::ConstantOp::create(builder, loc, type,
40  builder.getI32VectorAttr(values));
41  }
42  if (type.isInteger(32))
43  return spirv::ConstantOp::create(builder, loc, type,
44  builder.getI32IntegerAttr(value));
45 
46  return nullptr;
47 }
48 
49 /// Check if the type is supported by math-to-spirv conversion. We expect to
50 /// only see scalars and vectors at this point, with higher-level types already
51 /// lowered.
52 static bool isSupportedSourceType(Type originalType) {
53  if (originalType.isIntOrIndexOrFloat())
54  return true;
55 
56  if (auto vecTy = dyn_cast<VectorType>(originalType)) {
57  if (!vecTy.getElementType().isIntOrIndexOrFloat())
58  return false;
59  if (vecTy.isScalable())
60  return false;
61  if (vecTy.getRank() > 1)
62  return false;
63 
64  return true;
65  }
66 
67  return false;
68 }
69 
70 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
71 /// Notify of a match failure othwerise and return a `failure` result.
72 /// This is intended to simplify type checks in `OpConversionPattern`s.
73 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
74  Operation *sourceOp) {
75  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
76  llvm::append_range(allTypes, sourceOp->getResultTypes());
77 
78  for (Type ty : allTypes) {
79  if (!isSupportedSourceType(ty)) {
80  return rewriter.notifyMatchFailure(
81  sourceOp,
82  llvm::formatv(
83  "unsupported source type for Math to SPIR-V conversion: {0}",
84  ty));
85  }
86  }
87 
88  return success();
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Operation conversion
93 //===----------------------------------------------------------------------===//
94 
95 // Note that DRR cannot be used for the patterns in this file: we may need to
96 // convert type along the way, which requires ConversionPattern. DRR generates
97 // normal RewritePattern.
98 
99 namespace {
100 /// Converts elementwise unary, binary, and ternary standard operations to
101 /// SPIR-V operations. Checks that source `Op` types are supported.
102 template <typename Op, typename SPIRVOp>
103 struct CheckedElementwiseOpPattern final
104  : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
105  using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106  using BasePattern::BasePattern;
107 
108  LogicalResult
109  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
112  return res;
113 
114  return BasePattern::matchAndRewrite(op, adaptor, rewriter);
115  }
116 };
117 
118 /// Converts math.copysign to SPIR-V ops.
119 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
121 
122  LogicalResult
123  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124  ConversionPatternRewriter &rewriter) const override {
125  if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
126  failed(res))
127  return res;
128 
129  Type type = getTypeConverter()->convertType(copySignOp.getType());
130  if (!type)
131  return failure();
132 
133  FloatType floatType;
134  if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135  floatType = scalarType;
136  } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137  floatType = cast<FloatType>(vectorType.getElementType());
138  } else {
139  return failure();
140  }
141 
142  Location loc = copySignOp.getLoc();
143  int bitwidth = floatType.getWidth();
144  Type intType = rewriter.getIntegerType(bitwidth);
145  uint64_t intValue = uint64_t(1) << (bitwidth - 1);
146 
147  Value signMask = spirv::ConstantOp::create(
148  rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
149  Value valueMask = spirv::ConstantOp::create(
150  rewriter, loc, intType,
151  rewriter.getIntegerAttr(intType, intValue - 1u));
152 
153  if (auto vectorType = dyn_cast<VectorType>(type)) {
154  assert(vectorType.getRank() == 1);
155  int count = vectorType.getNumElements();
156  intType = VectorType::get(count, intType);
157 
158  SmallVector<Value> signSplat(count, signMask);
159  signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
160  signSplat);
161 
162  SmallVector<Value> valueSplat(count, valueMask);
163  valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
164  valueSplat);
165  }
166 
167  Value lhsCast =
168  spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
169  Value rhsCast =
170  spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
171 
172  Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
173  ValueRange{lhsCast, valueMask});
174  Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
175  ValueRange{rhsCast, signMask});
176 
177  Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
178  ValueRange{value, sign});
179  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
180  return success();
181  }
182 };
183 
184 /// Converts math.ctlz to SPIR-V ops.
185 ///
186 /// SPIR-V does not have a direct operations for counting leading zeros. If
187 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
188 /// it.
189 struct CountLeadingZerosPattern final
190  : public OpConversionPattern<math::CountLeadingZerosOp> {
192 
193  LogicalResult
194  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
195  ConversionPatternRewriter &rewriter) const override {
196  if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
197  return res;
198 
199  Type type = getTypeConverter()->convertType(countOp.getType());
200  if (!type)
201  return failure();
202 
203  // We can only support 32-bit integer types for now.
204  unsigned bitwidth = 0;
205  if (isa<IntegerType>(type))
206  bitwidth = type.getIntOrFloatBitWidth();
207  if (auto vectorType = dyn_cast<VectorType>(type))
208  bitwidth = vectorType.getElementTypeBitWidth();
209  if (bitwidth != 32)
210  return failure();
211 
212  Location loc = countOp.getLoc();
213  Value input = adaptor.getOperand();
214  Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
215  Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
216  Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
217 
218  Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
219  // We need to subtract from 31 given that the index returned by GLSL
220  // FindUMsb is counted from the least significant bit. Theoretically this
221  // also gives the correct result even if the integer has all zero bits, in
222  // which case GL FindUMsb would return -1.
223  Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
224  // However, certain Vulkan implementations have driver bugs for the corner
225  // case where the input is zero. And.. it can be smart to optimize a select
226  // only involving the corner case. So separately compute the result when the
227  // input is either zero or one.
228  Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
229  Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
230  rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
231  subMsb);
232  return success();
233  }
234 };
235 
236 /// Converts math.expm1 to SPIR-V ops.
237 ///
238 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
239 /// these operations.
240 template <typename ExpOp>
241 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
243 
244  LogicalResult
245  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  assert(adaptor.getOperands().size() == 1);
248  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
249  failed(res))
250  return res;
251 
252  Location loc = operation.getLoc();
253  Type type = this->getTypeConverter()->convertType(operation.getType());
254  if (!type)
255  return failure();
256 
257  Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
258  auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
259  rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
260  return success();
261  }
262 };
263 
264 /// Converts math.log1p to SPIR-V ops.
265 ///
266 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
267 /// these operations.
268 template <typename LogOp>
269 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
271 
272  LogicalResult
273  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const override {
275  assert(adaptor.getOperands().size() == 1);
276  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
277  failed(res))
278  return res;
279 
280  Location loc = operation.getLoc();
281  Type type = this->getTypeConverter()->convertType(operation.getType());
282  if (!type)
283  return failure();
284 
285  auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
286  Value onePlus =
287  spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
288  rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
289  return success();
290  }
291 };
292 
293 /// Converts math.log2 and math.log10 to SPIR-V ops.
294 ///
295 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
296 /// lower to these operations using:
297 /// log2(x) = log(x) * 1/log(2)
298 /// log10(x) = log(x) * 1/log(10)
299 
300 template <typename MathLogOp, typename SpirvLogOp>
301 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
304 
305  static constexpr double log2Reciprocal =
306  1.442695040888963407359924681001892137426645954152985934135449407;
307  static constexpr double log10Reciprocal =
308  0.4342944819032518276511289189166050822943970058036665661144537832;
309 
310  LogicalResult
311  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
312  ConversionPatternRewriter &rewriter) const override {
313  assert(adaptor.getOperands().size() == 1);
314  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
315  failed(res))
316  return res;
317 
318  Location loc = operation.getLoc();
319  Type type = this->getTypeConverter()->convertType(operation.getType());
320  if (!type)
321  return rewriter.notifyMatchFailure(operation, "type conversion failed");
322 
323  auto getConstantValue = [&](double value) {
324  if (auto floatType = dyn_cast<FloatType>(type)) {
325  return spirv::ConstantOp::create(
326  rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
327  }
328  if (auto vectorType = dyn_cast<VectorType>(type)) {
329  Type elemType = vectorType.getElementType();
330 
331  if (isa<FloatType>(elemType)) {
332  return spirv::ConstantOp::create(
333  rewriter, loc, type,
335  vectorType, FloatAttr::get(elemType, value).getValue()));
336  }
337  }
338 
339  llvm_unreachable("unimplemented types for log2/log10");
340  };
341 
342  Value constantValue = getConstantValue(
343  std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
344  : log10Reciprocal);
345  Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
346  rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
347  constantValue);
348  return success();
349  }
350 };
351 
352 /// Converts math.powf to SPIRV-Ops.
353 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
355 
356  LogicalResult
357  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
360  return res;
361 
362  Type dstType = getTypeConverter()->convertType(powfOp.getType());
363  if (!dstType)
364  return failure();
365 
366  // Get the scalar float type.
367  FloatType scalarFloatType;
368  if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
369  scalarFloatType = scalarType;
370  } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
371  scalarFloatType = cast<FloatType>(vectorType.getElementType());
372  } else {
373  return failure();
374  }
375 
376  // Get int type of the same shape as the float type.
377  Type scalarIntType = rewriter.getIntegerType(32);
378  Type intType = scalarIntType;
379  auto operandType = adaptor.getRhs().getType();
380  if (auto vectorType = dyn_cast<VectorType>(operandType)) {
381  auto shape = vectorType.getShape();
382  intType = VectorType::get(shape, scalarIntType);
383  }
384 
385  // Per GL Pow extended instruction spec:
386  // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
387  Location loc = powfOp.getLoc();
388  Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
389  Value lessThan =
390  spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
391 
392  // Per C/C++ spec:
393  // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
394  // > finite and negative and exponent is finite and non-integer.
395  // Calculate the reminder from the exponent and check whether it is zero.
396  Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
397  Value expRem =
398  spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
399  Value expRemNonZero =
400  spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
401  Value cmpNegativeWithFractionalExp =
402  spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
403  // Create NaN result and replace base value if conditions are met.
404  const auto &floatSemantics = scalarFloatType.getFloatSemantics();
405  const auto nan = APFloat::getNaN(floatSemantics);
406  Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
407  if (auto vectorType = dyn_cast<VectorType>(operandType))
408  nanAttr = DenseElementsAttr::get(vectorType, nan);
409 
410  Value NanValue =
411  spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
412  Value lhs =
413  spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
414  NanValue, adaptor.getLhs());
415  Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
416 
417  // TODO: The following just forcefully casts y into an integer value in
418  // order to properly propagate the sign, assuming integer y cases. It
419  // doesn't cover other cases and should be fixed.
420 
421  // Cast exponent to integer and calculate exponent % 2 != 0.
422  Value intRhs =
423  spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
424  Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425  Value bitwiseAndOne =
426  spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
427  Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
428 
429  // calculate pow based on abs(lhs)^rhs.
430  Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
431  Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
432  // if the exponent is odd and lhs < 0, negate the result.
433  Value shouldNegate =
434  spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
435  rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
436  pow);
437  return success();
438  }
439 };
440 
441 /// Converts math.round to GLSL SPIRV extended ops.
442 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
444 
445  LogicalResult
446  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447  ConversionPatternRewriter &rewriter) const override {
448  if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
449  return res;
450 
451  Location loc = roundOp.getLoc();
452  Value operand = roundOp.getOperand();
453  Type ty = operand.getType();
454  Type ety = getElementTypeOrSelf(ty);
455 
456  auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457  auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458  Value half;
459  if (VectorType vty = dyn_cast<VectorType>(ty)) {
460  half = spirv::ConstantOp::create(
461  rewriter, loc, vty,
463  rewriter.getFloatAttr(ety, 0.5).getValue()));
464  } else {
465  half = spirv::ConstantOp::create(rewriter, loc, ty,
466  rewriter.getFloatAttr(ety, 0.5));
467  }
468 
469  auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
470  auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
471  auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
472  auto greater =
473  spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
474  auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
475  auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
476  rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477  return success();
478  }
479 };
480 
481 } // namespace
482 
483 //===----------------------------------------------------------------------===//
484 // Pattern population
485 //===----------------------------------------------------------------------===//
486 
487 namespace mlir {
490  // Core patterns
491  patterns
492  .add<CopySignPattern,
493  CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
494  CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
495  CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
496  typeConverter, patterns.getContext());
497 
498  // GLSL patterns
499  patterns
500  .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
501  Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
502  Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
503  ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
504  CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
505  CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
506  CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
507  CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
508  CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
509  CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
510  CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
511  CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
512  CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
513  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
514  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
515  CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
516  CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
517  CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
518  CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
519  CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
520  CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
521  CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
522  CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
523  CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
524  CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
525  CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
526  typeConverter, patterns.getContext());
527 
528  // OpenCL patterns
529  patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
530  Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
531  Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
532  CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
533  CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
534  CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
535  CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
536  CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
537  CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
538  CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
539  CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
540  CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
541  CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
542  CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
543  CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
544  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
545  CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
546  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
547  CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
548  CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
549  CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
550  CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
551  CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
552  CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
553  CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
554  CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
555  CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
556  CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
557  CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
558  typeConverter, patterns.getContext());
559 }
560 
561 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:73
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:52
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Definition: MathToSPIRV.cpp:33
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:117
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23