21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "arith-to-spirv-pattern"
45 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
47 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
68 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
69 << dstAttr <<
"' for type '" << dstType <<
"'\n");
73 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
74 <<
"' illegal: cannot fit into target type '"
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo =
false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr <<
" illegal: cannot fit into converted type '"
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
115 assert(type &&
"Not a valid type");
119 if (
auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
128 if (
auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
134 if (
auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
144 auto getNumBitwidth = [](
Type type) {
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (
auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(
b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
161 return rewriter.notifyMatchFailure(
163 llvm::formatv(
"failed to convert source type '{0}'", srcType));
175 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
182template <
typename Op,
typename SPIRVOp>
183struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
184 using OpConversionPattern<
Op>::OpConversionPattern;
187 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 assert(adaptor.getOperands().size() <= 3);
190 auto converter = this->
template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
193 return rewriter.notifyMatchFailure(
195 llvm::formatv(
"failed to convert type {0} for SPIR-V", op.getType()));
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
200 dstType != op.getType()) {
201 return op.
emitError(
"bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (
auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
218 rewriter.getUnitAttr());
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
222 rewriter.getUnitAttr());
233struct ConstantCompositeOpPattern final
234 :
public OpConversionPattern<arith::ConstantOp> {
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter)
const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp,
"unsupported ShapedType");
249 Type dstType = getTypeConverter()->convertType(srcType);
256 if (
auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 }
else if (
auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
264 return constOp->emitError(
"could not find resource blob");
270 bool detectedSplat =
false;
272 return constOp->emitError(
"resource is not a valid buffer");
277 return constOp->emitError(
"unsupported elements attribute");
280 ShapedType dstAttrType = dstElementsAttr.
getType();
284 if (srcType.getRank() > 1) {
285 if (isa<RankedTensorType>(srcType)) {
286 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
287 srcType.getElementType());
288 dstElementsAttr = dstElementsAttr.
reshape(dstAttrType);
295 Type srcElemType = srcType.getElementType();
299 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
300 dstElemType = arrayType.getElementType();
302 dstElemType = cast<VectorType>(dstType).getElementType();
306 if (srcElemType != dstElemType) {
308 if (isa<FloatType>(srcElemType)) {
309 for (FloatAttr srcAttr : dstElementsAttr.
getValues<FloatAttr>()) {
312 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
313 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
315 isa<IntegerType>(dstElemType)) {
324 elements.push_back(dstAttr);
329 for (IntegerAttr srcAttr : dstElementsAttr.
getValues<IntegerAttr>()) {
331 srcAttr, cast<IntegerType>(dstElemType), rewriter);
334 elements.push_back(dstAttr);
342 if (isa<RankedTensorType>(dstAttrType))
344 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
346 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
351 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
358struct ConstantScalarOpPattern final
359 :
public OpConversionPattern<arith::ConstantOp> {
363 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter)
const override {
365 Type srcType = constOp.getType();
366 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
367 if (shapedType.getNumElements() != 1)
369 srcType = shapedType.getElementType();
374 Attribute cstAttr = constOp.getValue();
375 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
376 cstAttr = elementsAttr.getSplatValue<Attribute>();
378 Type dstType = getTypeConverter()->convertType(srcType);
383 if (isa<FloatType>(srcType)) {
384 auto srcAttr = cast<FloatAttr>(cstAttr);
385 Attribute dstAttr = srcAttr;
389 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
390 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
392 dstType.getIntOrFloatBitWidth() == 8) {
397 }
else if (srcType != dstType) {
403 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
414 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
420 auto srcAttr = cast<IntegerAttr>(cstAttr);
421 IntegerAttr dstAttr =
425 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
441template <
typename SignedAbsOp>
444 assert(
lhs.getType() ==
rhs.getType());
445 assert(
lhs == signOperand ||
rhs == signOperand);
450 Value lhsAbs = SignedAbsOp::create(builder, loc, type,
lhs);
451 Value rhsAbs = SignedAbsOp::create(builder, loc, type,
rhs);
452 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
456 if (
lhs == signOperand)
457 isPositive = spirv::IEqualOp::create(builder, loc,
lhs, lhsAbs);
459 isPositive = spirv::IEqualOp::create(builder, loc,
rhs, rhsAbs);
460 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
461 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
469struct RemSIOpGLPattern final :
public OpConversionPattern<arith::RemSIOp> {
473 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
474 ConversionPatternRewriter &rewriter)
const override {
475 Value
result = emulateSignedRemainder<spirv::CLSAbsOp>(
476 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
477 adaptor.getOperands()[0], rewriter);
478 rewriter.replaceOp(op,
result);
485struct RemSIOpCLPattern final :
public OpConversionPattern<arith::RemSIOp> {
489 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter)
const override {
491 Value
result = emulateSignedRemainder<spirv::GLSAbsOp>(
492 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
493 adaptor.getOperands()[0], rewriter);
494 rewriter.replaceOp(op,
result);
508template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
509struct BitwiseOpPattern final :
public OpConversionPattern<Op> {
510 using OpConversionPattern<
Op>::OpConversionPattern;
513 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
515 assert(adaptor.getOperands().size() == 2);
516 Type dstType = this->getTypeConverter()->convertType(op.getType());
521 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
522 op, dstType, adaptor.getOperands());
524 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
525 op, dstType, adaptor.getOperands());
536struct XOrIOpLogicalPattern final :
public OpConversionPattern<arith::XOrIOp> {
540 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter)
const override {
542 assert(adaptor.getOperands().size() == 2);
547 Type dstType = getTypeConverter()->convertType(op.getType());
551 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
552 adaptor.getOperands());
560struct XOrIOpBooleanPattern final :
public OpConversionPattern<arith::XOrIOp> {
564 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter)
const override {
566 assert(adaptor.getOperands().size() == 2);
571 Type dstType = getTypeConverter()->convertType(op.getType());
575 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
576 op, dstType, adaptor.getOperands());
587struct UIToFPI1Pattern final :
public OpConversionPattern<arith::UIToFPOp> {
591 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
592 ConversionPatternRewriter &rewriter)
const override {
593 Type srcType = adaptor.getOperands().front().getType();
597 Type dstType = getTypeConverter()->convertType(op.getType());
601 Location loc = op.getLoc();
602 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
603 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
604 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
605 op, dstType, adaptor.getOperands().front(), one, zero);
615struct IndexCastIndexI1Pattern final
616 :
public OpConversionPattern<arith::IndexCastOp> {
620 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621 ConversionPatternRewriter &rewriter)
const override {
625 Type dstType = getTypeConverter()->convertType(op.getType());
629 Location loc = op.getLoc();
631 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
632 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
639struct IndexCastI1IndexPattern final
640 :
public OpConversionPattern<arith::IndexCastOp> {
644 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
645 ConversionPatternRewriter &rewriter)
const override {
649 Type dstType = getTypeConverter()->convertType(op.getType());
653 Location loc = op.getLoc();
654 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
655 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
656 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
668struct ExtSII1Pattern final :
public OpConversionPattern<arith::ExtSIOp> {
672 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter)
const override {
674 Value operand = adaptor.getIn();
678 Location loc = op.getLoc();
679 Type dstType = getTypeConverter()->convertType(op.getType());
684 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
685 unsigned componentBitwidth = intTy.getWidth();
686 allOnes = spirv::ConstantOp::create(
687 rewriter, loc, intTy,
688 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
689 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
690 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
691 allOnes = spirv::ConstantOp::create(
692 rewriter, loc, vectorTy,
693 SplatElementsAttr::get(vectorTy,
694 APInt::getAllOnes(componentBitwidth)));
696 return rewriter.notifyMatchFailure(
697 loc, llvm::formatv(
"unhandled type: {0}", dstType));
700 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
701 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
709struct ExtSIPattern final :
public OpConversionPattern<arith::ExtSIOp> {
713 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const override {
715 Type srcType = adaptor.getIn().getType();
719 Type dstType = getTypeConverter()->convertType(op.getType());
723 if (dstType == srcType) {
731 assert(srcBW < dstBW);
733 rewriter, op.getLoc());
738 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
739 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
743 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
744 op, dstType, shiftLOp, shiftSize);
746 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
747 adaptor.getOperands());
760struct ExtUII1Pattern final :
public OpConversionPattern<arith::ExtUIOp> {
764 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
765 ConversionPatternRewriter &rewriter)
const override {
766 Type srcType = adaptor.getOperands().front().getType();
770 Type dstType = getTypeConverter()->convertType(op.getType());
775 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
776 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
778 op, dstType, adaptor.getOperands().front(), one, zero);
785struct ExtUIPattern final :
public OpConversionPattern<arith::ExtUIOp> {
789 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
790 ConversionPatternRewriter &rewriter)
const override {
791 Type srcType = adaptor.getIn().getType();
795 Type dstType = getTypeConverter()->convertType(op.getType());
799 if (dstType == srcType) {
807 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
809 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
810 adaptor.getIn(), mask);
812 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
813 adaptor.getOperands());
825struct TruncII1Pattern final :
public OpConversionPattern<arith::TruncIOp> {
829 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
830 ConversionPatternRewriter &rewriter)
const override {
831 Type dstType = getTypeConverter()->convertType(op.getType());
838 Location loc = op.getLoc();
839 auto srcType = adaptor.getOperands().front().getType();
841 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
842 Value maskedSrc = spirv::BitwiseAndOp::create(
843 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
844 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
846 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
847 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
848 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
855struct TruncIPattern final :
public OpConversionPattern<arith::TruncIOp> {
859 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter)
const override {
861 Type srcType = adaptor.getIn().getType();
862 Type dstType = getTypeConverter()->convertType(op.getType());
869 if (dstType == srcType) {
876 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
877 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
878 adaptor.getIn(), mask);
881 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
882 adaptor.getOperands());
892static std::optional<spirv::FPRoundingMode>
893convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
894 switch (roundingMode) {
895 case arith::RoundingMode::downward:
896 return spirv::FPRoundingMode::RTN;
897 case arith::RoundingMode::to_nearest_even:
898 return spirv::FPRoundingMode::RTE;
899 case arith::RoundingMode::toward_zero:
900 return spirv::FPRoundingMode::RTZ;
901 case arith::RoundingMode::upward:
902 return spirv::FPRoundingMode::RTP;
903 case arith::RoundingMode::to_nearest_away:
908 llvm_unreachable(
"Unhandled rounding mode");
912template <
typename Op,
typename SPIRVOp>
913struct TypeCastingOpPattern final :
public OpConversionPattern<Op> {
914 using OpConversionPattern<
Op>::OpConversionPattern;
917 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
918 ConversionPatternRewriter &rewriter)
const override {
919 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
920 Type dstType = this->getTypeConverter()->convertType(op.getType());
927 if (dstType == srcType) {
930 rewriter.replaceOp(op, adaptor.getOperands().front());
933 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
934 if (
auto roundingModeOp =
935 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
936 if (arith::RoundingModeAttr roundingMode =
937 roundingModeOp.getRoundingModeAttr()) {
939 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
940 return rewriter.notifyMatchFailure(
942 llvm::formatv(
"unsupported rounding mode '{0}'", roundingMode));
947 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
948 op, dstType, adaptor.getOperands());
952 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
964class CmpIOpBooleanPattern final :
public OpConversionPattern<arith::CmpIOp> {
969 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter)
const override {
971 Type srcType = op.getLhs().getType();
974 Type dstType = getTypeConverter()->convertType(srcType);
978 switch (op.getPredicate()) {
979 case arith::CmpIPredicate::eq: {
980 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
984 case arith::CmpIPredicate::ne: {
985 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
986 op, adaptor.getLhs(), adaptor.getRhs());
989 case arith::CmpIPredicate::uge:
990 case arith::CmpIPredicate::ugt:
991 case arith::CmpIPredicate::ule:
992 case arith::CmpIPredicate::ult: {
995 Type type = rewriter.getI32Type();
996 if (
auto vectorType = dyn_cast<VectorType>(dstType))
997 type = VectorType::get(vectorType.getShape(), type);
999 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1001 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1003 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1015class CmpIOpPattern final :
public OpConversionPattern<arith::CmpIOp> {
1020 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter)
const override {
1022 Type srcType = op.getLhs().getType();
1025 Type dstType = getTypeConverter()->convertType(srcType);
1029 switch (op.getPredicate()) {
1030#define DISPATCH(cmpPredicate, spirvOp) \
1031 case cmpPredicate: \
1032 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1033 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1034 !hasSameBitwidth(srcType, dstType)) { \
1035 return op.emitError( \
1036 "bitwidth emulation is not implemented yet on unsigned op"); \
1038 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1039 adaptor.getRhs()); \
1042 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1044 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1045 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1046 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1047 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1048 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1049 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1050 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1051 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1064class CmpFOpPattern final :
public OpConversionPattern<arith::CmpFOp> {
1069 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1070 ConversionPatternRewriter &rewriter)
const override {
1071 switch (op.getPredicate()) {
1072#define DISPATCH(cmpPredicate, spirvOp) \
1073 case cmpPredicate: \
1074 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1075 adaptor.getRhs()); \
1079 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1080 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1081 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1082 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1083 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1086 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1087 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1088 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1089 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1090 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1091 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1104class CmpFOpNanKernelPattern final :
public OpConversionPattern<arith::CmpFOp> {
1109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1110 ConversionPatternRewriter &rewriter)
const override {
1111 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1112 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1117 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1118 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1129class CmpFOpNanNonePattern final :
public OpConversionPattern<arith::CmpFOp> {
1134 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1135 ConversionPatternRewriter &rewriter)
const override {
1136 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1137 op.getPredicate() != arith::CmpFPredicate::UNO)
1140 Location loc = op.getLoc();
1143 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1144 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1146 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1149 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1152 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1153 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1155 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1156 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1157 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1160 rewriter.replaceOp(op, replace);
1170class AddUIExtendedOpPattern final
1171 :
public OpConversionPattern<arith::AddUIExtendedOp> {
1175 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const override {
1177 Type dstElemTy = adaptor.getLhs().getType();
1178 Location loc = op->getLoc();
1179 Value
result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1182 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1184 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1188 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1189 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1191 rewriter.replaceOp(op, {sumResult, carryResult});
1201template <
typename ArithMulOp,
typename SPIRVMulOp>
1202class MulIExtendedOpPattern final :
public OpConversionPattern<ArithMulOp> {
1204 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1206 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1207 ConversionPatternRewriter &rewriter)
const override {
1208 Location loc = op->getLoc();
1210 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1212 Value low = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1214 Value high = spirv::CompositeExtractOp::create(rewriter, loc,
result,
1217 rewriter.replaceOp(op, {low, high});
1227class SelectOpPattern final :
public OpConversionPattern<arith::SelectOp> {
1231 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1232 ConversionPatternRewriter &rewriter)
const override {
1233 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1234 adaptor.getTrueValue(),
1235 adaptor.getFalseValue());
1246template <
typename Op,
typename SPIRVOp>
1247class MinimumMaximumFOpPattern final :
public OpConversionPattern<Op> {
1249 using OpConversionPattern<
Op>::OpConversionPattern;
1251 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
1252 ConversionPatternRewriter &rewriter)
const override {
1253 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1254 Type dstType = converter->convertType(op.getType());
1266 Location loc = op.
getLoc();
1268 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1270 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1271 rewriter.replaceOp(op, spirvOp);
1275 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1276 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1278 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1279 adaptor.getLhs(), spirvOp);
1280 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1281 adaptor.getRhs(), select1);
1283 rewriter.replaceOp(op, select2);
1294template <
typename Op,
typename SPIRVOp>
1295class MinNumMaxNumFOpPattern final :
public OpConversionPattern<Op> {
1296 template <
typename TargetOp>
1297 constexpr bool shouldInsertNanGuards()
const {
1298 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1302 using OpConversionPattern<
Op>::OpConversionPattern;
1304 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1305 ConversionPatternRewriter &rewriter)
const override {
1306 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1307 Type dstType = converter->convertType(op.getType());
1322 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1324 if (!shouldInsertNanGuards<SPIRVOp>() ||
1325 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1326 rewriter.replaceOp(op, spirvOp);
1330 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1331 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1333 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1334 adaptor.getRhs(), spirvOp);
1335 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1336 adaptor.getLhs(), select1);
1338 rewriter.replaceOp(op, select2);
1353 ConstantCompositeOpPattern,
1354 ConstantScalarOpPattern,
1355 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1356 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1357 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1361 RemSIOpGLPattern, RemSIOpCLPattern,
1362 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1363 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1364 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1365 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1374 ExtUIPattern, ExtUII1Pattern,
1375 ExtSIPattern, ExtSII1Pattern,
1376 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1377 TruncIPattern, TruncII1Pattern,
1378 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1379 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1380 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1381 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1382 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1383 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1385 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1386 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1387 CmpIOpBooleanPattern, CmpIOpPattern,
1388 CmpFOpNanNonePattern, CmpFOpPattern,
1389 AddUIExtendedOpPattern,
1390 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1391 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1394 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1395 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1396 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1397 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1403 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1404 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1405 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1406 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1411 >(typeConverter,
patterns.getContext());
1425struct ConvertArithToSPIRVPass
1429 void runOnOperation()
override {
1432 std::unique_ptr<SPIRVConversionTarget>
target =
1436 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1437 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1442 target->addLegalOp<UnrealizedConversionCastOp>();
1445 target->addIllegalDialect<arith::ArithDialect>();
1450 if (failed(applyPartialConversion(op, *
target, std::move(
patterns))))
1451 signalPassFailure();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ConvertArithToSPIRVPassBase Base
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.