MLIR  18.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 struct ArithToAMDGPUConversionPass final
30  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
31  using impl::ArithToAMDGPUConversionPassBase<
32  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
33 
34  void runOnOperation() override;
35 };
36 
37 struct ExtfOnFloat8RewritePattern final
38  : public OpRewritePattern<arith::ExtFOp> {
40 
41  LogicalResult match(arith::ExtFOp op) const override;
42  void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
43 };
44 
45 struct TruncfToFloat8RewritePattern final
46  : public OpRewritePattern<arith::TruncFOp> {
48 
49  LogicalResult match(arith::TruncFOp op) const override;
50  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
51 };
52 } // end namespace
53 
54 static Value castF32To(Type elementType, Value f32, Location loc,
55  PatternRewriter &rewriter) {
56  if (elementType.isF32())
57  return f32;
58  if (elementType.getIntOrFloatBitWidth() < 32)
59  return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
60  if (elementType.getIntOrFloatBitWidth() > 32)
61  return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
62  llvm_unreachable("The only 32-bit float type is f32");
63 }
64 
65 LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
66  Type inType = op.getIn().getType();
67  if (auto inVecType = inType.dyn_cast<VectorType>()) {
68  if (inVecType.isScalable())
69  return failure();
70  if (inVecType.getShape().size() > 1)
71  // Multi-dimensional vectors are currently unsupported.
72  return failure();
73  inType = inVecType.getElementType();
74  }
75  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
76 }
77 
78 void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
79  PatternRewriter &rewriter) const {
80  Location loc = op.getLoc();
81  Value in = op.getIn();
82  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
83  if (!in.getType().isa<VectorType>()) {
84  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
85  loc, rewriter.getF32Type(), in, 0);
86  Value result = castF32To(outElemType, asFloat, loc, rewriter);
87  return rewriter.replaceOp(op, result);
88  }
89  VectorType inType = in.getType().cast<VectorType>();
90  int64_t numElements = inType.getNumElements();
91  Value zero = rewriter.createOrFold<arith::ConstantOp>(
92  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
93  Value result =
94  rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
95  if (inType.getShape().empty()) {
96  Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
97  // Recurse to send the 0-D vector case to the 1-D vector case
98  Value scalarExt =
99  rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
100  result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
101  return rewriter.replaceOp(op, result);
102  }
103  for (int64_t i = 0; i < numElements; i += 4) {
104  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
105  Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
106  loc, in, i, elemsThisOp, 1);
107  for (int64_t j = 0; j < elemsThisOp; ++j) {
108  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
109  loc, rewriter.getF32Type(), inSlice, j);
110  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
111  result = rewriter.create<vector::InsertElementOp>(
112  loc, asType, result,
113  rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
114  }
115  }
116  rewriter.replaceOp(op, result);
117 }
118 
119 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
120  Type type = value.getType();
121  if (type.isF32())
122  return value;
123  if (type.getIntOrFloatBitWidth() < 32)
124  return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
125  if (type.getIntOrFloatBitWidth() > 32)
126  return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
127  llvm_unreachable("The only 32-bit float type is f32");
128 }
129 
130 LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
131  Type outType = op.getOut().getType();
132  if (auto outVecType = outType.dyn_cast<VectorType>()) {
133  if (outVecType.isScalable())
134  return failure();
135  if (outVecType.getShape().size() > 1)
136  // Multi-dimensional vectors are currently unsupported.
137  return failure();
138  outType = outVecType.getElementType();
139  }
140  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
141 }
142 
143 void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
144  PatternRewriter &rewriter) const {
145  Location loc = op.getLoc();
146  Value in = op.getIn();
147  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
148  VectorType truncResType = VectorType::get(4, outElemType);
149  if (!in.getType().isa<VectorType>()) {
150  Value asFloat = castToF32(in, loc, rewriter);
151  Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
152  loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
153  /*existing=*/nullptr);
154  Value result = rewriter.create<vector::ExtractElementOp>(
155  loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
156  return rewriter.replaceOp(op, result);
157  }
158  VectorType outType = op.getOut().getType().cast<VectorType>();
159  int64_t numElements = outType.getNumElements();
160  Value zero = rewriter.createOrFold<arith::ConstantOp>(
161  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
162  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
163  if (outType.getShape().empty()) {
164  Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
165  // Recurse to send the 0-D vector case to the 1-D vector case
166  Value scalarTrunc =
167  rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
168  result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
169  return rewriter.replaceOp(op, result);
170  }
171 
172  for (int64_t i = 0; i < numElements; i += 4) {
173  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174  Value thisResult = nullptr;
175  for (int64_t j = 0; j < elemsThisOp; j += 2) {
176  Value elemA = rewriter.create<vector::ExtractElementOp>(
177  loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
178  Value asFloatA = castToF32(elemA, loc, rewriter);
179  Value asFloatB = nullptr;
180  if (j + 1 < elemsThisOp) {
181  Value elemB = rewriter.create<vector::ExtractElementOp>(
182  loc, in,
183  rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
184  asFloatB = castToF32(elemB, loc, rewriter);
185  }
186  thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
187  loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
188  }
189  if (elemsThisOp < 4)
190  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
191  loc, thisResult, 0, elemsThisOp, 1);
192  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
193  result, i, 1);
194  }
195  rewriter.replaceOp(op, result);
196 }
197 
199  RewritePatternSet &patterns) {
200  patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
201  patterns.getContext());
202 }
203 
204 void ArithToAMDGPUConversionPass::runOnOperation() {
205  Operation *op = getOperation();
206  RewritePatternSet patterns(op->getContext());
208  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
209  return signalPassFailure();
210 }
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
FloatType getF32Type()
Definition: Builders.cpp:63
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
U dyn_cast() const
Definition: Types.h:329
bool isF32() const
Definition: Types.cpp:51
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
bool isa() const
Definition: Types.h:319
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.