18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
39 #include <cassert>
41 using namespace mlir;
42 using namespace mlir::gpu;
44 #include "mlir/Dialect/GPU/IR/"
46 //===----------------------------------------------------------------------===//
47 // GPU Device Mapping Attributes
48 //===----------------------------------------------------------------------===//
50 int64_t GPUBlockMappingAttr::getMappingId() const {
51  return static_cast<int64_t>(getBlock());
52 }
54 bool GPUBlockMappingAttr::isLinearMapping() const {
55  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
56 }
58 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
59  return isLinearMapping()
60  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
61  : getMappingId();
62 }
64 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
65  return static_cast<int64_t>(getWarpgroup());
66 }
68 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
69  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
70 }
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
73  return isLinearMapping()
74  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
75  : getMappingId();
76 }
78 int64_t GPUWarpMappingAttr::getMappingId() const {
79  return static_cast<int64_t>(getWarp());
80 }
82 bool GPUWarpMappingAttr::isLinearMapping() const {
83  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
84 }
86 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
87  return isLinearMapping()
88  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
89  : getMappingId();
90 }
92 int64_t GPUThreadMappingAttr::getMappingId() const {
93  return static_cast<int64_t>(getThread());
94 }
96 bool GPUThreadMappingAttr::isLinearMapping() const {
97  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
98 }
100 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
101  return isLinearMapping()
102  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
103  : getMappingId();
104 }
106 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
107  return static_cast<int64_t>(getAddressSpace());
108 }
110 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
111  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
112 }
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
115  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
116 }
118 //===----------------------------------------------------------------------===//
119 // MMAMatrixType
120 //===----------------------------------------------------------------------===//
123  StringRef operand) {
124  return Base::get(elementType.getContext(), shape, elementType, operand);
125 }
129  ArrayRef<int64_t> shape, Type elementType,
130  StringRef operand) {
131  return Base::getChecked(emitError, elementType.getContext(), shape,
132  elementType, operand);
133 }
135 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
138  return getImpl()->getShape();
139 }
141 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
143 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
146  return elementType.isF16() || elementType.isF32() ||
147  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
148  elementType.isInteger(32);
149 }
153  ArrayRef<int64_t> shape, Type elementType,
154  StringRef operand) {
155  if (!operand.equals("AOp") && !operand.equals("BOp") &&
156  !operand.equals("COp"))
157  return emitError() << "operand expected to be one of AOp, BOp or COp";
159  if (shape.size() != 2)
160  return emitError() << "MMAMatrixType must have exactly two dimensions";
162  if (!MMAMatrixType::isValidElementType(elementType))
163  return emitError()
164  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
166  return success();
167 }
169 //===----------------------------------------------------------------------===//
170 // GPUDialect
171 //===----------------------------------------------------------------------===//
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
174  if (!memorySpace)
175  return false;
176  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177  return gpuAttr.getValue() == getWorkgroupAddressSpace();
178  return false;
179 }
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182  Attribute memorySpace = type.getMemorySpace();
183  return isWorkgroupMemoryAddressSpace(memorySpace);
184 }
186 bool GPUDialect::isKernel(Operation *op) {
187  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188  return static_cast<bool>(isKernelAttr);
189 }
191 namespace {
192 /// This class defines the interface for handling inlining with gpu
193 /// operations.
194 struct GPUInlinerInterface : public DialectInlinerInterface {
197  /// All gpu dialect ops can be inlined.
198  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
199  return true;
200  }
201 };
202 } // namespace
204 void GPUDialect::initialize() {
205  addTypes<AsyncTokenType>();
206  addTypes<MMAMatrixType>();
207  addTypes<SparseDnTensorHandleType>();
208  addTypes<SparseSpMatHandleType>();
209  addTypes<SparseSpGEMMOpHandleType>();
210  addOperations<
211 #define GET_OP_LIST
212 #include "mlir/Dialect/GPU/IR/"
213  >();
214  addAttributes<
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/"
217  >();
218  addInterfaces<GPUInlinerInterface>();
219  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
220  TerminatorOp>();
221 }
223 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
224  switch (kind) {
226  return "sparse.dntensor_handle";
228  return "sparse.spmat_handle";
230  return "sparse.spgemmop_handle";
231  }
232  llvm_unreachable("unknown sparse handle kind");
233  return "";
234 }
237  // Parse the main keyword for the type.
238  StringRef keyword;
239  if (parser.parseKeyword(&keyword))
240  return Type();
241  MLIRContext *context = getContext();
243  // Handle 'async token' types.
244  if (keyword == "async.token")
245  return AsyncTokenType::get(context);
247  if (keyword == "mma_matrix") {
248  SMLoc beginLoc = parser.getNameLoc();
250  // Parse '<'.
251  if (parser.parseLess())
252  return nullptr;
254  // Parse the size and elementType.
255  SmallVector<int64_t> shape;
256  Type elementType;
257  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
258  parser.parseType(elementType))
259  return nullptr;
261  // Parse ','
262  if (parser.parseComma())
263  return nullptr;
265  // Parse operand.
266  std::string operand;
267  if (failed(parser.parseOptionalString(&operand)))
268  return nullptr;
270  // Parse '>'.
271  if (parser.parseGreater())
272  return nullptr;
275  parser.getEncodedSourceLoc(beginLoc)),
276  shape, elementType, operand);
277  }
280  return SparseDnTensorHandleType::get(context);
282  return SparseSpMatHandleType::get(context);
284  return SparseSpGEMMOpHandleType::get(context);
286  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
287  return Type();
288 }
289 // TODO: print refined type here. Notice that should be corresponding to the
290 // parser
291 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
292  TypeSwitch<Type>(type)
293  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
294  .Case<SparseDnTensorHandleType>([&](Type) {
296  })
297  .Case<SparseSpMatHandleType>(
299  .Case<SparseSpGEMMOpHandleType>([&](Type) {
301  })
302  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
303  os << "mma_matrix<";
304  auto shape = fragTy.getShape();
305  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
306  os << *dim << 'x';
307  os << shape.back() << 'x' << fragTy.getElementType();
308  os << ", \"" << fragTy.getOperand() << "\"" << '>';
309  })
310  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
311 }
313 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
314  NamedAttribute attr) {
315  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
316  attr.getName() != getContainerModuleAttrName())
317  return success();
319  auto module = dyn_cast<ModuleOp>(op);
320  if (!module)
321  return op->emitError("expected '")
322  << getContainerModuleAttrName() << "' attribute to be attached to '"
323  << ModuleOp::getOperationName() << '\'';
325  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
326  // Ignore launches that are nested more or less deep than functions in the
327  // module we are currently checking.
328  if (!launchOp->getParentOp() ||
329  launchOp->getParentOp()->getParentOp() != module)
330  return success();
332  // Ignore launch ops with missing attributes here. The errors will be
333  // reported by the verifiers of those ops.
334  if (!launchOp->getAttrOfType<SymbolRefAttr>(
335  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
336  return success();
338  // Check that `launch_func` refers to a well-formed GPU kernel container.
339  StringAttr kernelContainerName = launchOp.getKernelModuleName();
340  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
341  if (!kernelContainer)
342  return launchOp.emitOpError()
343  << "kernel container '" << kernelContainerName.getValue()
344  << "' is undefined";
346  // If the container is a GPU binary op return success.
347  if (isa<BinaryOp>(kernelContainer))
348  return success();
350  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
351  if (!kernelModule)
352  return launchOp.emitOpError()
353  << "kernel module '" << kernelContainerName.getValue()
354  << "' is undefined";
356  // Check that `launch_func` refers to a well-formed kernel function.
357  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
358  if (!kernelFunc)
359  return launchOp.emitOpError("kernel function '")
360  << launchOp.getKernel() << "' is undefined";
361  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
362  if (!kernelConvertedFunction) {
363  InFlightDiagnostic diag = launchOp.emitOpError()
364  << "referenced kernel '" << launchOp.getKernel()
365  << "' is not a function";
366  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
367  return diag;
368  }
370  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
371  GPUDialect::getKernelFuncAttrName()))
372  return launchOp.emitOpError("kernel function is missing the '")
373  << GPUDialect::getKernelFuncAttrName() << "' attribute";
375  // TODO: If the kernel isn't a GPU function (which happens during separate
376  // compilation), do not check type correspondence as it would require the
377  // verifier to be aware of the type conversion.
378  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
379  if (!kernelGPUFunction)
380  return success();
382  unsigned actualNumArguments = launchOp.getNumKernelOperands();
383  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
384  if (expectedNumArguments != actualNumArguments)
385  return launchOp.emitOpError("got ")
386  << actualNumArguments << " kernel operands but expected "
387  << expectedNumArguments;
389  auto functionType = kernelGPUFunction.getFunctionType();
390  for (unsigned i = 0; i < expectedNumArguments; ++i) {
391  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
392  return launchOp.emitOpError("type of function argument ")
393  << i << " does not match";
394  }
395  }
397  return success();
398  });
400  return walkResult.wasInterrupted() ? failure() : success();
401 }
403 /// Parses an optional list of async operands with an optional leading keyword.
404 /// (`async`)? (`[` ssa-id-list `]`)?
405 ///
406 /// This method is used by the tablegen assembly format for async ops as well.
408  OpAsmParser &parser, Type &asyncTokenType,
410  auto loc = parser.getCurrentLocation();
411  if (succeeded(parser.parseOptionalKeyword("async"))) {
412  if (parser.getNumResults() == 0)
413  return parser.emitError(loc, "needs to be named when marked 'async'");
414  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
415  }
416  return parser.parseOperandList(asyncDependencies,
418 }
420 /// Prints optional async dependencies with its leading keyword.
421 /// (`async`)? (`[` ssa-id-list `]`)?
422 // Used by the tablegen assembly format for several async ops.
424  Type asyncTokenType,
425  OperandRange asyncDependencies) {
426  if (asyncTokenType)
427  printer << "async";
428  if (asyncDependencies.empty())
429  return;
430  if (asyncTokenType)
431  printer << ' ';
432  printer << '[';
433  llvm::interleaveComma(asyncDependencies, printer);
434  printer << ']';
435 }
437 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
438 /// Parses a GPU function memory attribution.
439 ///
440 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
441 /// (`private` `(` ssa-id-and-type-list `)`)?
442 ///
443 /// Note that this function parses only one of the two similar parts, with the
444 /// keyword provided as argument.
445 static ParseResult
446 parseAttributions(OpAsmParser &parser, StringRef keyword,
448  // If we could not parse the keyword, just assume empty list and succeed.
449  if (failed(parser.parseOptionalKeyword(keyword)))
450  return success();
453  /*allowType=*/true);
454 }
456 /// Prints a GPU function memory attribution.
457 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
458  ArrayRef<BlockArgument> values) {
459  if (values.empty())
460  return;
462  p << ' ' << keyword << '(';
463  llvm::interleaveComma(
464  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
465  p << ')';
466 }
468 /// Verifies a GPU function memory attribution.
470  ArrayRef<BlockArgument> attributions,
471  gpu::AddressSpace memorySpace) {
472  for (Value v : attributions) {
473  auto type = llvm::dyn_cast<MemRefType>(v.getType());
474  if (!type)
475  return op->emitOpError() << "expected memref type in attribution";
477  // We can only verify the address space if it hasn't already been lowered
478  // from the AddressSpaceAttr to a target-specific numeric value.
479  auto addressSpace =
480  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
481  if (!addressSpace)
482  continue;
483  if (addressSpace.getValue() != memorySpace)
484  return op->emitOpError()
485  << "expected memory space " << stringifyAddressSpace(memorySpace)
486  << " in attribution";
487  }
488  return success();
489 }
491 //===----------------------------------------------------------------------===//
492 // AllReduceOp
493 //===----------------------------------------------------------------------===//
495 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
496  Type resType) {
497  using Kind = gpu::AllReduceOperation;
498  if (llvm::is_contained(
499  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
500  opName)) {
501  if (!isa<FloatType>(resType))
502  return failure();
503  }
505  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
506  Kind::AND, Kind::OR, Kind::XOR},
507  opName)) {
508  if (!isa<IntegerType>(resType))
509  return failure();
510  }
512  return success();
513 }
515 LogicalResult gpu::AllReduceOp::verifyRegions() {
516  if (getBody().empty() != getOp().has_value())
517  return emitError("expected either an op attribute or a non-empty body");
518  if (!getBody().empty()) {
519  if (getBody().getNumArguments() != 2)
520  return emitError("expected two region arguments");
521  for (auto argument : getBody().getArguments()) {
522  if (argument.getType() != getType())
523  return emitError("incorrect region argument type");
524  }
525  unsigned yieldCount = 0;
526  for (Block &block : getBody()) {
527  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
528  if (yield.getNumOperands() != 1)
529  return emitError("expected one gpu.yield operand");
530  if (yield.getOperand(0).getType() != getType())
531  return emitError("incorrect gpu.yield type");
532  ++yieldCount;
533  }
534  }
535  if (yieldCount == 0)
536  return emitError("expected gpu.yield op in region");
537  } else {
538  gpu::AllReduceOperation opName = *getOp();
539  if (failed(verifyReduceOpAndType(opName, getType()))) {
540  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
541  << "` reduction operation is not compatible with type "
542  << getType();
543  }
544  }
546  return success();
547 }
550  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
551  if (!launchOp)
552  return false;
554  Region &body = launchOp.getBody();
555  assert(!body.empty() && "Invalid region");
557  // Only convert ops in gpu::launch entry block for now.
558  return op->getBlock() == &body.front();
559 }
561 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
562  if (!getUniform() && canMakeGroupOpUniform(*this)) {
563  setUniform(true);
564  return getResult();
565  }
567  return nullptr;
568 }
570 // TODO: Support optional custom attributes (without dialect prefix).
572  AllReduceOperationAttr &attr) {
573  StringRef enumStr;
574  if (!parser.parseOptionalKeyword(&enumStr)) {
575  std::optional<AllReduceOperation> op =
576  gpu::symbolizeAllReduceOperation(enumStr);
577  if (!op)
578  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
579  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
580  }
581  return success();
582 }
584 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
585  AllReduceOperationAttr attr) {
586  if (attr)
587  attr.print(printer);
588 }
590 //===----------------------------------------------------------------------===//
591 // SubgroupReduceOp
592 //===----------------------------------------------------------------------===//
595  Type elemType = getType();
596  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
597  if (vecTy.isScalable())
598  return emitOpError() << "is not compatible with scalable vector types";
600  elemType = vecTy.getElementType();
601  }
603  gpu::AllReduceOperation opName = getOp();
604  if (failed(verifyReduceOpAndType(opName, elemType))) {
605  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
606  << "` reduction operation is not compatible with type "
607  << getType();
608  }
609  return success();
610 }
612 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
613  if (!getUniform() && canMakeGroupOpUniform(*this)) {
614  setUniform(true);
615  return getResult();
616  }
618  return nullptr;
619 }
621 //===----------------------------------------------------------------------===//
622 // AsyncOpInterface
623 //===----------------------------------------------------------------------===//
626  op->insertOperands(0, {token});
627  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
628  return;
629  auto attrName =
631  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
633  // Async dependencies is the only variadic operand.
634  if (!sizeAttr)
635  return;
637  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
638  ++sizes.front();
639  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
640 }
642 //===----------------------------------------------------------------------===//
643 // LaunchOp
644 //===----------------------------------------------------------------------===//
646 void LaunchOp::build(OpBuilder &builder, OperationState &result,
647  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
648  Value getBlockSizeX, Value getBlockSizeY,
649  Value getBlockSizeZ, Value dynamicSharedMemorySize,
650  Type asyncTokenType, ValueRange asyncDependencies,
651  TypeRange workgroupAttributions,
652  TypeRange privateAttributions, Value clusterSizeX,
653  Value clusterSizeY, Value clusterSizeZ) {
654  OpBuilder::InsertionGuard g(builder);
656  // Add a WorkGroup attribution attribute. This attribute is required to
657  // identify private attributions in the list of block argguments.
658  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
659  builder.getI64IntegerAttr(workgroupAttributions.size()));
661  // Add Op operands.
662  result.addOperands(asyncDependencies);
663  if (asyncTokenType)
664  result.types.push_back(builder.getType<AsyncTokenType>());
666  // Add grid and block sizes as op operands, followed by the data operands.
667  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
668  getBlockSizeY, getBlockSizeZ});
669  if (clusterSizeX)
670  result.addOperands(clusterSizeX);
671  if (clusterSizeY)
672  result.addOperands(clusterSizeY);
673  if (clusterSizeZ)
674  result.addOperands(clusterSizeZ);
675  if (dynamicSharedMemorySize)
676  result.addOperands(dynamicSharedMemorySize);
678  // Create a kernel body region with kNumConfigRegionAttributes + N memory
679  // attributions, where the first kNumConfigRegionAttributes arguments have
680  // `index` type and the rest have the same types as the data operands.
681  Region *kernelRegion = result.addRegion();
682  Block *body = builder.createBlock(kernelRegion);
683  // TODO: Allow passing in proper locations here.
684  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
685  body->addArgument(builder.getIndexType(), result.location);
686  // Add WorkGroup & Private attributions to the region arguments.
687  for (Type argTy : workgroupAttributions)
688  body->addArgument(argTy, result.location);
689  for (Type argTy : privateAttributions)
690  body->addArgument(argTy, result.location);
691  // Fill OperandSegmentSize Attribute.
692  SmallVector<int32_t, 11> segmentSizes(11, 1);
693  segmentSizes.front() = asyncDependencies.size();
694  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
695  segmentSizes[7] = clusterSizeX ? 1 : 0;
696  segmentSizes[8] = clusterSizeY ? 1 : 0;
697  segmentSizes[9] = clusterSizeZ ? 1 : 0;
698  result.addAttribute(getOperandSegmentSizeAttr(),
699  builder.getDenseI32ArrayAttr(segmentSizes));
700 }
702 KernelDim3 LaunchOp::getBlockIds() {
703  assert(!getBody().empty() && "LaunchOp body must not be empty.");
704  auto args = getBody().getArguments();
705  return KernelDim3{args[0], args[1], args[2]};
706 }
708 KernelDim3 LaunchOp::getThreadIds() {
709  assert(!getBody().empty() && "LaunchOp body must not be empty.");
710  auto args = getBody().getArguments();
711  return KernelDim3{args[3], args[4], args[5]};
712 }
714 KernelDim3 LaunchOp::getGridSize() {
715  assert(!getBody().empty() && "LaunchOp body must not be empty.");
716  auto args = getBody().getArguments();
717  return KernelDim3{args[6], args[7], args[8]};
718 }
721  assert(!getBody().empty() && "LaunchOp body must not be empty.");
722  auto args = getBody().getArguments();
723  return KernelDim3{args[9], args[10], args[11]};
724 }
726 std::optional<KernelDim3> LaunchOp::getClusterIds() {
727  assert(!getBody().empty() && "LaunchOp body must not be empty.");
728  if (!hasClusterSize())
729  return std::nullopt;
730  auto args = getBody().getArguments();
731  return KernelDim3{args[12], args[13], args[14]};
732 }
734 std::optional<KernelDim3> LaunchOp::getClusterSize() {
735  assert(!getBody().empty() && "LaunchOp body must not be empty.");
736  if (!hasClusterSize())
737  return std::nullopt;
738  auto args = getBody().getArguments();
739  return KernelDim3{args[15], args[16], args[17]};
740 }
742 KernelDim3 LaunchOp::getGridSizeOperandValues() {
743  auto operands = getOperands().drop_front(getAsyncDependencies().size());
744  return KernelDim3{operands[0], operands[1], operands[2]};
745 }
747 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
748  auto operands = getOperands().drop_front(getAsyncDependencies().size());
749  return KernelDim3{operands[3], operands[4], operands[5]};
750 }
752 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
753  auto operands = getOperands().drop_front(getAsyncDependencies().size());
754  if (!hasClusterSize())
755  return std::nullopt;
756  return KernelDim3{operands[6], operands[7], operands[8]};
757 }
760  if (!(hasClusterSize()) &&
761  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
762  return emitOpError() << "cluster size must be all present";
763  return success();
764 }
766 LogicalResult LaunchOp::verifyRegions() {
767  // Kernel launch takes kNumConfigOperands leading operands for grid/block
768  // sizes and transforms them into kNumConfigRegionAttributes region arguments
769  // for block/thread identifiers and grid/block sizes.
770  if (!getBody().empty()) {
771  if (getBody().getNumArguments() <
772  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
773  return emitOpError("unexpected number of region arguments");
774  }
776  // Verify Attributions Address Spaces.
777  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
778  GPUDialect::getWorkgroupAddressSpace())) ||
779  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
780  GPUDialect::getPrivateAddressSpace())))
781  return failure();
783  // Block terminators without successors are expected to exit the kernel region
784  // and must be `gpu.terminator`.
785  for (Block &block : getBody()) {
786  if (block.empty())
787  continue;
788  if (block.back().getNumSuccessors() != 0)
789  continue;
790  if (!isa<gpu::TerminatorOp>(&block.back())) {
791  return block.back()
792  .emitError()
793  .append("expected '", gpu::TerminatorOp::getOperationName(),
794  "' or a terminator with successors")
795  .attachNote(getLoc())
796  .append("in '", LaunchOp::getOperationName(), "' body region");
797  }
798  }
800  if (getNumResults() == 0 && getAsyncToken())
801  return emitOpError("needs to be named when async keyword is specified");
803  return success();
804 }
806 // Pretty-print the kernel grid/block size assignment as
807 // (%iter-x, %iter-y, %iter-z) in
808 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
809 // where %size-* and %iter-* will correspond to the body region arguments.
811  KernelDim3 operands, KernelDim3 ids) {
812  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
813  p << size.x << " = " << operands.x << ", ";
814  p << size.y << " = " << operands.y << ", ";
815  p << size.z << " = " << operands.z << ')';
816 }
818 void LaunchOp::print(OpAsmPrinter &p) {
819  if (getAsyncToken()) {
820  p << " async";
821  if (!getAsyncDependencies().empty())
822  p << " [" << getAsyncDependencies() << ']';
823  }
824  // Print the launch configuration.
825  if (hasClusterSize()) {
826  p << ' ' << getClustersKeyword();
827  printSizeAssignment(p, getClusterSize().value(),
828  getClusterSizeOperandValues().value(),
829  getClusterIds().value());
830  }
831  p << ' ' << getBlocksKeyword();
832  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
833  getBlockIds());
834  p << ' ' << getThreadsKeyword();
835  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
836  getThreadIds());
837  if (getDynamicSharedMemorySize())
838  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
839  << getDynamicSharedMemorySize();
841  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
842  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
844  p << ' ';
846  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
847  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
848  LaunchOp::getOperandSegmentSizeAttr(),
849  getNumWorkgroupAttributionsAttrName()});
850 }
852 // Parse the size assignment blocks for blocks and threads. These have the form
853 // (%region_arg, %region_arg, %region_arg) in
854 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
855 // where %region_arg are percent-identifiers for the region arguments to be
856 // introduced further (SSA defs), and %operand are percent-identifiers for the
857 // SSA value uses.
858 static ParseResult
863  assert(indices.size() == 3 && "space for three indices expected");
866  /*allowResultNumber=*/false) ||
867  parser.parseKeyword("in") || parser.parseLParen())
868  return failure();
869  std::move(args.begin(), args.end(), indices.begin());
871  for (int i = 0; i < 3; ++i) {
872  if (i != 0 && parser.parseComma())
873  return failure();
874  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
875  parser.parseEqual() || parser.parseOperand(sizes[i]))
876  return failure();
877  }
879  return parser.parseRParen();
880 }
882 /// Parses a Launch operation.
883 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
884 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
885 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
886 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
887 /// memory-attribution
888 /// region attr-dict?
889 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
891  // Sizes of the grid and block.
893  sizes(LaunchOp::kNumConfigOperands);
895  // Actual (data) operands passed to the kernel.
898  // Region arguments to be created.
900  LaunchOp::kNumConfigRegionAttributes);
902  // Parse optional async dependencies.
904  Type asyncTokenType;
905  if (failed(
906  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
907  parser.resolveOperands(asyncDependencies, asyncTokenType,
908  result.operands))
909  return failure();
910  if (parser.getNumResults() > 0)
911  result.types.push_back(asyncTokenType);
913  bool hasCluster = false;
914  if (succeeded(
915  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
916  hasCluster = true;
917  sizes.resize(9);
918  regionArgs.resize(18);
919  }
921  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
923  // Last three segment assigns the cluster size. In the region argument
924  // list, this is last 6 arguments.
925  if (hasCluster) {
926  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
927  regionArgsRef.slice(15, 3),
928  regionArgsRef.slice(12, 3)))
929  return failure();
930  }
931  // Parse the size assignment segments: the first segment assigns grid sizes
932  // and defines values for block identifiers; the second segment assigns block
933  // sizes and defines values for thread identifiers. In the region argument
934  // list, identifiers precede sizes, and block-related values precede
935  // thread-related values.
936  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
937  parseSizeAssignment(parser, sizesRef.take_front(3),
938  regionArgsRef.slice(6, 3),
939  regionArgsRef.slice(0, 3)) ||
940  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
941  parseSizeAssignment(parser, sizesRef.drop_front(3),
942  regionArgsRef.slice(9, 3),
943  regionArgsRef.slice(3, 3)) ||
944  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
945  result.operands))
946  return failure();
948  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
949  bool hasDynamicSharedMemorySize = false;
950  if (!parser.parseOptionalKeyword(
951  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
952  hasDynamicSharedMemorySize = true;
953  if (parser.parseOperand(dynamicSharedMemorySize) ||
954  parser.resolveOperand(dynamicSharedMemorySize,
955  parser.getBuilder().getI32Type(),
956  result.operands))
957  return failure();
958  }
960  // Create the region arguments, it has kNumConfigRegionAttributes arguments
961  // that correspond to block/thread identifiers and grid/block sizes, all
962  // having `index` type, a variadic number of WorkGroup Attributions and
963  // a variadic number of Private Attributions. The number of WorkGroup
964  // Attributions is stored in the attr with name:
965  // LaunchOp::getNumWorkgroupAttributionsAttrName().
966  Type index = parser.getBuilder().getIndexType();
968  LaunchOp::kNumConfigRegionAttributes + 6, index);
970  SmallVector<OpAsmParser::Argument> regionArguments;
971  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
973  arg.ssaName = std::get<0>(ssaValueAndType);
974  arg.type = std::get<1>(ssaValueAndType);
975  regionArguments.push_back(arg);
976  }
978  Builder &builder = parser.getBuilder();
979  // Parse workgroup memory attributions.
980  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
981  regionArguments)))
982  return failure();
984  // Store the number of operands we just parsed as the number of workgroup
985  // memory attributions.
986  unsigned numWorkgroupAttrs = regionArguments.size() -
987  LaunchOp::kNumConfigRegionAttributes -
988  (hasCluster ? 6 : 0);
989  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
990  builder.getI64IntegerAttr(numWorkgroupAttrs));
992  // Parse private memory attributions.
993  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
994  regionArguments)))
995  return failure();
997  // Introduce the body region and parse it. The region has
998  // kNumConfigRegionAttributes arguments that correspond to
999  // block/thread identifiers and grid/block sizes, all having `index` type.
1000  Region *body = result.addRegion();
1001  if (parser.parseRegion(*body, regionArguments) ||
1002  parser.parseOptionalAttrDict(result.attributes))
1003  return failure();
1005  SmallVector<int32_t, 11> segmentSizes(11, 1);
1006  segmentSizes.front() = asyncDependencies.size();
1008  if (!hasCluster) {
1009  segmentSizes[7] = 0;
1010  segmentSizes[8] = 0;
1011  segmentSizes[9] = 0;
1012  }
1013  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1014  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1015  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1016  return success();
1017 }
1019 /// Simplify the gpu.launch when the range of a thread or block ID is
1020 /// trivially known to be one.
1021 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1024  PatternRewriter &rewriter) const override {
1025  // If the range implies a single value for `id`, replace `id`'s uses by
1026  // zero.
1027  Value zero;
1028  bool simplified = false;
1029  auto constPropIdUses = [&](Value id, Value size) {
1030  // Check if size is trivially one.
1031  if (!matchPattern(size, m_One()))
1032  return;
1033  if (id.getUses().empty())
1034  return;
1035  if (!simplified) {
1036  // Create a zero value the first time.
1037  OpBuilder::InsertionGuard guard(rewriter);
1038  rewriter.setInsertionPointToStart(&op.getBody().front());
1039  zero =
1040  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1041  }
1042  rewriter.replaceAllUsesWith(id, zero);
1043  simplified = true;
1044  };
1045  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1046  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1047  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1048  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1049  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1050  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1052  return success(simplified);
1053  }
1054 };
1056 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1057  MLIRContext *context) {
1058  rewrites.add<FoldLaunchArguments>(context);
1059 }
1061 /// Adds a new block argument that corresponds to buffers located in
1062 /// workgroup memory.
1063 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1064  auto attrName = getNumWorkgroupAttributionsAttrName();
1065  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1066  (*this)->setAttr(attrName,
1067  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1068  return getBody().insertArgument(
1069  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1070 }
1072 /// Adds a new block argument that corresponds to buffers located in
1073 /// private memory.
1074 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1075  // Buffers on the private memory always come after buffers on the workgroup
1076  // memory.
1077  return getBody().addArgument(type, loc);
1078 }
1080 //===----------------------------------------------------------------------===//
1081 // LaunchFuncOp
1082 //===----------------------------------------------------------------------===//
1084 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1085  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1086  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1087  ValueRange kernelOperands, Type asyncTokenType,
1088  ValueRange asyncDependencies,
1089  std::optional<KernelDim3> clusterSize) {
1090  result.addOperands(asyncDependencies);
1091  if (asyncTokenType)
1092  result.types.push_back(builder.getType<AsyncTokenType>());
1094  // Add grid and block sizes as op operands, followed by the data operands.
1095  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1096  getBlockSize.y, getBlockSize.z});
1097  if (clusterSize.has_value())
1098  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1099  if (dynamicSharedMemorySize)
1100  result.addOperands(dynamicSharedMemorySize);
1101  result.addOperands(kernelOperands);
1102  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1103  auto kernelSymbol =
1104  SymbolRefAttr::get(kernelModule.getNameAttr(),
1105  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1107  Properties &prop = result.getOrAddProperties<Properties>();
1108  prop.kernel = kernelSymbol;
1109  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1110  // Initialize the segment sizes to 1.
1111  for (auto &sz : prop.operandSegmentSizes)
1112  sz = 1;
1113  prop.operandSegmentSizes[0] = asyncDependencies.size();
1114  if (!clusterSize.has_value()) {
1115  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1116  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1117  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1118  }
1119  prop.operandSegmentSizes[segmentSizesLen - 3] =
1120  dynamicSharedMemorySize ? 1 : 0;
1121  prop.operandSegmentSizes[segmentSizesLen - 2] =
1122  static_cast<int32_t>(kernelOperands.size());
1123  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1124 }
1126 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1127  SymbolRefAttr kernel, KernelDim3 gridSize,
1128  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1129  ValueRange kernelOperands, Value asyncObject,
1130  std::optional<KernelDim3> clusterSize) {
1131  // Add grid and block sizes as op operands, followed by the data operands.
1132  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1133  getBlockSize.y, getBlockSize.z});
1134  if (clusterSize.has_value())
1135  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1136  if (dynamicSharedMemorySize)
1137  result.addOperands(dynamicSharedMemorySize);
1138  result.addOperands(kernelOperands);
1139  if (asyncObject)
1140  result.addOperands(asyncObject);
1141  Properties &prop = result.getOrAddProperties<Properties>();
1142  prop.kernel = kernel;
1143  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1144  // Initialize the segment sizes to 1.
1145  for (auto &sz : prop.operandSegmentSizes)
1146  sz = 1;
1147  prop.operandSegmentSizes[0] = 0;
1148  if (!clusterSize.has_value()) {
1149  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1150  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1151  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1152  }
1153  prop.operandSegmentSizes[segmentSizesLen - 3] =
1154  dynamicSharedMemorySize ? 1 : 0;
1155  prop.operandSegmentSizes[segmentSizesLen - 2] =
1156  static_cast<int32_t>(kernelOperands.size());
1157  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1158 }
1160 StringAttr LaunchFuncOp::getKernelModuleName() {
1161  return getKernel().getRootReference();
1162 }
1164 StringAttr LaunchFuncOp::getKernelName() {
1165  return getKernel().getLeafReference();
1166 }
1168 unsigned LaunchFuncOp::getNumKernelOperands() {
1169  return getKernelOperands().size();
1170 }
1172 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1173  return getKernelOperands()[i];
1174 }
1176 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1177  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1178  return KernelDim3{operands[0], operands[1], operands[2]};
1179 }
1181 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1182  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1183  return KernelDim3{operands[3], operands[4], operands[5]};
1184 }
1186 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1187  assert(hasClusterSize() &&
1188  "cluster size is not set, check hasClusterSize() first");
1189  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1190  return KernelDim3{operands[6], operands[7], operands[8]};
1191 }
1194  auto module = (*this)->getParentOfType<ModuleOp>();
1195  if (!module)
1196  return emitOpError("expected to belong to a module");
1198  if (!module->getAttrOfType<UnitAttr>(
1199  GPUDialect::getContainerModuleAttrName()))
1200  return emitOpError("expected the closest surrounding module to have the '" +
1201  GPUDialect::getContainerModuleAttrName() +
1202  "' attribute");
1204  if (hasClusterSize()) {
1205  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1206  getClusterSizeZ().getType() != getClusterSizeX().getType())
1207  return emitOpError()
1208  << "expects types of the cluster dimensions must be the same";
1209  }
1211  return success();
1212 }
1214 static ParseResult
1216  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1217  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1218  if (succeeded(parser.parseOptionalColon())) {
1219  if (parser.parseType(dimTy))
1220  return failure();
1221  } else {
1222  dimTy = IndexType::get(parser.getContext());
1223  }
1224  if (clusterValue.has_value()) {
1225  clusterXTy = clusterYTy = clusterZTy = dimTy;
1226  }
1227  return success();
1228 }
1230 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1231  Value clusterValue, Type clusterXTy,
1232  Type clusterYTy, Type clusterZTy) {
1233  if (!dimTy.isIndex())
1234  printer << ": " << dimTy;
1235 }
1238  OpAsmParser &parser,
1240  SmallVectorImpl<Type> &argTypes) {
1241  if (parser.parseOptionalKeyword("args"))
1242  return success();
1244  auto parseElement = [&]() -> ParseResult {
1245  return failure(parser.parseOperand(argNames.emplace_back()) ||
1246  parser.parseColonType(argTypes.emplace_back()));
1247  };
1250  parseElement, " in argument list");
1251 }
1254  OperandRange operands, TypeRange types) {
1255  if (operands.empty())
1256  return;
1257  printer << "args(";
1258  llvm::interleaveComma(llvm::zip(operands, types), printer,
1259  [&](const auto &pair) {
1260  printer.printOperand(std::get<0>(pair));
1261  printer << " : ";
1262  printer.printType(std::get<1>(pair));
1263  });
1264  printer << ")";
1265 }
1267 //===----------------------------------------------------------------------===//
1268 // ShuffleOp
1269 //===----------------------------------------------------------------------===//
1271 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1272  int32_t offset, int32_t width, ShuffleMode mode) {
1273  build(builder, result, value,
1274  builder.create<arith::ConstantOp>(result.location,
1275  builder.getI32IntegerAttr(offset)),
1276  builder.create<arith::ConstantOp>(result.location,
1277  builder.getI32IntegerAttr(width)),
1278  mode);
1279 }
1281 //===----------------------------------------------------------------------===//
1282 // BarrierOp
1283 //===----------------------------------------------------------------------===//
1285 namespace {
1287 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1288 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1289  PatternRewriter &rewriter) {
1290  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1291  rewriter.eraseOp(op);
1292  return success();
1293  }
1294  return failure();
1295 }
1297 } // end anonymous namespace
1299 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1300  MLIRContext *context) {
1301  results.add(eraseRedundantGpuBarrierOps);
1302 }
1304 //===----------------------------------------------------------------------===//
1305 // GPUFuncOp
1306 //===----------------------------------------------------------------------===//
1308 /// Adds a new block argument that corresponds to buffers located in
1309 /// workgroup memory.
1310 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1311  auto attrName = getNumWorkgroupAttributionsAttrName();
1312  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1313  (*this)->setAttr(attrName,
1314  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1315  return getBody().insertArgument(
1316  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1317 }
1319 /// Adds a new block argument that corresponds to buffers located in
1320 /// private memory.
1321 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1322  // Buffers on the private memory always come after buffers on the workgroup
1323  // memory.
1324  return getBody().addArgument(type, loc);
1325 }
1327 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1328  StringRef name, FunctionType type,
1329  TypeRange workgroupAttributions,
1330  TypeRange privateAttributions,
1331  ArrayRef<NamedAttribute> attrs) {
1332  OpBuilder::InsertionGuard g(builder);
1335  builder.getStringAttr(name));
1336  result.addAttribute(getFunctionTypeAttrName(,
1337  TypeAttr::get(type));
1338  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1339  builder.getI64IntegerAttr(workgroupAttributions.size()));
1340  result.addAttributes(attrs);
1341  Region *body = result.addRegion();
1342  Block *entryBlock = builder.createBlock(body);
1344  // TODO: Allow passing in proper locations here.
1345  for (Type argTy : type.getInputs())
1346  entryBlock->addArgument(argTy, result.location);
1347  for (Type argTy : workgroupAttributions)
1348  entryBlock->addArgument(argTy, result.location);
1349  for (Type argTy : privateAttributions)
1350  entryBlock->addArgument(argTy, result.location);
1351 }
1353 /// Parses a GPU function memory attribution.
1354 ///
1355 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1356 /// (`private` `(` ssa-id-and-type-list `)`)?
1357 ///
1358 /// Note that this function parses only one of the two similar parts, with the
1359 /// keyword provided as argument.
1360 static ParseResult
1361 parseAttributions(OpAsmParser &parser, StringRef keyword,
1363  Attribute &attributionAttrs) {
1364  // If we could not parse the keyword, just assume empty list and succeed.
1365  if (failed(parser.parseOptionalKeyword(keyword)))
1366  return success();
1368  size_t existingArgs = args.size();
1369  ParseResult result =
1371  /*allowType=*/true, /*allowAttrs=*/true);
1372  if (failed(result))
1373  return result;
1375  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1376  [](const OpAsmParser::Argument &arg) -> bool {
1377  return arg.attrs && !arg.attrs.empty();
1378  });
1379  if (!hadAttrs) {
1380  attributionAttrs = nullptr;
1381  return result;
1382  }
1384  Builder &builder = parser.getBuilder();
1385  SmallVector<Attribute> attributionAttrsVec;
1386  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1387  if (!argument.attrs)
1388  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1389  else
1390  attributionAttrsVec.push_back(argument.attrs);
1391  }
1392  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1393  return result;
1394 }
1396 /// Parses a GPU function.
1397 ///
1398 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1399 /// (`->` function-result-list)? memory-attribution `kernel`?
1400 /// function-attributes? region
1403  SmallVector<DictionaryAttr> resultAttrs;
1404  SmallVector<Type> resultTypes;
1405  bool isVariadic;
1407  // Parse the function name.
1408  StringAttr nameAttr;
1409  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1410  result.attributes))
1411  return failure();
1413  auto signatureLocation = parser.getCurrentLocation();
1415  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1416  resultAttrs)))
1417  return failure();
1419  if (!entryArgs.empty() && entryArgs[0]
1420  return parser.emitError(signatureLocation)
1421  << "gpu.func requires named arguments";
1423  // Construct the function type. More types will be added to the region, but
1424  // not to the function type.
1425  Builder &builder = parser.getBuilder();
1427  SmallVector<Type> argTypes;
1428  for (auto &arg : entryArgs)
1429  argTypes.push_back(arg.type);
1430  auto type = builder.getFunctionType(argTypes, resultTypes);
1431  result.addAttribute(getFunctionTypeAttrName(,
1432  TypeAttr::get(type));
1435  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(,
1436  getResAttrsAttrName(;
1438  Attribute workgroupAttributionAttrs;
1439  // Parse workgroup memory attributions.
1440  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1441  entryArgs, workgroupAttributionAttrs)))
1442  return failure();
1444  // Store the number of operands we just parsed as the number of workgroup
1445  // memory attributions.
1446  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1447  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1448  builder.getI64IntegerAttr(numWorkgroupAttrs));
1449  if (workgroupAttributionAttrs)
1450  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(,
1451  workgroupAttributionAttrs);
1453  Attribute privateAttributionAttrs;
1454  // Parse private memory attributions.
1455  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1456  entryArgs, privateAttributionAttrs)))
1457  return failure();
1458  if (privateAttributionAttrs)
1459  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(,
1460  privateAttributionAttrs);
1462  // Parse the kernel attribute if present.
1463  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1464  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1465  builder.getUnitAttr());
1467  // Parse attributes.
1469  return failure();
1471  // Parse the region. If no argument names were provided, take all names
1472  // (including those of attributions) from the entry block.
1473  auto *body = result.addRegion();
1474  return parser.parseRegion(*body, entryArgs);
1475 }
1477 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1478  ArrayRef<BlockArgument> values,
1479  ArrayAttr attributes) {
1480  if (values.empty())
1481  return;
1483  p << ' ' << keyword << '(';
1484  llvm::interleaveComma(
1485  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1486  BlockArgument v = pair.value();
1487  p << v << " : " << v.getType();
1489  size_t attributionIndex = pair.index();
1490  DictionaryAttr attrs;
1491  if (attributes && attributionIndex < attributes.size())
1492  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1493  if (attrs)
1494  p.printOptionalAttrDict(attrs.getValue());
1495  });
1496  p << ')';
1497 }
1499 void GPUFuncOp::print(OpAsmPrinter &p) {
1500  p << ' ';
1501  p.printSymbolName(getName());
1503  FunctionType type = getFunctionType();
1504  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1505  /*isVariadic=*/false,
1506  type.getResults());
1508  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1509  getWorkgroupAttribAttrs().value_or(nullptr));
1510  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1511  getPrivateAttribAttrs().value_or(nullptr));
1512  if (isKernel())
1513  p << ' ' << getKernelKeyword();
1516  p, *this,
1517  {getNumWorkgroupAttributionsAttrName(),
1518  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1519  getArgAttrsAttrName(), getResAttrsAttrName(),
1520  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1521  p << ' ';
1522  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1523 }
1525 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1526  StringAttr attrName) {
1527  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1528  if (!allAttrs || index >= allAttrs.size())
1529  return DictionaryAttr();
1530  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1531 }
1533 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1534  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1535 }
1537 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1538  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1539 }
1541 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1542  DictionaryAttr value, StringAttr attrName) {
1543  MLIRContext *ctx = op.getContext();
1544  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1545  SmallVector<Attribute> elements;
1546  if (allAttrs)
1547  elements.append(allAttrs.begin(), allAttrs.end());
1548  while (elements.size() <= index)
1549  elements.push_back(DictionaryAttr::get(ctx));
1550  if (!value)
1551  elements[index] = DictionaryAttr::get(ctx);
1552  else
1553  elements[index] = value;
1554  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1555  op->setAttr(attrName, newValue);
1556 }
1558 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1559  DictionaryAttr value) {
1560  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1561 }
1563 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1564  DictionaryAttr value) {
1565  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1566 }
1568 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1569  StringAttr name, StringAttr attrsName) {
1570  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1571  if (!dict)
1572  return Attribute();
1573  return dict.get(name);
1574 }
1576 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1577  StringAttr name) {
1578  assert(index < getNumWorkgroupAttributions() &&
1579  "index must map to a workgroup attribution");
1580  return getAttributionAttr(*this, index, name,
1581  getWorkgroupAttribAttrsAttrName());
1582 }
1584 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1585  StringAttr name) {
1586  assert(index < getNumPrivateAttributions() &&
1587  "index must map to a private attribution");
1588  return getAttributionAttr(*this, index, name,
1589  getPrivateAttribAttrsAttrName());
1590 }
1592 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1593  Attribute value, StringAttr attrsName) {
1594  MLIRContext *ctx = op.getContext();
1596  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1597  if (oldDict)
1598  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1600  bool found = false;
1601  bool mustSort = true;
1602  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1603  if (elems[i].getName() == name) {
1604  found = true;
1605  if (!value) {
1606  std::swap(elems[i], elems[elems.size() - 1]);
1607  elems.pop_back();
1608  } else {
1609  mustSort = false;
1610  elems[i] = NamedAttribute(elems[i].getName(), value);
1611  }
1612  break;
1613  }
1614  }
1615  if (!found) {
1616  if (!value)
1617  return;
1618  elems.emplace_back(name, value);
1619  }
1620  if (mustSort) {
1621  DictionaryAttr::sortInPlace(elems);
1622  }
1623  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1624  setAttributionAttrs(op, index, newDict, attrsName);
1625 }
1627 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1628  Attribute value) {
1629  assert(index < getNumWorkgroupAttributions() &&
1630  "index must map to a workgroup attribution");
1631  setAttributionAttr(*this, index, name, value,
1632  getWorkgroupAttribAttrsAttrName());
1633 }
1635 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1636  Attribute value) {
1637  assert(index < getNumPrivateAttributions() &&
1638  "index must map to a private attribution");
1639  setAttributionAttr(*this, index, name, value,
1640  getPrivateAttribAttrsAttrName());
1641 }
1643 LogicalResult GPUFuncOp::verifyType() {
1644  if (isKernel() && getFunctionType().getNumResults() != 0)
1645  return emitOpError() << "expected void return type for kernel function";
1647  return success();
1648 }
1650 /// Verifies the body of the function.
1651 LogicalResult GPUFuncOp::verifyBody() {
1652  if (empty())
1653  return emitOpError() << "expected body with at least one block";
1654  unsigned numFuncArguments = getNumArguments();
1655  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1656  unsigned numBlockArguments = front().getNumArguments();
1657  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1658  return emitOpError() << "expected at least "
1659  << numFuncArguments + numWorkgroupAttributions
1660  << " arguments to body region";
1662  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1663  for (unsigned i = 0; i < numFuncArguments; ++i) {
1664  Type blockArgType = front().getArgument(i).getType();
1665  if (funcArgTypes[i] != blockArgType)
1666  return emitOpError() << "expected body region argument #" << i
1667  << " to be of type " << funcArgTypes[i] << ", got "
1668  << blockArgType;
1669  }
1671  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1672  GPUDialect::getWorkgroupAddressSpace())) ||
1673  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1674  GPUDialect::getPrivateAddressSpace())))
1675  return failure();
1677  return success();
1678 }
1680 static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
1681  StringRef attrName) {
1682  auto maybeAttr = op->getAttr(attrName);
1683  if (!maybeAttr)
1684  return success();
1685  auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1686  if (!array)
1687  return op.emitOpError(attrName + " must be a dense i32 array");
1688  if (array.size() != 3)
1689  return op.emitOpError(attrName + " must contain exactly 3 elements");
1690  return success();
1691 }
1694  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
1695  return failure();
1696  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
1697  return failure();
1698  return success();
1699 }
1701 //===----------------------------------------------------------------------===//
1702 // ReturnOp
1703 //===----------------------------------------------------------------------===//
1706  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1708  FunctionType funType = function.getFunctionType();
1710  if (funType.getNumResults() != getOperands().size())
1711  return emitOpError()
1712  .append("expected ", funType.getNumResults(), " result operands")
1713  .attachNote(function.getLoc())
1714  .append("return type declared here");
1716  for (const auto &pair : llvm::enumerate(
1717  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1718  auto [type, operand] = pair.value();
1719  if (type != operand.getType())
1720  return emitOpError() << "unexpected type `" << operand.getType()
1721  << "' for operand #" << pair.index();
1722  }
1723  return success();
1724 }
1726 //===----------------------------------------------------------------------===//
1727 // GPUModuleOp
1728 //===----------------------------------------------------------------------===//
1730 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1731  StringRef name, ArrayAttr targets,
1732  Attribute offloadingHandler) {
1733  ensureTerminator(*result.addRegion(), builder, result.location);
1734  result.attributes.push_back(builder.getNamedAttr(
1737  Properties &props = result.getOrAddProperties<Properties>();
1738  if (targets)
1739  props.targets = targets;
1740  props.offloadingHandler = offloadingHandler;
1741 }
1743 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1744  StringRef name, ArrayRef<Attribute> targets,
1745  Attribute offloadingHandler) {
1746  build(builder, result, name,
1747  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1748  offloadingHandler);
1749 }
1752  StringAttr nameAttr;
1753  ArrayAttr targetsAttr;
1756  result.attributes))
1757  return failure();
1759  Properties &props = result.getOrAddProperties<Properties>();
1761  // Parse the optional offloadingHandler
1762  if (succeeded(parser.parseOptionalLess())) {
1763  if (parser.parseAttribute(props.offloadingHandler))
1764  return failure();
1765  if (parser.parseGreater())
1766  return failure();
1767  }
1769  // Parse the optional array of target attributes.
1770  OptionalParseResult targetsAttrResult =
1771  parser.parseOptionalAttribute(targetsAttr, Type{});
1772  if (targetsAttrResult.has_value()) {
1773  if (failed(*targetsAttrResult)) {
1774  return failure();
1775  }
1776  props.targets = targetsAttr;
1777  }
1779  // If module attributes are present, parse them.
1780  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1781  return failure();
1783  // Parse the module body.
1784  auto *body = result.addRegion();
1785  if (parser.parseRegion(*body, {}))
1786  return failure();
1788  // Ensure that this module has a valid terminator.
1789  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1790  return success();
1791 }
1794  p << ' ';
1795  p.printSymbolName(getName());
1797  if (Attribute attr = getOffloadingHandlerAttr()) {
1798  p << " <";
1799  p.printAttribute(attr);
1800  p << ">";
1801  }
1803  if (Attribute attr = getTargetsAttr()) {
1804  p << ' ';
1805  p.printAttribute(attr);
1806  p << ' ';
1807  }
1809  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1810  {mlir::SymbolTable::getSymbolAttrName(),
1811  getTargetsAttrName(),
1812  getOffloadingHandlerAttrName()});
1813  p << ' ';
1814  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1815  /*printBlockTerminators=*/false);
1816 }
1818 bool GPUModuleOp::hasTarget(Attribute target) {
1819  if (ArrayAttr targets = getTargetsAttr())
1820  return llvm::count(targets.getValue(), target);
1821  return false;
1822 }
1824 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1825  ArrayAttr &targetsAttr = getProperties().targets;
1826  SmallVector<Attribute> targetsVector(targets);
1827  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1828 }
1830 //===----------------------------------------------------------------------===//
1831 // GPUBinaryOp
1832 //===----------------------------------------------------------------------===//
1833 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1834  Attribute offloadingHandler, ArrayAttr objects) {
1835  auto &properties = result.getOrAddProperties<Properties>();
1836  result.attributes.push_back(builder.getNamedAttr(
1837  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1838  properties.objects = objects;
1839  if (offloadingHandler)
1840  properties.offloadingHandler = offloadingHandler;
1841  else
1842  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1843 }
1845 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1846  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1847  build(builder, result, name, offloadingHandler,
1848  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1849 }
1852  Attribute &offloadingHandler) {
1853  if (succeeded(parser.parseOptionalLess())) {
1854  if (parser.parseAttribute(offloadingHandler))
1855  return failure();
1856  if (parser.parseGreater())
1857  return failure();
1858  }
1859  if (!offloadingHandler)
1860  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1861  return success();
1862 }
1865  Attribute offloadingHandler) {
1866  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1867  printer << '<' << offloadingHandler << '>';
1868 }
1870 //===----------------------------------------------------------------------===//
1871 // GPUMemcpyOp
1872 //===----------------------------------------------------------------------===//
1874 LogicalResult MemcpyOp::verify() {
1875  auto srcType = getSrc().getType();
1876  auto dstType = getDst().getType();
1878  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1879  return emitOpError("arguments have incompatible element type");
1881  if (failed(verifyCompatibleShape(srcType, dstType)))
1882  return emitOpError("arguments have incompatible shape");
1884  return success();
1885 }
1887 namespace {
1889 /// Erases a common case of copy ops where a destination value is used only by
1890 /// the copy op, alloc and dealloc ops.
1891 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1892  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1894  LogicalResult matchAndRewrite(MemcpyOp op,
1895  PatternRewriter &rewriter) const override {
1896  Value dest = op.getDst();
1897  Operation *destDefOp = dest.getDefiningOp();
1898  // `dest` must be defined by an op having Allocate memory effect in order to
1899  // perform the folding.
1900  if (!destDefOp ||
1901  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1902  return failure();
1903  // We can erase `op` iff `dest` has no other use apart from its
1904  // use by `op` and dealloc ops.
1905  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1906  return user != op &&
1907  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1908  }))
1909  return failure();
1910  // We can perform the folding if and only if op has a single async
1911  // dependency and produces an async token as result, or if it does not have
1912  // any async dependency and does not produce any async token result.
1913  if (op.getAsyncDependencies().size() > 1 ||
1914  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1915  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1916  return failure();
1917  rewriter.replaceOp(op, op.getAsyncDependencies());
1918  return success();
1919  }
1920 };
1922 } // end anonymous namespace
1924 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1925  MLIRContext *context) {
1926  results.add<EraseTrivialCopyOp>(context);
1927 }
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaLoadMatrixOp
1931 //===----------------------------------------------------------------------===//
1933 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1934  auto srcType = getSrcMemref().getType();
1935  auto resType = getRes().getType();
1936  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1937  auto operand = resMatrixType.getOperand();
1938  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1940  if (!isLastMemrefDimUnitStride(srcMemrefType))
1941  return emitError(
1942  "expected source memref most minor dim must have unit stride");
1944  if (!operand.equals("AOp") && !operand.equals("BOp") &&
1945  !operand.equals("COp"))
1946  return emitError("only AOp, BOp and COp can be loaded");
1948  return success();
1949 }
1951 //===----------------------------------------------------------------------===//
1952 // GPU_SubgroupMmaStoreMatrixOp
1953 //===----------------------------------------------------------------------===//
1955 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1956  auto srcType = getSrc().getType();
1957  auto dstType = getDstMemref().getType();
1958  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1959  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1961  if (!isLastMemrefDimUnitStride(dstMemrefType))
1962  return emitError(
1963  "expected destination memref most minor dim must have unit stride");
1965  if (!srcMatrixType.getOperand().equals("COp"))
1966  return emitError(
1967  "expected the operand matrix being stored to have 'COp' operand type");
1969  return success();
1970 }
1972 //===----------------------------------------------------------------------===//
1973 // GPU_SubgroupMmaComputeOp
1974 //===----------------------------------------------------------------------===//
1976 LogicalResult SubgroupMmaComputeOp::verify() {
1977  enum OperandMap { A, B, C };
1978  SmallVector<MMAMatrixType, 3> opTypes;
1979  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1980  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1981  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1983  if (!opTypes[A].getOperand().equals("AOp") ||
1984  !opTypes[B].getOperand().equals("BOp") ||
1985  !opTypes[C].getOperand().equals("COp"))
1986  return emitError("operands must be in the order AOp, BOp, COp");
1988  ArrayRef<int64_t> aShape, bShape, cShape;
1989  aShape = opTypes[A].getShape();
1990  bShape = opTypes[B].getShape();
1991  cShape = opTypes[C].getShape();
1993  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1994  bShape[1] != cShape[1])
1995  return emitError("operand shapes do not satisfy matmul constraints");
1997  return success();
1998 }
2000 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2001  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2002  return memref::foldMemRefCast(*this);
2003 }
2005 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2006  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2007  return memref::foldMemRefCast(*this);
2008 }
2010 //===----------------------------------------------------------------------===//
2011 // GPU_WaitOp
2012 //===----------------------------------------------------------------------===//
2014 namespace {
2016 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2017 /// %t = gpu.wait async [] // No async dependencies.
2018 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2019 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2020 public:
2021  using OpRewritePattern::OpRewritePattern;
2023  LogicalResult matchAndRewrite(WaitOp op,
2024  PatternRewriter &rewriter) const final {
2025  auto predicate = [](Value value) {
2026  auto waitOp = value.getDefiningOp<WaitOp>();
2027  return waitOp && waitOp->getNumOperands() == 0;
2028  };
2029  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2030  return failure();
2031  SmallVector<Value> validOperands;
2032  for (Value operand : op->getOperands()) {
2033  if (predicate(operand))
2034  continue;
2035  validOperands.push_back(operand);
2036  }
2037  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2038  return success();
2039  }
2040 };
2042 /// Simplify trivial gpu.wait ops for the following patterns.
2043 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2044 /// dependencies).
2045 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2046 /// %t0.
2047 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2048 /// dependencies nor return any token.
2049 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2050 public:
2051  using OpRewritePattern::OpRewritePattern;
2053  LogicalResult matchAndRewrite(WaitOp op,
2054  PatternRewriter &rewriter) const final {
2055  // Erase gpu.wait ops that neither have any async dependencies nor return
2056  // any async token.
2057  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2058  rewriter.eraseOp(op);
2059  return success();
2060  }
2061  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2062  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2063  op.getAsyncToken()) {
2064  rewriter.replaceOp(op, op.getAsyncDependencies());
2065  return success();
2066  }
2067  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2068  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2069  rewriter.eraseOp(op);
2070  return success();
2071  }
2072  return failure();
2073  }
2074 };
2076 } // end anonymous namespace
2078 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2079  MLIRContext *context) {
2080  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2081 }
2083 //===----------------------------------------------------------------------===//
2084 // GPU_AllocOp
2085 //===----------------------------------------------------------------------===//
2087 LogicalResult AllocOp::verify() {
2088  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2090  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2091  memRefType.getNumDynamicDims())
2092  return emitOpError("dimension operand count does not equal memref "
2093  "dynamic dimension count");
2095  unsigned numSymbols = 0;
2096  if (!memRefType.getLayout().isIdentity())
2097  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2098  if (getSymbolOperands().size() != numSymbols) {
2099  return emitOpError(
2100  "symbol operand count does not equal memref symbol count");
2101  }
2103  return success();
2104 }
2106 namespace {
2108 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2109 /// `memref::AllocOp`.
2110 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2111  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2113  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2114  PatternRewriter &rewriter) const override {
2115  std::optional<int64_t> index = dimOp.getConstantIndex();
2116  if (!index)
2117  return failure();
2119  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2120  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2121  return failure();
2123  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2124  if (!alloc)
2125  return failure();
2127  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2128  memrefType.getDynamicDimIndex(index.value()));
2129  rewriter.replaceOp(dimOp, substituteOp);
2130  return success();
2131  }
2132 };
2134 } // namespace
2136 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2137  MLIRContext *context) {
2138  results.add<SimplifyDimOfAllocOp>(context);
2139 }
2141 //===----------------------------------------------------------------------===//
2142 // GPU object attribute
2143 //===----------------------------------------------------------------------===//
2145 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2146  Attribute target, CompilationTarget format,
2147  StringAttr object, DictionaryAttr properties) {
2148  if (!target)
2149  return emitError() << "the target attribute cannot be null";
2150  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2151  return success();
2152  return emitError() << "the target attribute must implement or promise the "
2153  "`gpu::TargetAttrInterface`";
2154 }
2156 namespace {
2157 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2158  StringAttr &object) {
2159  std::optional<CompilationTarget> formatResult;
2160  StringRef enumKeyword;
2161  auto loc = odsParser.getCurrentLocation();
2162  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2163  formatResult = CompilationTarget::Fatbin;
2164  if (!formatResult &&
2165  (formatResult =
2166  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2167  odsParser.parseEqual())
2168  return odsParser.emitError(loc, "expected an equal sign");
2169  if (!formatResult)
2170  return odsParser.emitError(loc, "expected keyword for GPU object format");
2171  FailureOr<StringAttr> objectResult =
2172  FieldParser<StringAttr>::parse(odsParser);
2173  if (failed(objectResult))
2174  return odsParser.emitError(odsParser.getCurrentLocation(),
2175  "failed to parse GPU_ObjectAttr parameter "
2176  "'object' which is to be a `StringAttr`");
2177  format = *formatResult;
2178  object = *objectResult;
2179  return success();
2180 }
2182 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2183  StringAttr object) {
2184  if (format != CompilationTarget::Fatbin)
2185  odsParser << stringifyEnum(format) << " = ";
2186  odsParser << object;
2187 }
2188 } // namespace
2190 //===----------------------------------------------------------------------===//
2191 // GPU select object attribute
2192 //===----------------------------------------------------------------------===//
2194 LogicalResult
2195 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2196  Attribute target) {
2197  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2198  if (target) {
2199  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2200  if (intAttr.getInt() < 0) {
2201  return emitError() << "the object index must be positive";
2202  }
2203  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2204  return emitError()
2205  << "the target attribute must be a GPU Target attribute";
2206  }
2207  }
2208  return success();
2209 }
2211 //===----------------------------------------------------------------------===//
2212 // DynamicSharedMemoryOp
2213 //===----------------------------------------------------------------------===//
2215 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2216  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2217  return emitOpError() << "must be inside an op with symbol table";
2219  MemRefType memrefType = getResultMemref().getType();
2220  // Check address space
2221  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2222  return emitOpError() << "address space must be "
2223  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2224  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2225  }
2226  if (memrefType.hasStaticShape()) {
2227  return emitOpError() << "result memref type must be memref<?xi8, "
2228  "#gpu.address_space<workgroup>>";
2229  }
2230  return success();
2231 }
2233 //===----------------------------------------------------------------------===//
2234 // GPU target options
2235 //===----------------------------------------------------------------------===//
2237 TargetOptions::TargetOptions(
2238  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2239  StringRef cmdOptions, CompilationTarget compilationTarget,
2240  function_ref<SymbolTable *()> getSymbolTableCallback)
2241  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2242  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2244 TargetOptions::TargetOptions(
2245  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2246  StringRef cmdOptions, CompilationTarget compilationTarget,
2247  function_ref<SymbolTable *()> getSymbolTableCallback)
2248  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2249  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2250  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2252 TypeID TargetOptions::getTypeID() const { return typeID; }
2254 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2256 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2258 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2260 SymbolTable *TargetOptions::getSymbolTable() const {
2261  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2262 }
2264 CompilationTarget TargetOptions::getCompilationTarget() const {
2265  return compilationTarget;
2266 }
2268 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2269  return CompilationTarget::Fatbin;
2270 }
2272 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2273 TargetOptions::tokenizeCmdOptions() const {
2274  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2275  llvm::StringSaver stringSaver(options.first);
2276  StringRef opts = cmdOptions;
2277  // For a correct tokenization of the command line options `opts` must be
2278  // unquoted, otherwise the tokenization function returns a single string: the
2279  // unquoted `cmdOptions` -which is not the desired behavior.
2280  // Remove any quotes if they are at the beginning and end of the string:
2281  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2282  opts.consume_front("\""), opts.consume_back("\"");
2283  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2284  opts.consume_front("'"), opts.consume_back("'");
2285 #ifdef _WIN32
2286  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2287  /*MarkEOLs=*/false);
2288 #else
2289  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2290  /*MarkEOLs=*/false);
2291 #endif // _WIN32
2292  return options;
2293 }
2297 #include "mlir/Dialect/GPU/IR/"
2298 #include "mlir/Dialect/GPU/IR/"
2301 #include "mlir/Dialect/GPU/IR/"
2303 #define GET_OP_CLASSES
2304 #include "mlir/Dialect/GPU/IR/"
2306 #include "mlir/Dialect/GPU/IR/"
