MLIR 22.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMX dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/Builders.h"
21
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace mlir;
25
26#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
27
28#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
29
30void amx::AMXDialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34 >();
35
36 addOperations<
37#define GET_OP_LIST
38#include "mlir/Dialect/AMX/AMX.cpp.inc"
39 >();
40}
41
42/// Verify that AMX supports the implied tile shape.
43static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
44 const unsigned kMaxRows = 16;
45 const unsigned kBitsPerRow = 64 * 8;
46 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47 if (tp.getDimSize(0) > kMaxRows)
48 return op->emitOpError("bad row height: ") << tp.getDimSize(0);
49 if (col > kBitsPerRow || col & 0x1f)
50 return op->emitOpError("bad column width: ") << (col >> 3);
51 return success();
52}
53
54/// Verify that AMX supports the multiplication.
55static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
56 amx::TileType btp, amx::TileType ctp,
57 unsigned scale) {
58 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61 if (cm != am || cn != bn || ak != bk)
62 return op->emitOpError("bad mult shape: ")
63 << cm << " x " << cn << " x " << ak;
64 return success();
65}
66
67/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
68/// dimension directly translates into the number of rows of the tiles.
69/// The second dimensions needs to be scaled by the number of bytes.
70static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
71 RewriterBase &rewriter) {
72 Type llvmInt16Type = rewriter.getIntegerType(16);
73 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74 assert(llvm::isPowerOf2_64(width) && width >= 8);
75 unsigned bytes = width >> 3;
76 auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
77 auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
78 return SmallVector<Value>{
79 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
81}
82
83/// Returns stride expressed in number of bytes for the given `elementStride`
84/// stride encoded in number of elements of the type `mType`.
85static Value computeStrideInBytes(Location loc, MemRefType mType,
86 Value elementStride, RewriterBase &rewriter) {
87 Type llvmInt64Type = rewriter.getIntegerType(64);
88 unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
89 auto attr = rewriter.getI64IntegerAttr(bytes);
90 Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
91 return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
92 .getResult();
93}
94
95/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
96/// shape may "envelop" the actual tile shape, and may be dynamically sized.
97static Value inferStride(Location loc, MemRefType mType, Value base,
98 RewriterBase &rewriter) {
99 assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
100 int64_t preLast = mType.getRank() - 2;
101 Type llvmInt64Type = rewriter.getIntegerType(64);
102 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
103 assert(llvm::isPowerOf2_64(width) && width >= 8);
104 unsigned bytes = width >> 3;
105 auto [strides, offset] = mType.getStridesAndOffset();
106 if (strides[preLast] == ShapedType::kDynamic) {
107 // Dynamic stride needs code to compute the stride at runtime.
108 MemRefDescriptor memrefDescriptor(base);
110 loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
111 }
112 // Use direct constant for static stride.
113 auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
114 return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
115 .getResult();
116}
117
118LogicalResult amx::TileZeroOp::verify() {
119 return verifyTileSize(*this, getTileType());
120}
121
123amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
124 const LLVMTypeConverter &typeConverter,
125 RewriterBase &rewriter) {
126 return getTileSizes(getLoc(), getTileType(), rewriter);
127}
128
129template <typename OpTy,
130 typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
131 std::is_same_v<OpTy, amx::TileStoreOp>>>
132static LogicalResult tileTransferVerifier(OpTy op) {
133 MemRefType memrefTy = op.getMemRefType();
134 unsigned rank = memrefTy.getRank();
135 if (op.getIndices().size() != rank)
136 return op.emitOpError("requires ") << rank << " indices";
137
138 if (failed(verifyTileSize(op, op.getTileType())))
139 return failure();
140
141 // Validate basic buffer properties when the stride is implicit.
142 if (!op.getStride()) {
143 if (rank < 2)
144 return op.emitOpError("requires at least 2D memref");
145 SmallVector<int64_t> strides;
146 int64_t offset;
147 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
148 strides.back() != 1)
149 return op.emitOpError("requires memref with unit innermost stride");
150 }
151
152 return success();
153}
154
155void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
156 Value base, ValueRange indices) {
157 build(builder, state, res, base, indices, /*stride=*/nullptr);
158}
159
160LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
161
163amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
164 const LLVMTypeConverter &typeConverter,
165 RewriterBase &rewriter) {
166 auto loc = getLoc();
167 Adaptor adaptor(operands, *this);
168
169 SmallVector<Value> intrinsicOperands;
170 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
171 intrinsicOperands.push_back(
172 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
173 adaptor.getBase(), adaptor.getIndices()));
174 if (Value stride = adaptor.getStride())
175 intrinsicOperands.push_back(
176 computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
177 else
178 intrinsicOperands.push_back(
179 inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
180
181 return intrinsicOperands;
182}
183
184void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
185 Value base, ValueRange indices, Value val) {
186 build(builder, state, base, indices, val, /*stride=*/nullptr);
187}
188
189LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
190
192amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
193 const LLVMTypeConverter &typeConverter,
194 RewriterBase &rewriter) {
195 auto loc = getLoc();
196 Adaptor adaptor(operands, *this);
197
198 SmallVector<Value> intrinsicOperands;
199 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
200 intrinsicOperands.push_back(
201 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
202 adaptor.getBase(), adaptor.getIndices()));
203 if (Value stride = adaptor.getStride())
204 intrinsicOperands.push_back(
205 computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
206 else
207 intrinsicOperands.push_back(
208 inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
209 intrinsicOperands.push_back(adaptor.getVal());
210
211 return intrinsicOperands;
212}
213
214LogicalResult amx::TileMulFOp::verify() {
215 amx::TileType aType = getLhsTileType();
216 amx::TileType bType = getRhsTileType();
217 amx::TileType cType = getTileType();
218 if (failed(verifyTileSize(*this, aType)) ||
219 failed(verifyTileSize(*this, bType)) ||
220 failed(verifyTileSize(*this, cType)) ||
221 failed(verifyMultShape(*this, aType, bType, cType, 1)))
222 return failure();
223 Type ta = aType.getElementType();
224 Type tb = bType.getElementType();
225 Type tc = cType.getElementType();
226 if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
227 return emitOpError("unsupported type combination");
228 return success();
229}
230
232amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
233 const LLVMTypeConverter &typeConverter,
234 RewriterBase &rewriter) {
235 auto loc = getLoc();
236 Adaptor adaptor(operands, *this);
237
238 amx::TileType aType = getLhsTileType();
239 amx::TileType bType = getRhsTileType();
240 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
241 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
242
243 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
244 tsza[1], adaptor.getAcc(),
245 adaptor.getLhs(), adaptor.getRhs()};
246
247 return intrinsicOperands;
248}
249
250LogicalResult amx::TileMulIOp::verify() {
251 amx::TileType aType = getLhsTileType();
252 amx::TileType bType = getRhsTileType();
253 amx::TileType cType = getTileType();
254 if (failed(verifyTileSize(*this, aType)) ||
255 failed(verifyTileSize(*this, bType)) ||
256 failed(verifyTileSize(*this, cType)) ||
257 failed(verifyMultShape(*this, aType, bType, cType, 2)))
258 return failure();
259 Type ta = aType.getElementType();
260 Type tb = bType.getElementType();
261 Type tc = cType.getElementType();
262 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
263 return emitOpError("unsupported type combination");
264 return success();
265}
266
268amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
269 const LLVMTypeConverter &typeConverter,
270 RewriterBase &rewriter) {
271 auto loc = getLoc();
272 Adaptor adaptor(operands, *this);
273
274 amx::TileType aType = getLhsTileType();
275 amx::TileType bType = getRhsTileType();
276 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
277 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
278
279 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
280 tsza[1], adaptor.getAcc(),
281 adaptor.getLhs(), adaptor.getRhs()};
282
283 return intrinsicOperands;
284}
285
286Type amx::TileType::parse(AsmParser &parser) {
287 if (parser.parseLess())
288 return nullptr;
289
291 if (parser.parseDimensionList(shape, false, true))
292 return nullptr;
293
294 Type elementType;
295 if (parser.parseType(elementType))
296 return nullptr;
297
298 if (parser.parseGreater())
299 return nullptr;
300
301 return TileType::getChecked(
302 [&] { return parser.emitError(parser.getNameLoc()); }, shape,
303 elementType);
304}
305
306void amx::TileType::print(AsmPrinter &os) const {
307 os << "<";
309 os << 'x';
311 os << '>';
312}
313
314#define GET_OP_CLASSES
315#include "mlir/Dialect/AMX/AMX.cpp.inc"
316
317#define GET_TYPEDEF_CLASSES
318#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static Value inferStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static Value computeStrideInBytes(Location loc, MemRefType mType, Value elementStride, RewriterBase &rewriter)
Returns stride expressed in number of bytes for the given elementStride stride encoded in number of e...
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
static LogicalResult tileTransferVerifier(OpTy op)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
Definition Builders.cpp:217
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:478
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.