11 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
175 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
182 template <
typename Op,
typename SPIRVOp>
187 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
189 assert(adaptor.getOperands().size() <= 3);
190 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
195 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
200 dstType != op.getType()) {
201 return op.
emitError(
"bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (
auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
233 struct ConstantCompositeOpPattern final
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
246 if (!isa<VectorType, RankedTensorType>(srcType))
249 Type dstType = getTypeConverter()->convertType(srcType);
256 if (
auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 }
else if (
auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
264 return constOp->emitError(
"could not find resource blob");
270 bool detectedSplat =
false;
272 return constOp->emitError(
"resource is not a valid buffer");
277 return constOp->emitError(
"unsupported elements attribute");
280 ShapedType dstAttrType = dstElementsAttr.
getType();
284 if (srcType.getRank() > 1) {
285 if (isa<RankedTensorType>(srcType)) {
287 srcType.getElementType());
288 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
295 Type srcElemType = srcType.getElementType();
299 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
300 dstElemType = arrayType.getElementType();
302 dstElemType = cast<VectorType>(dstType).getElementType();
306 if (srcElemType != dstElemType) {
308 if (isa<FloatType>(srcElemType)) {
309 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
312 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
313 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
315 isa<IntegerType>(dstElemType)) {
324 elements.push_back(dstAttr);
329 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
331 srcAttr, cast<IntegerType>(dstElemType), rewriter);
334 elements.push_back(dstAttr);
342 if (isa<RankedTensorType>(dstAttrType))
358 struct ConstantScalarOpPattern final
363 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
365 Type srcType = constOp.getType();
366 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
367 if (shapedType.getNumElements() != 1)
369 srcType = shapedType.getElementType();
375 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
376 cstAttr = elementsAttr.getSplatValue<
Attribute>();
378 Type dstType = getTypeConverter()->convertType(srcType);
383 if (isa<FloatType>(srcType)) {
384 auto srcAttr = cast<FloatAttr>(cstAttr);
389 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
390 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
392 dstType.getIntOrFloatBitWidth() == 8) {
397 }
else if (srcType != dstType) {
420 auto srcAttr = cast<IntegerAttr>(cstAttr);
421 IntegerAttr dstAttr =
441 template <
typename SignedAbsOp>
445 assert(lhs == signOperand || rhs == signOperand);
450 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
451 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
452 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
456 if (lhs == signOperand)
457 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
459 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
460 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
461 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
473 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
475 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
476 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
477 adaptor.getOperands()[0], rewriter);
489 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
491 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
492 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
493 adaptor.getOperands()[0], rewriter);
508 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
513 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
515 assert(adaptor.getOperands().size() == 2);
516 Type dstType = this->getTypeConverter()->convertType(op.getType());
521 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
522 op, dstType, adaptor.getOperands());
524 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
525 op, dstType, adaptor.getOperands());
540 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
542 assert(adaptor.getOperands().size() == 2);
547 Type dstType = getTypeConverter()->convertType(op.getType());
552 adaptor.getOperands());
564 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
566 assert(adaptor.getOperands().size() == 2);
571 Type dstType = getTypeConverter()->convertType(op.getType());
576 op, dstType, adaptor.getOperands());
591 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
593 Type srcType = adaptor.getOperands().front().getType();
597 Type dstType = getTypeConverter()->convertType(op.getType());
603 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
605 op, dstType, adaptor.getOperands().front(), one, zero);
615 struct IndexCastIndexI1Pattern final
620 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
625 Type dstType = getTypeConverter()->convertType(op.getType());
639 struct IndexCastI1IndexPattern final
644 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
649 Type dstType = getTypeConverter()->convertType(op.getType());
655 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
672 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
674 Value operand = adaptor.getIn();
679 Type dstType = getTypeConverter()->convertType(op.getType());
684 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
685 unsigned componentBitwidth = intTy.getWidth();
686 allOnes = spirv::ConstantOp::create(
687 rewriter, loc, intTy,
688 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
689 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
690 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
691 allOnes = spirv::ConstantOp::create(
692 rewriter, loc, vectorTy,
694 APInt::getAllOnes(componentBitwidth)));
697 loc, llvm::formatv(
"unhandled type: {0}", dstType));
713 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
715 Type srcType = adaptor.getIn().getType();
719 Type dstType = getTypeConverter()->convertType(op.getType());
723 if (dstType == srcType) {
731 assert(srcBW < dstBW);
733 rewriter, op.getLoc());
738 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
739 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
744 op, dstType, shiftLOp, shiftSize);
747 adaptor.getOperands());
764 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
766 Type srcType = adaptor.getOperands().front().getType();
770 Type dstType = getTypeConverter()->convertType(op.getType());
776 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
778 op, dstType, adaptor.getOperands().front(), one, zero);
789 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
791 Type srcType = adaptor.getIn().getType();
795 Type dstType = getTypeConverter()->convertType(op.getType());
799 if (dstType == srcType) {
807 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
810 adaptor.getIn(), mask);
813 adaptor.getOperands());
829 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
831 Type dstType = getTypeConverter()->convertType(op.getType());
839 auto srcType = adaptor.getOperands().front().getType();
841 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
842 Value maskedSrc = spirv::BitwiseAndOp::create(
843 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
844 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
847 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
859 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
861 Type srcType = adaptor.getIn().getType();
862 Type dstType = getTypeConverter()->convertType(op.getType());
869 if (dstType == srcType) {
876 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
878 adaptor.getIn(), mask);
882 adaptor.getOperands());
892 static std::optional<spirv::FPRoundingMode>
893 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
894 switch (roundingMode) {
895 case arith::RoundingMode::downward:
896 return spirv::FPRoundingMode::RTN;
897 case arith::RoundingMode::to_nearest_even:
898 return spirv::FPRoundingMode::RTE;
899 case arith::RoundingMode::toward_zero:
900 return spirv::FPRoundingMode::RTZ;
901 case arith::RoundingMode::upward:
902 return spirv::FPRoundingMode::RTP;
903 case arith::RoundingMode::to_nearest_away:
908 llvm_unreachable(
"Unhandled rounding mode");
912 template <
typename Op,
typename SPIRVOp>
917 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
919 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
920 Type dstType = this->getTypeConverter()->convertType(op.getType());
927 if (dstType == srcType) {
930 rewriter.
replaceOp(op, adaptor.getOperands().front());
933 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
934 if (
auto roundingModeOp =
935 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
936 if (arith::RoundingModeAttr roundingMode =
937 roundingModeOp.getRoundingModeAttr()) {
939 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
942 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
947 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
948 op, dstType, adaptor.getOperands());
969 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
971 Type srcType = op.getLhs().getType();
974 Type dstType = getTypeConverter()->convertType(srcType);
978 switch (op.getPredicate()) {
979 case arith::CmpIPredicate::eq: {
984 case arith::CmpIPredicate::ne: {
986 op, adaptor.getLhs(), adaptor.getRhs());
989 case arith::CmpIPredicate::uge:
990 case arith::CmpIPredicate::ugt:
991 case arith::CmpIPredicate::ule:
992 case arith::CmpIPredicate::ult: {
996 if (
auto vectorType = dyn_cast<VectorType>(dstType))
999 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1001 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1020 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1022 Type srcType = op.getLhs().getType();
1025 Type dstType = getTypeConverter()->convertType(srcType);
1029 switch (op.getPredicate()) {
1030 #define DISPATCH(cmpPredicate, spirvOp) \
1031 case cmpPredicate: \
1032 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1033 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1034 !hasSameBitwidth(srcType, dstType)) { \
1035 return op.emitError( \
1036 "bitwidth emulation is not implemented yet on unsigned op"); \
1038 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1039 adaptor.getRhs()); \
1042 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1044 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1045 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1046 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1047 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1048 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1049 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1050 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1051 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1069 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1071 switch (op.getPredicate()) {
1072 #define DISPATCH(cmpPredicate, spirvOp) \
1073 case cmpPredicate: \
1074 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1075 adaptor.getRhs()); \
1079 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1080 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1081 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1082 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1083 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1086 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1087 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1088 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1089 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1090 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1091 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1111 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1117 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1134 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1136 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1137 op.getPredicate() != arith::CmpFPredicate::UNO)
1143 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1144 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1146 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1152 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1153 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1155 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1156 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1157 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1170 class AddUIExtendedOpPattern final
1175 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1177 Type dstElemTy = adaptor.getLhs().getType();
1179 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1182 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1184 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1188 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1189 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1191 rewriter.
replaceOp(op, {sumResult, carryResult});
1201 template <
typename ArithMulOp,
typename SPIRVMulOp>
1206 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1210 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1212 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1214 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1231 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1234 adaptor.getTrueValue(),
1235 adaptor.getFalseValue());
1246 template <
typename Op,
typename SPIRVOp>
1251 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1253 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1254 Type dstType = converter->convertType(op.getType());
1268 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1270 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1275 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1276 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1278 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1279 adaptor.getLhs(), spirvOp);
1280 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1281 adaptor.getRhs(), select1);
1294 template <
typename Op,
typename SPIRVOp>
1296 template <
typename TargetOp>
1297 constexpr
bool shouldInsertNanGuards()
const {
1298 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1304 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1306 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1307 Type dstType = converter->convertType(op.getType());
1322 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1324 if (!shouldInsertNanGuards<SPIRVOp>() ||
1325 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1330 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1331 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1333 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1334 adaptor.getRhs(), spirvOp);
1335 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1336 adaptor.getLhs(), select1);
1353 ConstantCompositeOpPattern,
1354 ConstantScalarOpPattern,
1355 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1356 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1357 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1361 RemSIOpGLPattern, RemSIOpCLPattern,
1362 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1363 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1364 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1365 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1374 ExtUIPattern, ExtUII1Pattern,
1375 ExtSIPattern, ExtSII1Pattern,
1376 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1377 TruncIPattern, TruncII1Pattern,
1378 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1379 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1380 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1381 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1382 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1383 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1385 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1386 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1387 CmpIOpBooleanPattern, CmpIOpPattern,
1388 CmpFOpNanNonePattern, CmpFOpPattern,
1389 AddUIExtendedOpPattern,
1390 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1391 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1394 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1395 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1396 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1397 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1403 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1404 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1405 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1406 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1411 >(typeConverter,
patterns.getContext());
1425 struct ConvertArithToSPIRVPass
1426 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1429 void runOnOperation()
override {
1432 std::unique_ptr<SPIRVConversionTarget> target =
1436 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1437 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1442 target->addLegalOp<UnrealizedConversionCastOp>();
1445 target->addIllegalDialect<arith::ArithDialect>();
1451 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.