MLIR  19.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 
21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
22 
23 void amx::AMXDialect::initialize() {
24  addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/AMX/AMX.cpp.inc"
27  >();
28 }
29 
30 /// Verify that AMX supports the implied tile shape.
31 static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
32  const unsigned kMaxRows = 16;
33  const unsigned kBitsPerRow = 64 * 8;
34  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
35  if (tp.getDimSize(0) > kMaxRows)
36  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
37  if (col > kBitsPerRow || col & 0x1f)
38  return op->emitOpError("bad column width: ") << (col >> 3);
39  return success();
40 }
41 
42 /// Verify that AMX supports the multiplication.
43 static LogicalResult verifyMultShape(Operation *op, VectorType atp,
44  VectorType btp, VectorType ctp,
45  unsigned scale) {
46  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
47  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
48  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
49  if (cm != am || cn != bn || ak != bk)
50  return op->emitOpError("bad mult shape: ")
51  << cm << " x " << cn << " x " << ak;
52  return success();
53 }
54 
56  return verifyTileSize(*this, getVectorType());
57 }
58 
60  unsigned rank = getMemRefType().getRank();
61  if (getIndices().size() != rank)
62  return emitOpError("requires ") << rank << " indices";
63  return verifyTileSize(*this, getVectorType());
64 }
65 
67  unsigned rank = getMemRefType().getRank();
68  if (getIndices().size() != rank)
69  return emitOpError("requires ") << rank << " indices";
70  return verifyTileSize(*this, getVectorType());
71 }
72 
74  VectorType aType = getLhsVectorType();
75  VectorType bType = getRhsVectorType();
76  VectorType cType = getVectorType();
77  if (failed(verifyTileSize(*this, aType)) ||
78  failed(verifyTileSize(*this, bType)) ||
79  failed(verifyTileSize(*this, cType)) ||
80  failed(verifyMultShape(*this, aType, bType, cType, 1)))
81  return failure();
82  Type ta = aType.getElementType();
83  Type tb = bType.getElementType();
84  Type tc = cType.getElementType();
85  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
86  return emitOpError("unsupported type combination");
87  return success();
88 }
89 
91  VectorType aType = getLhsVectorType();
92  VectorType bType = getRhsVectorType();
93  VectorType cType = getVectorType();
94  if (failed(verifyTileSize(*this, aType)) ||
95  failed(verifyTileSize(*this, bType)) ||
96  failed(verifyTileSize(*this, cType)) ||
97  failed(verifyMultShape(*this, aType, bType, cType, 2)))
98  return failure();
99  Type ta = aType.getElementType();
100  Type tb = bType.getElementType();
101  Type tc = cType.getElementType();
102  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
103  return emitOpError("unsupported type combination");
104  return success();
105 }
106 
107 #define GET_OP_CLASSES
108 #include "mlir/Dialect/AMX/AMX.cpp.inc"
static LogicalResult verifyMultShape(Operation *op, VectorType atp, VectorType btp, VectorType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:43
static LogicalResult verifyTileSize(Operation *op, VectorType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:31
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isBF16() const
Definition: Types.cpp:48
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26