20 #include "llvm/ADT/TypeSwitch.h"
24 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
26 void amx::AMXDialect::initialize() {
28 #define GET_TYPEDEF_LIST
29 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34 #include "mlir/Dialect/AMX/AMX.cpp.inc"
40 const unsigned kMaxRows = 16;
41 const unsigned kBitsPerRow = 64 * 8;
42 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
43 if (tp.getDimSize(0) > kMaxRows)
44 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
45 if (col > kBitsPerRow || col & 0x1f)
46 return op->
emitOpError(
"bad column width: ") << (col >> 3);
52 amx::TileType btp, amx::TileType ctp,
54 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
55 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
56 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
57 if (cm != am || cn != bn || ak != bk)
59 << cm <<
" x " << cn <<
" x " << ak;
70 return emitOpError(
"requires ") << rank <<
" indices";
77 return emitOpError(
"requires ") << rank <<
" indices";
82 amx::TileType aType = getLhsTileType();
83 amx::TileType bType = getRhsTileType();
84 amx::TileType cType = getTileType();
90 Type ta = aType.getElementType();
91 Type tb = bType.getElementType();
92 Type tc = cType.getElementType();
94 return emitOpError(
"unsupported type combination");
99 amx::TileType aType = getLhsTileType();
100 amx::TileType bType = getRhsTileType();
101 amx::TileType cType = getTileType();
107 Type ta = aType.getElementType();
108 Type tb = bType.getElementType();
109 Type tc = cType.getElementType();
111 return emitOpError(
"unsupported type combination");
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/AMX/AMX.cpp.inc"
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...