36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
55int64_t GPUBlockMappingAttr::getMappingId()
const {
56 return static_cast<int64_t>(getBlock());
59bool GPUBlockMappingAttr::isLinearMapping()
const {
60 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
63int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
64 return isLinearMapping()
65 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
69int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
70 return static_cast<int64_t>(getWarpgroup());
73bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
74 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
78 return isLinearMapping()
79 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
83int64_t GPUWarpMappingAttr::getMappingId()
const {
84 return static_cast<int64_t>(getWarp());
87bool GPUWarpMappingAttr::isLinearMapping()
const {
88 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
91int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
92 return isLinearMapping()
93 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
97int64_t GPUThreadMappingAttr::getMappingId()
const {
98 return static_cast<int64_t>(getThread());
101bool GPUThreadMappingAttr::isLinearMapping()
const {
102 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
105int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
106 return isLinearMapping()
107 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
111int64_t GPULaneMappingAttr::getMappingId()
const {
112 return static_cast<int64_t>(getLane());
115bool GPULaneMappingAttr::isLinearMapping()
const {
116 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
119int64_t GPULaneMappingAttr::getRelativeIndex()
const {
120 return isLinearMapping()
121 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
141 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(
b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
146 return math::CtPopOp::create(
b, loc, filteredId);
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
163 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
172int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
173 return static_cast<int64_t>(getAddressSpace());
176bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
177 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
181 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
198 elementType, operand);
212 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
221 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
222 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
224 if (
shape.size() != 2)
225 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
229 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
238bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
241 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
251bool GPUDialect::isKernel(
Operation *op) {
252 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
259struct GPUInlinerInterface :
public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
263 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
286 declarePromisedInterfaces<
287 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
288 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
289 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
295 return "sparse.dntensor_handle";
297 return "sparse.spmat_handle";
299 return "sparse.spgemmop_handle";
301 llvm_unreachable(
"unknown sparse handle kind");
305Type GPUDialect::parseType(DialectAsmParser &parser)
const {
313 if (keyword ==
"async.token")
316 if (keyword ==
"mma_matrix") {
324 SmallVector<int64_t> shape;
345 shape, elementType, operand);
360void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
363 .Case<SparseDnTensorHandleType>([&](Type) {
366 .Case<SparseSpMatHandleType>(
368 .Case<SparseSpGEMMOpHandleType>([&](Type) {
374 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
377 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
379 .DefaultUnreachable(
"unexpected 'gpu' type kind");
384 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
387 " must be a dense i32 array");
388 if (array.size() != 3)
390 " must contain exactly 3 elements");
394LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
395 NamedAttribute attr) {
396 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
398 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
400 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
402 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
403 attr.
getName() != getContainerModuleAttrName())
406 auto module = dyn_cast<ModuleOp>(op);
409 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
410 << ModuleOp::getOperationName() <<
'\'';
412 auto walkResult =
module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
415 if (!launchOp->getParentOp() ||
416 launchOp->getParentOp()->getParentOp() != module)
421 if (!launchOp->getAttrOfType<SymbolRefAttr>(
422 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
426 StringAttr kernelContainerName = launchOp.getKernelModuleName();
427 Operation *kernelContainer =
module.lookupSymbol(kernelContainerName);
428 if (!kernelContainer)
429 return launchOp.emitOpError()
430 <<
"kernel container '" << kernelContainerName.getValue()
434 if (isa<BinaryOp>(kernelContainer))
437 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
439 return launchOp.emitOpError()
440 <<
"kernel module '" << kernelContainerName.getValue()
444 Operation *kernelFunc =
module.lookupSymbol(launchOp.getKernelAttr());
446 return launchOp.emitOpError(
"kernel function '")
447 << launchOp.getKernel() <<
"' is undefined";
448 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
449 if (!kernelConvertedFunction) {
450 InFlightDiagnostic
diag = launchOp.emitOpError()
451 <<
"referenced kernel '" << launchOp.getKernel()
452 <<
"' is not a function";
453 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
458 GPUDialect::getKernelFuncAttrName()))
459 return launchOp.emitOpError(
"kernel function is missing the '")
460 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
465 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
466 if (!kernelGPUFunction)
469 unsigned actualNumArguments = launchOp.getNumKernelOperands();
470 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
471 if (expectedNumArguments != actualNumArguments)
472 return launchOp.emitOpError(
"got ")
473 << actualNumArguments <<
" kernel operands but expected "
474 << expectedNumArguments;
476 auto functionType = kernelGPUFunction.getFunctionType();
477 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
478 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
479 return launchOp.emitOpError(
"type of function argument ")
480 << i <<
" does not match";
487 return walkResult.wasInterrupted() ? failure() :
success();
500 return parser.
emitError(loc,
"needs to be named when marked 'async'");
515 if (asyncDependencies.empty())
519 printer << llvm::interleaved_array(asyncDependencies);
547 p <<
' ' << keyword <<
'(';
548 llvm::interleaveComma(
549 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
550 BlockArgument v = pair.value();
551 p << v <<
" : " << v.
getType();
553 size_t attributionIndex = pair.index();
554 DictionaryAttr attrs;
555 if (attributes && attributionIndex < attributes.size())
556 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
566 gpu::AddressSpace memorySpace) {
567 for (
Value v : attributions) {
568 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
570 return op->
emitOpError() <<
"expected memref type in attribution";
575 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
578 if (addressSpace.getValue() != memorySpace)
580 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
581 <<
" in attribution";
592 using Kind = gpu::AllReduceOperation;
593 if (llvm::is_contained(
594 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
596 if (!isa<FloatType>(resType))
600 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
601 Kind::AND, Kind::OR, Kind::XOR},
603 if (!isa<IntegerType>(resType))
610LogicalResult gpu::AllReduceOp::verifyRegions() {
611 if (getBody().empty() != getOp().has_value())
612 return emitError(
"expected either an op attribute or a non-empty body");
613 if (!getBody().empty()) {
614 if (getBody().getNumArguments() != 2)
615 return emitError(
"expected two region arguments");
616 for (
auto argument : getBody().getArguments()) {
617 if (argument.getType() !=
getType())
618 return emitError(
"incorrect region argument type");
620 unsigned yieldCount = 0;
621 for (
Block &block : getBody()) {
622 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
623 if (yield.getNumOperands() != 1)
624 return emitError(
"expected one gpu.yield operand");
625 if (yield.getOperand(0).getType() !=
getType())
626 return emitError(
"incorrect gpu.yield type");
631 return emitError(
"expected gpu.yield op in region");
633 gpu::AllReduceOperation opName = *getOp();
635 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
636 <<
"` reduction operation is not compatible with type "
645 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
649 Region &body = launchOp.getBody();
650 assert(!body.
empty() &&
"Invalid region");
656OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
667 AllReduceOperationAttr &attr) {
670 std::optional<AllReduceOperation> op =
671 gpu::symbolizeAllReduceOperation(enumStr);
674 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
680 AllReduceOperationAttr attr) {
689LogicalResult gpu::SubgroupReduceOp::verify() {
691 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
692 if (vecTy.isScalable())
693 return emitOpError() <<
"is not compatible with scalable vector types";
695 elemType = vecTy.getElementType();
698 gpu::AllReduceOperation opName = getOp();
700 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
701 <<
"` reduction operation is not compatible with type "
705 auto clusterSize = getClusterSize();
707 uint32_t size = *clusterSize;
708 if (!llvm::isPowerOf2_32(size)) {
710 <<
" is not a power of two";
714 uint32_t stride = getClusterStride();
715 if (stride != 1 && !clusterSize) {
716 return emitOpError() <<
"cluster stride can only be specified if cluster "
719 if (!llvm::isPowerOf2_32(stride)) {
720 return emitOpError() <<
"cluster stride " << stride
721 <<
" is not a power of two";
727OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
728 if (getClusterSize() == 1)
745 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
749 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
767 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
777 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
781 result.addOperands(asyncDependencies);
786 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
787 getBlockSizeY, getBlockSizeZ});
789 result.addOperands(clusterSizeX);
791 result.addOperands(clusterSizeY);
793 result.addOperands(clusterSizeZ);
794 if (dynamicSharedMemorySize)
795 result.addOperands(dynamicSharedMemorySize);
799 result.addAttribute(getModuleAttrName(
result.name), module);
801 result.addAttribute(getFunctionAttrName(
result.name), function);
809 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
812 for (
Type argTy : workgroupAttributions)
814 for (
Type argTy : privateAttributions)
818 segmentSizes.front() = asyncDependencies.size();
819 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
820 segmentSizes[7] = clusterSizeX ? 1 : 0;
821 segmentSizes[8] = clusterSizeY ? 1 : 0;
822 segmentSizes[9] = clusterSizeZ ? 1 : 0;
823 result.addAttribute(getOperandSegmentSizeAttr(),
828 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
829 auto args = getBody().getArguments();
834 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
835 auto args = getBody().getArguments();
840 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
841 auto args = getBody().getArguments();
846 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
847 auto args = getBody().getArguments();
848 return KernelDim3{args[9], args[10], args[11]};
851std::optional<KernelDim3> LaunchOp::getClusterIds() {
852 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
853 if (!hasClusterSize())
855 auto args = getBody().getArguments();
856 return KernelDim3{args[12], args[13], args[14]};
859std::optional<KernelDim3> LaunchOp::getClusterSize() {
860 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
861 if (!hasClusterSize())
863 auto args = getBody().getArguments();
864 return KernelDim3{args[15], args[16], args[17]};
867KernelDim3 LaunchOp::getGridSizeOperandValues() {
868 auto operands = getOperands().drop_front(getAsyncDependencies().size());
869 return KernelDim3{operands[0], operands[1], operands[2]};
872KernelDim3 LaunchOp::getBlockSizeOperandValues() {
873 auto operands = getOperands().drop_front(getAsyncDependencies().size());
874 return KernelDim3{operands[3], operands[4], operands[5]};
877std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
878 auto operands = getOperands().drop_front(getAsyncDependencies().size());
879 if (!hasClusterSize())
881 return KernelDim3{operands[6], operands[7], operands[8]};
884LogicalResult LaunchOp::verify() {
885 if (!(hasClusterSize()) &&
886 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
887 return emitOpError() <<
"cluster size must be all present";
891LogicalResult LaunchOp::verifyRegions() {
895 if (!getBody().empty()) {
896 if (getBody().getNumArguments() <
897 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
898 return emitOpError(
"unexpected number of region arguments");
903 GPUDialect::getWorkgroupAddressSpace())) ||
905 GPUDialect::getPrivateAddressSpace())))
910 for (
Block &block : getBody()) {
913 if (block.back().getNumSuccessors() != 0)
915 if (!isa<gpu::TerminatorOp>(&block.back())) {
918 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
919 "' or a terminator with successors")
920 .attachNote(getLoc())
921 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
925 if (getNumResults() == 0 && getAsyncToken())
926 return emitOpError(
"needs to be named when async keyword is specified");
937 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
938 p << size.
x <<
" = " << operands.
x <<
", ";
939 p << size.
y <<
" = " << operands.
y <<
", ";
940 p << size.
z <<
" = " << operands.
z <<
')';
943void LaunchOp::print(OpAsmPrinter &p) {
944 if (getAsyncToken()) {
946 if (!getAsyncDependencies().empty())
947 p <<
" [" << getAsyncDependencies() <<
']';
950 if (hasClusterSize()) {
951 p <<
' ' << getClustersKeyword();
953 getClusterSizeOperandValues().value(),
954 getClusterIds().value());
956 p <<
' ' << getBlocksKeyword();
959 p <<
' ' << getThreadsKeyword();
962 if (getDynamicSharedMemorySize())
963 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
964 << getDynamicSharedMemorySize();
967 StringRef moduleAttrName = getModuleAttrName();
968 if (
auto module = getModule()) {
969 p <<
' ' << moduleAttrName <<
'(';
974 StringRef functionAttrName = getFunctionAttrName();
975 if (
auto function = getFunction()) {
976 p <<
' ' << functionAttrName <<
'(';
988 LaunchOp::getOperandSegmentSizeAttr(),
989 getNumWorkgroupAttributionsAttrName(),
990 moduleAttrName, functionAttrName});
1004 assert(
indices.size() == 3 &&
"space for three indices expected");
1010 std::move(args.begin(), args.end(),
indices.begin());
1012 for (
int i = 0; i < 3; ++i) {
1034ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
1036 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1037 sizes(LaunchOp::kNumConfigOperands);
1040 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1041 LaunchOp::kNumConfigRegionAttributes);
1044 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1045 Type asyncTokenType;
1052 if (!asyncTokenType)
1055 "gpu.launch requires 'async' keyword to return a value");
1056 result.types.push_back(asyncTokenType);
1059 bool hasCluster =
false;
1064 regionArgs.resize(18);
1066 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1067 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1073 regionArgsRef.slice(15, 3),
1074 regionArgsRef.slice(12, 3)))
1082 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1084 regionArgsRef.slice(6, 3),
1085 regionArgsRef.slice(0, 3)) ||
1086 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1088 regionArgsRef.slice(9, 3),
1089 regionArgsRef.slice(3, 3)) ||
1094 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1095 bool hasDynamicSharedMemorySize =
false;
1097 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1098 hasDynamicSharedMemorySize =
true;
1107 StringRef moduleAttrName = getModuleAttrName(
result.name);
1109 FlatSymbolRefAttr moduleSymbol;
1117 StringRef functionAttrName = getFunctionAttrName(
result.name);
1119 FlatSymbolRefAttr funcSymbol;
1134 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1135 LaunchOp::kNumConfigRegionAttributes + 6, index);
1137 SmallVector<OpAsmParser::Argument> regionArguments;
1138 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1139 OpAsmParser::Argument arg;
1140 arg.
ssaName = std::get<0>(ssaValueAndType);
1141 arg.
type = std::get<1>(ssaValueAndType);
1142 regionArguments.push_back(arg);
1153 unsigned numWorkgroupAttrs = regionArguments.size() -
1154 LaunchOp::kNumConfigRegionAttributes -
1155 (hasCluster ? 6 : 0);
1156 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1167 Region *body =
result.addRegion();
1172 SmallVector<int32_t, 11> segmentSizes(11, 1);
1173 segmentSizes.front() = asyncDependencies.size();
1176 segmentSizes[7] = 0;
1177 segmentSizes[8] = 0;
1178 segmentSizes[9] = 0;
1180 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1181 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1195 bool simplified =
false;
1196 auto constPropIdUses = [&](
Value id,
Value size) {
1200 if (
id.getUses().empty())
1212 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1213 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1214 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1215 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1216 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1217 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1223void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1224 MLIRContext *context) {
1225 rewrites.
add<FoldLaunchArguments>(context);
1230BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1231 auto attrName = getNumWorkgroupAttributionsAttrName();
1232 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1233 (*this)->setAttr(attrName,
1234 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1235 return getBody().insertArgument(
1236 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1241BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1244 return getBody().addArgument(type, loc);
1251void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1252 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1253 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1254 ValueRange kernelOperands, Type asyncTokenType,
1256 std::optional<KernelDim3> clusterSize) {
1257 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1258 "expected a symbol reference with a single nested reference");
1259 result.addOperands(asyncDependencies);
1266 if (clusterSize.has_value())
1267 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1268 if (dynamicSharedMemorySize)
1269 result.addOperands(dynamicSharedMemorySize);
1270 result.addOperands(kernelOperands);
1272 Properties &prop =
result.getOrAddProperties<Properties>();
1273 prop.kernel = kernelSymbol;
1274 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1276 llvm::fill(prop.operandSegmentSizes, 1);
1277 prop.operandSegmentSizes[0] = asyncDependencies.size();
1278 if (!clusterSize.has_value()) {
1279 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1280 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1281 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1283 prop.operandSegmentSizes[segmentSizesLen - 3] =
1284 dynamicSharedMemorySize ? 1 : 0;
1285 prop.operandSegmentSizes[segmentSizesLen - 2] =
1286 static_cast<int32_t
>(kernelOperands.size());
1287 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1290void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1292 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1293 ValueRange kernelOperands, Type asyncTokenType,
1295 std::optional<KernelDim3> clusterSize) {
1296 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1298 SymbolRefAttr::get(kernelModule.getNameAttr(),
1299 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1300 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1301 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1302 asyncDependencies, clusterSize);
1305void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1307 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1308 ValueRange kernelOperands, Value asyncObject,
1309 std::optional<KernelDim3> clusterSize) {
1313 if (clusterSize.has_value())
1314 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1315 if (dynamicSharedMemorySize)
1316 result.addOperands(dynamicSharedMemorySize);
1317 result.addOperands(kernelOperands);
1319 result.addOperands(asyncObject);
1320 Properties &prop =
result.getOrAddProperties<Properties>();
1321 prop.kernel = kernel;
1322 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1324 llvm::fill(prop.operandSegmentSizes, 1);
1325 prop.operandSegmentSizes[0] = 0;
1326 if (!clusterSize.has_value()) {
1327 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1328 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1329 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1331 prop.operandSegmentSizes[segmentSizesLen - 3] =
1332 dynamicSharedMemorySize ? 1 : 0;
1333 prop.operandSegmentSizes[segmentSizesLen - 2] =
1334 static_cast<int32_t
>(kernelOperands.size());
1335 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1338StringAttr LaunchFuncOp::getKernelModuleName() {
1342StringAttr LaunchFuncOp::getKernelName() {
1346unsigned LaunchFuncOp::getNumKernelOperands() {
1347 return getKernelOperands().size();
1350Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1351 return getKernelOperands()[i];
1354KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1355 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1356 return KernelDim3{operands[0], operands[1], operands[2]};
1359KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1360 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1361 return KernelDim3{operands[3], operands[4], operands[5]};
1364KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1365 assert(hasClusterSize() &&
1366 "cluster size is not set, check hasClusterSize() first");
1367 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1368 return KernelDim3{operands[6], operands[7], operands[8]};
1371LogicalResult LaunchFuncOp::verify() {
1372 auto module = (*this)->getParentOfType<ModuleOp>();
1374 return emitOpError(
"expected to belong to a module");
1376 if (!module->getAttrOfType<UnitAttr>(
1377 GPUDialect::getContainerModuleAttrName()))
1378 return emitOpError(
"expected the closest surrounding module to have the '" +
1379 GPUDialect::getContainerModuleAttrName() +
1382 if (hasClusterSize()) {
1383 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1386 <<
"expects types of the cluster dimensions must be the same";
1394 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1395 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1402 if (clusterValue.has_value()) {
1403 clusterXTy = clusterYTy = clusterZTy = dimTy;
1410 Type clusterYTy,
Type clusterZTy) {
1412 printer <<
": " << dimTy;
1422 auto parseElement = [&]() -> ParseResult {
1423 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1428 parseElement,
" in argument list");
1433 if (operands.empty())
1436 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1437 [&](
const auto &pair) {
1438 auto [operand, type] = pair;
1439 printer << operand <<
" : " << type;
1448void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1449 int32_t offset, int32_t width, ShuffleMode mode) {
1450 build(builder,
result, value,
1451 arith::ConstantOp::create(builder,
result.location,
1453 arith::ConstantOp::create(builder,
result.location,
1462LogicalResult RotateOp::verify() {
1463 uint32_t offset = getOffset();
1464 uint32_t width = getWidth();
1466 if (offset >= width) {
1467 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1480 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1484 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1485 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1489 if (!nextMemfence) {
1490 op.removeAddressSpacesAttr();
1494 if (*thisMemfence == *nextMemfence) {
1498 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1500 mergedSpaces.insert(attr);
1502 mergedSpaces.insert(attr);
1503 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1511void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1512 MLIRContext *context) {
1516void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1517 mlir::OperationState &odsState,
1518 std::optional<AddressSpace> addressSpace) {
1522 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1523 build(odsBuilder, odsState, addressSpacesAttr);
1530void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1531 Value memrefToFence) {
1532 std::optional<AddressSpace> addrSpaceToFence;
1533 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1534 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1535 memrefType.getMemorySpace()))
1536 addrSpaceToFence = addrSpaceAttr.getValue();
1537 return build(builder, odsState, addrSpaceToFence);
1546BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1547 auto attrName = getNumWorkgroupAttributionsAttrName();
1548 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1549 (*this)->setAttr(attrName,
1550 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1551 return getBody().insertArgument(
1552 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1557BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1560 return getBody().addArgument(type, loc);
1563void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1564 StringRef name, FunctionType type,
1567 ArrayRef<NamedAttribute> attrs) {
1568 OpBuilder::InsertionGuard g(builder);
1572 result.addAttribute(getFunctionTypeAttrName(
result.name),
1573 TypeAttr::get(type));
1574 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1576 result.addAttributes(attrs);
1577 Region *body =
result.addRegion();
1581 for (Type argTy : type.getInputs())
1583 for (Type argTy : workgroupAttributions)
1585 for (Type argTy : privateAttributions)
1604 size_t existingArgs = args.size();
1611 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1616 attributionAttrs =
nullptr;
1622 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1623 if (!argument.attrs)
1626 attributionAttrsVec.push_back(argument.attrs);
1628 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1637ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1638 SmallVector<OpAsmParser::Argument> entryArgs;
1639 SmallVector<DictionaryAttr> resultAttrs;
1640 SmallVector<Type> resultTypes;
1644 StringAttr nameAttr;
1651 parser,
false, entryArgs, isVariadic, resultTypes,
1655 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1656 return parser.
emitError(signatureLocation)
1657 <<
"gpu.func requires named arguments";
1663 SmallVector<Type> argTypes;
1664 for (
auto &arg : entryArgs)
1665 argTypes.push_back(arg.
type);
1667 result.addAttribute(getFunctionTypeAttrName(
result.name),
1668 TypeAttr::get(type));
1671 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1672 getResAttrsAttrName(
result.name));
1674 Attribute workgroupAttributionAttrs;
1677 entryArgs, workgroupAttributionAttrs)))
1682 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1683 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1685 if (workgroupAttributionAttrs)
1686 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1687 workgroupAttributionAttrs);
1689 Attribute privateAttributionAttrs;
1692 entryArgs, privateAttributionAttrs)))
1694 if (privateAttributionAttrs)
1695 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1696 privateAttributionAttrs);
1700 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1709 auto *body =
result.addRegion();
1713void GPUFuncOp::print(OpAsmPrinter &p) {
1717 FunctionType type = getFunctionType();
1723 getWorkgroupAttribAttrs().value_or(
nullptr));
1725 getPrivateAttribAttrs().value_or(
nullptr));
1727 p <<
' ' << getKernelKeyword();
1731 {getNumWorkgroupAttributionsAttrName(),
1732 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1733 getArgAttrsAttrName(), getResAttrsAttrName(),
1734 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1740 StringAttr attrName) {
1741 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1742 if (!allAttrs ||
index >= allAttrs.size())
1743 return DictionaryAttr();
1744 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1747DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1751DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1756 DictionaryAttr value, StringAttr attrName) {
1758 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1761 elements.append(allAttrs.begin(), allAttrs.end());
1762 while (elements.size() <=
index)
1763 elements.push_back(DictionaryAttr::get(ctx));
1765 elements[
index] = DictionaryAttr::get(ctx);
1767 elements[
index] = value;
1768 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1769 op->setAttr(attrName, newValue);
1772void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1773 DictionaryAttr value) {
1777void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1778 DictionaryAttr value) {
1783 StringAttr name, StringAttr attrsName) {
1787 return dict.get(name);
1790Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1792 assert(index < getNumWorkgroupAttributions() &&
1793 "index must map to a workgroup attribution");
1795 getWorkgroupAttribAttrsAttrName());
1798Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1800 assert(index < getNumPrivateAttributions() &&
1801 "index must map to a private attribution");
1803 getPrivateAttribAttrsAttrName());
1807 Attribute value, StringAttr attrsName) {
1812 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1815 bool mustSort =
true;
1816 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1817 if (elems[i].getName() == name) {
1820 std::swap(elems[i], elems[elems.size() - 1]);
1832 elems.emplace_back(name, value);
1835 DictionaryAttr::sortInPlace(elems);
1837 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1841void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1843 assert(index < getNumWorkgroupAttributions() &&
1844 "index must map to a workgroup attribution");
1846 getWorkgroupAttribAttrsAttrName());
1849void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1851 assert(index < getNumPrivateAttributions() &&
1852 "index must map to a private attribution");
1854 getPrivateAttribAttrsAttrName());
1857LogicalResult GPUFuncOp::verifyType() {
1858 if (isKernel() && getFunctionType().getNumResults() != 0)
1859 return emitOpError() <<
"expected void return type for kernel function";
1865LogicalResult GPUFuncOp::verifyBody() {
1867 return emitOpError() <<
"expected body with at least one block";
1868 unsigned numFuncArguments = getNumArguments();
1869 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1870 unsigned numBlockArguments = front().getNumArguments();
1871 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1873 << numFuncArguments + numWorkgroupAttributions
1874 <<
" arguments to body region";
1876 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1877 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1878 Type blockArgType = front().getArgument(i).getType();
1879 if (funcArgTypes[i] != blockArgType)
1880 return emitOpError() <<
"expected body region argument #" << i
1881 <<
" to be of type " << funcArgTypes[i] <<
", got "
1886 GPUDialect::getWorkgroupAddressSpace())) ||
1888 GPUDialect::getPrivateAddressSpace())))
1898LogicalResult gpu::ReturnOp::verify() {
1899 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1901 FunctionType funType = function.getFunctionType();
1903 if (funType.getNumResults() != getOperands().size())
1905 .append(
"expected ", funType.getNumResults(),
" result operands")
1906 .attachNote(function.getLoc())
1907 .append(
"return type declared here");
1909 for (
const auto &pair : llvm::enumerate(
1910 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1911 auto [type, operand] = pair.value();
1912 if (type != operand.getType())
1913 return emitOpError() <<
"unexpected type `" << operand.getType()
1914 <<
"' for operand #" << pair.index();
1923void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1925 Attribute offloadingHandler) {
1926 result.addRegion()->emplaceBlock();
1927 Properties &props =
result.getOrAddProperties<Properties>();
1929 props.targets = targets;
1931 props.offloadingHandler = offloadingHandler;
1934void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1935 StringRef name, ArrayRef<Attribute> targets,
1936 Attribute offloadingHandler) {
1937 build(builder,
result, name,
1942bool GPUModuleOp::hasTarget(Attribute
target) {
1943 if (
ArrayAttr targets = getTargetsAttr())
1944 return llvm::count(targets.getValue(),
target);
1948void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1949 ArrayAttr &targetsAttr = getProperties().targets;
1950 SmallVector<Attribute> targetsVector(targets);
1951 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1954LogicalResult GPUModuleOp::verify() {
1955 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1960 for (
auto target : targets) {
1961 if (
auto verifyTargetAttr =
1962 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1963 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1973void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1974 Attribute offloadingHandler,
ArrayAttr objects) {
1975 auto &properties =
result.getOrAddProperties<Properties>();
1978 properties.objects = objects;
1979 if (offloadingHandler)
1980 properties.offloadingHandler = offloadingHandler;
1982 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1985void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1986 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1987 build(builder,
result, name, offloadingHandler,
1999 if (!offloadingHandler)
2006 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2007 printer <<
'<' << offloadingHandler <<
'>';
2014LogicalResult MemcpyOp::verify() {
2015 auto srcType = getSrc().getType();
2016 auto dstType = getDst().getType();
2019 return emitOpError(
"arguments have incompatible element type");
2022 return emitOpError(
"arguments have incompatible shape");
2031struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2032 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2034 LogicalResult matchAndRewrite(MemcpyOp op,
2035 PatternRewriter &rewriter)
const override {
2036 Value dest = op.getDst();
2045 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2046 return user != op &&
2047 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2053 if (op.getAsyncDependencies().size() > 1 ||
2054 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2055 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2057 rewriter.
replaceOp(op, op.getAsyncDependencies());
2064void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2065 MLIRContext *context) {
2066 results.
add<EraseTrivialCopyOp>(context);
2073LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2074 auto srcType = getSrcMemref().getType();
2075 auto resType = getRes().getType();
2076 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2077 auto operand = resMatrixType.getOperand();
2078 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2080 if (!srcMemrefType.isLastDimUnitStride())
2082 "expected source memref most minor dim must have unit stride");
2084 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2085 return emitError(
"only AOp, BOp and COp can be loaded");
2094LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2095 auto srcType = getSrc().getType();
2096 auto dstType = getDstMemref().getType();
2097 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2098 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2100 if (!dstMemrefType.isLastDimUnitStride())
2102 "expected destination memref most minor dim must have unit stride");
2104 if (srcMatrixType.getOperand() !=
"COp")
2106 "expected the operand matrix being stored to have 'COp' operand type");
2115LogicalResult SubgroupMmaComputeOp::verify() {
2116 enum OperandMap {
A,
B,
C };
2117 SmallVector<MMAMatrixType, 3> opTypes;
2118 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2119 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2120 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2122 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2123 opTypes[C].getOperand() !=
"COp")
2124 return emitError(
"operands must be in the order AOp, BOp, COp");
2126 ArrayRef<int64_t> aShape, bShape, cShape;
2127 aShape = opTypes[
A].getShape();
2128 bShape = opTypes[
B].getShape();
2129 cShape = opTypes[
C].getShape();
2131 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2132 bShape[1] != cShape[1])
2133 return emitError(
"operand shapes do not satisfy matmul constraints");
2138LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2139 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2143LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2144 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2157struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2161 LogicalResult matchAndRewrite(WaitOp op,
2162 PatternRewriter &rewriter)
const final {
2163 auto predicate = [](Value value) {
2164 auto waitOp = value.getDefiningOp<WaitOp>();
2165 return waitOp && waitOp->getNumOperands() == 0;
2167 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2169 SmallVector<Value> validOperands;
2170 for (Value operand : op->getOperands()) {
2171 if (predicate(operand))
2173 validOperands.push_back(operand);
2175 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2187struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2191 LogicalResult matchAndRewrite(WaitOp op,
2192 PatternRewriter &rewriter)
const final {
2195 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2200 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2201 op.getAsyncToken()) {
2202 rewriter.
replaceOp(op, op.getAsyncDependencies());
2206 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2216void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2217 MLIRContext *context) {
2218 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2225LogicalResult AllocOp::verify() {
2226 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2232 unsigned numSymbols = 0;
2233 if (!memRefType.getLayout().isIdentity())
2234 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2235 if (getSymbolOperands().size() != numSymbols) {
2237 "symbol operand count does not equal memref symbol count");
2247struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2248 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2250 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2251 PatternRewriter &rewriter)
const override {
2252 std::optional<int64_t> index = dimOp.getConstantIndex();
2256 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2257 if (!memrefType || index.value() >= memrefType.getRank() ||
2258 !memrefType.isDynamicDim(index.value()))
2261 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2265 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2266 memrefType.getDynamicDimIndex(index.value()));
2267 rewriter.
replaceOp(dimOp, substituteOp);
2274void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2275 MLIRContext *context) {
2276 results.
add<SimplifyDimOfAllocOp>(context);
2284 Attribute
target, CompilationTarget format,
2285 StringAttr
object, DictionaryAttr properties,
2286 KernelTableAttr kernels) {
2288 return emitError() <<
"the target attribute cannot be null";
2289 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2291 return emitError() <<
"the target attribute must implement or promise the "
2292 "`gpu::TargetAttrInterface`";
2296ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2297 StringAttr &
object) {
2298 std::optional<CompilationTarget> formatResult;
2299 StringRef enumKeyword;
2302 formatResult = CompilationTarget::Fatbin;
2303 if (!formatResult &&
2305 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2307 return odsParser.
emitError(loc,
"expected an equal sign");
2309 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2310 FailureOr<StringAttr> objectResult =
2311 FieldParser<StringAttr>::parse(odsParser);
2312 if (
failed(objectResult))
2314 "failed to parse GPU_ObjectAttr parameter "
2315 "'object' which is to be a `StringAttr`");
2316 format = *formatResult;
2317 object = *objectResult;
2321void printObject(AsmPrinter &odsParser, CompilationTarget format,
2322 StringAttr
object) {
2323 if (format != CompilationTarget::Fatbin)
2324 odsParser << stringifyEnum(format) <<
" = ";
2325 odsParser << object;
2338 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2339 if (intAttr.getInt() < 0) {
2340 return emitError() <<
"the object index must be positive";
2342 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2344 <<
"the target attribute must be a GPU Target attribute";
2354LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2355 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2356 return emitOpError() <<
"must be inside an op with symbol table";
2358 MemRefType memrefType = getResultMemref().getType();
2360 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2362 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2363 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2365 if (memrefType.hasStaticShape()) {
2366 return emitOpError() <<
"result memref type must be memref<?xi8, "
2367 "#gpu.address_space<workgroup>>";
2376void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2377 p <<
"(" << getLaneid() <<
")";
2379 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2380 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2381 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2383 if (!getArgs().empty())
2384 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2385 if (!getResults().empty())
2386 p <<
" -> (" << getResults().getTypes() <<
')';
2390 !getResults().empty());
2394ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2395 OperationState &
result) {
2397 result.regions.reserve(1);
2398 Region *warpRegion =
result.addRegion();
2401 OpAsmParser::UnresolvedOperand laneId;
2413 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2420 llvm::SMLoc inputsOperandsLoc;
2421 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2422 SmallVector<Type> inputTypes;
2432 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2443 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2451void WarpExecuteOnLane0Op::getSuccessorRegions(
2452 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2459 regions.push_back(RegionSuccessor(&getWarpRegion()));
2462ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2465void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2468 build(builder,
result, resultTypes, laneId, warpSize,
2472void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2476 result.addOperands(laneId);
2477 result.addAttribute(getAttributeNames()[0],
2479 result.addTypes(resultTypes);
2480 result.addOperands(args);
2481 assert(args.size() == blockArgTypes.size());
2482 OpBuilder::InsertionGuard guard(builder);
2483 Region *warpRegion =
result.addRegion();
2485 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2494 if (expanded == distributed)
2496 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2497 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2498 if (!expandedVecType || !distributedVecType)
2499 return op->
emitOpError(
"expected vector type for distributed operands.");
2500 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2501 expandedVecType.getElementType() != distributedVecType.getElementType())
2503 "expected distributed vectors to have same rank and element type.");
2506 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2507 int64_t eDim = expandedVecType.getDimSize(i);
2508 int64_t dDim = distributedVecType.getDimSize(i);
2511 if (eDim % dDim != 0)
2513 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2514 <<
") to be a multipler of the distributed vector dimension ("
2516 scales[i] = eDim / dDim;
2518 if (llvm::product_of(scales) != warpSize)
2520 <<
"incompatible distribution dimensions from " << expandedVecType
2521 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2526LogicalResult WarpExecuteOnLane0Op::verify() {
2527 if (getArgs().size() != getWarpRegion().getNumArguments())
2529 "expected same number op arguments and block arguments.");
2530 gpu::YieldOp yield = getTerminator();
2531 if (yield.getNumOperands() != getNumResults())
2533 "expected same number of yield operands and return values.");
2534 int64_t warpSize = getWarpSize();
2535 for (
auto [regionArg, arg] :
2536 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2538 warpSize, getOperation())))
2541 for (
auto [yieldOperand,
result] :
2542 llvm::zip_equal(yield.getOperands(), getResults())) {
2544 warpSize, getOperation())))
2549bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2554gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2555 return cast<gpu::YieldOp>(getBody()->getTerminator());
2562void gpu::SubgroupBroadcastOp::inferResultRanges(
2563 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2564 setResultRange(getResult(), argRanges.front());
2568 switch (getBroadcastType()) {
2569 case BroadcastType::first_active_lane:
2573 case BroadcastType::specific_lane:
2577 llvm_unreachable(
"Unknown BroadcastType");
2580LogicalResult gpu::SubgroupBroadcastOp::verify() {
2581 switch (getBroadcastType()) {
2582 case BroadcastType::first_active_lane:
2585 <<
"lane can only be specified for `specific_lane` broadcast";
2587 case BroadcastType::specific_lane:
2590 <<
"lane must be specified for `specific_lane` broadcast";
2593 llvm_unreachable(
"Unknown BroadcastType");
2596OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2598 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2599 return prev.getResult();
2608KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2609 DictionaryAttr metadata) {
2610 assert(kernel &&
"invalid kernel");
2611 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2612 kernel.getAllArgAttrs(), metadata);
2617 FunctionOpInterface kernel,
2618 DictionaryAttr metadata) {
2619 assert(kernel &&
"invalid kernel");
2620 return getChecked(
emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2621 kernel.getAllArgAttrs(), metadata);
2625KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2628 NamedAttrList attrList;
2629 if (DictionaryAttr dict = getMetadata())
2632 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2638 StringAttr name, Type functionType,
2639 ArrayAttr argAttrs, DictionaryAttr metadata) {
2641 return emitError() <<
"the kernel name can't be empty";
2643 if (llvm::any_of(argAttrs, [](Attribute attr) {
2644 return !llvm::isa<DictionaryAttr>(attr);
2647 <<
"all attributes in the array must be a dictionary attribute";
2656KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2657 ArrayRef<KernelMetadataAttr> kernels,
2660 assert((!isSorted || llvm::is_sorted(kernels)) &&
2661 "expected a sorted kernel array");
2663 if (isSorted || llvm::is_sorted(kernels))
2664 return Base::get(context, kernels);
2666 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2667 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2668 return Base::get(context, kernelsTmp);
2671KernelTableAttr KernelTableAttr::getChecked(
2673 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2675 assert((!isSorted || llvm::is_sorted(kernels)) &&
2676 "expected a sorted kernel array");
2678 if (isSorted || llvm::is_sorted(kernels))
2679 return Base::getChecked(
emitError, context, kernels);
2681 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2682 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2683 return Base::getChecked(
emitError, context, kernelsTmp);
2688 ArrayRef<KernelMetadataAttr> kernels) {
2689 if (kernels.size() < 2)
2692 if (std::adjacent_find(kernels.begin(), kernels.end(),
2693 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2694 return l.getName() == r.getName();
2695 }) != kernels.end()) {
2696 return emitError() <<
"expected all kernels to be uniquely named";
2701KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2703 return found ? *iterator : KernelMetadataAttr();
2706KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2708 return found ? *iterator : KernelMetadataAttr();
2788 return CompilationTarget::Fatbin;
2791std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2793 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2794 llvm::StringSaver stringSaver(
options.first);
2800 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2801 opts.consume_front(
"\""), opts.consume_back(
"\"");
2802 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2803 opts.consume_front(
"'"), opts.consume_back(
"'");
2805 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2808 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2814std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2819std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2821 size_t startPos =
cmdOptions.find(startsWith);
2822 if (startPos == std::string::npos)
2833#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2834#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2836#define GET_ATTRDEF_CLASSES
2837#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2839#define GET_OP_CLASSES
2840#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2842#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....