11 #include "../SPIRVCommon/Pattern.h" 
   21 #include "llvm/ADT/APInt.h" 
   22 #include "llvm/ADT/ArrayRef.h" 
   23 #include "llvm/ADT/STLExtras.h" 
   24 #include "llvm/Support/Debug.h" 
   25 #include "llvm/Support/MathExtras.h" 
   30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS 
   31 #include "mlir/Conversion/Passes.h.inc" 
   34 #define DEBUG_TYPE "arith-to-spirv-pattern" 
   45   if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
 
   47   if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
 
   48     return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
 
   58   if (srcAttr.getValue().isIntN(dstType.getWidth()))
 
   66   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
 
   68     LLVM_DEBUG(llvm::dbgs() << 
"attribute '" << srcAttr << 
"' converted to '" 
   69                             << dstAttr << 
"' for type '" << dstType << 
"'\n");
 
   73   LLVM_DEBUG(llvm::dbgs() << 
"attribute '" << srcAttr
 
   74                           << 
"' illegal: cannot fit into target type '" 
   88   APFloat dstVal = srcAttr.getValue();
 
   89   bool losesInfo = 
false;
 
   90   APFloat::opStatus status =
 
   91       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
 
   92   if (status != APFloat::opOK || losesInfo) {
 
   93     LLVM_DEBUG(llvm::dbgs()
 
   94                << srcAttr << 
" illegal: cannot fit into converted type '" 
  108   APFloat floatVal = floatAttr.getValue();
 
  109   APInt intVal = floatVal.bitcastToAPInt();
 
  115   assert(type && 
"Not a valid type");
 
  119   if (
auto vecType = dyn_cast<VectorType>(type))
 
  120     return vecType.getElementType().isInteger(1);
 
  128   if (
auto vectorType = dyn_cast<VectorType>(type)) {
 
  131     return spirv::ConstantOp::create(builder, loc, vectorType, attr);
 
  134   if (
auto intType = dyn_cast<IntegerType>(type))
 
  135     return spirv::ConstantOp::create(builder, loc, type,
 
  144   auto getNumBitwidth = [](
Type type) {
 
  146     if (type.isIntOrFloat())
 
  147       bw = type.getIntOrFloatBitWidth();
 
  148     else if (
auto vecType = dyn_cast<VectorType>(type))
 
  149       bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
 
  152   unsigned aBW = getNumBitwidth(a);
 
  153   unsigned bBW = getNumBitwidth(b);
 
  154   return aBW != 0 && bBW != 0 && aBW == bBW;
 
  163       llvm::formatv(
"failed to convert source type '{0}'", srcType));
 
  175   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
 
  182 template <
typename Op, 
typename SPIRVOp>
 
  187   matchAndRewrite(
Op op, 
typename Op::Adaptor adaptor,
 
  189     assert(adaptor.getOperands().size() <= 3);
 
  190     auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
 
  191     Type dstType = converter->convertType(op.getType());
 
  195           llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
 
  198     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
 
  200         dstType != op.getType()) {
 
  201       return op.
emitError(
"bitwidth emulation is not implemented yet on " 
  202                           "unsigned op pattern version");
 
  205     auto overflowFlags = arith::IntegerOverflowFlags::none;
 
  206     if (
auto overflowIface =
 
  207             dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
 
  208       if (converter->getTargetEnv().allows(
 
  209               spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
 
  210         overflowFlags = overflowIface.getOverflowAttr().getValue();
 
  213     auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
 
  214         op, dstType, adaptor.getOperands());
 
  216     if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
 
  220     if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
 
  233 struct ConstantCompositeOpPattern final
 
  238   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
 
  240     auto srcType = dyn_cast<ShapedType>(constOp.getType());
 
  241     if (!srcType || srcType.getNumElements() == 1)
 
  246     if (!isa<VectorType, RankedTensorType>(srcType))
 
  249     Type dstType = getTypeConverter()->convertType(srcType);
 
  256     if (
auto denseElementsAttr =
 
  257             dyn_cast<DenseElementsAttr>(constOp.getValue())) {
 
  258       dstElementsAttr = denseElementsAttr;
 
  259     } 
else if (
auto resourceAttr =
 
  260                    dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
 
  264         return constOp->emitError(
"could not find resource blob");
 
  270       bool detectedSplat = 
false;
 
  272         return constOp->emitError(
"resource is not a valid buffer");
 
  277       return constOp->emitError(
"unsupported elements attribute");
 
  280     ShapedType dstAttrType = dstElementsAttr.
getType();
 
  284     if (srcType.getRank() > 1) {
 
  285       if (isa<RankedTensorType>(srcType)) {
 
  287                                             srcType.getElementType());
 
  288         dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
 
  295     Type srcElemType = srcType.getElementType();
 
  299     if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
 
  300       dstElemType = arrayType.getElementType();
 
  302       dstElemType = cast<VectorType>(dstType).getElementType();
 
  306     if (srcElemType != dstElemType) {
 
  308       if (isa<FloatType>(srcElemType)) {
 
  309         for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
 
  312           auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
 
  313           if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
 
  315               isa<IntegerType>(dstElemType)) {
 
  324           elements.push_back(dstAttr);
 
  329         for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
 
  331               srcAttr, cast<IntegerType>(dstElemType), rewriter);
 
  334           elements.push_back(dstAttr);
 
  342       if (isa<RankedTensorType>(dstAttrType))
 
  358 struct ConstantScalarOpPattern final
 
  363   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
 
  365     Type srcType = constOp.getType();
 
  366     if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
 
  367       if (shapedType.getNumElements() != 1)
 
  369       srcType = shapedType.getElementType();
 
  375     if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
 
  376       cstAttr = elementsAttr.getSplatValue<
Attribute>();
 
  378     Type dstType = getTypeConverter()->convertType(srcType);
 
  383     if (isa<FloatType>(srcType)) {
 
  384       auto srcAttr = cast<FloatAttr>(cstAttr);
 
  389       auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
 
  390       if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
 
  392           dstType.getIntOrFloatBitWidth() == 8) {
 
  397       } 
else if (srcType != dstType) {
 
  420     auto srcAttr = cast<IntegerAttr>(cstAttr);
 
  421     IntegerAttr dstAttr =
 
  441 template <
typename SignedAbsOp>
 
  445   assert(lhs == signOperand || rhs == signOperand);
 
  450   Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
 
  451   Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
 
  452   Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
 
  456   if (lhs == signOperand)
 
  457     isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
 
  459     isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
 
  460   Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
 
  461   return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
 
  473   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
 
  475     Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
 
  476         op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
 
  477         adaptor.getOperands()[0], rewriter);
 
  489   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
 
  491     Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
 
  492         op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
 
  493         adaptor.getOperands()[0], rewriter);
 
  508 template <
typename Op, 
typename SPIRVLogicalOp, 
typename SPIRVBitwiseOp>
 
  513   matchAndRewrite(
Op op, 
typename Op::Adaptor adaptor,
 
  515     assert(adaptor.getOperands().size() == 2);
 
  516     Type dstType = this->getTypeConverter()->convertType(op.getType());
 
  521       rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
 
  522           op, dstType, adaptor.getOperands());
 
  524       rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
 
  525           op, dstType, adaptor.getOperands());
 
  540   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
 
  542     assert(adaptor.getOperands().size() == 2);
 
  547     Type dstType = getTypeConverter()->convertType(op.getType());
 
  552                                                      adaptor.getOperands());
 
  564   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
 
  566     assert(adaptor.getOperands().size() == 2);
 
  571     Type dstType = getTypeConverter()->convertType(op.getType());
 
  576         op, dstType, adaptor.getOperands());
 
  591   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
 
  593     Type srcType = adaptor.getOperands().front().getType();
 
  597     Type dstType = getTypeConverter()->convertType(op.getType());
 
  603     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
 
  605         op, dstType, adaptor.getOperands().front(), one, zero);
 
  615 struct IndexCastIndexI1Pattern final
 
  620   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
 
  625     Type dstType = getTypeConverter()->convertType(op.getType());
 
  639 struct IndexCastI1IndexPattern final
 
  644   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
 
  649     Type dstType = getTypeConverter()->convertType(op.getType());
 
  655     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
 
  672   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
 
  674     Value operand = adaptor.getIn();
 
  679     Type dstType = getTypeConverter()->convertType(op.getType());
 
  684     if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
 
  685       unsigned componentBitwidth = intTy.getWidth();
 
  686       allOnes = spirv::ConstantOp::create(
 
  687           rewriter, loc, intTy,
 
  688           rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
 
  689     } 
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
 
  690       unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
 
  691       allOnes = spirv::ConstantOp::create(
 
  692           rewriter, loc, vectorTy,
 
  694                                  APInt::getAllOnes(componentBitwidth)));
 
  697           loc, llvm::formatv(
"unhandled type: {0}", dstType));
 
  713   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
 
  715     Type srcType = adaptor.getIn().getType();
 
  719     Type dstType = getTypeConverter()->convertType(op.getType());
 
  723     if (dstType == srcType) {
 
  731       assert(srcBW < dstBW);
 
  733                                                   rewriter, op.getLoc());
 
  738       auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
 
  739           rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
 
  744           op, dstType, shiftLOp, shiftSize);
 
  747                                                      adaptor.getOperands());
 
  764   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
 
  766     Type srcType = adaptor.getOperands().front().getType();
 
  770     Type dstType = getTypeConverter()->convertType(op.getType());
 
  776     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
 
  778         op, dstType, adaptor.getOperands().front(), one, zero);
 
  789   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
 
  791     Type srcType = adaptor.getIn().getType();
 
  795     Type dstType = getTypeConverter()->convertType(op.getType());
 
  799     if (dstType == srcType) {
 
  807           dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
 
  810                                                        adaptor.getIn(), mask);
 
  813                                                      adaptor.getOperands());
 
  829   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
 
  831     Type dstType = getTypeConverter()->convertType(op.getType());
 
  839     auto srcType = adaptor.getOperands().front().getType();
 
  841     Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
 
  842     Value maskedSrc = spirv::BitwiseAndOp::create(
 
  843         rewriter, loc, srcType, adaptor.getOperands()[0], mask);
 
  844     Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
 
  847     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
 
  859   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
 
  861     Type srcType = adaptor.getIn().getType();
 
  862     Type dstType = getTypeConverter()->convertType(op.getType());
 
  869     if (dstType == srcType) {
 
  876           dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
 
  878                                                        adaptor.getIn(), mask);
 
  882                                                      adaptor.getOperands());
 
  892 static std::optional<spirv::FPRoundingMode>
 
  893 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
 
  894   switch (roundingMode) {
 
  895   case arith::RoundingMode::downward:
 
  896     return spirv::FPRoundingMode::RTN;
 
  897   case arith::RoundingMode::to_nearest_even:
 
  898     return spirv::FPRoundingMode::RTE;
 
  899   case arith::RoundingMode::toward_zero:
 
  900     return spirv::FPRoundingMode::RTZ;
 
  901   case arith::RoundingMode::upward:
 
  902     return spirv::FPRoundingMode::RTP;
 
  903   case arith::RoundingMode::to_nearest_away:
 
  908   llvm_unreachable(
"Unhandled rounding mode");
 
  912 template <
typename Op, 
typename SPIRVOp>
 
  917   matchAndRewrite(
Op op, 
typename Op::Adaptor adaptor,
 
  919     Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
 
  920     Type dstType = this->getTypeConverter()->convertType(op.getType());
 
  927     if (dstType == srcType) {
 
  930       rewriter.
replaceOp(op, adaptor.getOperands().front());
 
  933       std::optional<spirv::FPRoundingMode> rm = std::nullopt;
 
  934       if (
auto roundingModeOp =
 
  935               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
 
  936         if (arith::RoundingModeAttr roundingMode =
 
  937                 roundingModeOp.getRoundingModeAttr()) {
 
  939                     convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
 
  942                 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
 
  947       auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
 
  948           op, dstType, adaptor.getOperands());
 
  969   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 
  971     Type srcType = op.getLhs().getType();
 
  974     Type dstType = getTypeConverter()->convertType(srcType);
 
  978     switch (op.getPredicate()) {
 
  979     case arith::CmpIPredicate::eq: {
 
  984     case arith::CmpIPredicate::ne: {
 
  986           op, adaptor.getLhs(), adaptor.getRhs());
 
  989     case arith::CmpIPredicate::uge:
 
  990     case arith::CmpIPredicate::ugt:
 
  991     case arith::CmpIPredicate::ule:
 
  992     case arith::CmpIPredicate::ult: {
 
  996       if (
auto vectorType = dyn_cast<VectorType>(dstType))
 
  999           arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
 
 1001           arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
 
 1020   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 
 1022     Type srcType = op.getLhs().getType();
 
 1025     Type dstType = getTypeConverter()->convertType(srcType);
 
 1029     switch (op.getPredicate()) {
 
 1030 #define DISPATCH(cmpPredicate, spirvOp)                                        \ 
 1031   case cmpPredicate:                                                           \ 
 1032     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \ 
 1033         !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType &&      \ 
 1034         !hasSameBitwidth(srcType, dstType)) {                                  \ 
 1035       return op.emitError(                                                     \ 
 1036           "bitwidth emulation is not implemented yet on unsigned op");         \
 
 1038     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
 
 1039                                          adaptor.getRhs());                    \
 
 1042       DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
 
 1043       DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
 
 1044       DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
 
 1045       DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
 
 1046       DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
 
 1047       DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
 
 1048       DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
 
 1049       DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
 
 1050       DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
 
 1051       DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
 
 1069   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
 
 1071     switch (op.getPredicate()) {
 
 1072 #define DISPATCH(cmpPredicate, spirvOp)                                        \ 
 1073   case cmpPredicate:                                                           \ 
 1074     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \ 
 1075                                          adaptor.getRhs());                    \ 
 1079       DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
 
 1080       DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
 
 1081       DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
 
 1082       DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
 
 1083       DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
 
 1084       DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
 
 1086       DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
 
 1087       DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
 
 1088       DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
 
 1089       DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
 
 1090       DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
 
 1091       DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
 
 1109   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
 
 1111     if (op.getPredicate() == arith::CmpFPredicate::ORD) {
 
 1117     if (op.getPredicate() == arith::CmpFPredicate::UNO) {
 
 1134   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
 
 1136     if (op.getPredicate() != arith::CmpFPredicate::ORD &&
 
 1137         op.getPredicate() != arith::CmpFPredicate::UNO)
 
 1143     if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
 
 1144       if (op.getPredicate() == arith::CmpFPredicate::ORD) {
 
 1146         replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
 
 1152       Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
 
 1153       Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
 
 1155       replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
 
 1156       if (op.getPredicate() == arith::CmpFPredicate::ORD)
 
 1157         replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
 
 1170 class AddUIExtendedOpPattern final
 
 1175   matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
 
 1177     Type dstElemTy = adaptor.getLhs().getType();
 
 1179     Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
 
 1182     Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
 
 1184     Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
 
 1188     Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
 
 1189     Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
 
 1191     rewriter.
replaceOp(op, {sumResult, carryResult});
 
 1201 template <
typename ArithMulOp, 
typename SPIRVMulOp>
 
 1206   matchAndRewrite(ArithMulOp op, 
typename ArithMulOp::Adaptor adaptor,
 
 1210         SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
 
 1212     Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
 
 1214     Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
 
 1231   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
 
 1234                                                  adaptor.getTrueValue(),
 
 1235                                                  adaptor.getFalseValue());
 
 1246 template <
typename Op, 
typename SPIRVOp>
 
 1251   matchAndRewrite(
Op op, 
typename Op::Adaptor adaptor,
 
 1253     auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
 
 1254     Type dstType = converter->convertType(op.getType());
 
 1268         SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
 
 1270     if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
 
 1275     Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
 
 1276     Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
 
 1278     Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
 
 1279                                             adaptor.getLhs(), spirvOp);
 
 1280     Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
 
 1281                                             adaptor.getRhs(), select1);
 
 1294 template <
typename Op, 
typename SPIRVOp>
 
 1296   template <
typename TargetOp>
 
 1297   constexpr 
bool shouldInsertNanGuards()
 const {
 
 1298     return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
 
 1304   matchAndRewrite(
Op op, 
typename Op::Adaptor adaptor,
 
 1306     auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
 
 1307     Type dstType = converter->convertType(op.getType());
 
 1322         SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
 
 1324     if (!shouldInsertNanGuards<SPIRVOp>() ||
 
 1325         bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
 
 1330     Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
 
 1331     Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
 
 1333     Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
 
 1334                                             adaptor.getRhs(), spirvOp);
 
 1335     Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
 
 1336                                             adaptor.getLhs(), select1);
 
 1353     ConstantCompositeOpPattern,
 
 1354     ConstantScalarOpPattern,
 
 1355     ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
 
 1356     ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
 
 1357     ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
 
 1361     RemSIOpGLPattern, RemSIOpCLPattern,
 
 1362     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
 
 1363     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
 
 1364     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
 
 1365     ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
 
 1374     ExtUIPattern, ExtUII1Pattern,
 
 1375     ExtSIPattern, ExtSII1Pattern,
 
 1376     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
 
 1377     TruncIPattern, TruncII1Pattern,
 
 1378     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
 
 1379     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
 
 1380     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
 
 1381     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
 
 1382     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
 
 1383     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
 
 1384     IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
 
 1385     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
 
 1386     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
 
 1387     CmpIOpBooleanPattern, CmpIOpPattern,
 
 1388     CmpFOpNanNonePattern, CmpFOpPattern,
 
 1389     AddUIExtendedOpPattern,
 
 1390     MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
 
 1391     MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
 
 1394     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
 
 1395     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
 
 1396     MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
 
 1397     MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
 
 1403     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
 
 1404     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
 
 1405     MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
 
 1406     MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
 
 1411   >(typeConverter, 
patterns.getContext());
 
 1425 struct ConvertArithToSPIRVPass
 
 1426     : 
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
 
 1429   void runOnOperation()
 override {
 
 1432     std::unique_ptr<SPIRVConversionTarget> target =
 
 1436     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
 
 1437     options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
 
 1442     target->addLegalOp<UnrealizedConversionCastOp>();
 
 1445     target->addIllegalDialect<arith::ArithDialect>();
 
 1451       signalPassFailure();
 
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.