MLIR  19.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::amx;
20 
21 namespace {
22 
23 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
24 /// dimension directly translates into the number of rows of the tiles.
25 /// The second dimensions needs to be scaled by the number of bytes.
26 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
27  const LLVMTypeConverter &typeConverter,
28  VectorType vType, Location loc) {
29  Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
30  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
31  assert(llvm::isPowerOf2_64(width) && width >= 8);
32  unsigned bytes = width >> 3;
33  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
34  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
35  return std::make_pair(
36  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
37  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
38 }
39 
40 /// Verifies if the stride matches proper tile access.
41 LogicalResult verifyStride(MemRefType mType) {
42  if (mType.getRank() < 2)
43  return failure();
44  int64_t last = mType.getRank() - 1;
45  int64_t offset;
47  if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
48  return failure();
49  return success();
50 }
51 
52 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
53 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
54 Value getStride(ConversionPatternRewriter &rewriter,
55  const LLVMTypeConverter &typeConverter, MemRefType mType,
56  Value base, Location loc) {
57  assert(mType.getRank() >= 2);
58  int64_t last = mType.getRank() - 1;
59  Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
60  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
61  assert(llvm::isPowerOf2_64(width) && width >= 8);
62  unsigned bytes = width >> 3;
63  if (mType.isDynamicDim(last)) {
64  // Dynamic size needs code to compute the stride at runtime.
65  MemRefDescriptor memrefDescriptor(base);
66  auto attr = rewriter.getI64IntegerAttr(bytes);
67  Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68  return rewriter.create<LLVM::MulOp>(
69  loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
70  }
71  // Use direct constant for static size.
72  auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
73  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
74 }
75 
76 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
79  matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
80  ConversionPatternRewriter &rewriter) const override {
81  VectorType vType = op.getVectorType();
82  // Determine m x n tile sizes.
83  std::pair<Value, Value> tsz =
84  getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
85  // Replace operation with intrinsic.
86  Type resType = typeConverter->convertType(vType);
87  rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
88  tsz.second);
89  return success();
90  }
91 };
92 
93 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
95 
97  matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  MemRefType mType = op.getMemRefType();
100  VectorType vType = op.getVectorType();
101  // Determine m x n tile sizes.
102  std::pair<Value, Value> tsz =
103  getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
104  // Determine stride.
105  if (failed(verifyStride(mType)))
106  return failure();
107  Value stride = getStride(rewriter, *getTypeConverter(), mType,
108  adaptor.getBase(), op.getLoc());
109  // Replace operation with intrinsic.
110  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
111  adaptor.getIndices(), rewriter);
112  Type resType = typeConverter->convertType(vType);
113  rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
114  op, resType, tsz.first, tsz.second, ptr, stride);
115  return success();
116  }
117 };
118 
119 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
121 
123  matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
124  ConversionPatternRewriter &rewriter) const override {
125  MemRefType mType = op.getMemRefType();
126  VectorType vType = op.getVectorType();
127  // Determine m x n tile sizes.
128  std::pair<Value, Value> tsz =
129  getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
130  // Determine stride.
131  if (failed(verifyStride(mType)))
132  return failure();
133  Value stride = getStride(rewriter, *getTypeConverter(), mType,
134  adaptor.getBase(), op.getLoc());
135  // Replace operation with intrinsic.
136  Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
137  adaptor.getIndices(), rewriter);
138  rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
139  op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
140  return success();
141  }
142 };
143 
144 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
147  matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter) const override {
149  VectorType aType = op.getLhsVectorType();
150  VectorType bType = op.getRhsVectorType();
151  VectorType cType = op.getVectorType();
152  // Determine m x n x k tile sizes.
153  std::pair<Value, Value> tsza =
154  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
155  std::pair<Value, Value> tszb =
156  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
157  // Replace operation with intrinsic.
158  Type resType = typeConverter->convertType(cType);
159  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
160  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
161  adaptor.getLhs(), adaptor.getRhs());
162  return success();
163  }
164 };
165 
166 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
169  matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
170  ConversionPatternRewriter &rewriter) const override {
171  VectorType aType = op.getLhsVectorType();
172  VectorType bType = op.getRhsVectorType();
173  VectorType cType = op.getVectorType();
174  // Determine m x n x k tile sizes.
175  std::pair<Value, Value> tsza =
176  getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
177  std::pair<Value, Value> tszb =
178  getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
179  // Replace operation with intrinsic.
180  Type resType = typeConverter->convertType(cType);
181  bool zexta = op.getIsZextLhs();
182  bool zextb = op.getIsZextRhs();
183  if (zexta && zextb)
184  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
185  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
186  adaptor.getLhs(), adaptor.getRhs());
187  else if (zexta && !zextb)
188  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
189  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
190  adaptor.getLhs(), adaptor.getRhs());
191  else if (!zexta && zextb)
192  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
193  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
194  adaptor.getLhs(), adaptor.getRhs());
195  else
196  rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
197  op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
198  adaptor.getLhs(), adaptor.getRhs());
199  return success();
200  }
201 };
202 
203 } // namespace
204 
206  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
207  patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
208  TileMulFConversion, TileMulIConversion>(converter);
209 }
210 
212  target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
213  x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
214  x86_amx_tdpbusd, x86_amx_tdpbuud>();
215  target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
216  TileMulFOp>();
217 }
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:230
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
This class implements a pattern rewriter for use with ConversionPatterns.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26