19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.
create<arith::ConstantOp>(
35 return rewriter.
create<arith::ConstantOp>(loc, attr);
47 Value a = op.getLhs();
48 Value b = op.getRhs();
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
69 Type type = op.getType();
70 Value a = op.getLhs();
71 Value b = op.getRhs();
77 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78 Value x = rewriter.create<arith::SelectOp>(loc,
compare, minusOne, plusOne);
80 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
84 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
97 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
99 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
101 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
105 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
107 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
120 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
125 Type type = op.getType();
126 Value a = op.getLhs();
127 Value b = op.getRhs();
129 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132 loc, arith::CmpIPredicate::ne, a,
product);
136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
140 Value signOpposite = rewriter.create<arith::CmpIOp>(
141 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
143 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
146 Value quotientMinusOne =
147 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
149 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
155 template <
typename OpTy, arith::CmpFPredicate pred>
162 Value lhs = op.getLhs();
163 Value rhs = op.getRhs();
167 static_assert(pred == arith::CmpFPredicate::UGT ||
168 pred == arith::CmpFPredicate::ULT,
169 "pred must be either UGT or ULT");
170 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
171 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
174 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
176 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
181 template <
typename OpTy, arith::CmpFPredicate pred>
188 Value lhs = op.getLhs();
189 Value rhs = op.getRhs();
193 static_assert(pred == arith::CmpFPredicate::UGT ||
194 pred == arith::CmpFPredicate::ULT,
195 "pred must be either UGT or ULT");
196 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
197 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
200 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
202 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
213 Type operandTy = operand.getType();
214 Type resultTy = op.getType();
219 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
222 Type i16Ty = b.getI16Type();
223 Type i32Ty = b.getI32Type();
224 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
225 i16Ty = shapedTy.clone(i16Ty);
226 i32Ty = shapedTy.clone(i32Ty);
229 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
230 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
233 Value shl = b.create<arith::ShLIOp>(exti, c16);
234 Value result = b.create<arith::BitcastOp>(resultTy, shl);
236 rewriter.replaceOp(op, result);
241 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
247 Type operandTy = operand.getType();
248 Type resultTy = op.getType();
253 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
256 if (op.getRoundingmodeAttr()) {
257 return rewriter.notifyMatchFailure(
258 op,
"only applicable to default rounding mode.");
261 Type i16Ty = b.getI16Type();
262 Type i32Ty = b.getI32Type();
263 Type f32Ty = b.getF32Type();
264 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
265 i16Ty = shapedTy.clone(i16Ty);
266 i32Ty = shapedTy.clone(i32Ty);
267 f32Ty = shapedTy.clone(f32Ty);
284 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
293 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
296 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
299 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
306 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
309 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
310 Value normalCaseResult_i16 =
311 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
315 b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
316 Value result = b.create<arith::BitcastOp>(resultTy, select);
317 rewriter.replaceOp(op, result);
322 struct ArithExpandOpsPass
323 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
324 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
326 void runOnOperation()
override {
332 target.addLegalDialect<arith::ArithDialect>();
346 target.addDynamicallyLegalOp<arith::ExtFOp>(
347 [](arith::ExtFOp op) {
353 target.addDynamicallyLegalOp<arith::TruncFOp>(
354 [](arith::TruncFOp op) {
363 std::move(patterns))))
373 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
378 patterns.
add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
386 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
387 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
388 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
389 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...