29 #include "llvm/ADT/TypeSwitch.h" 34 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" 50 elementType, operand);
64 return elementType.
isF16() || elementType.
isF32();
71 if (!operand.equals(
"AOp") && !operand.equals(
"BOp") &&
72 !operand.equals(
"COp"))
73 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
75 if (shape.size() != 2)
76 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
79 return emitError() <<
"MMAMatrixType elements must be F16 or F32";
100 bool GPUDialect::isKernel(
Operation *op) {
101 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
102 return static_cast<bool>(isKernelAttr);
119 void GPUDialect::initialize() {
120 addTypes<AsyncTokenType>();
121 addTypes<MMAMatrixType>();
124 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" 127 #define GET_ATTRDEF_LIST 128 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 130 addInterfaces<GPUInlinerInterface>();
141 if (keyword ==
"async.token")
144 if (keyword ==
"mma_matrix") {
173 shape, elementType, operand);
182 .Case<AsyncTokenType>([&](
Type) { os <<
"async.token"; })
186 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
189 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
191 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
197 attr.
getName() != getContainerModuleAttrName())
200 auto module = dyn_cast<ModuleOp>(op);
203 << getContainerModuleAttrName() <<
"' attribute to be attached to '" 204 << ModuleOp::getOperationName() <<
'\'';
206 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
209 if (!launchOp->getParentOp() ||
210 launchOp->getParentOp()->getParentOp() != module)
215 if (!launchOp->getAttrOfType<SymbolRefAttr>(
216 LaunchFuncOp::getKernelAttrName()))
220 StringAttr kernelModuleName = launchOp.getKernelModuleName();
221 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName);
223 return launchOp.emitOpError()
224 <<
"kernel module '" << kernelModuleName.getValue()
228 Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr());
231 << launchOp.kernel() <<
"' is undefined";
232 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
233 if (!kernelConvertedFunction) {
235 <<
"referenced kernel '" << launchOp.kernel()
236 <<
"' is not a function";
242 GPUDialect::getKernelFuncAttrName()))
243 return launchOp.emitOpError(
"kernel function is missing the '")
244 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
249 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
250 if (!kernelGPUFunction)
253 unsigned actualNumArguments = launchOp.getNumKernelOperands();
254 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
255 if (expectedNumArguments != actualNumArguments)
256 return launchOp.emitOpError(
"got ")
257 << actualNumArguments <<
" kernel operands but expected " 258 << expectedNumArguments;
260 auto functionType = kernelGPUFunction.getFunctionType();
261 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
262 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
263 return launchOp.emitOpError(
"type of function argument ")
264 << i <<
" does not match";
284 return parser.
emitError(loc,
"needs to be named when marked 'async'");
299 if (asyncDependencies.empty())
304 llvm::interleaveComma(asyncDependencies, printer);
313 if (body().empty() != op().has_value())
314 return emitError(
"expected either an op attribute or a non-empty body");
315 if (!body().empty()) {
316 if (body().getNumArguments() != 2)
317 return emitError(
"expected two region arguments");
318 for (
auto argument : body().getArguments()) {
319 if (argument.getType() != getType())
320 return emitError(
"incorrect region argument type");
322 unsigned yieldCount = 0;
323 for (
Block &block : body()) {
324 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
325 if (yield.getNumOperands() != 1)
326 return emitError(
"expected one gpu.yield operand");
327 if (yield.getOperand(0).getType() != getType())
328 return emitError(
"incorrect gpu.yield type");
333 return emitError(
"expected gpu.yield op in region");
335 gpu::AllReduceOperation opName = *op();
336 if ((opName == gpu::AllReduceOperation::AND ||
337 opName == gpu::AllReduceOperation::OR ||
338 opName == gpu::AllReduceOperation::XOR) &&
339 !getType().isa<IntegerType>()) {
341 <<
'`' << gpu::stringifyAllReduceOperation(opName)
342 <<
"` accumulator is only compatible with Integer type";
350 AllReduceOperationAttr &attr) {
356 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
362 AllReduceOperationAttr attr) {
373 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
377 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
395 Value dynamicSharedMemorySize,
Type asyncTokenType,
403 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
404 if (dynamicSharedMemorySize)
412 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
416 segmentSizes.front() = asyncDependencies.size();
417 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
423 assert(!body().empty() &&
"LaunchOp body must not be empty.");
424 auto args = body().getArguments();
429 assert(!body().empty() &&
"LaunchOp body must not be empty.");
430 auto args = body().getArguments();
435 assert(!body().empty() &&
"LaunchOp body must not be empty.");
436 auto args = body().getArguments();
441 assert(!body().empty() &&
"LaunchOp body must not be empty.");
442 auto args = body().getArguments();
443 return KernelDim3{args[9], args[10], args[11]};
446 KernelDim3 LaunchOp::getGridSizeOperandValues() {
447 auto operands = getOperands().drop_front(asyncDependencies().size());
448 return KernelDim3{operands[0], operands[1], operands[2]};
451 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
452 auto operands = getOperands().drop_front(asyncDependencies().size());
453 return KernelDim3{operands[3], operands[4], operands[5]};
460 if (!body().empty()) {
461 if (body().getNumArguments() !=
462 LaunchOp::kNumConfigOperands + getNumOperands() -
463 (dynamicSharedMemorySize() ? 1 : 0) - asyncDependencies().size())
464 return emitOpError(
"unexpected number of region arguments");
469 for (
Block &block : body()) {
472 if (block.back().getNumSuccessors() != 0)
474 if (!isa<gpu::TerminatorOp>(&block.back())) {
477 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
478 "' or a terminator with successors")
479 .attachNote(getLoc())
480 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
484 if (getNumResults() == 0 && asyncToken())
485 return emitOpError(
"needs to be named when async keyword is specified");
496 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
497 p << size.
x <<
" = " << operands.
x <<
", ";
498 p << size.
y <<
" = " << operands.
y <<
", ";
499 p << size.
z <<
" = " << operands.
z <<
')';
505 if (!asyncDependencies().empty())
506 p <<
" [" << asyncDependencies() <<
']';
509 p <<
' ' << getBlocksKeyword();
512 p <<
' ' << getThreadsKeyword();
515 if (dynamicSharedMemorySize())
516 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' ' 517 << dynamicSharedMemorySize();
522 LaunchOp::getOperandSegmentSizeAttr()});
536 assert(indices.size() == 3 &&
"space for three indices expected");
542 std::move(args.begin(), args.end(), indices.begin());
544 for (
int i = 0; i < 3; ++i) {
564 sizes(LaunchOp::kNumConfigOperands);
572 LaunchOp::kNumConfigRegionAttributes);
584 result.
types.push_back(asyncTokenType);
591 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
593 regionArgsRef.slice(6, 3),
594 regionArgsRef.slice(0, 3)) ||
595 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
597 regionArgsRef.slice(9, 3),
598 regionArgsRef.slice(3, 3)) ||
604 bool hasDynamicSharedMemorySize =
false;
606 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
607 hasDynamicSharedMemorySize =
true;
620 LaunchOp::kNumConfigRegionAttributes, index);
623 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
625 arg.
ssaName = std::get<0>(ssaValueAndType);
626 arg.
type = std::get<1>(ssaValueAndType);
627 regionArguments.push_back(arg);
636 segmentSizes.front() = asyncDependencies.size();
637 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
638 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
652 bool simplified =
false;
653 auto constPropIdUses = [&](
Value id,
Value size) {
664 id.replaceAllUsesWith(zero);
667 constPropIdUses(op.getBlockIds().x, op.gridSizeX());
668 constPropIdUses(op.getBlockIds().y, op.gridSizeY());
669 constPropIdUses(op.getBlockIds().z, op.gridSizeZ());
670 constPropIdUses(op.getThreadIds().x, op.blockSizeX());
671 constPropIdUses(op.getThreadIds().y, op.blockSizeY());
672 constPropIdUses(op.getThreadIds().z, op.blockSizeZ());
698 blockSize.
y, blockSize.
z});
699 if (dynamicSharedMemorySize)
702 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
704 SymbolRefAttr::get(kernelModule.getNameAttr(),
705 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
708 segmentSizes.front() = asyncDependencies.size();
709 segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
710 segmentSizes.back() =
static_cast<int32_t
>(kernelOperands.size());
715 StringAttr LaunchFuncOp::getKernelModuleName() {
716 return kernel().getRootReference();
719 StringAttr LaunchFuncOp::getKernelName() {
return kernel().getLeafReference(); }
721 unsigned LaunchFuncOp::getNumKernelOperands() {
return operands().size(); }
723 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
return operands()[i]; }
725 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
726 auto operands = getOperands().drop_front(asyncDependencies().size());
727 return KernelDim3{operands[0], operands[1], operands[2]};
730 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
731 auto operands = getOperands().drop_front(asyncDependencies().size());
732 return KernelDim3{operands[3], operands[4], operands[5]};
736 auto module = (*this)->getParentOfType<ModuleOp>();
738 return emitOpError(
"expected to belong to a module");
740 if (!module->getAttrOfType<UnitAttr>(
741 GPUDialect::getContainerModuleAttrName()))
742 return emitOpError(
"expected the closest surrounding module to have the '" +
743 GPUDialect::getContainerModuleAttrName() +
746 auto kernelAttr = (*this)->getAttrOfType<SymbolRefAttr>(getKernelAttrName());
748 return emitOpError(
"symbol reference attribute '" + getKernelAttrName() +
749 "' must be specified");
765 for (
auto &arg : args) {
766 argNames.push_back(arg.ssaName);
767 argTypes.push_back(arg.type);
774 if (operands.empty())
777 llvm::interleaveComma(llvm::zip(operands, types), printer,
778 [&](
const auto &pair) {
791 int32_t offset, int32_t width, ShuffleMode mode) {
792 build(builder, result, value,
807 auto attrName = getNumWorkgroupAttributionsAttrName();
808 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
809 (*this)->setAttr(attrName,
810 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
811 return getBody().insertArgument(
820 return getBody().addArgument(type, loc);
824 StringRef name, FunctionType type,
831 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
838 for (
Type argTy : type.getInputs())
840 for (
Type argTy : workgroupAttributions)
842 for (
Type argTy : privateAttributions)
885 parser,
false, entryArgs, isVariadic, resultTypes,
889 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
890 return parser.
emitError(signatureLocation)
891 <<
"gpu.func requires named arguments";
898 for (
auto &arg : entryArgs)
899 argTypes.push_back(arg.type);
913 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
914 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
924 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
942 p <<
' ' << keyword <<
'(';
943 llvm::interleaveComma(
960 p <<
' ' << getKernelKeyword();
963 p, *
this, type.getNumInputs(), type.getNumResults(),
964 {getNumWorkgroupAttributionsAttrName(),
965 GPUDialect::getKernelFuncAttrName()});
971 Type type = getFunctionTypeAttr().getValue();
972 if (!type.
isa<FunctionType>())
974 "' attribute of function type");
977 return emitOpError() <<
"expected void return type for kernel function";
984 unsigned memorySpace) {
985 for (
Value v : attributions) {
986 auto type = v.getType().dyn_cast<MemRefType>();
988 return op->
emitOpError() <<
"expected memref type in attribution";
990 if (type.getMemorySpaceAsInt() != memorySpace) {
992 <<
"expected memory space " << memorySpace <<
" in attribution";
1000 unsigned numFuncArguments = getNumArguments();
1001 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1002 unsigned numBlockArguments = front().getNumArguments();
1003 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1004 return emitOpError() <<
"expected at least " 1005 << numFuncArguments + numWorkgroupAttributions
1006 <<
" arguments to body region";
1009 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1010 Type blockArgType = front().getArgument(i).getType();
1011 if (funcArgTypes[i] != blockArgType)
1012 return emitOpError() <<
"expected body region argument #" << i
1013 <<
" to be of type " << funcArgTypes[i] <<
", got " 1018 GPUDialect::getWorkgroupAddressSpace())) ||
1020 GPUDialect::getPrivateAddressSpace())))
1031 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1033 FunctionType funType =
function.getFunctionType();
1035 if (funType.getNumResults() != operands().size())
1036 return emitOpError()
1037 .append(
"expected ", funType.getNumResults(),
" result operands")
1038 .attachNote(
function.getLoc())
1039 .append(
"return type declared here");
1045 std::tie(type, operand) = pair.value();
1046 if (type != operand.
getType())
1047 return emitOpError() <<
"unexpected type `" << operand.
getType()
1048 <<
"' for operand #" << pair.index();
1065 StringAttr nameAttr;
1097 auto srcType = src().getType();
1098 auto dstType = dst().getType();
1101 return emitOpError(
"arguments have incompatible element type");
1104 return emitOpError(
"arguments have incompatible shape");
1118 Value dest = op.dst();
1123 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1128 return user != op &&
1129 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1135 if (op.asyncDependencies().size() > 1 ||
1136 ((op.asyncDependencies().empty() && op.asyncToken()) ||
1137 (!op.asyncDependencies().empty() && !op.asyncToken())))
1139 rewriter.
replaceOp(op, op.asyncDependencies());
1148 results.
add<EraseTrivialCopyOp>(context);
1163 return strides.back() == 1;
1167 auto srcType = srcMemref().getType();
1168 auto resType = res().
getType();
1171 auto srcMemrefType = srcType.cast<MemRefType>();
1172 auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
1176 "expected source memref most minor dim must have unit stride");
1181 "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " 1182 "kGlobalMemorySpace only allowed");
1184 if (!operand.equals(
"AOp") && !operand.equals(
"BOp") &&
1185 !operand.equals(
"COp"))
1186 return emitError(
"only AOp, BOp and COp can be loaded");
1196 auto srcType = src().getType();
1197 auto dstType = dstMemref().getType();
1199 auto dstMemrefType = dstType.cast<MemRefType>();
1200 auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
1204 "expected destination memref most minor dim must have unit stride");
1208 return emitError(
"destination memorySpace of kGenericMemorySpace, " 1209 "kGlobalMemorySpace or kSharedMemorySpace only allowed");
1211 if (!srcMatrixType.getOperand().equals(
"COp"))
1213 "expected the operand matrix being stored to have 'COp' operand type");
1223 enum OperandMap { A, B, C };
1225 opTypes.push_back(opA().getType().cast<MMAMatrixType>());
1226 opTypes.push_back(opB().getType().cast<MMAMatrixType>());
1227 opTypes.push_back(opC().getType().cast<MMAMatrixType>());
1229 if (!opTypes[A].
getOperand().equals(
"AOp") ||
1232 return emitError(
"operands must be in the order AOp, BOp, COp");
1235 aShape = opTypes[A].getShape();
1236 bShape = opTypes[B].getShape();
1237 cShape = opTypes[C].getShape();
1239 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1240 bShape[1] != cShape[1])
1241 return emitError(
"operand shapes do not satisfy matmul constraints");
1250 bool folded =
false;
1252 auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
1254 operand.set(cast.getOperand());
1290 if (llvm::none_of(op.asyncDependencies(), predicate))
1293 for (
Value operand : op->getOperands()) {
1294 if (predicate(operand))
1296 validOperands.push_back(operand);
1298 op->setOperands(validOperands);
1318 if (op.asyncDependencies().empty() && !op.asyncToken()) {
1323 if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) {
1324 rewriter.
replaceOp(op, op.asyncDependencies());
1328 if (op.asyncToken() && op.asyncToken().use_empty()) {
1340 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
1348 auto memRefType = memref().getType().cast<MemRefType>();
1350 if (static_cast<int64_t>(dynamicSizes().size()) !=
1351 memRefType.getNumDynamicDims())
1352 return emitOpError(
"dimension operand count does not equal memref " 1353 "dynamic dimension count");
1355 unsigned numSymbols = 0;
1356 if (!memRefType.getLayout().isIdentity())
1357 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
1358 if (symbolOperands().size() != numSymbols) {
1360 "symbol operand count does not equal memref symbol count");
1379 auto memrefType = dimOp.source().getType().dyn_cast<MemRefType>();
1380 if (!memrefType || !memrefType.isDynamicDim(index.value()))
1383 auto alloc = dimOp.source().getDefiningOp<AllocOp>();
1387 Value substituteOp = *(alloc.dynamicSizes().begin() +
1388 memrefType.getDynamicDimIndex(index.value()));
1389 rewriter.
replaceOp(dimOp, substituteOp);
1398 results.
add<SimplifyDimOfAllocOp>(context);
1401 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc" 1402 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc" 1404 #define GET_ATTRDEF_CLASSES 1405 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 1407 #define GET_OP_CLASSES 1408 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static std::string diag(llvm::Value &v)
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
Operation is a basic unit of execution within MLIR.
Attribute getValue() const
Return the value of the attribute.
void addAsyncDependency(Operation *op, Value token)
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations...
BlockListType & getBlocks()
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
virtual void printType(Type type)
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
This class represents a diagnostic that is inflight and set to be reported.
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
Block represents an ordered list of Operations.
Generic memory space identifier.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
AttrClass getAttrOfType(StringAttr name)
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments, to the list of operation attributes in result.
void push_back(Block *block)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
unsigned getNumOperands()
Shared memory space identifier.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
static bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
This is the representation of an operand reference.
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, unsigned memorySpace)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
NamedAttribute getNamedAttr(StringRef name, Attribute val)
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
user_range getUsers() const
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseComma()=0
Parse a , token.
static constexpr const bool value
UnresolvedOperand ssaName
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the '@' symbol) in a string attribute named 'attrName'...
static ConcreteT get(MLIRContext *ctx, Args... args)
Get or create a new ConcreteT instance within the ctx.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
NamedAttribute represents a combination of a name and an Attribute value.
MLIRContext * getContext()
Return the context this operation is associated with.
MutableArrayRef< OpOperand > getOpOperands()
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseGreater()=0
Parse a '>' token.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void addOperands(ValueRange newOperands)
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
IntegerAttr getI64IntegerAttr(int64_t value)
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DialectInlinerInterface(Dialect *dialect)
StringAttr getName() const
Return the name of the attribute.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
virtual ParseResult parseRParen()=0
Parse a ) token.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
This is the interface that must be implemented by the dialects of operations to be inlined...
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
This class provides an abstraction over the various different ranges of value types.
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
virtual ParseResult parseLess()=0
Parse a '<' token.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Location getLoc()
The source location the operation was defined or derived from.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Parens surrounding zero or more operands.
Utility class for the GPU dialect to represent triples of Values accessible through ...
A utility result that is used to signal how to proceed with an ongoing walk:
This class represents an argument of a Block.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
This base class exposes generic asm parser hooks, usable across the various derived parsers...
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static StringRef getOperandSegmentSizeAttr()
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form "someop(memrefcast) -> someop".
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Type getElementType() const
Get elementType of a single element.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Region * addRegion()
Create a region that should be attached to the operation.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ImplType * getImpl() const
Utility for easy access to the storage instance.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Specialization of arith.constant op that returns an integer of index type.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
This class represents an operand of an operation.
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
This class implements the operand iterators for the Operation class.
Global memory space identifier.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
unsigned getNumDims() const
Get number of dims.
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
GPUMemorySpace
GPU memory space identifiers.
virtual ParseResult parseEqual()=0
Parse a = token.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Square brackets supporting zero or more ops, or nothing.
MLIRContext * getContext() const
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class helps build Operations.
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one...
This class provides an abstraction over the different types of ranges over Values.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
StringAttr getStringAttr(const Twine &bytes)
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
static ConcreteT getChecked(const Location &loc, Args... args)
Get or create a new ConcreteT instance within the ctx, defined at the given, potentially unknown...