MLIR 23.0.0git
MathToSPIRV.cpp
Go to the documentation of this file.
1//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Math dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Matchers.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#define DEBUG_TYPE "math-to-spirv-pattern"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Utility functions
31//===----------------------------------------------------------------------===//
32
33/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34/// given type is not a 32-bit scalar/vector type.
36 OpBuilder &builder, Location loc) {
37 if (auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
39 return nullptr;
40 SmallVector<int> values(vectorType.getNumElements(), value);
41 return spirv::ConstantOp::create(builder, loc, type,
42 builder.getI32VectorAttr(values));
43 }
44 if (type.isInteger(32))
45 return spirv::ConstantOp::create(builder, loc, type,
46 builder.getI32IntegerAttr(value));
47
48 return nullptr;
49}
50
51/// Check if the type is supported by math-to-spirv conversion. We expect to
52/// only see scalars and vectors at this point, with higher-level types already
53/// lowered.
54static bool isSupportedSourceType(Type originalType) {
55 if (originalType.isIntOrIndexOrFloat())
56 return true;
57
58 if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
60 return false;
61 if (vecTy.isScalable())
62 return false;
63 if (vecTy.getRank() > 1)
64 return false;
65
66 return true;
67 }
68
69 return false;
70}
71
72/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73/// Notify of a match failure othwerise and return a `failure` result.
74/// This is intended to simplify type checks in `OpConversionPattern`s.
75static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76 Operation *sourceOp) {
77 auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
78 llvm::append_range(allTypes, sourceOp->getResultTypes());
79
80 for (Type ty : allTypes) {
81 if (!isSupportedSourceType(ty)) {
82 return rewriter.notifyMatchFailure(
83 sourceOp,
84 llvm::formatv(
85 "unsupported source type for Math to SPIR-V conversion: {0}",
86 ty));
87 }
88 }
89
90 return success();
91}
92
93//===----------------------------------------------------------------------===//
94// Operation conversion
95//===----------------------------------------------------------------------===//
96
97// Note that DRR cannot be used for the patterns in this file: we may need to
98// convert type along the way, which requires ConversionPattern. DRR generates
99// normal RewritePattern.
100
101namespace {
102/// Converts elementwise unary, binary, and ternary standard operations to
103/// SPIR-V operations. Checks that source `Op` types are supported.
104template <typename Op, typename SPIRVOp>
105struct CheckedElementwiseOpPattern final
106 : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107 using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108 using BasePattern::BasePattern;
109
110 LogicalResult
111 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
114 return res;
115
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
117 }
118};
119
120/// Converts math.copysign to SPIR-V ops.
121struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
122 using Base::Base;
123
124 LogicalResult
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter) const override {
127 if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128 failed(res))
129 return res;
130
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
132 if (!type)
133 return failure();
134
135 FloatType floatType;
136 if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
140 } else {
141 return failure();
142 }
143
144 Location loc = copySignOp.getLoc();
145 int bitwidth = floatType.getWidth();
146 Type intType = rewriter.getIntegerType(bitwidth);
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148
149 Value signMask = spirv::ConstantOp::create(
150 rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
151 Value valueMask = spirv::ConstantOp::create(
152 rewriter, loc, intType,
153 rewriter.getIntegerAttr(intType, intValue - 1u));
154
155 if (auto vectorType = dyn_cast<VectorType>(type)) {
156 assert(vectorType.getRank() == 1);
157 int count = vectorType.getNumElements();
158 intType = VectorType::get(count, intType);
159
160 Repeated<Value> signSplat(count, signMask);
161 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
162 signSplat);
163
164 Repeated<Value> valueSplat(count, valueMask);
165 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
166 valueSplat);
167 }
168
169 Value lhsCast =
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
171 Value rhsCast =
172 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
173
174 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
175 ValueRange{lhsCast, valueMask});
176 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
177 ValueRange{rhsCast, signMask});
178
179 Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
180 ValueRange{value, sign});
181 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
182 return success();
183 }
184};
185
186/// Converts math.ctlz to SPIR-V ops.
187///
188/// OpenCL targets lower math.ctlz directly to OpenCL.std clz via the generic
189/// elementwise pattern. This pattern handles the shader fallback.
190///
191/// SPIR-V does not have a direct operations for counting leading zeros for
192/// glsl. If Shader capability is supported, we can leverage GL FindUMsb to
193/// calculate it.
194struct CountLeadingZerosPattern final
195 : public OpConversionPattern<math::CountLeadingZerosOp> {
196 using Base::Base;
197
198 LogicalResult
199 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter) const override {
201 if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
202 return res;
203
204 Type type = getTypeConverter()->convertType(countOp.getType());
205 if (!type)
206 return failure();
207
208 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
209 if (!typeConverter.getTargetEnv().allows(spirv::Capability::Shader))
210 return rewriter.notifyMatchFailure(countOp, "requires Shader capability");
211
212 // The GL FindUMsb fallback only supports 32-bit integer types for now.
213 unsigned bitwidth = 0;
214 if (isa<IntegerType>(type))
215 bitwidth = type.getIntOrFloatBitWidth();
216 if (auto vectorType = dyn_cast<VectorType>(type))
217 bitwidth = vectorType.getElementTypeBitWidth();
218 if (bitwidth != 32)
219 return failure();
220
221 Location loc = countOp.getLoc();
222 Value input = adaptor.getOperand();
223 Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
224 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
225 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
226
227 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
228 // We need to subtract from 31 given that the index returned by GLSL
229 // FindUMsb is counted from the least significant bit. Theoretically this
230 // also gives the correct result even if the integer has all zero bits, in
231 // which case GL FindUMsb would return -1.
232 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
233 // However, certain Vulkan implementations have driver bugs for the corner
234 // case where the input is zero. And.. it can be smart to optimize a select
235 // only involving the corner case. So separately compute the result when the
236 // input is either zero or one.
237 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
238 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
239 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
240 subMsb);
241 return success();
242 }
243};
244
245/// Converts math.cttz to GL FindILsb. GL FindILsb returns -1 for a zero
246/// input while math.cttz must return the bitwidth, so the zero case is
247/// patched up with a select.
248struct CountTrailingZerosPattern final
249 : public OpConversionPattern<math::CountTrailingZerosOp> {
250 using Base::Base;
251
252 LogicalResult
253 matchAndRewrite(math::CountTrailingZerosOp countOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override {
255 if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
256 return res;
257
258 Type type = getTypeConverter()->convertType(countOp.getType());
259 if (!type)
260 return failure();
261
262 unsigned bitwidth = 0;
263 if (isa<IntegerType>(type))
264 bitwidth = type.getIntOrFloatBitWidth();
265 else if (auto vectorType = dyn_cast<VectorType>(type))
266 bitwidth = vectorType.getElementTypeBitWidth();
267 if (bitwidth != 32)
268 return failure();
269
270 Location loc = countOp.getLoc();
271 Value input = adaptor.getOperand();
272 Value val0 = getScalarOrVectorI32Constant(type, 0, rewriter, loc);
273 Value valBitwidth =
274 getScalarOrVectorI32Constant(type, bitwidth, rewriter, loc);
275
276 Value lsb = spirv::GLFindILsbOp::create(rewriter, loc, input);
277 Value isZero = spirv::IEqualOp::create(rewriter, loc, input, val0);
278 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, isZero, valBitwidth,
279 lsb);
280 return success();
281 }
282};
283
284/// Converts math.expm1 to SPIR-V ops.
285///
286/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
287/// these operations.
288template <typename ExpOp>
289struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
290 using Base::Base;
291
292 LogicalResult
293 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
294 ConversionPatternRewriter &rewriter) const override {
295 assert(adaptor.getOperands().size() == 1);
296 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
297 failed(res))
298 return res;
299
300 Location loc = operation.getLoc();
301 Type type = this->getTypeConverter()->convertType(operation.getType());
302 if (!type)
303 return failure();
304
305 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
306 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
307 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
308 return success();
309 }
310};
311
312/// Converts math.log1p to SPIR-V ops.
313///
314/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
315/// these operations.
316template <typename LogOp>
317struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
318 using Base::Base;
319
320 LogicalResult
321 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 assert(adaptor.getOperands().size() == 1);
324 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
325 failed(res))
326 return res;
327
328 Location loc = operation.getLoc();
329 Type type = this->getTypeConverter()->convertType(operation.getType());
330 if (!type)
331 return failure();
332
333 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
334 Value onePlus =
335 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
336 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
337 return success();
338 }
339};
340
341/// Converts math.log10 to GLSL SPIR-V ops.
342///
343/// GLSL.std.450 has no Log10 instruction. Lower it as:
344/// log10(x) = log(x) * 1/log(10)
345struct Log10OpPattern final : public OpConversionPattern<math::Log10Op> {
346 using Base::Base;
347
348 static constexpr double log10Reciprocal =
349 0.4342944819032518276511289189166050822943970058036665661144537832;
350
351 LogicalResult
352 matchAndRewrite(math::Log10Op operation, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 assert(adaptor.getOperands().size() == 1);
355 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
356 failed(res))
357 return res;
358
359 Location loc = operation.getLoc();
360 Type type = this->getTypeConverter()->convertType(operation.getType());
361 if (!type)
362 return rewriter.notifyMatchFailure(operation, "type conversion failed");
363
364 auto getConstantValue = [&](double value) {
365 if (auto floatType = dyn_cast<FloatType>(type)) {
366 return spirv::ConstantOp::create(
367 rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
368 }
369 if (auto vectorType = dyn_cast<VectorType>(type)) {
370 Type elemType = vectorType.getElementType();
371
372 if (isa<FloatType>(elemType)) {
373 return spirv::ConstantOp::create(
374 rewriter, loc, type,
376 vectorType, FloatAttr::get(elemType, value).getValue()));
377 }
378 }
379 llvm_unreachable("unimplemented type for log10");
380 };
381
382 Value constantValue = getConstantValue(log10Reciprocal);
383 Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getOperand());
384 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
385 constantValue);
386 return success();
387 }
388};
389
390/// Converts math.powf to SPIRV-Ops.
391struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
392 using Base::Base;
393
394 LogicalResult
395 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter) const override {
397 if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
398 return res;
399
400 Type dstType = getTypeConverter()->convertType(powfOp.getType());
401 if (!dstType)
402 return failure();
403
404 Location loc = powfOp.getLoc();
405 Type operandType = adaptor.getRhs().getType();
406
407 // Parity-based lowering requires an integer-valued constant exponent.
408 // Otherwise fall back to exp(y*log(x)), which yields NaN for x<0 (matches
409 // C).
410 auto isOdd = [](const APFloat &v) {
411 APSInt i(/*BitWidth=*/64, /*isUnsigned=*/false);
412 bool ignored;
413 v.convertToInteger(i, APFloat::rmTowardZero, &ignored);
414 return i[0];
415 };
416
417 SmallVector<bool> oddMask;
418 Attribute rhsAttr;
419 if (matchPattern(adaptor.getRhs(), m_Constant(&rhsAttr))) {
420 TypeSwitch<Attribute>(rhsAttr)
421 .Case([&](FloatAttr a) {
422 if (a.getValue().isInteger())
423 oddMask.push_back(isOdd(a.getValue()));
424 })
425 .Case([&](SplatElementsAttr a) {
426 APFloat splat = a.getSplatValue<APFloat>();
427 if (splat.isInteger())
428 oddMask.push_back(isOdd(splat));
429 })
430 .Case([&](DenseElementsAttr a) {
431 SmallVector<bool> mask;
432 for (const APFloat &elt : a.getValues<APFloat>()) {
433 if (!elt.isInteger())
434 return;
435 mask.push_back(isOdd(elt));
436 }
437 oddMask = std::move(mask);
438 });
439 }
440
441 if (oddMask.empty()) {
442 Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
443 Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
444 rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul);
445 return success();
446 }
447
448 // GL.Pow is undefined for x < 0; take abs and conditionally negate the
449 // result for lanes whose exponent is odd.
450 Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
451 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
452
453 // No odd-parity element: result has the same sign as |lhs|^rhs >= 0.
454 if (llvm::none_of(oddMask, [](bool b) { return b; })) {
455 rewriter.replaceOp(powfOp, pow);
456 return success();
457 }
458
459 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
460 Value lessThan =
461 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
462 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
463
464 Value shouldNegate;
465 if (llvm::all_equal(oddMask)) {
466 // Every lane has odd exponent: negate iff lhs < 0.
467 shouldNegate = lessThan;
468 } else {
469 // Mixed parity (non-splat dense vector): AND lhs<0 with a per-element
470 // constant odd-mask.
471 auto vecType = cast<VectorType>(operandType);
472 auto maskType = VectorType::get(vecType.getShape(), rewriter.getI1Type());
473 Value oddConst = spirv::ConstantOp::create(
474 rewriter, loc, maskType, DenseElementsAttr::get(maskType, oddMask));
475 shouldNegate =
476 spirv::LogicalAndOp::create(rewriter, loc, lessThan, oddConst);
477 }
478
479 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
480 pow);
481 return success();
482 }
483};
484
485/// Converts math.fpowi to spirv.CL.pown.
486struct PowIOpPattern final : public OpConversionPattern<math::FPowIOp> {
487 using Base::Base;
488
489 LogicalResult
490 matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter) const override {
492 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
493 return res;
494
495 Type dstType = getTypeConverter()->convertType(op.getType());
496 if (!dstType)
497 return failure();
498
499 rewriter.replaceOpWithNewOp<spirv::CLPownOp>(op, dstType, adaptor.getLhs(),
500 adaptor.getRhs());
501 return success();
502 }
503};
504
505/// Converts math.fpowi to GLSL SPIR-V ops. GL has no integer-power op, so the
506/// exponent is converted to float and lowered through spirv.GL.Pow. As GL.Pow
507/// is undefined for a negative base, the base is made positive and the result
508/// is negated when the base is negative and the exponent is odd.
509struct PowIOpGLPattern final : public OpConversionPattern<math::FPowIOp> {
510 using Base::Base;
511
512 LogicalResult
513 matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
516 return res;
517
518 Type dstType = getTypeConverter()->convertType(op.getType());
519 if (!dstType)
520 return failure();
521
522 Location loc = op.getLoc();
523 Value base = adaptor.getLhs();
524 Value power = adaptor.getRhs();
525
526 Value expFloat =
527 spirv::ConvertSToFOp::create(rewriter, loc, dstType, power);
528 Value abs = spirv::GLFAbsOp::create(rewriter, loc, base);
529 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, expFloat);
530
531 Value zeroF = spirv::ConstantOp::getZero(dstType, loc, rewriter);
532 Value lessThan = spirv::FOrdLessThanOp::create(rewriter, loc, base, zeroF);
533
534 Type powerType = power.getType();
535 Value oneI = spirv::ConstantOp::getOne(powerType, loc, rewriter);
536 Value lowBit = spirv::BitwiseAndOp::create(rewriter, loc, power, oneI);
537 Value isOdd = spirv::IEqualOp::create(rewriter, loc, lowBit, oneI);
538
539 Value shouldNegate =
540 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
541 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
542 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, shouldNegate, negate, pow);
543 return success();
544 }
545};
546
547/// Converts math.round to GLSL SPIRV extended ops.
548struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
549 using Base::Base;
550
551 LogicalResult
552 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter) const override {
554 if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
555 return res;
556
557 Location loc = roundOp.getLoc();
558 auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
559 if (!ty) {
560 return rewriter.notifyMatchFailure(
561 roundOp->getLoc(),
562 llvm::formatv("failed to convert type {0} for SPIR-V",
563 roundOp.getType()));
564 }
565
566 Type ety = getElementTypeOrSelf(ty);
567
568 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
569 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
570 Value half;
571 if (VectorType vty = dyn_cast<VectorType>(ty)) {
572 half = spirv::ConstantOp::create(
573 rewriter, loc, vty,
575 rewriter.getFloatAttr(ety, 0.5).getValue()));
576 } else {
577 half = spirv::ConstantOp::create(rewriter, loc, ty,
578 rewriter.getFloatAttr(ety, 0.5));
579 }
580
581 auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
582 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
583 auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
584 auto greater =
585 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
586 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
587 auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
588 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add,
589 adaptor.getOperand());
590 return success();
591 }
592};
593
594} // namespace
595
596//===----------------------------------------------------------------------===//
597// Pattern population
598//===----------------------------------------------------------------------===//
599
600namespace mlir {
602 RewritePatternSet &patterns) {
603 // Core patterns
604 patterns
605 .add<CopySignPattern,
606 CheckedElementwiseOpPattern<math::CtPopOp, spirv::BitCountOp>,
607 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
608 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
609 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>,
610 CheckedElementwiseOpPattern<math::IsNormalOp, spirv::IsNormalOp>>(
611 typeConverter, patterns.getContext());
612
613 // GLSL patterns
614 patterns
615 .add<CountLeadingZerosPattern, CountTrailingZerosPattern,
616 Log1pOpPattern<spirv::GLLogOp>, Log10OpPattern,
617 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, PowIOpGLPattern,
618 RoundOpPattern,
619 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
620 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
621 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
622 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
623 CheckedElementwiseOpPattern<math::ClampFOp, spirv::GLFClampOp>,
624 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
625 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
626 CheckedElementwiseOpPattern<math::Exp2Op, spirv::GLExp2Op>,
627 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
628 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
629 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
630 CheckedElementwiseOpPattern<math::Log2Op, spirv::GLLog2Op>,
631 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
632 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
633 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
634 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
635 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
636 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
637 CheckedElementwiseOpPattern<math::TruncOp, spirv::GLTruncOp>,
638 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
639 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
640 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
641 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
642 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
643 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
644 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
645 typeConverter, patterns.getContext());
646
647 // OpenCL patterns
648 patterns.add<
649 Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
650 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
651 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
652 CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>,
653 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
654 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
655 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
656 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
657 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
658 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
659 CheckedElementwiseOpPattern<math::Exp2Op, spirv::CLExp2Op>,
660 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
661 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
662 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
663 CheckedElementwiseOpPattern<math::Log2Op, spirv::CLLog2Op>,
664 CheckedElementwiseOpPattern<math::Log10Op, spirv::CLLog10Op>,
665 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>, PowIOpPattern,
666 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
667 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
668 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
669 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
670 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
671 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
672 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
673 CheckedElementwiseOpPattern<math::TruncOp, spirv::CLTruncOp>,
674 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
675 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
676 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
677 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
678 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
679 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
680 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
681 typeConverter, patterns.getContext());
682}
683
684} // namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
#define mul(a, b)
#define add(a, b)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:126
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition Fraction.h:77
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24