21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
188 return rewriter.notifyMatchFailure(
190 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
195 dstType != op.getType()) {
196 return op.
emitError(
"bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (
auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
213 rewriter.getUnitAttr());
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
217 rewriter.getUnitAttr());
228struct ConstantCompositeOpPattern final
229 :
public OpConversionPattern<arith::ConstantOp> {
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter)
const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
244 Type dstType = getTypeConverter()->convertType(srcType);
251 if (
auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 }
else if (
auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
259 return constOp->emitError(
"could not find resource blob");
266 return constOp->emitError(
"resource is not a valid buffer");
271 return constOp->emitError(
"unsupported elements attribute");
274 ShapedType dstAttrType = dstElementsAttr.
getType();
278 if (srcType.getRank() > 1) {
279 if (isa<RankedTensorType>(srcType)) {
280 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
281 srcType.getElementType());
282 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
289 Type srcElemType = srcType.getElementType();
293 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
294 dstElemType = arrayType.getElementType();
296 dstElemType = cast<VectorType>(dstType).getElementType();
300 if (srcElemType != dstElemType) {
302 if (isa<FloatType>(srcElemType)) {
303 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
306 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
307 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
309 isa<IntegerType>(dstElemType)) {
318 elements.push_back(dstAttr);
323 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
325 srcAttr, cast<IntegerType>(dstElemType), rewriter);
328 elements.push_back(dstAttr);
336 if (isa<RankedTensorType>(dstAttrType))
338 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
340 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
345 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
352struct ConstantScalarOpPattern final
353 :
public OpConversionPattern<arith::ConstantOp> {
357 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 Type srcType = constOp.getType();
360 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
361 if (shapedType.getNumElements() != 1)
363 srcType = shapedType.getElementType();
368 Attribute cstAttr = constOp.getValue();
369 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
370 cstAttr = elementsAttr.getSplatValue<Attribute>();
372 Type dstType = getTypeConverter()->convertType(srcType);
377 if (isa<FloatType>(srcType)) {
378 auto srcAttr = cast<FloatAttr>(cstAttr);
379 Attribute dstAttr = srcAttr;
383 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
384 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
386 dstType.getIntOrFloatBitWidth() == 8) {
391 }
else if (srcType != dstType) {
397 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
408 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
414 auto srcAttr = cast<IntegerAttr>(cstAttr);
415 IntegerAttr dstAttr =
419 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
435template <
typename SignedAbsOp>
438 assert(
lhs.getType() ==
rhs.getType());
439 assert(
lhs == signOperand ||
rhs == signOperand);
444 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
445 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
446 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
450 if (
lhs == signOperand)
451 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
453 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
454 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
455 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
463struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
467 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override {
469 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
470 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
471 adaptor.getOperands()[0], rewriter);
472 rewriter.replaceOp(op,
result);
479struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
483 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const override {
485 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
486 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
487 adaptor.getOperands()[0], rewriter);
488 rewriter.replaceOp(op,
result);
502template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
503struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
504 using OpConversionPattern<
Op>::OpConversionPattern;
507 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter)
const override {
509 assert(adaptor.getOperands().size() == 2);
510 Type dstType = this->getTypeConverter()->convertType(op.getType());
515 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
516 op, dstType, adaptor.getOperands());
518 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
519 op, dstType, adaptor.getOperands());
530struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter)
const override {
536 assert(adaptor.getOperands().size() == 2);
541 Type dstType = getTypeConverter()->convertType(op.getType());
545 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
546 adaptor.getOperands());
554struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
558 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const override {
560 assert(adaptor.getOperands().size() == 2);
565 Type dstType = getTypeConverter()->convertType(op.getType());
569 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
570 op, dstType, adaptor.getOperands());
581struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
585 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter)
const override {
587 Type srcType = adaptor.getOperands().front().getType();
591 Type dstType = getTypeConverter()->convertType(op.getType());
595 Location loc = op.getLoc();
596 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
597 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
598 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
599 op, dstType, adaptor.getOperands().front(), one, zero);
611template <
typename ArithOp,
typename SPIRVOp,
bool IsSigned>
612struct IntToFPPattern final :
public OpConversionPattern<ArithOp> {
613 using OpConversionPattern<ArithOp>::OpConversionPattern;
616 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
617 ConversionPatternRewriter &rewriter)
const override {
618 Type srcType = adaptor.getOperands().front().getType();
622 Type dstType = this->getTypeConverter()->convertType(op.getType());
627 unsigned originalBitwidth =
629 unsigned convertedBitwidth =
632 if (originalBitwidth >= convertedBitwidth) {
633 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
638 Location loc = op.getLoc();
640 if constexpr (IsSigned) {
642 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
645 Value shifted = spirv::ShiftLeftLogicalOp::create(
646 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
647 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
652 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
654 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
655 adaptor.getIn(), mask);
657 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
667struct IndexCastIndexI1Pattern final
668 :
public OpConversionPattern<arith::IndexCastOp> {
672 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter)
const override {
677 Type dstType = getTypeConverter()->convertType(op.getType());
681 Location loc = op.getLoc();
683 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
684 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
691struct IndexCastI1IndexPattern final
692 :
public OpConversionPattern<arith::IndexCastOp> {
696 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
697 ConversionPatternRewriter &rewriter)
const override {
701 Type dstType = getTypeConverter()->convertType(op.getType());
705 Location loc = op.getLoc();
706 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
707 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
708 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
720struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
724 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
725 ConversionPatternRewriter &rewriter)
const override {
726 Value operand = adaptor.getIn();
730 Location loc = op.getLoc();
731 Type dstType = getTypeConverter()->convertType(op.getType());
736 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
737 unsigned componentBitwidth = intTy.getWidth();
738 allOnes = spirv::ConstantOp::create(
739 rewriter, loc, intTy,
740 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
741 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
742 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
743 allOnes = spirv::ConstantOp::create(
744 rewriter, loc, vectorTy,
746 APInt::getAllOnes(componentBitwidth)));
748 return rewriter.notifyMatchFailure(
749 loc, llvm::formatv(
"unhandled type: {0}", dstType));
752 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
753 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
761struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
765 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter)
const override {
767 Type srcType = adaptor.getIn().getType();
771 Type dstType = getTypeConverter()->convertType(op.getType());
775 if (dstType == srcType) {
783 assert(srcBW < dstBW);
785 rewriter, op.getLoc());
787 return rewriter.notifyMatchFailure(op,
"unsupported type for shift");
792 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
793 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
797 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
798 op, dstType, shiftLOp, shiftSize);
800 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
801 adaptor.getOperands());
814struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
818 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
819 ConversionPatternRewriter &rewriter)
const override {
820 Type srcType = adaptor.getOperands().front().getType();
824 Type dstType = getTypeConverter()->convertType(op.getType());
828 Location loc = op.getLoc();
829 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
830 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
831 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
832 op, dstType, adaptor.getOperands().front(), one, zero);
839struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
843 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
844 ConversionPatternRewriter &rewriter)
const override {
845 Type srcType = adaptor.getIn().getType();
849 Type dstType = getTypeConverter()->convertType(op.getType());
853 if (dstType == srcType) {
861 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
864 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
865 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
866 adaptor.getIn(), mask);
868 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
869 adaptor.getOperands());
881struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
885 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
886 ConversionPatternRewriter &rewriter)
const override {
887 Type dstType = getTypeConverter()->convertType(op.getType());
894 Location loc = op.getLoc();
895 auto srcType = adaptor.getOperands().front().getType();
897 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
898 Value maskedSrc = spirv::BitwiseAndOp::create(
899 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
900 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
902 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
903 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
904 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
911struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
915 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
916 ConversionPatternRewriter &rewriter)
const override {
917 Type srcType = adaptor.getIn().getType();
918 Type dstType = getTypeConverter()->convertType(op.getType());
925 if (dstType == srcType) {
932 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
934 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
935 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
936 adaptor.getIn(), mask);
939 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
940 adaptor.getOperands());
950static std::optional<spirv::FPRoundingMode>
951convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
952 switch (roundingMode) {
953 case arith::RoundingMode::downward:
954 return spirv::FPRoundingMode::RTN;
955 case arith::RoundingMode::to_nearest_even:
956 return spirv::FPRoundingMode::RTE;
957 case arith::RoundingMode::toward_zero:
958 return spirv::FPRoundingMode::RTZ;
959 case arith::RoundingMode::upward:
960 return spirv::FPRoundingMode::RTP;
961 case arith::RoundingMode::to_nearest_away:
966 llvm_unreachable(
"Unhandled rounding mode");
970template <
typename Op,
typename SPIRVOp>
971struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
972 using OpConversionPattern<
Op>::OpConversionPattern;
975 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
976 ConversionPatternRewriter &rewriter)
const override {
977 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
978 Type dstType = this->getTypeConverter()->convertType(op.getType());
985 if (dstType == srcType) {
988 rewriter.replaceOp(op, adaptor.getOperands().front());
991 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
992 if (
auto roundingModeOp =
993 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
994 if (arith::RoundingModeAttr roundingMode =
995 roundingModeOp.getRoundingModeAttr()) {
997 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
998 return rewriter.notifyMatchFailure(
1000 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
1005 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1006 op, dstType, adaptor.getOperands());
1010 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1022class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
1027 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter)
const override {
1029 Type srcType = op.getLhs().getType();
1032 Type dstType = getTypeConverter()->convertType(srcType);
1036 switch (op.getPredicate()) {
1037 case arith::CmpIPredicate::eq: {
1038 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1042 case arith::CmpIPredicate::ne: {
1043 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1044 op, adaptor.getLhs(), adaptor.getRhs());
1047 case arith::CmpIPredicate::uge:
1048 case arith::CmpIPredicate::ugt:
1049 case arith::CmpIPredicate::ule:
1050 case arith::CmpIPredicate::ult: {
1053 Type type = rewriter.getI32Type();
1054 if (
auto vectorType = dyn_cast<VectorType>(dstType))
1055 type = VectorType::get(vectorType.getShape(), type);
1057 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1059 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1061 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1073class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1078 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1079 ConversionPatternRewriter &rewriter)
const override {
1080 Type srcType = op.getLhs().getType();
1083 Type dstType = getTypeConverter()->convertType(srcType);
1087 switch (op.getPredicate()) {
1088#define DISPATCH(cmpPredicate, spirvOp) \
1089 case cmpPredicate: \
1090 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1091 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1092 !hasSameBitwidth(srcType, dstType)) { \
1093 return op.emitError( \
1094 "bitwidth emulation is not implemented yet on unsigned op"); \
1096 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1097 adaptor.getRhs()); \
1100 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1101 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1102 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1103 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1104 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1105 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1106 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1107 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1108 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1109 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1122class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1127 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter)
const override {
1129 switch (op.getPredicate()) {
1130#define DISPATCH(cmpPredicate, spirvOp) \
1131 case cmpPredicate: \
1132 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1133 adaptor.getRhs()); \
1137 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1138 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1139 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1140 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1141 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1142 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1144 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1145 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1146 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1147 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1148 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1149 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1162class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1167 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1168 ConversionPatternRewriter &rewriter)
const override {
1169 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1170 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1175 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1176 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1187class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1192 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1193 ConversionPatternRewriter &rewriter)
const override {
1194 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1195 op.getPredicate() != arith::CmpFPredicate::UNO)
1198 Location loc = op.getLoc();
1201 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1202 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1204 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1207 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1210 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1211 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1213 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1214 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1215 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1218 rewriter.replaceOp(op, replace);
1228class AddUIExtendedOpPattern final
1229 :
public OpConversionPattern<arith::AddUIExtendedOp> {
1233 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1234 ConversionPatternRewriter &rewriter)
const override {
1235 Type dstElemTy = adaptor.getLhs().getType();
1236 Location loc = op->getLoc();
1237 Value
result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1240 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1242 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1246 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1247 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1249 rewriter.replaceOp(op, {sumResult, carryResult});
1259template <
typename ArithMulOp,
typename SPIRVMulOp>
1260class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1262 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1264 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const override {
1266 Location loc = op->getLoc();
1268 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1270 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1272 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1275 rewriter.replaceOp(op, {low, high});
1285class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1289 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1290 ConversionPatternRewriter &rewriter)
const override {
1291 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1292 adaptor.getTrueValue(),
1293 adaptor.getFalseValue());
1304template <
typename Op,
typename SPIRVOp>
1305class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1307 using OpConversionPattern<
Op>::OpConversionPattern;
1309 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1310 ConversionPatternRewriter &rewriter)
const override {
1311 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1312 Type dstType = converter->convertType(op.getType());
1324 Location loc = op.
getLoc();
1326 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1328 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1329 rewriter.replaceOp(op, spirvOp);
1333 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1334 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1336 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1337 adaptor.getLhs(), spirvOp);
1338 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1339 adaptor.getRhs(), select1);
1341 rewriter.replaceOp(op, select2);
1352template <
typename Op,
typename SPIRVOp>
1353class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1354 template <
typename TargetOp>
1355 constexpr bool shouldInsertNanGuards()
const {
1356 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1360 using OpConversionPattern<
Op>::OpConversionPattern;
1362 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1363 ConversionPatternRewriter &rewriter)
const override {
1364 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1365 Type dstType = converter->convertType(op.getType());
1380 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1382 if (!shouldInsertNanGuards<SPIRVOp>() ||
1383 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1384 rewriter.replaceOp(op, spirvOp);
1388 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1389 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1391 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1392 adaptor.getRhs(), spirvOp);
1393 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1394 adaptor.getLhs(), select1);
1396 rewriter.replaceOp(op, select2);
1411 ConstantCompositeOpPattern,
1412 ConstantScalarOpPattern,
1413 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1414 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1415 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1419 RemSIOpGLPattern, RemSIOpCLPattern,
1420 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1421 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1422 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1423 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1432 ExtUIPattern, ExtUII1Pattern,
1433 ExtSIPattern, ExtSII1Pattern,
1434 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1435 TruncIPattern, TruncII1Pattern,
1436 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1437 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1439 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1440 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1441 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1442 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1443 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1444 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1445 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1446 CmpIOpBooleanPattern, CmpIOpPattern,
1447 CmpFOpNanNonePattern, CmpFOpPattern,
1448 AddUIExtendedOpPattern,
1449 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1450 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1453 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1454 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1455 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1456 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1462 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1463 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1464 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1465 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1475 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1484struct ConvertArithToSPIRVPass
1488 void runOnOperation()
override {
1491 std::unique_ptr<SPIRVConversionTarget>
target =
1495 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1496 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1501 target->addLegalOp<UnrealizedConversionCastOp>();
1504 target->addIllegalDialect<arith::ArithDialect>();
1509 if (failed(applyPartialConversion(op, *
target, std::move(patterns))))
1510 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ConvertArithToSPIRVPassBase Base
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.