MLIR  20.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "math-to-spirv-pattern"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34 /// given type is not a 32-bit scalar/vector type.
35 static Value getScalarOrVectorI32Constant(Type type, int value,
36  OpBuilder &builder, Location loc) {
37  if (auto vectorType = dyn_cast<VectorType>(type)) {
38  if (!vectorType.getElementType().isInteger(32))
39  return nullptr;
40  SmallVector<int> values(vectorType.getNumElements(), value);
41  return builder.create<spirv::ConstantOp>(loc, type,
42  builder.getI32VectorAttr(values));
43  }
44  if (type.isInteger(32))
45  return builder.create<spirv::ConstantOp>(loc, type,
46  builder.getI32IntegerAttr(value));
47 
48  return nullptr;
49 }
50 
51 /// Check if the type is supported by math-to-spirv conversion. We expect to
52 /// only see scalars and vectors at this point, with higher-level types already
53 /// lowered.
54 static bool isSupportedSourceType(Type originalType) {
55  if (originalType.isIntOrIndexOrFloat())
56  return true;
57 
58  if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59  if (!vecTy.getElementType().isIntOrIndexOrFloat())
60  return false;
61  if (vecTy.isScalable())
62  return false;
63  if (vecTy.getRank() > 1)
64  return false;
65 
66  return true;
67  }
68 
69  return false;
70 }
71 
72 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73 /// Notify of a match failure othwerise and return a `failure` result.
74 /// This is intended to simplify type checks in `OpConversionPattern`s.
75 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76  Operation *sourceOp) {
77  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
78  llvm::append_range(allTypes, sourceOp->getResultTypes());
79 
80  for (Type ty : allTypes) {
81  if (!isSupportedSourceType(ty)) {
82  return rewriter.notifyMatchFailure(
83  sourceOp,
84  llvm::formatv(
85  "unsupported source type for Math to SPIR-V conversion: {0}",
86  ty));
87  }
88  }
89 
90  return success();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Operation conversion
95 //===----------------------------------------------------------------------===//
96 
97 // Note that DRR cannot be used for the patterns in this file: we may need to
98 // convert type along the way, which requires ConversionPattern. DRR generates
99 // normal RewritePattern.
100 
101 namespace {
102 /// Converts elementwise unary, binary, and ternary standard operations to
103 /// SPIR-V operations. Checks that source `Op` types are supported.
104 template <typename Op, typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
106  : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107  using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108  using BasePattern::BasePattern;
109 
110  LogicalResult
111  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112  ConversionPatternRewriter &rewriter) const override {
113  if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
114  return res;
115 
116  return BasePattern::matchAndRewrite(op, adaptor, rewriter);
117  }
118 };
119 
120 /// Converts math.copysign to SPIR-V ops.
121 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
123 
124  LogicalResult
125  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126  ConversionPatternRewriter &rewriter) const override {
127  if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128  failed(res))
129  return res;
130 
131  Type type = getTypeConverter()->convertType(copySignOp.getType());
132  if (!type)
133  return failure();
134 
135  FloatType floatType;
136  if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137  floatType = scalarType;
138  } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139  floatType = cast<FloatType>(vectorType.getElementType());
140  } else {
141  return failure();
142  }
143 
144  Location loc = copySignOp.getLoc();
145  int bitwidth = floatType.getWidth();
146  Type intType = rewriter.getIntegerType(bitwidth);
147  uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148 
149  Value signMask = rewriter.create<spirv::ConstantOp>(
150  loc, intType, rewriter.getIntegerAttr(intType, intValue));
151  Value valueMask = rewriter.create<spirv::ConstantOp>(
152  loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
153 
154  if (auto vectorType = dyn_cast<VectorType>(type)) {
155  assert(vectorType.getRank() == 1);
156  int count = vectorType.getNumElements();
157  intType = VectorType::get(count, intType);
158 
159  SmallVector<Value> signSplat(count, signMask);
160  signMask =
161  rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162 
163  SmallVector<Value> valueSplat(count, valueMask);
164  valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165  valueSplat);
166  }
167 
168  Value lhsCast =
169  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170  Value rhsCast =
171  rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
172 
173  Value value = rewriter.create<spirv::BitwiseAndOp>(
174  loc, intType, ValueRange{lhsCast, valueMask});
175  Value sign = rewriter.create<spirv::BitwiseAndOp>(
176  loc, intType, ValueRange{rhsCast, signMask});
177 
178  Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179  ValueRange{value, sign});
180  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181  return success();
182  }
183 };
184 
185 /// Converts math.ctlz to SPIR-V ops.
186 ///
187 /// SPIR-V does not have a direct operations for counting leading zeros. If
188 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
189 /// it.
190 struct CountLeadingZerosPattern final
191  : public OpConversionPattern<math::CountLeadingZerosOp> {
193 
194  LogicalResult
195  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196  ConversionPatternRewriter &rewriter) const override {
197  if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
198  return res;
199 
200  Type type = getTypeConverter()->convertType(countOp.getType());
201  if (!type)
202  return failure();
203 
204  // We can only support 32-bit integer types for now.
205  unsigned bitwidth = 0;
206  if (isa<IntegerType>(type))
207  bitwidth = type.getIntOrFloatBitWidth();
208  if (auto vectorType = dyn_cast<VectorType>(type))
209  bitwidth = vectorType.getElementTypeBitWidth();
210  if (bitwidth != 32)
211  return failure();
212 
213  Location loc = countOp.getLoc();
214  Value input = adaptor.getOperand();
215  Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
216  Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
217  Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
218 
219  Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
220  // We need to subtract from 31 given that the index returned by GLSL
221  // FindUMsb is counted from the least significant bit. Theoretically this
222  // also gives the correct result even if the integer has all zero bits, in
223  // which case GL FindUMsb would return -1.
224  Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
225  // However, certain Vulkan implementations have driver bugs for the corner
226  // case where the input is zero. And.. it can be smart to optimize a select
227  // only involving the corner case. So separately compute the result when the
228  // input is either zero or one.
229  Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
230  Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
231  rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
232  subMsb);
233  return success();
234  }
235 };
236 
237 /// Converts math.expm1 to SPIR-V ops.
238 ///
239 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
240 /// these operations.
241 template <typename ExpOp>
242 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
244 
245  LogicalResult
246  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  assert(adaptor.getOperands().size() == 1);
249  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
250  failed(res))
251  return res;
252 
253  Location loc = operation.getLoc();
254  Type type = this->getTypeConverter()->convertType(operation.getType());
255  if (!type)
256  return failure();
257 
258  Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
259  auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
260  rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
261  return success();
262  }
263 };
264 
265 /// Converts math.log1p to SPIR-V ops.
266 ///
267 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268 /// these operations.
269 template <typename LogOp>
270 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
272 
273  LogicalResult
274  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275  ConversionPatternRewriter &rewriter) const override {
276  assert(adaptor.getOperands().size() == 1);
277  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
278  failed(res))
279  return res;
280 
281  Location loc = operation.getLoc();
282  Type type = this->getTypeConverter()->convertType(operation.getType());
283  if (!type)
284  return failure();
285 
286  auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287  Value onePlus =
288  rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
289  rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290  return success();
291  }
292 };
293 
294 /// Converts math.log2 and math.log10 to SPIR-V ops.
295 ///
296 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
297 /// lower to these operations using:
298 /// log2(x) = log(x) * 1/log(2)
299 /// log10(x) = log(x) * 1/log(10)
300 
301 template <typename MathLogOp, typename SpirvLogOp>
302 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
305 
306  static constexpr double log2Reciprocal =
307  1.442695040888963407359924681001892137426645954152985934135449407;
308  static constexpr double log10Reciprocal =
309  0.4342944819032518276511289189166050822943970058036665661144537832;
310 
311  LogicalResult
312  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313  ConversionPatternRewriter &rewriter) const override {
314  assert(adaptor.getOperands().size() == 1);
315  if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
316  failed(res))
317  return res;
318 
319  Location loc = operation.getLoc();
320  Type type = this->getTypeConverter()->convertType(operation.getType());
321  if (!type)
322  return rewriter.notifyMatchFailure(operation, "type conversion failed");
323 
324  auto getConstantValue = [&](double value) {
325  if (auto floatType = dyn_cast<FloatType>(type)) {
326  return rewriter.create<spirv::ConstantOp>(
327  loc, type, rewriter.getFloatAttr(floatType, value));
328  }
329  if (auto vectorType = dyn_cast<VectorType>(type)) {
330  Type elemType = vectorType.getElementType();
331 
332  if (isa<FloatType>(elemType)) {
333  return rewriter.create<spirv::ConstantOp>(
334  loc, type,
336  vectorType, FloatAttr::get(elemType, value).getValue()));
337  }
338  }
339 
340  llvm_unreachable("unimplemented types for log2/log10");
341  };
342 
343  Value constantValue = getConstantValue(
344  std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345  : log10Reciprocal);
346  Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
347  rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
348  constantValue);
349  return success();
350  }
351 };
352 
353 /// Converts math.powf to SPIRV-Ops.
354 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
356 
357  LogicalResult
358  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
359  ConversionPatternRewriter &rewriter) const override {
360  if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
361  return res;
362 
363  Type dstType = getTypeConverter()->convertType(powfOp.getType());
364  if (!dstType)
365  return failure();
366 
367  // Get the scalar float type.
368  FloatType scalarFloatType;
369  if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370  scalarFloatType = scalarType;
371  } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372  scalarFloatType = cast<FloatType>(vectorType.getElementType());
373  } else {
374  return failure();
375  }
376 
377  // Get int type of the same shape as the float type.
378  Type scalarIntType = rewriter.getIntegerType(32);
379  Type intType = scalarIntType;
380  if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
381  auto shape = vectorType.getShape();
382  intType = VectorType::get(shape, scalarIntType);
383  }
384 
385  // Per GL Pow extended instruction spec:
386  // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
387  Location loc = powfOp.getLoc();
388  Value zero =
389  spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
390  Value lessThan =
391  rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392  Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
393 
394  // TODO: The following just forcefully casts y into an integer value in
395  // order to properly propagate the sign, assuming integer y cases. It
396  // doesn't cover other cases and should be fixed.
397 
398  // Cast exponent to integer and calculate exponent % 2 != 0.
399  Value intRhs =
400  rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
401  Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
402  Value bitwiseAndOne =
403  rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
404  Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
405 
406  // calculate pow based on abs(lhs)^rhs.
407  Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
408  Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
409  // if the exponent is odd and lhs < 0, negate the result.
410  Value shouldNegate =
411  rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
412  rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
413  pow);
414  return success();
415  }
416 };
417 
418 /// Converts math.round to GLSL SPIRV extended ops.
419 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
421 
422  LogicalResult
423  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
424  ConversionPatternRewriter &rewriter) const override {
425  if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
426  return res;
427 
428  Location loc = roundOp.getLoc();
429  Value operand = roundOp.getOperand();
430  Type ty = operand.getType();
431  Type ety = getElementTypeOrSelf(ty);
432 
433  auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
434  auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
435  Value half;
436  if (VectorType vty = dyn_cast<VectorType>(ty)) {
437  half = rewriter.create<spirv::ConstantOp>(
438  loc, vty,
440  rewriter.getFloatAttr(ety, 0.5).getValue()));
441  } else {
442  half = rewriter.create<spirv::ConstantOp>(
443  loc, ty, rewriter.getFloatAttr(ety, 0.5));
444  }
445 
446  auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
447  auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
448  auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
449  auto greater =
450  rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
451  auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
452  auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
453  rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
454  return success();
455  }
456 };
457 
458 } // namespace
459 
460 //===----------------------------------------------------------------------===//
461 // Pattern population
462 //===----------------------------------------------------------------------===//
463 
464 namespace mlir {
466  RewritePatternSet &patterns) {
467  // Core patterns
468  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
469 
470  // GLSL patterns
471  patterns
472  .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
473  Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
474  Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
475  ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
476  CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
477  CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
478  CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
479  CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
480  CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
481  CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
482  CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
483  CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
484  CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
485  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
486  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
487  CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
488  CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
489  CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
490  typeConverter, patterns.getContext());
491 
492  // OpenCL patterns
493  patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
494  Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
495  Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
496  CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
497  CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
498  CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
499  CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
500  CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
501  CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
502  CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
503  CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
504  CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
505  CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
506  CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
507  CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
508  CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
509  CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
510  CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
511  CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
512  CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
513  CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
514  typeConverter, patterns.getContext());
515 }
516 
517 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:75
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
Definition: MathToSPIRV.cpp:54
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Definition: MathToSPIRV.cpp:35
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:228
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:250
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:273
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:99
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:150
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:126
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:61
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23