MLIR 22.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Math dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/FormatVariadic.h"
22
23#define DEBUG_TYPE "math-to-spirv-pattern"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Utility functions
29//===----------------------------------------------------------------------===//
30
31/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32/// given type is not a 32-bit scalar/vector type.
34 OpBuilder &builder, Location loc) {
35 if (auto vectorType = dyn_cast<VectorType>(type)) {
36 if (!vectorType.getElementType().isInteger(32))
37 return nullptr;
38 SmallVector<int> values(vectorType.getNumElements(), value);
39 return spirv::ConstantOp::create(builder, loc, type,
40 builder.getI32VectorAttr(values));
41 }
42 if (type.isInteger(32))
43 return spirv::ConstantOp::create(builder, loc, type,
44 builder.getI32IntegerAttr(value));
45
46 return nullptr;
47}
48
49/// Check if the type is supported by math-to-spirv conversion. We expect to
50/// only see scalars and vectors at this point, with higher-level types already
51/// lowered.
52static bool isSupportedSourceType(Type originalType) {
53 if (originalType.isIntOrIndexOrFloat())
54 return true;
55
56 if (auto vecTy = dyn_cast<VectorType>(originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
58 return false;
59 if (vecTy.isScalable())
60 return false;
61 if (vecTy.getRank() > 1)
62 return false;
63
64 return true;
65 }
66
67 return false;
68}
69
70/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
71/// Notify of a match failure othwerise and return a `failure` result.
72/// This is intended to simplify type checks in `OpConversionPattern`s.
73static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
74 Operation *sourceOp) {
75 auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
76 llvm::append_range(allTypes, sourceOp->getResultTypes());
77
78 for (Type ty : allTypes) {
79 if (!isSupportedSourceType(ty)) {
80 return rewriter.notifyMatchFailure(
81 sourceOp,
82 llvm::formatv(
83 "unsupported source type for Math to SPIR-V conversion: {0}",
84 ty));
85 }
86 }
87
88 return success();
89}
90
91//===----------------------------------------------------------------------===//
92// Operation conversion
93//===----------------------------------------------------------------------===//
94
95// Note that DRR cannot be used for the patterns in this file: we may need to
96// convert type along the way, which requires ConversionPattern. DRR generates
97// normal RewritePattern.
98
99namespace {
100/// Converts elementwise unary, binary, and ternary standard operations to
101/// SPIR-V operations. Checks that source `Op` types are supported.
102template <typename Op, typename SPIRVOp>
103struct CheckedElementwiseOpPattern final
104 : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
105 using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106 using BasePattern::BasePattern;
107
108 LogicalResult
109 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
112 return res;
113
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
115 }
116};
117
118/// Converts math.copysign to SPIR-V ops.
119struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
120 using Base::Base;
121
122 LogicalResult
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter) const override {
125 if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
126 failed(res))
127 return res;
128
129 Type type = getTypeConverter()->convertType(copySignOp.getType());
130 if (!type)
131 return failure();
132
133 FloatType floatType;
134 if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135 floatType = scalarType;
136 } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137 floatType = cast<FloatType>(vectorType.getElementType());
138 } else {
139 return failure();
140 }
141
142 Location loc = copySignOp.getLoc();
143 int bitwidth = floatType.getWidth();
144 Type intType = rewriter.getIntegerType(bitwidth);
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
146
147 Value signMask = spirv::ConstantOp::create(
148 rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
149 Value valueMask = spirv::ConstantOp::create(
150 rewriter, loc, intType,
151 rewriter.getIntegerAttr(intType, intValue - 1u));
152
153 if (auto vectorType = dyn_cast<VectorType>(type)) {
154 assert(vectorType.getRank() == 1);
155 int count = vectorType.getNumElements();
156 intType = VectorType::get(count, intType);
157
158 SmallVector<Value> signSplat(count, signMask);
159 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
160 signSplat);
161
162 SmallVector<Value> valueSplat(count, valueMask);
163 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
164 valueSplat);
165 }
166
167 Value lhsCast =
168 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
169 Value rhsCast =
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
171
172 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
173 ValueRange{lhsCast, valueMask});
174 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
175 ValueRange{rhsCast, signMask});
176
177 Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
178 ValueRange{value, sign});
179 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
180 return success();
181 }
182};
183
184/// Converts math.ctlz to SPIR-V ops.
185///
186/// SPIR-V does not have a direct operations for counting leading zeros. If
187/// Shader capability is supported, we can leverage GL FindUMsb to calculate
188/// it.
189struct CountLeadingZerosPattern final
190 : public OpConversionPattern<math::CountLeadingZerosOp> {
191 using Base::Base;
192
193 LogicalResult
194 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter) const override {
196 if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
197 return res;
198
199 Type type = getTypeConverter()->convertType(countOp.getType());
200 if (!type)
201 return failure();
202
203 // We can only support 32-bit integer types for now.
204 unsigned bitwidth = 0;
205 if (isa<IntegerType>(type))
206 bitwidth = type.getIntOrFloatBitWidth();
207 if (auto vectorType = dyn_cast<VectorType>(type))
208 bitwidth = vectorType.getElementTypeBitWidth();
209 if (bitwidth != 32)
210 return failure();
211
212 Location loc = countOp.getLoc();
213 Value input = adaptor.getOperand();
214 Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
215 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
216 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
217
218 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
219 // We need to subtract from 31 given that the index returned by GLSL
220 // FindUMsb is counted from the least significant bit. Theoretically this
221 // also gives the correct result even if the integer has all zero bits, in
222 // which case GL FindUMsb would return -1.
223 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
224 // However, certain Vulkan implementations have driver bugs for the corner
225 // case where the input is zero. And.. it can be smart to optimize a select
226 // only involving the corner case. So separately compute the result when the
227 // input is either zero or one.
228 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
229 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
230 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
231 subMsb);
232 return success();
233 }
234};
235
236/// Converts math.expm1 to SPIR-V ops.
237///
238/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
239/// these operations.
240template <typename ExpOp>
241struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
242 using Base::Base;
243
244 LogicalResult
245 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 assert(adaptor.getOperands().size() == 1);
248 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
249 failed(res))
250 return res;
251
252 Location loc = operation.getLoc();
253 Type type = this->getTypeConverter()->convertType(operation.getType());
254 if (!type)
255 return failure();
256
257 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
258 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
259 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
260 return success();
261 }
262};
263
264/// Converts math.log1p to SPIR-V ops.
265///
266/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
267/// these operations.
268template <typename LogOp>
269struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
270 using Base::Base;
271
272 LogicalResult
273 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const override {
275 assert(adaptor.getOperands().size() == 1);
276 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
277 failed(res))
278 return res;
279
280 Location loc = operation.getLoc();
281 Type type = this->getTypeConverter()->convertType(operation.getType());
282 if (!type)
283 return failure();
284
285 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
286 Value onePlus =
287 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
288 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
289 return success();
290 }
291};
292
293/// Converts math.log2 and math.log10 to SPIR-V ops.
294///
295/// SPIR-V does not have direct operations for log2 and log10. Explicitly
296/// lower to these operations using:
297/// log2(x) = log(x) * 1/log(2)
298/// log10(x) = log(x) * 1/log(10)
299
300template <typename MathLogOp, typename SpirvLogOp>
301struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
302 using OpConversionPattern<MathLogOp>::OpConversionPattern;
303 using typename OpConversionPattern<MathLogOp>::OpAdaptor;
304
305 static constexpr double log2Reciprocal =
306 1.442695040888963407359924681001892137426645954152985934135449407;
307 static constexpr double log10Reciprocal =
308 0.4342944819032518276511289189166050822943970058036665661144537832;
309
310 LogicalResult
311 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
312 ConversionPatternRewriter &rewriter) const override {
313 assert(adaptor.getOperands().size() == 1);
314 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
315 failed(res))
316 return res;
317
318 Location loc = operation.getLoc();
319 Type type = this->getTypeConverter()->convertType(operation.getType());
320 if (!type)
321 return rewriter.notifyMatchFailure(operation, "type conversion failed");
322
323 auto getConstantValue = [&](double value) {
324 if (auto floatType = dyn_cast<FloatType>(type)) {
325 return spirv::ConstantOp::create(
326 rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
327 }
328 if (auto vectorType = dyn_cast<VectorType>(type)) {
329 Type elemType = vectorType.getElementType();
330
331 if (isa<FloatType>(elemType)) {
332 return spirv::ConstantOp::create(
333 rewriter, loc, type,
335 vectorType, FloatAttr::get(elemType, value).getValue()));
336 }
337 }
338
339 llvm_unreachable("unimplemented types for log2/log10");
340 };
341
342 Value constantValue = getConstantValue(
343 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
344 : log10Reciprocal);
345 Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
346 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
347 constantValue);
348 return success();
349 }
350};
351
352/// Converts math.powf to SPIRV-Ops.
353struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
360 return res;
361
362 Type dstType = getTypeConverter()->convertType(powfOp.getType());
363 if (!dstType)
364 return failure();
365
366 // Get the scalar float type.
367 FloatType scalarFloatType;
368 if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
369 scalarFloatType = scalarType;
370 } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
371 scalarFloatType = cast<FloatType>(vectorType.getElementType());
372 } else {
373 return failure();
374 }
375
376 // Get int type of the same shape as the float type.
377 Type scalarIntType = rewriter.getIntegerType(32);
378 Type intType = scalarIntType;
379 auto operandType = adaptor.getRhs().getType();
380 if (auto vectorType = dyn_cast<VectorType>(operandType)) {
381 auto shape = vectorType.getShape();
382 intType = VectorType::get(shape, scalarIntType);
383 }
384
385 // Per GL Pow extended instruction spec:
386 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
387 Location loc = powfOp.getLoc();
388 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
389 Value lessThan =
390 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
391
392 // Per C/C++ spec:
393 // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
394 // > finite and negative and exponent is finite and non-integer.
395 // Calculate the reminder from the exponent and check whether it is zero.
396 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
397 Value expRem =
398 spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
399 Value expRemNonZero =
400 spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
401 Value cmpNegativeWithFractionalExp =
402 spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
403 // Create NaN result and replace base value if conditions are met.
404 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
405 const auto nan = APFloat::getNaN(floatSemantics);
406 Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
407 if (auto vectorType = dyn_cast<VectorType>(operandType))
408 nanAttr = DenseElementsAttr::get(vectorType, nan);
409
410 Value nanValue =
411 spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
412 Value lhs =
413 spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
414 nanValue, adaptor.getLhs());
415 Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
416
417 // TODO: The following just forcefully casts y into an integer value in
418 // order to properly propagate the sign, assuming integer y cases. It
419 // doesn't cover other cases and should be fixed.
420
421 // Cast exponent to integer and calculate exponent % 2 != 0.
422 Value intRhs =
423 spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
424 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425 Value bitwiseAndOne =
426 spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
427 Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
428
429 // calculate pow based on abs(lhs)^rhs.
430 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
431 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
432 // if the exponent is odd and lhs < 0, negate the result.
433 Value shouldNegate =
434 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
435 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
436 pow);
437 return success();
438 }
439};
440
441/// Converts math.round to GLSL SPIRV extended ops.
442struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
443 using Base::Base;
444
445 LogicalResult
446 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter) const override {
448 if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
449 return res;
450
451 Location loc = roundOp.getLoc();
452 Value operand = roundOp.getOperand();
453 Type ty = operand.getType();
454 Type ety = getElementTypeOrSelf(ty);
455
456 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458 Value half;
459 if (VectorType vty = dyn_cast<VectorType>(ty)) {
460 half = spirv::ConstantOp::create(
461 rewriter, loc, vty,
463 rewriter.getFloatAttr(ety, 0.5).getValue()));
464 } else {
465 half = spirv::ConstantOp::create(rewriter, loc, ty,
466 rewriter.getFloatAttr(ety, 0.5));
467 }
468
469 auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
470 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
471 auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
472 auto greater =
473 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
474 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
475 auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
476 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477 return success();
478 }
479};
480
481} // namespace
482
483//===----------------------------------------------------------------------===//
484// Pattern population
485//===----------------------------------------------------------------------===//
486
487namespace mlir {
490 // Core patterns
492 .add<CopySignPattern,
493 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
494 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
495 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
496 typeConverter, patterns.getContext());
497
498 // GLSL patterns
500 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
501 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
502 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
503 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
504 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
505 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
506 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
507 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
508 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
509 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
510 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
511 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
512 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
513 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
514 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
515 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
516 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
517 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
518 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
519 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
520 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
521 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
522 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
523 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
524 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
525 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
526 typeConverter, patterns.getContext());
527
528 // OpenCL patterns
529 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
530 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
531 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
532 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
533 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
534 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
535 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
536 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
537 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
538 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
539 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
540 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
541 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
542 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
543 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
544 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
545 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
546 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
547 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
548 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
549 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
550 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
551 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
552 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
553 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
554 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
555 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
556 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
557 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
558 typeConverter, patterns.getContext());
559}
560
561} // namespace mlir
return success()
lhs
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
#define add(a, b)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:122
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition Fraction.h:77
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:23