22 #include "llvm/ADT/TypeSwitch.h"
26 #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
28 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
30 void amx::AMXDialect::initialize() {
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
38 #include "mlir/Dialect/AMX/AMX.cpp.inc"
44 const unsigned kMaxRows = 16;
45 const unsigned kBitsPerRow = 64 * 8;
46 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47 if (tp.getDimSize(0) > kMaxRows)
48 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
49 if (col > kBitsPerRow || col & 0x1f)
50 return op->
emitOpError(
"bad column width: ") << (col >> 3);
56 amx::TileType btp, amx::TileType ctp,
58 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61 if (cm != am || cn != bn || ak != bk)
63 << cm <<
" x " << cn <<
" x " << ak;
73 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74 assert(llvm::isPowerOf2_64(width) && width >= 8);
75 unsigned bytes = width >> 3;
79 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
80 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
87 assert(mType.getRank() >= 2 &&
"Invalid shape for AMX strides");
88 int64_t preLast = mType.getRank() - 2;
90 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
91 assert(llvm::isPowerOf2_64(width) && width >= 8);
92 unsigned bytes = width >> 3;
93 auto [strides, offset] = mType.getStridesAndOffset();
94 if (strides[preLast] == ShapedType::kDynamic) {
98 Value scale = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
100 .
create<LLVM::MulOp>(loc, llvmInt64Type, scale,
101 memrefDescriptor.
stride(rewriter, loc, preLast))
106 return rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
123 unsigned rank = memrefTy.getRank();
125 return emitOpError(
"requires at least 2D memref");
127 return emitOpError(
"requires ") << rank <<
" indices";
130 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
132 return emitOpError(
"requires memref with unit innermost stride");
141 Adaptor adaptor(operands, *
this);
144 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
145 intrinsicOperands.push_back(
147 adaptor.getBase(), adaptor.getIndices()));
148 intrinsicOperands.push_back(
151 return intrinsicOperands;
156 unsigned rank = memrefTy.getRank();
158 return emitOpError(
"requires at least 2D memref");
160 return emitOpError(
"requires ") << rank <<
" indices";
163 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
165 return emitOpError(
"requires memref with unit innermost stride");
174 Adaptor adaptor(operands, *
this);
177 intrinsicOperands.append(
getTileSizes(loc, getTileType(), rewriter));
178 intrinsicOperands.push_back(
180 adaptor.getBase(), adaptor.getIndices()));
181 intrinsicOperands.push_back(
183 intrinsicOperands.push_back(adaptor.getVal());
185 return intrinsicOperands;
189 amx::TileType aType = getLhsTileType();
190 amx::TileType bType = getRhsTileType();
191 amx::TileType cType = getTileType();
197 Type ta = aType.getElementType();
198 Type tb = bType.getElementType();
199 Type tc = cType.getElementType();
201 return emitOpError(
"unsupported type combination");
210 Adaptor adaptor(operands, *
this);
212 amx::TileType aType = getLhsTileType();
213 amx::TileType bType = getRhsTileType();
218 tsza[1], adaptor.getAcc(),
219 adaptor.getLhs(), adaptor.getRhs()};
221 return intrinsicOperands;
225 amx::TileType aType = getLhsTileType();
226 amx::TileType bType = getRhsTileType();
227 amx::TileType cType = getTileType();
233 Type ta = aType.getElementType();
234 Type tb = bType.getElementType();
235 Type tc = cType.getElementType();
237 return emitOpError(
"unsupported type combination");
246 Adaptor adaptor(operands, *
this);
248 amx::TileType aType = getLhsTileType();
249 amx::TileType bType = getRhsTileType();
254 tsza[1], adaptor.getAcc(),
255 adaptor.getLhs(), adaptor.getRhs()};
257 return intrinsicOperands;
286 #define GET_OP_CLASSES
287 #include "mlir/Dialect/AMX/AMX.cpp.inc"
289 #define GET_TYPEDEF_CLASSES
290 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...