19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.
create<arith::ConstantOp>(
35 return rewriter.
create<arith::ConstantOp>(loc, attr);
44 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
47 Value a = op.getLhs();
48 Value b = op.getRhs();
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
70 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
73 Type type = op.getType();
74 Value a = op.getLhs();
75 Value b = op.getRhs();
80 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
81 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
82 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
83 loc, arith::CmpIPredicate::ne, a,
product);
86 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
88 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
90 Value signEqual = rewriter.create<arith::CmpIOp>(
91 loc, arith::CmpIPredicate::eq, aNeg, bNeg);
93 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
95 Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
97 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
110 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
112 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
115 Type type = op.getType();
116 Value a = op.getLhs();
117 Value b = op.getRhs();
119 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
120 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
121 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
122 loc, arith::CmpIPredicate::ne, a,
product);
126 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
128 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
130 Value signOpposite = rewriter.create<arith::CmpIOp>(
131 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
133 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
136 Value quotientMinusOne =
137 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
139 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
145 template <
typename OpTy, arith::CmpIPredicate pred>
150 LogicalResult matchAndRewrite(OpTy op,
152 Value lhs = op.getLhs();
153 Value rhs = op.getRhs();
155 Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
156 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
161 template <
typename OpTy, arith::CmpFPredicate pred>
166 LogicalResult matchAndRewrite(OpTy op,
168 Value lhs = op.getLhs();
169 Value rhs = op.getRhs();
173 static_assert(pred == arith::CmpFPredicate::UGT ||
174 pred == arith::CmpFPredicate::ULT,
175 "pred must be either UGT or ULT");
176 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
177 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
180 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
182 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
187 template <
typename OpTy, arith::CmpFPredicate pred>
192 LogicalResult matchAndRewrite(OpTy op,
194 Value lhs = op.getLhs();
195 Value rhs = op.getRhs();
199 static_assert(pred == arith::CmpFPredicate::UGT ||
200 pred == arith::CmpFPredicate::ULT,
201 "pred must be either UGT or ULT");
202 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
203 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
206 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
208 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
215 LogicalResult matchAndRewrite(arith::ExtFOp op,
218 auto operand = op.getOperand();
219 Type operandTy = operand.getType();
220 Type resultTy = op.getType();
225 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
228 Type i16Ty = b.getI16Type();
229 Type i32Ty = b.getI32Type();
230 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231 i16Ty = shapedTy.clone(i16Ty);
232 i32Ty = shapedTy.clone(i32Ty);
235 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
236 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
239 Value shl = b.create<arith::ShLIOp>(exti, c16);
240 Value result = b.create<arith::BitcastOp>(resultTy, shl);
242 rewriter.replaceOp(op, result);
247 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
249 LogicalResult matchAndRewrite(arith::TruncFOp op,
252 auto operand = op.getOperand();
253 Type operandTy = operand.getType();
254 Type resultTy = op.getType();
259 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
262 if (op.getRoundingmodeAttr()) {
263 return rewriter.notifyMatchFailure(
264 op,
"only applicable to default rounding mode.");
267 Type i16Ty = b.getI16Type();
268 Type i32Ty = b.getI32Type();
269 Type f32Ty = b.getF32Type();
270 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271 i16Ty = shapedTy.clone(i16Ty);
272 i32Ty = shapedTy.clone(i32Ty);
273 f32Ty = shapedTy.clone(f32Ty);
290 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
299 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
302 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
305 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
312 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
315 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
316 Value normalCaseResult_i16 =
317 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
321 b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
322 Value result = b.create<arith::BitcastOp>(resultTy, select);
323 rewriter.replaceOp(op, result);
328 struct ArithExpandOpsPass
329 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
332 void runOnOperation()
override {
338 target.addLegalDialect<arith::ArithDialect>();
356 target.addDynamicallyLegalOp<arith::ExtFOp>(
357 [](arith::ExtFOp op) {
363 target.addDynamicallyLegalOp<arith::TruncFOp>(
364 [](arith::TruncFOp op) {
383 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
388 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
396 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
397 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
398 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
399 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
400 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
401 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
402 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
403 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...