10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 18 #include "llvm/Support/Debug.h" 20 #define DEBUG_TYPE "arith-to-spirv-pattern" 31 struct ConstantCompositeOpPattern final
36 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
41 struct ConstantScalarOpPattern final
46 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
58 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
67 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
75 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
80 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
89 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
99 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
109 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
119 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
129 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
134 template <
typename Op,
typename SPIRVOp>
139 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
149 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
159 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
169 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
180 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
191 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
215 if (
auto intAttr = srcAttr.
dyn_cast<IntegerAttr>())
216 return builder.
getBoolAttr(intAttr.getValue().getBoolValue());
226 if (srcAttr.getValue().isIntN(dstType.getWidth()))
234 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
236 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr <<
"' converted to '" 237 << dstAttr <<
"' for type '" << dstType <<
"'\n");
241 LLVM_DEBUG(llvm::dbgs() <<
"attribute '" << srcAttr
242 <<
"' illegal: cannot fit into target type '" 243 << dstType <<
"'\n");
244 return IntegerAttr();
252 if (!dstType.
isF32())
256 APFloat dstVal = srcAttr.getValue();
257 bool losesInfo =
false;
258 APFloat::opStatus status =
259 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
260 if (status != APFloat::opOK || losesInfo) {
261 LLVM_DEBUG(llvm::dbgs()
262 << srcAttr <<
" illegal: cannot fit into converted type '" 263 << dstType <<
"'\n");
274 if (
auto vecType = type.
dyn_cast<VectorType>())
275 return vecType.getElementType().isInteger(1);
282 auto getNumBitwidth = [](
Type type) {
284 if (type.isIntOrFloat())
285 bw = type.getIntOrFloatBitWidth();
286 else if (
auto vecType = type.dyn_cast<VectorType>())
287 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
290 unsigned aBW = getNumBitwidth(a);
291 unsigned bBW = getNumBitwidth(b);
292 return aBW != 0 && bBW != 0 && aBW == bBW;
300 arith::ConstantOp constOp, OpAdaptor adaptor,
302 auto srcType = constOp.getType().dyn_cast<ShapedType>();
303 if (!srcType || srcType.getNumElements() == 1)
307 assert((srcType.isa<VectorType, RankedTensorType>()));
309 auto dstType = getTypeConverter()->convertType(srcType);
314 if (!dstElementsAttr)
317 ShapedType dstAttrType = dstElementsAttr.getType();
320 if (srcType.getRank() > 1) {
321 if (srcType.isa<RankedTensorType>()) {
322 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
323 srcType.getElementType());
324 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
331 Type srcElemType = srcType.getElementType();
336 dstElemType = arrayType.getElementType();
342 if (srcElemType != dstElemType) {
345 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
350 elements.push_back(dstAttr);
355 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
357 srcAttr, dstElemType.
cast<IntegerType>(), rewriter);
360 elements.push_back(dstAttr);
368 if (dstAttrType.isa<RankedTensorType>())
369 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
371 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
386 arith::ConstantOp constOp, OpAdaptor adaptor,
388 Type srcType = constOp.getType();
389 if (
auto shapedType = srcType.
dyn_cast<ShapedType>()) {
390 if (shapedType.getNumElements() != 1)
392 srcType = shapedType.getElementType();
399 cstAttr = elementsAttr.getSplatValue<
Attribute>();
401 Type dstType = getTypeConverter()->convertType(srcType);
407 auto srcAttr = cstAttr.
cast<FloatAttr>();
408 auto dstAttr = srcAttr;
412 if (srcType != dstType) {
435 auto srcAttr = cstAttr.
cast<IntegerAttr>();
455 template <
typename SignedAbsOp>
459 assert(lhs == signOperand || rhs == signOperand);
464 Value lhsAbs = builder.
create<SignedAbsOp>(loc, type, lhs);
465 Value rhsAbs = builder.
create<SignedAbsOp>(loc, type, rhs);
470 if (lhs == signOperand)
471 isPositive = builder.
create<spirv::IEqualOp>(loc, lhs, lhsAbs);
473 isPositive = builder.
create<spirv::IEqualOp>(loc, rhs, rhsAbs);
474 Value absNegate = builder.
create<spirv::SNegateOp>(loc, type,
abs);
475 return builder.
create<spirv::SelectOp>(loc, type, isPositive,
abs, absNegate);
479 RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
481 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
482 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
483 adaptor.getOperands()[0], rewriter);
494 RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
496 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
497 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
498 adaptor.getOperands()[0], rewriter);
508 template <
typename Op,
typename SPIRVLogicalOp,
typename SPIRVBitwiseOp>
510 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
511 Op op,
typename Op::Adaptor adaptor,
513 assert(adaptor.getOperands().size() == 2);
515 this->getTypeConverter()->convertType(op.getResult().getType());
519 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
520 adaptor.getOperands());
522 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
523 adaptor.getOperands());
533 arith::XOrIOp op, OpAdaptor adaptor,
535 assert(adaptor.getOperands().size() == 2);
540 auto dstType = getTypeConverter()->convertType(op.getType());
544 adaptor.getOperands());
554 arith::XOrIOp op, OpAdaptor adaptor,
556 assert(adaptor.getOperands().size() == 2);
561 auto dstType = getTypeConverter()->convertType(op.getType());
565 adaptor.getOperands());
574 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
576 auto srcType = adaptor.getOperands().front().getType();
581 this->getTypeConverter()->convertType(op.getResult().getType());
584 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
585 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
586 op, dstType, adaptor.getOperands().front(), one, zero);
595 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
597 auto srcType = adaptor.getOperands().front().getType();
602 this->getTypeConverter()->convertType(op.getResult().getType());
605 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
606 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
607 op, dstType, adaptor.getOperands().front(), one, zero);
616 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
619 this->getTypeConverter()->convertType(op.getResult().getType());
624 auto srcType = adaptor.getOperands().front().getType();
626 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
627 Value maskedSrc = rewriter.
create<spirv::BitwiseAndOp>(
628 loc, srcType, adaptor.getOperands()[0], mask);
629 Value isOne = rewriter.
create<spirv::IEqualOp>(loc, maskedSrc, mask);
632 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
641 template <
typename Op,
typename SPIRVOp>
642 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
643 Op op,
typename Op::Adaptor adaptor,
645 assert(adaptor.getOperands().size() == 1);
646 auto srcType = adaptor.getOperands().front().getType();
648 this->getTypeConverter()->convertType(op.getResult().getType());
651 if (dstType == srcType) {
654 rewriter.
replaceOp(op, adaptor.getOperands().front());
656 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
657 adaptor.getOperands());
667 arith::CmpIOp op, OpAdaptor adaptor,
669 Type srcType = op.getLhs().getType();
672 Type dstType = getTypeConverter()->convertType(srcType);
676 switch (op.getPredicate()) {
677 case arith::CmpIPredicate::eq: {
682 case arith::CmpIPredicate::ne: {
687 case arith::CmpIPredicate::uge:
688 case arith::CmpIPredicate::ugt:
689 case arith::CmpIPredicate::ule:
690 case arith::CmpIPredicate::ult: {
695 type = VectorType::get(
vectorType.getShape(), type);
697 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getLhs());
699 rewriter.
create<arith::ExtUIOp>(op.
getLoc(), type, adaptor.getRhs());
716 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
718 Type srcType = op.getLhs().getType();
721 Type dstType = getTypeConverter()->convertType(srcType);
725 switch (op.getPredicate()) {
726 #define DISPATCH(cmpPredicate, spirvOp) \ 728 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 729 srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \ 730 return op.emitError( \ 731 "bitwidth emulation is not implemented yet on unsigned op"); \ 733 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 737 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
738 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
739 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
740 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
741 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
742 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
743 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
744 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
745 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
746 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
758 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
760 switch (op.getPredicate()) {
761 #define DISPATCH(cmpPredicate, spirvOp) \ 763 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 768 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
769 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
770 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
771 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
772 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
773 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
775 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
776 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
777 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
778 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
779 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
780 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
795 arith::CmpFOp op, OpAdaptor adaptor,
797 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
803 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
817 arith::CmpFOp op, OpAdaptor adaptor,
819 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
820 op.getPredicate() != arith::CmpFPredicate::UNO)
825 Value lhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getLhs());
826 Value rhsIsNan = rewriter.
create<spirv::IsNanOp>(loc, adaptor.getRhs());
828 Value replace = rewriter.
create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
829 if (op.getPredicate() == arith::CmpFPredicate::ORD)
830 replace = rewriter.
create<spirv::LogicalNotOp>(loc, replace);
841 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
844 adaptor.getTrueValue(),
845 adaptor.getFalseValue());
857 ConstantCompositeOpPattern,
858 ConstantScalarOpPattern,
865 RemSIOpGLPattern, RemSIOpCLPattern,
866 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
867 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
868 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
878 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
879 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
880 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
881 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
882 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
883 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
884 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
885 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
886 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
887 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
888 CmpIOpBooleanPattern, CmpIOpPattern,
889 CmpFOpNanNonePattern, CmpFOpPattern,
903 patterns.
add<CmpFOpNanKernelPattern>(typeConverter, patterns.
getContext(),
912 struct ConvertArithmeticToSPIRVPass
913 :
public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
914 void runOnOperation()
override {
927 auto cast = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
932 target->addLegalOp<UnrealizedConversionCastOp>();
943 std::unique_ptr<OperationPass<>>
945 return std::make_unique<ConvertArithmeticToSPIRVPass>();
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder)
Returns signed remainder for lhs and rhs and lets the result follow the sign of signOperand.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal...
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
An attribute that represents a reference to a dense vector or tensor object.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
#define DISPATCH(cmpPredicate, spirvOp)
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static llvm::ManagedStatic< PassManagerOptions > options
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Type getType() const
Return the type of this value.
Location getLoc()
The source location the operation was defined or derived from.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
This provides public APIs that all operations should have.
std::unique_ptr< OperationPass<> > createConvertArithmeticToSPIRVPass()
FloatAttr getF32FloatAttr(float value)
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
MLIRContext * getContext() const
bool emulateNon32BitScalarTypes
Whether to emulate non-32-bit scalar types with 32-bit scalar types if no native support.
Type conversion from builtin types to SPIR-V types for shader interface.