MLIR  22.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 #define DEBUG_TYPE "math-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32 /// given type is not a 32-bit scalar/vector type.
33 static Value getScalarOrVectorI32Constant(Type type, int value,
34  OpBuilder &builder, Location loc) {
35  if (auto vectorType = dyn_cast<VectorType>(type)) {
36  if (!vectorType.getElementType().isInteger(32))
37  return nullptr;
38  SmallVector<int> values(vectorType.getNumElements(), value);
39  return builder.create<spirv::ConstantOp>(loc, type,
40  builder.getI32VectorAttr(values));
41  }
42  if (type.isInteger(32))
43  return builder.create<spirv::ConstantOp>(loc, type,
44  builder.getI32IntegerAttr(value));
45 
46  return nullptr;
47 }
48 
49 /// Check if the type is supported by math-to-spirv conversion. We expect to
50 /// only see scalars and vectors at this point, with higher-level types already
51 /// lowered.
52 static bool isSupportedSourceType(Type originalType) {
53  if (originalType.isIntOrIndexOrFloat())
54  return true;
55 
56  if (auto vecTy = dyn_cast<VectorType>(originalType)) {
57  if (!vecTy.getElementType().isIntOrIndexOrFloat())
58  return false;
59  if (vecTy.isScalable())
60  return false;
61  if (vecTy.getRank() > 1)
62  return false;
63 
64  return true;
65  }
66 
67  return false;
68 }
69 
70 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
71 /// Notify of a match failure othwerise and return a `failure` result.
72 /// This is intended to simplify type checks in `OpConversionPattern`s.
73 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
74  Operation *sourceOp) {
75  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
76  llvm::append_range(allTypes, sourceOp->getResultTypes());
77 
78  for (Type ty : allTypes) {
79  if (!isSupportedSourceType(ty)) {
80  return rewriter.notifyMatchFailure(
81  sourceOp,
82  llvm::formatv(
83  "unsupported source type for Math to SPIR-V conversion: {0}",
84  ty));
85  }
86  }
87 
88  return success();
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Operation conversion
93 //===----------------------------------------------------------------------===//
94 
95 // Note that DRR cannot be used for the patterns in this file: we may need to
96 // convert type along the way, which requires ConversionPattern. DRR generates
97 // normal RewritePattern.
98 
99 namespace {
100 /// Converts elementwise unary, binary, and ternary standard operations to
101 /// SPIR-V operations. Checks that source `Op` types are supported.
102 template <typename Op, typename SPIRVOp>
103 struct CheckedElementwiseOpPattern final
104  : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
105  using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106  using BasePattern::BasePattern;
107 
108  LogicalResult
109  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
112  return res;
113 
114  return BasePattern::matchAndRewrite(op, adaptor, rewriter);
115  }
116 };
117 
118 /// Converts math.copysign to SPIR-V ops.
119 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
121 
122  LogicalResult
123  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124  ConversionPatternRewriter &rewriter) const override {
125  if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
126  failed(res))
127  return res;
128 
129  Type type = getTypeConverter()->convertType(copySignOp.getType());
130  if (!type)
131  return failure();
132 
133  FloatType floatType;
134  if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135  floatType = scalarType;
136  } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137  floatType = cast<FloatType>(vectorType.getElementType());
138  } else {
139  return failure();
140  }
141 
142  Location loc = copySignOp.getLoc();
143  int bitwidth = floatType.getWidth();
144  Type intType = rewriter.getIntegerType(bitwidth);
145  uint64_t intValue = uint64_t(1) << (bitwidth - 1);
146 
147  Value signMask = rewriter.create<spirv::ConstantOp>(
148  loc, intType, rewriter.getIntegerAttr(intType, intValue));
149  Value valueMask = rewriter.create<spirv::ConstantOp>(
150  loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
151 
152  if (auto vectorType = dyn_cast<VectorType>(type)) {
153  assert(vectorType.getRank() == 1);
154  int count = vectorType.getNumElements();
155  intType = VectorType::get(count, intType);
156 
157  SmallVector<Value> signSplat(count, signMask);
158  signMask =
159  rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
160 
161  SmallVector<Value> valueSplat(count, valueMask);
162  valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
163  valueSplat);
164  }
165 
166  Value lhsCast =
167  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
168  Value rhsCast =
169  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
170 
171  Value value = rewriter.create<spirv::BitwiseAndOp>(
172  loc, intType, ValueRange{lhsCast, valueMask});
173  Value sign = rewriter.create<spirv::BitwiseAndOp>(
174  loc, intType, ValueRange{rhsCast, signMask});
175 
176  Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
177  ValueRange{value, sign});
178  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
179  return success();
180  }
181 };
182 
183 /// Converts math.ctlz to SPIR-V ops.
184 ///
185 /// SPIR-V does not have a direct operations for counting leading zeros. If
186 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
187 /// it.
188 struct CountLeadingZerosPattern final
189  : public OpConversionPattern<math::CountLeadingZerosOp> {
191 
192  LogicalResult
193  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override {
195  if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
196  return res;
197 
198  Type type = getTypeConverter()->convertType(countOp.getType());
199  if (!type)
200  return failure();
201 
202  // We can only support 32-bit integer types for now.
203  unsigned bitwidth = 0;
204  if (isa<IntegerType>(type))
205  bitwidth = type.getIntOrFloatBitWidth();
206  if (auto vectorType = dyn_cast<VectorType>(type))
207  bitwidth = vectorType.getElementTypeBitWidth();
208  if (bitwidth != 32)
209  return failure();
210 
211  Location loc = countOp.getLoc();
212  Value input = adaptor.getOperand();
213  Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
214  Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
215  Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
216 
217  Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
218  // We need to subtract from 31 given that the index returned by GLSL
219  // FindUMsb is counted from the least significant bit. Theoretically this
220  // also gives the correct result even if the integer has all zero bits, in
221  // which case GL FindUMsb would return -1.
222  Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
223  // However, certain Vulkan implementations have driver bugs for the corner
224  // case where the input is zero. And.. it can be smart to optimize a select
225  // only involving the corner case. So separately compute the result when the
226  // input is either zero or one.
227  Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
228  Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
229  rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
230  subMsb);
231  return success();
232  }
233 };
234 
235 /// Converts math.expm1 to SPIR-V ops.
236 ///
237 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
238 /// these operations.
239 template <typename ExpOp>
240 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
242 
243  LogicalResult
244  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
245  ConversionPatternRewriter &rewriter) const override {
246  assert(adaptor.getOperands().size() == 1);
247  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
248  failed(res))
249  return res;
250 
251  Location loc = operation.getLoc();
252  Type type = this->getTypeConverter()->convertType(operation.getType());
253  if (!type)
254  return failure();
255 
256  Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
257  auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
258  rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
259  return success();
260  }
261 };
262 
263 /// Converts math.log1p to SPIR-V ops.
264 ///
265 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
266 /// these operations.
267 template <typename LogOp>
268 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
270 
271  LogicalResult
272  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
273  ConversionPatternRewriter &rewriter) const override {
274  assert(adaptor.getOperands().size() == 1);
275  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
276  failed(res))
277  return res;
278 
279  Location loc = operation.getLoc();
280  Type type = this->getTypeConverter()->convertType(operation.getType());
281  if (!type)
282  return failure();
283 
284  auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
285  Value onePlus =
286  rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
287  rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
288  return success();
289  }
290 };
291 
292 /// Converts math.log2 and math.log10 to SPIR-V ops.
293 ///
294 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
295 /// lower to these operations using:
296 /// log2(x) = log(x) * 1/log(2)
297 /// log10(x) = log(x) * 1/log(10)
298 
299 template <typename MathLogOp, typename SpirvLogOp>
300 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
303 
304  static constexpr double log2Reciprocal =
305  1.442695040888963407359924681001892137426645954152985934135449407;
306  static constexpr double log10Reciprocal =
307  0.4342944819032518276511289189166050822943970058036665661144537832;
308 
309  LogicalResult
310  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const override {
312  assert(adaptor.getOperands().size() == 1);
313  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
314  failed(res))
315  return res;
316 
317  Location loc = operation.getLoc();
318  Type type = this->getTypeConverter()->convertType(operation.getType());
319  if (!type)
320  return rewriter.notifyMatchFailure(operation, "type conversion failed");
321 
322  auto getConstantValue = [&](double value) {
323  if (auto floatType = dyn_cast<FloatType>(type)) {
324  return rewriter.create<spirv::ConstantOp>(
325  loc, type, rewriter.getFloatAttr(floatType, value));
326  }
327  if (auto vectorType = dyn_cast<VectorType>(type)) {
328  Type elemType = vectorType.getElementType();
329 
330  if (isa<FloatType>(elemType)) {
331  return rewriter.create<spirv::ConstantOp>(
332  loc, type,
334  vectorType, FloatAttr::get(elemType, value).getValue()));
335  }
336  }
337 
338  llvm_unreachable("unimplemented types for log2/log10");
339  };
340 
341  Value constantValue = getConstantValue(
342  std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
343  : log10Reciprocal);
344  Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
345  rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
346  constantValue);
347  return success();
348  }
349 };
350 
351 /// Converts math.powf to SPIRV-Ops.
352 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
354 
355  LogicalResult
356  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
357  ConversionPatternRewriter &rewriter) const override {
358  if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
359  return res;
360 
361  Type dstType = getTypeConverter()->convertType(powfOp.getType());
362  if (!dstType)
363  return failure();
364 
365  // Get the scalar float type.
366  FloatType scalarFloatType;
367  if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
368  scalarFloatType = scalarType;
369  } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
370  scalarFloatType = cast<FloatType>(vectorType.getElementType());
371  } else {
372  return failure();
373  }
374 
375  // Get int type of the same shape as the float type.
376  Type scalarIntType = rewriter.getIntegerType(32);
377  Type intType = scalarIntType;
378  auto operandType = adaptor.getRhs().getType();
379  if (auto vectorType = dyn_cast<VectorType>(operandType)) {
380  auto shape = vectorType.getShape();
381  intType = VectorType::get(shape, scalarIntType);
382  }
383 
384  // Per GL Pow extended instruction spec:
385  // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
386  Location loc = powfOp.getLoc();
387  Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
388  Value lessThan =
389  rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
390 
391  // Per C/C++ spec:
392  // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
393  // > finite and negative and exponent is finite and non-integer.
394  // Calculate the reminder from the exponent and check whether it is zero.
395  Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
396  Value expRem =
397  rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
398  Value expRemNonZero =
399  rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
400  Value cmpNegativeWithFractionalExp =
401  rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
402  // Create NaN result and replace base value if conditions are met.
403  const auto &floatSemantics = scalarFloatType.getFloatSemantics();
404  const auto nan = APFloat::getNaN(floatSemantics);
405  Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
406  if (auto vectorType = dyn_cast<VectorType>(operandType))
407  nanAttr = DenseElementsAttr::get(vectorType, nan);
408 
409  Value NanValue =
410  rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
411  Value lhs = rewriter.create<spirv::SelectOp>(
412  loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
413  Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
414 
415  // TODO: The following just forcefully casts y into an integer value in
416  // order to properly propagate the sign, assuming integer y cases. It
417  // doesn't cover other cases and should be fixed.
418 
419  // Cast exponent to integer and calculate exponent % 2 != 0.
420  Value intRhs =
421  rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
422  Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
423  Value bitwiseAndOne =
424  rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
425  Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
426 
427  // calculate pow based on abs(lhs)^rhs.
428  Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
429  Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
430  // if the exponent is odd and lhs < 0, negate the result.
431  Value shouldNegate =
432  rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
433  rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
434  pow);
435  return success();
436  }
437 };
438 
439 /// Converts math.round to GLSL SPIRV extended ops.
440 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
442 
443  LogicalResult
444  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
445  ConversionPatternRewriter &rewriter) const override {
446  if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
447  return res;
448 
449  Location loc = roundOp.getLoc();
450  Value operand = roundOp.getOperand();
451  Type ty = operand.getType();
452  Type ety = getElementTypeOrSelf(ty);
453 
454  auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
455  auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
456  Value half;
457  if (VectorType vty = dyn_cast<VectorType>(ty)) {
458  half = rewriter.create<spirv::ConstantOp>(
459  loc, vty,
461  rewriter.getFloatAttr(ety, 0.5).getValue()));
462  } else {
463  half = rewriter.create<spirv::ConstantOp>(
464  loc, ty, rewriter.getFloatAttr(ety, 0.5));
465  }
466 
467  auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
468  auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
469  auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
470  auto greater =
471  rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
472  auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
473  auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
474  rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
475  return success();
476  }
477 };
478 
479 } // namespace
480 
481 //===----------------------------------------------------------------------===//
482 // Pattern population
483 //===----------------------------------------------------------------------===//
484 
485 namespace mlir {
488  // Core patterns
489  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
490 
491  // GLSL patterns
492  patterns
493  .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
494  Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
495  Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
496  ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
497  CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
498  CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
499  CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
500  CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
501  CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
502  CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
503  CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
504  CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
505  CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
506  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
507  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
508  CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
509  CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
510  CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
511  CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
512  CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
513  CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
514  CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
515  CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
516  CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
517  CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
518  CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
519  typeConverter, patterns.getContext());
520 
521  // OpenCL patterns
522  patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
523  Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
524  Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
525  CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
526  CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
527  CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
528  CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
529  CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
530  CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
531  CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
532  CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
533  CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
534  CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
535  CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
536  CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
537  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
538  CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
539  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
540  CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
541  CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
542  CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
543  CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
544  CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
545  CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
546  CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
547  CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
548  CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
549  CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
550  CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
551  typeConverter, patterns.getContext());
552 }
553 
554 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:73
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:52
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Definition: MathToSPIRV.cpp:33
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:117
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23