MLIR 23.0.0git
X86Dialect.cpp
Go to the documentation of this file.
1//===- X86Dialect.cpp - MLIR X86 ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the X86 dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/Builders.h"
21
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace mlir;
25
26#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
27
28#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
29
30void x86::X86Dialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/X86/X86Types.cpp.inc"
34 >();
35
36 addOperations<
37#define GET_OP_LIST
38#include "mlir/Dialect/X86/X86.cpp.inc"
39 >();
40}
41
42static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
43 const LLVMTypeConverter &typeConverter,
44 RewriterBase &rewriter) {
45 MemRefDescriptor memRefDescriptor(buffer);
46 return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
47}
48
49LogicalResult x86::MaskCompressOp::verify() {
50 if (getSrc() && getConstantSrc())
51 return emitError("cannot use both src and constant_src");
52
53 if (getSrc() && (getSrc().getType() != getDst().getType()))
54 return emitError("failed to verify that src and dst have same type");
55
56 if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
57 return emitError(
58 "failed to verify that constant_src and dst have same type");
59
60 return success();
61}
62
63SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
64 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
65 RewriterBase &rewriter) {
66 auto loc = getLoc();
67 Adaptor adaptor(operands, *this);
68
69 auto opType = adaptor.getA().getType();
70 Value src;
71 if (adaptor.getSrc()) {
72 src = adaptor.getSrc();
73 } else if (adaptor.getConstantSrc()) {
74 src = LLVM::ConstantOp::create(rewriter, loc, opType,
75 adaptor.getConstantSrcAttr());
76 } else {
77 auto zeroAttr = rewriter.getZeroAttr(opType);
78 src = LLVM::ConstantOp::create(rewriter, loc, opType, zeroAttr);
79 }
80
81 return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
82}
83
85x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
86 const LLVMTypeConverter &typeConverter,
87 RewriterBase &rewriter) {
88 SmallVector<Value> intrinsicOperands(operands);
89 // Dot product of all elements, broadcasted to all elements.
90 Value scale =
91 LLVM::ConstantOp::create(rewriter, getLoc(), rewriter.getI8Type(), 0xff);
92 intrinsicOperands.push_back(scale);
93
94 return intrinsicOperands;
95}
96
97SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
98 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
99 RewriterBase &rewriter) {
100 Adaptor adaptor(operands, *this);
101 return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
102 typeConverter, rewriter)};
103}
104
105SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
106 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
107 RewriterBase &rewriter) {
108 Adaptor adaptor(operands, *this);
109 return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
110 typeConverter, rewriter)};
111}
112
113SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
114 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
115 RewriterBase &rewriter) {
116 Adaptor adaptor(operands, *this);
117 return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
118 typeConverter, rewriter)};
119}
120
121/// Verify that AMX supports the implied tile shape.
122static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp) {
123 const unsigned kMaxRows = 16;
124 const unsigned kBitsPerRow = 64 * 8;
125 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
126 if (tp.getDimSize(0) > kMaxRows)
127 return op->emitOpError("bad row height: ") << tp.getDimSize(0);
128 if (col > kBitsPerRow || col & 0x1f)
129 return op->emitOpError("bad column width: ") << (col >> 3);
130 return success();
131}
132
133/// Verify that AMX supports the multiplication.
134static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp,
136 x86::amx::TileType ctp, unsigned scale) {
137 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
138 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
139 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
140 if (cm != am || cn != bn || ak != bk)
141 return op->emitOpError("bad mult shape: ")
142 << cm << " x " << cn << " x " << ak;
143 return success();
144}
145
146/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
147/// dimension directly translates into the number of rows of the tiles.
148/// The second dimensions needs to be scaled by the number of bytes.
150 RewriterBase &rewriter) {
151 Type llvmInt16Type = rewriter.getIntegerType(16);
152 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
153 assert(llvm::isPowerOf2_64(width) && width >= 8);
154 unsigned bytes = width >> 3;
155 auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
156 auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
157 return SmallVector<Value>{
158 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
159 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
160}
161
162/// Returns stride expressed in number of bytes for the given `elementStride`
163/// stride encoded in number of elements of the type `mType`.
164static Value computeStrideInBytes(Location loc, MemRefType mType,
165 Value elementStride, RewriterBase &rewriter) {
166 Type llvmInt64Type = rewriter.getIntegerType(64);
167 unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
168 auto attr = rewriter.getI64IntegerAttr(bytes);
169 Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
170 return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
171 .getResult();
172}
173
174/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
175/// shape may "envelop" the actual tile shape, and may be dynamically sized.
176static Value inferStride(Location loc, MemRefType mType, Value base,
177 RewriterBase &rewriter) {
178 assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
179 int64_t preLast = mType.getRank() - 2;
180 Type llvmInt64Type = rewriter.getIntegerType(64);
181 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
182 assert(llvm::isPowerOf2_64(width) && width >= 8);
183 unsigned bytes = width >> 3;
184 auto [strides, offset] = mType.getStridesAndOffset();
185 if (strides[preLast] == ShapedType::kDynamic) {
186 // Dynamic stride needs code to compute the stride at runtime.
187 MemRefDescriptor memrefDescriptor(base);
189 loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
190 }
191 // Use direct constant for static stride.
192 auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
193 return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
194 .getResult();
195}
196
197LogicalResult x86::amx::TileZeroOp::verify() {
198 return verifyTileSize(*this, getTileType());
199}
200
201SmallVector<Value> x86::amx::TileZeroOp::getIntrinsicOperands(
202 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
203 RewriterBase &rewriter) {
204 return getTileSizes(getLoc(), getTileType(), rewriter);
205}
206
207template <typename OpTy, typename = std::enable_if_t<
208 std::is_same_v<OpTy, x86::amx::TileLoadOp> ||
209 std::is_same_v<OpTy, x86::amx::TileStoreOp>>>
210static LogicalResult tileTransferVerifier(OpTy op) {
211 MemRefType memrefTy = op.getMemRefType();
212 unsigned rank = memrefTy.getRank();
213 if (op.getIndices().size() != rank)
214 return op.emitOpError("requires ") << rank << " indices";
215
216 if (failed(verifyTileSize(op, op.getTileType())))
217 return failure();
218
219 // Validate basic buffer properties when the stride is implicit.
220 if (!op.getStride()) {
221 if (rank < 2)
222 return op.emitOpError("requires at least 2D memref");
223 SmallVector<int64_t> strides;
224 int64_t offset;
225 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
226 strides.back() != 1)
227 return op.emitOpError("requires memref with unit innermost stride");
228 }
229
230 return success();
231}
232
233void x86::amx::TileLoadOp::build(OpBuilder &builder, OperationState &state,
234 Type res, Value base, ValueRange indices) {
235 build(builder, state, res, base, indices, /*stride=*/nullptr);
236}
237
238LogicalResult x86::amx::TileLoadOp::verify() {
239 return tileTransferVerifier(*this);
240}
241
242SmallVector<Value> x86::amx::TileLoadOp::getIntrinsicOperands(
243 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
244 RewriterBase &rewriter) {
245 auto loc = getLoc();
246 Adaptor adaptor(operands, *this);
247
248 SmallVector<Value> intrinsicOperands;
249 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
250 intrinsicOperands.push_back(
251 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
252 adaptor.getBase(), adaptor.getIndices()));
253 if (Value stride = adaptor.getStride())
254 intrinsicOperands.push_back(
255 computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
256 else
257 intrinsicOperands.push_back(
258 inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
259
260 return intrinsicOperands;
261}
262
263void x86::amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
264 Value base, ValueRange indices, Value val) {
265 build(builder, state, base, indices, val, /*stride=*/nullptr);
266}
267
268LogicalResult x86::amx::TileStoreOp::verify() {
269 return tileTransferVerifier(*this);
270}
271
272SmallVector<Value> x86::amx::TileStoreOp::getIntrinsicOperands(
273 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
274 RewriterBase &rewriter) {
275 auto loc = getLoc();
276 Adaptor adaptor(operands, *this);
277
278 SmallVector<Value> intrinsicOperands;
279 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
280 intrinsicOperands.push_back(
281 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
282 adaptor.getBase(), adaptor.getIndices()));
283 if (Value stride = adaptor.getStride())
284 intrinsicOperands.push_back(
285 computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
286 else
287 intrinsicOperands.push_back(
288 inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
289 intrinsicOperands.push_back(adaptor.getVal());
290
291 return intrinsicOperands;
292}
293
294LogicalResult x86::amx::TileMulFOp::verify() {
295 x86::amx::TileType aType = getLhsTileType();
296 x86::amx::TileType bType = getRhsTileType();
297 x86::amx::TileType cType = getTileType();
298 if (failed(verifyTileSize(*this, aType)) ||
299 failed(verifyTileSize(*this, bType)) ||
300 failed(verifyTileSize(*this, cType)) ||
301 failed(verifyMultShape(*this, aType, bType, cType, 1)))
302 return failure();
303 Type ta = aType.getElementType();
304 Type tb = bType.getElementType();
305 Type tc = cType.getElementType();
306 if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
307 return emitOpError("unsupported type combination");
308 return success();
309}
310
311SmallVector<Value> x86::amx::TileMulFOp::getIntrinsicOperands(
312 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
313 RewriterBase &rewriter) {
314 auto loc = getLoc();
315 Adaptor adaptor(operands, *this);
316
317 x86::amx::TileType aType = getLhsTileType();
318 x86::amx::TileType bType = getRhsTileType();
319 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
320 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
321
322 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
323 tsza[1], adaptor.getAcc(),
324 adaptor.getLhs(), adaptor.getRhs()};
325
326 return intrinsicOperands;
327}
328
329LogicalResult x86::amx::TileMulIOp::verify() {
330 x86::amx::TileType aType = getLhsTileType();
331 x86::amx::TileType bType = getRhsTileType();
332 x86::amx::TileType cType = getTileType();
333 if (failed(verifyTileSize(*this, aType)) ||
334 failed(verifyTileSize(*this, bType)) ||
335 failed(verifyTileSize(*this, cType)) ||
336 failed(verifyMultShape(*this, aType, bType, cType, 2)))
337 return failure();
338 Type ta = aType.getElementType();
339 Type tb = bType.getElementType();
340 Type tc = cType.getElementType();
341 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
342 return emitOpError("unsupported type combination");
343 return success();
344}
345
346SmallVector<Value> x86::amx::TileMulIOp::getIntrinsicOperands(
347 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
348 RewriterBase &rewriter) {
349 auto loc = getLoc();
350 Adaptor adaptor(operands, *this);
351
352 x86::amx::TileType aType = getLhsTileType();
353 x86::amx::TileType bType = getRhsTileType();
354 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
355 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
356
357 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
358 tsza[1], adaptor.getAcc(),
359 adaptor.getLhs(), adaptor.getRhs()};
360
361 return intrinsicOperands;
362}
363
364Type x86::amx::TileType::parse(AsmParser &parser) {
365 if (parser.parseLess())
366 return nullptr;
367
369 if (parser.parseDimensionList(shape, false, true))
370 return nullptr;
371
372 Type elementType;
373 if (parser.parseType(elementType))
374 return nullptr;
375
376 if (parser.parseGreater())
377 return nullptr;
378
379 return AMXTileType::getChecked(
380 [&] { return parser.emitError(parser.getNameLoc()); }, shape,
381 elementType);
382}
383
384void x86::amx::TileType::print(AsmPrinter &os) const {
385 os << "<";
387 os << 'x';
389 os << '>';
390}
391
392#define GET_OP_CLASSES
393#include "mlir/Dialect/X86/X86.cpp.inc"
394
395#define GET_TYPEDEF_CLASSES
396#include "mlir/Dialect/X86/X86Types.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
static Value inferStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Value computeStrideInBytes(Location loc, MemRefType mType, Value elementStride, RewriterBase &rewriter)
Returns stride expressed in number of bytes for the given elementStride stride encoded in number of e...
static LogicalResult tileTransferVerifier(OpTy op)
static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp, x86::amx::TileType btp, x86::amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp)
Verify that AMX supports the implied tile shape.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
Definition Builders.cpp:221
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI8Type()
Definition Builders.cpp:63
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
mlir::x86::AMXTileType TileType
Definition X86Dialect.h:40
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.