MLIR  19.0.0git
OptimizeForNVVM.cpp
Go to the documentation of this file.
1 //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace NVVM {
19 #define GEN_PASS_DEF_NVVMOPTIMIZEFORTARGET
20 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
21 } // namespace NVVM
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
28 // (conditional) Newton iteration.
29 //
30 // This as accurate as promoting the division to fp32 in the NVPTX backend, but
31 // faster because it performs less Newton iterations, avoids the slow path
32 // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
33 // by the same divisor.
34 struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
36 
37 private:
38  LogicalResult matchAndRewrite(LLVM::FDivOp op,
39  PatternRewriter &rewriter) const override;
40 };
41 
42 struct NVVMOptimizeForTarget
43  : public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
44  void runOnOperation() override;
45 
46  void getDependentDialects(DialectRegistry &registry) const override {
47  registry.insert<NVVM::NVVMDialect>();
48  }
49 };
50 } // namespace
51 
52 LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
53  PatternRewriter &rewriter) const {
54  if (!op.getType().isF16())
55  return rewriter.notifyMatchFailure(op, "not f16");
56  Location loc = op.getLoc();
57 
58  Type f32Type = rewriter.getF32Type();
59  Type i32Type = rewriter.getI32Type();
60 
61  // Extend lhs and rhs to fp32.
62  Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
63  Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
64 
65  // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
66  Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
67  Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
68 
69  // Refine the approximation with one Newton iteration:
70  // float refined = approx + (lhs - approx * rhs) * rcp;
71  Value err = rewriter.create<LLVM::FMAOp>(
72  loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
73  Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
74 
75  // Use refined value if approx is normal (exponent neither all 0 or all 1).
76  Value mask = rewriter.create<LLVM::ConstantOp>(
77  loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
78  Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
79  Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
80  Value zero = rewriter.create<LLVM::ConstantOp>(
81  loc, i32Type, rewriter.getUI32IntegerAttr(0));
82  Value pred = rewriter.create<LLVM::OrOp>(
83  loc,
84  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
85  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
86  Value result =
87  rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
88 
89  // Replace with trucation back to fp16.
90  rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
91 
92  return success();
93 }
94 
95 void NVVMOptimizeForTarget::runOnOperation() {
96  MLIRContext *ctx = getOperation()->getContext();
97  RewritePatternSet patterns(ctx);
98  patterns.add<ExpandDivF16>(ctx);
99  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
100  return signalPassFailure();
101 }
102 
103 std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
104  return std::make_unique<NVVMOptimizeForTarget>();
105 }
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getUI32IntegerAttr(uint32_t value)
Definition: Builders.cpp:225
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::unique_ptr< Pass > createOptimizeForTargetPass()
Creates a pass that optimizes LLVM IR for the NVVM target.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358