30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/ErrorHandling.h"
33 #include "llvm/Support/StringSaver.h"
38 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
44 int64_t GPUBlockMappingAttr::getMappingId()
const {
45 return static_cast<int64_t
>(getBlock());
48 bool GPUBlockMappingAttr::isLinearMapping()
const {
49 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
52 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
53 return isLinearMapping()
54 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
58 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
59 return static_cast<int64_t
>(getWarpgroup());
62 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
63 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
66 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
67 return isLinearMapping()
68 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
72 int64_t GPUWarpMappingAttr::getMappingId()
const {
73 return static_cast<int64_t
>(getWarp());
76 bool GPUWarpMappingAttr::isLinearMapping()
const {
77 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
80 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
81 return isLinearMapping()
82 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
86 int64_t GPUThreadMappingAttr::getMappingId()
const {
87 return static_cast<int64_t
>(getThread());
90 bool GPUThreadMappingAttr::isLinearMapping()
const {
91 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
94 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
95 return isLinearMapping()
96 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
100 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
101 return static_cast<int64_t
>(getAddressSpace());
104 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
105 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
108 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
109 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
126 elementType, operand);
140 return elementType.
isF16() || elementType.
isF32() ||
149 if (!operand.equals(
"AOp") && !operand.equals(
"BOp") &&
150 !operand.equals(
"COp"))
151 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
153 if (shape.size() != 2)
154 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
158 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
179 bool GPUDialect::isKernel(
Operation *op) {
180 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
181 return static_cast<bool>(isKernelAttr);
197 void GPUDialect::initialize() {
198 addTypes<AsyncTokenType>();
199 addTypes<MMAMatrixType>();
200 addTypes<SparseDnTensorHandleType>();
201 addTypes<SparseSpMatHandleType>();
202 addTypes<SparseSpGEMMOpHandleType>();
205 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
208 #define GET_ATTRDEF_LIST
209 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
211 addInterfaces<GPUInlinerInterface>();
217 return "sparse.dntensor_handle";
219 return "sparse.spmat_handle";
221 return "sparse.spgemmop_handle";
223 llvm_unreachable(
"unknown sparse handle kind");
235 if (keyword ==
"async.token")
238 if (keyword ==
"mma_matrix") {
267 shape, elementType, operand);
285 .Case<SparseDnTensorHandleType>([&](
Type) {
288 .Case<SparseSpMatHandleType>(
290 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
296 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
299 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
301 .
Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
306 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
307 attr.
getName() != getContainerModuleAttrName())
310 auto module = dyn_cast<ModuleOp>(op);
313 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
314 << ModuleOp::getOperationName() <<
'\'';
316 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
319 if (!launchOp->getParentOp() ||
320 launchOp->getParentOp()->getParentOp() != module)
325 if (!launchOp->getAttrOfType<SymbolRefAttr>(
326 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
330 StringAttr kernelContainerName = launchOp.getKernelModuleName();
331 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
332 if (!kernelContainer)
334 <<
"kernel container '" << kernelContainerName.getValue()
338 if (isa<BinaryOp>(kernelContainer))
341 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
343 return launchOp.emitOpError()
344 <<
"kernel module '" << kernelContainerName.getValue()
348 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
351 << launchOp.getKernel() <<
"' is undefined";
352 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
353 if (!kernelConvertedFunction) {
355 <<
"referenced kernel '" << launchOp.getKernel()
356 <<
"' is not a function";
357 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
362 GPUDialect::getKernelFuncAttrName()))
363 return launchOp.emitOpError(
"kernel function is missing the '")
364 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
369 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
370 if (!kernelGPUFunction)
373 unsigned actualNumArguments = launchOp.getNumKernelOperands();
374 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
375 if (expectedNumArguments != actualNumArguments)
376 return launchOp.emitOpError(
"got ")
377 << actualNumArguments <<
" kernel operands but expected "
378 << expectedNumArguments;
380 auto functionType = kernelGPUFunction.getFunctionType();
381 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
382 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
383 return launchOp.emitOpError(
"type of function argument ")
384 << i <<
" does not match";
404 return parser.
emitError(loc,
"needs to be named when marked 'async'");
419 if (asyncDependencies.empty())
424 llvm::interleaveComma(asyncDependencies, printer);
453 p <<
' ' << keyword <<
'(';
454 llvm::interleaveComma(
462 gpu::AddressSpace memorySpace) {
463 for (
Value v : attributions) {
464 auto type = llvm::dyn_cast<MemRefType>(v.getType());
466 return op->
emitOpError() <<
"expected memref type in attribution";
471 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
474 if (addressSpace.getValue() != memorySpace)
476 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
477 <<
" in attribution";
488 return (opName != gpu::AllReduceOperation::AND &&
489 opName != gpu::AllReduceOperation::OR &&
490 opName != gpu::AllReduceOperation::XOR) ||
491 llvm::isa<IntegerType>(resType);
495 if (getBody().empty() != getOp().has_value())
496 return emitError(
"expected either an op attribute or a non-empty body");
497 if (!getBody().empty()) {
498 if (getBody().getNumArguments() != 2)
499 return emitError(
"expected two region arguments");
500 for (
auto argument : getBody().getArguments()) {
501 if (argument.getType() != getType())
502 return emitError(
"incorrect region argument type");
504 unsigned yieldCount = 0;
505 for (
Block &block : getBody()) {
506 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
507 if (yield.getNumOperands() != 1)
508 return emitError(
"expected one gpu.yield operand");
509 if (yield.getOperand(0).getType() != getType())
510 return emitError(
"incorrect gpu.yield type");
515 return emitError(
"expected gpu.yield op in region");
517 gpu::AllReduceOperation opName = *getOp();
520 <<
'`' << gpu::stringifyAllReduceOperation(opName)
521 <<
"` accumulator is only compatible with Integer type";
528 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
532 Region &body = launchOp.getBody();
533 assert(!body.
empty() &&
"Invalid region");
550 AllReduceOperationAttr &attr) {
553 std::optional<AllReduceOperation> op =
554 gpu::symbolizeAllReduceOperation(enumStr);
563 AllReduceOperationAttr attr) {
573 gpu::AllReduceOperation opName = getOp();
575 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
576 <<
"` accumulator is only compatible with Integer type";
581 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
596 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
600 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
618 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
624 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
633 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
634 getBlockSizeY, getBlockSizeZ});
635 if (dynamicSharedMemorySize)
644 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
647 for (
Type argTy : workgroupAttributions)
649 for (
Type argTy : privateAttributions)
654 segmentSizes.front() = asyncDependencies.size();
655 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
661 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
662 auto args = getBody().getArguments();
667 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
668 auto args = getBody().getArguments();
673 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
674 auto args = getBody().getArguments();
679 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
680 auto args = getBody().getArguments();
681 return KernelDim3{args[9], args[10], args[11]};
684 KernelDim3 LaunchOp::getGridSizeOperandValues() {
685 auto operands = getOperands().drop_front(getAsyncDependencies().size());
686 return KernelDim3{operands[0], operands[1], operands[2]};
689 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
690 auto operands = getOperands().drop_front(getAsyncDependencies().size());
691 return KernelDim3{operands[3], operands[4], operands[5]};
698 if (!getBody().empty()) {
699 if (getBody().getNumArguments() <
700 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
701 return emitOpError(
"unexpected number of region arguments");
706 GPUDialect::getWorkgroupAddressSpace())) ||
708 GPUDialect::getPrivateAddressSpace())))
713 for (
Block &block : getBody()) {
716 if (block.back().getNumSuccessors() != 0)
718 if (!isa<gpu::TerminatorOp>(&block.back())) {
721 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
722 "' or a terminator with successors")
723 .attachNote(getLoc())
724 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
728 if (getNumResults() == 0 && getAsyncToken())
729 return emitOpError(
"needs to be named when async keyword is specified");
740 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
741 p << size.
x <<
" = " << operands.
x <<
", ";
742 p << size.
y <<
" = " << operands.
y <<
", ";
743 p << size.
z <<
" = " << operands.
z <<
')';
747 if (getAsyncToken()) {
749 if (!getAsyncDependencies().empty())
750 p <<
" [" << getAsyncDependencies() <<
']';
753 p <<
' ' << getBlocksKeyword();
756 p <<
' ' << getThreadsKeyword();
759 if (getDynamicSharedMemorySize())
760 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
761 << getDynamicSharedMemorySize();
770 LaunchOp::getOperandSegmentSizeAttr(),
771 getNumWorkgroupAttributionsAttrName()});
785 assert(indices.size() == 3 &&
"space for three indices expected");
791 std::move(args.begin(), args.end(), indices.begin());
793 for (
int i = 0; i < 3; ++i) {
814 sizes(LaunchOp::kNumConfigOperands);
822 LaunchOp::kNumConfigRegionAttributes);
834 result.
types.push_back(asyncTokenType);
841 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
843 regionArgsRef.slice(6, 3),
844 regionArgsRef.slice(0, 3)) ||
845 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
847 regionArgsRef.slice(9, 3),
848 regionArgsRef.slice(3, 3)) ||
854 bool hasDynamicSharedMemorySize =
false;
856 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
857 hasDynamicSharedMemorySize =
true;
873 LaunchOp::kNumConfigRegionAttributes, index);
876 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
878 arg.
ssaName = std::get<0>(ssaValueAndType);
879 arg.
type = std::get<1>(ssaValueAndType);
880 regionArguments.push_back(arg);
891 unsigned numWorkgroupAttrs =
892 regionArguments.size() - LaunchOp::kNumConfigRegionAttributes;
893 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
910 segmentSizes.front() = asyncDependencies.size();
911 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
912 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
926 bool simplified =
false;
927 auto constPropIdUses = [&](
Value id,
Value size) {
931 if (
id.getUses().empty())
938 rewriter.
create<arith::ConstantIndexOp>(op.
getLoc(), 0);
943 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
944 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
945 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
946 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
947 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
948 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
962 auto attrName = getNumWorkgroupAttributionsAttrName();
963 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
964 (*this)->setAttr(attrName,
966 return getBody().insertArgument(
967 LaunchOp::kNumConfigRegionAttributes + attr.getInt(), type, loc);
975 return getBody().addArgument(type, loc);
992 result.
addOperands({gridSize.
x, gridSize.
y, gridSize.
z, getBlockSize.
x,
993 getBlockSize.
y, getBlockSize.
z});
994 if (dynamicSharedMemorySize)
997 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1000 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1003 prop.kernel = kernelSymbol;
1004 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1006 for (
auto &sz : prop.operandSegmentSizes)
1008 prop.operandSegmentSizes[0] = asyncDependencies.size();
1009 prop.operandSegmentSizes[segmentSizesLen - 3] =
1010 dynamicSharedMemorySize ? 1 : 0;
1011 prop.operandSegmentSizes[segmentSizesLen - 2] =
1012 static_cast<int32_t
>(kernelOperands.size());
1013 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1021 result.
addOperands({gridSize.
x, gridSize.
y, gridSize.
z, getBlockSize.
x,
1022 getBlockSize.
y, getBlockSize.
z});
1023 if (dynamicSharedMemorySize)
1029 prop.kernel = kernel;
1030 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1032 for (
auto &sz : prop.operandSegmentSizes)
1034 prop.operandSegmentSizes[0] = 0;
1035 prop.operandSegmentSizes[segmentSizesLen - 3] =
1036 dynamicSharedMemorySize ? 1 : 0;
1037 prop.operandSegmentSizes[segmentSizesLen - 2] =
1038 static_cast<int32_t
>(kernelOperands.size());
1039 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1042 StringAttr LaunchFuncOp::getKernelModuleName() {
1043 return getKernel().getRootReference();
1046 StringAttr LaunchFuncOp::getKernelName() {
1047 return getKernel().getLeafReference();
1050 unsigned LaunchFuncOp::getNumKernelOperands() {
1051 return getKernelOperands().size();
1054 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1055 return getKernelOperands()[i];
1058 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1059 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1060 return KernelDim3{operands[0], operands[1], operands[2]};
1063 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1064 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1065 return KernelDim3{operands[3], operands[4], operands[5]};
1069 auto module = (*this)->getParentOfType<ModuleOp>();
1071 return emitOpError(
"expected to belong to a module");
1073 if (!module->getAttrOfType<UnitAttr>(
1074 GPUDialect::getContainerModuleAttrName()))
1075 return emitOpError(
"expected the closest surrounding module to have the '" +
1076 GPUDialect::getContainerModuleAttrName() +
1095 printer <<
": " << dimTy;
1111 parseElement,
" in argument list");
1116 if (operands.empty())
1119 llvm::interleaveComma(llvm::zip(operands, types), printer,
1120 [&](
const auto &pair) {
1133 int32_t offset, int32_t width, ShuffleMode mode) {
1134 build(builder, result, value,
1149 auto attrName = getNumWorkgroupAttributionsAttrName();
1150 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1151 (*this)->setAttr(attrName,
1153 return getBody().insertArgument(
1154 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1162 return getBody().addArgument(type, loc);
1166 StringRef name, FunctionType type,
1174 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1181 for (
Type argTy : type.getInputs())
1183 for (
Type argTy : workgroupAttributions)
1185 for (
Type argTy : privateAttributions)
1188 body->
getBlocks().push_back(entryBlock);
1206 size_t existingArgs = args.size();
1213 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1218 attributionAttrs =
nullptr;
1224 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1225 if (!argument.attrs)
1228 attributionAttrsVec.push_back(argument.attrs);
1230 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1246 StringAttr nameAttr;
1253 parser,
false, entryArgs, isVariadic, resultTypes,
1257 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1258 return parser.
emitError(signatureLocation)
1259 <<
"gpu.func requires named arguments";
1266 for (
auto &arg : entryArgs)
1267 argTypes.push_back(arg.
type);
1273 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1274 getResAttrsAttrName(result.
name));
1279 entryArgs, workgroupAttributionAttrs)))
1284 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1285 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1287 if (workgroupAttributionAttrs)
1288 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1289 workgroupAttributionAttrs);
1294 entryArgs, privateAttributionAttrs)))
1296 if (privateAttributionAttrs)
1298 privateAttributionAttrs);
1302 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1317 ArrayAttr attributes) {
1321 p <<
' ' << keyword <<
'(';
1322 llvm::interleaveComma(
1325 p << v <<
" : " << v.
getType();
1327 size_t attributionIndex = pair.index();
1328 DictionaryAttr attrs;
1329 if (attributes && attributionIndex < attributes.size())
1330 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1341 FunctionType type = getFunctionType();
1347 getWorkgroupAttribAttrs().value_or(
nullptr));
1349 getPrivateAttribAttrs().value_or(
nullptr));
1351 p <<
' ' << getKernelKeyword();
1355 {getNumWorkgroupAttributionsAttrName(),
1356 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1357 getArgAttrsAttrName(), getResAttrsAttrName(),
1358 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1364 StringAttr attrName) {
1365 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1366 if (!allAttrs || index >= allAttrs.size())
1367 return DictionaryAttr();
1368 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1371 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1375 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1380 DictionaryAttr value, StringAttr attrName) {
1382 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1385 elements.append(allAttrs.begin(), allAttrs.end());
1386 while (elements.size() <= index)
1391 elements[index] = value;
1393 op->
setAttr(attrName, newValue);
1396 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1397 DictionaryAttr value) {
1401 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1402 DictionaryAttr value) {
1407 StringAttr name, StringAttr attrsName) {
1411 return dict.get(name);
1414 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1416 assert(index < getNumWorkgroupAttributions() &&
1417 "index must map to a workgroup attribution");
1419 getWorkgroupAttribAttrsAttrName());
1422 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1424 assert(index < getNumPrivateAttributions() &&
1425 "index must map to a private attribution");
1427 getPrivateAttribAttrsAttrName());
1431 Attribute value, StringAttr attrsName) {
1436 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1439 bool mustSort =
true;
1440 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1441 if (elems[i].getName() == name) {
1444 std::swap(elems[i], elems[elems.size() - 1]);
1456 elems.emplace_back(name, value);
1459 DictionaryAttr::sortInPlace(elems);
1461 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1465 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1467 assert(index < getNumWorkgroupAttributions() &&
1468 "index must map to a workgroup attribution");
1470 getWorkgroupAttribAttrsAttrName());
1473 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1475 assert(index < getNumPrivateAttributions() &&
1476 "index must map to a private attribution");
1478 getPrivateAttribAttrsAttrName());
1482 if (isKernel() && getFunctionType().getNumResults() != 0)
1483 return emitOpError() <<
"expected void return type for kernel function";
1491 return emitOpError() <<
"expected body with at least one block";
1492 unsigned numFuncArguments = getNumArguments();
1493 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1494 unsigned numBlockArguments = front().getNumArguments();
1495 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1496 return emitOpError() <<
"expected at least "
1497 << numFuncArguments + numWorkgroupAttributions
1498 <<
" arguments to body region";
1501 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1502 Type blockArgType = front().getArgument(i).getType();
1503 if (funcArgTypes[i] != blockArgType)
1504 return emitOpError() <<
"expected body region argument #" << i
1505 <<
" to be of type " << funcArgTypes[i] <<
", got "
1510 GPUDialect::getWorkgroupAddressSpace())) ||
1512 GPUDialect::getPrivateAddressSpace())))
1519 StringRef attrName) {
1520 auto maybeAttr = op->
getAttr(attrName);
1523 auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1525 return op.
emitOpError(attrName +
" must be a dense i32 array");
1526 if (array.size() != 3)
1527 return op.
emitOpError(attrName +
" must contain exactly 3 elements");
1544 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1546 FunctionType funType =
function.getFunctionType();
1548 if (funType.getNumResults() != getOperands().size())
1549 return emitOpError()
1550 .append(
"expected ", funType.getNumResults(),
" result operands")
1551 .attachNote(
function.getLoc())
1552 .append(
"return type declared here");
1555 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1556 auto [type, operand] = pair.value();
1557 if (type != operand.getType())
1558 return emitOpError() <<
"unexpected type `" << operand.getType()
1559 <<
"' for operand #" << pair.index();
1569 StringRef name, ArrayAttr targets) {
1580 build(builder, result, name,
1581 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets));
1585 StringAttr nameAttr;
1586 ArrayAttr targetsAttr;
1596 if (
failed(*targetsAttrResult)) {
1620 if (
Attribute attr = getTargetsAttr()) {
1627 (*this)->getAttrs(),
1628 {mlir::SymbolTable::getSymbolAttrName(), getTargetsAttrName()});
1634 bool GPUModuleOp::hasTarget(
Attribute target) {
1635 if (ArrayAttr targets = getTargetsAttr())
1636 return llvm::count(targets.getValue(), target);
1641 ArrayAttr &targetsAttr = getProperties().targets;
1650 Attribute offloadingHandler, ArrayAttr objects) {
1654 properties.objects = objects;
1655 if (offloadingHandler)
1656 properties.offloadingHandler = offloadingHandler;
1658 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1663 build(builder, result, name, offloadingHandler,
1664 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1675 if (!offloadingHandler)
1683 printer << '<' << offloadingHandler << '>
';
1686 //===----------------------------------------------------------------------===//
1688 //===----------------------------------------------------------------------===//
1690 LogicalResult MemcpyOp::verify() {
1691 auto srcType = getSrc().getType();
1692 auto dstType = getDst().getType();
1694 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1695 return emitOpError("arguments have incompatible element type");
1697 if (failed(verifyCompatibleShape(srcType, dstType)))
1698 return emitOpError("arguments have incompatible shape");
1707 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1708 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1710 LogicalResult matchAndRewrite(MemcpyOp op,
1711 PatternRewriter &rewriter) const override {
1712 Value dest = op.getDst();
1713 Operation *destDefOp = dest.getDefiningOp();
1714 // `dest` must be defined by an op having Allocate memory effect in order to
1715 // perform the folding.
1717 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1719 // We can erase `op` iff `dest` has no other use apart from its
1720 // use by `op` and dealloc ops.
1721 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1722 return user != op &&
1723 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1726 // We can perform the folding if and only if op has a single async
1727 // dependency and produces an async token as result, or if it does not have
1728 // any async dependency and does not produce any async token result.
1729 if (op.getAsyncDependencies().size() > 1 ||
1730 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1731 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1733 rewriter.replaceOp(op, op.getAsyncDependencies());
1738 } // end anonymous namespace
1740 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1741 MLIRContext *context) {
1742 results.add<EraseTrivialCopyOp>(context);
1745 //===----------------------------------------------------------------------===//
1746 // GPU_SubgroupMmaLoadMatrixOp
1747 //===----------------------------------------------------------------------===//
1749 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1750 auto srcType = getSrcMemref().getType();
1751 auto resType = getRes().getType();
1752 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1753 auto operand = resMatrixType.getOperand();
1754 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1756 if (!isLastMemrefDimUnitStride(srcMemrefType))
1758 "expected source memref most minor dim must have unit stride");
1760 if (!operand.equals("AOp") && !operand.equals("BOp") &&
1761 !operand.equals("COp"))
1762 return emitError("only AOp, BOp and COp can be loaded");
1767 //===----------------------------------------------------------------------===//
1768 // GPU_SubgroupMmaStoreMatrixOp
1769 //===----------------------------------------------------------------------===//
1771 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1772 auto srcType = getSrc().getType();
1773 auto dstType = getDstMemref().getType();
1774 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1775 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1777 if (!isLastMemrefDimUnitStride(dstMemrefType))
1779 "expected destination memref most minor dim must have unit stride");
1781 if (!srcMatrixType.getOperand().equals("COp"))
1783 "expected the operand matrix being stored to have 'COp
' operand type");
1788 //===----------------------------------------------------------------------===//
1789 // GPU_SubgroupMmaComputeOp
1790 //===----------------------------------------------------------------------===//
1792 LogicalResult SubgroupMmaComputeOp::verify() {
1793 enum OperandMap { A, B, C };
1794 SmallVector<MMAMatrixType, 3> opTypes;
1795 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1796 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1797 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1799 if (!opTypes[A].getOperand().equals("AOp") ||
1800 !opTypes[B].getOperand().equals("BOp") ||
1801 !opTypes[C].getOperand().equals("COp"))
1802 return emitError("operands must be in the order AOp, BOp, COp");
1804 ArrayRef<int64_t> aShape, bShape, cShape;
1805 aShape = opTypes[A].getShape();
1806 bShape = opTypes[B].getShape();
1807 cShape = opTypes[C].getShape();
1809 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1810 bShape[1] != cShape[1])
1811 return emitError("operand shapes do not satisfy matmul constraints");
1816 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1817 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1818 return memref::foldMemRefCast(*this);
1821 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1822 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1823 return memref::foldMemRefCast(*this);
1826 //===----------------------------------------------------------------------===//
1828 //===----------------------------------------------------------------------===//
1835 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1837 using OpRewritePattern::OpRewritePattern;
1839 LogicalResult matchAndRewrite(WaitOp op,
1840 PatternRewriter &rewriter) const final {
1841 auto predicate = [](Value value) {
1842 auto waitOp = value.getDefiningOp<WaitOp>();
1843 return waitOp && waitOp->getNumOperands() == 0;
1845 if (llvm::none_of(op.getAsyncDependencies(), predicate))
1847 SmallVector<Value> validOperands;
1848 for (Value operand : op->getOperands()) {
1849 if (predicate(operand))
1851 validOperands.push_back(operand);
1853 rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); });
1865 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
1867 using OpRewritePattern::OpRewritePattern;
1869 LogicalResult matchAndRewrite(WaitOp op,
1870 PatternRewriter &rewriter) const final {
1871 // Erase gpu.wait ops that neither have any async dependencies nor return
1873 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
1874 rewriter.eraseOp(op);
1877 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
1878 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
1879 op.getAsyncToken()) {
1880 rewriter.replaceOp(op, op.getAsyncDependencies());
1883 // Erase %t = gpu.wait async ... ops, where %t has no uses.
1884 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
1885 rewriter.eraseOp(op);
1892 } // end anonymous namespace
1894 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
1895 MLIRContext *context) {
1896 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
1899 //===----------------------------------------------------------------------===//
1901 //===----------------------------------------------------------------------===//
1903 LogicalResult AllocOp::verify() {
1904 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
1906 if (static_cast<int64_t>(getDynamicSizes().size()) !=
1907 memRefType.getNumDynamicDims())
1908 return emitOpError("dimension operand count does not equal memref "
1909 "dynamic dimension count");
1911 unsigned numSymbols = 0;
1912 if (!memRefType.getLayout().isIdentity())
1913 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
1914 if (getSymbolOperands().size() != numSymbols) {
1916 "symbol operand count does not equal memref symbol count");
1926 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
1927 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
1929 LogicalResult matchAndRewrite(memref::DimOp dimOp,
1930 PatternRewriter &rewriter) const override {
1931 std::optional<int64_t> index = dimOp.getConstantIndex();
1935 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
1936 if (!memrefType || !memrefType.isDynamicDim(index.value()))
1939 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
1943 Value substituteOp = *(alloc.getDynamicSizes().begin() +
1944 memrefType.getDynamicDimIndex(index.value()));
1945 rewriter.replaceOp(dimOp, substituteOp);
1952 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1953 MLIRContext *context) {
1954 results.add<SimplifyDimOfAllocOp>(context);
1957 //===----------------------------------------------------------------------===//
1958 // GPU object attribute
1959 //===----------------------------------------------------------------------===//
1961 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1962 Attribute target, CompilationTarget format,
1963 StringAttr object, DictionaryAttr properties) {
1965 return emitError() << "the target attribute cannot be null";
1966 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
1968 return emitError() << "the target attribute must implement or promise the "
1969 "`gpu::TargetAttrInterface`";
1973 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
1974 StringAttr &object) {
1975 std::optional<CompilationTarget> formatResult;
1976 StringRef enumKeyword;
1977 auto loc = odsParser.getCurrentLocation();
1978 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
1979 formatResult = CompilationTarget::Fatbin;
1980 if (!formatResult &&
1982 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
1983 odsParser.parseEqual())
1984 return odsParser.emitError(loc, "expected an equal sign");
1986 return odsParser.emitError(loc, "expected keyword for GPU object format");
1987 FailureOr<StringAttr> objectResult =
1988 FieldParser<StringAttr>::parse(odsParser);
1989 if (failed(objectResult))
1990 return odsParser.emitError(odsParser.getCurrentLocation(),
1991 "failed to parse GPU_ObjectAttr parameter "
1992 "'
object' which is to be a `StringAttr`");
1993 format = *formatResult;
1994 object = *objectResult;
1998 void printObject(AsmPrinter &odsParser, CompilationTarget format,
1999 StringAttr object) {
2000 if (format != CompilationTarget::Fatbin)
2001 odsParser << stringifyEnum(format) << " = ";
2002 odsParser << object;
2006 //===----------------------------------------------------------------------===//
2007 // GPU select object attribute
2008 //===----------------------------------------------------------------------===//
2011 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2013 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2015 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2016 if (intAttr.getInt() < 0) {
2017 return emitError() << "the object index must be positive";
2019 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2021 << "the target attribute must be a GPU Target attribute";
2027 //===----------------------------------------------------------------------===//
2028 // GPU target options
2029 //===----------------------------------------------------------------------===//
2031 TargetOptions::TargetOptions(
2032 StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2033 StringRef cmdOptions, CompilationTarget compilationTarget,
2034 function_ref<SymbolTable *()> getSymbolTableCallback)
2035 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2036 cmdOptions, compilationTarget, getSymbolTableCallback) {}
2038 TargetOptions::TargetOptions(
2039 TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2040 StringRef cmdOptions, CompilationTarget compilationTarget,
2041 function_ref<SymbolTable *()> getSymbolTableCallback)
2042 : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2043 cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2044 getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2046 TypeID TargetOptions::getTypeID() const { return typeID; }
2048 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2050 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2052 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2054 SymbolTable *TargetOptions::getSymbolTable() const {
2055 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2058 CompilationTarget TargetOptions::getCompilationTarget() const {
2059 return compilationTarget;
2062 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2063 return CompilationTarget::Fatbin;
2066 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2067 TargetOptions::tokenizeCmdOptions() const {
2068 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2069 llvm::StringSaver stringSaver(options.first);
2070 StringRef opts = cmdOptions;
2071 // For a correct tokenization of the command line options `opts` must be
2072 // unquoted, otherwise the tokenization function returns a single string: the
2073 // unquoted `cmdOptions` -which is not the desired behavior.
2074 // Remove any quotes if they are at the beginning and end of the string:
2075 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2076 opts.consume_front("\""), opts.consume_back("\"");
2077 if (!opts.empty() && opts.front() == '\
'' && opts.back() ==
'\'')
2078 opts.consume_front(
"'"), opts.consume_back(
"'");
2080 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2083 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2091 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2092 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2094 #define GET_ATTRDEF_CLASSES
2095 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2097 #define GET_OP_CLASSES
2098 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2100 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
GPUMemorySpace
GPU memory space identifiers.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
@ kGenericMemorySpace
Generic memory space identifier.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static bool verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
BlockListType & getBlocks()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args... args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....