11 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
31 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
104 assert(type &&
"Not a valid type");
108 if (
auto vecType = dyn_cast<VectorType>(type))
109 return vecType.getElementType().isInteger(1);
117 if (
auto vectorType = dyn_cast<VectorType>(type)) {
120 return builder.
create<spirv::ConstantOp>(loc, vectorType, attr);
123 if (
auto intType = dyn_cast<IntegerType>(type))
124 return builder.
create<spirv::ConstantOp>(
133 auto getNumBitwidth = [](
Type type) {
135 if (type.isIntOrFloat())
136 bw = type.getIntOrFloatBitWidth();
137 else if (
auto vecType = dyn_cast<VectorType>(type))
138 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
141 unsigned aBW = getNumBitwidth(a);
142 unsigned bBW = getNumBitwidth(b);
143 return aBW != 0 && bBW != 0 && aBW == bBW;
152 llvm::formatv(
"failed to convert source type '{0}'", srcType));
164 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
171 template <
typename Op,
typename SPIRVOp>
176 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
178 assert(adaptor.getOperands().size() <= 3);
179 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
180 Type dstType = converter->convertType(op.getType());
184 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
187 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
189 dstType != op.getType()) {
190 return op.
emitError(
"bitwidth emulation is not implemented yet on "
191 "unsigned op pattern version");
194 auto overflowFlags = arith::IntegerOverflowFlags::none;
195 if (
auto overflowIface =
196 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197 if (converter->getTargetEnv().allows(
198 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199 overflowFlags = overflowIface.getOverflowAttr().getValue();
202 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203 op, dstType, adaptor.getOperands());
205 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
209 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 struct ConstantCompositeOpPattern final
227 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
229 auto srcType = dyn_cast<ShapedType>(constOp.getType());
230 if (!srcType || srcType.getNumElements() == 1)
235 if (!isa<VectorType, RankedTensorType>(srcType))
238 Type dstType = getTypeConverter()->convertType(srcType);
245 if (
auto denseElementsAttr =
246 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247 dstElementsAttr = denseElementsAttr;
248 }
else if (
auto resourceAttr =
249 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
253 return constOp->emitError(
"could not find resource blob");
259 bool detectedSplat =
false;
261 return constOp->emitError(
"resource is not a valid buffer");
266 return constOp->emitError(
"unsupported elements attribute");
269 ShapedType dstAttrType = dstElementsAttr.
getType();
273 if (srcType.getRank() > 1) {
274 if (isa<RankedTensorType>(srcType)) {
276 srcType.getElementType());
277 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
284 Type srcElemType = srcType.getElementType();
288 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289 dstElemType = arrayType.getElementType();
291 dstElemType = cast<VectorType>(dstType).getElementType();
295 if (srcElemType != dstElemType) {
297 if (isa<FloatType>(srcElemType)) {
298 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
303 elements.push_back(dstAttr);
308 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
310 srcAttr, cast<IntegerType>(dstElemType), rewriter);
313 elements.push_back(dstAttr);
321 if (isa<RankedTensorType>(dstAttrType))
337 struct ConstantScalarOpPattern final
342 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
344 Type srcType = constOp.getType();
345 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
346 if (shapedType.getNumElements() != 1)
348 srcType = shapedType.getElementType();
354 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355 cstAttr = elementsAttr.getSplatValue<
Attribute>();
357 Type dstType = getTypeConverter()->convertType(srcType);
362 if (isa<FloatType>(srcType)) {
363 auto srcAttr = cast<FloatAttr>(cstAttr);
364 auto dstAttr = srcAttr;
368 if (srcType != dstType) {
391 auto srcAttr = cast<IntegerAttr>(cstAttr);
392 IntegerAttr dstAttr =
412 template <
typename SignedAbsOp>
416 assert(lhs == signOperand || rhs == signOperand);
421 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
422 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
427 if (lhs == signOperand)
428 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
430 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
432 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
443 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
445 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447 adaptor.getOperands()[0], rewriter);
459 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
461 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463 adaptor.getOperands()[0], rewriter);
478 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
483 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
485 assert(adaptor.getOperands().size() == 2);
486 Type dstType = this->getTypeConverter()->convertType(op.getType());
491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492 op, dstType, adaptor.getOperands());
494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495 op, dstType, adaptor.getOperands());
510 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
512 assert(adaptor.getOperands().size() == 2);
517 Type dstType = getTypeConverter()->convertType(op.getType());
522 adaptor.getOperands());
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
536 assert(adaptor.getOperands().size() == 2);
541 Type dstType = getTypeConverter()->convertType(op.getType());
546 op, dstType, adaptor.getOperands());
561 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
563 Type srcType = adaptor.getOperands().front().getType();
567 Type dstType = getTypeConverter()->convertType(op.getType());
573 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
575 op, dstType, adaptor.getOperands().front(), one, zero);
590 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
592 Value operand = adaptor.getIn();
597 Type dstType = getTypeConverter()->convertType(op.getType());
602 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
603 unsigned componentBitwidth = intTy.getWidth();
604 allOnes = rewriter.
create<spirv::ConstantOp>(
606 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
608 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609 allOnes = rewriter.
create<spirv::ConstantOp>(
612 APInt::getAllOnes(componentBitwidth)));
615 loc, llvm::formatv(
"unhandled type: {0}", dstType));
631 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
633 Type srcType = adaptor.getIn().getType();
637 Type dstType = getTypeConverter()->convertType(op.getType());
641 if (dstType == srcType) {
649 assert(srcBW < dstBW);
651 rewriter, op.getLoc());
656 auto shiftLOp = rewriter.
create<spirv::ShiftLeftLogicalOp>(
657 op.getLoc(), dstType, adaptor.getIn(), shiftSize);
662 op, dstType, shiftLOp, shiftSize);
665 adaptor.getOperands());
682 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
684 Type srcType = adaptor.getOperands().front().getType();
688 Type dstType = getTypeConverter()->convertType(op.getType());
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
696 op, dstType, adaptor.getOperands().front(), one, zero);
707 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
709 Type srcType = adaptor.getIn().getType();
713 Type dstType = getTypeConverter()->convertType(op.getType());
717 if (dstType == srcType) {
725 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
728 adaptor.getIn(), mask);
731 adaptor.getOperands());
747 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
749 Type dstType = getTypeConverter()->convertType(op.getType());
757 auto srcType = adaptor.getOperands().front().getType();
759 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
761 loc, srcType, adaptor.getOperands()[0], mask);
762 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
765 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
779 Type srcType = adaptor.getIn().getType();
780 Type dstType = getTypeConverter()->convertType(op.getType());
787 if (dstType == srcType) {
794 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
796 adaptor.getIn(), mask);
800 adaptor.getOperands());
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812 switch (roundingMode) {
813 case arith::RoundingMode::downward:
814 return spirv::FPRoundingMode::RTN;
815 case arith::RoundingMode::to_nearest_even:
816 return spirv::FPRoundingMode::RTE;
817 case arith::RoundingMode::toward_zero:
818 return spirv::FPRoundingMode::RTZ;
819 case arith::RoundingMode::upward:
820 return spirv::FPRoundingMode::RTP;
821 case arith::RoundingMode::to_nearest_away:
826 llvm_unreachable(
"Unhandled rounding mode");
830 template <
typename Op,
typename SPIRVOp>
835 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
837 assert(adaptor.getOperands().size() == 1);
838 Type srcType = adaptor.getOperands().front().getType();
839 Type dstType = this->getTypeConverter()->convertType(op.getType());
846 if (dstType == srcType) {
849 rewriter.
replaceOp(op, adaptor.getOperands().front());
851 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
852 op, dstType, adaptor.getOperands());
853 if (
auto roundingModeOp =
854 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
855 if (arith::RoundingModeAttr roundingMode =
856 roundingModeOp.getRoundingModeAttr()) {
858 convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
865 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
884 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
886 Type srcType = op.getLhs().getType();
889 Type dstType = getTypeConverter()->convertType(srcType);
893 switch (op.getPredicate()) {
894 case arith::CmpIPredicate::eq: {
899 case arith::CmpIPredicate::ne: {
901 op, adaptor.getLhs(), adaptor.getRhs());
904 case arith::CmpIPredicate::uge:
905 case arith::CmpIPredicate::ugt:
906 case arith::CmpIPredicate::ule:
907 case arith::CmpIPredicate::ult: {
911 if (
auto vectorType = dyn_cast<VectorType>(dstType))
914 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
916 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
935 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
937 Type srcType = op.getLhs().getType();
940 Type dstType = getTypeConverter()->convertType(srcType);
944 switch (op.getPredicate()) {
945 #define DISPATCH(cmpPredicate, spirvOp) \
947 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
948 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
949 !hasSameBitwidth(srcType, dstType)) { \
950 return op.emitError( \
951 "bitwidth emulation is not implemented yet on unsigned op"); \
953 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
957 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
958 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
959 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
960 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
961 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
962 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
963 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
964 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
965 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
966 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
984 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
986 switch (op.getPredicate()) {
987 #define DISPATCH(cmpPredicate, spirvOp) \
989 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
994 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
995 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
996 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
997 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
998 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
999 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1001 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1002 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1003 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1004 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1005 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1006 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1024 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1026 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1032 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1049 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1051 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1052 op.getPredicate() != arith::CmpFPredicate::UNO)
1058 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1059 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1061 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1067 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1068 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1070 replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1071 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1072 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
1085 class AddUIExtendedOpPattern final
1090 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1092 Type dstElemTy = adaptor.getLhs().getType();
1094 Value result = rewriter.
create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1097 Value sumResult = rewriter.
create<spirv::CompositeExtractOp>(
1099 Value carryValue = rewriter.
create<spirv::CompositeExtractOp>(
1103 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1104 Value carryResult = rewriter.
create<spirv::IEqualOp>(loc, carryValue, one);
1106 rewriter.
replaceOp(op, {sumResult, carryResult});
1116 template <
typename ArithMulOp,
typename SPIRVMulOp>
1121 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1125 rewriter.
create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1127 Value low = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1129 Value high = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1146 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1149 adaptor.getTrueValue(),
1150 adaptor.getFalseValue());
1161 template <
typename Op,
typename SPIRVOp>
1166 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1168 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1169 Type dstType = converter->convertType(op.getType());
1183 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1185 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1190 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1191 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1193 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1194 adaptor.getLhs(), spirvOp);
1195 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1196 adaptor.getRhs(), select1);
1209 template <
typename Op,
typename SPIRVOp>
1211 template <
typename TargetOp>
1212 constexpr
bool shouldInsertNanGuards()
const {
1213 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1219 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1221 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1222 Type dstType = converter->convertType(op.getType());
1237 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1239 if (!shouldInsertNanGuards<SPIRVOp>() ||
1240 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1245 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1246 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1248 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1249 adaptor.getRhs(), spirvOp);
1250 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1251 adaptor.getLhs(), select1);
1268 ConstantCompositeOpPattern,
1269 ConstantScalarOpPattern,
1270 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1271 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1272 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1276 RemSIOpGLPattern, RemSIOpCLPattern,
1277 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1278 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1279 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1280 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1289 ExtUIPattern, ExtUII1Pattern,
1290 ExtSIPattern, ExtSII1Pattern,
1291 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1292 TruncIPattern, TruncII1Pattern,
1293 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1294 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1295 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1296 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1297 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1298 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1299 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1300 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1301 CmpIOpBooleanPattern, CmpIOpPattern,
1302 CmpFOpNanNonePattern, CmpFOpPattern,
1303 AddUIExtendedOpPattern,
1304 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1305 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1308 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1309 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1310 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1311 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1317 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1318 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1319 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1320 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1330 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1339 struct ConvertArithToSPIRVPass
1340 :
public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1341 void runOnOperation()
override {
1344 std::unique_ptr<SPIRVConversionTarget> target =
1348 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1353 target->addLegalOp<UnrealizedConversionCastOp>();
1356 target->addIllegalDialect<arith::ArithDialect>();
1362 signalPassFailure();
1368 return std::make_unique<ConvertArithToSPIRVPass>();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.