MLIR  22.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
21 
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
27 
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
29 
30 void amx::AMXDialect::initialize() {
31  addTypes<
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34  >();
35 
36  addOperations<
37 #define GET_OP_LIST
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
39  >();
40 }
41 
42 /// Verify that AMX supports the implied tile shape.
43 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
44  const unsigned kMaxRows = 16;
45  const unsigned kBitsPerRow = 64 * 8;
46  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47  if (tp.getDimSize(0) > kMaxRows)
48  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
49  if (col > kBitsPerRow || col & 0x1f)
50  return op->emitOpError("bad column width: ") << (col >> 3);
51  return success();
52 }
53 
54 /// Verify that AMX supports the multiplication.
55 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
56  amx::TileType btp, amx::TileType ctp,
57  unsigned scale) {
58  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61  if (cm != am || cn != bn || ak != bk)
62  return op->emitOpError("bad mult shape: ")
63  << cm << " x " << cn << " x " << ak;
64  return success();
65 }
66 
67 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
68 /// dimension directly translates into the number of rows of the tiles.
69 /// The second dimensions needs to be scaled by the number of bytes.
70 static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
71  RewriterBase &rewriter) {
72  Type llvmInt16Type = rewriter.getIntegerType(16);
73  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74  assert(llvm::isPowerOf2_64(width) && width >= 8);
75  unsigned bytes = width >> 3;
76  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
77  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
78  return SmallVector<Value>{
79  LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80  LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
81 }
82 
83 /// Returns stride expressed in number of bytes for the given `elementStride`
84 /// stride encoded in number of elements of the type `mType`.
85 static Value computeStrideInBytes(Location loc, MemRefType mType,
86  Value elementStride, RewriterBase &rewriter) {
87  Type llvmInt64Type = rewriter.getIntegerType(64);
88  unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
89  auto attr = rewriter.getI64IntegerAttr(bytes);
90  Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
91  return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
92  .getResult();
93 }
94 
95 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
96 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
97 static Value inferStride(Location loc, MemRefType mType, Value base,
98  RewriterBase &rewriter) {
99  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
100  int64_t preLast = mType.getRank() - 2;
101  Type llvmInt64Type = rewriter.getIntegerType(64);
102  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
103  assert(llvm::isPowerOf2_64(width) && width >= 8);
104  unsigned bytes = width >> 3;
105  auto [strides, offset] = mType.getStridesAndOffset();
106  if (strides[preLast] == ShapedType::kDynamic) {
107  // Dynamic stride needs code to compute the stride at runtime.
108  MemRefDescriptor memrefDescriptor(base);
109  return computeStrideInBytes(
110  loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
111  }
112  // Use direct constant for static stride.
113  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
114  return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
115  .getResult();
116 }
117 
118 LogicalResult amx::TileZeroOp::verify() {
119  return verifyTileSize(*this, getTileType());
120 }
121 
123 amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
124  const LLVMTypeConverter &typeConverter,
125  RewriterBase &rewriter) {
126  return getTileSizes(getLoc(), getTileType(), rewriter);
127 }
128 
129 template <typename OpTy,
130  typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
131  std::is_same_v<OpTy, amx::TileStoreOp>>>
132 static LogicalResult tileTransferVerifier(OpTy op) {
133  MemRefType memrefTy = op.getMemRefType();
134  unsigned rank = memrefTy.getRank();
135  if (op.getIndices().size() != rank)
136  return op.emitOpError("requires ") << rank << " indices";
137 
138  if (failed(verifyTileSize(op, op.getTileType())))
139  return failure();
140 
141  // Validate basic buffer properties when the stride is implicit.
142  if (!op.getStride()) {
143  if (rank < 2)
144  return op.emitOpError("requires at least 2D memref");
145  SmallVector<int64_t> strides;
146  int64_t offset;
147  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
148  strides.back() != 1)
149  return op.emitOpError("requires memref with unit innermost stride");
150  }
151 
152  return success();
153 }
154 
155 void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
156  Value base, ValueRange indices) {
157  build(builder, state, res, base, indices, /*stride=*/nullptr);
158 }
159 
160 LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
161 
163 amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
164  const LLVMTypeConverter &typeConverter,
165  RewriterBase &rewriter) {
166  auto loc = getLoc();
167  Adaptor adaptor(operands, *this);
168 
169  SmallVector<Value> intrinsicOperands;
170  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
171  intrinsicOperands.push_back(
172  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
173  adaptor.getBase(), adaptor.getIndices()));
174  if (Value stride = adaptor.getStride())
175  intrinsicOperands.push_back(
176  computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
177  else
178  intrinsicOperands.push_back(
179  inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
180 
181  return intrinsicOperands;
182 }
183 
184 void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
185  Value base, ValueRange indices, Value val) {
186  build(builder, state, base, indices, val, /*stride=*/nullptr);
187 }
188 
189 LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
190 
192 amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
193  const LLVMTypeConverter &typeConverter,
194  RewriterBase &rewriter) {
195  auto loc = getLoc();
196  Adaptor adaptor(operands, *this);
197 
198  SmallVector<Value> intrinsicOperands;
199  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
200  intrinsicOperands.push_back(
201  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
202  adaptor.getBase(), adaptor.getIndices()));
203  if (Value stride = adaptor.getStride())
204  intrinsicOperands.push_back(
205  computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
206  else
207  intrinsicOperands.push_back(
208  inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
209  intrinsicOperands.push_back(adaptor.getVal());
210 
211  return intrinsicOperands;
212 }
213 
214 LogicalResult amx::TileMulFOp::verify() {
215  amx::TileType aType = getLhsTileType();
216  amx::TileType bType = getRhsTileType();
217  amx::TileType cType = getTileType();
218  if (failed(verifyTileSize(*this, aType)) ||
219  failed(verifyTileSize(*this, bType)) ||
220  failed(verifyTileSize(*this, cType)) ||
221  failed(verifyMultShape(*this, aType, bType, cType, 1)))
222  return failure();
223  Type ta = aType.getElementType();
224  Type tb = bType.getElementType();
225  Type tc = cType.getElementType();
226  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
227  return emitOpError("unsupported type combination");
228  return success();
229 }
230 
232 amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
233  const LLVMTypeConverter &typeConverter,
234  RewriterBase &rewriter) {
235  auto loc = getLoc();
236  Adaptor adaptor(operands, *this);
237 
238  amx::TileType aType = getLhsTileType();
239  amx::TileType bType = getRhsTileType();
240  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
241  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
242 
243  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
244  tsza[1], adaptor.getAcc(),
245  adaptor.getLhs(), adaptor.getRhs()};
246 
247  return intrinsicOperands;
248 }
249 
250 LogicalResult amx::TileMulIOp::verify() {
251  amx::TileType aType = getLhsTileType();
252  amx::TileType bType = getRhsTileType();
253  amx::TileType cType = getTileType();
254  if (failed(verifyTileSize(*this, aType)) ||
255  failed(verifyTileSize(*this, bType)) ||
256  failed(verifyTileSize(*this, cType)) ||
257  failed(verifyMultShape(*this, aType, bType, cType, 2)))
258  return failure();
259  Type ta = aType.getElementType();
260  Type tb = bType.getElementType();
261  Type tc = cType.getElementType();
262  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
263  return emitOpError("unsupported type combination");
264  return success();
265 }
266 
268 amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
269  const LLVMTypeConverter &typeConverter,
270  RewriterBase &rewriter) {
271  auto loc = getLoc();
272  Adaptor adaptor(operands, *this);
273 
274  amx::TileType aType = getLhsTileType();
275  amx::TileType bType = getRhsTileType();
276  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
277  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
278 
279  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
280  tsza[1], adaptor.getAcc(),
281  adaptor.getLhs(), adaptor.getRhs()};
282 
283  return intrinsicOperands;
284 }
285 
287  if (parser.parseLess())
288  return nullptr;
289 
291  if (parser.parseDimensionList(shape, false, true))
292  return nullptr;
293 
294  Type elementType;
295  if (parser.parseType(elementType))
296  return nullptr;
297 
298  if (parser.parseGreater())
299  return nullptr;
300 
301  return TileType::getChecked(
302  [&] { return parser.emitError(parser.getNameLoc()); }, shape,
303  elementType);
304 }
305 
306 void amx::TileType::print(AsmPrinter &os) const {
307  os << "<";
309  os << 'x';
311  os << '>';
312 }
313 
314 #define GET_OP_CLASSES
315 #include "mlir/Dialect/AMX/AMX.cpp.inc"
316 
317 #define GET_TYPEDEF_CLASSES
318 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static Value inferStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:97
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:55
static Value computeStrideInBytes(Location loc, MemRefType mType, Value elementStride, RewriterBase &rewriter)
Returns stride expressed in number of bytes for the given elementStride stride encoded in number of e...
Definition: AMXDialect.cpp:85
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:43
static LogicalResult tileTransferVerifier(OpTy op)
Definition: AMXDialect.cpp:132
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:217
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.