22#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
39static FailureOr<FuncOp>
42 auto i32Type = IntegerType::get(symTable->getContext(), 32);
43 auto i64Type = IntegerType::get(symTable->getContext(), 64);
44 std::string funcName = (llvm::Twine(
"_mlir_apfloat_") + name).str();
46 {i32Type, i64Type, i64Type}, symbolTables);
53template <
typename Fn,
typename... Values>
57 auto vecTy1 = dyn_cast<VectorType>(operand1.
getType());
60 assert(vecTy1 == dyn_cast<VectorType>(operand2.
getType()) &&
61 "expected same vector types");
65 return fn(operand1, operand2, resultType);
70 vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
74 scalars2.assign(vecTy1.getNumElements(),
Value());
78 vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
82 auto resultVecType = cast<VectorType>(resultType);
84 for (
auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
85 Value result = fn(scalar1, scalar2, resultVecType.getElementType());
90 return vector::FromElementsOp::create(
92 vecTy1.cloneWith(std::nullopt, results.front().getType()),
101 Type type = value.getType();
102 if (
auto vecTy = dyn_cast<VectorType>(type)) {
103 type = vecTy.getElementType();
107 op,
"only integers and floats (or vectors thereof) are supported");
111 "bitwidth > 64 bits is not supported");
117template <
typename OpTy>
132 FailureOr<FuncOp> fn =
141 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
144 auto floatTy = cast<FloatType>(resultType);
145 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
146 auto int64Type = rewriter.getI64Type();
147 Value lhsBits = arith::ExtUIOp::create(
148 rewriter, loc, int64Type,
149 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
150 Value rhsBits = arith::ExtUIOp::create(
151 rewriter, loc, int64Type,
152 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
155 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
156 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
157 auto resultOp = func::CallOp::create(rewriter, loc,
159 SymbolRefAttr::get(*fn), params);
162 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
163 resultOp->getResult(0));
164 return arith::BitcastOp::create(rewriter, loc, floatTy,
167 rewriter.replaceOp(op, repl);
175template <
typename OpTy>
187 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
188 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
189 FailureOr<FuncOp> fn =
191 {i32Type, i32Type, i64Type});
199 rewriter, loc, op.getOperand(),
Value(), op.getType(),
202 auto inFloatTy = cast<FloatType>(operand1.getType());
203 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
204 Value operandBits = arith::ExtUIOp::create(
205 rewriter, loc, i64Type,
206 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
209 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
210 auto outFloatTy = cast<FloatType>(resultType);
212 getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
213 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
214 auto resultOp = func::CallOp::create(rewriter, loc,
216 SymbolRefAttr::get(*fn), params);
219 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
220 Value truncatedBits = arith::TruncIOp::create(
221 rewriter, loc, outIntWType, resultOp->getResult(0));
222 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
225 rewriter.replaceOp(op, repl);
232template <
typename OpTy>
245 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
246 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
247 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
248 FailureOr<FuncOp> fn =
250 {i32Type, i32Type, i1Type, i64Type});
258 rewriter, loc, op.getOperand(),
Value(), op.getType(),
261 auto inFloatTy = cast<FloatType>(operand1.getType());
262 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
263 Value operandBits = arith::ExtUIOp::create(
264 rewriter, loc, i64Type,
265 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
268 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
269 auto outIntTy = cast<IntegerType>(resultType);
270 Value outWidthValue = arith::ConstantOp::create(
271 rewriter, loc, i32Type,
272 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
273 Value isUnsignedValue = arith::ConstantOp::create(
274 rewriter, loc, i1Type,
275 rewriter.getIntegerAttr(i1Type, isUnsigned));
276 SmallVector<Value> params = {inSemValue, outWidthValue,
277 isUnsignedValue, operandBits};
278 auto resultOp = func::CallOp::create(rewriter, loc,
280 SymbolRefAttr::get(*fn), params);
283 return arith::TruncIOp::create(rewriter, loc, outIntTy,
284 resultOp->getResult(0));
286 rewriter.replaceOp(op, repl);
294template <
typename OpTy>
307 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
308 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
309 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
311 rewriter,
symTable,
"_mlir_apfloat_convert_from_int",
312 {i32Type, i32Type, i1Type, i64Type});
320 rewriter, loc, op.getOperand(),
Value(), op.getType(),
323 auto inIntTy = cast<IntegerType>(operand1.getType());
324 Value operandBits = operand1;
325 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
328 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
331 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
336 auto outFloatTy = cast<FloatType>(resultType);
339 Value inWidthValue = arith::ConstantOp::create(
340 rewriter, loc, i32Type,
342 Value isUnsignedValue = arith::ConstantOp::create(
343 rewriter, loc, i1Type,
346 isUnsignedValue, operandBits};
347 auto resultOp = func::CallOp::create(rewriter, loc,
349 SymbolRefAttr::get(*fn), params);
352 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
353 Value truncatedBits = arith::TruncIOp::create(
354 rewriter, loc, outIntWType, resultOp->getResult(0));
355 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
358 rewriter.replaceOp(op, repl);
377 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
378 auto i8Type = IntegerType::get(
symTable->getContext(), 8);
379 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
380 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
381 FailureOr<FuncOp> fn =
383 {i32Type, i64Type, i64Type},
nullptr, i8Type);
391 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
394 auto floatTy = cast<FloatType>(lhs.getType());
395 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
396 Value lhsBits = arith::ExtUIOp::create(
397 rewriter, loc, i64Type,
398 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
399 Value rhsBits = arith::ExtUIOp::create(
400 rewriter, loc, i64Type,
401 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
404 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
405 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
406 Value comparisonResult =
407 func::CallOp::create(rewriter, loc,
TypeRange(i8Type),
408 SymbolRefAttr::get(*fn), params)
413 auto checkResult = [&](llvm::APFloat::cmpResult val) {
414 return arith::CmpIOp::create(
415 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
416 arith::ConstantOp::create(
417 rewriter, loc, i8Type,
425 Value first = checkResult(vals.front());
426 if (vals.size() == 1)
428 Value rest = checkResults(vals.drop_front());
429 return arith::OrIOp::create(rewriter, loc, first, rest)
435 switch (op.getPredicate()) {
436 case arith::CmpFPredicate::AlwaysFalse:
438 arith::ConstantOp::create(rewriter, loc, i1Type,
442 case arith::CmpFPredicate::OEQ:
443 result = checkResult(llvm::APFloat::cmpEqual);
445 case arith::CmpFPredicate::OGT:
446 result = checkResult(llvm::APFloat::cmpGreaterThan);
448 case arith::CmpFPredicate::OGE:
450 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
452 case arith::CmpFPredicate::OLT:
453 result = checkResult(llvm::APFloat::cmpLessThan);
455 case arith::CmpFPredicate::OLE:
457 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
459 case arith::CmpFPredicate::ONE:
462 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
464 case arith::CmpFPredicate::ORD:
466 result = checkResults({llvm::APFloat::cmpLessThan,
467 llvm::APFloat::cmpGreaterThan,
468 llvm::APFloat::cmpEqual});
470 case arith::CmpFPredicate::UEQ:
472 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
474 case arith::CmpFPredicate::UGT:
476 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
478 case arith::CmpFPredicate::UGE:
479 result = checkResults({llvm::APFloat::cmpUnordered,
480 llvm::APFloat::cmpGreaterThan,
481 llvm::APFloat::cmpEqual});
483 case arith::CmpFPredicate::ULT:
485 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
487 case arith::CmpFPredicate::ULE:
488 result = checkResults({llvm::APFloat::cmpUnordered,
489 llvm::APFloat::cmpLessThan,
490 llvm::APFloat::cmpEqual});
492 case arith::CmpFPredicate::UNE:
494 result = checkResults({llvm::APFloat::cmpLessThan,
495 llvm::APFloat::cmpGreaterThan,
496 llvm::APFloat::cmpUnordered});
498 case arith::CmpFPredicate::UNO:
499 result = checkResult(llvm::APFloat::cmpUnordered);
501 case arith::CmpFPredicate::AlwaysTrue:
503 arith::ConstantOp::create(rewriter, loc, i1Type,
510 rewriter.replaceOp(op, repl);
528 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
529 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
531 rewriter,
symTable,
"_mlir_apfloat_neg", {i32Type, i64Type});
542 auto floatTy = cast<FloatType>(operand1.getType());
543 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
544 Value operandBits = arith::ExtUIOp::create(
545 rewriter, loc, i64Type,
546 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
549 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
550 SmallVector<Value> params = {semValue, operandBits};
552 func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
553 SymbolRefAttr::get(*fn), params)
557 Value truncatedBits =
558 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
559 return arith::BitcastOp::create(rewriter, loc, floatTy,
562 rewriter.replaceOp(op, repl);
570struct ArithToAPFloatConversionPass final
571 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
574 void runOnOperation()
override;
577void ArithToAPFloatConversionPass::runOnOperation() {
579 RewritePatternSet
patterns(context);
580 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context,
"add",
582 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
583 context,
"subtract", getOperation());
584 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
585 context,
"multiply", getOperation());
586 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
587 context,
"divide", getOperation());
588 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
589 context,
"remainder", getOperation());
590 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
591 context,
"minnum", getOperation());
592 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
593 context,
"maxnum", getOperation());
594 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
595 context,
"minimum", getOperation());
596 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
597 context,
"maximum", getOperation());
599 .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
600 CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
601 context, getOperation());
602 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
604 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
606 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
608 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
611 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
612 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
621 return signalPassFailure();
static Value forEachScalarValue(RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, Fn fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})