22 #include "llvm/ADT/TypeSwitch.h"
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
30 void amx::AMXDialect::initialize() {
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
44 const unsigned kMaxRows = 16;
45 const unsigned kBitsPerRow = 64 * 8;
46 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47 if (tp.getDimSize(0) > kMaxRows)
48 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
49 if (col > kBitsPerRow || col & 0x1f)
50 return op->
emitOpError(
"bad column width: ") << (col >> 3);
56 amx::TileType btp, amx::TileType ctp,
58 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61 if (cm != am || cn != bn || ak != bk)
63 << cm <<
" x " << cn <<
" x " << ak;
73 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74 assert(llvm::isPowerOf2_64(width) && width >= 8);
75 unsigned bytes = width >> 3;
79 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
87 assert(mType.getRank() >= 2 &&
"Invalid shape for AMX strides");
88 int64_t preLast = mType.getRank() - 2;
90 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
91 assert(llvm::isPowerOf2_64(width) && width >= 8);
92 unsigned bytes = width >> 3;
93 auto [strides, offset] = mType.getStridesAndOffset();
94 if (strides[preLast] == ShapedType::kDynamic) {
98 Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
99 return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
100 memrefDescriptor.
stride(rewriter, loc, preLast))
105 return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
122 unsigned rank = memrefTy.getRank();
124 return emitOpError(
"requires at least 2D memref");
126 return emitOpError(
"requires ") << rank <<
" indices";
129 if (
failed(memrefTy.getStridesAndOffset(strides, offset)) ||
131 return emitOpError(
"requires memref with unit innermost stride");
140 Adaptor adaptor(operands, *
this);
143 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
144 intrinsicOperands.push_back(
146 adaptor.getBase(), adaptor.getIndices()));
147 intrinsicOperands.push_back(
150 return intrinsicOperands;
155 unsigned rank = memrefTy.getRank();
157 return emitOpError(
"requires at least 2D memref");
159 return emitOpError(
"requires ") << rank <<
" indices";
162 if (
failed(memrefTy.getStridesAndOffset(strides, offset)) ||
164 return emitOpError(
"requires memref with unit innermost stride");
173 Adaptor adaptor(operands, *
this);
176 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
177 intrinsicOperands.push_back(
179 adaptor.getBase(), adaptor.getIndices()));
180 intrinsicOperands.push_back(
182 intrinsicOperands.push_back(adaptor.getVal());
184 return intrinsicOperands;
188 amx::TileType aType = getLhsTileType();
189 amx::TileType bType = getRhsTileType();
190 amx::TileType cType = getTileType();
196 Type ta = aType.getElementType();
197 Type tb = bType.getElementType();
198 Type tc = cType.getElementType();
200 return emitOpError(
"unsupported type combination");
209 Adaptor adaptor(operands, *
this);
211 amx::TileType aType = getLhsTileType();
212 amx::TileType bType = getRhsTileType();
217 tsza[1], adaptor.getAcc(),
218 adaptor.getLhs(), adaptor.getRhs()};
220 return intrinsicOperands;
224 amx::TileType aType = getLhsTileType();
225 amx::TileType bType = getRhsTileType();
226 amx::TileType cType = getTileType();
232 Type ta = aType.getElementType();
233 Type tb = bType.getElementType();
234 Type tc = cType.getElementType();
236 return emitOpError(
"unsupported type combination");
245 Adaptor adaptor(operands, *
this);
247 amx::TileType aType = getLhsTileType();
248 amx::TileType bType = getRhsTileType();
253 tsza[1], adaptor.getAcc(),
254 adaptor.getLhs(), adaptor.getRhs()};
256 return intrinsicOperands;
274 return TileType::getChecked(
287 #define GET_OP_CLASSES
288 #include "mlir/Dialect/AMX/AMX.cpp.inc"
290 #define GET_TYPEDEF_CLASSES
291 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...