MLIR  20.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::amx;
21 
22 namespace {
23 
24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
25 /// dimension directly translates into the number of rows of the tiles.
26 /// The second dimensions needs to be scaled by the number of bytes.
27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
28  const LLVMTypeConverter &typeConverter,
29  amx::TileType tType, Location loc) {
30  Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
31  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
32  assert(llvm::isPowerOf2_64(width) && width >= 8);
33  unsigned bytes = width >> 3;
34  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
35  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
36  return std::make_pair(
37  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
38  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
39 }
40 
41 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
42 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
43 /// Returns failure if proper stride couldn't be found.
44 FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
45  const LLVMTypeConverter &typeConverter,
46  MemRefType mType, Value base, Location loc) {
47  if (mType.getRank() < 2)
48  return failure();
49  int64_t preLast = mType.getRank() - 2;
50  Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
51  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
52  assert(llvm::isPowerOf2_64(width) && width >= 8);
53  unsigned bytes = width >> 3;
54  int64_t offset;
56  if (failed(getStridesAndOffset(mType, strides, offset)) ||
57  strides.back() != 1)
58  return failure();
59  if (strides[preLast] == ShapedType::kDynamic) {
60  // Dynamic stride needs code to compute the stride at runtime.
61  MemRefDescriptor memrefDescriptor(base);
62  auto attr = rewriter.getI64IntegerAttr(bytes);
63  Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
64  return rewriter
65  .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
66  memrefDescriptor.stride(rewriter, loc, preLast))
67  .getResult();
68  }
69  // Use direct constant for static stride.
70  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
71  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
72  .getResult();
73 }
74 
75 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
77  LogicalResult
78  matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80  amx::TileType tType = op.getTileType();
81  // Determine m x n tile sizes.
82  std::pair<Value, Value> tsz =
83  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
84  // Replace operation with intrinsic.
85  Type resType = typeConverter->convertType(tType);
86  rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
87  tsz.second);
88  return success();
89  }
90 };
91 
92 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
94 
95  LogicalResult
96  matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
97  ConversionPatternRewriter &rewriter) const override {
98  MemRefType mType = op.getMemRefType();
99  amx::TileType tType = op.getTileType();
100  // Determine m x n tile sizes.
101  std::pair<Value, Value> tsz =
102  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
103  // Determine stride.
104  auto stride = getStride(rewriter, *getTypeConverter(), mType,
105  adaptor.getBase(), op.getLoc());
106  if (failed(stride))
107  return failure();
108  // Replace operation with intrinsic.
109  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
110  adaptor.getIndices(), rewriter);
111  Type resType = typeConverter->convertType(tType);
112  rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
113  op, resType, tsz.first, tsz.second, ptr, stride.value());
114  return success();
115  }
116 };
117 
118 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
120 
121  LogicalResult
122  matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
123  ConversionPatternRewriter &rewriter) const override {
124  MemRefType mType = op.getMemRefType();
125  amx::TileType tType = op.getTileType();
126  // Determine m x n tile sizes.
127  std::pair<Value, Value> tsz =
128  getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
129  // Determine stride.
130  auto stride = getStride(rewriter, *getTypeConverter(), mType,
131  adaptor.getBase(), op.getLoc());
132  if (failed(stride))
133  return failure();
134  // Replace operation with intrinsic.
135  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
136  adaptor.getIndices(), rewriter);
137  rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
138  op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
139  return success();
140  }
141 };
142 
143 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
145  LogicalResult
146  matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  amx::TileType aType = op.getLhsTileType();
149  amx::TileType bType = op.getRhsTileType();
150  amx::TileType cType = op.getTileType();
151  // Determine m x n x k tile sizes.
152  std::pair<Value, Value> tsza =
153  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
154  std::pair<Value, Value> tszb =
155  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
156  // Replace operation with intrinsic.
157  Type resType = typeConverter->convertType(cType);
158  if (aType.getElementType().isBF16())
159  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
160  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
161  adaptor.getLhs(), adaptor.getRhs());
162  else if (aType.getElementType().isF16())
163  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
164  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
165  adaptor.getLhs(), adaptor.getRhs());
166  else
167  llvm_unreachable("Unexpected element type for amx.mulf");
168  return success();
169  }
170 };
171 
172 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
174  LogicalResult
175  matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const override {
177  amx::TileType aType = op.getLhsTileType();
178  amx::TileType bType = op.getRhsTileType();
179  amx::TileType cType = op.getTileType();
180  // Determine m x n x k tile sizes.
181  std::pair<Value, Value> tsza =
182  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
183  std::pair<Value, Value> tszb =
184  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
185  // Replace operation with intrinsic.
186  Type resType = typeConverter->convertType(cType);
187  bool zexta = op.getIsZextLhs();
188  bool zextb = op.getIsZextRhs();
189  if (zexta && zextb)
190  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
191  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
192  adaptor.getLhs(), adaptor.getRhs());
193  else if (zexta && !zextb)
194  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
195  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
196  adaptor.getLhs(), adaptor.getRhs());
197  else if (!zexta && zextb)
198  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
199  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
200  adaptor.getLhs(), adaptor.getRhs());
201  else
202  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
203  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
204  adaptor.getLhs(), adaptor.getRhs());
205  return success();
206  }
207 };
208 
209 } // namespace
210 
212  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
213  patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
214  TileMulFConversion, TileMulIConversion>(converter);
215  converter.addConversion([&](amx::TileType type) {
216  return LLVM::LLVMX86AMXType::get(&converter.getContext());
217  });
218 }
219 
221  target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
222  x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
223  x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
224  target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
225  TileMulFOp>();
226 }
227 
228 namespace {
229 /// Implement the interface to convert AMX to LLVM.
230 struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
232 
233  void populateConvertToLLVMConversionPatterns(
234  ConversionTarget &target, LLVMTypeConverter &typeConverter,
235  RewritePatternSet &patterns) const final {
236  populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
237  }
238 };
239 } // namespace
240 
242  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
243  dialect->addInterfaces<AMXToLLVMDialectInterface>();
244  });
245 }
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:257
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry)
Register LLVM conversion interface for AMX dialect.