11 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
104 assert(type &&
"Not a valid type");
108 if (
auto vecType = dyn_cast<VectorType>(type))
109 return vecType.getElementType().isInteger(1);
117 if (
auto vectorType = dyn_cast<VectorType>(type)) {
120 return builder.
create<spirv::ConstantOp>(loc, vectorType, attr);
123 if (
auto intType = dyn_cast<IntegerType>(type))
124 return builder.
create<spirv::ConstantOp>(
133 auto getNumBitwidth = [](
Type type) {
135 if (type.isIntOrFloat())
136 bw = type.getIntOrFloatBitWidth();
137 else if (
auto vecType = dyn_cast<VectorType>(type))
138 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
141 unsigned aBW = getNumBitwidth(a);
142 unsigned bBW = getNumBitwidth(b);
143 return aBW != 0 && bBW != 0 && aBW == bBW;
152 llvm::formatv(
"failed to convert source type '{0}'", srcType));
164 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
171 template <
typename Op,
typename SPIRVOp>
176 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
178 assert(adaptor.getOperands().size() <= 3);
179 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
180 Type dstType = converter->convertType(op.getType());
184 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
187 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
189 dstType != op.getType()) {
190 return op.
emitError(
"bitwidth emulation is not implemented yet on "
191 "unsigned op pattern version");
194 auto overflowFlags = arith::IntegerOverflowFlags::none;
195 if (
auto overflowIface =
196 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197 if (converter->getTargetEnv().allows(
198 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199 overflowFlags = overflowIface.getOverflowAttr().getValue();
202 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203 op, dstType, adaptor.getOperands());
205 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
209 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 struct ConstantCompositeOpPattern final
227 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
229 auto srcType = dyn_cast<ShapedType>(constOp.getType());
230 if (!srcType || srcType.getNumElements() == 1)
235 if (!isa<VectorType, RankedTensorType>(srcType))
238 Type dstType = getTypeConverter()->convertType(srcType);
245 if (
auto denseElementsAttr =
246 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247 dstElementsAttr = denseElementsAttr;
248 }
else if (
auto resourceAttr =
249 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
253 return constOp->emitError(
"could not find resource blob");
259 bool detectedSplat =
false;
261 return constOp->emitError(
"resource is not a valid buffer");
266 return constOp->emitError(
"unsupported elements attribute");
269 ShapedType dstAttrType = dstElementsAttr.
getType();
273 if (srcType.getRank() > 1) {
274 if (isa<RankedTensorType>(srcType)) {
276 srcType.getElementType());
277 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
284 Type srcElemType = srcType.getElementType();
288 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289 dstElemType = arrayType.getElementType();
291 dstElemType = cast<VectorType>(dstType).getElementType();
295 if (srcElemType != dstElemType) {
297 if (isa<FloatType>(srcElemType)) {
298 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
303 elements.push_back(dstAttr);
308 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
310 srcAttr, cast<IntegerType>(dstElemType), rewriter);
313 elements.push_back(dstAttr);
321 if (isa<RankedTensorType>(dstAttrType))
337 struct ConstantScalarOpPattern final
342 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
344 Type srcType = constOp.getType();
345 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
346 if (shapedType.getNumElements() != 1)
348 srcType = shapedType.getElementType();
354 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355 cstAttr = elementsAttr.getSplatValue<
Attribute>();
357 Type dstType = getTypeConverter()->convertType(srcType);
362 if (isa<FloatType>(srcType)) {
363 auto srcAttr = cast<FloatAttr>(cstAttr);
364 auto dstAttr = srcAttr;
368 if (srcType != dstType) {
391 auto srcAttr = cast<IntegerAttr>(cstAttr);
392 IntegerAttr dstAttr =
412 template <
typename SignedAbsOp>
416 assert(lhs == signOperand || rhs == signOperand);
421 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
422 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
427 if (lhs == signOperand)
428 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
430 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
432 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
443 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
445 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447 adaptor.getOperands()[0], rewriter);
459 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
461 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463 adaptor.getOperands()[0], rewriter);
478 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
483 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
485 assert(adaptor.getOperands().size() == 2);
486 Type dstType = this->getTypeConverter()->convertType(op.getType());
491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492 op, dstType, adaptor.getOperands());
494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495 op, dstType, adaptor.getOperands());
510 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
512 assert(adaptor.getOperands().size() == 2);
517 Type dstType = getTypeConverter()->convertType(op.getType());
522 adaptor.getOperands());
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
536 assert(adaptor.getOperands().size() == 2);
541 Type dstType = getTypeConverter()->convertType(op.getType());
546 op, dstType, adaptor.getOperands());
561 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
563 Type srcType = adaptor.getOperands().front().getType();
567 Type dstType = getTypeConverter()->convertType(op.getType());
573 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
575 op, dstType, adaptor.getOperands().front(), one, zero);
590 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
592 Value operand = adaptor.getIn();
597 Type dstType = getTypeConverter()->convertType(op.getType());
602 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
603 unsigned componentBitwidth = intTy.getWidth();
604 allOnes = rewriter.
create<spirv::ConstantOp>(
606 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
608 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609 allOnes = rewriter.
create<spirv::ConstantOp>(
612 APInt::getAllOnes(componentBitwidth)));
615 loc, llvm::formatv(
"unhandled type: {0}", dstType));
631 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
633 Type srcType = adaptor.getIn().getType();
637 Type dstType = getTypeConverter()->convertType(op.getType());
641 if (dstType == srcType) {
649 assert(srcBW < dstBW);
651 rewriter, op.getLoc());
656 auto shiftLOp = rewriter.
create<spirv::ShiftLeftLogicalOp>(
657 op.getLoc(), dstType, adaptor.getIn(), shiftSize);
662 op, dstType, shiftLOp, shiftSize);
665 adaptor.getOperands());
682 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
684 Type srcType = adaptor.getOperands().front().getType();
688 Type dstType = getTypeConverter()->convertType(op.getType());
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
696 op, dstType, adaptor.getOperands().front(), one, zero);
707 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
709 Type srcType = adaptor.getIn().getType();
713 Type dstType = getTypeConverter()->convertType(op.getType());
717 if (dstType == srcType) {
725 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
728 adaptor.getIn(), mask);
731 adaptor.getOperands());
747 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
749 Type dstType = getTypeConverter()->convertType(op.getType());
757 auto srcType = adaptor.getOperands().front().getType();
759 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
761 loc, srcType, adaptor.getOperands()[0], mask);
762 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
765 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
779 Type srcType = adaptor.getIn().getType();
780 Type dstType = getTypeConverter()->convertType(op.getType());
787 if (dstType == srcType) {
794 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
796 adaptor.getIn(), mask);
800 adaptor.getOperands());
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812 switch (roundingMode) {
813 case arith::RoundingMode::downward:
814 return spirv::FPRoundingMode::RTN;
815 case arith::RoundingMode::to_nearest_even:
816 return spirv::FPRoundingMode::RTE;
817 case arith::RoundingMode::toward_zero:
818 return spirv::FPRoundingMode::RTZ;
819 case arith::RoundingMode::upward:
820 return spirv::FPRoundingMode::RTP;
821 case arith::RoundingMode::to_nearest_away:
826 llvm_unreachable(
"Unhandled rounding mode");
830 template <
typename Op,
typename SPIRVOp>
835 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
837 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
838 Type dstType = this->getTypeConverter()->convertType(op.getType());
845 if (dstType == srcType) {
848 rewriter.
replaceOp(op, adaptor.getOperands().front());
851 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
852 if (
auto roundingModeOp =
853 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854 if (arith::RoundingModeAttr roundingMode =
855 roundingModeOp.getRoundingModeAttr()) {
857 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
860 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
865 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
866 op, dstType, adaptor.getOperands());
887 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
889 Type srcType = op.getLhs().getType();
892 Type dstType = getTypeConverter()->convertType(srcType);
896 switch (op.getPredicate()) {
897 case arith::CmpIPredicate::eq: {
902 case arith::CmpIPredicate::ne: {
904 op, adaptor.getLhs(), adaptor.getRhs());
907 case arith::CmpIPredicate::uge:
908 case arith::CmpIPredicate::ugt:
909 case arith::CmpIPredicate::ule:
910 case arith::CmpIPredicate::ult: {
914 if (
auto vectorType = dyn_cast<VectorType>(dstType))
917 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
919 rewriter.
create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
938 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
940 Type srcType = op.getLhs().getType();
943 Type dstType = getTypeConverter()->convertType(srcType);
947 switch (op.getPredicate()) {
948 #define DISPATCH(cmpPredicate, spirvOp) \
950 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
951 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
952 !hasSameBitwidth(srcType, dstType)) { \
953 return op.emitError( \
954 "bitwidth emulation is not implemented yet on unsigned op"); \
956 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
960 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
961 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
962 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
963 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
964 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
965 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
966 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
967 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
968 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
969 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
987 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
989 switch (op.getPredicate()) {
990 #define DISPATCH(cmpPredicate, spirvOp) \
992 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
997 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
998 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
999 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1000 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1001 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1002 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1004 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1005 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1006 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1007 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1008 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1009 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1027 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1029 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1035 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1052 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1054 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1055 op.getPredicate() != arith::CmpFPredicate::UNO)
1061 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1062 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1064 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1070 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1071 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1073 replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1074 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1075 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
1088 class AddUIExtendedOpPattern final
1093 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1095 Type dstElemTy = adaptor.getLhs().getType();
1097 Value result = rewriter.
create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1100 Value sumResult = rewriter.
create<spirv::CompositeExtractOp>(
1102 Value carryValue = rewriter.
create<spirv::CompositeExtractOp>(
1106 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1107 Value carryResult = rewriter.
create<spirv::IEqualOp>(loc, carryValue, one);
1109 rewriter.
replaceOp(op, {sumResult, carryResult});
1119 template <
typename ArithMulOp,
typename SPIRVMulOp>
1124 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1128 rewriter.
create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1130 Value low = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1132 Value high = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1149 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1152 adaptor.getTrueValue(),
1153 adaptor.getFalseValue());
1164 template <
typename Op,
typename SPIRVOp>
1169 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1171 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1172 Type dstType = converter->convertType(op.getType());
1186 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1188 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1193 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1196 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197 adaptor.getLhs(), spirvOp);
1198 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199 adaptor.getRhs(), select1);
1212 template <
typename Op,
typename SPIRVOp>
1214 template <
typename TargetOp>
1215 constexpr
bool shouldInsertNanGuards()
const {
1216 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1222 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1224 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1225 Type dstType = converter->convertType(op.getType());
1240 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1242 if (!shouldInsertNanGuards<SPIRVOp>() ||
1243 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1248 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1249 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1251 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1252 adaptor.getRhs(), spirvOp);
1253 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1254 adaptor.getLhs(), select1);
1271 ConstantCompositeOpPattern,
1272 ConstantScalarOpPattern,
1273 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1274 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1275 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1279 RemSIOpGLPattern, RemSIOpCLPattern,
1280 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1281 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1282 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1283 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1292 ExtUIPattern, ExtUII1Pattern,
1293 ExtSIPattern, ExtSII1Pattern,
1294 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1295 TruncIPattern, TruncII1Pattern,
1296 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1297 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1298 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1299 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1300 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1301 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1302 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1303 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1304 CmpIOpBooleanPattern, CmpIOpPattern,
1305 CmpFOpNanNonePattern, CmpFOpPattern,
1306 AddUIExtendedOpPattern,
1307 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1308 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1311 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1312 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1313 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1314 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1320 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1321 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1322 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1323 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1328 >(typeConverter,
patterns.getContext());
1342 struct ConvertArithToSPIRVPass
1343 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1346 void runOnOperation()
override {
1349 std::unique_ptr<SPIRVConversionTarget> target =
1353 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1358 target->addLegalOp<UnrealizedConversionCastOp>();
1361 target->addIllegalDialect<arith::ArithDialect>();
1367 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.