MLIR  21.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
21 
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
27 
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
29 
30 void amx::AMXDialect::initialize() {
31  addTypes<
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34  >();
35 
36  addOperations<
37 #define GET_OP_LIST
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
39  >();
40 }
41 
42 /// Verify that AMX supports the implied tile shape.
43 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
44  const unsigned kMaxRows = 16;
45  const unsigned kBitsPerRow = 64 * 8;
46  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47  if (tp.getDimSize(0) > kMaxRows)
48  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
49  if (col > kBitsPerRow || col & 0x1f)
50  return op->emitOpError("bad column width: ") << (col >> 3);
51  return success();
52 }
53 
54 /// Verify that AMX supports the multiplication.
55 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
56  amx::TileType btp, amx::TileType ctp,
57  unsigned scale) {
58  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61  if (cm != am || cn != bn || ak != bk)
62  return op->emitOpError("bad mult shape: ")
63  << cm << " x " << cn << " x " << ak;
64  return success();
65 }
66 
67 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
68 /// dimension directly translates into the number of rows of the tiles.
69 /// The second dimensions needs to be scaled by the number of bytes.
70 static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
71  RewriterBase &rewriter) {
72  Type llvmInt16Type = rewriter.getIntegerType(16);
73  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74  assert(llvm::isPowerOf2_64(width) && width >= 8);
75  unsigned bytes = width >> 3;
76  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
77  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
78  return SmallVector<Value>{
79  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
80  rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
81 }
82 
83 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
84 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
85 static Value getStride(Location loc, MemRefType mType, Value base,
86  RewriterBase &rewriter) {
87  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
88  int64_t preLast = mType.getRank() - 2;
89  Type llvmInt64Type = rewriter.getIntegerType(64);
90  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
91  assert(llvm::isPowerOf2_64(width) && width >= 8);
92  unsigned bytes = width >> 3;
93  auto [strides, offset] = mType.getStridesAndOffset();
94  if (strides[preLast] == ShapedType::kDynamic) {
95  // Dynamic stride needs code to compute the stride at runtime.
96  MemRefDescriptor memrefDescriptor(base);
97  auto attr = rewriter.getI64IntegerAttr(bytes);
98  Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
99  return rewriter
100  .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
101  memrefDescriptor.stride(rewriter, loc, preLast))
102  .getResult();
103  }
104  // Use direct constant for static stride.
105  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
106  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
107  .getResult();
108 }
109 
110 LogicalResult amx::TileZeroOp::verify() {
111  return verifyTileSize(*this, getTileType());
112 }
113 
115 amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
116  const LLVMTypeConverter &typeConverter,
117  RewriterBase &rewriter) {
118  return getTileSizes(getLoc(), getTileType(), rewriter);
119 }
120 
121 LogicalResult amx::TileLoadOp::verify() {
122  MemRefType memrefTy = getMemRefType();
123  unsigned rank = memrefTy.getRank();
124  if (rank < 2)
125  return emitOpError("requires at least 2D memref");
126  if (getIndices().size() != rank)
127  return emitOpError("requires ") << rank << " indices";
128  SmallVector<int64_t> strides;
129  int64_t offset;
130  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
131  strides.back() != 1)
132  return emitOpError("requires memref with unit innermost stride");
133  return verifyTileSize(*this, getTileType());
134 }
135 
137 amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
138  const LLVMTypeConverter &typeConverter,
139  RewriterBase &rewriter) {
140  auto loc = getLoc();
141  Adaptor adaptor(operands, *this);
142 
143  SmallVector<Value> intrinsicOperands;
144  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
145  intrinsicOperands.push_back(
146  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
147  adaptor.getBase(), adaptor.getIndices()));
148  intrinsicOperands.push_back(
149  getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
150 
151  return intrinsicOperands;
152 }
153 
154 LogicalResult amx::TileStoreOp::verify() {
155  MemRefType memrefTy = getMemRefType();
156  unsigned rank = memrefTy.getRank();
157  if (rank < 2)
158  return emitOpError("requires at least 2D memref");
159  if (getIndices().size() != rank)
160  return emitOpError("requires ") << rank << " indices";
161  SmallVector<int64_t> strides;
162  int64_t offset;
163  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
164  strides.back() != 1)
165  return emitOpError("requires memref with unit innermost stride");
166  return verifyTileSize(*this, getTileType());
167 }
168 
170 amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
171  const LLVMTypeConverter &typeConverter,
172  RewriterBase &rewriter) {
173  auto loc = getLoc();
174  Adaptor adaptor(operands, *this);
175 
176  SmallVector<Value> intrinsicOperands;
177  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
178  intrinsicOperands.push_back(
179  LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
180  adaptor.getBase(), adaptor.getIndices()));
181  intrinsicOperands.push_back(
182  getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
183  intrinsicOperands.push_back(adaptor.getVal());
184 
185  return intrinsicOperands;
186 }
187 
188 LogicalResult amx::TileMulFOp::verify() {
189  amx::TileType aType = getLhsTileType();
190  amx::TileType bType = getRhsTileType();
191  amx::TileType cType = getTileType();
192  if (failed(verifyTileSize(*this, aType)) ||
193  failed(verifyTileSize(*this, bType)) ||
194  failed(verifyTileSize(*this, cType)) ||
195  failed(verifyMultShape(*this, aType, bType, cType, 1)))
196  return failure();
197  Type ta = aType.getElementType();
198  Type tb = bType.getElementType();
199  Type tc = cType.getElementType();
200  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
201  return emitOpError("unsupported type combination");
202  return success();
203 }
204 
206 amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
207  const LLVMTypeConverter &typeConverter,
208  RewriterBase &rewriter) {
209  auto loc = getLoc();
210  Adaptor adaptor(operands, *this);
211 
212  amx::TileType aType = getLhsTileType();
213  amx::TileType bType = getRhsTileType();
214  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
215  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
216 
217  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
218  tsza[1], adaptor.getAcc(),
219  adaptor.getLhs(), adaptor.getRhs()};
220 
221  return intrinsicOperands;
222 }
223 
224 LogicalResult amx::TileMulIOp::verify() {
225  amx::TileType aType = getLhsTileType();
226  amx::TileType bType = getRhsTileType();
227  amx::TileType cType = getTileType();
228  if (failed(verifyTileSize(*this, aType)) ||
229  failed(verifyTileSize(*this, bType)) ||
230  failed(verifyTileSize(*this, cType)) ||
231  failed(verifyMultShape(*this, aType, bType, cType, 2)))
232  return failure();
233  Type ta = aType.getElementType();
234  Type tb = bType.getElementType();
235  Type tc = cType.getElementType();
236  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
237  return emitOpError("unsupported type combination");
238  return success();
239 }
240 
242 amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
243  const LLVMTypeConverter &typeConverter,
244  RewriterBase &rewriter) {
245  auto loc = getLoc();
246  Adaptor adaptor(operands, *this);
247 
248  amx::TileType aType = getLhsTileType();
249  amx::TileType bType = getRhsTileType();
250  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
251  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
252 
253  SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
254  tsza[1], adaptor.getAcc(),
255  adaptor.getLhs(), adaptor.getRhs()};
256 
257  return intrinsicOperands;
258 }
259 
261  if (parser.parseLess())
262  return nullptr;
263 
265  if (parser.parseDimensionList(shape, false, true))
266  return nullptr;
267 
268  Type elementType;
269  if (parser.parseType(elementType))
270  return nullptr;
271 
272  if (parser.parseGreater())
273  return nullptr;
274 
275  return TileType::get(shape, elementType);
276 }
277 
278 void amx::TileType::print(AsmPrinter &os) const {
279  os << "<";
281  os << 'x';
283  os << '>';
284 }
285 
286 #define GET_OP_CLASSES
287 #include "mlir/Dialect/AMX/AMX.cpp.inc"
288 
289 #define GET_TYPEDEF_CLASSES
290 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:55
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:43
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:213
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423