11 #include "../SPIRVCommon/Pattern.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
29 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
30 #include "mlir/Conversion/Passes.h.inc"
33 #define DEBUG_TYPE "arith-to-spirv-pattern"
44 if (
auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
47 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
57 if (srcAttr.getValue().isIntN(dstType.getWidth()))
65 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '"
68 << dstAttr <<
"' for type '" << dstType <<
"'\n");
72 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
73 <<
"' illegal: cannot fit into target type '"
87 APFloat dstVal = srcAttr.getValue();
88 bool losesInfo =
false;
89 APFloat::opStatus status =
90 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
91 if (status != APFloat::opOK || losesInfo) {
92 LLVM_DEBUG(llvm::dbgs()
93 << srcAttr <<
" illegal: cannot fit into converted type '"
103 assert(type &&
"Not a valid type");
107 if (
auto vecType = dyn_cast<VectorType>(type))
108 return vecType.getElementType().isInteger(1);
116 if (
auto vectorType = dyn_cast<VectorType>(type)) {
119 return builder.
create<spirv::ConstantOp>(loc, vectorType, attr);
122 if (
auto intType = dyn_cast<IntegerType>(type))
123 return builder.
create<spirv::ConstantOp>(
132 auto getNumBitwidth = [](
Type type) {
134 if (type.isIntOrFloat())
135 bw = type.getIntOrFloatBitWidth();
136 else if (
auto vecType = dyn_cast<VectorType>(type))
137 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
140 unsigned aBW = getNumBitwidth(a);
141 unsigned bBW = getNumBitwidth(b);
142 return aBW != 0 && bBW != 0 && aBW == bBW;
151 llvm::formatv(
"failed to convert source type '{0}'", srcType));
168 struct ConstantCompositeOpPattern final
173 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
175 auto srcType = dyn_cast<ShapedType>(constOp.getType());
176 if (!srcType || srcType.getNumElements() == 1)
180 assert((isa<VectorType, RankedTensorType>(srcType)));
182 Type dstType = getTypeConverter()->convertType(srcType);
186 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
187 if (!dstElementsAttr)
190 ShapedType dstAttrType = dstElementsAttr.getType();
194 if (srcType.getRank() > 1) {
195 if (isa<RankedTensorType>(srcType)) {
197 srcType.getElementType());
198 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
205 Type srcElemType = srcType.getElementType();
209 if (
auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
210 dstElemType = arrayType.getElementType();
212 dstElemType = cast<VectorType>(dstType).getElementType();
216 if (srcElemType != dstElemType) {
218 if (isa<FloatType>(srcElemType)) {
219 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
224 elements.push_back(dstAttr);
229 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
231 srcAttr, cast<IntegerType>(dstElemType), rewriter);
234 elements.push_back(dstAttr);
242 if (isa<RankedTensorType>(dstAttrType))
258 struct ConstantScalarOpPattern final
263 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
265 Type srcType = constOp.getType();
266 if (
auto shapedType = dyn_cast<ShapedType>(srcType)) {
267 if (shapedType.getNumElements() != 1)
269 srcType = shapedType.getElementType();
275 if (
auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
276 cstAttr = elementsAttr.getSplatValue<
Attribute>();
278 Type dstType = getTypeConverter()->convertType(srcType);
283 if (isa<FloatType>(srcType)) {
284 auto srcAttr = cast<FloatAttr>(cstAttr);
285 auto dstAttr = srcAttr;
289 if (srcType != dstType) {
312 auto srcAttr = cast<IntegerAttr>(cstAttr);
313 IntegerAttr dstAttr =
333 template <
typename SignedAbsOp>
337 assert(lhs == signOperand || rhs == signOperand);
342 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
343 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
348 if (lhs == signOperand)
349 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
351 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
352 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
353 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
364 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
366 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
367 op.
getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
368 adaptor.getOperands()[0], rewriter);
380 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
382 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
383 op.
getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
384 adaptor.getOperands()[0], rewriter);
399 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
404 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
406 assert(adaptor.getOperands().size() == 2);
407 Type dstType = this->getTypeConverter()->convertType(op.getType());
412 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
413 op, dstType, adaptor.getOperands());
415 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
416 op, dstType, adaptor.getOperands());
431 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
433 assert(adaptor.getOperands().size() == 2);
438 Type dstType = getTypeConverter()->convertType(op.getType());
443 adaptor.getOperands());
455 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
457 assert(adaptor.getOperands().size() == 2);
462 Type dstType = getTypeConverter()->convertType(op.getType());
467 op, dstType, adaptor.getOperands());
482 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
484 Type srcType = adaptor.getOperands().front().getType();
488 Type dstType = getTypeConverter()->convertType(op.getType());
494 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
496 op, dstType, adaptor.getOperands().front(), one, zero);
511 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
513 Value operand = adaptor.getIn();
518 Type dstType = getTypeConverter()->convertType(op.getType());
523 if (
auto intTy = dyn_cast<IntegerType>(dstType)) {
524 unsigned componentBitwidth = intTy.getWidth();
525 allOnes = rewriter.
create<spirv::ConstantOp>(
527 rewriter.
getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
528 }
else if (
auto vectorTy = dyn_cast<VectorType>(dstType)) {
529 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
530 allOnes = rewriter.
create<spirv::ConstantOp>(
533 APInt::getAllOnes(componentBitwidth)));
536 loc, llvm::formatv(
"unhandled type: {0}", dstType));
552 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
554 Type srcType = adaptor.getIn().getType();
558 Type dstType = getTypeConverter()->convertType(op.getType());
562 if (dstType == srcType) {
570 assert(srcBW < dstBW);
577 auto shiftLOp = rewriter.
create<spirv::ShiftLeftLogicalOp>(
578 op.
getLoc(), dstType, adaptor.getIn(), shiftSize);
583 op, dstType, shiftLOp, shiftSize);
586 adaptor.getOperands());
603 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
605 Type srcType = adaptor.getOperands().front().getType();
609 Type dstType = getTypeConverter()->convertType(op.getType());
615 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
617 op, dstType, adaptor.getOperands().front(), one, zero);
628 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
630 Type srcType = adaptor.getIn().getType();
634 Type dstType = getTypeConverter()->convertType(op.getType());
638 if (dstType == srcType) {
646 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
649 adaptor.getIn(), mask);
668 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
670 Type dstType = getTypeConverter()->convertType(op.getType());
678 auto srcType = adaptor.getOperands().front().getType();
680 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
681 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
682 loc, srcType, adaptor.getOperands()[0], mask);
683 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
686 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
698 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
700 Type srcType = adaptor.getIn().getType();
701 Type dstType = getTypeConverter()->convertType(op.getType());
708 if (dstType == srcType) {
715 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.
getLoc());
717 adaptor.getIn(), mask);
732 template <
typename Op,
typename SPIRVOp>
737 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
739 assert(adaptor.getOperands().size() == 1);
740 Type srcType = adaptor.getOperands().front().getType();
741 Type dstType = this->getTypeConverter()->convertType(op.getType());
748 if (dstType == srcType) {
751 rewriter.
replaceOp(op, adaptor.getOperands().front());
753 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
754 adaptor.getOperands());
770 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
772 Type srcType = op.getLhs().getType();
775 Type dstType = getTypeConverter()->convertType(srcType);
779 switch (op.getPredicate()) {
780 case arith::CmpIPredicate::eq: {
785 case arith::CmpIPredicate::ne: {
787 op, adaptor.getLhs(), adaptor.getRhs());
790 case arith::CmpIPredicate::uge:
791 case arith::CmpIPredicate::ugt:
792 case arith::CmpIPredicate::ule:
793 case arith::CmpIPredicate::ult: {
797 if (
auto vectorType = dyn_cast<VectorType>(dstType))
800 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getLhs());
802 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getRhs());
821 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
823 Type srcType = op.getLhs().getType();
826 Type dstType = getTypeConverter()->convertType(srcType);
830 switch (op.getPredicate()) {
831 #define DISPATCH(cmpPredicate, spirvOp) \
833 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
834 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
835 !hasSameBitwidth(srcType, dstType)) { \
836 return op.emitError( \
837 "bitwidth emulation is not implemented yet on unsigned op"); \
839 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
843 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
844 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
845 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
846 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
847 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
848 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
849 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
850 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
851 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
852 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
870 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
872 switch (op.getPredicate()) {
873 #define DISPATCH(cmpPredicate, spirvOp) \
875 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
880 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
881 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
882 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
883 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
884 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
885 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
887 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
888 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
889 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
890 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
891 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
892 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
910 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
912 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
918 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
935 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
937 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
938 op.getPredicate() != arith::CmpFPredicate::UNO)
942 auto *converter = getTypeConverter<SPIRVTypeConverter>();
945 if (converter->getOptions().enableFastMathMode) {
946 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
948 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
954 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
955 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
957 replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
958 if (op.getPredicate() == arith::CmpFPredicate::ORD)
959 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
972 class AddUIExtendedOpPattern final
977 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
979 Type dstElemTy = adaptor.getLhs().getType();
981 Value result = rewriter.
create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
984 Value sumResult = rewriter.
create<spirv::CompositeExtractOp>(
986 Value carryValue = rewriter.
create<spirv::CompositeExtractOp>(
990 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
991 Value carryResult = rewriter.
create<spirv::IEqualOp>(loc, carryValue, one);
993 rewriter.
replaceOp(op, {sumResult, carryResult});
1003 template <
typename ArithMulOp,
typename SPIRVMulOp>
1008 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
1012 rewriter.
create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1014 Value low = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1016 Value high = rewriter.
create<spirv::CompositeExtractOp>(loc, result,
1033 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1036 adaptor.getTrueValue(),
1037 adaptor.getFalseValue());
1048 template <
typename Op,
typename SPIRVOp>
1053 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1055 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1056 Type dstType = converter->convertType(op.getType());
1070 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1072 if (converter->getOptions().enableFastMathMode) {
1077 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1078 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1080 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1081 adaptor.getLhs(), spirvOp);
1082 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1083 adaptor.getRhs(), select1);
1096 template <
typename Op,
typename SPIRVOp>
1098 template <
typename TargetOp>
1099 constexpr
bool shouldInsertNanGuards()
const {
1100 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1106 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1108 auto *converter = this->
template getTypeConverter<SPIRVTypeConverter>();
1109 Type dstType = converter->convertType(op.getType());
1124 rewriter.
create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1126 if (!shouldInsertNanGuards<SPIRVOp>() ||
1127 converter->getOptions().enableFastMathMode) {
1132 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
1133 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
1135 Value select1 = rewriter.
create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1136 adaptor.getRhs(), spirvOp);
1137 Value select2 = rewriter.
create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1138 adaptor.getLhs(), select1);
1155 ConstantCompositeOpPattern,
1156 ConstantScalarOpPattern,
1163 RemSIOpGLPattern, RemSIOpCLPattern,
1164 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1165 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1166 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1176 ExtUIPattern, ExtUII1Pattern,
1177 ExtSIPattern, ExtSII1Pattern,
1178 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1179 TruncIPattern, TruncII1Pattern,
1180 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1181 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1182 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1183 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1184 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1185 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1186 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1187 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1188 CmpIOpBooleanPattern, CmpIOpPattern,
1189 CmpFOpNanNonePattern, CmpFOpPattern,
1190 AddUIExtendedOpPattern,
1191 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1192 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1195 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1196 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1197 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1198 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1204 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1205 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1206 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1207 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1217 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
1226 struct ConvertArithToSPIRVPass
1227 :
public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1228 void runOnOperation()
override {
1231 std::unique_ptr<SPIRVConversionTarget> target =
1235 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1236 options.enableFastMathMode = this->enableFastMath;
1241 target->addLegalOp<UnrealizedConversionCastOp>();
1244 target->addIllegalDialect<arith::ArithDialect>();
1250 signalPassFailure();
1256 return std::make_unique<ConvertArithToSPIRVPass>();
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
BoolAttr getBoolAttr(bool value)
FloatAttr getF32FloatAttr(float value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.