21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
177template <
typename Op,
typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<
Op>::OpConversionPattern;
182 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
184 assert(adaptor.getOperands().size() <= 3);
187 if (!adaptor.getOperands().empty() &&
190 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
193 return rewriter.notifyMatchFailure(
195 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
200 dstType != op.getType()) {
201 return op.
emitError(
"bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (
auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
218 rewriter.getUnitAttr());
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 rewriter.getUnitAttr());
233struct ConstantCompositeOpPattern final
234 :
public OpConversionPattern<arith::ConstantOp> {
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter)
const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
249 Type dstType = getTypeConverter()->convertType(srcType);
256 if (
auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 }
else if (
auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
264 return constOp->emitError(
"could not find resource blob");
271 return constOp->emitError(
"resource is not a valid buffer");
276 return constOp->emitError(
"unsupported elements attribute");
279 ShapedType dstAttrType = dstElementsAttr.
getType();
283 if (srcType.getRank() > 1) {
284 if (isa<RankedTensorType>(srcType)) {
285 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286 srcType.getElementType());
287 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
294 Type srcElemType = srcType.getElementType();
298 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
299 dstElemType = arrayType.getElementType();
301 dstElemType = cast<VectorType>(dstType).getElementType();
305 if (srcElemType != dstElemType) {
307 if (isa<FloatType>(srcElemType)) {
308 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
311 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314 isa<IntegerType>(dstElemType)) {
323 elements.push_back(dstAttr);
328 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
330 srcAttr, cast<IntegerType>(dstElemType), rewriter);
333 elements.push_back(dstAttr);
341 if (isa<RankedTensorType>(dstAttrType))
343 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
345 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
350 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
357struct ConstantScalarOpPattern final
358 :
public OpConversionPattern<arith::ConstantOp> {
362 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 Type srcType = constOp.getType();
365 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
366 if (shapedType.getNumElements() != 1)
368 srcType = shapedType.getElementType();
373 Attribute cstAttr = constOp.getValue();
374 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
375 cstAttr = elementsAttr.getSplatValue<Attribute>();
377 Type dstType = getTypeConverter()->convertType(srcType);
382 if (isa<FloatType>(srcType)) {
383 auto srcAttr = cast<FloatAttr>(cstAttr);
384 Attribute dstAttr = srcAttr;
388 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
389 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391 dstType.getIntOrFloatBitWidth() == 8) {
396 }
else if (srcType != dstType) {
402 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419 auto srcAttr = cast<IntegerAttr>(cstAttr);
420 IntegerAttr dstAttr =
424 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
440template <
typename SignedAbsOp>
443 assert(
lhs.getType() ==
rhs.getType());
444 assert(
lhs == signOperand ||
rhs == signOperand);
449 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
450 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
451 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
455 if (
lhs == signOperand)
456 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
458 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
459 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
460 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
468struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
472 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter)
const override {
474 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
475 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
476 adaptor.getOperands()[0], rewriter);
477 rewriter.replaceOp(op,
result);
484struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
488 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter)
const override {
490 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
491 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
492 adaptor.getOperands()[0], rewriter);
493 rewriter.replaceOp(op,
result);
507template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
508struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
509 using OpConversionPattern<
Op>::OpConversionPattern;
512 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
513 ConversionPatternRewriter &rewriter)
const override {
514 assert(adaptor.getOperands().size() == 2);
515 Type dstType = this->getTypeConverter()->convertType(op.getType());
520 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
521 op, dstType, adaptor.getOperands());
523 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
524 op, dstType, adaptor.getOperands());
535struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
539 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter)
const override {
541 assert(adaptor.getOperands().size() == 2);
546 Type dstType = getTypeConverter()->convertType(op.getType());
550 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
551 adaptor.getOperands());
559struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
563 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter)
const override {
565 assert(adaptor.getOperands().size() == 2);
570 Type dstType = getTypeConverter()->convertType(op.getType());
574 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
575 op, dstType, adaptor.getOperands());
595template <
typename ArithOp,
typename SPIRVOp>
596struct BoolIOpPattern final :
public OpConversionPattern<ArithOp> {
597 BoolIOpPattern(
const TypeConverter &converter, MLIRContext *context)
600 : OpConversionPattern<ArithOp>(converter, context, 2) {}
603 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
604 ConversionPatternRewriter &rewriter)
const override {
608 Type dstType = this->getTypeConverter()->convertType(op.getType());
612 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
626template <
typename ArithOp>
627struct BoolIOpAndNotPattern final :
public OpConversionPattern<ArithOp> {
628 BoolIOpAndNotPattern(
const TypeConverter &converter, MLIRContext *context)
631 : OpConversionPattern<ArithOp>(converter, context, 2) {}
634 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const override {
639 Type dstType = this->getTypeConverter()->convertType(op.getType());
643 Location loc = op.getLoc();
644 Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
645 adaptor.getOperands()[1]);
646 rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
647 op, dstType, adaptor.getOperands()[0], notRhs);
654struct ShRSIBoolPattern final :
public OpConversionPattern<arith::ShRSIOp> {
655 ShRSIBoolPattern(
const TypeConverter &converter, MLIRContext *context)
658 : OpConversionPattern<arith::ShRSIOp>(converter, context,
662 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter)
const override {
667 rewriter.replaceOp(op, adaptor.getOperands().front());
678struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
682 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
683 ConversionPatternRewriter &rewriter)
const override {
684 Type srcType = adaptor.getOperands().front().getType();
688 Type dstType = getTypeConverter()->convertType(op.getType());
692 Location loc = op.getLoc();
693 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696 op, dstType, adaptor.getOperands().front(), one, zero);
708template <
typename ArithOp,
typename SPIRVOp,
bool IsSigned>
709struct IntToFPPattern final :
public OpConversionPattern<ArithOp> {
710 using OpConversionPattern<ArithOp>::OpConversionPattern;
713 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const override {
715 Type srcType = adaptor.getOperands().front().getType();
719 Type dstType = this->getTypeConverter()->convertType(op.getType());
724 unsigned originalBitwidth =
726 unsigned convertedBitwidth =
729 if (originalBitwidth >= convertedBitwidth) {
730 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
735 Location loc = op.getLoc();
737 if constexpr (IsSigned) {
739 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
742 Value shifted = spirv::ShiftLeftLogicalOp::create(
743 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
744 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
749 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
751 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
752 adaptor.getIn(), mask);
754 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
764struct IndexCastIndexI1Pattern final
765 :
public OpConversionPattern<arith::IndexCastOp> {
769 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
770 ConversionPatternRewriter &rewriter)
const override {
774 Type dstType = getTypeConverter()->convertType(op.getType());
778 Location loc = op.getLoc();
780 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
781 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
788struct IndexCastI1IndexPattern final
789 :
public OpConversionPattern<arith::IndexCastOp> {
793 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter)
const override {
798 Type dstType = getTypeConverter()->convertType(op.getType());
802 Location loc = op.getLoc();
803 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
804 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
805 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
817struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
821 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
822 ConversionPatternRewriter &rewriter)
const override {
823 Value operand = adaptor.getIn();
827 Location loc = op.getLoc();
828 Type dstType = getTypeConverter()->convertType(op.getType());
833 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
834 unsigned componentBitwidth = intTy.getWidth();
835 allOnes = spirv::ConstantOp::create(
836 rewriter, loc, intTy,
837 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
838 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
839 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
840 allOnes = spirv::ConstantOp::create(
841 rewriter, loc, vectorTy,
842 SplatElementsAttr::get(vectorTy,
843 APInt::getAllOnes(componentBitwidth)));
845 return rewriter.notifyMatchFailure(
846 loc, llvm::formatv(
"unhandled type: {0}", dstType));
849 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
850 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
858struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
862 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter)
const override {
864 Type srcType = adaptor.getIn().getType();
868 Type dstType = getTypeConverter()->convertType(op.getType());
872 if (dstType == srcType) {
880 assert(srcBW < dstBW);
882 rewriter, op.getLoc());
884 return rewriter.notifyMatchFailure(op,
"unsupported type for shift");
889 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
890 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
894 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
895 op, dstType, shiftLOp, shiftSize);
897 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
898 adaptor.getOperands());
911struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
915 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
916 ConversionPatternRewriter &rewriter)
const override {
917 Type srcType = adaptor.getOperands().front().getType();
921 Type dstType = getTypeConverter()->convertType(op.getType());
925 Location loc = op.getLoc();
926 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
927 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
928 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
929 op, dstType, adaptor.getOperands().front(), one, zero);
936struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
940 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter)
const override {
942 Type srcType = adaptor.getIn().getType();
946 Type dstType = getTypeConverter()->convertType(op.getType());
950 if (dstType == srcType) {
958 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
961 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
962 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
963 adaptor.getIn(), mask);
965 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
966 adaptor.getOperands());
978struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
982 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter)
const override {
984 Type dstType = getTypeConverter()->convertType(op.getType());
991 Location loc = op.getLoc();
992 auto srcType = adaptor.getOperands().front().getType();
994 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
995 Value maskedSrc = spirv::BitwiseAndOp::create(
996 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
997 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
999 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
1000 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
1001 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
1008struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
1012 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter)
const override {
1014 Type srcType = adaptor.getIn().getType();
1015 Type dstType = getTypeConverter()->convertType(op.getType());
1022 if (dstType == srcType) {
1029 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
1031 return rewriter.notifyMatchFailure(op,
"unsupported type for mask");
1032 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
1033 adaptor.getIn(), mask);
1036 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
1037 adaptor.getOperands());
1047static std::optional<spirv::FPRoundingMode>
1048convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
1049 switch (roundingMode) {
1050 case arith::RoundingMode::downward:
1051 return spirv::FPRoundingMode::RTN;
1052 case arith::RoundingMode::to_nearest_even:
1053 return spirv::FPRoundingMode::RTE;
1054 case arith::RoundingMode::toward_zero:
1055 return spirv::FPRoundingMode::RTZ;
1056 case arith::RoundingMode::upward:
1057 return spirv::FPRoundingMode::RTP;
1058 case arith::RoundingMode::to_nearest_away:
1061 return std::nullopt;
1063 llvm_unreachable(
"Unhandled rounding mode");
1067template <
typename Op,
typename SPIRVOp>
1068struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
1069 using OpConversionPattern<
Op>::OpConversionPattern;
1072 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1073 ConversionPatternRewriter &rewriter)
const override {
1074 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
1075 Type dstType = this->getTypeConverter()->convertType(op.getType());
1082 if (dstType == srcType) {
1085 rewriter.replaceOp(op, adaptor.getOperands().front());
1088 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
1089 if (
auto roundingModeOp =
1090 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
1091 if (arith::RoundingModeAttr roundingMode =
1092 roundingModeOp.getRoundingModeAttr()) {
1094 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
1095 return rewriter.notifyMatchFailure(
1097 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
1102 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1103 op, dstType, adaptor.getOperands());
1107 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1119class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
1124 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter)
const override {
1126 Type srcType = op.getLhs().getType();
1129 Type dstType = getTypeConverter()->convertType(srcType);
1133 switch (op.getPredicate()) {
1134 case arith::CmpIPredicate::eq: {
1135 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1139 case arith::CmpIPredicate::ne: {
1140 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1141 op, adaptor.getLhs(), adaptor.getRhs());
1144 case arith::CmpIPredicate::uge:
1145 case arith::CmpIPredicate::ugt:
1146 case arith::CmpIPredicate::ule:
1147 case arith::CmpIPredicate::ult: {
1150 Type type = rewriter.getI32Type();
1151 if (
auto vectorType = dyn_cast<VectorType>(dstType))
1152 type = VectorType::get(vectorType.getShape(), type);
1154 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1156 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1158 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1170class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1175 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const override {
1177 Type srcType = op.getLhs().getType();
1180 Type dstType = getTypeConverter()->convertType(srcType);
1184 switch (op.getPredicate()) {
1185#define DISPATCH(cmpPredicate, spirvOp) \
1186 case cmpPredicate: \
1187 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1188 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1189 !hasSameBitwidth(srcType, dstType)) { \
1190 return op.emitError( \
1191 "bitwidth emulation is not implemented yet on unsigned op"); \
1193 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1194 adaptor.getRhs()); \
1197 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1198 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1199 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1200 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1201 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1202 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1203 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1204 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1205 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1206 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1219class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1224 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1225 ConversionPatternRewriter &rewriter)
const override {
1226 switch (op.getPredicate()) {
1227#define DISPATCH(cmpPredicate, spirvOp) \
1228 case cmpPredicate: \
1229 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1230 adaptor.getRhs()); \
1234 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1235 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1236 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1237 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1238 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1239 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1241 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1242 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1243 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1244 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1245 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1246 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1259class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1264 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const override {
1266 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1267 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1272 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1273 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1284class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1289 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1290 ConversionPatternRewriter &rewriter)
const override {
1291 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1292 op.getPredicate() != arith::CmpFPredicate::UNO)
1295 Location loc = op.getLoc();
1298 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1299 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1301 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1304 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1307 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1308 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1310 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1311 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1312 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1315 rewriter.replaceOp(op, replace);
1326template <
typename ArithExtendedOp,
typename SPIRVExtendedOp>
1327class BinaryExtendedOpPattern final
1328 :
public OpConversionPattern<ArithExtendedOp> {
1330 using OpConversionPattern<ArithExtendedOp>::OpConversionPattern;
1332 matchAndRewrite(ArithExtendedOp op,
typename ArithExtendedOp::Adaptor adaptor,
1333 ConversionPatternRewriter &rewriter)
const override {
1334 Type dstElemTy = adaptor.getLhs().getType();
1335 Location loc = op->getLoc();
1336 Value
result = SPIRVExtendedOp::create(rewriter, loc, adaptor.getLhs(),
1339 Value valueResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1341 Value flagValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1345 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1346 Value flagResult = spirv::IEqualOp::create(rewriter, loc, flagValue, one);
1348 rewriter.replaceOp(op, {valueResult, flagResult});
1358template <
typename ArithMulOp,
typename SPIRVMulOp>
1359class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1361 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1363 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1364 ConversionPatternRewriter &rewriter)
const override {
1365 Location loc = op->getLoc();
1367 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1369 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1371 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1374 rewriter.replaceOp(op, {low, high});
1384class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1388 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1389 ConversionPatternRewriter &rewriter)
const override {
1390 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1391 adaptor.getTrueValue(),
1392 adaptor.getFalseValue());
1403template <
typename Op,
typename SPIRVOp>
1404class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1406 using OpConversionPattern<
Op>::OpConversionPattern;
1408 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1409 ConversionPatternRewriter &rewriter)
const override {
1410 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1411 Type dstType = converter->convertType(op.getType());
1423 Location loc = op.
getLoc();
1425 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1427 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1428 rewriter.replaceOp(op, spirvOp);
1432 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1433 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1435 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1436 adaptor.getLhs(), spirvOp);
1437 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1438 adaptor.getRhs(), select1);
1440 rewriter.replaceOp(op, select2);
1451template <
typename Op,
typename SPIRVOp>
1452class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1453 template <
typename TargetOp>
1454 constexpr bool shouldInsertNanGuards()
const {
1455 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1459 using OpConversionPattern<
Op>::OpConversionPattern;
1461 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1462 ConversionPatternRewriter &rewriter)
const override {
1463 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1464 Type dstType = converter->convertType(op.getType());
1479 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1481 if (!shouldInsertNanGuards<SPIRVOp>() ||
1482 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1483 rewriter.replaceOp(op, spirvOp);
1487 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1488 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1490 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1491 adaptor.getRhs(), spirvOp);
1492 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1493 adaptor.getLhs(), select1);
1495 rewriter.replaceOp(op, select2);
1510 ConstantCompositeOpPattern,
1511 ConstantScalarOpPattern,
1512 BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
1513 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1514 BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
1515 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1516 BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
1517 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1518 BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>,
1520 BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>,
1522 BoolIOpAndNotPattern<arith::RemUIOp>,
1524 BoolIOpAndNotPattern<arith::RemSIOp>,
1525 RemSIOpGLPattern, RemSIOpCLPattern,
1526 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1527 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1528 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1529 BoolIOpAndNotPattern<arith::ShLIOp>,
1530 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1531 BoolIOpAndNotPattern<arith::ShRUIOp>,
1541 ExtUIPattern, ExtUII1Pattern,
1542 ExtSIPattern, ExtSII1Pattern,
1543 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1544 TruncIPattern, TruncII1Pattern,
1545 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1546 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1548 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1549 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1550 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1551 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1552 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1553 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1554 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1555 CmpIOpBooleanPattern, CmpIOpPattern,
1556 CmpFOpNanNonePattern, CmpFOpPattern,
1557 BinaryExtendedOpPattern<arith::AddUIExtendedOp, spirv::IAddCarryOp>,
1558 BinaryExtendedOpPattern<arith::SubUIExtendedOp, spirv::ISubBorrowOp>,
1559 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1560 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1563 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1564 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1565 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1566 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1567 BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>,
1568 BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>,
1569 BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>,
1570 BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>,
1576 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1577 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1578 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1579 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1589 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1598struct ConvertArithToSPIRVPass
1599 :
public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1602 void runOnOperation()
override {
1605 std::unique_ptr<SPIRVConversionTarget>
target =
1609 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1610 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1615 target->addLegalOp<UnrealizedConversionCastOp>();
1618 target->addIllegalDialect<arith::ArithDialect>();
1623 if (failed(applyPartialConversion(op, *
target, std::move(patterns))))
1624 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.