MLIR  14.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace NVVM;
34 
35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Printing/parsing for NVVM ops
40 //===----------------------------------------------------------------------===//
41 
43  p << " " << op->getOperands();
44  if (op->getNumResults() > 0)
45  p << " : " << op->getResultTypes();
46 }
47 
48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
50  OperationState &result) {
51  MLIRContext *context = parser.getContext();
52  auto int32Ty = IntegerType::get(context, 32);
53  auto int1Ty = IntegerType::get(context, 1);
54 
56  Type type;
57  return failure(parser.parseOperandList(ops) ||
58  parser.parseOptionalAttrDict(result.attributes) ||
59  parser.parseColonType(type) ||
60  parser.addTypeToList(type, result.types) ||
61  parser.resolveOperands(ops, {int32Ty, int1Ty},
62  parser.getNameLoc(), result.operands));
63 }
64 
65 static LogicalResult verify(MmaOp op) {
66  MLIRContext *context = op.getContext();
67  auto f16Ty = Float16Type::get(context);
68  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
69  auto f32Ty = Float32Type::get(context);
70  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
71  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
72  auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
73  context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
74 
75  SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
76  op.getOperandTypes().end());
77  if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
78  operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
79  f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
80  f32Ty, f32Ty, f32Ty}) {
81  return op.emitOpError(
82  "expected operands to be 4 <halfx2>s followed by either "
83  "4 <halfx2>s or 8 floats");
84  }
85  if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
86  return op.emitOpError("expected result type to be a struct of either 4 "
87  "<halfx2>s or 8 floats");
88  }
89 
90  auto alayout = op->getAttrOfType<StringAttr>("alayout");
91  auto blayout = op->getAttrOfType<StringAttr>("blayout");
92 
93  if (!(alayout && blayout) ||
94  !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
95  !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
96  return op.emitOpError(
97  "alayout and blayout attributes must be set to either "
98  "\"row\" or \"col\"");
99  }
100 
101  if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
102  f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
103  f32Ty, f32Ty, f32Ty} &&
104  op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
105  blayout.getValue() == "col") {
106  return success();
107  }
108  return op.emitOpError("unimplemented mma.sync variant");
109 }
110 
111 std::pair<mlir::Type, unsigned>
112 inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
113  unsigned numberElements = 0;
114  Type elementType;
115  OpBuilder builder(context);
116  Type f16x2 = VectorType::get(2, builder.getF16Type());
117  if (type == NVVM::MMATypes::f16) {
118  elementType = f16x2;
119  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
120  numberElements = 8;
121  else
122  numberElements = 4;
123  } else if (type == NVVM::MMATypes::f32) {
124  elementType = builder.getF32Type();
125  numberElements = 8;
126  } else if (type == NVVM::MMATypes::tf32) {
127  elementType = builder.getI32Type();
128  numberElements = 4;
129  }
130  assert(numberElements != 0 && elementType != nullptr);
131  return std::make_pair(elementType, numberElements);
132 }
133 
134 static LogicalResult verify(NVVM::WMMALoadOp op) {
135  unsigned addressSpace =
136  op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
137  if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
138  return op.emitOpError("expected source pointer in memory "
139  "space 0, 1, 3");
140 
141  if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
142  op.eltype(), op.frag()) == 0)
143  return op.emitOpError() << "invalid attribute combination";
144  std::pair<Type, unsigned> typeInfo =
145  inferMMAType(op.eltype(), op.frag(), op.getContext());
147  op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
148  if (op.getType() != dstType)
149  return op.emitOpError("expected destination type is a structure of ")
150  << typeInfo.second << " elements of type " << typeInfo.first;
151  return success();
152 }
153 
154 static LogicalResult verify(NVVM::WMMAStoreOp op) {
155  unsigned addressSpace =
156  op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
157  if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
158  return op.emitOpError("expected operands to be a source pointer in memory "
159  "space 0, 1, 3");
160 
161  if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
162  op.eltype()) == 0)
163  return op.emitOpError() << "invalid attribute combination";
164  std::pair<Type, unsigned> typeInfo =
165  inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext());
166  if (op.args().size() != typeInfo.second)
167  return op.emitOpError()
168  << "expected " << typeInfo.second << " data operands";
169  if (llvm::any_of(op.args(), [&typeInfo](Value operands) {
170  return operands.getType() != typeInfo.first;
171  }))
172  return op.emitOpError()
173  << "expected data operands of type " << typeInfo.first;
174  return success();
175 }
176 
177 static LogicalResult verify(NVVM::WMMAMmaOp op) {
178  if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(),
179  op.layoutB(), op.eltypeA(),
180  op.eltypeB()) == 0)
181  return op.emitOpError() << "invalid attribute combination";
182  std::pair<Type, unsigned> typeInfoA =
183  inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext());
184  std::pair<Type, unsigned> typeInfoB =
185  inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext());
186  std::pair<Type, unsigned> typeInfoC =
187  inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext());
188  SmallVector<Type, 32> arguments;
189  arguments.append(typeInfoA.second, typeInfoA.first);
190  arguments.append(typeInfoB.second, typeInfoB.first);
191  arguments.append(typeInfoC.second, typeInfoC.first);
192  unsigned numArgs = arguments.size();
193  if (op.args().size() != numArgs)
194  return op.emitOpError() << "expected " << numArgs << " arguments";
195  for (unsigned i = 0; i < numArgs; i++) {
196  if (op.args()[i].getType() != arguments[i])
197  return op.emitOpError()
198  << "expected argument " << i << " to be of type " << arguments[i];
199  }
201  op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
202  if (op.getType() != dstType)
203  return op.emitOpError("expected destination type is a structure of ")
204  << typeInfoC.second << " elements of type " << typeInfoC.first;
205  return success();
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // NVVMDialect initialization, type parsing, and registration.
210 //===----------------------------------------------------------------------===//
211 
212 // TODO: This should be the llvm.nvvm dialect once this is supported.
213 void NVVMDialect::initialize() {
214  addOperations<
215 #define GET_OP_LIST
216 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
217  >();
218  addAttributes<
219 #define GET_ATTRDEF_LIST
220 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
221  >();
222 
223  // Support unknown operations because not all NVVM operations are
224  // registered.
225  allowUnknownOperations();
226 }
227 
228 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
229  NamedAttribute attr) {
230  // Kernel function attribute should be attached to functions.
231  if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
232  if (!isa<LLVM::LLVMFuncOp>(op)) {
233  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
234  << "' attribute attached to unexpected op";
235  }
236  }
237  return success();
238 }
239 
240 #define GET_OP_CLASSES
241 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
242 
243 #define GET_ATTRDEF_CLASSES
244 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
Include the generated interface declarations.
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:830
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:371
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
FloatType getF16Type()
Definition: Builders.cpp:38
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
FloatType getF32Type()
Definition: Builders.cpp:40
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
std::pair< mlir::Type, unsigned > inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context)
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
SmallVector< Value, 4 > operands
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
NamedAttrList attributes
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
LLVM dialect pointer type.
Definition: LLVMTypes.h:181
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
This class helps build Operations.
Definition: Builders.h:177
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
Definition: NVVMDialect.cpp:42
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, OperationState &result)
Definition: NVVMDialect.cpp:49
result_type_range getResultTypes()
Definition: Operation.h:297
IntegerType getI32Type()
Definition: Builders.cpp:54
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
SmallVector< Type, 4 > types
Types of the results of this operation.