MLIR  22.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
21 
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
27 
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
29 
30 void amx::AMXDialect::initialize() {
31  addTypes<
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34  >();
35 
36  addOperations<
37 #define GET_OP_LIST
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
39  >();
40 }
41 
42 /// Verify that AMX supports the implied tile shape.
43 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
44  const unsigned kMaxRows = 16;
45  const unsigned kBitsPerRow = 64 * 8;
46  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47  if (tp.getDimSize(0) > kMaxRows)
48  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
49  if (col > kBitsPerRow || col & 0x1f)
50  return op->emitOpError("bad column width: ") << (col >> 3);
51  return success();
52 }
53 
54 /// Verify that AMX supports the multiplication.
55 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
56  amx::TileType btp, amx::TileType ctp,
57  unsigned scale) {
58  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61  if (cm != am || cn != bn || ak != bk)
62  return op->emitOpError("bad mult shape: ")
63  << cm << " x " << cn << " x " << ak;
64  return success();
65 }
66 
67 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
68 /// dimension directly translates into the number of rows of the tiles.
69 /// The second dimensions needs to be scaled by the number of bytes.
70 static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
71  RewriterBase &rewriter) {
72  Type llvmInt16Type = rewriter.getIntegerType(16);
73  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74  assert(llvm::isPowerOf2_64(width) && width >= 8);
75  unsigned bytes = width >> 3;
76  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
77  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
78  return SmallVector<Value>{
79  LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80  LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
81 }
82 
83 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
84 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
85 static Value getStride(Location loc, MemRefType mType, Value base,
86  RewriterBase &rewriter) {
87  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
88  int64_t preLast = mType.getRank() - 2;
89  Type llvmInt64Type = rewriter.getIntegerType(64);
90  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
91  assert(llvm::isPowerOf2_64(width) && width >= 8);
92  unsigned bytes = width >> 3;
93  auto [strides, offset] = mType.getStridesAndOffset();
94  if (strides[preLast] == ShapedType::kDynamic) {
95  // Dynamic stride needs code to compute the stride at runtime.
96  MemRefDescriptor memrefDescriptor(base);
97  auto attr = rewriter.getI64IntegerAttr(bytes);
98  Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
99  return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
100  memrefDescriptor.stride(rewriter, loc, preLast))
101  .getResult();
102  }
103  // Use direct constant for static stride.
104  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
105  return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
106  .getResult();
107 }
108 
109 LogicalResult amx::TileZeroOp::verify() {
110  return verifyTileSize(*this, getTileType());
111 }
112 
114 amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
115  const LLVMTypeConverter &typeConverter,
116  RewriterBase &rewriter) {
117  return getTileSizes(getLoc(), getTileType(), rewriter);
118 }
119 
120 LogicalResult amx::TileLoadOp::verify() {
121  MemRefType memrefTy = getMemRefType();
122  unsigned rank = memrefTy.getRank();
123  if (rank < 2)
124  return emitOpError("requires at least 2D memref");
125  if (getIndices().size() != rank)
126  return emitOpError("requires ") << rank << " indices";
127  SmallVector<int64_t> strides;
128  int64_t offset;
129  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
130  strides.back() != 1)
131  return emitOpError("requires memref with unit innermost stride");
132  return verifyTileSize(*this, getTileType());
133 }
134 
136 amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
137  const LLVMTypeConverter &typeConverter,
138  RewriterBase &rewriter) {
139  auto loc = getLoc();
140  Adaptor adaptor(operands, *this);
141 
142  SmallVector<Value> intrinsicOperands;
143  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
144  intrinsicOperands.push_back(
145  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
146  adaptor.getBase(), adaptor.getIndices()));
147  intrinsicOperands.push_back(
148  getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
149 
150  return intrinsicOperands;
151 }
152 
153 LogicalResult amx::TileStoreOp::verify() {
154  MemRefType memrefTy = getMemRefType();
155  unsigned rank = memrefTy.getRank();
156  if (rank < 2)
157  return emitOpError("requires at least 2D memref");
158  if (getIndices().size() != rank)
159  return emitOpError("requires ") << rank << " indices";
160  SmallVector<int64_t> strides;
161  int64_t offset;
162  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
163  strides.back() != 1)
164  return emitOpError("requires memref with unit innermost stride");
165  return verifyTileSize(*this, getTileType());
166 }
167 
169 amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
170  const LLVMTypeConverter &typeConverter,
171  RewriterBase &rewriter) {
172  auto loc = getLoc();
173  Adaptor adaptor(operands, *this);
174 
175  SmallVector<Value> intrinsicOperands;
176  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
177  intrinsicOperands.push_back(
178  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
179  adaptor.getBase(), adaptor.getIndices()));
180  intrinsicOperands.push_back(
181  getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
182  intrinsicOperands.push_back(adaptor.getVal());
183 
184  return intrinsicOperands;
185 }
186 
187 LogicalResult amx::TileMulFOp::verify() {
188  amx::TileType aType = getLhsTileType();
189  amx::TileType bType = getRhsTileType();
190  amx::TileType cType = getTileType();
191  if (failed(verifyTileSize(*this, aType)) ||
192  failed(verifyTileSize(*this, bType)) ||
193  failed(verifyTileSize(*this, cType)) ||
194  failed(verifyMultShape(*this, aType, bType, cType, 1)))
195  return failure();
196  Type ta = aType.getElementType();
197  Type tb = bType.getElementType();
198  Type tc = cType.getElementType();
199  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
200  return emitOpError("unsupported type combination");
201  return success();
202 }
203 
205 amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
206  const LLVMTypeConverter &typeConverter,
207  RewriterBase &rewriter) {
208  auto loc = getLoc();
209  Adaptor adaptor(operands, *this);
210 
211  amx::TileType aType = getLhsTileType();
212  amx::TileType bType = getRhsTileType();
213  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
214  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
215 
216  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
217  tsza[1], adaptor.getAcc(),
218  adaptor.getLhs(), adaptor.getRhs()};
219 
220  return intrinsicOperands;
221 }
222 
223 LogicalResult amx::TileMulIOp::verify() {
224  amx::TileType aType = getLhsTileType();
225  amx::TileType bType = getRhsTileType();
226  amx::TileType cType = getTileType();
227  if (failed(verifyTileSize(*this, aType)) ||
228  failed(verifyTileSize(*this, bType)) ||
229  failed(verifyTileSize(*this, cType)) ||
230  failed(verifyMultShape(*this, aType, bType, cType, 2)))
231  return failure();
232  Type ta = aType.getElementType();
233  Type tb = bType.getElementType();
234  Type tc = cType.getElementType();
235  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
236  return emitOpError("unsupported type combination");
237  return success();
238 }
239 
241 amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
242  const LLVMTypeConverter &typeConverter,
243  RewriterBase &rewriter) {
244  auto loc = getLoc();
245  Adaptor adaptor(operands, *this);
246 
247  amx::TileType aType = getLhsTileType();
248  amx::TileType bType = getRhsTileType();
249  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
250  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
251 
252  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
253  tsza[1], adaptor.getAcc(),
254  adaptor.getLhs(), adaptor.getRhs()};
255 
256  return intrinsicOperands;
257 }
258 
260  if (parser.parseLess())
261  return nullptr;
262 
264  if (parser.parseDimensionList(shape, false, true))
265  return nullptr;
266 
267  Type elementType;
268  if (parser.parseType(elementType))
269  return nullptr;
270 
271  if (parser.parseGreater())
272  return nullptr;
273 
274  return TileType::getChecked(
275  [&] { return parser.emitError(parser.getNameLoc()); }, shape,
276  elementType);
277 }
278 
279 void amx::TileType::print(AsmPrinter &os) const {
280  os << "<";
282  os << 'x';
284  os << '>';
285 }
286 
287 #define GET_OP_CLASSES
288 #include "mlir/Dialect/AMX/AMX.cpp.inc"
289 
290 #define GET_TYPEDEF_CLASSES
291 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:55
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:43
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:216
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423