21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
23 void amx::AMXDialect::initialize() {
26 #include "mlir/Dialect/AMX/AMX.cpp.inc"
32 const unsigned kMaxRows = 16;
33 const unsigned kBitsPerRow = 64 * 8;
34 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
35 if (tp.getDimSize(0) > kMaxRows)
36 return op->
emitOpError(
"bad row height: ") << tp.getDimSize(0);
37 if (col > kBitsPerRow || col & 0x1f)
38 return op->
emitOpError(
"bad column width: ") << (col >> 3);
44 VectorType btp, VectorType ctp,
46 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
47 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
48 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
49 if (cm != am || cn != bn || ak != bk)
51 << cm <<
" x " << cn <<
" x " << ak;
62 return emitOpError(
"requires ") << rank <<
" indices";
69 return emitOpError(
"requires ") << rank <<
" indices";
74 VectorType aType = getLhsVectorType();
75 VectorType bType = getRhsVectorType();
82 Type ta = aType.getElementType();
83 Type tb = bType.getElementType();
84 Type tc = cType.getElementType();
86 return emitOpError(
"unsupported type combination");
91 VectorType aType = getLhsVectorType();
92 VectorType bType = getRhsVectorType();
99 Type ta = aType.getElementType();
100 Type tb = bType.getElementType();
101 Type tc = cType.getElementType();
103 return emitOpError(
"unsupported type combination");
107 #define GET_OP_CLASSES
108 #include "mlir/Dialect/AMX/AMX.cpp.inc"
static LogicalResult verifyMultShape(Operation *op, VectorType atp, VectorType btp, VectorType ctp, unsigned scale)
Verify that AMX supports the multiplication.
static LogicalResult verifyTileSize(Operation *op, VectorType tp)
Verify that AMX supports the implied tile shape.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.