21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
188 return rewriter.notifyMatchFailure(
190 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
195 dstType != op.getType()) {
196 return op.
emitError(
"bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (
auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
213 rewriter.getUnitAttr());
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
217 rewriter.getUnitAttr());
228struct ConstantCompositeOpPattern final
229 :
public OpConversionPattern<arith::ConstantOp> {
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter)
const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
244 Type dstType = getTypeConverter()->convertType(srcType);
251 if (
auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 }
else if (
auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
259 return constOp->emitError(
"could not find resource blob");
266 return constOp->emitError(
"resource is not a valid buffer");
271 return constOp->emitError(
"unsupported elements attribute");
274 ShapedType dstAttrType = dstElementsAttr.
getType();
278 if (srcType.getRank() > 1) {
279 if (isa<RankedTensorType>(srcType)) {
280 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
281 srcType.getElementType());
282 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
289 Type srcElemType = srcType.getElementType();
293 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
294 dstElemType = arrayType.getElementType();
296 dstElemType = cast<VectorType>(dstType).getElementType();
300 if (srcElemType != dstElemType) {
302 if (isa<FloatType>(srcElemType)) {
303 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
306 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
307 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
309 isa<IntegerType>(dstElemType)) {
318 elements.push_back(dstAttr);
323 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
325 srcAttr, cast<IntegerType>(dstElemType), rewriter);
328 elements.push_back(dstAttr);
336 if (isa<RankedTensorType>(dstAttrType))
338 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
340 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
345 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
352struct ConstantScalarOpPattern final
353 :
public OpConversionPattern<arith::ConstantOp> {
357 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 Type srcType = constOp.getType();
360 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
361 if (shapedType.getNumElements() != 1)
363 srcType = shapedType.getElementType();
368 Attribute cstAttr = constOp.getValue();
369 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
370 cstAttr = elementsAttr.getSplatValue<Attribute>();
372 Type dstType = getTypeConverter()->convertType(srcType);
377 if (isa<FloatType>(srcType)) {
378 auto srcAttr = cast<FloatAttr>(cstAttr);
379 Attribute dstAttr = srcAttr;
383 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
384 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
386 dstType.getIntOrFloatBitWidth() == 8) {
391 }
else if (srcType != dstType) {
397 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
408 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
414 auto srcAttr = cast<IntegerAttr>(cstAttr);
415 IntegerAttr dstAttr =
419 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
435template <
typename SignedAbsOp>
438 assert(
lhs.getType() ==
rhs.getType());
439 assert(
lhs == signOperand ||
rhs == signOperand);
444 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
445 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
446 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
450 if (
lhs == signOperand)
451 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
453 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
454 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
455 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
463struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
467 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override {
469 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
470 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
471 adaptor.getOperands()[0], rewriter);
472 rewriter.replaceOp(op,
result);
479struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
483 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const override {
485 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
486 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
487 adaptor.getOperands()[0], rewriter);
488 rewriter.replaceOp(op,
result);
502template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
503struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
504 using OpConversionPattern<
Op>::OpConversionPattern;
507 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter)
const override {
509 assert(adaptor.getOperands().size() == 2);
510 Type dstType = this->getTypeConverter()->convertType(op.getType());
515 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
516 op, dstType, adaptor.getOperands());
518 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
519 op, dstType, adaptor.getOperands());
530struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter)
const override {
536 assert(adaptor.getOperands().size() == 2);
541 Type dstType = getTypeConverter()->convertType(op.getType());
545 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
546 adaptor.getOperands());
554struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
558 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const override {
560 assert(adaptor.getOperands().size() == 2);
565 Type dstType = getTypeConverter()->convertType(op.getType());
569 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
570 op, dstType, adaptor.getOperands());
581struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
585 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter)
const override {
587 Type srcType = adaptor.getOperands().front().getType();
591 Type dstType = getTypeConverter()->convertType(op.getType());
595 Location loc = op.getLoc();
596 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
597 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
598 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
599 op, dstType, adaptor.getOperands().front(), one, zero);
609struct IndexCastIndexI1Pattern final
610 :
public OpConversionPattern<arith::IndexCastOp> {
614 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter)
const override {
619 Type dstType = getTypeConverter()->convertType(op.getType());
623 Location loc = op.getLoc();
625 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
626 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
633struct IndexCastI1IndexPattern final
634 :
public OpConversionPattern<arith::IndexCastOp> {
638 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter)
const override {
643 Type dstType = getTypeConverter()->convertType(op.getType());
647 Location loc = op.getLoc();
648 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
649 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
650 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
662struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
666 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter)
const override {
668 Value operand = adaptor.getIn();
672 Location loc = op.getLoc();
673 Type dstType = getTypeConverter()->convertType(op.getType());
678 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
679 unsigned componentBitwidth = intTy.getWidth();
680 allOnes = spirv::ConstantOp::create(
681 rewriter, loc, intTy,
682 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
683 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
684 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
685 allOnes = spirv::ConstantOp::create(
686 rewriter, loc, vectorTy,
687 SplatElementsAttr::get(vectorTy,
688 APInt::getAllOnes(componentBitwidth)));
690 return rewriter.notifyMatchFailure(
691 loc, llvm::formatv(
"unhandled type: {0}", dstType));
694 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
703struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
707 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter)
const override {
709 Type srcType = adaptor.getIn().getType();
713 Type dstType = getTypeConverter()->convertType(op.getType());
717 if (dstType == srcType) {
725 assert(srcBW < dstBW);
727 rewriter, op.getLoc());
732 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
733 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
737 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
738 op, dstType, shiftLOp, shiftSize);
740 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
741 adaptor.getOperands());
754struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
758 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const override {
760 Type srcType = adaptor.getOperands().front().getType();
764 Type dstType = getTypeConverter()->convertType(op.getType());
769 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
770 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
771 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
772 op, dstType, adaptor.getOperands().front(), one, zero);
779struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
783 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
784 ConversionPatternRewriter &rewriter)
const override {
785 Type srcType = adaptor.getIn().getType();
789 Type dstType = getTypeConverter()->convertType(op.getType());
793 if (dstType == srcType) {
801 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
803 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
804 adaptor.getIn(), mask);
806 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
807 adaptor.getOperands());
819struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
823 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
824 ConversionPatternRewriter &rewriter)
const override {
825 Type dstType = getTypeConverter()->convertType(op.getType());
832 Location loc = op.getLoc();
833 auto srcType = adaptor.getOperands().front().getType();
835 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
836 Value maskedSrc = spirv::BitwiseAndOp::create(
837 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
838 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
840 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
841 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
842 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
849struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
853 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter)
const override {
855 Type srcType = adaptor.getIn().getType();
856 Type dstType = getTypeConverter()->convertType(op.getType());
863 if (dstType == srcType) {
870 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
871 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
872 adaptor.getIn(), mask);
875 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
876 adaptor.getOperands());
886static std::optional<spirv::FPRoundingMode>
887convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
888 switch (roundingMode) {
889 case arith::RoundingMode::downward:
890 return spirv::FPRoundingMode::RTN;
891 case arith::RoundingMode::to_nearest_even:
892 return spirv::FPRoundingMode::RTE;
893 case arith::RoundingMode::toward_zero:
894 return spirv::FPRoundingMode::RTZ;
895 case arith::RoundingMode::upward:
896 return spirv::FPRoundingMode::RTP;
897 case arith::RoundingMode::to_nearest_away:
902 llvm_unreachable(
"Unhandled rounding mode");
906template <
typename Op,
typename SPIRVOp>
907struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
908 using OpConversionPattern<
Op>::OpConversionPattern;
911 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
912 ConversionPatternRewriter &rewriter)
const override {
913 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
914 Type dstType = this->getTypeConverter()->convertType(op.getType());
921 if (dstType == srcType) {
924 rewriter.replaceOp(op, adaptor.getOperands().front());
927 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
928 if (
auto roundingModeOp =
929 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
930 if (arith::RoundingModeAttr roundingMode =
931 roundingModeOp.getRoundingModeAttr()) {
933 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
934 return rewriter.notifyMatchFailure(
936 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
941 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
942 op, dstType, adaptor.getOperands());
946 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
958class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
963 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter)
const override {
965 Type srcType = op.getLhs().getType();
968 Type dstType = getTypeConverter()->convertType(srcType);
972 switch (op.getPredicate()) {
973 case arith::CmpIPredicate::eq: {
974 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
978 case arith::CmpIPredicate::ne: {
979 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
980 op, adaptor.getLhs(), adaptor.getRhs());
983 case arith::CmpIPredicate::uge:
984 case arith::CmpIPredicate::ugt:
985 case arith::CmpIPredicate::ule:
986 case arith::CmpIPredicate::ult: {
989 Type type = rewriter.getI32Type();
990 if (
auto vectorType = dyn_cast<VectorType>(dstType))
991 type = VectorType::get(vectorType.getShape(), type);
993 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
995 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
997 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1009class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1014 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1015 ConversionPatternRewriter &rewriter)
const override {
1016 Type srcType = op.getLhs().getType();
1019 Type dstType = getTypeConverter()->convertType(srcType);
1023 switch (op.getPredicate()) {
1024#define DISPATCH(cmpPredicate, spirvOp) \
1025 case cmpPredicate: \
1026 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1027 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1028 !hasSameBitwidth(srcType, dstType)) { \
1029 return op.emitError( \
1030 "bitwidth emulation is not implemented yet on unsigned op"); \
1032 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1033 adaptor.getRhs()); \
1036 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1037 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1038 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1039 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1040 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1041 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1042 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1043 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1044 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1045 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1058class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1063 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter)
const override {
1065 switch (op.getPredicate()) {
1066#define DISPATCH(cmpPredicate, spirvOp) \
1067 case cmpPredicate: \
1068 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1069 adaptor.getRhs()); \
1073 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1074 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1075 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1076 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1077 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1078 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1080 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1081 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1082 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1083 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1084 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1085 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1098class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1103 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter)
const override {
1105 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1106 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1111 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1112 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1123class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1128 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter)
const override {
1130 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1131 op.getPredicate() != arith::CmpFPredicate::UNO)
1134 Location loc = op.getLoc();
1137 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1138 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1140 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1143 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1146 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1147 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1149 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1150 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1151 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1154 rewriter.replaceOp(op, replace);
1164class AddUIExtendedOpPattern final
1165 :
public OpConversionPattern<arith::AddUIExtendedOp> {
1169 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1170 ConversionPatternRewriter &rewriter)
const override {
1171 Type dstElemTy = adaptor.getLhs().getType();
1172 Location loc = op->getLoc();
1173 Value
result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1176 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1178 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1182 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1183 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1185 rewriter.replaceOp(op, {sumResult, carryResult});
1195template <
typename ArithMulOp,
typename SPIRVMulOp>
1196class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1198 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1200 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1201 ConversionPatternRewriter &rewriter)
const override {
1202 Location loc = op->getLoc();
1204 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1206 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1208 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1211 rewriter.replaceOp(op, {low, high});
1221class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1225 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1226 ConversionPatternRewriter &rewriter)
const override {
1227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1228 adaptor.getTrueValue(),
1229 adaptor.getFalseValue());
1240template <
typename Op,
typename SPIRVOp>
1241class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1243 using OpConversionPattern<
Op>::OpConversionPattern;
1245 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1246 ConversionPatternRewriter &rewriter)
const override {
1247 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1248 Type dstType = converter->convertType(op.getType());
1260 Location loc = op.
getLoc();
1262 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1264 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1265 rewriter.replaceOp(op, spirvOp);
1269 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1270 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1272 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1273 adaptor.getLhs(), spirvOp);
1274 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1275 adaptor.getRhs(), select1);
1277 rewriter.replaceOp(op, select2);
1288template <
typename Op,
typename SPIRVOp>
1289class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1290 template <
typename TargetOp>
1291 constexpr bool shouldInsertNanGuards()
const {
1292 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1296 using OpConversionPattern<
Op>::OpConversionPattern;
1298 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1299 ConversionPatternRewriter &rewriter)
const override {
1300 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1301 Type dstType = converter->convertType(op.getType());
1316 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1318 if (!shouldInsertNanGuards<SPIRVOp>() ||
1319 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1320 rewriter.replaceOp(op, spirvOp);
1324 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1325 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1327 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1328 adaptor.getRhs(), spirvOp);
1329 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1330 adaptor.getLhs(), select1);
1332 rewriter.replaceOp(op, select2);
1347 ConstantCompositeOpPattern,
1348 ConstantScalarOpPattern,
1349 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1350 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1351 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1355 RemSIOpGLPattern, RemSIOpCLPattern,
1356 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1357 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1358 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1359 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1368 ExtUIPattern, ExtUII1Pattern,
1369 ExtSIPattern, ExtSII1Pattern,
1370 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1371 TruncIPattern, TruncII1Pattern,
1372 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1373 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1374 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1375 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1376 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1377 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1378 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1379 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1380 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1381 CmpIOpBooleanPattern, CmpIOpPattern,
1382 CmpFOpNanNonePattern, CmpFOpPattern,
1383 AddUIExtendedOpPattern,
1384 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1385 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1388 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1389 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1390 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1391 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1397 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1398 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1399 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1400 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1405 >(typeConverter,
patterns.getContext());
1419struct ConvertArithToSPIRVPass
1423 void runOnOperation()
override {
1426 std::unique_ptr<SPIRVConversionTarget>
target =
1430 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1431 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1436 target->addLegalOp<UnrealizedConversionCastOp>();
1439 target->addIllegalDialect<arith::ArithDialect>();
1444 if (failed(applyPartialConversion(op, *
target, std::move(
patterns))))
1445 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ConvertArithToSPIRVPassBase Base
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.