22 #include "llvm/ADT/TypeSwitch.h"
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
30 void amx::AMXDialect::initialize() {
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
44 const unsigned kMaxRows = 16;
45 const unsigned kBitsPerRow = 64 * 8;
46 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47 if (tp.getDimSize(0) > kMaxRows)
48 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
49 if (col > kBitsPerRow || col & 0x1f)
50 return op->
emitOpError(
"bad column width: ") << (col >> 3);
56 amx::TileType btp, amx::TileType ctp,
58 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61 if (cm != am || cn != bn || ak != bk)
63 << cm <<
" x " << cn <<
" x " << ak;
73 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74 assert(llvm::isPowerOf2_64(width) && width >= 8);
75 unsigned bytes = width >> 3;
79 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80 LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
88 unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
90 Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
91 return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
99 assert(mType.getRank() >= 2 &&
"Invalid shape for AMX strides");
100 int64_t preLast = mType.getRank() - 2;
102 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
103 assert(llvm::isPowerOf2_64(width) && width >= 8);
104 unsigned bytes = width >> 3;
105 auto [strides, offset] = mType.getStridesAndOffset();
106 if (strides[preLast] == ShapedType::kDynamic) {
110 loc, mType, memrefDescriptor.
stride(rewriter, loc, preLast), rewriter);
114 return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
129 template <
typename OpTy,
130 typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
131 std::is_same_v<OpTy, amx::TileStoreOp>>>
133 MemRefType memrefTy = op.getMemRefType();
134 unsigned rank = memrefTy.getRank();
135 if (op.getIndices().size() != rank)
136 return op.emitOpError(
"requires ") << rank <<
" indices";
142 if (!op.getStride()) {
144 return op.emitOpError(
"requires at least 2D memref");
147 if (
failed(memrefTy.getStridesAndOffset(strides, offset)) ||
149 return op.emitOpError(
"requires memref with unit innermost stride");
157 build(builder, state, res, base, indices,
nullptr);
167 Adaptor adaptor(operands, *
this);
170 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
171 intrinsicOperands.push_back(
173 adaptor.getBase(), adaptor.getIndices()));
174 if (
Value stride = adaptor.getStride())
175 intrinsicOperands.push_back(
178 intrinsicOperands.push_back(
181 return intrinsicOperands;
186 build(builder, state, base, indices, val,
nullptr);
196 Adaptor adaptor(operands, *
this);
199 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
200 intrinsicOperands.push_back(
202 adaptor.getBase(), adaptor.getIndices()));
203 if (
Value stride = adaptor.getStride())
204 intrinsicOperands.push_back(
207 intrinsicOperands.push_back(
209 intrinsicOperands.push_back(adaptor.getVal());
211 return intrinsicOperands;
215 amx::TileType aType = getLhsTileType();
216 amx::TileType bType = getRhsTileType();
217 amx::TileType cType = getTileType();
223 Type ta = aType.getElementType();
224 Type tb = bType.getElementType();
225 Type tc = cType.getElementType();
227 return emitOpError(
"unsupported type combination");
236 Adaptor adaptor(operands, *
this);
238 amx::TileType aType = getLhsTileType();
239 amx::TileType bType = getRhsTileType();
244 tsza[1], adaptor.getAcc(),
245 adaptor.getLhs(), adaptor.getRhs()};
247 return intrinsicOperands;
251 amx::TileType aType = getLhsTileType();
252 amx::TileType bType = getRhsTileType();
253 amx::TileType cType = getTileType();
259 Type ta = aType.getElementType();
260 Type tb = bType.getElementType();
261 Type tc = cType.getElementType();
263 return emitOpError(
"unsupported type combination");
272 Adaptor adaptor(operands, *
this);
274 amx::TileType aType = getLhsTileType();
275 amx::TileType bType = getRhsTileType();
280 tsza[1], adaptor.getAcc(),
281 adaptor.getLhs(), adaptor.getRhs()};
283 return intrinsicOperands;
301 return TileType::getChecked(
314 #define GET_OP_CLASSES
315 #include "mlir/Dialect/AMX/AMX.cpp.inc"
317 #define GET_TYPEDEF_CLASSES
318 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static Value inferStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static Value computeStrideInBytes(Location loc, MemRefType mType, Value elementStride, RewriterBase &rewriter)
Returns stride expressed in number of bytes for the given elementStride stride encoded in number of e...
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
static LogicalResult tileTransferVerifier(OpTy op)
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.