MLIR  21.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::amx;
21 
22 namespace {
23 
24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
25 /// dimension directly translates into the number of rows of the tiles.
26 /// The second dimensions needs to be scaled by the number of bytes.
27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
28  const LLVMTypeConverter &typeConverter,
29  amx::TileType tType, Location loc) {
30  Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
31  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
32  assert(llvm::isPowerOf2_64(width) && width >= 8);
33  unsigned bytes = width >> 3;
34  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
35  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
36  return std::make_pair(
37  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
38  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
39 }
40 
41 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
42 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
43 /// Returns failure if proper stride couldn't be found.
44 FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
45  const LLVMTypeConverter &typeConverter,
46  MemRefType mType, Value base, Location loc) {
47  if (mType.getRank() < 2)
48  return failure();
49  int64_t preLast = mType.getRank() - 2;
50  Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
51  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
52  assert(llvm::isPowerOf2_64(width) && width >= 8);
53  unsigned bytes = width >> 3;
54  int64_t offset;
56  if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
57  return failure();
58  if (strides[preLast] == ShapedType::kDynamic) {
59  // Dynamic stride needs code to compute the stride at runtime.
60  MemRefDescriptor memrefDescriptor(base);
61  auto attr = rewriter.getI64IntegerAttr(bytes);
62  Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
63  return rewriter
64  .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
65  memrefDescriptor.stride(rewriter, loc, preLast))
66  .getResult();
67  }
68  // Use direct constant for static stride.
69  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
70  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71  .getResult();
72 }
73 
74 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
76  LogicalResult
77  matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override {
79  amx::TileType tType = op.getTileType();
80  // Determine m x n tile sizes.
81  std::pair<Value, Value> tsz =
82  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
83  // Replace operation with intrinsic.
84  Type resType = typeConverter->convertType(tType);
85  rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
86  tsz.second);
87  return success();
88  }
89 };
90 
91 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
93 
94  LogicalResult
95  matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
96  ConversionPatternRewriter &rewriter) const override {
97  MemRefType mType = op.getMemRefType();
98  amx::TileType tType = op.getTileType();
99  // Determine m x n tile sizes.
100  std::pair<Value, Value> tsz =
101  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
102  // Determine stride.
103  auto stride = getStride(rewriter, *getTypeConverter(), mType,
104  adaptor.getBase(), op.getLoc());
105  if (failed(stride))
106  return failure();
107  // Replace operation with intrinsic.
108  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109  adaptor.getIndices(), rewriter);
110  Type resType = typeConverter->convertType(tType);
111  rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
112  op, resType, tsz.first, tsz.second, ptr, stride.value());
113  return success();
114  }
115 };
116 
117 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
119 
120  LogicalResult
121  matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  MemRefType mType = op.getMemRefType();
124  amx::TileType tType = op.getTileType();
125  // Determine m x n tile sizes.
126  std::pair<Value, Value> tsz =
127  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
128  // Determine stride.
129  auto stride = getStride(rewriter, *getTypeConverter(), mType,
130  adaptor.getBase(), op.getLoc());
131  if (failed(stride))
132  return failure();
133  // Replace operation with intrinsic.
134  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135  adaptor.getIndices(), rewriter);
136  rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
137  op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
138  return success();
139  }
140 };
141 
142 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
144  LogicalResult
145  matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  amx::TileType aType = op.getLhsTileType();
148  amx::TileType bType = op.getRhsTileType();
149  amx::TileType cType = op.getTileType();
150  // Determine m x n x k tile sizes.
151  std::pair<Value, Value> tsza =
152  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
153  std::pair<Value, Value> tszb =
154  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
155  // Replace operation with intrinsic.
156  Type resType = typeConverter->convertType(cType);
157  if (aType.getElementType().isBF16())
158  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
159  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
160  adaptor.getLhs(), adaptor.getRhs());
161  else if (aType.getElementType().isF16())
162  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
163  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
164  adaptor.getLhs(), adaptor.getRhs());
165  else
166  llvm_unreachable("Unexpected element type for amx.mulf");
167  return success();
168  }
169 };
170 
171 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
173  LogicalResult
174  matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  amx::TileType aType = op.getLhsTileType();
177  amx::TileType bType = op.getRhsTileType();
178  amx::TileType cType = op.getTileType();
179  // Determine m x n x k tile sizes.
180  std::pair<Value, Value> tsza =
181  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
182  std::pair<Value, Value> tszb =
183  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
184  // Replace operation with intrinsic.
185  Type resType = typeConverter->convertType(cType);
186  bool zexta = op.getIsZextLhs();
187  bool zextb = op.getIsZextRhs();
188  if (zexta && zextb)
189  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
190  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
191  adaptor.getLhs(), adaptor.getRhs());
192  else if (zexta && !zextb)
193  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
194  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
195  adaptor.getLhs(), adaptor.getRhs());
196  else if (!zexta && zextb)
197  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
198  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
199  adaptor.getLhs(), adaptor.getRhs());
200  else
201  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
202  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
203  adaptor.getLhs(), adaptor.getRhs());
204  return success();
205  }
206 };
207 
208 } // namespace
209 
212  patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
213  TileMulFConversion, TileMulIConversion>(converter);
214  converter.addConversion([&](amx::TileType type) {
215  return LLVM::LLVMX86AMXType::get(&converter.getContext());
216  });
217 }
218 
220  target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
221  x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
222  x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
223  target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
224  TileMulFOp>();
225 }
226 
227 namespace {
228 /// Implement the interface to convert AMX to LLVM.
229 struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
231 
232  void populateConvertToLLVMConversionPatterns(
233  ConversionTarget &target, LLVMTypeConverter &typeConverter,
234  RewritePatternSet &patterns) const final {
236  }
237 };
238 } // namespace
239 
241  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
242  dialect->addInterfaces<AMXToLLVMDialectInterface>();
243  });
244 }
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:213
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry)
Register LLVM conversion interface for AMX dialect.