35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/ErrorHandling.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/InterleavedRange.h"
41#include "llvm/Support/StringSaver.h"
48#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
54int64_t GPUBlockMappingAttr::getMappingId()
const {
55 return static_cast<int64_t>(getBlock());
58bool GPUBlockMappingAttr::isLinearMapping()
const {
59 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
62int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
63 return isLinearMapping()
64 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
68int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
69 return static_cast<int64_t>(getWarpgroup());
72bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
73 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
76int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
77 return isLinearMapping()
78 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
82int64_t GPUWarpMappingAttr::getMappingId()
const {
83 return static_cast<int64_t>(getWarp());
86bool GPUWarpMappingAttr::isLinearMapping()
const {
87 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
90int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
91 return isLinearMapping()
92 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
96int64_t GPUThreadMappingAttr::getMappingId()
const {
97 return static_cast<int64_t>(getThread());
100bool GPUThreadMappingAttr::isLinearMapping()
const {
101 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
104int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
105 return isLinearMapping()
106 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
110int64_t GPULaneMappingAttr::getMappingId()
const {
111 return static_cast<int64_t>(getLane());
114bool GPULaneMappingAttr::isLinearMapping()
const {
115 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
118int64_t GPULaneMappingAttr::getRelativeIndex()
const {
119 return isLinearMapping()
120 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
124int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
136Value GPUMappingMaskAttr::createLogicalLinearMappingId(
140 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
141 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
142 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
143 filter = arith::SubIOp::create(
b, loc, filter, one);
144 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
145 return math::CtPopOp::create(
b, loc, filteredId);
158Value GPUMappingMaskAttr::createIsActiveIdPredicate(
162 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
163 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
164 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
165 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
166 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
167 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
171int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
172 return static_cast<int64_t>(getAddressSpace());
175bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
176 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
179int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
180 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
197 elementType, operand);
211 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
220 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
221 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
223 if (
shape.size() != 2)
224 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
228 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
237bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
240 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241 return gpuAttr.getValue() == getWorkgroupAddressSpace();
245bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246 Attribute memorySpace = type.getMemorySpace();
247 return isWorkgroupMemoryAddressSpace(memorySpace);
250bool GPUDialect::isKernel(
Operation *op) {
251 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252 return static_cast<bool>(isKernelAttr);
258struct GPUInlinerInterface :
public DialectInlinerInterface {
259 using DialectInlinerInterface::DialectInlinerInterface;
262 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
268void GPUDialect::initialize() {
269 addTypes<AsyncTokenType>();
270 addTypes<MMAMatrixType>();
271 addTypes<SparseDnTensorHandleType>();
272 addTypes<SparseSpMatHandleType>();
273 addTypes<SparseSpGEMMOpHandleType>();
276#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
279#define GET_ATTRDEF_LIST
280#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 addInterfaces<GPUInlinerInterface>();
283 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 declarePromisedInterfaces<
286 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
294 return "sparse.dntensor_handle";
296 return "sparse.spmat_handle";
298 return "sparse.spgemmop_handle";
300 llvm_unreachable(
"unknown sparse handle kind");
304Type GPUDialect::parseType(DialectAsmParser &parser)
const {
312 if (keyword ==
"async.token")
315 if (keyword ==
"mma_matrix") {
323 SmallVector<int64_t> shape;
344 shape, elementType, operand);
359void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
362 .Case<SparseDnTensorHandleType>([&](Type) {
365 .Case<SparseSpMatHandleType>(
367 .Case<SparseSpGEMMOpHandleType>([&](Type) {
373 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
376 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
378 .DefaultUnreachable(
"unexpected 'gpu' type kind");
383 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
386 " must be a dense i32 array");
387 if (array.size() != 3)
389 " must contain exactly 3 elements");
393LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
394 NamedAttribute attr) {
395 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
397 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
399 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
401 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
402 attr.
getName() != getContainerModuleAttrName())
405 auto module = dyn_cast<ModuleOp>(op);
408 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
409 << ModuleOp::getOperationName() <<
'\'';
411 auto walkResult =
module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
414 if (!launchOp->getParentOp() ||
415 launchOp->getParentOp()->getParentOp() != module)
420 if (!launchOp->getAttrOfType<SymbolRefAttr>(
421 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
425 StringAttr kernelContainerName = launchOp.getKernelModuleName();
426 Operation *kernelContainer =
module.lookupSymbol(kernelContainerName);
427 if (!kernelContainer)
428 return launchOp.emitOpError()
429 <<
"kernel container '" << kernelContainerName.getValue()
433 if (isa<BinaryOp>(kernelContainer))
436 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
438 return launchOp.emitOpError()
439 <<
"kernel module '" << kernelContainerName.getValue()
443 Operation *kernelFunc =
module.lookupSymbol(launchOp.getKernelAttr());
445 return launchOp.emitOpError(
"kernel function '")
446 << launchOp.getKernel() <<
"' is undefined";
447 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
448 if (!kernelConvertedFunction) {
449 InFlightDiagnostic
diag = launchOp.emitOpError()
450 <<
"referenced kernel '" << launchOp.getKernel()
451 <<
"' is not a function";
452 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
457 GPUDialect::getKernelFuncAttrName()))
458 return launchOp.emitOpError(
"kernel function is missing the '")
459 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
464 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
465 if (!kernelGPUFunction)
468 unsigned actualNumArguments = launchOp.getNumKernelOperands();
469 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
470 if (expectedNumArguments != actualNumArguments)
471 return launchOp.emitOpError(
"got ")
472 << actualNumArguments <<
" kernel operands but expected "
473 << expectedNumArguments;
475 auto functionType = kernelGPUFunction.getFunctionType();
476 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
477 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
478 return launchOp.emitOpError(
"type of function argument ")
479 << i <<
" does not match";
486 return walkResult.wasInterrupted() ? failure() :
success();
499 return parser.
emitError(loc,
"needs to be named when marked 'async'");
514 if (asyncDependencies.empty())
518 printer << llvm::interleaved_array(asyncDependencies);
546 p <<
' ' << keyword <<
'(';
547 llvm::interleaveComma(
548 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
549 BlockArgument v = pair.value();
550 p << v <<
" : " << v.
getType();
552 size_t attributionIndex = pair.index();
553 DictionaryAttr attrs;
554 if (attributes && attributionIndex < attributes.size())
555 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
565 gpu::AddressSpace memorySpace) {
566 for (
Value v : attributions) {
567 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
569 return op->
emitOpError() <<
"expected memref type in attribution";
574 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
577 if (addressSpace.getValue() != memorySpace)
579 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
580 <<
" in attribution";
591 using Kind = gpu::AllReduceOperation;
592 if (llvm::is_contained(
593 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
595 if (!isa<FloatType>(resType))
599 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
600 Kind::AND, Kind::OR, Kind::XOR},
602 if (!isa<IntegerType>(resType))
609LogicalResult gpu::AllReduceOp::verifyRegions() {
610 if (getBody().empty() != getOp().has_value())
611 return emitError(
"expected either an op attribute or a non-empty body");
612 if (!getBody().empty()) {
613 if (getBody().getNumArguments() != 2)
614 return emitError(
"expected two region arguments");
615 for (
auto argument : getBody().getArguments()) {
616 if (argument.getType() !=
getType())
617 return emitError(
"incorrect region argument type");
619 unsigned yieldCount = 0;
620 for (
Block &block : getBody()) {
621 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
622 if (yield.getNumOperands() != 1)
623 return emitError(
"expected one gpu.yield operand");
624 if (yield.getOperand(0).getType() !=
getType())
625 return emitError(
"incorrect gpu.yield type");
630 return emitError(
"expected gpu.yield op in region");
632 gpu::AllReduceOperation opName = *getOp();
634 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
635 <<
"` reduction operation is not compatible with type "
644 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
648 Region &body = launchOp.getBody();
649 assert(!body.
empty() &&
"Invalid region");
655OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
666 AllReduceOperationAttr &attr) {
669 std::optional<AllReduceOperation> op =
670 gpu::symbolizeAllReduceOperation(enumStr);
673 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
679 AllReduceOperationAttr attr) {
688LogicalResult gpu::SubgroupReduceOp::verify() {
690 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
691 if (vecTy.isScalable())
692 return emitOpError() <<
"is not compatible with scalable vector types";
694 elemType = vecTy.getElementType();
697 gpu::AllReduceOperation opName = getOp();
699 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
700 <<
"` reduction operation is not compatible with type "
704 auto clusterSize = getClusterSize();
706 uint32_t size = *clusterSize;
707 if (!llvm::isPowerOf2_32(size)) {
709 <<
" is not a power of two";
713 uint32_t stride = getClusterStride();
714 if (stride != 1 && !clusterSize) {
715 return emitOpError() <<
"cluster stride can only be specified if cluster "
718 if (!llvm::isPowerOf2_32(stride)) {
719 return emitOpError() <<
"cluster stride " << stride
720 <<
" is not a power of two";
726OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
727 if (getClusterSize() == 1)
744 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
748 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
766 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
776 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
780 result.addOperands(asyncDependencies);
785 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
786 getBlockSizeY, getBlockSizeZ});
788 result.addOperands(clusterSizeX);
790 result.addOperands(clusterSizeY);
792 result.addOperands(clusterSizeZ);
793 if (dynamicSharedMemorySize)
794 result.addOperands(dynamicSharedMemorySize);
798 result.addAttribute(getModuleAttrName(
result.name), module);
800 result.addAttribute(getFunctionAttrName(
result.name), function);
808 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
811 for (
Type argTy : workgroupAttributions)
813 for (
Type argTy : privateAttributions)
817 segmentSizes.front() = asyncDependencies.size();
818 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
819 segmentSizes[7] = clusterSizeX ? 1 : 0;
820 segmentSizes[8] = clusterSizeY ? 1 : 0;
821 segmentSizes[9] = clusterSizeZ ? 1 : 0;
822 result.addAttribute(getOperandSegmentSizeAttr(),
827 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
828 auto args = getBody().getArguments();
833 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
834 auto args = getBody().getArguments();
839 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
840 auto args = getBody().getArguments();
845 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
846 auto args = getBody().getArguments();
847 return KernelDim3{args[9], args[10], args[11]};
850std::optional<KernelDim3> LaunchOp::getClusterIds() {
851 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
852 if (!hasClusterSize())
854 auto args = getBody().getArguments();
855 return KernelDim3{args[12], args[13], args[14]};
858std::optional<KernelDim3> LaunchOp::getClusterSize() {
859 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
860 if (!hasClusterSize())
862 auto args = getBody().getArguments();
863 return KernelDim3{args[15], args[16], args[17]};
866KernelDim3 LaunchOp::getGridSizeOperandValues() {
867 auto operands = getOperands().drop_front(getAsyncDependencies().size());
868 return KernelDim3{operands[0], operands[1], operands[2]};
871KernelDim3 LaunchOp::getBlockSizeOperandValues() {
872 auto operands = getOperands().drop_front(getAsyncDependencies().size());
873 return KernelDim3{operands[3], operands[4], operands[5]};
876std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
877 auto operands = getOperands().drop_front(getAsyncDependencies().size());
878 if (!hasClusterSize())
880 return KernelDim3{operands[6], operands[7], operands[8]};
883LogicalResult LaunchOp::verify() {
884 if (!(hasClusterSize()) &&
885 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
886 return emitOpError() <<
"cluster size must be all present";
890LogicalResult LaunchOp::verifyRegions() {
894 if (!getBody().empty()) {
895 if (getBody().getNumArguments() <
896 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
897 return emitOpError(
"unexpected number of region arguments");
902 GPUDialect::getWorkgroupAddressSpace())) ||
904 GPUDialect::getPrivateAddressSpace())))
909 for (
Block &block : getBody()) {
912 if (block.back().getNumSuccessors() != 0)
914 if (!isa<gpu::TerminatorOp>(&block.back())) {
917 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
918 "' or a terminator with successors")
919 .attachNote(getLoc())
920 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
924 if (getNumResults() == 0 && getAsyncToken())
925 return emitOpError(
"needs to be named when async keyword is specified");
936 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
937 p << size.
x <<
" = " << operands.
x <<
", ";
938 p << size.
y <<
" = " << operands.
y <<
", ";
939 p << size.
z <<
" = " << operands.
z <<
')';
942void LaunchOp::print(OpAsmPrinter &p) {
943 if (getAsyncToken()) {
945 if (!getAsyncDependencies().empty())
946 p <<
" [" << getAsyncDependencies() <<
']';
949 if (hasClusterSize()) {
950 p <<
' ' << getClustersKeyword();
952 getClusterSizeOperandValues().value(),
953 getClusterIds().value());
955 p <<
' ' << getBlocksKeyword();
958 p <<
' ' << getThreadsKeyword();
961 if (getDynamicSharedMemorySize())
962 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
963 << getDynamicSharedMemorySize();
966 StringRef moduleAttrName = getModuleAttrName();
967 if (
auto module = getModule()) {
968 p <<
' ' << moduleAttrName <<
'(';
973 StringRef functionAttrName = getFunctionAttrName();
974 if (
auto function = getFunction()) {
975 p <<
' ' << functionAttrName <<
'(';
987 LaunchOp::getOperandSegmentSizeAttr(),
988 getNumWorkgroupAttributionsAttrName(),
989 moduleAttrName, functionAttrName});
1003 assert(
indices.size() == 3 &&
"space for three indices expected");
1009 std::move(args.begin(), args.end(),
indices.begin());
1011 for (
int i = 0; i < 3; ++i) {
1033ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
1035 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1036 sizes(LaunchOp::kNumConfigOperands);
1039 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1040 LaunchOp::kNumConfigRegionAttributes);
1043 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1044 Type asyncTokenType;
1051 result.types.push_back(asyncTokenType);
1053 bool hasCluster =
false;
1058 regionArgs.resize(18);
1060 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1061 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1067 regionArgsRef.slice(15, 3),
1068 regionArgsRef.slice(12, 3)))
1076 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1078 regionArgsRef.slice(6, 3),
1079 regionArgsRef.slice(0, 3)) ||
1080 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1082 regionArgsRef.slice(9, 3),
1083 regionArgsRef.slice(3, 3)) ||
1088 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1089 bool hasDynamicSharedMemorySize =
false;
1091 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1092 hasDynamicSharedMemorySize =
true;
1101 StringRef moduleAttrName = getModuleAttrName(
result.name);
1103 FlatSymbolRefAttr moduleSymbol;
1111 StringRef functionAttrName = getFunctionAttrName(
result.name);
1113 FlatSymbolRefAttr funcSymbol;
1128 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1129 LaunchOp::kNumConfigRegionAttributes + 6, index);
1131 SmallVector<OpAsmParser::Argument> regionArguments;
1132 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1133 OpAsmParser::Argument arg;
1134 arg.
ssaName = std::get<0>(ssaValueAndType);
1135 arg.
type = std::get<1>(ssaValueAndType);
1136 regionArguments.push_back(arg);
1147 unsigned numWorkgroupAttrs = regionArguments.size() -
1148 LaunchOp::kNumConfigRegionAttributes -
1149 (hasCluster ? 6 : 0);
1150 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1161 Region *body =
result.addRegion();
1166 SmallVector<int32_t, 11> segmentSizes(11, 1);
1167 segmentSizes.front() = asyncDependencies.size();
1170 segmentSizes[7] = 0;
1171 segmentSizes[8] = 0;
1172 segmentSizes[9] = 0;
1174 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1175 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1189 bool simplified =
false;
1190 auto constPropIdUses = [&](
Value id,
Value size) {
1194 if (
id.getUses().empty())
1206 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1207 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1208 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1209 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1210 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1211 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1217void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1218 MLIRContext *context) {
1219 rewrites.
add<FoldLaunchArguments>(context);
1224BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1225 auto attrName = getNumWorkgroupAttributionsAttrName();
1226 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1227 (*this)->setAttr(attrName,
1228 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1229 return getBody().insertArgument(
1230 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1235BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1238 return getBody().addArgument(type, loc);
1245void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1246 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1247 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1248 ValueRange kernelOperands, Type asyncTokenType,
1250 std::optional<KernelDim3> clusterSize) {
1251 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1252 "expected a symbol reference with a single nested reference");
1253 result.addOperands(asyncDependencies);
1260 if (clusterSize.has_value())
1261 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1262 if (dynamicSharedMemorySize)
1263 result.addOperands(dynamicSharedMemorySize);
1264 result.addOperands(kernelOperands);
1266 Properties &prop =
result.getOrAddProperties<Properties>();
1267 prop.kernel = kernelSymbol;
1268 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1270 llvm::fill(prop.operandSegmentSizes, 1);
1271 prop.operandSegmentSizes[0] = asyncDependencies.size();
1272 if (!clusterSize.has_value()) {
1273 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1274 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1275 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1277 prop.operandSegmentSizes[segmentSizesLen - 3] =
1278 dynamicSharedMemorySize ? 1 : 0;
1279 prop.operandSegmentSizes[segmentSizesLen - 2] =
1280 static_cast<int32_t
>(kernelOperands.size());
1281 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1284void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1286 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1287 ValueRange kernelOperands, Type asyncTokenType,
1289 std::optional<KernelDim3> clusterSize) {
1290 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1292 SymbolRefAttr::get(kernelModule.getNameAttr(),
1293 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1294 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1295 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1296 asyncDependencies, clusterSize);
1299void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1301 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1302 ValueRange kernelOperands, Value asyncObject,
1303 std::optional<KernelDim3> clusterSize) {
1307 if (clusterSize.has_value())
1308 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1309 if (dynamicSharedMemorySize)
1310 result.addOperands(dynamicSharedMemorySize);
1311 result.addOperands(kernelOperands);
1313 result.addOperands(asyncObject);
1314 Properties &prop =
result.getOrAddProperties<Properties>();
1315 prop.kernel = kernel;
1316 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1318 llvm::fill(prop.operandSegmentSizes, 1);
1319 prop.operandSegmentSizes[0] = 0;
1320 if (!clusterSize.has_value()) {
1321 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1322 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1323 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1325 prop.operandSegmentSizes[segmentSizesLen - 3] =
1326 dynamicSharedMemorySize ? 1 : 0;
1327 prop.operandSegmentSizes[segmentSizesLen - 2] =
1328 static_cast<int32_t
>(kernelOperands.size());
1329 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1332StringAttr LaunchFuncOp::getKernelModuleName() {
1336StringAttr LaunchFuncOp::getKernelName() {
1340unsigned LaunchFuncOp::getNumKernelOperands() {
1341 return getKernelOperands().size();
1344Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1345 return getKernelOperands()[i];
1348KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1349 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1350 return KernelDim3{operands[0], operands[1], operands[2]};
1353KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1354 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1355 return KernelDim3{operands[3], operands[4], operands[5]};
1358KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1359 assert(hasClusterSize() &&
1360 "cluster size is not set, check hasClusterSize() first");
1361 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1362 return KernelDim3{operands[6], operands[7], operands[8]};
1365LogicalResult LaunchFuncOp::verify() {
1366 auto module = (*this)->getParentOfType<ModuleOp>();
1368 return emitOpError(
"expected to belong to a module");
1370 if (!module->getAttrOfType<UnitAttr>(
1371 GPUDialect::getContainerModuleAttrName()))
1372 return emitOpError(
"expected the closest surrounding module to have the '" +
1373 GPUDialect::getContainerModuleAttrName() +
1376 if (hasClusterSize()) {
1377 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1380 <<
"expects types of the cluster dimensions must be the same";
1388 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1389 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1396 if (clusterValue.has_value()) {
1397 clusterXTy = clusterYTy = clusterZTy = dimTy;
1404 Type clusterYTy,
Type clusterZTy) {
1406 printer <<
": " << dimTy;
1416 auto parseElement = [&]() -> ParseResult {
1417 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1422 parseElement,
" in argument list");
1427 if (operands.empty())
1430 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1431 [&](
const auto &pair) {
1432 auto [operand, type] = pair;
1433 printer << operand <<
" : " << type;
1442void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1443 int32_t offset, int32_t width, ShuffleMode mode) {
1444 build(builder,
result, value,
1445 arith::ConstantOp::create(builder,
result.location,
1447 arith::ConstantOp::create(builder,
result.location,
1456LogicalResult RotateOp::verify() {
1457 uint32_t offset = getOffset();
1458 uint32_t width = getWidth();
1460 if (offset >= width) {
1461 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1474LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1475 PatternRewriter &rewriter) {
1476 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1485void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1486 MLIRContext *context) {
1487 results.
add(eraseRedundantGpuBarrierOps);
1496BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1497 auto attrName = getNumWorkgroupAttributionsAttrName();
1498 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1499 (*this)->setAttr(attrName,
1500 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1501 return getBody().insertArgument(
1502 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1507BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1510 return getBody().addArgument(type, loc);
1513void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1514 StringRef name, FunctionType type,
1517 ArrayRef<NamedAttribute> attrs) {
1518 OpBuilder::InsertionGuard g(builder);
1522 result.addAttribute(getFunctionTypeAttrName(
result.name),
1523 TypeAttr::get(type));
1524 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1526 result.addAttributes(attrs);
1527 Region *body =
result.addRegion();
1531 for (Type argTy : type.getInputs())
1533 for (Type argTy : workgroupAttributions)
1535 for (Type argTy : privateAttributions)
1554 size_t existingArgs = args.size();
1561 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1566 attributionAttrs =
nullptr;
1572 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1573 if (!argument.attrs)
1576 attributionAttrsVec.push_back(argument.attrs);
1578 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1587ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1588 SmallVector<OpAsmParser::Argument> entryArgs;
1589 SmallVector<DictionaryAttr> resultAttrs;
1590 SmallVector<Type> resultTypes;
1594 StringAttr nameAttr;
1601 parser,
false, entryArgs, isVariadic, resultTypes,
1605 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1606 return parser.
emitError(signatureLocation)
1607 <<
"gpu.func requires named arguments";
1613 SmallVector<Type> argTypes;
1614 for (
auto &arg : entryArgs)
1615 argTypes.push_back(arg.
type);
1617 result.addAttribute(getFunctionTypeAttrName(
result.name),
1618 TypeAttr::get(type));
1621 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1622 getResAttrsAttrName(
result.name));
1624 Attribute workgroupAttributionAttrs;
1627 entryArgs, workgroupAttributionAttrs)))
1632 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1633 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1635 if (workgroupAttributionAttrs)
1636 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1637 workgroupAttributionAttrs);
1639 Attribute privateAttributionAttrs;
1642 entryArgs, privateAttributionAttrs)))
1644 if (privateAttributionAttrs)
1645 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1646 privateAttributionAttrs);
1650 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1659 auto *body =
result.addRegion();
1663void GPUFuncOp::print(OpAsmPrinter &p) {
1667 FunctionType type = getFunctionType();
1673 getWorkgroupAttribAttrs().value_or(
nullptr));
1675 getPrivateAttribAttrs().value_or(
nullptr));
1677 p <<
' ' << getKernelKeyword();
1681 {getNumWorkgroupAttributionsAttrName(),
1682 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1683 getArgAttrsAttrName(), getResAttrsAttrName(),
1684 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1690 StringAttr attrName) {
1691 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1692 if (!allAttrs ||
index >= allAttrs.size())
1693 return DictionaryAttr();
1694 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1697DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1701DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1706 DictionaryAttr value, StringAttr attrName) {
1708 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1711 elements.append(allAttrs.begin(), allAttrs.end());
1712 while (elements.size() <=
index)
1713 elements.push_back(DictionaryAttr::get(ctx));
1715 elements[
index] = DictionaryAttr::get(ctx);
1717 elements[
index] = value;
1718 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1719 op->setAttr(attrName, newValue);
1722void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1723 DictionaryAttr value) {
1727void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1728 DictionaryAttr value) {
1733 StringAttr name, StringAttr attrsName) {
1737 return dict.get(name);
1740Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1742 assert(index < getNumWorkgroupAttributions() &&
1743 "index must map to a workgroup attribution");
1745 getWorkgroupAttribAttrsAttrName());
1748Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1750 assert(index < getNumPrivateAttributions() &&
1751 "index must map to a private attribution");
1753 getPrivateAttribAttrsAttrName());
1757 Attribute value, StringAttr attrsName) {
1762 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1765 bool mustSort =
true;
1766 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1767 if (elems[i].getName() == name) {
1770 std::swap(elems[i], elems[elems.size() - 1]);
1782 elems.emplace_back(name, value);
1785 DictionaryAttr::sortInPlace(elems);
1787 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1791void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1793 assert(index < getNumWorkgroupAttributions() &&
1794 "index must map to a workgroup attribution");
1796 getWorkgroupAttribAttrsAttrName());
1799void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1801 assert(index < getNumPrivateAttributions() &&
1802 "index must map to a private attribution");
1804 getPrivateAttribAttrsAttrName());
1807LogicalResult GPUFuncOp::verifyType() {
1808 if (isKernel() && getFunctionType().getNumResults() != 0)
1809 return emitOpError() <<
"expected void return type for kernel function";
1815LogicalResult GPUFuncOp::verifyBody() {
1817 return emitOpError() <<
"expected body with at least one block";
1818 unsigned numFuncArguments = getNumArguments();
1819 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1820 unsigned numBlockArguments = front().getNumArguments();
1821 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1823 << numFuncArguments + numWorkgroupAttributions
1824 <<
" arguments to body region";
1826 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1827 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1828 Type blockArgType = front().getArgument(i).getType();
1829 if (funcArgTypes[i] != blockArgType)
1830 return emitOpError() <<
"expected body region argument #" << i
1831 <<
" to be of type " << funcArgTypes[i] <<
", got "
1836 GPUDialect::getWorkgroupAddressSpace())) ||
1838 GPUDialect::getPrivateAddressSpace())))
1848LogicalResult gpu::ReturnOp::verify() {
1849 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1851 FunctionType funType = function.getFunctionType();
1853 if (funType.getNumResults() != getOperands().size())
1855 .append(
"expected ", funType.getNumResults(),
" result operands")
1856 .attachNote(function.getLoc())
1857 .append(
"return type declared here");
1859 for (
const auto &pair : llvm::enumerate(
1860 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1861 auto [type, operand] = pair.value();
1862 if (type != operand.getType())
1863 return emitOpError() <<
"unexpected type `" << operand.getType()
1864 <<
"' for operand #" << pair.index();
1873void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1875 Attribute offloadingHandler) {
1876 result.addRegion()->emplaceBlock();
1877 Properties &props =
result.getOrAddProperties<Properties>();
1879 props.targets = targets;
1881 props.offloadingHandler = offloadingHandler;
1884void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1885 StringRef name, ArrayRef<Attribute> targets,
1886 Attribute offloadingHandler) {
1887 build(builder,
result, name,
1892bool GPUModuleOp::hasTarget(Attribute
target) {
1893 if (
ArrayAttr targets = getTargetsAttr())
1894 return llvm::count(targets.getValue(),
target);
1898void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1899 ArrayAttr &targetsAttr = getProperties().targets;
1900 SmallVector<Attribute> targetsVector(targets);
1901 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1904LogicalResult GPUModuleOp::verify() {
1905 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1910 for (
auto target : targets) {
1911 if (
auto verifyTargetAttr =
1912 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1913 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1923void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1924 Attribute offloadingHandler,
ArrayAttr objects) {
1925 auto &properties =
result.getOrAddProperties<Properties>();
1928 properties.objects = objects;
1929 if (offloadingHandler)
1930 properties.offloadingHandler = offloadingHandler;
1932 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1935void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1936 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1937 build(builder,
result, name, offloadingHandler,
1949 if (!offloadingHandler)
1956 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
1957 printer <<
'<' << offloadingHandler <<
'>';
1964LogicalResult MemcpyOp::verify() {
1965 auto srcType = getSrc().getType();
1966 auto dstType = getDst().getType();
1969 return emitOpError(
"arguments have incompatible element type");
1972 return emitOpError(
"arguments have incompatible shape");
1981struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
1982 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1984 LogicalResult matchAndRewrite(MemcpyOp op,
1985 PatternRewriter &rewriter)
const override {
1986 Value dest = op.getDst();
1995 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
1996 return user != op &&
1997 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2003 if (op.getAsyncDependencies().size() > 1 ||
2004 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2005 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2007 rewriter.
replaceOp(op, op.getAsyncDependencies());
2014void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2015 MLIRContext *context) {
2016 results.
add<EraseTrivialCopyOp>(context);
2023LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2024 auto srcType = getSrcMemref().getType();
2025 auto resType = getRes().getType();
2026 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2027 auto operand = resMatrixType.getOperand();
2028 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2030 if (!srcMemrefType.isLastDimUnitStride())
2032 "expected source memref most minor dim must have unit stride");
2034 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2035 return emitError(
"only AOp, BOp and COp can be loaded");
2044LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2045 auto srcType = getSrc().getType();
2046 auto dstType = getDstMemref().getType();
2047 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2048 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2050 if (!dstMemrefType.isLastDimUnitStride())
2052 "expected destination memref most minor dim must have unit stride");
2054 if (srcMatrixType.getOperand() !=
"COp")
2056 "expected the operand matrix being stored to have 'COp' operand type");
2065LogicalResult SubgroupMmaComputeOp::verify() {
2066 enum OperandMap {
A,
B,
C };
2067 SmallVector<MMAMatrixType, 3> opTypes;
2068 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2069 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2070 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2072 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2073 opTypes[C].getOperand() !=
"COp")
2074 return emitError(
"operands must be in the order AOp, BOp, COp");
2076 ArrayRef<int64_t> aShape, bShape, cShape;
2077 aShape = opTypes[
A].getShape();
2078 bShape = opTypes[
B].getShape();
2079 cShape = opTypes[
C].getShape();
2081 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2082 bShape[1] != cShape[1])
2083 return emitError(
"operand shapes do not satisfy matmul constraints");
2088LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2089 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2093LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2094 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2107struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2111 LogicalResult matchAndRewrite(WaitOp op,
2112 PatternRewriter &rewriter)
const final {
2113 auto predicate = [](Value value) {
2114 auto waitOp = value.getDefiningOp<WaitOp>();
2115 return waitOp && waitOp->getNumOperands() == 0;
2117 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2119 SmallVector<Value> validOperands;
2120 for (Value operand : op->getOperands()) {
2121 if (predicate(operand))
2123 validOperands.push_back(operand);
2125 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2137struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2141 LogicalResult matchAndRewrite(WaitOp op,
2142 PatternRewriter &rewriter)
const final {
2145 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2150 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2151 op.getAsyncToken()) {
2152 rewriter.
replaceOp(op, op.getAsyncDependencies());
2156 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2166void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2167 MLIRContext *context) {
2168 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2175LogicalResult AllocOp::verify() {
2176 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2179 return emitOpError(
"dimension operand count does not equal memref "
2180 "dynamic dimension count");
2182 unsigned numSymbols = 0;
2183 if (!memRefType.getLayout().isIdentity())
2184 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2185 if (getSymbolOperands().size() != numSymbols) {
2187 "symbol operand count does not equal memref symbol count");
2197struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2198 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2200 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2201 PatternRewriter &rewriter)
const override {
2202 std::optional<int64_t> index = dimOp.getConstantIndex();
2206 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2207 if (!memrefType || index.value() >= memrefType.getRank() ||
2208 !memrefType.isDynamicDim(index.value()))
2211 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2215 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2216 memrefType.getDynamicDimIndex(index.value()));
2217 rewriter.
replaceOp(dimOp, substituteOp);
2224void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2225 MLIRContext *context) {
2226 results.
add<SimplifyDimOfAllocOp>(context);
2234 Attribute
target, CompilationTarget format,
2235 StringAttr
object, DictionaryAttr properties,
2236 KernelTableAttr kernels) {
2238 return emitError() <<
"the target attribute cannot be null";
2239 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2241 return emitError() <<
"the target attribute must implement or promise the "
2242 "`gpu::TargetAttrInterface`";
2246ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2247 StringAttr &
object) {
2248 std::optional<CompilationTarget> formatResult;
2249 StringRef enumKeyword;
2252 formatResult = CompilationTarget::Fatbin;
2253 if (!formatResult &&
2255 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2257 return odsParser.
emitError(loc,
"expected an equal sign");
2259 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2260 FailureOr<StringAttr> objectResult =
2261 FieldParser<StringAttr>::parse(odsParser);
2262 if (
failed(objectResult))
2264 "failed to parse GPU_ObjectAttr parameter "
2265 "'object' which is to be a `StringAttr`");
2266 format = *formatResult;
2267 object = *objectResult;
2271void printObject(AsmPrinter &odsParser, CompilationTarget format,
2272 StringAttr
object) {
2273 if (format != CompilationTarget::Fatbin)
2274 odsParser << stringifyEnum(format) <<
" = ";
2275 odsParser << object;
2288 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2289 if (intAttr.getInt() < 0) {
2290 return emitError() <<
"the object index must be positive";
2292 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2294 <<
"the target attribute must be a GPU Target attribute";
2304LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2305 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2306 return emitOpError() <<
"must be inside an op with symbol table";
2308 MemRefType memrefType = getResultMemref().getType();
2310 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2312 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2313 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2315 if (memrefType.hasStaticShape()) {
2316 return emitOpError() <<
"result memref type must be memref<?xi8, "
2317 "#gpu.address_space<workgroup>>";
2326void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2327 p <<
"(" << getLaneid() <<
")";
2329 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2330 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2331 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2333 if (!getArgs().empty())
2334 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2335 if (!getResults().empty())
2336 p <<
" -> (" << getResults().getTypes() <<
')';
2340 !getResults().empty());
2344ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2345 OperationState &
result) {
2347 result.regions.reserve(1);
2348 Region *warpRegion =
result.addRegion();
2351 OpAsmParser::UnresolvedOperand laneId;
2363 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2370 llvm::SMLoc inputsOperandsLoc;
2371 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2372 SmallVector<Type> inputTypes;
2382 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2393 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2401void WarpExecuteOnLane0Op::getSuccessorRegions(
2402 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2404 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2409 regions.push_back(RegionSuccessor(&getWarpRegion()));
2412void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2415 build(builder,
result, resultTypes, laneId, warpSize,
2419void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2423 result.addOperands(laneId);
2424 result.addAttribute(getAttributeNames()[0],
2426 result.addTypes(resultTypes);
2427 result.addOperands(args);
2428 assert(args.size() == blockArgTypes.size());
2429 OpBuilder::InsertionGuard guard(builder);
2430 Region *warpRegion =
result.addRegion();
2432 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2441 if (expanded == distributed)
2443 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2444 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2445 if (!expandedVecType || !distributedVecType)
2446 return op->
emitOpError(
"expected vector type for distributed operands.");
2447 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2448 expandedVecType.getElementType() != distributedVecType.getElementType())
2450 "expected distributed vectors to have same rank and element type.");
2453 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2454 int64_t eDim = expandedVecType.getDimSize(i);
2455 int64_t dDim = distributedVecType.getDimSize(i);
2458 if (eDim % dDim != 0)
2460 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2461 <<
") to be a multipler of the distributed vector dimension ("
2463 scales[i] = eDim / dDim;
2465 if (llvm::product_of(scales) != warpSize)
2467 <<
"incompatible distribution dimensions from " << expandedVecType
2468 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2473LogicalResult WarpExecuteOnLane0Op::verify() {
2474 if (getArgs().size() != getWarpRegion().getNumArguments())
2476 "expected same number op arguments and block arguments.");
2477 gpu::YieldOp yield = getTerminator();
2478 if (yield.getNumOperands() != getNumResults())
2480 "expected same number of yield operands and return values.");
2481 int64_t warpSize = getWarpSize();
2482 for (
auto [regionArg, arg] :
2483 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2485 warpSize, getOperation())))
2488 for (
auto [yieldOperand,
result] :
2489 llvm::zip_equal(yield.getOperands(), getResults())) {
2491 warpSize, getOperation())))
2496bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2501gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2502 return cast<gpu::YieldOp>(getBody()->getTerminator());
2509void gpu::SubgroupBroadcastOp::inferResultRanges(
2510 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2511 setResultRange(getResult(), argRanges.front());
2515 switch (getBroadcastType()) {
2516 case BroadcastType::first_active_lane:
2520 case BroadcastType::specific_lane:
2526LogicalResult gpu::SubgroupBroadcastOp::verify() {
2527 switch (getBroadcastType()) {
2528 case BroadcastType::first_active_lane:
2531 <<
"lane can only be specified for `specific_lane` broadcast";
2533 case BroadcastType::specific_lane:
2536 <<
"lane must be specified for `specific_lane` broadcast";
2541OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2543 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2544 return prev.getResult();
2553KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2554 DictionaryAttr metadata) {
2555 assert(kernel &&
"invalid kernel");
2556 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2557 kernel.getAllArgAttrs(), metadata);
2562 FunctionOpInterface kernel,
2563 DictionaryAttr metadata) {
2564 assert(kernel &&
"invalid kernel");
2565 return getChecked(
emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2566 kernel.getAllArgAttrs(), metadata);
2570KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2573 NamedAttrList attrList;
2574 if (DictionaryAttr dict = getMetadata())
2577 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2583 StringAttr name, Type functionType,
2584 ArrayAttr argAttrs, DictionaryAttr metadata) {
2586 return emitError() <<
"the kernel name can't be empty";
2588 if (llvm::any_of(argAttrs, [](Attribute attr) {
2589 return !llvm::isa<DictionaryAttr>(attr);
2592 <<
"all attributes in the array must be a dictionary attribute";
2601KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2602 ArrayRef<KernelMetadataAttr> kernels,
2605 assert((!isSorted || llvm::is_sorted(kernels)) &&
2606 "expected a sorted kernel array");
2608 if (isSorted || llvm::is_sorted(kernels))
2609 return Base::get(context, kernels);
2611 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2612 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2613 return Base::get(context, kernelsTmp);
2616KernelTableAttr KernelTableAttr::getChecked(
2618 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2620 assert((!isSorted || llvm::is_sorted(kernels)) &&
2621 "expected a sorted kernel array");
2623 if (isSorted || llvm::is_sorted(kernels))
2624 return Base::getChecked(
emitError, context, kernels);
2626 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2627 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2628 return Base::getChecked(
emitError, context, kernelsTmp);
2633 ArrayRef<KernelMetadataAttr> kernels) {
2634 if (kernels.size() < 2)
2637 if (std::adjacent_find(kernels.begin(), kernels.end(),
2638 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2639 return l.getName() == r.getName();
2640 }) != kernels.end()) {
2641 return emitError() <<
"expected all kernels to be uniquely named";
2646KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2648 return found ? *iterator : KernelMetadataAttr();
2651KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2653 return found ? *iterator : KernelMetadataAttr();
2733 return CompilationTarget::Fatbin;
2736std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2738 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2739 llvm::StringSaver stringSaver(
options.first);
2745 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2746 opts.consume_front(
"\""), opts.consume_back(
"\"");
2747 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2748 opts.consume_front(
"'"), opts.consume_back(
"'");
2750 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2753 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2759std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2764std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2766 size_t startPos =
cmdOptions.find(startsWith);
2767 if (startPos == std::string::npos)
2778#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2779#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2781#define GET_ATTRDEF_CLASSES
2782#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2784#define GET_OP_CLASSES
2785#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2787#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....