17 #define GEN_PASS_DEF_ARITHEXPANDOPS
18 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
27 return rewriter.
create<arith::ConstantOp>(
40 Value a = op.getLhs();
41 Value b = op.getRhs();
44 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
46 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
47 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
48 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
49 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
62 Type type = op.getType();
63 Value a = op.getLhs();
64 Value b = op.getRhs();
70 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
71 Value x = rewriter.create<arith::SelectOp>(loc,
compare, minusOne, plusOne);
73 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
74 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
75 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
77 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
78 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
79 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
88 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
90 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
92 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
94 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
95 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
96 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
98 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
100 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
109 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
114 Type type = op.getType();
115 Value a = op.getLhs();
116 Value b = op.getRhs();
122 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
123 Value x = rewriter.create<arith::SelectOp>(loc,
compare, plusOne, minusOne);
125 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
126 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
127 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
129 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
140 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
142 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
144 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
145 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
146 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
148 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
150 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
156 template <
typename OpTy, arith::CmpFPredicate pred>
163 Value lhs = op.getLhs();
164 Value rhs = op.getRhs();
168 static_assert(pred == arith::CmpFPredicate::UGT ||
169 pred == arith::CmpFPredicate::ULT,
170 "pred must be either UGT or ULT");
171 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
172 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
175 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
177 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
182 struct ArithExpandOpsPass
183 :
public arith::impl::ArithExpandOpsBase<ArithExpandOpsPass> {
184 void runOnOperation()
override {
190 target.addLegalDialect<arith::ArithDialect>();
201 std::move(patterns))))
211 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
219 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
220 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>
226 return std::make_unique<ArithExpandOpsPass>();
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
std::unique_ptr< Pass > createArithExpandOpsPass()
Create a pass to legalize Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...