36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
55int64_t GPUBlockMappingAttr::getMappingId()
const {
56 return static_cast<int64_t>(getBlock());
59bool GPUBlockMappingAttr::isLinearMapping()
const {
60 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
63int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
64 return isLinearMapping()
65 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
69int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
70 return static_cast<int64_t>(getWarpgroup());
73bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
74 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
78 return isLinearMapping()
79 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
83int64_t GPUWarpMappingAttr::getMappingId()
const {
84 return static_cast<int64_t>(getWarp());
87bool GPUWarpMappingAttr::isLinearMapping()
const {
88 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
91int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
92 return isLinearMapping()
93 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
97int64_t GPUThreadMappingAttr::getMappingId()
const {
98 return static_cast<int64_t>(getThread());
101bool GPUThreadMappingAttr::isLinearMapping()
const {
102 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
105int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
106 return isLinearMapping()
107 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
111int64_t GPULaneMappingAttr::getMappingId()
const {
112 return static_cast<int64_t>(getLane());
115bool GPULaneMappingAttr::isLinearMapping()
const {
116 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
119int64_t GPULaneMappingAttr::getRelativeIndex()
const {
120 return isLinearMapping()
121 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
141 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(
b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
146 return math::CtPopOp::create(
b, loc, filteredId);
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
163 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
172int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
173 return static_cast<int64_t>(getAddressSpace());
176bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
177 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
181 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
198 elementType, operand);
212 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
221 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
222 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
224 if (
shape.size() != 2)
225 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
229 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
238bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
241 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
251bool GPUDialect::isKernel(
Operation *op) {
252 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
259struct GPUInlinerInterface :
public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
263 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
286 declarePromisedInterfaces<
287 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
288 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
289 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
295 return "sparse.dntensor_handle";
297 return "sparse.spmat_handle";
299 return "sparse.spgemmop_handle";
301 llvm_unreachable(
"unknown sparse handle kind");
305Type GPUDialect::parseType(DialectAsmParser &parser)
const {
313 if (keyword ==
"async.token")
316 if (keyword ==
"mma_matrix") {
324 SmallVector<int64_t> shape;
345 shape, elementType, operand);
360void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
363 .Case<SparseDnTensorHandleType>([&](Type) {
366 .Case<SparseSpMatHandleType>(
368 .Case<SparseSpGEMMOpHandleType>([&](Type) {
374 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
377 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
379 .DefaultUnreachable(
"unexpected 'gpu' type kind");
384 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
387 " must be a dense i32 array");
388 if (array.size() != 3)
390 " must contain exactly 3 elements");
394LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
395 NamedAttribute attr) {
396 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
398 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
400 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
402 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
403 attr.
getName() != getContainerModuleAttrName())
406 auto module = dyn_cast<ModuleOp>(op);
409 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
410 << ModuleOp::getOperationName() <<
'\'';
412 auto walkResult =
module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
415 if (!launchOp->getParentOp() ||
416 launchOp->getParentOp()->getParentOp() != module)
421 if (!launchOp->getAttrOfType<SymbolRefAttr>(
422 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
426 StringAttr kernelContainerName = launchOp.getKernelModuleName();
427 Operation *kernelContainer =
module.lookupSymbol(kernelContainerName);
428 if (!kernelContainer)
429 return launchOp.emitOpError()
430 <<
"kernel container '" << kernelContainerName.getValue()
434 if (isa<BinaryOp>(kernelContainer))
437 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
439 return launchOp.emitOpError()
440 <<
"kernel module '" << kernelContainerName.getValue()
444 Operation *kernelFunc =
module.lookupSymbol(launchOp.getKernelAttr());
446 return launchOp.emitOpError(
"kernel function '")
447 << launchOp.getKernel() <<
"' is undefined";
448 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
449 if (!kernelConvertedFunction) {
450 InFlightDiagnostic
diag = launchOp.emitOpError()
451 <<
"referenced kernel '" << launchOp.getKernel()
452 <<
"' is not a function";
453 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
458 GPUDialect::getKernelFuncAttrName()))
459 return launchOp.emitOpError(
"kernel function is missing the '")
460 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
465 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
466 if (!kernelGPUFunction)
469 unsigned actualNumArguments = launchOp.getNumKernelOperands();
470 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
471 if (expectedNumArguments != actualNumArguments)
472 return launchOp.emitOpError(
"got ")
473 << actualNumArguments <<
" kernel operands but expected "
474 << expectedNumArguments;
476 auto functionType = kernelGPUFunction.getFunctionType();
477 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
478 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
479 return launchOp.emitOpError(
"type of function argument ")
480 << i <<
" does not match";
487 return walkResult.wasInterrupted() ? failure() :
success();
500 return parser.
emitError(loc,
"needs to be named when marked 'async'");
515 if (asyncDependencies.empty())
519 printer << llvm::interleaved_array(asyncDependencies);
547 p <<
' ' << keyword <<
'(';
548 llvm::interleaveComma(
549 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
550 BlockArgument v = pair.value();
551 p << v <<
" : " << v.
getType();
553 size_t attributionIndex = pair.index();
554 DictionaryAttr attrs;
555 if (attributes && attributionIndex < attributes.size())
556 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
566 gpu::AddressSpace memorySpace) {
567 for (
Value v : attributions) {
568 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
570 return op->
emitOpError() <<
"expected memref type in attribution";
575 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
578 if (addressSpace.getValue() != memorySpace)
580 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
581 <<
" in attribution";
592 using Kind = gpu::AllReduceOperation;
593 if (llvm::is_contained(
594 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
596 if (!isa<FloatType>(resType))
600 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
601 Kind::AND, Kind::OR, Kind::XOR},
603 if (!isa<IntegerType>(resType))
610LogicalResult gpu::AllReduceOp::verifyRegions() {
611 if (getBody().empty() != getOp().has_value())
612 return emitError(
"expected either an op attribute or a non-empty body");
613 if (!getBody().empty()) {
614 if (getBody().getNumArguments() != 2)
615 return emitError(
"expected two region arguments");
616 for (
auto argument : getBody().getArguments()) {
617 if (argument.getType() !=
getType())
618 return emitError(
"incorrect region argument type");
620 unsigned yieldCount = 0;
621 for (
Block &block : getBody()) {
622 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
623 if (yield.getNumOperands() != 1)
624 return emitError(
"expected one gpu.yield operand");
625 if (yield.getOperand(0).getType() !=
getType())
626 return emitError(
"incorrect gpu.yield type");
631 return emitError(
"expected gpu.yield op in region");
633 gpu::AllReduceOperation opName = *getOp();
635 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
636 <<
"` reduction operation is not compatible with type "
645 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
649 Region &body = launchOp.getBody();
650 assert(!body.
empty() &&
"Invalid region");
656OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
667 AllReduceOperationAttr &attr) {
670 std::optional<AllReduceOperation> op =
671 gpu::symbolizeAllReduceOperation(enumStr);
674 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
680 AllReduceOperationAttr attr) {
689LogicalResult gpu::SubgroupReduceOp::verify() {
691 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
692 if (vecTy.isScalable())
693 return emitOpError() <<
"is not compatible with scalable vector types";
695 elemType = vecTy.getElementType();
698 gpu::AllReduceOperation opName = getOp();
700 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
701 <<
"` reduction operation is not compatible with type "
705 auto clusterSize = getClusterSize();
707 uint32_t size = *clusterSize;
708 if (!llvm::isPowerOf2_32(size)) {
710 <<
" is not a power of two";
714 uint32_t stride = getClusterStride();
715 if (stride != 1 && !clusterSize) {
716 return emitOpError() <<
"cluster stride can only be specified if cluster "
719 if (!llvm::isPowerOf2_32(stride)) {
720 return emitOpError() <<
"cluster stride " << stride
721 <<
" is not a power of two";
727OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
728 if (getClusterSize() == 1)
745 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
749 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
767 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
777 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
781 result.addOperands(asyncDependencies);
786 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
787 getBlockSizeY, getBlockSizeZ});
789 result.addOperands(clusterSizeX);
791 result.addOperands(clusterSizeY);
793 result.addOperands(clusterSizeZ);
794 if (dynamicSharedMemorySize)
795 result.addOperands(dynamicSharedMemorySize);
799 result.addAttribute(getModuleAttrName(
result.name), module);
801 result.addAttribute(getFunctionAttrName(
result.name), function);
809 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
812 for (
Type argTy : workgroupAttributions)
814 for (
Type argTy : privateAttributions)
818 segmentSizes.front() = asyncDependencies.size();
819 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
820 segmentSizes[7] = clusterSizeX ? 1 : 0;
821 segmentSizes[8] = clusterSizeY ? 1 : 0;
822 segmentSizes[9] = clusterSizeZ ? 1 : 0;
823 result.addAttribute(getOperandSegmentSizeAttr(),
828 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
829 auto args = getBody().getArguments();
834 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
835 auto args = getBody().getArguments();
840 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
841 auto args = getBody().getArguments();
846 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
847 auto args = getBody().getArguments();
848 return KernelDim3{args[9], args[10], args[11]};
851std::optional<KernelDim3> LaunchOp::getClusterIds() {
852 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
853 if (!hasClusterSize())
855 auto args = getBody().getArguments();
856 return KernelDim3{args[12], args[13], args[14]};
859std::optional<KernelDim3> LaunchOp::getClusterSize() {
860 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
861 if (!hasClusterSize())
863 auto args = getBody().getArguments();
864 return KernelDim3{args[15], args[16], args[17]};
867KernelDim3 LaunchOp::getGridSizeOperandValues() {
868 auto operands = getOperands().drop_front(getAsyncDependencies().size());
869 return KernelDim3{operands[0], operands[1], operands[2]};
872KernelDim3 LaunchOp::getBlockSizeOperandValues() {
873 auto operands = getOperands().drop_front(getAsyncDependencies().size());
874 return KernelDim3{operands[3], operands[4], operands[5]};
877std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
878 auto operands = getOperands().drop_front(getAsyncDependencies().size());
879 if (!hasClusterSize())
881 return KernelDim3{operands[6], operands[7], operands[8]};
884LogicalResult LaunchOp::verify() {
885 if (!(hasClusterSize()) &&
886 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
887 return emitOpError() <<
"cluster size must be all present";
891LogicalResult LaunchOp::verifyRegions() {
895 if (!getBody().empty()) {
896 if (getBody().getNumArguments() <
897 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
898 return emitOpError(
"unexpected number of region arguments");
903 GPUDialect::getWorkgroupAddressSpace())) ||
905 GPUDialect::getPrivateAddressSpace())))
910 for (
Block &block : getBody()) {
913 if (block.back().getNumSuccessors() != 0)
915 if (!isa<gpu::TerminatorOp>(&block.back())) {
918 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
919 "' or a terminator with successors")
920 .attachNote(getLoc())
921 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
925 if (getNumResults() == 0 && getAsyncToken())
926 return emitOpError(
"needs to be named when async keyword is specified");
937 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
938 p << size.
x <<
" = " << operands.
x <<
", ";
939 p << size.
y <<
" = " << operands.
y <<
", ";
940 p << size.
z <<
" = " << operands.
z <<
')';
943void LaunchOp::print(OpAsmPrinter &p) {
944 if (getAsyncToken()) {
946 if (!getAsyncDependencies().empty())
947 p <<
" [" << getAsyncDependencies() <<
']';
950 if (hasClusterSize()) {
951 p <<
' ' << getClustersKeyword();
953 getClusterSizeOperandValues().value(),
954 getClusterIds().value());
956 p <<
' ' << getBlocksKeyword();
959 p <<
' ' << getThreadsKeyword();
962 if (getDynamicSharedMemorySize())
963 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
964 << getDynamicSharedMemorySize();
967 StringRef moduleAttrName = getModuleAttrName();
968 if (
auto module = getModule()) {
969 p <<
' ' << moduleAttrName <<
'(';
974 StringRef functionAttrName = getFunctionAttrName();
975 if (
auto function = getFunction()) {
976 p <<
' ' << functionAttrName <<
'(';
988 LaunchOp::getOperandSegmentSizeAttr(),
989 getNumWorkgroupAttributionsAttrName(),
990 moduleAttrName, functionAttrName});
1004 StringRef keyword) {
1005 assert(
indices.size() == 3 &&
"space for three indices expected");
1012 if (args.size() != 3) {
1014 << keyword <<
" expects 3 arguments, but got " << args.size();
1016 std::move(args.begin(), args.end(),
indices.begin());
1018 for (
int i = 0; i < 3; ++i) {
1040ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
1042 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1043 sizes(LaunchOp::kNumConfigOperands);
1046 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1047 LaunchOp::kNumConfigRegionAttributes);
1050 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1051 Type asyncTokenType;
1058 if (!asyncTokenType)
1061 "gpu.launch requires 'async' keyword to return a value");
1062 result.types.push_back(asyncTokenType);
1065 bool hasCluster =
false;
1069 regionArgs.resize(18);
1071 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1072 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1078 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1079 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1087 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1089 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1090 LaunchOp::getBlocksKeyword()) ||
1093 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1094 LaunchOp::getThreadsKeyword()) ||
1099 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1100 bool hasDynamicSharedMemorySize =
false;
1102 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1103 hasDynamicSharedMemorySize =
true;
1112 StringRef moduleAttrName = getModuleAttrName(
result.name);
1114 FlatSymbolRefAttr moduleSymbol;
1122 StringRef functionAttrName = getFunctionAttrName(
result.name);
1124 FlatSymbolRefAttr funcSymbol;
1139 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1140 LaunchOp::kNumConfigRegionAttributes + 6, index);
1142 SmallVector<OpAsmParser::Argument> regionArguments;
1143 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1144 OpAsmParser::Argument arg;
1145 arg.
ssaName = std::get<0>(ssaValueAndType);
1146 arg.
type = std::get<1>(ssaValueAndType);
1147 regionArguments.push_back(arg);
1158 unsigned numWorkgroupAttrs = regionArguments.size() -
1159 LaunchOp::kNumConfigRegionAttributes -
1160 (hasCluster ? 6 : 0);
1161 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1172 Region *body =
result.addRegion();
1177 SmallVector<int32_t, 11> segmentSizes(11, 1);
1178 segmentSizes.front() = asyncDependencies.size();
1181 segmentSizes[7] = 0;
1182 segmentSizes[8] = 0;
1183 segmentSizes[9] = 0;
1185 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1186 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1200 bool simplified =
false;
1201 auto constPropIdUses = [&](
Value id,
Value size) {
1205 if (
id.getUses().empty())
1217 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1218 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1219 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1220 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1221 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1222 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1228void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1229 MLIRContext *context) {
1230 rewrites.
add<FoldLaunchArguments>(context);
1235BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1236 auto attrName = getNumWorkgroupAttributionsAttrName();
1237 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1238 (*this)->setAttr(attrName,
1239 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1240 return getBody().insertArgument(
1241 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1246BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1249 return getBody().addArgument(type, loc);
1256void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1257 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1258 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1259 ValueRange kernelOperands, Type asyncTokenType,
1261 std::optional<KernelDim3> clusterSize) {
1262 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1263 "expected a symbol reference with a single nested reference");
1264 result.addOperands(asyncDependencies);
1271 if (clusterSize.has_value())
1272 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1273 if (dynamicSharedMemorySize)
1274 result.addOperands(dynamicSharedMemorySize);
1275 result.addOperands(kernelOperands);
1277 Properties &prop =
result.getOrAddProperties<Properties>();
1278 prop.kernel = kernelSymbol;
1279 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1281 llvm::fill(prop.operandSegmentSizes, 1);
1282 prop.operandSegmentSizes[0] = asyncDependencies.size();
1283 if (!clusterSize.has_value()) {
1284 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1285 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1286 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1288 prop.operandSegmentSizes[segmentSizesLen - 3] =
1289 dynamicSharedMemorySize ? 1 : 0;
1290 prop.operandSegmentSizes[segmentSizesLen - 2] =
1291 static_cast<int32_t
>(kernelOperands.size());
1292 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1295void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1297 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1298 ValueRange kernelOperands, Type asyncTokenType,
1300 std::optional<KernelDim3> clusterSize) {
1301 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1303 SymbolRefAttr::get(kernelModule.getNameAttr(),
1304 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1305 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1306 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1307 asyncDependencies, clusterSize);
1310void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1312 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1313 ValueRange kernelOperands, Value asyncObject,
1314 std::optional<KernelDim3> clusterSize) {
1318 if (clusterSize.has_value())
1319 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1320 if (dynamicSharedMemorySize)
1321 result.addOperands(dynamicSharedMemorySize);
1322 result.addOperands(kernelOperands);
1324 result.addOperands(asyncObject);
1325 Properties &prop =
result.getOrAddProperties<Properties>();
1326 prop.kernel = kernel;
1327 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1329 llvm::fill(prop.operandSegmentSizes, 1);
1330 prop.operandSegmentSizes[0] = 0;
1331 if (!clusterSize.has_value()) {
1332 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1333 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1334 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1336 prop.operandSegmentSizes[segmentSizesLen - 3] =
1337 dynamicSharedMemorySize ? 1 : 0;
1338 prop.operandSegmentSizes[segmentSizesLen - 2] =
1339 static_cast<int32_t
>(kernelOperands.size());
1340 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1343StringAttr LaunchFuncOp::getKernelModuleName() {
1347StringAttr LaunchFuncOp::getKernelName() {
1351unsigned LaunchFuncOp::getNumKernelOperands() {
1352 return getKernelOperands().size();
1355Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1356 return getKernelOperands()[i];
1359KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1360 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1361 return KernelDim3{operands[0], operands[1], operands[2]};
1364KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1365 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1366 return KernelDim3{operands[3], operands[4], operands[5]};
1369KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1370 assert(hasClusterSize() &&
1371 "cluster size is not set, check hasClusterSize() first");
1372 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1373 return KernelDim3{operands[6], operands[7], operands[8]};
1376LogicalResult LaunchFuncOp::verify() {
1377 auto module = (*this)->getParentOfType<ModuleOp>();
1379 return emitOpError(
"expected to belong to a module");
1381 if (!module->getAttrOfType<UnitAttr>(
1382 GPUDialect::getContainerModuleAttrName()))
1383 return emitOpError(
"expected the closest surrounding module to have the '" +
1384 GPUDialect::getContainerModuleAttrName() +
1387 if (hasClusterSize()) {
1388 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1391 <<
"expects types of the cluster dimensions must be the same";
1399 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1400 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1407 if (clusterValue.has_value()) {
1408 clusterXTy = clusterYTy = clusterZTy = dimTy;
1415 Type clusterYTy,
Type clusterZTy) {
1417 printer <<
": " << dimTy;
1427 auto parseElement = [&]() -> ParseResult {
1428 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1433 parseElement,
" in argument list");
1438 if (operands.empty())
1441 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1442 [&](
const auto &pair) {
1443 auto [operand, type] = pair;
1444 printer << operand <<
" : " << type;
1453void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1454 int32_t offset, int32_t width, ShuffleMode mode) {
1455 build(builder,
result, value,
1456 arith::ConstantOp::create(builder,
result.location,
1458 arith::ConstantOp::create(builder,
result.location,
1467LogicalResult RotateOp::verify() {
1468 uint32_t offset = getOffset();
1469 uint32_t width = getWidth();
1471 if (offset >= width) {
1472 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1485 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1489 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1490 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1494 if (!nextMemfence) {
1495 op.removeAddressSpacesAttr();
1499 if (*thisMemfence == *nextMemfence) {
1503 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1505 mergedSpaces.insert(attr);
1507 mergedSpaces.insert(attr);
1508 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1516void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517 MLIRContext *context) {
1521void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1522 mlir::OperationState &odsState,
1523 std::optional<AddressSpace> addressSpace) {
1527 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1528 build(odsBuilder, odsState, addressSpacesAttr);
1535void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1536 Value memrefToFence) {
1537 std::optional<AddressSpace> addrSpaceToFence;
1538 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1539 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1540 memrefType.getMemorySpace()))
1541 addrSpaceToFence = addrSpaceAttr.getValue();
1542 return build(builder, odsState, addrSpaceToFence);
1551BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1552 auto attrName = getNumWorkgroupAttributionsAttrName();
1553 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1554 (*this)->setAttr(attrName,
1555 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1556 return getBody().insertArgument(
1557 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1562BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1565 return getBody().addArgument(type, loc);
1568void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1569 StringRef name, FunctionType type,
1572 ArrayRef<NamedAttribute> attrs) {
1573 OpBuilder::InsertionGuard g(builder);
1577 result.addAttribute(getFunctionTypeAttrName(
result.name),
1578 TypeAttr::get(type));
1579 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1581 result.addAttributes(attrs);
1582 Region *body =
result.addRegion();
1586 for (Type argTy : type.getInputs())
1588 for (Type argTy : workgroupAttributions)
1590 for (Type argTy : privateAttributions)
1609 size_t existingArgs = args.size();
1616 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1621 attributionAttrs =
nullptr;
1627 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1628 if (!argument.attrs)
1631 attributionAttrsVec.push_back(argument.attrs);
1633 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1642ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1643 SmallVector<OpAsmParser::Argument> entryArgs;
1644 SmallVector<DictionaryAttr> resultAttrs;
1645 SmallVector<Type> resultTypes;
1649 StringAttr nameAttr;
1656 parser,
false, entryArgs, isVariadic, resultTypes,
1660 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1661 return parser.
emitError(signatureLocation)
1662 <<
"gpu.func requires named arguments";
1668 SmallVector<Type> argTypes;
1669 for (
auto &arg : entryArgs)
1670 argTypes.push_back(arg.
type);
1672 result.addAttribute(getFunctionTypeAttrName(
result.name),
1673 TypeAttr::get(type));
1676 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1677 getResAttrsAttrName(
result.name));
1679 Attribute workgroupAttributionAttrs;
1682 entryArgs, workgroupAttributionAttrs)))
1687 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1688 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1690 if (workgroupAttributionAttrs)
1691 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1692 workgroupAttributionAttrs);
1694 Attribute privateAttributionAttrs;
1697 entryArgs, privateAttributionAttrs)))
1699 if (privateAttributionAttrs)
1700 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1701 privateAttributionAttrs);
1705 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1714 auto *body =
result.addRegion();
1718void GPUFuncOp::print(OpAsmPrinter &p) {
1722 FunctionType type = getFunctionType();
1728 getWorkgroupAttribAttrs().value_or(
nullptr));
1730 getPrivateAttribAttrs().value_or(
nullptr));
1732 p <<
' ' << getKernelKeyword();
1736 {getNumWorkgroupAttributionsAttrName(),
1737 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1738 getArgAttrsAttrName(), getResAttrsAttrName(),
1739 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1745 StringAttr attrName) {
1746 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1747 if (!allAttrs ||
index >= allAttrs.size())
1748 return DictionaryAttr();
1749 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1752DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1756DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1761 DictionaryAttr value, StringAttr attrName) {
1763 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1766 elements.append(allAttrs.begin(), allAttrs.end());
1767 while (elements.size() <=
index)
1768 elements.push_back(DictionaryAttr::get(ctx));
1770 elements[
index] = DictionaryAttr::get(ctx);
1772 elements[
index] = value;
1773 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1774 op->setAttr(attrName, newValue);
1777void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1778 DictionaryAttr value) {
1782void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1783 DictionaryAttr value) {
1788 StringAttr name, StringAttr attrsName) {
1792 return dict.get(name);
1795Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1797 assert(index < getNumWorkgroupAttributions() &&
1798 "index must map to a workgroup attribution");
1800 getWorkgroupAttribAttrsAttrName());
1803Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1805 assert(index < getNumPrivateAttributions() &&
1806 "index must map to a private attribution");
1808 getPrivateAttribAttrsAttrName());
1812 Attribute value, StringAttr attrsName) {
1817 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1820 bool mustSort =
true;
1821 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1822 if (elems[i].getName() == name) {
1825 std::swap(elems[i], elems[elems.size() - 1]);
1837 elems.emplace_back(name, value);
1840 DictionaryAttr::sortInPlace(elems);
1842 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1846void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1848 assert(index < getNumWorkgroupAttributions() &&
1849 "index must map to a workgroup attribution");
1851 getWorkgroupAttribAttrsAttrName());
1854void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1856 assert(index < getNumPrivateAttributions() &&
1857 "index must map to a private attribution");
1859 getPrivateAttribAttrsAttrName());
1862LogicalResult GPUFuncOp::verifyType() {
1863 if (isKernel() && getFunctionType().getNumResults() != 0)
1864 return emitOpError() <<
"expected void return type for kernel function";
1870LogicalResult GPUFuncOp::verifyBody() {
1872 return emitOpError() <<
"expected body with at least one block";
1873 unsigned numFuncArguments = getNumArguments();
1874 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1875 unsigned numBlockArguments = front().getNumArguments();
1876 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1878 << numFuncArguments + numWorkgroupAttributions
1879 <<
" arguments to body region";
1881 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1882 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1883 Type blockArgType = front().getArgument(i).getType();
1884 if (funcArgTypes[i] != blockArgType)
1885 return emitOpError() <<
"expected body region argument #" << i
1886 <<
" to be of type " << funcArgTypes[i] <<
", got "
1891 GPUDialect::getWorkgroupAddressSpace())) ||
1893 GPUDialect::getPrivateAddressSpace())))
1903LogicalResult gpu::ReturnOp::verify() {
1904 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1906 FunctionType funType = function.getFunctionType();
1908 if (funType.getNumResults() != getOperands().size())
1910 .append(
"expected ", funType.getNumResults(),
" result operands")
1911 .attachNote(function.getLoc())
1912 .append(
"return type declared here");
1914 for (
const auto &pair : llvm::enumerate(
1915 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1916 auto [type, operand] = pair.value();
1917 if (type != operand.getType())
1918 return emitOpError() <<
"unexpected type `" << operand.getType()
1919 <<
"' for operand #" << pair.index();
1928void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1930 Attribute offloadingHandler) {
1931 result.addRegion()->emplaceBlock();
1932 Properties &props =
result.getOrAddProperties<Properties>();
1934 props.targets = targets;
1936 props.offloadingHandler = offloadingHandler;
1939void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1940 StringRef name, ArrayRef<Attribute> targets,
1941 Attribute offloadingHandler) {
1942 build(builder,
result, name,
1947bool GPUModuleOp::hasTarget(Attribute
target) {
1948 if (
ArrayAttr targets = getTargetsAttr())
1949 return llvm::count(targets.getValue(),
target);
1953void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1954 ArrayAttr &targetsAttr = getProperties().targets;
1955 SmallVector<Attribute> targetsVector(targets);
1956 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1959LogicalResult GPUModuleOp::verify() {
1960 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1965 for (
auto target : targets) {
1966 if (
auto verifyTargetAttr =
1967 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1968 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1978void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1979 Attribute offloadingHandler,
ArrayAttr objects) {
1980 auto &properties =
result.getOrAddProperties<Properties>();
1983 properties.objects = objects;
1984 if (offloadingHandler)
1985 properties.offloadingHandler = offloadingHandler;
1987 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1990void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1991 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1992 build(builder,
result, name, offloadingHandler,
2004 if (!offloadingHandler)
2011 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2012 printer <<
'<' << offloadingHandler <<
'>';
2019LogicalResult MemcpyOp::verify() {
2020 auto srcType = getSrc().getType();
2021 auto dstType = getDst().getType();
2024 return emitOpError(
"arguments have incompatible element type");
2027 return emitOpError(
"arguments have incompatible shape");
2036struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2037 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2039 LogicalResult matchAndRewrite(MemcpyOp op,
2040 PatternRewriter &rewriter)
const override {
2041 Value dest = op.getDst();
2050 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2051 return user != op &&
2052 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2058 if (op.getAsyncDependencies().size() > 1 ||
2059 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2060 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2062 rewriter.
replaceOp(op, op.getAsyncDependencies());
2069void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2070 MLIRContext *context) {
2071 results.
add<EraseTrivialCopyOp>(context);
2078LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2079 auto srcType = getSrcMemref().getType();
2080 auto resType = getRes().getType();
2081 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2082 auto operand = resMatrixType.getOperand();
2083 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2085 if (!srcMemrefType.isLastDimUnitStride())
2087 "expected source memref most minor dim must have unit stride");
2089 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2090 return emitError(
"only AOp, BOp and COp can be loaded");
2099LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2100 auto srcType = getSrc().getType();
2101 auto dstType = getDstMemref().getType();
2102 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2103 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2105 if (!dstMemrefType.isLastDimUnitStride())
2107 "expected destination memref most minor dim must have unit stride");
2109 if (srcMatrixType.getOperand() !=
"COp")
2111 "expected the operand matrix being stored to have 'COp' operand type");
2120LogicalResult SubgroupMmaComputeOp::verify() {
2121 enum OperandMap {
A,
B,
C };
2122 SmallVector<MMAMatrixType, 3> opTypes;
2123 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2124 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2125 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2127 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2128 opTypes[C].getOperand() !=
"COp")
2129 return emitError(
"operands must be in the order AOp, BOp, COp");
2131 ArrayRef<int64_t> aShape, bShape, cShape;
2132 aShape = opTypes[
A].getShape();
2133 bShape = opTypes[
B].getShape();
2134 cShape = opTypes[
C].getShape();
2136 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2137 bShape[1] != cShape[1])
2138 return emitError(
"operand shapes do not satisfy matmul constraints");
2143LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2144 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2148LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2149 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2162struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2166 LogicalResult matchAndRewrite(WaitOp op,
2167 PatternRewriter &rewriter)
const final {
2168 auto predicate = [](Value value) {
2169 auto waitOp = value.getDefiningOp<WaitOp>();
2170 return waitOp && waitOp->getNumOperands() == 0;
2172 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2174 SmallVector<Value> validOperands;
2175 for (Value operand : op->getOperands()) {
2176 if (predicate(operand))
2178 validOperands.push_back(operand);
2180 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2192struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2196 LogicalResult matchAndRewrite(WaitOp op,
2197 PatternRewriter &rewriter)
const final {
2200 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2205 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2206 op.getAsyncToken()) {
2207 rewriter.
replaceOp(op, op.getAsyncDependencies());
2211 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2221void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2222 MLIRContext *context) {
2223 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2230LogicalResult AllocOp::verify() {
2231 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2237 unsigned numSymbols = 0;
2238 if (!memRefType.getLayout().isIdentity())
2239 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2240 if (getSymbolOperands().size() != numSymbols) {
2242 "symbol operand count does not equal memref symbol count");
2252struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2253 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2255 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2256 PatternRewriter &rewriter)
const override {
2257 std::optional<int64_t> index = dimOp.getConstantIndex();
2261 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2262 if (!memrefType || index.value() >= memrefType.getRank() ||
2263 !memrefType.isDynamicDim(index.value()))
2266 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2270 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2271 memrefType.getDynamicDimIndex(index.value()));
2272 rewriter.
replaceOp(dimOp, substituteOp);
2279void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2280 MLIRContext *context) {
2281 results.
add<SimplifyDimOfAllocOp>(context);
2289 Attribute
target, CompilationTarget format,
2290 StringAttr
object, DictionaryAttr properties,
2291 KernelTableAttr kernels) {
2293 return emitError() <<
"the target attribute cannot be null";
2294 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2296 return emitError() <<
"the target attribute must implement or promise the "
2297 "`gpu::TargetAttrInterface`";
2301ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2302 StringAttr &
object) {
2303 std::optional<CompilationTarget> formatResult;
2304 StringRef enumKeyword;
2307 formatResult = CompilationTarget::Fatbin;
2308 if (!formatResult &&
2310 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2312 return odsParser.
emitError(loc,
"expected an equal sign");
2314 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2315 FailureOr<StringAttr> objectResult =
2316 FieldParser<StringAttr>::parse(odsParser);
2317 if (
failed(objectResult))
2319 "failed to parse GPU_ObjectAttr parameter "
2320 "'object' which is to be a `StringAttr`");
2321 format = *formatResult;
2322 object = *objectResult;
2326void printObject(AsmPrinter &odsParser, CompilationTarget format,
2327 StringAttr
object) {
2328 if (format != CompilationTarget::Fatbin)
2329 odsParser << stringifyEnum(format) <<
" = ";
2330 odsParser << object;
2343 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2344 if (intAttr.getInt() < 0) {
2345 return emitError() <<
"the object index must be positive";
2347 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2349 <<
"the target attribute must be a GPU Target attribute";
2359LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2360 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2361 return emitOpError() <<
"must be inside an op with symbol table";
2363 MemRefType memrefType = getResultMemref().getType();
2365 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2367 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2368 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2370 if (memrefType.hasStaticShape()) {
2371 return emitOpError() <<
"result memref type must be memref<?xi8, "
2372 "#gpu.address_space<workgroup>>";
2381void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2382 p <<
"(" << getLaneid() <<
")";
2384 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2385 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2386 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2388 if (!getArgs().empty())
2389 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2390 if (!getResults().empty())
2391 p <<
" -> (" << getResults().getTypes() <<
')';
2395 !getResults().empty());
2399ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2400 OperationState &
result) {
2402 result.regions.reserve(1);
2403 Region *warpRegion =
result.addRegion();
2406 OpAsmParser::UnresolvedOperand laneId;
2418 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2425 llvm::SMLoc inputsOperandsLoc;
2426 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2427 SmallVector<Type> inputTypes;
2437 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2448 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2456void WarpExecuteOnLane0Op::getSuccessorRegions(
2457 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2464 regions.push_back(RegionSuccessor(&getWarpRegion()));
2467ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2470void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2473 build(builder,
result, resultTypes, laneId, warpSize,
2477void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2481 result.addOperands(laneId);
2482 result.addAttribute(getAttributeNames()[0],
2484 result.addTypes(resultTypes);
2485 result.addOperands(args);
2486 assert(args.size() == blockArgTypes.size());
2487 OpBuilder::InsertionGuard guard(builder);
2488 Region *warpRegion =
result.addRegion();
2490 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2499 if (expanded == distributed)
2501 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2502 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2503 if (!expandedVecType || !distributedVecType)
2504 return op->
emitOpError(
"expected vector type for distributed operands.");
2505 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2506 expandedVecType.getElementType() != distributedVecType.getElementType())
2508 "expected distributed vectors to have same rank and element type.");
2511 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2512 int64_t eDim = expandedVecType.getDimSize(i);
2513 int64_t dDim = distributedVecType.getDimSize(i);
2516 if (eDim % dDim != 0)
2518 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2519 <<
") to be a multipler of the distributed vector dimension ("
2521 scales[i] = eDim / dDim;
2523 if (llvm::product_of(scales) != warpSize)
2525 <<
"incompatible distribution dimensions from " << expandedVecType
2526 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2531LogicalResult WarpExecuteOnLane0Op::verify() {
2532 if (getArgs().size() != getWarpRegion().getNumArguments())
2534 "expected same number op arguments and block arguments.");
2535 gpu::YieldOp yield = getTerminator();
2536 if (yield.getNumOperands() != getNumResults())
2538 "expected same number of yield operands and return values.");
2539 int64_t warpSize = getWarpSize();
2540 for (
auto [regionArg, arg] :
2541 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2543 warpSize, getOperation())))
2546 for (
auto [yieldOperand,
result] :
2547 llvm::zip_equal(yield.getOperands(), getResults())) {
2549 warpSize, getOperation())))
2554bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2559gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2560 return cast<gpu::YieldOp>(getBody()->getTerminator());
2567void gpu::SubgroupBroadcastOp::inferResultRanges(
2568 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2569 setResultRange(getResult(), argRanges.front());
2573 switch (getBroadcastType()) {
2574 case BroadcastType::first_active_lane:
2578 case BroadcastType::specific_lane:
2582 llvm_unreachable(
"Unknown BroadcastType");
2585LogicalResult gpu::SubgroupBroadcastOp::verify() {
2586 switch (getBroadcastType()) {
2587 case BroadcastType::first_active_lane:
2590 <<
"lane can only be specified for `specific_lane` broadcast";
2592 case BroadcastType::specific_lane:
2595 <<
"lane must be specified for `specific_lane` broadcast";
2598 llvm_unreachable(
"Unknown BroadcastType");
2601OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2603 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2604 return prev.getResult();
2613KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2614 DictionaryAttr metadata) {
2615 assert(kernel &&
"invalid kernel");
2616 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2617 kernel.getAllArgAttrs(), metadata);
2622 FunctionOpInterface kernel,
2623 DictionaryAttr metadata) {
2624 assert(kernel &&
"invalid kernel");
2625 return getChecked(
emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2626 kernel.getAllArgAttrs(), metadata);
2630KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2633 NamedAttrList attrList;
2634 if (DictionaryAttr dict = getMetadata())
2637 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2643 StringAttr name, Type functionType,
2644 ArrayAttr argAttrs, DictionaryAttr metadata) {
2646 return emitError() <<
"the kernel name can't be empty";
2648 if (llvm::any_of(argAttrs, [](Attribute attr) {
2649 return !llvm::isa<DictionaryAttr>(attr);
2652 <<
"all attributes in the array must be a dictionary attribute";
2661KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2662 ArrayRef<KernelMetadataAttr> kernels,
2665 assert((!isSorted || llvm::is_sorted(kernels)) &&
2666 "expected a sorted kernel array");
2668 if (isSorted || llvm::is_sorted(kernels))
2669 return Base::get(context, kernels);
2671 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2672 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2673 return Base::get(context, kernelsTmp);
2676KernelTableAttr KernelTableAttr::getChecked(
2678 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2680 assert((!isSorted || llvm::is_sorted(kernels)) &&
2681 "expected a sorted kernel array");
2683 if (isSorted || llvm::is_sorted(kernels))
2684 return Base::getChecked(
emitError, context, kernels);
2686 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2687 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2688 return Base::getChecked(
emitError, context, kernelsTmp);
2693 ArrayRef<KernelMetadataAttr> kernels) {
2694 if (kernels.size() < 2)
2697 if (std::adjacent_find(kernels.begin(), kernels.end(),
2698 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2699 return l.getName() == r.getName();
2700 }) != kernels.end()) {
2701 return emitError() <<
"expected all kernels to be uniquely named";
2706KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2708 return found ? *iterator : KernelMetadataAttr();
2711KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2713 return found ? *iterator : KernelMetadataAttr();
2793 return CompilationTarget::Fatbin;
2796std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2798 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2799 llvm::StringSaver stringSaver(
options.first);
2805 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2806 opts.consume_front(
"\""), opts.consume_back(
"\"");
2807 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2808 opts.consume_front(
"'"), opts.consume_back(
"'");
2810 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2813 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2819std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2824std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2826 size_t startPos =
cmdOptions.find(startsWith);
2827 if (startPos == std::string::npos)
2838#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2839#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2841#define GET_ATTRDEF_CLASSES
2842#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2844#define GET_OP_CLASSES
2845#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2847#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....