MLIR 23.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
20
21using namespace mlir;
22
23namespace mlir::math {
24#define GEN_PASS_DEF_MATHEXPANDOPSPASS
25#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26} // namespace mlir::math
27
28/// Create a float constant.
29static Value createFloatConst(Location loc, Type type, APFloat value,
30 OpBuilder &b) {
31 bool losesInfo = false;
32 auto eltType = getElementTypeOrSelf(type);
33 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
34 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35 APFloat::rmNearestTiesToEven, &losesInfo);
36 auto attr = b.getFloatAttr(eltType, value);
37 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
38 return arith::ConstantOp::create(b, loc,
39 DenseElementsAttr::get(shapedTy, attr));
40 }
41
42 return arith::ConstantOp::create(b, loc, attr);
43}
44
45static Value createFloatConst(Location loc, Type type, double value,
46 OpBuilder &b) {
47 return createFloatConst(loc, type, APFloat(value), b);
48}
49
50/// Create an integer constant.
51static Value createIntConst(Location loc, Type type, int64_t value,
52 OpBuilder &b) {
53 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
54 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
55 return arith::ConstantOp::create(b, loc,
56 DenseElementsAttr::get(shapedTy, attr));
57 }
58
59 return arith::ConstantOp::create(b, loc, attr);
60}
61
63 Type opType = operand.getType();
64 Type i64Ty = b.getI64Type();
65 if (auto shapedTy = dyn_cast<ShapedType>(opType))
66 i64Ty = shapedTy.clone(i64Ty);
67 Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
68 Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
69 // The truncation does not preserve the sign when the truncated
70 // value is -0. So here the sign is copied again.
71 return math::CopySignOp::create(b, fpFixedConvert, operand);
72}
73
74// sinhf(float x) -> (exp(x) - exp(-x)) / 2
75static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
76 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
77 Value operand = op.getOperand();
78 Type opType = operand.getType();
79
80 Value exp = math::ExpOp::create(b, operand);
81 Value neg = arith::NegFOp::create(b, operand);
82 Value nexp = math::ExpOp::create(b, neg);
83 Value sub = arith::SubFOp::create(b, exp, nexp);
84 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
85 Value res = arith::MulFOp::create(b, sub, half);
86 rewriter.replaceOp(op, res);
87 return success();
88}
89
90// coshf(float x) -> (exp(x) + exp(-x)) / 2
91static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
92 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
93 Value operand = op.getOperand();
94 Type opType = operand.getType();
95
96 Value exp = math::ExpOp::create(b, operand);
97 Value neg = arith::NegFOp::create(b, operand);
98 Value nexp = math::ExpOp::create(b, neg);
99 Value add = arith::AddFOp::create(b, exp, nexp);
100 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
101 Value res = arith::MulFOp::create(b, add, half);
102 rewriter.replaceOp(op, res);
103 return success();
104}
105
106/// Expands tanh op into
107/// 1-exp^{-2x} / 1+exp^{-2x}
108/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
109/// We compute a "signs" value which is -1 if input is negative and +1 if input
110/// is positive. Then multiply the input by this value, guaranteeing that the
111/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
112/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
113/// result by `sign(x)` to retain sign of the real result.
114static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
115 auto floatType = op.getOperand().getType();
116 Location loc = op.getLoc();
117 Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118 Value one = createFloatConst(loc, floatType, 1.0, rewriter);
119 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
120
121 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
129
130 // Normalize input to positive value: y = sign(x) * x
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
132
133 // Decompose on normalized input
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
139
140 // Multiply result by sign(x) to retain signs from negative inputs
141 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
142
143 return success();
144}
145
146// Converts math.tan to math.sin, math.cos, and arith.divf.
147static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
148 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
149 Value operand = op.getOperand();
150 Type type = operand.getType();
151 Value sin = math::SinOp::create(b, type, operand);
152 Value cos = math::CosOp::create(b, type, operand);
153 Value div = arith::DivFOp::create(b, type, sin, cos);
154 rewriter.replaceOp(op, div);
155 return success();
156}
157
158// asinh(float x) -> log(x + sqrt(x**2 + 1))
159static LogicalResult convertAsinhOp(math::AsinhOp op,
160 PatternRewriter &rewriter) {
161 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
162 Value operand = op.getOperand();
163 Type opType = operand.getType();
164
165 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
166 Value fma = math::FmaOp::create(b, operand, operand, one);
167 Value sqrt = math::SqrtOp::create(b, fma);
168 Value add = arith::AddFOp::create(b, operand, sqrt);
169 Value res = math::LogOp::create(b, add);
170 rewriter.replaceOp(op, res);
171 return success();
172}
173
174// acosh(float x) -> log(x + sqrt(x**2 - 1))
175static LogicalResult convertAcoshOp(math::AcoshOp op,
176 PatternRewriter &rewriter) {
177 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
178 Value operand = op.getOperand();
179 Type opType = operand.getType();
180
181 Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
182 Value fma = math::FmaOp::create(b, operand, operand, negOne);
183 Value sqrt = math::SqrtOp::create(b, fma);
184 Value add = arith::AddFOp::create(b, operand, sqrt);
185 Value res = math::LogOp::create(b, add);
186 rewriter.replaceOp(op, res);
187 return success();
188}
189
190// atanh(float x) -> log((1 + x) / (1 - x)) / 2
191static LogicalResult convertAtanhOp(math::AtanhOp op,
192 PatternRewriter &rewriter) {
193 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
194 Value operand = op.getOperand();
195 Type opType = operand.getType();
196
197 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
198 Value add = arith::AddFOp::create(b, operand, one);
199 Value neg = arith::NegFOp::create(b, operand);
200 Value sub = arith::AddFOp::create(b, neg, one);
201 Value div = arith::DivFOp::create(b, add, sub);
202 Value log = math::LogOp::create(b, div);
203 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
204 Value res = arith::MulFOp::create(b, log, half);
205 rewriter.replaceOp(op, res);
206 return success();
207}
208
209static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
210 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211 Value operandA = op.getOperand(0);
212 Value operandB = op.getOperand(1);
213 Value operandC = op.getOperand(2);
214 Type type = op.getType();
215 Value mult = arith::MulFOp::create(b, type, operandA, operandB);
216 Value add = arith::AddFOp::create(b, type, mult, operandC);
217 rewriter.replaceOp(op, add);
218 return success();
219}
220
221// Converts a ceilf() function to the following:
222// ceilf(float x) ->
223// y = (float)(int) x
224// if (x > y) then incr = 1 else incr = 0
225// y = y + incr <= replace this op with the ceilf op.
226static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
227 // Creating constants assumes the static shaped type.
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
230 return failure();
231
232 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
233 Value operand = op.getOperand();
234 Type opType = operand.getType();
235 Type operandETy = getElementTypeOrSelf(opType);
236 FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
237 const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
238
239 unsigned bitWidth = floatTy.getWidth();
240 unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
241 const int bias = (&semantics == &APFloat::Float8E8M0FNU())
242 ? -semantics.minExponent
243 : -(semantics.minExponent - 1);
244 bool hasNegativeZeroNaNEncoding =
245 (semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
246
247 Type iTy = rewriter.getIntegerType(bitWidth);
248 if (auto shapedTy = dyn_cast<ShapedType>(opType))
249 iTy = shapedTy.clone(iTy);
250
251 // For IEEE-like floating-point formats with an unbiased exponent ≥
252 // `mantissaWidth` falls into one of these categories:
253 // - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
254 // numbers are already integral, or
255 // - a special value (NaN or ±Inf), which also satisfies this exponent
256 // condition.
257 // For all such cases, `ceilf(x)` is defined to return `x` directly.
258 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
259 Value cMask = createIntConst(
260 op->getLoc(), iTy, static_cast<int64_t>((1ull << (bitWidth - 1)) - 1), b);
261 Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
262 Value cThreshold = createIntConst(
263 op->getLoc(), iTy,
264 static_cast<int64_t>((uint64_t(bias + mantissaWidth)) << mantissaWidth),
265 b);
266 Value isLargeExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge,
267 unsignedBits, cThreshold);
268 Value isSpecialValOrLargeVal = isLargeExp;
269
270 // In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
271 // all 0s in the exponent and mantissa, therefore requires an explicit check.
272 if (hasNegativeZeroNaNEncoding) {
273 Value cNegZeroBits = createIntConst(
274 op->getLoc(), iTy, static_cast<int64_t>(1ull << (bitWidth - 1)), b);
275 Value isNegZeroEncoding = arith::CmpIOp::create(
276 b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
277 isSpecialValOrLargeVal =
278 arith::OrIOp::create(b, isLargeExp, isNegZeroEncoding);
279 }
280
281 Value fpFixedConvert = createTruncatedFPValue(operand, b);
282
283 // Creating constants for later use.
284 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
285 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
286
287 Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
288 fpFixedConvert);
289 Value incrValue =
290 arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
291
292 Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
293 Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
294 rewriter.replaceOp(op, ret);
295 return success();
296}
297
298// Convert `math.fpowi` to a series of `arith.mulf` operations.
299// If the power is negative, we divide one by the result.
300// If both the base and power are zero, the result is 1.
301// In the case of non constant power, we convert the operation to `math.powf`.
302static LogicalResult convertFPowIOp(math::FPowIOp op,
303 PatternRewriter &rewriter) {
304 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
305 Value base = op.getOperand(0);
306 Value power = op.getOperand(1);
307 Type baseType = base.getType();
308
309 auto convertFPowItoPowf = [&]() -> LogicalResult {
310 Value castPowerToFp =
311 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
312 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
313 castPowerToFp);
314 rewriter.replaceOp(op, res);
315 return success();
316 };
317
318 Attribute cstAttr;
319 if (!matchPattern(power, m_Constant(&cstAttr)))
320 return convertFPowItoPowf();
321
322 APInt value;
323 if (!matchPattern(cstAttr, m_ConstantInt(&value)))
324 return convertFPowItoPowf();
325
326 int64_t powerInt = value.getSExtValue();
327 bool isNegative = powerInt < 0;
328 int64_t absPower = std::abs(powerInt);
329 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
330 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
331
332 while (absPower > 0) {
333 if (absPower & 1)
334 res = arith::MulFOp::create(b, baseType, base, res);
335 absPower >>= 1;
336 base = arith::MulFOp::create(b, baseType, base, base);
337 }
338
339 // Make sure not to introduce UB in case of negative power.
340 if (isNegative) {
341 auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
342 .getFloatSemantics();
343 Value zero =
344 createFloatConst(op->getLoc(), baseType,
345 APFloat::getZero(sem, /*Negative=*/false), rewriter);
346 Value negZero =
347 createFloatConst(op->getLoc(), baseType,
348 APFloat::getZero(sem, /*Negative=*/true), rewriter);
349 Value posInfinity =
350 createFloatConst(op->getLoc(), baseType,
351 APFloat::getInf(sem, /*Negative=*/false), rewriter);
352 Value negInfinity =
353 createFloatConst(op->getLoc(), baseType,
354 APFloat::getInf(sem, /*Negative=*/true), rewriter);
355 Value zeroEqCheck =
356 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
357 Value negZeroEqCheck =
358 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
359 res = arith::DivFOp::create(b, baseType, one, res);
360 res =
361 arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
362 res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
363 res);
364 }
365
366 rewriter.replaceOp(op, res);
367 return success();
368}
369
370// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
371// Some special cases where b is constant are handled separately:
372// when b == 0, or |b| == 0.5, 1.0, or 2.0.
373static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
374 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
375 Value operandA = op.getOperand(0);
376 Value operandB = op.getOperand(1);
377 auto typeA = operandA.getType();
378 auto typeB = operandB.getType();
379
380 auto &sem =
381 cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
382 APFloat valueB(sem);
383 auto mulf = [&](Value x, Value y) -> Value {
384 return arith::MulFOp::create(b, x, y);
385 };
386 if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
387 if (valueB.isZero()) {
388 // a^0 -> 1
389 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
390 rewriter.replaceOp(op, one);
391 return success();
392 }
393 if (valueB.isExactlyValue(1.0)) {
394 // a^1 -> a
395 rewriter.replaceOp(op, operandA);
396 return success();
397 }
398 if (valueB.isExactlyValue(-1.0)) {
399 // a^(-1) -> 1 / a
400 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
401 Value div = arith::DivFOp::create(b, one, operandA);
402 rewriter.replaceOp(op, div);
403 return success();
404 }
405 if (valueB.isExactlyValue(0.5)) {
406 // a^(1/2) -> sqrt(a)
407 Value sqrt = math::SqrtOp::create(b, operandA);
408 rewriter.replaceOp(op, sqrt);
409 return success();
410 }
411 if (valueB.isExactlyValue(-0.5)) {
412 // a^(-1/2) -> 1 / sqrt(a)
413 Value rsqrt = math::RsqrtOp::create(b, operandA);
414 rewriter.replaceOp(op, rsqrt);
415 return success();
416 }
417 if (valueB.isExactlyValue(2.0)) {
418 // a^2 -> a * a
419 rewriter.replaceOp(op, mulf(operandA, operandA));
420 return success();
421 }
422 if (valueB.isExactlyValue(-2.0)) {
423 // a^(-2) -> 1 / (a * a)
424 Value one =
425 createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
426 Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
427 rewriter.replaceOp(op, div);
428 return success();
429 }
430 if (valueB.isExactlyValue(3.0)) {
431 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
432 return success();
433 }
434 }
435
436 Value logA = math::LogOp::create(b, operandA);
437 Value mult = arith::MulFOp::create(b, operandB, logA);
438 Value expResult = math::ExpOp::create(b, mult);
439 rewriter.replaceOp(op, expResult);
440 return success();
441}
442
443// exp2f(float x) -> exp(x * ln(2))
444// Proof: Let's say 2^x = y
445// ln(2^x) = ln(y)
446// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
447static LogicalResult convertExp2fOp(math::Exp2Op op,
448 PatternRewriter &rewriter) {
449 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
450 Value operand = op.getOperand();
451 Type opType = operand.getType();
452 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
453 Value mult = arith::MulFOp::create(b, opType, operand, ln2);
454 Value exp = math::ExpOp::create(b, op->getLoc(), mult);
455 rewriter.replaceOp(op, exp);
456 return success();
457}
458
459static LogicalResult convertRoundOp(math::RoundOp op,
460 PatternRewriter &rewriter) {
461 Location loc = op.getLoc();
462 ImplicitLocOpBuilder b(loc, rewriter);
463 Value operand = op.getOperand();
464 Type opType = operand.getType();
465 Type opEType = getElementTypeOrSelf(opType);
466
467 if (!opEType.isF32()) {
468 return rewriter.notifyMatchFailure(op, "not a round of f32.");
469 }
470
471 Type i32Ty = b.getI32Type();
472 if (auto shapedTy = dyn_cast<ShapedType>(opType))
473 i32Ty = shapedTy.clone(i32Ty);
474
475 Value half = createFloatConst(loc, opType, 0.5, b);
476 Value c23 = createIntConst(loc, i32Ty, 23, b);
477 Value c127 = createIntConst(loc, i32Ty, 127, b);
478 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
479
480 Value incrValue = math::CopySignOp::create(b, half, operand);
481 Value add = arith::AddFOp::create(b, opType, operand, incrValue);
482 Value fpFixedConvert = createTruncatedFPValue(add, b);
483
484 // There are three cases where adding 0.5 to the value and truncating by
485 // converting to an i64 does not result in the correct behavior:
486 //
487 // 1. Special values: +-inf and +-nan
488 // Casting these special values to i64 has undefined behavior. To identify
489 // these values, we use the fact that these values are the only float
490 // values with the maximum possible biased exponent.
491 //
492 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
493 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
494 // errors that sometimes round the value up and sometimes round the value
495 // down. For example:
496 // 8388608.0 + 0.5 = 8388608.0
497 // 8388609.0 + 0.5 = 8388610.0
498 //
499 // 3. Very large values: |x| > INT_64_MAX
500 // Casting to i64 a value greater than the max i64 value will overflow the
501 // i64 leading to wrong outputs.
502 //
503 // All three cases satisfy the property `biasedExp >= 23`.
504 Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
505 Value operandExp = arith::AndIOp::create(
506 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
507 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
508 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
509 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
510
511 Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
512 fpFixedConvert);
513 rewriter.replaceOp(op, result);
514 return success();
515}
516
517// Converts math.ctlz to scf and arith operations. This is done
518// by performing a binary search on the bits.
519static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
520 PatternRewriter &rewriter) {
521 auto operand = op.getOperand();
522 auto operandTy = operand.getType();
523 auto eTy = getElementTypeOrSelf(operandTy);
524 Location loc = op.getLoc();
525
526 // Only expand for integer or float element types (index has no fixed bitwidth).
527 if (!eTy.isIntOrFloat()) {
528 return rewriter.notifyMatchFailure(op, "ctlz expansion only supports int or float types");
529 }
530
531 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
532 if (bitwidth > 64)
533 return failure();
534
535 uint64_t allbits = -1;
536 if (bitwidth < 64) {
537 allbits = allbits >> (64 - bitwidth);
538 }
539
540 Value x = operand;
541 Value count = createIntConst(loc, operandTy, 0, rewriter);
542 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
543 auto half = bw / 2;
544 auto bits = createIntConst(loc, operandTy, half, rewriter);
545 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
546
547 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
548 x, mask);
549 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
550 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
551
552 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
553 count = arith::SelectOp::create(rewriter, loc, pred, add, count);
554 }
555
556 Value zero = createIntConst(loc, operandTy, 0, rewriter);
557 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
558 operand, zero);
559
560 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
561 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
562 rewriter.replaceOp(op, sel);
563 return success();
564}
565
566// Convert `math.roundeven` into `math.round` + arith ops
567static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
568 PatternRewriter &rewriter) {
569 Location loc = op.getLoc();
570 ImplicitLocOpBuilder b(loc, rewriter);
571 auto operand = op.getOperand();
572 Type operandTy = operand.getType();
573 Type resultTy = op.getType();
574 Type operandETy = getElementTypeOrSelf(operandTy);
575 Type resultETy = getElementTypeOrSelf(resultTy);
576
577 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
578 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
579 }
580
581 Type fTy = operandTy;
582 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
583 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
584 iTy = shapedTy.clone(iTy);
585 }
586
587 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
588 // The width returned by getFPMantissaWidth includes the integer bit.
589 unsigned mantissaWidth =
590 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
591 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
592
593 // The names of the variables correspond to f32.
594 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
595 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
596 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
597 Value c1Float = createFloatConst(loc, fTy, 1.0, b);
598 Value c0 = createIntConst(loc, iTy, 0, b);
599 Value c1 = createIntConst(loc, iTy, 1, b);
600 Value cNeg1 = createIntConst(loc, iTy, -1, b);
601 Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
602 Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
603 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
604 Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
605 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
606 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
607
608 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
609 Value round = math::RoundOp::create(b, operand);
610 Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
611
612 // Get biased exponents for operand and round(operand)
613 Value operandExp = arith::AndIOp::create(
614 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
615 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
616 Value roundExp = arith::AndIOp::create(
617 b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
618 Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
619
620 auto safeShiftRight = [&](Value x, Value shift) -> Value {
621 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
622 Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
623 clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
624 return arith::ShRUIOp::create(b, x, clampedShift);
625 };
626
627 auto maskMantissa = [&](Value mantissa,
628 Value mantissaMaskRightShift) -> Value {
629 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
630 return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
631 };
632
633 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
634 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
635 // with `biasedExp > 23` (numbers where there is not enough precision to store
636 // decimals) are always even, and they satisfy the even condition trivially
637 // since the mantissa without all its bits is zero. The even condition
638 // is also true for +-0, since they have `biasedExp = -127` and the entire
639 // mantissa is zero. The case of +-1 has to be handled separately. Here
640 // we identify these values by noting that +-1 are the only whole numbers with
641 // `biasedExp == 0`.
642 //
643 // The special values +-inf and +-nan also satisfy the same property that
644 // whole non-unit even numbers satisfy. In particular, the special values have
645 // `biasedExp > 23`, so they get treated as large numbers with no room for
646 // decimals, which are always even.
647 Value roundBiasedExpEq0 =
648 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
649 Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
650 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
651 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
652 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
653 roundIsNotEvenOrSpecialVal =
654 arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
655
656 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
657 // integers if the bit at index `biasedExp` starting from the left in the
658 // mantissa is 1 and all the bits to the right are zero. Values with
659 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
660 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
661 // so these are handled separately. In particular, if `biasedExp == -1`, the
662 // value is halfway if the entire mantissa is zero.
663 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
664 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
665 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
666 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
667 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
668 Value operandIsHalfway =
669 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
670 expectedOperandMaskedMantissa);
671 // Ensure `biasedExp` is in the valid range for half values.
672 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
673 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
674 Value operandBiasedExpLt23 = arith::CmpIOp::create(
675 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
676 operandIsHalfway =
677 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
678 operandIsHalfway =
679 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
680
681 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
682 // case where `round` rounded in the opposite direction of `roundeven`.
683 Value sign = math::CopySignOp::create(b, c1Float, operand);
684 Value roundShifted = arith::SubFOp::create(b, round, sign);
685 // If the rounded value is even or a special value, we default to the behavior
686 // of `math.round`.
687 Value needsShift =
688 arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
689 Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
690 // The `x - sign` adjustment does not preserve the sign when we are adjusting
691 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
692 // rounded to -0.0.
693 result = math::CopySignOp::create(b, result, operand);
694 rewriter.replaceOp(op, result);
695 return success();
696}
697
698// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
699static LogicalResult convertRsqrtOp(math::RsqrtOp op,
700 PatternRewriter &rewriter) {
701
702 auto operand = op.getOperand();
703 auto operandTy = operand.getType();
704 // Operand type must be shatic shaped type to create const float.
705 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
706 if (shapedOperandType && !shapedOperandType.hasStaticShape())
707 return failure();
708
709 auto eTy = getElementTypeOrSelf(operandTy);
710 if (!isa<FloatType>(eTy))
711 return failure();
712
713 Location loc = op->getLoc();
714 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
715 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
716 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
717 return success();
718}
719
720// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
721static LogicalResult convertClampfOp(math::ClampFOp op,
722 PatternRewriter &rewriter) {
723 auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
724 op.getMax(), op.getFastmath());
725 rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMin(),
726 op.getFastmath());
727 return success();
728}
729
731 ArrayRef<StringRef> opMnemonics) {
732 auto filter = [&](StringRef name) {
733 // This should be a static assert and `consume_front` take a twine, but none
734 // is currently possible. TODO: augment `StringRef::consume_front` and make
735 // `getDialectNamespace` use `std::string_view`.
736 assert("math" == MathDialect::getDialectNamespace());
737 name.consume_front("math.");
738 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
739 };
740 if (filter(CountLeadingZerosOp::getOperationName()))
742 if (filter(SinhOp::getOperationName()))
744 if (filter(CoshOp::getOperationName()))
746 if (filter(TanOp::getOperationName()))
748 if (filter(TanhOp::getOperationName()))
750 if (filter(AsinhOp::getOperationName()))
752 if (filter(AcoshOp::getOperationName()))
754 if (filter(AtanhOp::getOperationName()))
756 if (filter(FmaOp::getOperationName()))
758 if (filter(CeilOp::getOperationName()))
760 if (filter(Exp2Op::getOperationName()))
762 if (filter(PowFOp::getOperationName()))
764 if (filter(FPowIOp::getOperationName()))
766 if (filter(RoundOp::getOperationName()))
768 if (filter(RoundEvenOp::getOperationName()))
770 if (filter(RsqrtOp::getOperationName()))
772 if (filter(ClampFOp::getOperationName()))
774}
775
776//===----------------------------------------------------------------------===//
777// MathExpandOpsPass pass
778//===----------------------------------------------------------------------===//
779namespace {
780struct MathExpandOpsPass final
781 : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
782 using MathExpandOpsPassBase::MathExpandOpsPassBase;
783
784 void runOnOperation() override {
786 SmallVector<StringRef> mnemonics =
787 llvm::to_vector_of<StringRef>(opMnemonics);
789 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
790 return signalPassFailure();
791 }
792};
793} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
Definition ExpandOps.cpp:62
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertClampfOp(math::ClampFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:91
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:75
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
Definition ExpandOps.cpp:29
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
Definition ExpandOps.cpp:51
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
#define add(a, b)
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition Matchers.h:520