21#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
22#include "mlir/Conversion/Passes.h.inc"
29 StringRef name, FunctionType funcT,
bool setPrivate,
32 assert(!symTable->getRegion(0).empty() &&
"expected non-empty region");
33 b.setInsertionPointToStart(&symTable->getRegion(0).front());
34 FuncOp funcOp = FuncOp::create(
b, symTable->getLoc(), name, funcT);
38 SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
39 symbolTable.
insert(funcOp, symTable->getRegion(0).front().begin());
47static FailureOr<FuncOp>
51 Type resultType = {}) {
53 resultType = IntegerType::get(symTable->getContext(), 64);
54 std::string funcName = (llvm::Twine(
"_mlir_apfloat_") + name).str();
55 auto funcT = FunctionType::get(
b.getContext(), paramTypes, {resultType});
56 FailureOr<FuncOp>
func =
79static FailureOr<FuncOp>
82 auto i32Type = IntegerType::get(symTable->getContext(), 32);
83 auto i64Type = IntegerType::get(symTable->getContext(), 64);
89 int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
90 return arith::ConstantOp::create(
b, loc,
b.getI32Type(),
91 b.getIntegerAttr(
b.getI32Type(), sem));
98template <
typename Fn,
typename... Values>
102 auto vecTy1 = dyn_cast<VectorType>(operand1.
getType());
105 assert(vecTy1 == dyn_cast<VectorType>(operand2.
getType()) &&
106 "expected same vector types");
110 return fn(operand1, operand2, resultType);
115 vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
119 scalars2.assign(vecTy1.getNumElements(),
Value());
123 vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
127 auto resultVecType = cast<VectorType>(resultType);
129 for (
auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
130 Value result = fn(scalar1, scalar2, resultVecType.getElementType());
131 results.push_back(
result);
135 return vector::FromElementsOp::create(
137 vecTy1.cloneWith(std::nullopt, results.front().getType()),
146 Type type = value.getType();
147 if (
auto vecTy = dyn_cast<VectorType>(type)) {
148 type = vecTy.getElementType();
152 op,
"only integers and floats (or vectors thereof) are supported");
156 "bitwidth > 64 bits is not supported");
162template <
typename OpTy>
177 FailureOr<FuncOp> fn =
186 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
189 auto floatTy = cast<FloatType>(resultType);
190 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
191 auto int64Type = rewriter.getI64Type();
192 Value lhsBits = arith::ExtUIOp::create(
193 rewriter, loc, int64Type,
194 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
195 Value rhsBits = arith::ExtUIOp::create(
196 rewriter, loc, int64Type,
197 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
200 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
201 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
202 auto resultOp = func::CallOp::create(rewriter, loc,
204 SymbolRefAttr::get(*fn), params);
207 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
208 resultOp->getResult(0));
209 return arith::BitcastOp::create(rewriter, loc, floatTy,
212 rewriter.replaceOp(op, repl);
216 SymbolOpInterface symTable;
220template <
typename OpTy>
232 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
233 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
235 rewriter,
symTable,
"convert", {i32Type, i32Type, i64Type});
243 rewriter, loc, op.getOperand(),
Value(), op.getType(),
246 auto inFloatTy = cast<FloatType>(operand1.getType());
247 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
248 Value operandBits = arith::ExtUIOp::create(
249 rewriter, loc, i64Type,
250 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
253 Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
254 auto outFloatTy = cast<FloatType>(resultType);
255 Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
256 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
257 auto resultOp = func::CallOp::create(rewriter, loc,
259 SymbolRefAttr::get(*fn), params);
262 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
263 Value truncatedBits = arith::TruncIOp::create(
264 rewriter, loc, outIntWType, resultOp->getResult(0));
265 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
275template <
typename OpTy>
288 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
289 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
290 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
291 FailureOr<FuncOp> fn =
293 {i32Type, i32Type, i1Type, i64Type});
301 rewriter, loc, op.getOperand(),
Value(), op.getType(),
304 auto inFloatTy = cast<FloatType>(operand1.getType());
305 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
306 Value operandBits = arith::ExtUIOp::create(
307 rewriter, loc, i64Type,
308 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
311 Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
312 auto outIntTy = cast<IntegerType>(resultType);
313 Value outWidthValue = arith::ConstantOp::create(
314 rewriter, loc, i32Type,
315 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
316 Value isUnsignedValue = arith::ConstantOp::create(
317 rewriter, loc, i1Type,
318 rewriter.getIntegerAttr(i1Type, isUnsigned));
319 SmallVector<Value> params = {inSemValue, outWidthValue,
320 isUnsignedValue, operandBits};
321 auto resultOp = func::CallOp::create(rewriter, loc,
323 SymbolRefAttr::get(*fn), params);
326 return arith::TruncIOp::create(rewriter, loc, outIntTy,
327 resultOp->getResult(0));
337template <
typename OpTy>
350 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
351 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
352 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
353 FailureOr<FuncOp> fn =
355 {i32Type, i32Type, i1Type, i64Type});
363 rewriter, loc, op.getOperand(),
Value(), op.getType(),
366 auto inIntTy = cast<IntegerType>(operand1.getType());
367 Value operandBits = operand1;
368 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
371 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
374 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
379 auto outFloatTy = cast<FloatType>(resultType);
381 Value inWidthValue = arith::ConstantOp::create(
382 rewriter, loc, i32Type,
384 Value isUnsignedValue = arith::ConstantOp::create(
385 rewriter, loc, i1Type,
388 isUnsignedValue, operandBits};
389 auto resultOp = func::CallOp::create(rewriter, loc,
391 SymbolRefAttr::get(*fn), params);
394 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
395 Value truncatedBits = arith::TruncIOp::create(
396 rewriter, loc, outIntWType, resultOp->getResult(0));
397 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
419 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
420 auto i8Type = IntegerType::get(
symTable->getContext(), 8);
421 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
422 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
423 FailureOr<FuncOp> fn =
425 {i32Type, i64Type, i64Type},
nullptr, i8Type);
433 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
436 auto floatTy = cast<FloatType>(lhs.getType());
437 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
438 Value lhsBits = arith::ExtUIOp::create(
439 rewriter, loc, i64Type,
440 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
441 Value rhsBits = arith::ExtUIOp::create(
442 rewriter, loc, i64Type,
443 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
446 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
447 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
448 Value comparisonResult =
449 func::CallOp::create(rewriter, loc,
TypeRange(i8Type),
450 SymbolRefAttr::get(*fn), params)
455 auto checkResult = [&](llvm::APFloat::cmpResult val) {
456 return arith::CmpIOp::create(
457 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
458 arith::ConstantOp::create(
459 rewriter, loc, i8Type,
467 Value first = checkResult(vals.front());
468 if (vals.size() == 1)
470 Value rest = checkResults(vals.drop_front());
471 return arith::OrIOp::create(rewriter, loc, first, rest)
477 switch (op.getPredicate()) {
478 case arith::CmpFPredicate::AlwaysFalse:
480 arith::ConstantOp::create(rewriter, loc, i1Type,
484 case arith::CmpFPredicate::OEQ:
485 result = checkResult(llvm::APFloat::cmpEqual);
487 case arith::CmpFPredicate::OGT:
488 result = checkResult(llvm::APFloat::cmpGreaterThan);
490 case arith::CmpFPredicate::OGE:
492 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
494 case arith::CmpFPredicate::OLT:
495 result = checkResult(llvm::APFloat::cmpLessThan);
497 case arith::CmpFPredicate::OLE:
499 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
501 case arith::CmpFPredicate::ONE:
504 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
506 case arith::CmpFPredicate::ORD:
508 result = checkResults({llvm::APFloat::cmpLessThan,
509 llvm::APFloat::cmpGreaterThan,
510 llvm::APFloat::cmpEqual});
512 case arith::CmpFPredicate::UEQ:
514 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
516 case arith::CmpFPredicate::UGT:
518 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
520 case arith::CmpFPredicate::UGE:
521 result = checkResults({llvm::APFloat::cmpUnordered,
522 llvm::APFloat::cmpGreaterThan,
523 llvm::APFloat::cmpEqual});
525 case arith::CmpFPredicate::ULT:
527 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
529 case arith::CmpFPredicate::ULE:
530 result = checkResults({llvm::APFloat::cmpUnordered,
531 llvm::APFloat::cmpLessThan,
532 llvm::APFloat::cmpEqual});
534 case arith::CmpFPredicate::UNE:
536 result = checkResults({llvm::APFloat::cmpLessThan,
537 llvm::APFloat::cmpGreaterThan,
538 llvm::APFloat::cmpUnordered});
540 case arith::CmpFPredicate::UNO:
541 result = checkResult(llvm::APFloat::cmpUnordered);
543 case arith::CmpFPredicate::AlwaysTrue:
545 arith::ConstantOp::create(rewriter, loc, i1Type,
570 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
571 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
572 FailureOr<FuncOp> fn =
584 auto floatTy = cast<FloatType>(operand1.getType());
585 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
586 Value operandBits = arith::ExtUIOp::create(
587 rewriter, loc, i64Type,
588 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
591 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
592 SmallVector<Value> params = {semValue, operandBits};
594 func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
595 SymbolRefAttr::get(*fn), params)
599 Value truncatedBits =
600 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
601 return arith::BitcastOp::create(rewriter, loc, floatTy,
612struct ArithToAPFloatConversionPass final
616 void runOnOperation()
override;
619void ArithToAPFloatConversionPass::runOnOperation() {
621 RewritePatternSet
patterns(context);
622 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context,
"add",
624 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
625 context,
"subtract", getOperation());
626 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
627 context,
"multiply", getOperation());
628 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
629 context,
"divide", getOperation());
630 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
631 context,
"remainder", getOperation());
632 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
633 context,
"minnum", getOperation());
634 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
635 context,
"maxnum", getOperation());
636 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
637 context,
"minimum", getOperation());
638 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
639 context,
"maximum", getOperation());
641 .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
642 CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
643 context, getOperation());
644 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
646 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
648 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
650 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
653 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
654 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
663 return signalPassFailure();
static Value forEachScalarValue(RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, Fn fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static FailureOr< FuncOp > lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, FunctionType funcT, bool setPrivate, SymbolTableCollection *symbolTables=nullptr)
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents a collection of SymbolTables.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< FuncOp > lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables=nullptr)
Look up a FuncOp with signature resultTypes(paramTypes) and name / name`.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})