22 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
23 #include "mlir/Conversion/Passes.h.inc"
29 struct ArithToAMDGPUConversionPass final
30 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
31 using impl::ArithToAMDGPUConversionPassBase<
32 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
34 void runOnOperation()
override;
37 struct ExtfOnFloat8RewritePattern final
45 struct TruncfToFloat8RewritePattern final
56 if (elementType.
isF32())
59 return rewriter.
create<arith::TruncFOp>(loc, elementType, f32);
61 return rewriter.
create<arith::ExtFOp>(loc, elementType, f32);
62 llvm_unreachable(
"The only 32-bit float type is f32");
65 LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op)
const {
66 Type inType = op.getIn().getType();
67 if (
auto inVecType = inType.
dyn_cast<VectorType>()) {
68 if (inVecType.isScalable())
70 if (inVecType.getShape().size() > 1)
73 inType = inVecType.getElementType();
81 Value in = op.getIn();
84 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
89 VectorType inType = in.
getType().
cast<VectorType>();
90 int64_t numElements = inType.getNumElements();
92 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
94 rewriter.
createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
95 if (inType.getShape().empty()) {
96 Value scalarIn = rewriter.
create<vector::ExtractElementOp>(loc, in);
99 rewriter.
create<arith::ExtFOp>(loc, outElemType, scalarIn);
100 result = rewriter.
create<vector::InsertElementOp>(loc, scalarExt, zero);
103 for (int64_t i = 0; i < numElements; i += 4) {
104 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
105 Value inSlice = rewriter.
create<vector::ExtractStridedSliceOp>(
106 loc, in, i, elemsThisOp, 1);
107 for (int64_t
j = 0;
j < elemsThisOp; ++
j) {
108 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
111 result = rewriter.
create<vector::InsertElementOp>(
127 llvm_unreachable(
"The only 32-bit float type is f32");
130 LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op)
const {
131 Type outType = op.getOut().getType();
132 if (
auto outVecType = outType.
dyn_cast<VectorType>()) {
133 if (outVecType.isScalable())
135 if (outVecType.getShape().size() > 1)
138 outType = outVecType.getElementType();
146 Value in = op.getIn();
151 Value asF8s = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
152 loc, truncResType, asFloat,
nullptr, 0,
154 Value result = rewriter.
create<vector::ExtractElementOp>(
155 loc, asF8s, rewriter.
createOrFold<arith::ConstantIndexOp>(loc, 0));
158 VectorType outType = op.getOut().getType().cast<VectorType>();
159 int64_t numElements = outType.getNumElements();
161 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
163 if (outType.getShape().empty()) {
164 Value scalarIn = rewriter.
create<vector::ExtractElementOp>(loc, in);
167 rewriter.
create<arith::TruncFOp>(loc, outElemType, scalarIn);
168 result = rewriter.
create<vector::InsertElementOp>(loc, scalarTrunc, zero);
172 for (int64_t i = 0; i < numElements; i += 4) {
173 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
174 Value thisResult =
nullptr;
175 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
176 Value elemA = rewriter.
create<vector::ExtractElementOp>(
177 loc, in, rewriter.
create<arith::ConstantIndexOp>(loc, i +
j));
179 Value asFloatB =
nullptr;
180 if (
j + 1 < elemsThisOp) {
181 Value elemB = rewriter.
create<vector::ExtractElementOp>(
183 rewriter.
createOrFold<arith::ConstantIndexOp>(loc, i +
j + 1));
184 asFloatB =
castToF32(elemB, loc, rewriter);
186 thisResult = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
187 loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
190 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
191 loc, thisResult, 0, elemsThisOp, 1);
192 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
200 patterns.
add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
204 void ArithToAMDGPUConversionPass::runOnOperation() {
209 return signalPassFailure();
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FNUZ() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.