19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.
create<arith::ConstantOp>(
35 return rewriter.
create<arith::ConstantOp>(loc, attr);
44 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
47 Value a = op.getLhs();
48 Value b = op.getRhs();
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
66 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
69 Type type = op.getType();
70 Value a = op.getLhs();
71 Value b = op.getRhs();
77 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78 Value x = rewriter.create<arith::SelectOp>(loc,
compare, minusOne, plusOne);
80 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
84 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
97 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
99 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
101 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
105 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
107 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
120 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
122 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
125 Type type = op.getType();
126 Value a = op.getLhs();
127 Value b = op.getRhs();
129 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132 loc, arith::CmpIPredicate::ne, a,
product);
136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
140 Value signOpposite = rewriter.create<arith::CmpIOp>(
141 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
143 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
146 Value quotientMinusOne =
147 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
149 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
155 template <
typename OpTy, arith::CmpIPredicate pred>
160 LogicalResult matchAndRewrite(OpTy op,
162 Value lhs = op.getLhs();
163 Value rhs = op.getRhs();
165 Value cmp = rewriter.create<arith::CmpIOp>(op.
getLoc(), pred, lhs, rhs);
166 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
171 template <
typename OpTy, arith::CmpFPredicate pred>
176 LogicalResult matchAndRewrite(OpTy op,
178 Value lhs = op.getLhs();
179 Value rhs = op.getRhs();
183 static_assert(pred == arith::CmpFPredicate::UGT ||
184 pred == arith::CmpFPredicate::ULT,
185 "pred must be either UGT or ULT");
186 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
187 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
190 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
192 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
197 template <
typename OpTy, arith::CmpFPredicate pred>
202 LogicalResult matchAndRewrite(OpTy op,
204 Value lhs = op.getLhs();
205 Value rhs = op.getRhs();
209 static_assert(pred == arith::CmpFPredicate::UGT ||
210 pred == arith::CmpFPredicate::ULT,
211 "pred must be either UGT or ULT");
212 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
213 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
216 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
218 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
225 LogicalResult matchAndRewrite(arith::ExtFOp op,
229 Type operandTy = operand.getType();
230 Type resultTy = op.getType();
235 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
238 Type i16Ty = b.getI16Type();
239 Type i32Ty = b.getI32Type();
240 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
241 i16Ty = shapedTy.clone(i16Ty);
242 i32Ty = shapedTy.clone(i32Ty);
245 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
246 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
249 Value shl = b.create<arith::ShLIOp>(exti, c16);
250 Value result = b.create<arith::BitcastOp>(resultTy, shl);
252 rewriter.replaceOp(op, result);
257 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
259 LogicalResult matchAndRewrite(arith::TruncFOp op,
263 Type operandTy = operand.getType();
264 Type resultTy = op.getType();
269 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
272 if (op.getRoundingmodeAttr()) {
273 return rewriter.notifyMatchFailure(
274 op,
"only applicable to default rounding mode.");
277 Type i16Ty = b.getI16Type();
278 Type i32Ty = b.getI32Type();
279 Type f32Ty = b.getF32Type();
280 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
281 i16Ty = shapedTy.clone(i16Ty);
282 i32Ty = shapedTy.clone(i32Ty);
283 f32Ty = shapedTy.clone(f32Ty);
300 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
309 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
312 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
315 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
322 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
325 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
326 Value normalCaseResult_i16 =
327 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
331 b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
332 Value result = b.create<arith::BitcastOp>(resultTy, select);
333 rewriter.replaceOp(op, result);
338 struct ArithExpandOpsPass
339 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
340 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
342 void runOnOperation()
override {
348 target.addLegalDialect<arith::ArithDialect>();
366 target.addDynamicallyLegalOp<arith::ExtFOp>(
367 [](arith::ExtFOp op) {
373 target.addDynamicallyLegalOp<arith::TruncFOp>(
374 [](arith::TruncFOp op) {
383 std::move(patterns))))
393 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
398 patterns.
add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
406 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
407 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
408 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
409 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
410 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
411 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
412 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
413 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...