21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
187 if (!adaptor.getOperands().empty() &&
190 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
193 return rewriter.notifyMatchFailure(
195 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
200 dstType != op.getType()) {
201 return op.
emitError(
"bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (
auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
218 rewriter.getUnitAttr());
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 rewriter.getUnitAttr());
233struct ConstantCompositeOpPattern final
234 :
public OpConversionPattern<arith::ConstantOp> {
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter)
const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
249 Type dstType = getTypeConverter()->convertType(srcType);
256 if (
auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 }
else if (
auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
264 return constOp->emitError(
"could not find resource blob");
271 return constOp->emitError(
"resource is not a valid buffer");
276 return constOp->emitError(
"unsupported elements attribute");
279 ShapedType dstAttrType = dstElementsAttr.
getType();
283 if (srcType.getRank() > 1) {
284 if (isa<RankedTensorType>(srcType)) {
285 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286 srcType.getElementType());
287 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
294 Type srcElemType = srcType.getElementType();
298 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
299 dstElemType = arrayType.getElementType();
301 dstElemType = cast<VectorType>(dstType).getElementType();
305 if (srcElemType != dstElemType) {
307 if (isa<FloatType>(srcElemType)) {
308 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
311 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314 isa<IntegerType>(dstElemType)) {
323 elements.push_back(dstAttr);
328 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
330 srcAttr, cast<IntegerType>(dstElemType), rewriter);
333 elements.push_back(dstAttr);
341 if (isa<RankedTensorType>(dstAttrType))
343 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
345 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
350 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
357struct ConstantScalarOpPattern final
358 :
public OpConversionPattern<arith::ConstantOp> {
362 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 Type srcType = constOp.getType();
365 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
366 if (shapedType.getNumElements() != 1)
368 srcType = shapedType.getElementType();
373 Attribute cstAttr = constOp.getValue();
374 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
375 cstAttr = elementsAttr.getSplatValue<Attribute>();
377 Type dstType = getTypeConverter()->convertType(srcType);
382 if (isa<FloatType>(srcType)) {
383 auto srcAttr = cast<FloatAttr>(cstAttr);
384 Attribute dstAttr = srcAttr;
388 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
389 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391 dstType.getIntOrFloatBitWidth() == 8) {
396 }
else if (srcType != dstType) {
402 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419 auto srcAttr = cast<IntegerAttr>(cstAttr);
420 IntegerAttr dstAttr =
424 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
440template <
typename SignedAbsOp>
443 assert(
lhs.getType() ==
rhs.getType());
444 assert(
lhs == signOperand ||
rhs == signOperand);
449 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
450 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
451 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
455 if (
lhs == signOperand)
456 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
458 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
459 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
460 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
468struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
472 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter)
const override {
474 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
475 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
476 adaptor.getOperands()[0], rewriter);
477 rewriter.replaceOp(op,
result);
484struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
488 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter)
const override {
490 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
491 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
492 adaptor.getOperands()[0], rewriter);
493 rewriter.replaceOp(op,
result);
507template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
508struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
509 using OpConversionPattern<
Op>::OpConversionPattern;
512 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
513 ConversionPatternRewriter &rewriter)
const override {
514 assert(adaptor.getOperands().size() == 2);
515 Type dstType = this->getTypeConverter()->convertType(op.getType());
520 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
521 op, dstType, adaptor.getOperands());
523 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
524 op, dstType, adaptor.getOperands());
535struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
539 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter)
const override {
541 assert(adaptor.getOperands().size() == 2);
546 Type dstType = getTypeConverter()->convertType(op.getType());
550 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
551 adaptor.getOperands());
559struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
563 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter)
const override {
565 assert(adaptor.getOperands().size() == 2);
570 Type dstType = getTypeConverter()->convertType(op.getType());
574 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
575 op, dstType, adaptor.getOperands());
595template <
typename ArithOp,
typename SPIRVOp>
596struct BoolIOpPattern final :
public OpConversionPattern<ArithOp> {
597 BoolIOpPattern(
const TypeConverter &converter, MLIRContext *context)
600 : OpConversionPattern<ArithOp>(converter, context, 2) {}
603 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
604 ConversionPatternRewriter &rewriter)
const override {
608 Type dstType = this->getTypeConverter()->convertType(op.getType());
612 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
626template <
typename ArithOp>
627struct BoolIOpAndNotPattern final :
public OpConversionPattern<ArithOp> {
628 BoolIOpAndNotPattern(
const TypeConverter &converter, MLIRContext *context)
631 : OpConversionPattern<ArithOp>(converter, context, 2) {}
634 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const override {
639 Type dstType = this->getTypeConverter()->convertType(op.getType());
643 Location loc = op.getLoc();
644 Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
645 adaptor.getOperands()[1]);
646 rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
647 op, dstType, adaptor.getOperands()[0], notRhs);
654struct ShRSIBoolPattern final :
public OpConversionPattern<arith::ShRSIOp> {
655 ShRSIBoolPattern(
const TypeConverter &converter, MLIRContext *context)
658 : OpConversionPattern<arith::ShRSIOp>(converter, context,
662 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter)
const override {
667 rewriter.replaceOp(op, adaptor.getOperands().front());
678struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
682 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
683 ConversionPatternRewriter &rewriter)
const override {
684 Type srcType = adaptor.getOperands().front().getType();
688 Type dstType = getTypeConverter()->convertType(op.getType());
692 Location loc = op.getLoc();
693 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696 op, dstType, adaptor.getOperands().front(), one, zero);
708template <
typename ArithOp,
typename SPIRVOp,
bool IsSigned>
709struct IntToFPPattern final :
public OpConversionPattern<ArithOp> {
710 using OpConversionPattern<ArithOp>::OpConversionPattern;
713 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const override {
715 Type srcType = adaptor.getOperands().front().getType();
719 Type dstType = this->getTypeConverter()->convertType(op.getType());
724 unsigned originalBitwidth =
726 unsigned convertedBitwidth =
729 if (originalBitwidth >= convertedBitwidth) {
730 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
737 if constexpr (IsSigned) {
739 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
742 Value shifted = spirv::ShiftLeftLogicalOp::create(
743 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
744 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
749 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
751 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
752 adaptor.getIn(), mask);
754 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
764struct IndexCastIndexI1Pattern final
765 :
public OpConversionPattern<arith::IndexCastOp> {
769 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
770 ConversionPatternRewriter &rewriter)
const override {
774 Type dstType = getTypeConverter()->convertType(op.getType());
780 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
781 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
788struct IndexCastI1IndexPattern final
789 :
public OpConversionPattern<arith::IndexCastOp> {
793 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter)
const override {
798 Type dstType = getTypeConverter()->convertType(op.getType());
803 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
804 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
805 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
817struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
821 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
822 ConversionPatternRewriter &rewriter)
const override {
823 Value operand = adaptor.getIn();
827 Location loc = op.getLoc();
828 Type dstType = getTypeConverter()->convertType(op.getType());
833 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
834 unsigned componentBitwidth = intTy.getWidth();
835 allOnes = spirv::ConstantOp::create(
836 rewriter, loc, intTy,
837 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
838 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
839 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
840 allOnes = spirv::ConstantOp::create(
841 rewriter, loc, vectorTy,
842 SplatElementsAttr::get(vectorTy,
843 APInt::getAllOnes(componentBitwidth)));
845 return rewriter.notifyMatchFailure(
846 loc, llvm::formatv(
"unhandled type: {0}", dstType));
849 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
850 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
858struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
862 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter)
const override {
864 Type srcType = adaptor.getIn().getType();
868 Type dstType = getTypeConverter()->convertType(op.getType());
872 if (dstType == srcType) {
880 assert(srcBW < dstBW);
882 rewriter, op.getLoc());
884 return rewriter.notifyMatchFailure(op,
"unsupported type for shift");
889 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
890 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
894 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
895 op, dstType, shiftLOp, shiftSize);
897 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
898 adaptor.getOperands());
911struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
915 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
916 ConversionPatternRewriter &rewriter)
const override {
917 Type srcType = adaptor.getOperands().front().getType();
921 Type dstType = getTypeConverter()->convertType(op.getType());
925 Location loc = op.getLoc();
926 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
927 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
928 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
929 op, dstType, adaptor.getOperands().front(), one, zero);
936struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
940 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter)
const override {
942 Type srcType = adaptor.getIn().getType();
946 Type dstType = getTypeConverter()->convertType(op.getType());
950 if (dstType == srcType) {
958 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
961 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
962 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
963 adaptor.getIn(), mask);
965 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
966 adaptor.getOperands());
978struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
982 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter)
const override {
984 Type dstType = getTypeConverter()->convertType(op.getType());
991 Location loc = op.getLoc();
992 auto srcType = adaptor.getOperands().front().getType();
994 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
995 Value maskedSrc = spirv::BitwiseAndOp::create(
996 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
997 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
999 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
1000 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
1001 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
1008struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
1012 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter)
const override {
1014 Type srcType = adaptor.getIn().getType();
1015 Type dstType = getTypeConverter()->convertType(op.getType());
1022 if (dstType == srcType) {
1029 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
1031 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
1032 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
1033 adaptor.getIn(), mask);
1036 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
1037 adaptor.getOperands());
1047static std::optional<spirv::FPRoundingMode>
1048convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
1049 switch (roundingMode) {
1050 case arith::RoundingMode::downward:
1051 return spirv::FPRoundingMode::RTN;
1052 case arith::RoundingMode::to_nearest_even:
1053 return spirv::FPRoundingMode::RTE;
1054 case arith::RoundingMode::toward_zero:
1055 return spirv::FPRoundingMode::RTZ;
1056 case arith::RoundingMode::upward:
1057 return spirv::FPRoundingMode::RTP;
1058 case arith::RoundingMode::to_nearest_away:
1061 return std::nullopt;
1063 llvm_unreachable(
"Unhandled rounding mode");
1067template <
typename Op,
typename SPIRVOp>
1068struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
1069 using OpConversionPattern<
Op>::OpConversionPattern;
1072 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1073 ConversionPatternRewriter &rewriter)
const override {
1074 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
1075 Type dstType = this->getTypeConverter()->convertType(op.getType());
1082 if (dstType == srcType) {
1085 rewriter.replaceOp(op, adaptor.getOperands().front());
1088 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
1089 if (
auto roundingModeOp =
1090 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
1091 if (arith::RoundingModeAttr roundingMode =
1092 roundingModeOp.getRoundingModeAttr()) {
1094 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
1095 return rewriter.notifyMatchFailure(
1097 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
1102 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1103 op, dstType, adaptor.getOperands());
1107 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1119class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
1124 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter)
const override {
1126 Type srcType = op.getLhs().getType();
1129 Type dstType = getTypeConverter()->convertType(srcType);
1133 switch (op.getPredicate()) {
1134 case arith::CmpIPredicate::eq: {
1135 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1139 case arith::CmpIPredicate::ne: {
1140 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1141 op, adaptor.getLhs(), adaptor.getRhs());
1144 case arith::CmpIPredicate::uge:
1145 case arith::CmpIPredicate::ugt:
1146 case arith::CmpIPredicate::ule:
1147 case arith::CmpIPredicate::ult: {
1150 Type type = rewriter.getI32Type();
1151 if (
auto vectorType = dyn_cast<VectorType>(dstType))
1152 type = VectorType::get(vectorType.getShape(), type);
1154 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1156 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1158 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1170class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1175 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const override {
1177 Type srcType = op.getLhs().getType();
1180 Type dstType = getTypeConverter()->convertType(srcType);
1184 switch (op.getPredicate()) {
1185#define DISPATCH(cmpPredicate, spirvOp) \
1186 case cmpPredicate: \
1187 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1188 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1189 !hasSameBitwidth(srcType, dstType)) { \
1190 return op.emitError( \
1191 "bitwidth emulation is not implemented yet on unsigned op"); \
1193 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1194 adaptor.getRhs()); \
1197 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1198 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1199 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1200 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1201 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1202 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1203 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1204 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1205 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1206 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1219class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1224 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1225 ConversionPatternRewriter &rewriter)
const override {
1226 switch (op.getPredicate()) {
1227#define DISPATCH(cmpPredicate, spirvOp) \
1228 case cmpPredicate: \
1229 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1230 adaptor.getRhs()); \
1234 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1235 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1236 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1237 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1238 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1239 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1241 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1242 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1243 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1244 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1245 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1246 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1259class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1264 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const override {
1266 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1267 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1272 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1273 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1284class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1289 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1290 ConversionPatternRewriter &rewriter)
const override {
1291 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1292 op.getPredicate() != arith::CmpFPredicate::UNO)
1295 Location loc = op.getLoc();
1298 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1299 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1301 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1304 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1307 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1308 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1310 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1311 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1312 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1315 rewriter.replaceOp(op, replace);
1325class AddUIExtendedOpPattern final
1326 :
public OpConversionPattern<arith::AddUIExtendedOp> {
1330 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1331 ConversionPatternRewriter &rewriter)
const override {
1332 Type dstElemTy = adaptor.getLhs().getType();
1333 Location loc = op->getLoc();
1334 Value
result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1337 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1339 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1343 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1344 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1346 rewriter.replaceOp(op, {sumResult, carryResult});
1356template <
typename ArithMulOp,
typename SPIRVMulOp>
1357class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1359 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1361 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1362 ConversionPatternRewriter &rewriter)
const override {
1363 Location loc = op->getLoc();
1365 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1367 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1369 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1372 rewriter.replaceOp(op, {low, high});
1382class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1386 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1387 ConversionPatternRewriter &rewriter)
const override {
1388 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1389 adaptor.getTrueValue(),
1390 adaptor.getFalseValue());
1401template <
typename Op,
typename SPIRVOp>
1402class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1404 using OpConversionPattern<
Op>::OpConversionPattern;
1406 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1407 ConversionPatternRewriter &rewriter)
const override {
1408 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1409 Type dstType = converter->convertType(op.getType());
1421 Location loc = op.
getLoc();
1423 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1425 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1426 rewriter.replaceOp(op, spirvOp);
1430 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1431 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1433 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1434 adaptor.getLhs(), spirvOp);
1435 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1436 adaptor.getRhs(), select1);
1438 rewriter.replaceOp(op, select2);
1449template <
typename Op,
typename SPIRVOp>
1450class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1451 template <
typename TargetOp>
1452 constexpr bool shouldInsertNanGuards()
const {
1453 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1457 using OpConversionPattern<
Op>::OpConversionPattern;
1459 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1460 ConversionPatternRewriter &rewriter)
const override {
1461 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1462 Type dstType = converter->convertType(op.getType());
1477 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1479 if (!shouldInsertNanGuards<SPIRVOp>() ||
1480 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1481 rewriter.replaceOp(op, spirvOp);
1485 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1486 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1488 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1489 adaptor.getRhs(), spirvOp);
1490 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1491 adaptor.getLhs(), select1);
1493 rewriter.replaceOp(op, select2);
1508 ConstantCompositeOpPattern,
1509 ConstantScalarOpPattern,
1510 BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
1511 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1512 BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
1513 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1514 BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
1515 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1516 BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>,
1518 BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>,
1520 BoolIOpAndNotPattern<arith::RemUIOp>,
1522 BoolIOpAndNotPattern<arith::RemSIOp>,
1523 RemSIOpGLPattern, RemSIOpCLPattern,
1524 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1525 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1526 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1527 BoolIOpAndNotPattern<arith::ShLIOp>,
1528 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1529 BoolIOpAndNotPattern<arith::ShRUIOp>,
1539 ExtUIPattern, ExtUII1Pattern,
1540 ExtSIPattern, ExtSII1Pattern,
1541 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1542 TruncIPattern, TruncII1Pattern,
1543 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1544 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1546 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1547 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1548 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1549 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1550 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1551 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1552 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1553 CmpIOpBooleanPattern, CmpIOpPattern,
1554 CmpFOpNanNonePattern, CmpFOpPattern,
1555 AddUIExtendedOpPattern,
1556 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1557 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1560 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1561 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1562 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1563 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1564 BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>,
1565 BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>,
1566 BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>,
1567 BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>,
1573 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1574 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1575 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1576 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1586 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1595struct ConvertArithToSPIRVPass
1599 void runOnOperation()
override {
1602 std::unique_ptr<SPIRVConversionTarget>
target =
1606 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1607 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1612 target->addLegalOp<UnrealizedConversionCastOp>();
1615 target->addIllegalDialect<arith::ArithDialect>();
1620 if (failed(applyPartialConversion(op, *
target, std::move(patterns))))
1621 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ConvertArithToSPIRVPassBase Base
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.