21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
187 if (!adaptor.getOperands().empty() &&
190 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
193 return rewriter.notifyMatchFailure(
195 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
200 dstType != op.getType()) {
201 return op.
emitError(
"bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (
auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
218 rewriter.getUnitAttr());
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 rewriter.getUnitAttr());
233struct ConstantCompositeOpPattern final
234 :
public OpConversionPattern<arith::ConstantOp> {
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter)
const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
249 Type dstType = getTypeConverter()->convertType(srcType);
256 if (
auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 }
else if (
auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
264 return constOp->emitError(
"could not find resource blob");
271 return constOp->emitError(
"resource is not a valid buffer");
276 return constOp->emitError(
"unsupported elements attribute");
279 ShapedType dstAttrType = dstElementsAttr.
getType();
283 if (srcType.getRank() > 1) {
284 if (isa<RankedTensorType>(srcType)) {
285 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286 srcType.getElementType());
287 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
294 Type srcElemType = srcType.getElementType();
298 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
299 dstElemType = arrayType.getElementType();
301 dstElemType = cast<VectorType>(dstType).getElementType();
305 if (srcElemType != dstElemType) {
307 if (isa<FloatType>(srcElemType)) {
308 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
311 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314 isa<IntegerType>(dstElemType)) {
323 elements.push_back(dstAttr);
328 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
330 srcAttr, cast<IntegerType>(dstElemType), rewriter);
333 elements.push_back(dstAttr);
341 if (isa<RankedTensorType>(dstAttrType))
343 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
345 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
350 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
357struct ConstantScalarOpPattern final
358 :
public OpConversionPattern<arith::ConstantOp> {
362 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 Type srcType = constOp.getType();
365 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
366 if (shapedType.getNumElements() != 1)
368 srcType = shapedType.getElementType();
373 Attribute cstAttr = constOp.getValue();
374 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
375 cstAttr = elementsAttr.getSplatValue<Attribute>();
377 Type dstType = getTypeConverter()->convertType(srcType);
382 if (isa<FloatType>(srcType)) {
383 auto srcAttr = cast<FloatAttr>(cstAttr);
384 Attribute dstAttr = srcAttr;
388 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
389 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391 dstType.getIntOrFloatBitWidth() == 8) {
396 }
else if (srcType != dstType) {
402 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419 auto srcAttr = cast<IntegerAttr>(cstAttr);
420 IntegerAttr dstAttr =
424 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
440template <
typename SignedAbsOp>
443 assert(
lhs.getType() ==
rhs.getType());
444 assert(
lhs == signOperand ||
rhs == signOperand);
449 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
450 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
451 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
455 if (
lhs == signOperand)
456 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
458 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
459 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
460 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
468struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
472 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter)
const override {
474 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
475 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
476 adaptor.getOperands()[0], rewriter);
477 rewriter.replaceOp(op,
result);
484struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
488 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter)
const override {
490 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
491 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
492 adaptor.getOperands()[0], rewriter);
493 rewriter.replaceOp(op,
result);
507template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
508struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
509 using OpConversionPattern<
Op>::OpConversionPattern;
512 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
513 ConversionPatternRewriter &rewriter)
const override {
514 assert(adaptor.getOperands().size() == 2);
515 Type dstType = this->getTypeConverter()->convertType(op.getType());
520 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
521 op, dstType, adaptor.getOperands());
523 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
524 op, dstType, adaptor.getOperands());
535struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
539 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter)
const override {
541 assert(adaptor.getOperands().size() == 2);
546 Type dstType = getTypeConverter()->convertType(op.getType());
550 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
551 adaptor.getOperands());
559struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
563 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter)
const override {
565 assert(adaptor.getOperands().size() == 2);
570 Type dstType = getTypeConverter()->convertType(op.getType());
574 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
575 op, dstType, adaptor.getOperands());
595template <
typename ArithOp,
typename SPIRVOp>
596struct BoolIOpPattern final :
public OpConversionPattern<ArithOp> {
597 BoolIOpPattern(
const TypeConverter &converter, MLIRContext *context)
600 : OpConversionPattern<ArithOp>(converter, context, 2) {}
603 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
604 ConversionPatternRewriter &rewriter)
const override {
608 Type dstType = this->getTypeConverter()->convertType(op.getType());
612 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
626template <
typename ArithOp>
627struct BoolIOpAndNotPattern final :
public OpConversionPattern<ArithOp> {
628 BoolIOpAndNotPattern(
const TypeConverter &converter, MLIRContext *context)
631 : OpConversionPattern<ArithOp>(converter, context, 2) {}
634 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const override {
639 Type dstType = this->getTypeConverter()->convertType(op.getType());
643 Location loc = op.getLoc();
644 Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
645 adaptor.getOperands()[1]);
646 rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
647 op, dstType, adaptor.getOperands()[0], notRhs);
654struct ShRSIBoolPattern final :
public OpConversionPattern<arith::ShRSIOp> {
655 ShRSIBoolPattern(
const TypeConverter &converter, MLIRContext *context)
658 : OpConversionPattern<arith::ShRSIOp>(converter, context,
662 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter)
const override {
667 rewriter.replaceOp(op, adaptor.getOperands().front());
679template <
typename ArithOp>
680struct BoolToValuePattern final :
public OpConversionPattern<ArithOp> {
681 using OpConversionPattern<ArithOp>::OpConversionPattern;
684 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
685 ConversionPatternRewriter &rewriter)
const override {
686 Type srcType = adaptor.getOperands().front().getType();
690 Type dstType = this->getTypeConverter()->convertType(op.getType());
694 Location loc = op.getLoc();
695 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
696 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
697 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
698 op, dstType, adaptor.getOperands().front(), one, zero);
714template <
typename ArithOp,
typename SPIRVOp,
bool IsSigned>
715struct IntToFPPattern final :
public OpConversionPattern<ArithOp> {
716 using OpConversionPattern<ArithOp>::OpConversionPattern;
719 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
720 ConversionPatternRewriter &rewriter)
const override {
721 Type srcType = adaptor.getOperands().front().getType();
725 Type dstType = this->getTypeConverter()->convertType(op.getType());
730 unsigned originalBitwidth =
732 unsigned convertedBitwidth =
735 if (originalBitwidth >= convertedBitwidth) {
736 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
741 Location loc = op.getLoc();
743 if constexpr (IsSigned) {
745 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
748 Value shifted = spirv::ShiftLeftLogicalOp::create(
749 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
750 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
755 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
757 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
758 adaptor.getIn(), mask);
760 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
770struct IndexCastIndexI1Pattern final
771 :
public OpConversionPattern<arith::IndexCastOp> {
775 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter)
const override {
780 Type dstType = getTypeConverter()->convertType(op.getType());
784 Location loc = op.getLoc();
786 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
787 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
799struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
803 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
804 ConversionPatternRewriter &rewriter)
const override {
805 Value operand = adaptor.getIn();
809 Location loc = op.getLoc();
810 Type dstType = getTypeConverter()->convertType(op.getType());
815 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
816 unsigned componentBitwidth = intTy.getWidth();
817 allOnes = spirv::ConstantOp::create(
818 rewriter, loc, intTy,
819 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
820 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
821 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
822 allOnes = spirv::ConstantOp::create(
823 rewriter, loc, vectorTy,
824 SplatElementsAttr::get(vectorTy,
825 APInt::getAllOnes(componentBitwidth)));
827 return rewriter.notifyMatchFailure(
828 loc, llvm::formatv(
"unhandled type: {0}", dstType));
831 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
832 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
840struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
844 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
845 ConversionPatternRewriter &rewriter)
const override {
846 Type srcType = adaptor.getIn().getType();
850 Type dstType = getTypeConverter()->convertType(op.getType());
854 if (dstType == srcType) {
862 assert(srcBW < dstBW);
864 rewriter, op.getLoc());
866 return rewriter.notifyMatchFailure(op,
"unsupported type for shift");
871 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
872 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
876 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
877 op, dstType, shiftLOp, shiftSize);
879 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
880 adaptor.getOperands());
893struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
897 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
898 ConversionPatternRewriter &rewriter)
const override {
899 Type srcType = adaptor.getIn().getType();
903 Type dstType = getTypeConverter()->convertType(op.getType());
907 if (dstType == srcType) {
915 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
918 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
919 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
920 adaptor.getIn(), mask);
922 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
923 adaptor.getOperands());
935struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
939 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter)
const override {
941 Type dstType = getTypeConverter()->convertType(op.getType());
948 Location loc = op.getLoc();
949 auto srcType = adaptor.getOperands().front().getType();
951 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
952 Value maskedSrc = spirv::BitwiseAndOp::create(
953 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
954 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
956 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
957 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
958 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
965struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
969 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter)
const override {
971 Type srcType = adaptor.getIn().getType();
972 Type dstType = getTypeConverter()->convertType(op.getType());
979 if (dstType == srcType) {
986 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
988 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
989 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
990 adaptor.getIn(), mask);
993 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
994 adaptor.getOperands());
1004static std::optional<spirv::FPRoundingMode>
1005convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
1006 switch (roundingMode) {
1007 case arith::RoundingMode::downward:
1008 return spirv::FPRoundingMode::RTN;
1009 case arith::RoundingMode::to_nearest_even:
1010 return spirv::FPRoundingMode::RTE;
1011 case arith::RoundingMode::toward_zero:
1012 return spirv::FPRoundingMode::RTZ;
1013 case arith::RoundingMode::upward:
1014 return spirv::FPRoundingMode::RTP;
1015 case arith::RoundingMode::to_nearest_away:
1018 return std::nullopt;
1020 llvm_unreachable(
"Unhandled rounding mode");
1024template <
typename Op,
typename SPIRVOp>
1025struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
1026 using OpConversionPattern<
Op>::OpConversionPattern;
1029 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1030 ConversionPatternRewriter &rewriter)
const override {
1031 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
1032 Type dstType = this->getTypeConverter()->convertType(op.getType());
1039 if (dstType == srcType) {
1042 rewriter.replaceOp(op, adaptor.getOperands().front());
1045 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
1046 if (
auto roundingModeOp =
1047 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
1048 if (arith::RoundingModeAttr roundingMode =
1049 roundingModeOp.getRoundingModeAttr()) {
1051 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
1052 return rewriter.notifyMatchFailure(
1054 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
1059 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1060 op, dstType, adaptor.getOperands());
1064 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1076class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
1081 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1082 ConversionPatternRewriter &rewriter)
const override {
1083 Type srcType = op.getLhs().getType();
1086 Type dstType = getTypeConverter()->convertType(srcType);
1090 switch (op.getPredicate()) {
1091 case arith::CmpIPredicate::eq: {
1092 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1096 case arith::CmpIPredicate::ne: {
1097 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1098 op, adaptor.getLhs(), adaptor.getRhs());
1101 case arith::CmpIPredicate::uge:
1102 case arith::CmpIPredicate::ugt:
1103 case arith::CmpIPredicate::ule:
1104 case arith::CmpIPredicate::ult: {
1107 Type type = rewriter.getI32Type();
1108 if (
auto vectorType = dyn_cast<VectorType>(dstType))
1109 type = VectorType::get(vectorType.getShape(), type);
1111 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1113 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1115 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1127class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1132 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter)
const override {
1134 Type srcType = op.getLhs().getType();
1137 Type dstType = getTypeConverter()->convertType(srcType);
1141 switch (op.getPredicate()) {
1142#define DISPATCH(cmpPredicate, spirvOp) \
1143 case cmpPredicate: \
1144 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1145 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1146 !hasSameBitwidth(srcType, dstType)) { \
1147 return op.emitError( \
1148 "bitwidth emulation is not implemented yet on unsigned op"); \
1150 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1151 adaptor.getRhs()); \
1154 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1155 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1156 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1157 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1158 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1159 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1160 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1161 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1162 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1163 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1176class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1181 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1182 ConversionPatternRewriter &rewriter)
const override {
1183 switch (op.getPredicate()) {
1184#define DISPATCH(cmpPredicate, spirvOp) \
1185 case cmpPredicate: \
1186 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1187 adaptor.getRhs()); \
1191 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1192 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1193 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1194 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1195 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1196 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1198 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1199 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1200 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1201 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1202 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1203 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1216class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1221 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1222 ConversionPatternRewriter &rewriter)
const override {
1223 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1224 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1229 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1230 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1241class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1246 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1247 ConversionPatternRewriter &rewriter)
const override {
1248 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1249 op.getPredicate() != arith::CmpFPredicate::UNO)
1252 Location loc = op.getLoc();
1255 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1256 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1258 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1261 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1264 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1265 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1267 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1268 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1269 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1272 rewriter.replaceOp(op, replace);
1283template <
typename ArithExtendedOp,
typename SPIRVExtendedOp>
1284class BinaryExtendedOpPattern final
1285 :
public OpConversionPattern<ArithExtendedOp> {
1287 using OpConversionPattern<ArithExtendedOp>::OpConversionPattern;
1289 matchAndRewrite(ArithExtendedOp op,
typename ArithExtendedOp::Adaptor adaptor,
1290 ConversionPatternRewriter &rewriter)
const override {
1291 Type dstElemTy = adaptor.getLhs().getType();
1292 Location loc = op->getLoc();
1293 Value
result = SPIRVExtendedOp::create(rewriter, loc, adaptor.getLhs(),
1296 Value valueResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1298 Value flagValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1302 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1303 Value flagResult = spirv::IEqualOp::create(rewriter, loc, flagValue, one);
1305 rewriter.replaceOp(op, {valueResult, flagResult});
1315template <
typename ArithMulOp,
typename SPIRVMulOp>
1316class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1318 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1320 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1321 ConversionPatternRewriter &rewriter)
const override {
1322 Location loc = op->getLoc();
1324 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1326 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1328 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1331 rewriter.replaceOp(op, {low, high});
1341class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1345 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1346 ConversionPatternRewriter &rewriter)
const override {
1347 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1348 adaptor.getTrueValue(),
1349 adaptor.getFalseValue());
1360template <
typename Op,
typename SPIRVOp>
1361class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1363 using OpConversionPattern<
Op>::OpConversionPattern;
1365 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1366 ConversionPatternRewriter &rewriter)
const override {
1367 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1368 Type dstType = converter->convertType(op.getType());
1380 Location loc = op.
getLoc();
1382 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1384 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1385 rewriter.replaceOp(op, spirvOp);
1389 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1390 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1392 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1393 adaptor.getLhs(), spirvOp);
1394 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1395 adaptor.getRhs(), select1);
1397 rewriter.replaceOp(op, select2);
1408template <
typename Op,
typename SPIRVOp>
1409class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1410 template <
typename TargetOp>
1411 constexpr bool shouldInsertNanGuards()
const {
1412 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1416 using OpConversionPattern<
Op>::OpConversionPattern;
1418 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1419 ConversionPatternRewriter &rewriter)
const override {
1420 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1421 Type dstType = converter->convertType(op.getType());
1436 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1438 if (!shouldInsertNanGuards<SPIRVOp>() ||
1439 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1440 rewriter.replaceOp(op, spirvOp);
1444 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1445 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1447 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1448 adaptor.getRhs(), spirvOp);
1449 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1450 adaptor.getLhs(), select1);
1452 rewriter.replaceOp(op, select2);
1467 ConstantCompositeOpPattern,
1468 ConstantScalarOpPattern,
1469 BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
1470 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1471 BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
1472 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1473 BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
1474 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1475 BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>,
1477 BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>,
1479 BoolIOpAndNotPattern<arith::RemUIOp>,
1481 BoolIOpAndNotPattern<arith::RemSIOp>,
1482 RemSIOpGLPattern, RemSIOpCLPattern,
1483 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1484 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1485 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1486 BoolIOpAndNotPattern<arith::ShLIOp>,
1487 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1488 BoolIOpAndNotPattern<arith::ShRUIOp>,
1498 ExtUIPattern, BoolToValuePattern<arith::ExtUIOp>,
1499 ExtSIPattern, ExtSII1Pattern,
1500 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1501 TruncIPattern, TruncII1Pattern,
1502 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1503 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1504 BoolToValuePattern<arith::UIToFPOp>,
1505 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1506 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1507 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1508 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1509 IndexCastIndexI1Pattern, BoolToValuePattern<arith::IndexCastOp>,
1510 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1511 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1512 CmpIOpBooleanPattern, CmpIOpPattern,
1513 CmpFOpNanNonePattern, CmpFOpPattern,
1514 BinaryExtendedOpPattern<arith::AddUIExtendedOp, spirv::IAddCarryOp>,
1515 BinaryExtendedOpPattern<arith::SubUIExtendedOp, spirv::ISubBorrowOp>,
1516 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1517 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1520 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1521 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1522 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1523 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1524 BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>,
1525 BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>,
1526 BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>,
1527 BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>,
1533 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1534 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1535 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1536 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1546 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1555struct ConvertArithToSPIRVPass
1556 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1559 void runOnOperation()
override {
1562 std::unique_ptr<SPIRVConversionTarget>
target =
1566 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1567 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1572 target->addLegalOp<UnrealizedConversionCastOp>();
1575 target->addIllegalDialect<arith::ArithDialect>();
1580 if (failed(applyPartialConversion(op, *
target, std::move(patterns))))
1581 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.