22#include "llvm/ADT/TypeSwitch.h"
26#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
28#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
30void x86::X86Dialect::initialize() {
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/X86/X86Types.cpp.inc"
38#include "mlir/Dialect/X86/X86.cpp.inc"
46 return memRefDescriptor.
bufferPtr(rewriter, loc, typeConverter, type);
49LogicalResult x86::MaskCompressOp::verify() {
50 if (getSrc() && getConstantSrc())
51 return emitError(
"cannot use both src and constant_src");
54 return emitError(
"failed to verify that src and dst have same type");
56 if (getConstantSrc() && (getConstantSrc()->
getType() != getDst().
getType()))
58 "failed to verify that constant_src and dst have same type");
67 Adaptor adaptor(operands, *
this);
69 auto opType = adaptor.getA().getType();
71 if (adaptor.getSrc()) {
72 src = adaptor.getSrc();
73 }
else if (adaptor.getConstantSrc()) {
74 src = LLVM::ConstantOp::create(rewriter, loc, opType,
75 adaptor.getConstantSrcAttr());
78 src = LLVM::ConstantOp::create(rewriter, loc, opType, zeroAttr);
91 LLVM::ConstantOp::create(rewriter, getLoc(), rewriter.
getI8Type(), 0xff);
92 intrinsicOperands.push_back(scale);
94 return intrinsicOperands;
100 Adaptor adaptor(operands, *
this);
102 typeConverter, rewriter)};
108 Adaptor adaptor(operands, *
this);
110 typeConverter, rewriter)};
116 Adaptor adaptor(operands, *
this);
118 typeConverter, rewriter)};
123 const unsigned kMaxRows = 16;
124 const unsigned kBitsPerRow = 64 * 8;
125 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
126 if (tp.getDimSize(0) > kMaxRows)
127 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
128 if (col > kBitsPerRow || col & 0x1f)
129 return op->
emitOpError(
"bad column width: ") << (col >> 3);
137 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
138 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
139 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
140 if (cm != am || cn != bn || ak != bk)
142 << cm <<
" x " << cn <<
" x " << ak;
152 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
153 assert(llvm::isPowerOf2_64(width) && width >= 8);
154 unsigned bytes = width >> 3;
158 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
159 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
167 unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
169 Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
170 return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
178 assert(mType.getRank() >= 2 &&
"Invalid shape for AMX strides");
179 int64_t preLast = mType.getRank() - 2;
181 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
182 assert(llvm::isPowerOf2_64(width) && width >= 8);
183 unsigned bytes = width >> 3;
184 auto [strides, offset] = mType.getStridesAndOffset();
185 if (strides[preLast] == ShapedType::kDynamic) {
189 loc, mType, memrefDescriptor.
stride(rewriter, loc, preLast), rewriter);
193 return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
197LogicalResult x86::amx::TileZeroOp::verify() {
207template <
typename OpTy,
typename = std::enable_if_t<
208 std::is_same_v<OpTy, x86::amx::TileLoadOp> ||
209 std::is_same_v<OpTy, x86::amx::TileStoreOp>>>
211 MemRefType memrefTy = op.getMemRefType();
212 unsigned rank = memrefTy.getRank();
213 if (op.getIndices().size() != rank)
214 return op.emitOpError(
"requires ") << rank <<
" indices";
220 if (!op.getStride()) {
222 return op.emitOpError(
"requires at least 2D memref");
225 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
227 return op.emitOpError(
"requires memref with unit innermost stride");
235 build(builder, state, res, base,
indices,
nullptr);
238LogicalResult x86::amx::TileLoadOp::verify() {
246 Adaptor adaptor(operands, *
this);
249 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
250 intrinsicOperands.push_back(
252 adaptor.getBase(), adaptor.getIndices()));
253 if (
Value stride = adaptor.getStride())
254 intrinsicOperands.push_back(
257 intrinsicOperands.push_back(
260 return intrinsicOperands;
265 build(builder, state, base,
indices, val,
nullptr);
268LogicalResult x86::amx::TileStoreOp::verify() {
276 Adaptor adaptor(operands, *
this);
279 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
280 intrinsicOperands.push_back(
282 adaptor.getBase(), adaptor.getIndices()));
283 if (
Value stride = adaptor.getStride())
284 intrinsicOperands.push_back(
287 intrinsicOperands.push_back(
289 intrinsicOperands.push_back(adaptor.getVal());
291 return intrinsicOperands;
294LogicalResult x86::amx::TileMulFOp::verify() {
303 Type ta = aType.getElementType();
304 Type tb = bType.getElementType();
305 Type tc = cType.getElementType();
307 return emitOpError(
"unsupported type combination");
315 Adaptor adaptor(operands, *
this);
323 tsza[1], adaptor.getAcc(),
324 adaptor.getLhs(), adaptor.getRhs()};
326 return intrinsicOperands;
329LogicalResult x86::amx::TileMulIOp::verify() {
338 Type ta = aType.getElementType();
339 Type tb = bType.getElementType();
340 Type tc = cType.getElementType();
342 return emitOpError(
"unsupported type combination");
350 Adaptor adaptor(operands, *
this);
358 tsza[1], adaptor.getAcc(),
359 adaptor.getLhs(), adaptor.getRhs()};
361 return intrinsicOperands;
379 return AMXTileType::getChecked(
384void x86::amx::TileType::print(
AsmPrinter &os)
const {
392#define GET_OP_CLASSES
393#include "mlir/Dialect/X86/X86.cpp.inc"
395#define GET_TYPEDEF_CLASSES
396#include "mlir/Dialect/X86/X86Types.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
static Value inferStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Value computeStrideInBytes(Location loc, MemRefType mType, Value elementStride, RewriterBase &rewriter)
Returns stride expressed in number of bytes for the given elementStride stride encoded in number of e...
static LogicalResult tileTransferVerifier(OpTy op)
static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp, x86::amx::TileType btp, x86::amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp)
Verify that AMX supports the implied tile shape.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
mlir::x86::AMXTileType TileType
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.