19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.
create<arith::ConstantOp>(
35 return rewriter.
create<arith::ConstantOp>(loc, attr);
47 Value a = op.getLhs();
48 Value b = op.getRhs();
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op,
compare, zero, plusOne);
69 Type type = op.getType();
70 Value a = op.getLhs();
71 Value b = op.getRhs();
77 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78 Value x = rewriter.create<arith::SelectOp>(loc,
compare, minusOne, plusOne);
80 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
84 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
97 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
99 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
101 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
105 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
107 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
116 struct FloorDivSIOpConverter :
public OpRewritePattern<arith::FloorDivSIOp> {
121 Type type = op.getType();
122 Value a = op.getLhs();
123 Value b = op.getRhs();
129 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
130 Value x = rewriter.create<arith::SelectOp>(loc,
compare, plusOne, minusOne);
132 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
133 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
134 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
136 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
145 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
147 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
149 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
151 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
152 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
153 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
155 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
157 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
163 template <
typename OpTy, arith::CmpFPredicate pred>
170 Value lhs = op.getLhs();
171 Value rhs = op.getRhs();
175 static_assert(pred == arith::CmpFPredicate::UGT ||
176 pred == arith::CmpFPredicate::ULT,
177 "pred must be either UGT or ULT");
178 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
179 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
182 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
184 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
195 Type operandTy = operand.getType();
196 Type resultTy = op.getType();
201 return rewriter.notifyMatchFailure(op,
"not a ext of bf16 to f32.");
204 Type i16Ty = b.getI16Type();
205 Type i32Ty = b.getI32Type();
206 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
207 i16Ty = shapedTy.clone(i16Ty);
208 i32Ty = shapedTy.clone(i32Ty);
211 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
212 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
215 Value shl = b.create<arith::ShLIOp>(exti, c16);
216 Value result = b.create<arith::BitcastOp>(resultTy, shl);
218 rewriter.replaceOp(op, result);
223 struct BFloat16TruncFOpConverter :
public OpRewritePattern<arith::TruncFOp> {
229 Type operandTy = operand.getType();
230 Type resultTy = op.getType();
235 return rewriter.notifyMatchFailure(op,
"not a trunc of f32 to bf16.");
238 Type i1Ty = b.getI1Type();
239 Type i16Ty = b.getI16Type();
240 Type i32Ty = b.getI32Type();
241 Type f32Ty = b.getF32Type();
242 if (
auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
243 i1Ty = shapedTy.clone(i1Ty);
244 i16Ty = shapedTy.clone(i16Ty);
245 i32Ty = shapedTy.clone(i32Ty);
246 f32Ty = shapedTy.clone(f32Ty);
249 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
260 Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
265 cManRound = b.create<arith::SubIOp>(cManRound, sign);
268 Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
269 Value manRound = b.create<arith::AddIOp>(man, cManRound);
272 Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
273 Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
276 Value exp = b.create<arith::AndIOp>(bitcast, expMask);
277 Value expCarry = b.create<arith::AddIOp>(exp, manRound);
278 expCarry = b.create<arith::AndIOp>(expCarry, expMask);
282 b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
283 exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
286 Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
287 Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
288 man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
291 Value rounded = b.create<arith::ShLIOp>(sign, c31);
292 rounded = b.create<arith::OrIOp>(rounded, exp);
293 rounded = b.create<arith::OrIOp>(rounded, man);
296 Value shr = b.create<arith::ShRUIOp>(rounded, c16);
297 Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
298 Value result = b.create<arith::BitcastOp>(resultTy, trunc);
300 rewriter.replaceOp(op, result);
305 struct ArithExpandOpsPass
306 :
public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
307 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
309 void runOnOperation()
override {
315 target.addLegalDialect<arith::ArithDialect>();
327 target.addDynamicallyLegalOp<arith::ExtFOp>(
328 [](arith::ExtFOp op) {
334 target.addDynamicallyLegalOp<arith::TruncFOp>(
335 [](arith::TruncFOp op) {
344 std::move(patterns))))
354 .
add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
359 patterns.
add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
367 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
368 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...