21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
188 return rewriter.notifyMatchFailure(
190 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
195 dstType != op.getType()) {
196 return op.
emitError(
"bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (
auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
213 rewriter.getUnitAttr());
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
217 rewriter.getUnitAttr());
228struct ConstantCompositeOpPattern final
229 :
public OpConversionPattern<arith::ConstantOp> {
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter)
const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
244 Type dstType = getTypeConverter()->convertType(srcType);
251 if (
auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 }
else if (
auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
259 return constOp->emitError(
"could not find resource blob");
265 bool detectedSplat =
false;
267 return constOp->emitError(
"resource is not a valid buffer");
272 return constOp->emitError(
"unsupported elements attribute");
275 ShapedType dstAttrType = dstElementsAttr.
getType();
279 if (srcType.getRank() > 1) {
280 if (isa<RankedTensorType>(srcType)) {
281 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
282 srcType.getElementType());
283 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
290 Type srcElemType = srcType.getElementType();
294 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
295 dstElemType = arrayType.getElementType();
297 dstElemType = cast<VectorType>(dstType).getElementType();
301 if (srcElemType != dstElemType) {
303 if (isa<FloatType>(srcElemType)) {
304 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
307 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
308 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
310 isa<IntegerType>(dstElemType)) {
319 elements.push_back(dstAttr);
324 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
326 srcAttr, cast<IntegerType>(dstElemType), rewriter);
329 elements.push_back(dstAttr);
337 if (isa<RankedTensorType>(dstAttrType))
339 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
341 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
346 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
353struct ConstantScalarOpPattern final
354 :
public OpConversionPattern<arith::ConstantOp> {
358 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
359 ConversionPatternRewriter &rewriter)
const override {
360 Type srcType = constOp.getType();
361 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
362 if (shapedType.getNumElements() != 1)
364 srcType = shapedType.getElementType();
369 Attribute cstAttr = constOp.getValue();
370 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
371 cstAttr = elementsAttr.getSplatValue<Attribute>();
373 Type dstType = getTypeConverter()->convertType(srcType);
378 if (isa<FloatType>(srcType)) {
379 auto srcAttr = cast<FloatAttr>(cstAttr);
380 Attribute dstAttr = srcAttr;
384 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
385 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
387 dstType.getIntOrFloatBitWidth() == 8) {
392 }
else if (srcType != dstType) {
398 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
409 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
415 auto srcAttr = cast<IntegerAttr>(cstAttr);
416 IntegerAttr dstAttr =
420 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
436template <
typename SignedAbsOp>
439 assert(
lhs.getType() ==
rhs.getType());
440 assert(
lhs == signOperand ||
rhs == signOperand);
445 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
446 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
447 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
451 if (
lhs == signOperand)
452 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
454 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
455 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
456 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
464struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
468 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override {
470 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
471 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
472 adaptor.getOperands()[0], rewriter);
473 rewriter.replaceOp(op,
result);
480struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
484 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const override {
486 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
487 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
488 adaptor.getOperands()[0], rewriter);
489 rewriter.replaceOp(op,
result);
503template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
504struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
505 using OpConversionPattern<
Op>::OpConversionPattern;
508 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
509 ConversionPatternRewriter &rewriter)
const override {
510 assert(adaptor.getOperands().size() == 2);
511 Type dstType = this->getTypeConverter()->convertType(op.getType());
516 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
517 op, dstType, adaptor.getOperands());
519 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
520 op, dstType, adaptor.getOperands());
531struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
535 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter)
const override {
537 assert(adaptor.getOperands().size() == 2);
542 Type dstType = getTypeConverter()->convertType(op.getType());
546 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
547 adaptor.getOperands());
555struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
559 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
561 assert(adaptor.getOperands().size() == 2);
566 Type dstType = getTypeConverter()->convertType(op.getType());
570 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
571 op, dstType, adaptor.getOperands());
582struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
586 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
587 ConversionPatternRewriter &rewriter)
const override {
588 Type srcType = adaptor.getOperands().front().getType();
592 Type dstType = getTypeConverter()->convertType(op.getType());
596 Location loc = op.getLoc();
597 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
598 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
599 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
600 op, dstType, adaptor.getOperands().front(), one, zero);
610struct IndexCastIndexI1Pattern final
611 :
public OpConversionPattern<arith::IndexCastOp> {
615 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter)
const override {
620 Type dstType = getTypeConverter()->convertType(op.getType());
624 Location loc = op.getLoc();
626 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
627 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
634struct IndexCastI1IndexPattern final
635 :
public OpConversionPattern<arith::IndexCastOp> {
639 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter)
const override {
644 Type dstType = getTypeConverter()->convertType(op.getType());
648 Location loc = op.getLoc();
649 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
650 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
651 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
663struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
667 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter)
const override {
669 Value operand = adaptor.getIn();
673 Location loc = op.getLoc();
674 Type dstType = getTypeConverter()->convertType(op.getType());
679 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
680 unsigned componentBitwidth = intTy.getWidth();
681 allOnes = spirv::ConstantOp::create(
682 rewriter, loc, intTy,
683 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
684 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
685 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
686 allOnes = spirv::ConstantOp::create(
687 rewriter, loc, vectorTy,
688 SplatElementsAttr::get(vectorTy,
689 APInt::getAllOnes(componentBitwidth)));
691 return rewriter.notifyMatchFailure(
692 loc, llvm::formatv(
"unhandled type: {0}", dstType));
695 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
696 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
704struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
708 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
709 ConversionPatternRewriter &rewriter)
const override {
710 Type srcType = adaptor.getIn().getType();
714 Type dstType = getTypeConverter()->convertType(op.getType());
718 if (dstType == srcType) {
726 assert(srcBW < dstBW);
728 rewriter, op.getLoc());
733 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
734 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
738 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
739 op, dstType, shiftLOp, shiftSize);
741 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
742 adaptor.getOperands());
755struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
759 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
760 ConversionPatternRewriter &rewriter)
const override {
761 Type srcType = adaptor.getOperands().front().getType();
765 Type dstType = getTypeConverter()->convertType(op.getType());
769 Location loc = op.getLoc();
770 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
771 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
772 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
773 op, dstType, adaptor.getOperands().front(), one, zero);
780struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
784 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter)
const override {
786 Type srcType = adaptor.getIn().getType();
790 Type dstType = getTypeConverter()->convertType(op.getType());
794 if (dstType == srcType) {
802 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
804 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
805 adaptor.getIn(), mask);
807 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
808 adaptor.getOperands());
820struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
824 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
825 ConversionPatternRewriter &rewriter)
const override {
826 Type dstType = getTypeConverter()->convertType(op.getType());
833 Location loc = op.getLoc();
834 auto srcType = adaptor.getOperands().front().getType();
836 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
837 Value maskedSrc = spirv::BitwiseAndOp::create(
838 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
839 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
841 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
842 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
843 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
850struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
854 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter)
const override {
856 Type srcType = adaptor.getIn().getType();
857 Type dstType = getTypeConverter()->convertType(op.getType());
864 if (dstType == srcType) {
871 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
872 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
873 adaptor.getIn(), mask);
876 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
877 adaptor.getOperands());
887static std::optional<spirv::FPRoundingMode>
888convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
889 switch (roundingMode) {
890 case arith::RoundingMode::downward:
891 return spirv::FPRoundingMode::RTN;
892 case arith::RoundingMode::to_nearest_even:
893 return spirv::FPRoundingMode::RTE;
894 case arith::RoundingMode::toward_zero:
895 return spirv::FPRoundingMode::RTZ;
896 case arith::RoundingMode::upward:
897 return spirv::FPRoundingMode::RTP;
898 case arith::RoundingMode::to_nearest_away:
903 llvm_unreachable(
"Unhandled rounding mode");
907template <
typename Op,
typename SPIRVOp>
908struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
909 using OpConversionPattern<
Op>::OpConversionPattern;
912 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
913 ConversionPatternRewriter &rewriter)
const override {
914 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
915 Type dstType = this->getTypeConverter()->convertType(op.getType());
922 if (dstType == srcType) {
925 rewriter.replaceOp(op, adaptor.getOperands().front());
928 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
929 if (
auto roundingModeOp =
930 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
931 if (arith::RoundingModeAttr roundingMode =
932 roundingModeOp.getRoundingModeAttr()) {
934 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
935 return rewriter.notifyMatchFailure(
937 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
942 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
943 op, dstType, adaptor.getOperands());
947 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
959class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
964 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
965 ConversionPatternRewriter &rewriter)
const override {
966 Type srcType = op.getLhs().getType();
969 Type dstType = getTypeConverter()->convertType(srcType);
973 switch (op.getPredicate()) {
974 case arith::CmpIPredicate::eq: {
975 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
979 case arith::CmpIPredicate::ne: {
980 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
981 op, adaptor.getLhs(), adaptor.getRhs());
984 case arith::CmpIPredicate::uge:
985 case arith::CmpIPredicate::ugt:
986 case arith::CmpIPredicate::ule:
987 case arith::CmpIPredicate::ult: {
990 Type type = rewriter.getI32Type();
991 if (
auto vectorType = dyn_cast<VectorType>(dstType))
992 type = VectorType::get(vectorType.getShape(), type);
994 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
996 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
998 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1010class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1015 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1016 ConversionPatternRewriter &rewriter)
const override {
1017 Type srcType = op.getLhs().getType();
1020 Type dstType = getTypeConverter()->convertType(srcType);
1024 switch (op.getPredicate()) {
1025#define DISPATCH(cmpPredicate, spirvOp) \
1026 case cmpPredicate: \
1027 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1028 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1029 !hasSameBitwidth(srcType, dstType)) { \
1030 return op.emitError( \
1031 "bitwidth emulation is not implemented yet on unsigned op"); \
1033 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1034 adaptor.getRhs()); \
1037 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1038 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1039 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1040 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1041 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1042 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1044 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1045 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1046 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1059class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1064 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1065 ConversionPatternRewriter &rewriter)
const override {
1066 switch (op.getPredicate()) {
1067#define DISPATCH(cmpPredicate, spirvOp) \
1068 case cmpPredicate: \
1069 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1070 adaptor.getRhs()); \
1074 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1075 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1076 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1077 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1078 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1079 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1081 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1082 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1083 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1085 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1086 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1099class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1104 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1105 ConversionPatternRewriter &rewriter)
const override {
1106 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1107 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1112 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1113 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1124class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1129 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1130 ConversionPatternRewriter &rewriter)
const override {
1131 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1132 op.getPredicate() != arith::CmpFPredicate::UNO)
1135 Location loc = op.getLoc();
1138 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1139 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1141 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1144 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1147 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1148 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1150 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1151 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1152 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1155 rewriter.replaceOp(op, replace);
1165class AddUIExtendedOpPattern final
1166 :
public OpConversionPattern<arith::AddUIExtendedOp> {
1170 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1171 ConversionPatternRewriter &rewriter)
const override {
1172 Type dstElemTy = adaptor.getLhs().getType();
1173 Location loc = op->getLoc();
1174 Value
result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1177 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1179 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1183 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1184 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1186 rewriter.replaceOp(op, {sumResult, carryResult});
1196template <
typename ArithMulOp,
typename SPIRVMulOp>
1197class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1199 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1201 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1202 ConversionPatternRewriter &rewriter)
const override {
1203 Location loc = op->getLoc();
1205 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1207 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1209 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1212 rewriter.replaceOp(op, {low, high});
1222class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1226 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1227 ConversionPatternRewriter &rewriter)
const override {
1228 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1229 adaptor.getTrueValue(),
1230 adaptor.getFalseValue());
1241template <
typename Op,
typename SPIRVOp>
1242class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1244 using OpConversionPattern<
Op>::OpConversionPattern;
1246 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1247 ConversionPatternRewriter &rewriter)
const override {
1248 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1249 Type dstType = converter->convertType(op.getType());
1261 Location loc = op.
getLoc();
1263 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1265 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1266 rewriter.replaceOp(op, spirvOp);
1270 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1271 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1273 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1274 adaptor.getLhs(), spirvOp);
1275 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1276 adaptor.getRhs(), select1);
1278 rewriter.replaceOp(op, select2);
1289template <
typename Op,
typename SPIRVOp>
1290class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1291 template <
typename TargetOp>
1292 constexpr bool shouldInsertNanGuards()
const {
1293 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1297 using OpConversionPattern<
Op>::OpConversionPattern;
1299 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1300 ConversionPatternRewriter &rewriter)
const override {
1301 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1302 Type dstType = converter->convertType(op.getType());
1317 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1319 if (!shouldInsertNanGuards<SPIRVOp>() ||
1320 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1321 rewriter.replaceOp(op, spirvOp);
1325 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1326 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1328 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1329 adaptor.getRhs(), spirvOp);
1330 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1331 adaptor.getLhs(), select1);
1333 rewriter.replaceOp(op, select2);
1348 ConstantCompositeOpPattern,
1349 ConstantScalarOpPattern,
1350 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1351 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1352 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1356 RemSIOpGLPattern, RemSIOpCLPattern,
1357 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1358 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1359 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1360 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1369 ExtUIPattern, ExtUII1Pattern,
1370 ExtSIPattern, ExtSII1Pattern,
1371 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1372 TruncIPattern, TruncII1Pattern,
1373 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1374 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1375 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1376 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1377 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1378 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1379 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1380 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1381 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1382 CmpIOpBooleanPattern, CmpIOpPattern,
1383 CmpFOpNanNonePattern, CmpFOpPattern,
1384 AddUIExtendedOpPattern,
1385 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1386 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1389 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1390 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1391 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1392 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1398 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1399 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1400 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1401 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1406 >(typeConverter,
patterns.getContext());
1420struct ConvertArithToSPIRVPass
1421 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1424 void runOnOperation()
override {
1427 std::unique_ptr<SPIRVConversionTarget>
target =
1431 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1432 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1437 target->addLegalOp<UnrealizedConversionCastOp>();
1440 target->addIllegalDialect<arith::ArithDialect>();
1445 if (failed(applyPartialConversion(op, *
target, std::move(
patterns))))
1446 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.