MLIR  20.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 
24 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
25 
26 void amx::AMXDialect::initialize() {
27  addTypes<
28 #define GET_TYPEDEF_LIST
29 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
30  >();
31 
32  addOperations<
33 #define GET_OP_LIST
34 #include "mlir/Dialect/AMX/AMX.cpp.inc"
35  >();
36 }
37 
38 /// Verify that AMX supports the implied tile shape.
39 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
40  const unsigned kMaxRows = 16;
41  const unsigned kBitsPerRow = 64 * 8;
42  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
43  if (tp.getDimSize(0) > kMaxRows)
44  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
45  if (col > kBitsPerRow || col & 0x1f)
46  return op->emitOpError("bad column width: ") << (col >> 3);
47  return success();
48 }
49 
50 /// Verify that AMX supports the multiplication.
51 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
52  amx::TileType btp, amx::TileType ctp,
53  unsigned scale) {
54  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
55  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
56  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
57  if (cm != am || cn != bn || ak != bk)
58  return op->emitOpError("bad mult shape: ")
59  << cm << " x " << cn << " x " << ak;
60  return success();
61 }
62 
63 LogicalResult amx::TileZeroOp::verify() {
64  return verifyTileSize(*this, getTileType());
65 }
66 
67 LogicalResult amx::TileLoadOp::verify() {
68  unsigned rank = getMemRefType().getRank();
69  if (getIndices().size() != rank)
70  return emitOpError("requires ") << rank << " indices";
71  return verifyTileSize(*this, getTileType());
72 }
73 
74 LogicalResult amx::TileStoreOp::verify() {
75  unsigned rank = getMemRefType().getRank();
76  if (getIndices().size() != rank)
77  return emitOpError("requires ") << rank << " indices";
78  return verifyTileSize(*this, getTileType());
79 }
80 
81 LogicalResult amx::TileMulFOp::verify() {
82  amx::TileType aType = getLhsTileType();
83  amx::TileType bType = getRhsTileType();
84  amx::TileType cType = getTileType();
85  if (failed(verifyTileSize(*this, aType)) ||
86  failed(verifyTileSize(*this, bType)) ||
87  failed(verifyTileSize(*this, cType)) ||
88  failed(verifyMultShape(*this, aType, bType, cType, 1)))
89  return failure();
90  Type ta = aType.getElementType();
91  Type tb = bType.getElementType();
92  Type tc = cType.getElementType();
93  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
94  return emitOpError("unsupported type combination");
95  return success();
96 }
97 
98 LogicalResult amx::TileMulIOp::verify() {
99  amx::TileType aType = getLhsTileType();
100  amx::TileType bType = getRhsTileType();
101  amx::TileType cType = getTileType();
102  if (failed(verifyTileSize(*this, aType)) ||
103  failed(verifyTileSize(*this, bType)) ||
104  failed(verifyTileSize(*this, cType)) ||
105  failed(verifyMultShape(*this, aType, bType, cType, 2)))
106  return failure();
107  Type ta = aType.getElementType();
108  Type tb = bType.getElementType();
109  Type tc = cType.getElementType();
110  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
111  return emitOpError("unsupported type combination");
112  return success();
113 }
114 
116  if (parser.parseLess())
117  return nullptr;
118 
120  if (parser.parseDimensionList(shape, false, true))
121  return nullptr;
122 
123  Type elementType;
124  if (parser.parseType(elementType))
125  return nullptr;
126 
127  if (parser.parseGreater())
128  return nullptr;
129 
130  return TileType::get(shape, elementType);
131 }
132 
133 void amx::TileType::print(AsmPrinter &os) const {
134  os << "<";
136  os << 'x';
138  os << '>';
139 }
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/AMX/AMX.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, amx::TileType btp, amx::TileType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:51
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:39
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printDimensionList(ArrayRef< int64_t > shape)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isF16() const
Definition: Types.cpp:57
bool isBF16() const
Definition: Types.cpp:56
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426