MLIR  14.0.0git
AMXDialect.cpp
Go to the documentation of this file.
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 
21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
22 
23 void amx::AMXDialect::initialize() {
24  addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/AMX/AMX.cpp.inc"
27  >();
28 }
29 
30 /// Verify that AMX supports the implied tile shape.
31 static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
32  const unsigned kMaxRows = 16;
33  const unsigned kBitsPerRow = 64 * 8;
34  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
35  if (tp.getDimSize(0) > kMaxRows)
36  return op->emitOpError("bad row height: ") << tp.getDimSize(0);
37  if (col > kBitsPerRow || col & 0x1f)
38  return op->emitOpError("bad column width: ") << (col >> 3);
39  return success();
40 }
41 
42 /// Verify that AMX supports the multiplication.
43 static LogicalResult verifyMultShape(Operation *op, VectorType atp,
44  VectorType btp, VectorType ctp,
45  unsigned scale) {
46  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
47  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
48  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
49  if (cm != am || cn != bn || ak != bk)
50  return op->emitOpError("bad mult shape: ")
51  << cm << " x " << cn << " x " << ak;
52  return success();
53 }
54 
55 static LogicalResult verify(amx::TileZeroOp op) {
56  return verifyTileSize(op, op.getVectorType());
57 }
58 
59 static LogicalResult verify(amx::TileLoadOp op) {
60  unsigned rank = op.getMemRefType().getRank();
61  if (llvm::size(op.indices()) != rank)
62  return op.emitOpError("requires ") << rank << " indices";
63  return verifyTileSize(op, op.getVectorType());
64 }
65 
66 static LogicalResult verify(amx::TileStoreOp op) {
67  unsigned rank = op.getMemRefType().getRank();
68  if (llvm::size(op.indices()) != rank)
69  return op.emitOpError("requires ") << rank << " indices";
70  return verifyTileSize(op, op.getVectorType());
71 }
72 
73 static LogicalResult verify(amx::TileMulFOp op) {
74  VectorType aType = op.getLhsVectorType();
75  VectorType bType = op.getRhsVectorType();
76  VectorType cType = op.getVectorType();
77  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
78  failed(verifyTileSize(op, cType)) ||
79  failed(verifyMultShape(op, aType, bType, cType, 1)))
80  return failure();
81  Type ta = aType.getElementType();
82  Type tb = bType.getElementType();
83  Type tc = cType.getElementType();
84  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
85  return op.emitOpError("unsupported type combination");
86  return success();
87 }
88 
89 static LogicalResult verify(amx::TileMulIOp op) {
90  VectorType aType = op.getLhsVectorType();
91  VectorType bType = op.getRhsVectorType();
92  VectorType cType = op.getVectorType();
93  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
94  failed(verifyTileSize(op, cType)) ||
95  failed(verifyMultShape(op, aType, bType, cType, 2)))
96  return failure();
97  Type ta = aType.getElementType();
98  Type tb = bType.getElementType();
99  Type tc = cType.getElementType();
100  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
101  return op.emitOpError("unsupported type combination");
102  return success();
103 }
104 
105 #define GET_OP_CLASSES
106 #include "mlir/Dialect/AMX/AMX.cpp.inc"
Include the generated interface declarations.
static LogicalResult verifyTileSize(Operation *op, VectorType tp)
Verify that AMX supports the implied tile shape.
Definition: AMXDialect.cpp:31
bool isF32() const
Definition: Types.cpp:23
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static LogicalResult verifyMultShape(Operation *op, VectorType atp, VectorType btp, VectorType ctp, unsigned scale)
Verify that AMX supports the multiplication.
Definition: AMXDialect.cpp:43
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
bool isBF16() const
Definition: Types.cpp:21