11 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
104 assert(type &&
"Not a valid type");
108 if (
auto vecType = dyn_cast<VectorType>(type))
109 return vecType.getElementType().isInteger(1);
117 if (
auto vectorType = dyn_cast<VectorType>(type)) {
120 return builder.
create<spirv::ConstantOp>(loc, vectorType, attr);
123 if (
auto intType = dyn_cast<IntegerType>(type))
124 return builder.
create<spirv::ConstantOp>(
133 auto getNumBitwidth = [](
Type type) {
135 if (type.isIntOrFloat())
136 bw = type.getIntOrFloatBitWidth();
137 else if (
auto vecType = dyn_cast<VectorType>(type))
138 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
141 unsigned aBW = getNumBitwidth(a);
142 unsigned bBW = getNumBitwidth(b);
143 return aBW != 0 && bBW != 0 && aBW == bBW;
152 llvm::formatv(
"failed to convert source type '{0}'", srcType));
164 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
171 template <
typename Op,
typename SPIRVOp>
176 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
178 assert(adaptor.getOperands().size() <= 3);
179 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
180 Type dstType = converter->convertType(op.getType());
184 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
187 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
189 dstType != op.getType()) {
190 return op.
emitError(
"bitwidth emulation is not implemented yet on "
191 "unsigned op pattern version");
194 auto overflowFlags = arith::IntegerOverflowFlags::none;
195 if (
auto overflowIface =
196 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197 if (converter->getTargetEnv().allows(
198 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199 overflowFlags = overflowIface.getOverflowAttr().getValue();
202 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203 op, dstType, adaptor.getOperands());
205 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
209 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 struct ConstantCompositeOpPattern final
227 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
229 auto srcType = dyn_cast<ShapedType>(constOp.getType());
230 if (!srcType || srcType.getNumElements() == 1)
235 if (!isa<VectorType, RankedTensorType>(srcType))
238 Type dstType = getTypeConverter()->convertType(srcType);
245 if (
auto denseElementsAttr =
246 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247 dstElementsAttr = denseElementsAttr;
248 }
else if (
auto resourceAttr =
249 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
253 return constOp->emitError(
"could not find resource blob");
259 bool detectedSplat =
false;
261 return constOp->emitError(
"resource is not a valid buffer");
266 return constOp->emitError(
"unsupported elements attribute");
269 ShapedType dstAttrType = dstElementsAttr.
getType();
273 if (srcType.getRank() > 1) {
274 if (isa<RankedTensorType>(srcType)) {
276 srcType.getElementType());
277 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
284 Type srcElemType = srcType.getElementType();
288 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289 dstElemType = arrayType.getElementType();
291 dstElemType = cast<VectorType>(dstType).getElementType();
295 if (srcElemType != dstElemType) {
297 if (isa<FloatType>(srcElemType)) {
298 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
303 elements.push_back(dstAttr);
308 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
310 srcAttr, cast<IntegerType>(dstElemType), rewriter);
313 elements.push_back(dstAttr);
321 if (isa<RankedTensorType>(dstAttrType))
337 struct ConstantScalarOpPattern final
342 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
344 Type srcType = constOp.getType();
345 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
346 if (shapedType.getNumElements() != 1)
348 srcType = shapedType.getElementType();
354 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355 cstAttr = elementsAttr.getSplatValue<
Attribute>();
357 Type dstType = getTypeConverter()->convertType(srcType);
362 if (isa<FloatType>(srcType)) {
363 auto srcAttr = cast<FloatAttr>(cstAttr);
364 auto dstAttr = srcAttr;
368 if (srcType != dstType) {
391 auto srcAttr = cast<IntegerAttr>(cstAttr);
392 IntegerAttr dstAttr =
412 template <
typename SignedAbsOp>
416 assert(lhs == signOperand || rhs == signOperand);
421 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
422 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
427 if (lhs == signOperand)
428 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
430 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
432 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
443 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
445 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447 adaptor.getOperands()[0], rewriter);
459 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
461 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463 adaptor.getOperands()[0], rewriter);
478 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
483 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
485 assert(adaptor.getOperands().size() == 2);
486 Type dstType = this->getTypeConverter()->convertType(op.getType());
491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492 op, dstType, adaptor.getOperands());
494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495 op, dstType, adaptor.getOperands());
510 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
512 assert(adaptor.getOperands().size() == 2);
517 Type dstType = getTypeConverter()->convertType(op.getType());
522 adaptor.getOperands());
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
536 assert(adaptor.getOperands().size() == 2);
541 Type dstType = getTypeConverter()->convertType(op.getType());
546 op, dstType, adaptor.getOperands());
561 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
563 Type srcType = adaptor.getOperands().front().getType();
567 Type dstType = getTypeConverter()->convertType(op.getType());
573 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
575 op, dstType, adaptor.getOperands().front(), one, zero);
590 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
592 Value operand = adaptor.getIn();
597 Type dstType = getTypeConverter()->convertType(op.getType());
602 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
603 unsigned componentBitwidth = intTy.getWidth();
604 allOnes = rewriter.
create<spirv::ConstantOp>(
606 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
608 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609 allOnes = rewriter.
create<spirv::ConstantOp>(
612 APInt::getAllOnes(componentBitwidth)));
615 loc, llvm::formatv(
"unhandled type: {0}", dstType));
631 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
633 Type srcType = adaptor.getIn().getType();
637 Type dstType = getTypeConverter()->convertType(op.getType());
641 if (dstType == srcType) {
649 assert(srcBW < dstBW);
651 rewriter, op.getLoc());
656 auto shiftLOp = rewriter.
create<spirv::ShiftLeftLogicalOp>(
657 op.getLoc(), dstType, adaptor.getIn(), shiftSize);
662 op, dstType, shiftLOp, shiftSize);
665 adaptor.getOperands());
682 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
684 Type srcType = adaptor.getOperands().front().getType();
688 Type dstType = getTypeConverter()->convertType(op.getType());
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
696 op, dstType, adaptor.getOperands().front(), one, zero);
707 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
709 Type srcType = adaptor.getIn().getType();
713 Type dstType = getTypeConverter()->convertType(op.getType());
717 if (dstType == srcType) {
725 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
728 adaptor.getIn(), mask);
731 adaptor.getOperands());
747 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
749 Type dstType = getTypeConverter()->convertType(op.getType());
757 auto srcType = adaptor.getOperands().front().getType();
759 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
761 loc, srcType, adaptor.getOperands()[0], mask);
762 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
765 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
779 Type srcType = adaptor.getIn().getType();
780 Type dstType = getTypeConverter()->convertType(op.getType());
787 if (dstType == srcType) {
794 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
796 adaptor.getIn(), mask);
800 adaptor.getOperands());
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812 switch (roundingMode) {
813 case arith::RoundingMode::downward:
814 return spirv::FPRoundingMode::RTN;
815 case arith::RoundingMode::to_nearest_even:
816 return spirv::FPRoundingMode::RTE;
817 case arith::RoundingMode::toward_zero:
818 return spirv::FPRoundingMode::RTZ;
819 case arith::RoundingMode::upward:
820 return spirv::FPRoundingMode::RTP;
821 case arith::RoundingMode::to_nearest_away:
826 llvm_unreachable(
"Unhandled rounding mode");
830 template <
typename Op,
typename SPIRVOp>
835 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
837 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
838 Type dstType = this->getTypeConverter()->convertType(op.getType());
845 if (dstType == srcType) {
848 rewriter.
replaceOp(op, adaptor.getOperands().front());
850 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
851 op, dstType, adaptor.getOperands());
852 if (
auto roundingModeOp =
853 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854 if (arith::RoundingModeAttr roundingMode =
855 roundingModeOp.getRoundingModeAttr()) {
857 convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
864 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
883 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
885 Type srcType = op.getLhs().getType();
888 Type dstType = getTypeConverter()->convertType(srcType);
892 switch (op.getPredicate()) {
893 case arith::CmpIPredicate::eq: {
898 case arith::CmpIPredicate::ne: {
900 op, adaptor.getLhs(), adaptor.getRhs());
903 case arith::CmpIPredicate::uge:
904 case arith::CmpIPredicate::ugt:
905 case arith::CmpIPredicate::ule:
906 case arith::CmpIPredicate::ult: {
910 if (
auto vectorType = dyn_cast<VectorType>(dstType))
913 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
915 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
934 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
936 Type srcType = op.getLhs().getType();
939 Type dstType = getTypeConverter()->convertType(srcType);
943 switch (op.getPredicate()) {
944 #define DISPATCH(cmpPredicate, spirvOp) \
946 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
947 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
948 !hasSameBitwidth(srcType, dstType)) { \
949 return op.emitError( \
950 "bitwidth emulation is not implemented yet on unsigned op"); \
952 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
956 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
957 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
958 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
959 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
960 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
961 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
962 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
963 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
964 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
965 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
983 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
985 switch (op.getPredicate()) {
986 #define DISPATCH(cmpPredicate, spirvOp) \
988 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
993 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
994 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
995 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
996 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
997 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
998 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1000 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1001 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1002 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1003 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1004 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1005 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1023 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1025 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1031 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1048 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1050 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1051 op.getPredicate() != arith::CmpFPredicate::UNO)
1057 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1058 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1060 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1066 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1067 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1069 replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1070 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1071 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
1084 class AddUIExtendedOpPattern final
1089 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1091 Type dstElemTy = adaptor.getLhs().getType();
1093 Value result = rewriter.
create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1096 Value sumResult = rewriter.
create<spirv::CompositeExtractOp>(
1098 Value carryValue = rewriter.
create<spirv::CompositeExtractOp>(
1102 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1103 Value carryResult = rewriter.
create<spirv::IEqualOp>(loc, carryValue, one);
1105 rewriter.
replaceOp(op, {sumResult, carryResult});
1115 template <
typename ArithMulOp,
typename SPIRVMulOp>
1120 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1124 rewriter.
create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1126 Value low = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1128 Value high = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1145 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1148 adaptor.getTrueValue(),
1149 adaptor.getFalseValue());
1160 template <
typename Op,
typename SPIRVOp>
1165 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1167 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1168 Type dstType = converter->convertType(op.getType());
1182 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1184 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1189 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1190 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1192 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1193 adaptor.getLhs(), spirvOp);
1194 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1195 adaptor.getRhs(), select1);
1208 template <
typename Op,
typename SPIRVOp>
1210 template <
typename TargetOp>
1211 constexpr
bool shouldInsertNanGuards()
const {
1212 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1218 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1220 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1221 Type dstType = converter->convertType(op.getType());
1236 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1238 if (!shouldInsertNanGuards<SPIRVOp>() ||
1239 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1244 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1245 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1247 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1248 adaptor.getRhs(), spirvOp);
1249 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1250 adaptor.getLhs(), select1);
1267 ConstantCompositeOpPattern,
1268 ConstantScalarOpPattern,
1269 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1270 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1271 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1275 RemSIOpGLPattern, RemSIOpCLPattern,
1276 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1277 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1278 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1279 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1288 ExtUIPattern, ExtUII1Pattern,
1289 ExtSIPattern, ExtSII1Pattern,
1290 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1291 TruncIPattern, TruncII1Pattern,
1292 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1293 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1294 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1295 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1296 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1297 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1298 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1299 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1300 CmpIOpBooleanPattern, CmpIOpPattern,
1301 CmpFOpNanNonePattern, CmpFOpPattern,
1302 AddUIExtendedOpPattern,
1303 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1304 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1307 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1308 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1309 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1310 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1316 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1317 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1318 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1319 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1324 >(typeConverter,
patterns.getContext());
1338 struct ConvertArithToSPIRVPass
1339 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1342 void runOnOperation()
override {
1345 std::unique_ptr<SPIRVConversionTarget> target =
1349 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1354 target->addLegalOp<UnrealizedConversionCastOp>();
1357 target->addIllegalDialect<arith::ArithDialect>();
1363 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.