11 #include "../SPIRVCommon/Pattern.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
29 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
30 #include "mlir/Conversion/Passes.h.inc"
33 #define DEBUG_TYPE "arith-to-spirv-pattern"
44 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
47 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
57 if (srcAttr.getValue().isIntN(dstType.getWidth()))
65 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
68 << dstAttr <<
"' for type '" << dstType <<
"'\n");
72 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
73 <<
"' illegal: cannot fit into target type '"
87 APFloat dstVal = srcAttr.getValue();
88 bool losesInfo =
false;
89 APFloat::opStatus status =
90 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
91 if (status != APFloat::opOK || losesInfo) {
92 LLVM_DEBUG(llvm::dbgs()
93 << srcAttr <<
" illegal: cannot fit into converted type '"
103 assert(type &&
"Not a valid type");
107 if (
auto vecType = dyn_cast<VectorType>(type))
108 return vecType.getElementType().isInteger(1);
116 if (
auto vectorType = dyn_cast<VectorType>(type)) {
119 return builder.
create<spirv::ConstantOp>(loc, vectorType, attr);
122 if (
auto intType = dyn_cast<IntegerType>(type))
123 return builder.
create<spirv::ConstantOp>(
132 auto getNumBitwidth = [](
Type type) {
134 if (type.isIntOrFloat())
135 bw = type.getIntOrFloatBitWidth();
136 else if (
auto vecType = dyn_cast<VectorType>(type))
137 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
140 unsigned aBW = getNumBitwidth(a);
141 unsigned bBW = getNumBitwidth(b);
142 return aBW != 0 && bBW != 0 && aBW == bBW;
151 llvm::formatv(
"failed to convert source type '{0}'", srcType));
163 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
170 template <
typename Op,
typename SPIRVOp>
175 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
177 assert(adaptor.getOperands().size() <= 3);
178 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
179 Type dstType = converter->convertType(op.getType());
183 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
186 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188 dstType != op.getType()) {
189 return op.
emitError(
"bitwidth emulation is not implemented yet on "
190 "unsigned op pattern version");
193 auto overflowFlags = arith::IntegerOverflowFlags::none;
194 if (
auto overflowIface =
195 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
196 if (converter->getTargetEnv().allows(
197 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
198 overflowFlags = overflowIface.getOverflowAttr().getValue();
201 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
202 op, dstType, adaptor.getOperands());
204 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
208 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
221 struct ConstantCompositeOpPattern final
226 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228 auto srcType = dyn_cast<ShapedType>(constOp.getType());
229 if (!srcType || srcType.getNumElements() == 1)
233 assert((isa<VectorType, RankedTensorType>(srcType)));
235 Type dstType = getTypeConverter()->convertType(srcType);
239 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
240 if (!dstElementsAttr)
243 ShapedType dstAttrType = dstElementsAttr.getType();
247 if (srcType.getRank() > 1) {
248 if (isa<RankedTensorType>(srcType)) {
250 srcType.getElementType());
251 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
258 Type srcElemType = srcType.getElementType();
262 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
263 dstElemType = arrayType.getElementType();
265 dstElemType = cast<VectorType>(dstType).getElementType();
269 if (srcElemType != dstElemType) {
271 if (isa<FloatType>(srcElemType)) {
272 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
277 elements.push_back(dstAttr);
282 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
284 srcAttr, cast<IntegerType>(dstElemType), rewriter);
287 elements.push_back(dstAttr);
295 if (isa<RankedTensorType>(dstAttrType))
311 struct ConstantScalarOpPattern final
316 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
318 Type srcType = constOp.getType();
319 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
320 if (shapedType.getNumElements() != 1)
322 srcType = shapedType.getElementType();
328 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
329 cstAttr = elementsAttr.getSplatValue<
Attribute>();
331 Type dstType = getTypeConverter()->convertType(srcType);
336 if (isa<FloatType>(srcType)) {
337 auto srcAttr = cast<FloatAttr>(cstAttr);
338 auto dstAttr = srcAttr;
342 if (srcType != dstType) {
365 auto srcAttr = cast<IntegerAttr>(cstAttr);
366 IntegerAttr dstAttr =
386 template <
typename SignedAbsOp>
390 assert(lhs == signOperand || rhs == signOperand);
395 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
396 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
401 if (lhs == signOperand)
402 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
404 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
405 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
406 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
417 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
419 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
420 op.
getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
421 adaptor.getOperands()[0], rewriter);
433 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
435 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
436 op.
getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
437 adaptor.getOperands()[0], rewriter);
452 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
457 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
459 assert(adaptor.getOperands().size() == 2);
460 Type dstType = this->getTypeConverter()->convertType(op.getType());
465 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
466 op, dstType, adaptor.getOperands());
468 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
469 op, dstType, adaptor.getOperands());
484 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
486 assert(adaptor.getOperands().size() == 2);
491 Type dstType = getTypeConverter()->convertType(op.getType());
496 adaptor.getOperands());
508 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
510 assert(adaptor.getOperands().size() == 2);
515 Type dstType = getTypeConverter()->convertType(op.getType());
520 op, dstType, adaptor.getOperands());
535 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
537 Type srcType = adaptor.getOperands().front().getType();
541 Type dstType = getTypeConverter()->convertType(op.getType());
547 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
549 op, dstType, adaptor.getOperands().front(), one, zero);
564 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
566 Value operand = adaptor.getIn();
571 Type dstType = getTypeConverter()->convertType(op.getType());
576 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
577 unsigned componentBitwidth = intTy.getWidth();
578 allOnes = rewriter.
create<spirv::ConstantOp>(
580 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
581 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
582 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
583 allOnes = rewriter.
create<spirv::ConstantOp>(
586 APInt::getAllOnes(componentBitwidth)));
589 loc, llvm::formatv(
"unhandled type: {0}", dstType));
605 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
607 Type srcType = adaptor.getIn().getType();
611 Type dstType = getTypeConverter()->convertType(op.getType());
615 if (dstType == srcType) {
623 assert(srcBW < dstBW);
630 auto shiftLOp = rewriter.
create<spirv::ShiftLeftLogicalOp>(
631 op.
getLoc(), dstType, adaptor.getIn(), shiftSize);
636 op, dstType, shiftLOp, shiftSize);
639 adaptor.getOperands());
656 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
658 Type srcType = adaptor.getOperands().front().getType();
662 Type dstType = getTypeConverter()->convertType(op.getType());
668 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
670 op, dstType, adaptor.getOperands().front(), one, zero);
681 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683 Type srcType = adaptor.getIn().getType();
687 Type dstType = getTypeConverter()->convertType(op.getType());
691 if (dstType == srcType) {
699 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
702 adaptor.getIn(), mask);
721 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
723 Type dstType = getTypeConverter()->convertType(op.getType());
731 auto srcType = adaptor.getOperands().front().getType();
733 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
734 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
735 loc, srcType, adaptor.getOperands()[0], mask);
736 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
739 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
751 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
753 Type srcType = adaptor.getIn().getType();
754 Type dstType = getTypeConverter()->convertType(op.getType());
761 if (dstType == srcType) {
768 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.
getLoc());
770 adaptor.getIn(), mask);
785 template <
typename Op,
typename SPIRVOp>
790 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
792 assert(adaptor.getOperands().size() == 1);
793 Type srcType = adaptor.getOperands().front().getType();
794 Type dstType = this->getTypeConverter()->convertType(op.getType());
801 if (dstType == srcType) {
804 rewriter.
replaceOp(op, adaptor.getOperands().front());
806 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
807 adaptor.getOperands());
808 if (
auto roundingModeOp =
809 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
810 if (arith::RoundingModeAttr roundingMode =
811 roundingModeOp.getRoundingModeAttr()) {
832 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
834 Type srcType = op.getLhs().getType();
837 Type dstType = getTypeConverter()->convertType(srcType);
841 switch (op.getPredicate()) {
842 case arith::CmpIPredicate::eq: {
847 case arith::CmpIPredicate::ne: {
849 op, adaptor.getLhs(), adaptor.getRhs());
852 case arith::CmpIPredicate::uge:
853 case arith::CmpIPredicate::ugt:
854 case arith::CmpIPredicate::ule:
855 case arith::CmpIPredicate::ult: {
859 if (
auto vectorType = dyn_cast<VectorType>(dstType))
862 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getLhs());
864 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getRhs());
883 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
885 Type srcType = op.getLhs().getType();
888 Type dstType = getTypeConverter()->convertType(srcType);
892 switch (op.getPredicate()) {
893 #define DISPATCH(cmpPredicate, spirvOp) \
895 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
896 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
897 !hasSameBitwidth(srcType, dstType)) { \
898 return op.emitError( \
899 "bitwidth emulation is not implemented yet on unsigned op"); \
901 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
905 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
906 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
907 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
908 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
909 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
910 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
911 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
912 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
913 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
914 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
932 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
934 switch (op.getPredicate()) {
935 #define DISPATCH(cmpPredicate, spirvOp) \
937 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
942 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
943 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
944 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
945 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
946 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
947 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
949 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
950 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
951 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
952 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
953 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
954 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
972 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
974 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
980 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
997 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
999 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1000 op.getPredicate() != arith::CmpFPredicate::UNO)
1006 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1007 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1009 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1015 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1016 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1018 replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1019 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1020 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
1033 class AddUIExtendedOpPattern final
1038 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1040 Type dstElemTy = adaptor.getLhs().getType();
1042 Value result = rewriter.
create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1045 Value sumResult = rewriter.
create<spirv::CompositeExtractOp>(
1047 Value carryValue = rewriter.
create<spirv::CompositeExtractOp>(
1051 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1052 Value carryResult = rewriter.
create<spirv::IEqualOp>(loc, carryValue, one);
1054 rewriter.
replaceOp(op, {sumResult, carryResult});
1064 template <
typename ArithMulOp,
typename SPIRVMulOp>
1069 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1073 rewriter.
create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1075 Value low = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1077 Value high = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1094 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1097 adaptor.getTrueValue(),
1098 adaptor.getFalseValue());
1109 template <
typename Op,
typename SPIRVOp>
1114 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1116 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1117 Type dstType = converter->convertType(op.getType());
1131 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1133 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1138 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1139 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1141 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1142 adaptor.getLhs(), spirvOp);
1143 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1144 adaptor.getRhs(), select1);
1157 template <
typename Op,
typename SPIRVOp>
1159 template <
typename TargetOp>
1160 constexpr
bool shouldInsertNanGuards()
const {
1161 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1167 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1169 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1170 Type dstType = converter->convertType(op.getType());
1185 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1187 if (!shouldInsertNanGuards<SPIRVOp>() ||
1188 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1193 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1196 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197 adaptor.getRhs(), spirvOp);
1198 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199 adaptor.getLhs(), select1);
1216 ConstantCompositeOpPattern,
1217 ConstantScalarOpPattern,
1218 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1219 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1220 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1224 RemSIOpGLPattern, RemSIOpCLPattern,
1225 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1226 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1227 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1228 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1237 ExtUIPattern, ExtUII1Pattern,
1238 ExtSIPattern, ExtSII1Pattern,
1239 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1240 TruncIPattern, TruncII1Pattern,
1241 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1242 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1243 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1244 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1245 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1246 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1247 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1248 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1249 CmpIOpBooleanPattern, CmpIOpPattern,
1250 CmpFOpNanNonePattern, CmpFOpPattern,
1251 AddUIExtendedOpPattern,
1252 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1253 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1256 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1257 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1258 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1259 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1265 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1266 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1267 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1268 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1278 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1287 struct ConvertArithToSPIRVPass
1288 :
public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1289 void runOnOperation()
override {
1292 std::unique_ptr<SPIRVConversionTarget> target =
1296 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1301 target->addLegalOp<UnrealizedConversionCastOp>();
1304 target->addIllegalDialect<arith::ArithDialect>();
1310 signalPassFailure();
1316 return std::make_unique<ConvertArithToSPIRVPass>();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.